Class AUC<T extends TNumber>

java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.AUC<T>
Type Parameters:
T - The data type for the metric result
All Implemented Interfaces:
Metric

public class AUC<T extends TNumber> extends BaseMetric
Metric that computes the approximate AUC (Area under the curve) via a Riemann sum.

This metric creates four local variables, truePositives, trueNegatives , falsePositives and falseNegatives that are used to compute the AUC. To discretize the AUC curve, a linearly spaced set of thresholds is used to compute pairs of recall and precision values. The area under the ROC-curve is therefore computed using the height of the recall values by the false positive rate, while the area under the PR-curve is the computed using the height of the precision values by the recall.

This value is ultimately returned as auc, an idempotent operation that computes the area under a discretized curve of precision versus recall values (computed using the aforementioned variables). The numThresholds variable controls the degree of discretization with larger numbers of thresholds more closely approximating the true AUC. The quality of the approximation may vary dramatically depending on numThresholds. The thresholds parameter can be used to manually specify thresholds which split the predictions more evenly.

For best results, predictions should be distributed approximately uniformly in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor if this is not the case. Setting summationMethod to minoring or majoring can help quantify the error in the approximation by providing lower or upper bound estimate of the AUC.

Usage:

AUC m = new  org.tensorflow.framework.metrics.AUC( tf, 3);
m.updateState( tf.constant(new float[] {0, 0, 1,1}),
         tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));

// threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
// tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
// recall = [1, 0.5, 0], fpRate = [1, 0, 0]
// auc = ((((1+0.5)/2)*(1-0))+ (((0.5+0)/2)*(0-0))) = 0.75
Operand<TFloat32> result = m.result();
System.out.println(result.data().getFloat());
0.75
m.resetStates()
m.updateState( tf.constant(new float[] {0, 0, 1, 1}),
                tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}, ),
                tf.constant(new float[] {1, 0, 0, 1}));
