Class TruePositives<T extends TNumber>

java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.TruePositives<T>
Type Parameters:
T - The data type for the metric result
All Implemented Interfaces:
Metric

public class TruePositives<T extends TNumber> extends BaseMetric
Metric that calculates the number of true positives.

If sampleWeights is given, calculates the sum of the weights of true positives. This metric creates one local variable, accumulator that is used to keep track of the number of true positives.

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values.

  • Field Details

  • Constructor Details

    • TruePositives

      public TruePositives(long seed, Class<T> type)
      Creates a TruePositives metric, using Class.getSimpleName() for the metric name and a default threshold of DEFAULT_THRESHOLD.
      Parameters:
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • TruePositives

      public TruePositives(float threshold, long seed, Class<T> type)
      Creates a TruePositives metric, using Class.getSimpleName() for the metric name
      Parameters:
      threshold - a threshold value in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • TruePositives

      public TruePositives(float[] thresholds, long seed, Class<T> type)
      Creates a TruePositives metric, using Class.getSimpleName() for the metric name
      Parameters:
      thresholds - threshold values in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • TruePositives

      public TruePositives(String name, long seed, Class<T> type)
      Creates a TruePositives metric, using a default threshold of DEFAULT_THRESHOLD.
      Parameters:
      name - the name of the metric, if null then Class.getSimpleName() is used
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • TruePositives

      public TruePositives(String name, float threshold, long seed, Class<T> type)
      Creates a TruePositives metric
      Parameters:
      name - the name of the metric, if null then Class.getSimpleName() is used
      threshold - a threshold value in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • TruePositives

      public TruePositives(String name, float[] thresholds, long seed, Class<T> type)
      Creates a TruePositives metric
      Parameters:
      name - the name of the metric, if null then Class.getSimpleName() is used
      thresholds - threshold values in the range [0, 1]. A threshold is compared with prediction values to determine the truth value of predictions (i.e., above the threshold is true, below is false). One metric value is generated for each threshold value
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
  • Method Details

    • init

      protected void init(Ops tf)
      Initialize the TensorFlow Ops
      Specified by:
      init in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
      Accumulates the metric statistics.
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - The ground truth values.
      predictions - the predictions
      sampleWeights - Optional weighting of each example. Defaults to 1. Rank is either 0, or the same rank as labels, and must be broadcastable to labels.
      Returns:
      a List of Operations to update the metric state.
    • result

      public <U extends TNumber> Operand<U> result(Ops tf, Class<U> type)
      Gets the current result of the metric
      Type Parameters:
      U - the date type for the result
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      type - the data type for the result
      Returns:
      the result, possibly with control dependencies
    • resetStates

      public Op resetStates(Ops tf)
      Resets any state variables to their initial values
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Returns:
      the operation for doing the reset
    • getThresholds

      public float[] getThresholds()
      get the thresholds
      Returns:
      the thresholds
    • getAccumulatorName

      public String getAccumulatorName()
      Gets the accumulatorName
      Returns:
      the accumulatorName