Class Recall<T extends TNumber>

java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.Recall<T>
Type Parameters:
T - The data type for the metric result
All Implemented Interfaces:
Metric

public class Recall<T extends TNumber> extends BaseMetric
Computes the recall of the predictions with respect to the labels.

This metric creates two local variables, truePositives and falseNegatives , that are used to compute the recall. This value is ultimately returned as recall, an idempotent operation that simply divides truePositives by the sum of truePositives and falseNegatives.

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values.

If topK is set, the metric calculates recall as how often on average a class among the labels of a batch entry is in the top-k predictions.

If classId is specified, the metric calculates recall by considering only the entries in the batch for which classId is in the label, and computing the fraction of them for which classId is above the threshold and/or in the top-k predictions.

  • Field Details

  • Constructor Details

    • Recall

      public Recall(long seed, Class<T> type)
      Creates a Recall metric with a name of Class.getSimpleName(), and topK and classId set to null, and thresholds set to DEFAULT_THRESHOLD
      Parameters:
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Recall

      public Recall(String name, long seed, Class<T> type)
      Creates a Recall metric with topK and classId set to null and thresholds set to DEFAULT_THRESHOLD.
      Parameters:
      name - name of the metric instance. If null, name defaults to Class.getSimpleName().
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Recall

      public Recall(float threshold, long seed, Class<T> type)
      Creates a Recall metric with a name of Class.getSimpleName(), and topK and classId set to null.
      Parameters:
      threshold - A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to DEFAULT_THRESHOLD.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Recall

      public Recall(float[] thresholds, long seed, Class<T> type)
      Creates a Recall metric with a name of Class.getSimpleName(), and topK and classId set to null.
      Parameters:
      thresholds - A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to DEFAULT_THRESHOLD.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Recall

      public Recall(String name, float threshold, long seed, Class<T> type)
      Creates a Recall metric with topK and classId set to null.
      Parameters:
      name - name of the metric instance. If null, name defaults to Class.getSimpleName().
      threshold - A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to DEFAULT_THRESHOLD.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Recall

      public Recall(String name, float[] thresholds, long seed, Class<T> type)
      Creates a Recall metric with topK and classId set to null.
      Parameters:
      name - name of the metric instance. If null, name defaults to Class.getSimpleName().
      thresholds - A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to DEFAULT_THRESHOLD.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Recall

      public Recall(Integer topK, Integer classId, long seed, Class<T> type)
      Creates a Recall metric with a name of Class.getSimpleName() and using a threshold value of DEFAULT_THRESHOLD.
      Parameters:
      topK - An optional value specifying the top-k predictions to consider when calculating precision.
      classId - Optional Integer class ID for which we want binary metrics. This must be in the half-open interval [0, numClasses], where numClasses is the last dimension of predictions.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Recall

      public Recall(String name, Integer topK, Integer classId, long seed, Class<T> type)
      Creates a Recall metric using a threshold value of DEFAULT_THRESHOLD.
      Parameters:
      name - name of the metric instance. If null, name defaults to Class.getSimpleName().
      topK - An optional value specifying the top-k predictions to consider when calculating precision.
      classId - Optional Integer class ID for which we want binary metrics. This must be in the half-open interval [0, numClasses], where numClasses is the last dimension of predictions.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Recall

      public Recall(float threshold, Integer topK, Integer classId, long seed, Class<T> type)
      Creates a Recall metric with a name of Class.getSimpleName()
      Parameters:
      threshold - A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to DEFAULT_THRESHOLD.
      topK - An optional value specifying the top-k predictions to consider when calculating precision.
      classId - Optional Integer class ID for which we want binary metrics. This must be in the half-open interval [0, numClasses], where numClasses is the last dimension of predictions.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Recall

      public Recall(float[] thresholds, Integer topK, Integer classId, long seed, Class<T> type)
      Creates a Recall metric with a name of Class.getSimpleName()
      Parameters:
      thresholds - A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to DEFAULT_THRESHOLD.
      topK - An optional value specifying the top-k predictions to consider when calculating precision.
      classId - Optional Integer class ID for which we want binary metrics. This must be in the half-open interval [0, numClasses], where numClasses is the last dimension of predictions.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Recall

      public Recall(String name, float threshold, Integer topK, Integer classId, long seed, Class<T> type)
      Creates a Recall metric.
      Parameters:
      name - name of the metric instance. If null, name defaults to Class.getSimpleName().
      threshold - A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to DEFAULT_THRESHOLD.
      topK - An optional value specifying the top-k predictions to consider when calculating precision.
      classId - Optional Integer class ID for which we want binary metrics. This must be in the half-open interval [0, numClasses], where numClasses is the last dimension of predictions.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • Recall

      public Recall(String name, float[] thresholds, Integer topK, Integer classId, long seed, Class<T> type)
      Creates a Recall metric.
      Parameters:
      name - name of the metric instance. If null, name defaults to Class.getSimpleName().
      thresholds - A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to DEFAULT_THRESHOLD.
      topK - An optional value specifying the top-k predictions to consider when calculating precision.
      classId - Optional Integer class ID for which we want binary metrics. This must be in the half-open interval [0, numClasses], where numClasses is the last dimension of predictions.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
  • Method Details

    • init

      protected void init(Ops tf)
      Initialize the TensorFlow Ops
      Specified by:
      init in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
    • resetStates

      public Op resetStates(Ops tf)
      Resets any state variables to their initial values
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Returns:
      the operation for doing the reset
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
      Accumulates true positive and false negative statistics.
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment. The TensorFlow Ops
      labels - the labels The ground truth values, with the same dimensions as predictions. Will be cast to TBool.
      predictions - the predictions, each element must be in the range [0, 1].
      sampleWeights - Optional weighting of each example. Defaults to 1. Rank is either 0, or * the same rank as labels, and must be broadcastable to labels.
      Returns:
      a List of Operations to update the metric state.
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not encapsulate a Graph environment.
    • result

      public <U extends TNumber> Operand<U> result(Ops tf, Class<U> resultType)
      Description copied from interface: Metric
      Gets the current result of the metric
      Type Parameters:
      U - the date type for the result
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      resultType - the data type for the result
      Returns:
      the result, possibly with control dependencies
    • getThresholds

      public float[] getThresholds()
      Gets the thresholds
      Returns:
      the thresholds
    • getTopK

      public Integer getTopK()
      Gets the topK value
      Returns:
      the topK value
    • getClassId

      public Integer getClassId()
      Gets the class id
      Returns:
      the class id
    • getTruePositives

      public Variable<T> getTruePositives()
      Gets the truePositives variable
      Returns:
      the truePositives variable
    • getFalseNegatives

      public Variable<T> getFalseNegatives()
      Gets the falseNegatives variable
      Returns:
      the falseNegatives variable
    • getTruePositivesName

      public String getTruePositivesName()
      Gets the truePositives variable name
      Returns:
      the truePositives variable name
    • getFalseNegativesName

      public String getFalseNegativesName()
      Gets the falseNegatives variable name
      Returns:
      the falseNegatives variable name