result = m.result();
System.out.println(result.data().getFloat());
1.0
  • Field Details

  • Constructor Details

    • AUC

      public AUC(long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric using DEFAULT_NAME for the metric name, DEFAULT_NUM_THRESHOLDS for the numThresholds, AUCCurve.ROC for the curve type, AUCSummationMethod.INTERPOLATION for the summation method, null for thresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(String name, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric using DEFAULT_NUM_THRESHOLDS for the numThresholds, AUCCurve.ROC for the curve type, AUCSummationMethod.INTERPOLATION for the summation method, null for thresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      name - the name of the metric, if null defaults to DEFAULT_NAME
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(int numThresholds, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric using DEFAULT_NAME for the metric name, AUCCurve.ROC for the curve type, AUCSummationMethod.INTERPOLATION for the summation method, null for thresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      numThresholds - the number of thresholds to use when discretizing the roc curve. Values must be > 1.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(float[] thresholds, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric using DEFAULT_NAME for the metric name, AUCCurve.ROC for the curve type, AUCSummationMethod.INTERPOLATION for the summation method, null for numThresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      thresholds - Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(String name, int numThresholds, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric. using AUCCurve.ROC for the curve type, AUCSummationMethod.INTERPOLATION for the summation method, null for thresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      name - the name of the metric, if null defaults to DEFAULT_NAME
      numThresholds - the number of thresholds to use when discretizing the roc curve. Values must be > 1.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(String name, float[] thresholds, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric using null for numThresholds, AUCCurve.ROC for the curve type, AUCSummationMethod.INTERPOLATION for the summation method, DEFAULT_NUM_THRESHOLDS num thresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      name - the name of the metric, if null defaults to DEFAULT_NAME
      thresholds - Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(String name, int numThresholds, AUCCurve curve, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric using AUCSummationMethod.INTERPOLATION for the summation method, null for thresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      name - the name of the metric, if null defaults to DEFAULT_NAME
      numThresholds - the number of thresholds to use when discretizing the roc curve. Values must be > 1.
      curve - specifies the type of the curve to be computed, AUCCurve.ROC or AUCCurve.PR for the Precision-Recall-curve.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(String name, float[] thresholds, AUCCurve curve, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric using null for numThresholds, AUCSummationMethod.INTERPOLATION for the summation method, DEFAULT_NUM_THRESHOLDS num thresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      name - the name of the metric, if null defaults to DEFAULT_NAME
      thresholds - Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].
      curve - specifies the type of the curve to be computed, AUCCurve.ROC or AUCCurve.PR for the Precision-Recall-curve.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(int numThresholds, AUCCurve curve, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric using DEFAULT_NAME for the metric name, AUCSummationMethod.INTERPOLATION for the summation method, null for thresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      numThresholds - the number of thresholds to use when discretizing the roc curve. Values must be > 1.
      curve - specifies the type of the curve to be computed, AUCCurve.ROC or AUCCurve.PR for the Precision-Recall-curve.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(float[] thresholds, AUCCurve curve, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric using null for numThresholds, AUCSummationMethod.INTERPOLATION for the summation method, false for multiLabel, and null for labelWeights.
      Parameters:
      thresholds - Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].
      curve - specifies the type of the curve to be computed, AUCCurve.ROC or AUCCurve.PR for the Precision-Recall-curve.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(int numThresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric. using DEFAULT_NAME for the metric name,, null for thresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      numThresholds - the number of thresholds to use when discretizing the roc curve. Values must be > 1.
      curve - specifies the type of the curve to be computed, AUCCurve.ROC or AUCCurve.PR for the Precision-Recall-curve.
      summationMethod - Specifies the Riemann summation method used
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(float[] thresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric using DEFAULT_NAME for the metric name, null for numThresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      thresholds - Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].
      curve - specifies the type of the curve to be computed, AUCCurve.ROC or AUCCurve.PR for the Precision-Recall-curve.
      summationMethod - Specifies the Riemann summation method used
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(String name, int numThresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric. using null for thresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      name - the name of the metric, if null defaults to DEFAULT_NAME
      numThresholds - the number of thresholds to use when discretizing the roc curve. Values must be > 1.
      curve - specifies the type of the curve to be computed, AUCCurve.ROC or AUCCurve.PR for the Precision-Recall-curve.
      summationMethod - Specifies the Riemann summation method used,
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(String name, float[] thresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric. using null for the numThresholds, false for multiLabel, and null for labelWeights.
      Parameters:
      name - the name of the metric, if null defaults to DEFAULT_NAME
      thresholds - Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].
      curve - specifies the type of the curve to be computed, AUCCurve.ROC or AUCCurve.PR for the Precision-Recall-curve.
      summationMethod - Specifies the Riemann summation method used.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
    • AUC

      public AUC(String name, int numThresholds, AUCCurve curve, AUCSummationMethod summationMethod, float[] thresholds, boolean multiLabel, Operand<T> labelWeights, long seed, Class<T> type)
      Creates an AUC (Area under the curve) metric.
      Parameters:
      name - the name of the metric, if name is null then use DEFAULT_NAME.
      numThresholds - the number of thresholds to use when discretizing the roc curve. This includes the bracketing 0 and 1 thresholds, so the value must be ≥ 2.
      curve - specifies the type of the curve to be computed, AUCCurve.ROC or AUCCurve.PR for the Precision-Recall-curve.
      summationMethod - Specifies the Riemann summation method used
      thresholds - Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1]. This method automatically brackets the provided thresholds with a (-EPSILON) below and a (1 + EPSILON) above.
      multiLabel - boolean indicating whether multilabel data should be treated as such, wherein AUC is computed separately for each label and then averaged across labels, or (when false) if the data should be flattened into a single label before AUC computation. In the latter case, when multilabel data is passed to AUC, each label-prediction pair is treated as an individual data point. Should be set to false for multi-class data.
      labelWeights - non-negative weights used to compute AUCs for multilabel data. When multiLabel is true, the weights are applied to the individual label AUCs when they are averaged to produce the multi-label AUC. When it's false, they are used to weight the individual label predictions in computing the confusion matrix on the flattened data.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the confusion matrix variables.
      Throws:
      IllegalArgumentException - if numThresholds is less than 2 and thresholds is null, or if a threshold value is less than 0 or greater than 1.
  • Method Details

    • init

      protected void init(Ops tf)
      Initialize the TensorFlow Ops
      Specified by:
      init in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
      Creates a List of Operations to update the metric state based on labels and predictions.
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be cast to <T>. If multiLabel or if labelWeights != null , then Cx must be a single dimension.
      predictions - the predictions shape (N, Cx, P1?). Will be cast to T.
      sampleWeights - sample weights to be applied to values, may be null. Will be cast to <T>.
      Returns:
      a List of Operations to update the metric state
    • result

      public <U extends TNumber> Operand<U> result(Ops tf, Class<U> resultType)
      Gets the current result of the metric
      Type Parameters:
      U - the date type for the result
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      resultType - the data type for the result
      Returns:
      the result, possibly with control dependencies
    • resetStates

      public Op resetStates(Ops tf)
      Resets any state variables to their initial values
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Returns:
      the operation for doing the reset
    • getNumThresholds

      public int getNumThresholds()
      Returns:
      the numThresholds
    • getCurve

      public AUCCurve getCurve()
      Returns:
      the curve
    • getSummationMethod

      public AUCSummationMethod getSummationMethod()
      Returns:
      the summationMethod
    • getThresholds

      public float[] getThresholds()
      Returns:
      the thresholds
    • isMultiLabel

      public boolean isMultiLabel()
      Returns:
      the multiLabel
    • getNumLabels

      public Integer getNumLabels()
      Returns:
      the numLabels
    • setNumLabels

      public void setNumLabels(Integer numLabels)
      Parameters:
      numLabels - the numLabels to set
    • getLabelWeights

      public Operand<T> getLabelWeights()
      Returns:
      the labelWeights
    • getTruePositives

      public Variable<T> getTruePositives()
      Returns:
      the truePositives
    • getFalsePositives

      public Variable<T> getFalsePositives()
      Returns:
      the falsePositives
    • getTrueNegatives

      public Variable<T> getTrueNegatives()
      Returns:
      the trueNegatives
    • getFalseNegatives

      public Variable<T> getFalseNegatives()
      Returns:
      the falseNegatives
    • getTruePositivesName

      public String getTruePositivesName()
      Returns:
      the truePositivesName
    • getFalsePositivesName

      public String getFalsePositivesName()
      Returns:
      the falsePositivesName
    • getTrueNegativesName

      public String getTrueNegativesName()
      Returns:
      the trueNegativesName
    • getFalseNegativesName

      public String getFalseNegativesName()
      Returns:
      the falseNegativesName