Class Precision<T extends TNumber>

java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.Precision<T>
Type Parameters:
T - The data type for the metric result
All Implemented Interfaces:
Metric

public class Precision<T extends TNumber> extends BaseMetric
Computes the precision of the predictions with respect to the labels.

The metric creates two local variables, truePositives and falsePositives that are used to compute the precision. This value is ultimately returned as precision, an idempotent operation that simply divides truePositives by the sum of truePositives and falsePositives.

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values.

If topK is set, the metric calculates precision as how often on average a class among the top-k classes with the highest predicted values of a batch entry is correct and can be found in the label for that entry.

If classId is specified, the metric calculates precision by considering only the entries in the batch for which classId is above the thresholds and/or in the top-k highest predictions, and computing the fraction of them for which classId is indeed a correct label.

  • Field Details

  • Constructor Details

    • Precision

      public Precision(long seed, Class<T> type)
      Creates a Precision Metric with a name of Class.getSimpleName() and no topK or classId values and with a threshold of DEFAULT_THRESHOLD.
      Parameters:
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Precision

      public Precision(String name, long seed, Class<T> type)
      Creates a Precision Metric with no topK or classId values with a threshold of DEFAULT_THRESHOLD.
      Parameters:
      name - name of the metric instance. If null, name defaults to Class.getSimpleName().
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Precision

      public Precision(float threshold, long seed, Class<T> type)
      Creates a Precision Metric with a name of Class.getSimpleName() and no topK or classId values.
      Parameters:
      threshold - Optional threshold value in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Precision

      public Precision(float[] thresholds, long seed, Class<T> type)
      Creates a Precision Metric with a name of Class.getSimpleName() and no topK or classId values.
      Parameters:
      thresholds - Optional threshold values in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Precision

      public Precision(String name, float threshold, long seed, Class<T> type)
      Creates a Precision Metric with no topK or classId values.
      Parameters:
      name - name of the metric instance. If null, name defaults to Class.getSimpleName().
      threshold - Optional threshold value in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Precision

      public Precision(String name, float[] thresholds, long seed, Class<T> type)
      Creates a Precision Metric with no topK or classId values.
      Parameters:
      name - name of the metric instance. If null, name defaults to Class.getSimpleName().
      thresholds - Optional threshold values in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Precision

      public Precision(float threshold, Integer topK, Integer classId, long seed, Class<T> type)
      Creates a Precision Metric with a name of Class.getSimpleName()
      Parameters:
      threshold - Optional threshold value in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value.
      topK - An optional value specifying the top-k predictions to consider when calculating precision.
      classId - Optional Integer class ID for which we want binary metrics. This must be in the half-open interval [0, numClasses], where numClasses is the last dimension of predictions.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Precision

      public Precision(float[] thresholds, Integer topK, Integer classId, long seed, Class<T> type)
      Creates a Precision Metric with a name of Class.getSimpleName()
      Parameters:
      thresholds - Optional threshold values in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value.
      topK - An optional value specifying the top-k predictions to consider when calculating precision.
      classId - Optional Integer class ID for which we want binary metrics. This must be in the half-open interval [0, numClasses], where numClasses is the last dimension of predictions.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Precision

      public Precision(String name, float threshold, Integer topK, Integer classId, long seed, Class<T> type)
      Creates a Precision Metric.
      Parameters:
      name - name of the metric instance. If null, name defaults to Class.getSimpleName().
      threshold - Optional threshold value in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value.
      topK - An optional value specifying the top-k predictions to consider when calculating precision.
      classId - Optional Integer class ID for which we want binary metrics. This must be in the half-open interval [0, numClasses], where numClasses is the last dimension of predictions.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Precision

      public Precision(String name, float[] thresholds, Integer topK, Integer classId, long seed, Class<T> type)
      Creates a Precision Metric.
      Parameters:
      name - name of the metric instance. If null, name defaults to Class.getSimpleName().
      thresholds - Optional threshold values in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value.
      topK - An optional value specifying the top-k predictions to consider when calculating precision.
      classId - Optional Integer class ID for which we want binary metrics. This must be in the half-open interval [0, numClasses], where numClasses is the last dimension of predictions.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
  • Method Details

    • init

      protected void init(Ops tf)
      Initialize the TensorFlow Ops
      Specified by:
      init in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
      Accumulates true positive and false positive statistics.
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - the labels The ground truth values, with the same dimensions as predictions. Will be cast to TBool.
      predictions - the predictions, each element must be in the range [0, 1].
      sampleWeights - Optional weighting of each example. Defaults to 1. Rank is either 0, or * the same rank as labels, and must be broadcastable to labels.
      Returns:
      a List of Operations to update the metric state.
    • result

      public <U extends TNumber> Operand<U> result(Ops tf, Class<U> resultType)
      Gets the current result of the metric
      Type Parameters:
      U - the date type for the result
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      resultType - the data type for the result
      Returns:
      the result, possibly with control dependencies
    • resetStates

      public Op resetStates(Ops tf)
      Resets any state variables to their initial values
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Returns:
      the operation for doing the reset
    • getThresholds

      public float[] getThresholds()
      Gets the thresholds
      Returns:
      the thresholds
    • getTopK

      public Integer getTopK()
      Gets the topK value, may be null
      Returns:
      the topK value or null
    • getClassId

      public Integer getClassId()
      Gets the classId, may be null
      Returns:
      the classId or null
    • getTruePositives

      public Variable<T> getTruePositives()
      Gets the truePositives variable
      Returns:
      the truePositives
    • getFalsePositives

      public Variable<T> getFalsePositives()
      Gets the falsePositives variable
      Returns:
      the falsePositives
    • getTruePositivesName

      public String getTruePositivesName()
      Gets the name of the truePositives variable
      Returns:
      the truePositivesName
    • getFalsePositivesName

      public String getFalsePositivesName()
      Gets the name of the falsePositives variable
      Returns:
      the falsePositivesName