Class AUC<T extends TNumber>
- Type Parameters:
T- The data type for the metric result
- All Implemented Interfaces:
Metric
This metric creates four local variables, truePositives, trueNegatives ,
falsePositives and falseNegatives that are used to compute the AUC. To discretize
the AUC curve, a linearly spaced set of thresholds is used to compute pairs of recall and
precision values. The area under the ROC-curve is therefore computed using the height of the
recall values by the false positive rate, while the area under the PR-curve is the computed using
the height of the precision values by the recall.
This value is ultimately returned as auc, an idempotent operation that computes the
area under a discretized curve of precision versus recall values (computed using the
aforementioned variables). The numThresholds variable controls the degree of
discretization with larger numbers of thresholds more closely approximating the true AUC. The
quality of the approximation may vary dramatically depending on numThresholds. The
thresholds parameter can be used to manually specify thresholds which split the predictions more
evenly.
For best results, predictions should be distributed approximately uniformly in the
range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor if
this is not the case. Setting summationMethod to minoring or majoring can
help quantify the error in the approximation by providing lower or upper bound estimate of the
AUC.
Usage:
AUC m = new org.tensorflow.framework.metrics.AUC( tf, 3);
m.updateState( tf.constant(new float[] {0, 0, 1,1}),
tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
// threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
// tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
// recall = [1, 0.5, 0], fpRate = [1, 0, 0]
// auc = ((((1+0.5)/2)*(1-0))+ (((0.5+0)/2)*(0-0))) = 0.75
Operand<TFloat32> result = m.result();
System.out.println(result.data().getFloat());
0.75
m.resetStates()
m.updateState( tf.constant(new float[] {0, 0, 1, 1}),
tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}, ),
tf.constant(new float[] {1, 0, 0, 1}));
result = m.result();
System.out.println(result.data().getFloat());
1.0
-
Field Summary
Fields -
Constructor Summary
ConstructorsConstructorDescriptionCreates an AUC (Area under the curve) metric usingDEFAULT_NAMEfor the metric name,AUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor numThresholds,falsefor multiLabel, andnullfor labelWeights.Creates an AUC (Area under the curve) metric usingnullfor numThresholds,AUCSummationMethod.INTERPOLATIONfor the summation method,falsefor multiLabel, andnullfor labelWeights.AUC(float[] thresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type) Creates an AUC (Area under the curve) metric usingDEFAULT_NAMEfor the metric name,nullfor numThresholds,falsefor multiLabel, andnullfor labelWeights.Creates an AUC (Area under the curve) metric usingDEFAULT_NAMEfor the metric name,AUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.Creates an AUC (Area under the curve) metric usingDEFAULT_NAMEfor the metric name,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.AUC(int numThresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type) Creates an AUC (Area under the curve) metric. usingDEFAULT_NAMEfor the metric name,,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.Creates an AUC (Area under the curve) metric usingDEFAULT_NAMEfor the metric name,DEFAULT_NUM_THRESHOLDSfor the numThresholds,AUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.Creates an AUC (Area under the curve) metric usingnullfor numThresholds,AUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,DEFAULT_NUM_THRESHOLDSnum thresholds,falsefor multiLabel, andnullfor labelWeights.Creates an AUC (Area under the curve) metric usingnullfor numThresholds,AUCSummationMethod.INTERPOLATIONfor the summation method,DEFAULT_NUM_THRESHOLDSnum thresholds,falsefor multiLabel, andnullfor labelWeights.AUC(String name, float[] thresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type) Creates an AUC (Area under the curve) metric. usingnullfor the numThresholds,falsefor multiLabel, andnullfor labelWeights.Creates an AUC (Area under the curve) metric. usingAUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.Creates an AUC (Area under the curve) metric usingAUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.AUC(String name, int numThresholds, AUCCurve curve, AUCSummationMethod summationMethod, float[] thresholds, boolean multiLabel, Operand<T> labelWeights, long seed, Class<T> type) Creates an AUC (Area under the curve) metric.AUC(String name, int numThresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type) Creates an AUC (Area under the curve) metric. usingnullfor thresholds,falsefor multiLabel, andnullfor labelWeights.Creates an AUC (Area under the curve) metric usingDEFAULT_NUM_THRESHOLDSfor the numThresholds,AUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights. -
Method Summary
Modifier and TypeMethodDescriptiongetCurve()intfloat[]protected voidInitialize the TensorFlow OpsbooleanresetStates(Ops tf) Resets any state variables to their initial valuesGets the current result of the metricvoidsetNumLabels(Integer numLabels) updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates a List of Operations to update the metric state based on labels and predictions.Methods inherited from class BaseMetric
callOnce, checkIsGraph, getName, getSeed, getTF, getVariableName, isInitialized, setInitialized, setName, setTF, updateState, updateState, updateStateList
-
Field Details
-
EPSILON
public static final float EPSILONDefault Fuzz factor.- See Also:
-
TRUE_POSITIVES
- See Also:
-
FALSE_POSITIVES
- See Also:
-
TRUE_NEGATIVES
- See Also:
-
FALSE_NEGATIVES
- See Also:
-
DEFAULT_NUM_THRESHOLDS
public static final int DEFAULT_NUM_THRESHOLDS- See Also:
-
DEFAULT_NAME
- See Also:
-
-
Constructor Details
-
AUC
Creates an AUC (Area under the curve) metric usingDEFAULT_NAMEfor the metric name,DEFAULT_NUM_THRESHOLDSfor the numThresholds,AUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
Creates an AUC (Area under the curve) metric usingDEFAULT_NUM_THRESHOLDSfor the numThresholds,AUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
name- the name of the metric, ifnulldefaults toDEFAULT_NAMEseed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
Creates an AUC (Area under the curve) metric usingDEFAULT_NAMEfor the metric name,AUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
numThresholds- the number of thresholds to use when discretizing the roc curve. Values must be > 1.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
Creates an AUC (Area under the curve) metric usingDEFAULT_NAMEfor the metric name,AUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor numThresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
thresholds- Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
Creates an AUC (Area under the curve) metric. usingAUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
name- the name of the metric, ifnulldefaults toDEFAULT_NAMEnumThresholds- the number of thresholds to use when discretizing the roc curve. Values must be > 1.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
Creates an AUC (Area under the curve) metric usingnullfor numThresholds,AUCCurve.ROCfor the curve type,AUCSummationMethod.INTERPOLATIONfor the summation method,DEFAULT_NUM_THRESHOLDSnum thresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
name- the name of the metric, ifnulldefaults toDEFAULT_NAMEthresholds- Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
Creates an AUC (Area under the curve) metric usingAUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
name- the name of the metric, ifnulldefaults toDEFAULT_NAMEnumThresholds- the number of thresholds to use when discretizing the roc curve. Values must be > 1.curve- specifies the type of the curve to be computed,AUCCurve.ROCorAUCCurve.PRfor the Precision-Recall-curve.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
Creates an AUC (Area under the curve) metric usingnullfor numThresholds,AUCSummationMethod.INTERPOLATIONfor the summation method,DEFAULT_NUM_THRESHOLDSnum thresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
name- the name of the metric, ifnulldefaults toDEFAULT_NAMEthresholds- Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].curve- specifies the type of the curve to be computed,AUCCurve.ROCorAUCCurve.PRfor the Precision-Recall-curve.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
Creates an AUC (Area under the curve) metric usingDEFAULT_NAMEfor the metric name,AUCSummationMethod.INTERPOLATIONfor the summation method,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
numThresholds- the number of thresholds to use when discretizing the roc curve. Values must be > 1.curve- specifies the type of the curve to be computed,AUCCurve.ROCorAUCCurve.PRfor the Precision-Recall-curve.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
Creates an AUC (Area under the curve) metric usingnullfor numThresholds,AUCSummationMethod.INTERPOLATIONfor the summation method,falsefor multiLabel, andnullfor labelWeights.- Parameters:
thresholds- Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].curve- specifies the type of the curve to be computed,AUCCurve.ROCorAUCCurve.PRfor the Precision-Recall-curve.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
public AUC(int numThresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type) Creates an AUC (Area under the curve) metric. usingDEFAULT_NAMEfor the metric name,,nullfor thresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
numThresholds- the number of thresholds to use when discretizing the roc curve. Values must be > 1.curve- specifies the type of the curve to be computed,AUCCurve.ROCorAUCCurve.PRfor the Precision-Recall-curve.summationMethod- Specifies the Riemann summation method usedseed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
public AUC(float[] thresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type) Creates an AUC (Area under the curve) metric usingDEFAULT_NAMEfor the metric name,nullfor numThresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
thresholds- Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].curve- specifies the type of the curve to be computed,AUCCurve.ROCorAUCCurve.PRfor the Precision-Recall-curve.summationMethod- Specifies the Riemann summation method usedseed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
public AUC(String name, int numThresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type) Creates an AUC (Area under the curve) metric. usingnullfor thresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
name- the name of the metric, ifnulldefaults toDEFAULT_NAMEnumThresholds- the number of thresholds to use when discretizing the roc curve. Values must be > 1.curve- specifies the type of the curve to be computed,AUCCurve.ROCorAUCCurve.PRfor the Precision-Recall-curve.summationMethod- Specifies the Riemann summation method used,seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
public AUC(String name, float[] thresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class<T> type) Creates an AUC (Area under the curve) metric. usingnullfor the numThresholds,falsefor multiLabel, andnullfor labelWeights.- Parameters:
name- the name of the metric, ifnulldefaults toDEFAULT_NAMEthresholds- Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1].curve- specifies the type of the curve to be computed,AUCCurve.ROCorAUCCurve.PRfor the Precision-Recall-curve.summationMethod- Specifies the Riemann summation method used.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.
-
AUC
public AUC(String name, int numThresholds, AUCCurve curve, AUCSummationMethod summationMethod, float[] thresholds, boolean multiLabel, Operand<T> labelWeights, long seed, Class<T> type) Creates an AUC (Area under the curve) metric.- Parameters:
name- the name of the metric, if name is null then useDEFAULT_NAME.numThresholds- the number of thresholds to use when discretizing the roc curve. This includes the bracketing 0 and 1 thresholds, so the value must be ≥ 2.curve- specifies the type of the curve to be computed,AUCCurve.ROCorAUCCurve.PRfor the Precision-Recall-curve.summationMethod- Specifies the Riemann summation method usedthresholds- Optional values to use as the thresholds for discretizing the curve. If set, the numThresholds parameter is ignored. Values should be in [0, 1]. This method automatically brackets the providedthresholdswith a (-EPSILON) below and a (1 +EPSILON) above.multiLabel- boolean indicating whether multilabel data should be treated as such, wherein AUC is computed separately for each label and then averaged across labels, or (when false) if the data should be flattened into a single label before AUC computation. In the latter case, when multilabel data is passed to AUC, each label-prediction pair is treated as an individual data point. Should be set tofalsefor multi-class data.labelWeights- non-negative weights used to compute AUCs for multilabel data. WhenmultiLabelis true, the weights are applied to the individual label AUCs when they are averaged to produce the multi-label AUC. When it's false, they are used to weight the individual label predictions in computing the confusion matrix on the flattened data.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the confusion matrix variables.- Throws:
IllegalArgumentException- if numThresholds is less than 2 and thresholds is null, or if a threshold value is less than 0 or greater than 1.
-
-
Method Details
-
init
Initialize the TensorFlow Ops- Specified by:
initin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates a List of Operations to update the metric state based on labels and predictions.- Specified by:
updateStateListin interfaceMetric- Overrides:
updateStateListin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.labels- shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be cast to<T>. IfmultiLabelor iflabelWeights!= null, then Cx must be a single dimension.predictions- the predictions shape (N, Cx, P1?). Will be cast toT.sampleWeights- sample weights to be applied to values, may be null. Will be cast to<T>.- Returns:
- a List of Operations to update the metric state
-
result
Gets the current result of the metric- Type Parameters:
U- the date type for the result- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.resultType- the data type for the result- Returns:
- the result, possibly with control dependencies
-
resetStates
-
getNumThresholds
public int getNumThresholds()- Returns:
- the numThresholds
-
getCurve
- Returns:
- the curve
-
getSummationMethod
- Returns:
- the summationMethod
-
getThresholds
public float[] getThresholds()- Returns:
- the thresholds
-
isMultiLabel
public boolean isMultiLabel()- Returns:
- the multiLabel
-
getNumLabels
- Returns:
- the numLabels
-
setNumLabels
- Parameters:
numLabels- the numLabels to set
-
getLabelWeights
-
getTruePositives
-
getFalsePositives
-
getTrueNegatives
-
getFalseNegatives
-
getTruePositivesName
- Returns:
- the truePositivesName
-
getFalsePositivesName
- Returns:
- the falsePositivesName
-
getTrueNegativesName
- Returns:
- the trueNegativesName
-
getFalseNegativesName
- Returns:
- the falseNegativesName
-