Class RecallAtPrecision<T extends TNumber>

java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.RecallAtPrecision<T>
Type Parameters:
T - The data type for the metric result
All Implemented Interfaces:
Metric

public class RecallAtPrecision<T extends TNumber> extends BaseMetric
Computes best recall where precision is >= specified value.

For a given score-label-distribution the required precision might not be achievable, in this case 0.0 is returned as recall.

This metric creates four local variables, truePositives, trueNegatives, falsePositives and falseNegatives that are used to compute the recall at the given precision. The threshold for the given precision value is computed and used to evaluate the corresponding recall.

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values.

  • Field Details

  • Constructor Details

    • RecallAtPrecision

      public RecallAtPrecision(float precision, long seed, Class<T> type)
      Creates a PrecisionRecall metric with a name of Class.getSimpleName() and DEFAULT_NUM_THRESHOLDS for the number of thresholds
      Parameters:
      precision - the precision. A scalar value in range [0, 1]
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
      Throws:
      IllegalArgumentException - if numThresholds <= 0 or if recall is not in the range [0-1].
    • RecallAtPrecision

      public RecallAtPrecision(String name, float precision, long seed, Class<T> type)
      Creates a PrecisionRecall metric with DEFAULT_NUM_THRESHOLDS for the number of thresholds
      Parameters:
      name - the name of the metric. If null, defaults to Class.getSimpleName()
      precision - the precision. A scalar value in range [0, 1]
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
      Throws:
      IllegalArgumentException - if numThresholds <= 0 or if recall is not in the range [0-1].
    • RecallAtPrecision

      public RecallAtPrecision(float precision, int numThresholds, long seed, Class<T> type)
      Creates a PrecisionRecall metric with a name of Class.getSimpleName().
      Parameters:
      precision - the precision. A scalar value in range [0, 1]
      numThresholds - Defaults to 200. The number of thresholds to use for matching the given recall.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
      Throws:
      IllegalArgumentException - if numThresholds <= 0 or if recall is not in the range [0-1].
    • RecallAtPrecision

      public RecallAtPrecision(String name, float precision, int numThresholds, long seed, Class<T> type)
      Creates a PrecisionRecall metric.
      Parameters:
      name - the name of the metric, if null defaults to Class.getSimpleName()
      precision - the precision. A scalar value in range [0, 1]
      numThresholds - Defaults to 200. The number of thresholds to use for matching the given recall.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
      Throws:
      IllegalArgumentException - if numThresholds <= 0 or if recall is not in the range [0-1].
  • Method Details

    • result

      public <U extends TNumber> Operand<U> result(Ops tf, Class<U> resultType)
      Gets the current result of the metric
      Type Parameters:
      U - the date type for the result
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      resultType - the data type for the result
      Returns:
      the result, possibly with control dependencies
    • getPrecision

      public float getPrecision()
      Gets the precision
      Returns:
      the precision
    • init

      protected void init(Ops tf)
      Initialize the TensorFlow Ops
      Specified by:
      init in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
      Accumulates confusion matrix statistics.
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - The ground truth values.
      predictions - the predictions
      sampleWeights - Optional weighting of each example. Defaults to 1. Rank is either 0, or the same rank as labels, and must be broadcastable to labels.
      Returns:
      a List of Operations to update the metric state.
    • resetStates

      public Op resetStates(Ops tf)
      Resets any state variables to their initial values
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Returns:
      the operation for doing the reset
    • getTruePositives

      public Variable<T> getTruePositives()
      Gets the truePositives variable
      Returns:
      the truePositives
    • getFalsePositives

      public Variable<T> getFalsePositives()
      Gets the falsePositives variable
      Returns:
      the falsePositives truePositives
    • getTrueNegatives

      public Variable<T> getTrueNegatives()
      Gets the trueNegatives variable
      Returns:
      the trueNegatives truePositives
    • getFalseNegatives

      public Variable<T> getFalseNegatives()
      Gets the falseNegatives variable
      Returns:
      the falseNegatives truePositives
    • getNumThresholds

      public int getNumThresholds()
      Gets the numThresholds
      Returns:
      the numThresholds
    • getThresholds

      public float[] getThresholds()
      Gets the thresholds
      Returns:
      the thresholds
    • getTruePositivesName

      public String getTruePositivesName()
      Gets the truePositives variable name
      Returns:
      the truePositivesName
    • getFalsePositivesName

      public String getFalsePositivesName()
      Gets the falsePositives variable name
      Returns:
      the falsePositivesName
    • getTrueNegativesName

      public String getTrueNegativesName()
      Gets the trueNegatives variable name
      Returns:
      the trueNegativesName
    • getFalseNegativesName

      public String getFalseNegativesName()
      Gets the falseNegatives variable name
      Returns:
      the falseNegativesName
    • getType

      public Class<T> getType()
      Gets the internalType
      Returns:
      the internalType
    • getInternalType

      public Class<T> getInternalType()