Class PrecisionAtRecall<T extends TNumber>
java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.PrecisionAtRecall<T>
- Type Parameters:
T- The data type for the metric result
- All Implemented Interfaces:
Metric
Computes best precision where recall is >= specified value.
This metric creates four local variables, truePositives, trueNegatives, falsePositives and falseNegatives that are used to compute the precision at the given recall. The threshold for the given recall value is computed and used to evaluate the corresponding precision.
If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask
values.
-
Field Summary
Fields -
Constructor Summary
ConstructorsConstructorDescriptionPrecisionAtRecall(float recall, int numThresholds, long seed, Class<T> type) Creates a PrecisionRecall metric with a name ofClass.getSimpleName().PrecisionAtRecall(float recall, long seed, Class<T> type) Creates a PrecisionRecall metric with a name ofClass.getSimpleName()andDEFAULT_NUM_THRESHOLDSfor the number of thresholdsPrecisionAtRecall(String name, float recall, int numThresholds, long seed, Class<T> type) Creates a PrecisionRecall metric.PrecisionAtRecall(String name, float recall, long seed, Class<T> type) Creates a PrecisionRecall metric withDEFAULT_NUM_THRESHOLDSfor the number of thresholds -
Method Summary
Modifier and TypeMethodDescriptionGets the falseNegatives variableGets the falseNegatives variable nameGets the falsePositives variableGets the falsePositives variable nameintGets the numThresholdsfloatGets the recall valuefloat[]Gets the thresholdsGets the trueNegatives variableGets the trueNegatives variable nameGets the truePositives variableGets the truePositives variable namegetType()Gets the internalTypeprotected voidInitialize the TensorFlow OpsresetStates(Ops tf) Resets any state variables to their initial valuesGets the current result of the metricupdateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Accumulates confusion matrix statistics.Methods inherited from class BaseMetric
callOnce, checkIsGraph, getName, getSeed, getTF, getVariableName, isInitialized, setInitialized, setName, setTF, updateState, updateState, updateStateList
-
Field Details
-
DEFAULT_NUM_THRESHOLDS
public static final int DEFAULT_NUM_THRESHOLDS- See Also:
-
TRUE_POSITIVES
- See Also:
-
FALSE_POSITIVES
- See Also:
-
TRUE_NEGATIVES
- See Also:
-
FALSE_NEGATIVES
- See Also:
-
numThresholds
protected final int numThresholds -
thresholds
protected final float[] thresholds -
truePositives
-
falsePositives
-
trueNegatives
-
falseNegatives
-
-
Constructor Details
-
PrecisionAtRecall
Creates a PrecisionRecall metric with a name ofClass.getSimpleName()andDEFAULT_NUM_THRESHOLDSfor the number of thresholds- Parameters:
recall- the recall. A scalar value in range [0, 1]seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the variables- Throws:
IllegalArgumentException- if numThresholds <= 0 or if recall is not in the range [0-1].
-
PrecisionAtRecall
Creates a PrecisionRecall metric withDEFAULT_NUM_THRESHOLDSfor the number of thresholds- Parameters:
name- the name of the metric, if null defaults toClass.getSimpleName()recall- the recall. A scalar value in range [0, 1]seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the variables- Throws:
IllegalArgumentException- if numThresholds <= 0 or if recall is not in the range [0-1].
-
PrecisionAtRecall
Creates a PrecisionRecall metric with a name ofClass.getSimpleName().- Parameters:
recall- the recall. A scalar value in range [0, 1]numThresholds- Defaults to 200. The number of thresholds to use for matching the given recall.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the variables- Throws:
IllegalArgumentException- if numThresholds <= 0 or if recall is not in the range [0-1].
-
PrecisionAtRecall
Creates a PrecisionRecall metric.- Parameters:
name- the name of the metric, if null defaults toClass.getSimpleName()recall- the recall. A scalar value in range [0, 1]numThresholds- Defaults to 200. The number of thresholds to use for matching the given recall.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the variables- Throws:
IllegalArgumentException- if numThresholds <= 0 or if recall is not in the range [0-1].
-
-
Method Details
-
result
Gets the current result of the metric- Type Parameters:
U- the date type for the result- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.resultType- the data type for the result- Returns:
- the result, possibly with control dependencies
-
getRecall
public float getRecall()Gets the recall value- Returns:
- the recall value
-
init
Initialize the TensorFlow Ops- Specified by:
initin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Accumulates confusion matrix statistics.- Specified by:
updateStateListin interfaceMetric- Overrides:
updateStateListin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.labels- The ground truth values.predictions- the predictionssampleWeights- Optional weighting of each example. Defaults to 1. Rank is either 0, or the same rank as labels, and must be broadcastable to labels.- Returns:
- a List of Operations to update the metric state.
-
resetStates
-
getTruePositives
-
getFalsePositives
-
getTrueNegatives
-
getFalseNegatives
-
getNumThresholds
public int getNumThresholds()Gets the numThresholds- Returns:
- the numThresholds
-
getThresholds
public float[] getThresholds()Gets the thresholds- Returns:
- the thresholds
-
getTruePositivesName
Gets the truePositives variable name- Returns:
- the truePositivesName
-
getFalsePositivesName
Gets the falsePositives variable name- Returns:
- the falsePositivesName
-
getTrueNegativesName
Gets the trueNegatives variable name- Returns:
- the trueNegativesName
-
getFalseNegativesName
Gets the falseNegatives variable name- Returns:
- the falseNegativesName
-
getType
-
getInternalType
-