Class PrecisionAtRecall<T extends TNumber>

java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.PrecisionAtRecall<T>
Type Parameters:
T - The data type for the metric result
All Implemented Interfaces:
Metric

public class PrecisionAtRecall<T extends TNumber> extends BaseMetric
Computes best precision where recall is >= specified value.

This metric creates four local variables, truePositives, trueNegatives, falsePositives and falseNegatives that are used to compute the precision at the given recall. The threshold for the given recall value is computed and used to evaluate the corresponding precision.

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values.

  • Field Details

  • Constructor Details

    • PrecisionAtRecall

      public PrecisionAtRecall(float recall, long seed, Class<T> type)
      Creates a PrecisionRecall metric with a name of Class.getSimpleName() and DEFAULT_NUM_THRESHOLDS for the number of thresholds
      Parameters:
      recall - the recall. A scalar value in range [0, 1]
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
      Throws:
      IllegalArgumentException - if numThresholds <= 0 or if recall is not in the range [0-1].
    • PrecisionAtRecall

      public PrecisionAtRecall(String name, float recall, long seed, Class<T> type)
      Creates a PrecisionRecall metric with DEFAULT_NUM_THRESHOLDS for the number of thresholds
      Parameters:
      name - the name of the metric, if null defaults to Class.getSimpleName()
      recall - the recall. A scalar value in range [0, 1]
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
      Throws:
      IllegalArgumentException - if numThresholds <= 0 or if recall is not in the range [0-1].
    • PrecisionAtRecall

      public PrecisionAtRecall(float recall, int numThresholds, long seed, Class<T> type)
      Creates a PrecisionRecall metric with a name of Class.getSimpleName().
      Parameters:
      recall - the recall. A scalar value in range [0, 1]
      numThresholds - Defaults to 200. The number of thresholds to use for matching the given recall.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
      Throws:
      IllegalArgumentException - if numThresholds <= 0 or if recall is not in the range [0-1].
    • PrecisionAtRecall

      public PrecisionAtRecall(String name, float recall, int numThresholds, long seed, Class<T> type)
      Creates a PrecisionRecall metric.
      Parameters:
      name - the name of the metric, if null defaults to Class.getSimpleName()
      recall - the recall. A scalar value in range [0, 1]
      numThresholds - Defaults to 200. The number of thresholds to use for matching the given recall.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
      Throws:
      IllegalArgumentException - if numThresholds <= 0 or if recall is not in the range [0-1].
  • Method Details

    • result

      public <U extends TNumber> Operand<U> result(Ops tf, Class<U> resultType)
      Gets the current result of the metric
      Type Parameters:
      U - the date type for the result
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      resultType - the data type for the result
      Returns:
      the result, possibly with control dependencies
    • getRecall

      public float getRecall()
      Gets the recall value
      Returns:
      the recall value
    • init

      protected void init(Ops tf)
      Initialize the TensorFlow Ops
      Specified by:
      init in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
      Accumulates confusion matrix statistics.
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - The ground truth values.
      predictions - the predictions
      sampleWeights - Optional weighting of each example. Defaults to 1. Rank is either 0, or the same rank as labels, and must be broadcastable to labels.
      Returns:
      a List of Operations to update the metric state.
    • resetStates

      public Op resetStates(Ops tf)
      Resets any state variables to their initial values
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Returns:
      the operation for doing the reset
    • getTruePositives

      public Variable<T> getTruePositives()
      Gets the truePositives variable
      Returns:
      the truePositives
    • getFalsePositives

      public Variable<T> getFalsePositives()
      Gets the falsePositives variable
      Returns:
      the falsePositives truePositives
    • getTrueNegatives

      public Variable<T> getTrueNegatives()
      Gets the trueNegatives variable
      Returns:
      the trueNegatives truePositives
    • getFalseNegatives

      public Variable<T> getFalseNegatives()
      Gets the falseNegatives variable
      Returns:
      the falseNegatives truePositives
    • getNumThresholds

      public int getNumThresholds()
      Gets the numThresholds
      Returns:
      the numThresholds
    • getThresholds

      public float[] getThresholds()
      Gets the thresholds
      Returns:
      the thresholds
    • getTruePositivesName

      public String getTruePositivesName()
      Gets the truePositives variable name
      Returns:
      the truePositivesName
    • getFalsePositivesName

      public String getFalsePositivesName()
      Gets the falsePositives variable name
      Returns:
      the falsePositivesName
    • getTrueNegativesName

      public String getTrueNegativesName()
      Gets the trueNegatives variable name
      Returns:
      the trueNegativesName
    • getFalseNegativesName

      public String getFalseNegativesName()
      Gets the falseNegatives variable name
      Returns:
      the falseNegativesName
    • getType

      public Class<T> getType()
      Gets the internalType
      Returns:
      the internalType
    • getInternalType

      public Class<T> getInternalType()