Class BaseMetric
java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
- All Implemented Interfaces:
Metric
- Direct Known Subclasses:
AUC, FalseNegatives, FalsePositives, Mean, MeanIoU, MeanTensor, Precision, PrecisionAtRecall, Recall, RecallAtPrecision, SensitivityAtSpecificity, SpecificityAtSensitivity, Sum, TrueNegatives, TruePositives
-
Constructor Summary
ConstructorsModifierConstructorDescriptionprotectedBaseMetric(long seed) Creates a Metric with a name ofClass.getSimpleName()protectedBaseMetric(String name, long seed) Creates a Metric -
Method Summary
Modifier and TypeMethodDescriptioncallOnce(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights, Class<T> type) Calls update state once, followed by a call to get the resultprotected voidcheckIsGraph(Ops tf) Checks if the TensorFlow Ops encapsulates aGraphenvironment.getName()The name for this metric.longgetSeed()Gets the random number generator seed valueprotected OpsgetTF()Gets the TensorFlow Ops for this metricprotected StringgetVariableName(String varName) Gets a formatted name for a variable, in the formname+ "_" + varName.protected abstract voidInitialize the TensorFlow OpsbooleanChecks whether the Metric is initialized or not.protected voidsetInitialized(boolean initialized) Sets the initialized indicatorvoidSets the metric nameprotected voidSets the TensorFlow Ops for this metric.final OpCreates a NoOp Operation with control dependencies to update the metric statefinal OpupdateState(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates a NoOp Operation with control dependencies to update the metric stateupdateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Creates a List of Operations to update the metric state based on input values.updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates a List of Operations to update the metric state based on labels and predictions.Methods inherited from class Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitMethods inherited from interface Metric
resetStates, result
-
Constructor Details
-
BaseMetric
protected BaseMetric(long seed) Creates a Metric with a name ofClass.getSimpleName()- Parameters:
seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
-
BaseMetric
Creates a Metric- Parameters:
name- the name for this metric. If null, name defaults toClass.getSimpleName().seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
-
-
Method Details
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Creates a List of Operations to update the metric state based on input values.This is an empty implementation that should be overridden in a subclass, if needed.
- Specified by:
updateStateListin interfaceMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.values- the inputs to be passed to update state, this may not be nullsampleWeights- sample weights to be applied to the values, may be null.- Returns:
- a List of Operations to update the metric state
- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not have a Graph environment.
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates a List of Operations to update the metric state based on labels and predictions.This is an empty implementation that should be overridden in a subclass, if needed.
- Specified by:
updateStateListin interfaceMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.labels- the labelspredictions- the predictionssampleWeights- sample weights to be applied to the metric values, may be null.- Returns:
- a List of Operations to update the metric state
- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not have a Graph environment.
-
updateState
public final Op updateState(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Creates a NoOp Operation with control dependencies to update the metric state- Specified by:
updateStatein interfaceMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.values- the inputs to be passed to update state, this may not be nullsampleWeights- sample weights to be applied to the values, may be null.- Returns:
- the Operation to update the metric state
- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not have a Graph environment.
-
updateState
public final Op updateState(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates a NoOp Operation with control dependencies to update the metric state- Specified by:
updateStatein interfaceMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.labels- the labelspredictions- the predictionssampleWeights- sample weights to be applied to the metric values, may be null.- Returns:
- the Operation to update the metric state
- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not have a Graph environment.
-
callOnce
public final <T extends TNumber> Operand<T> callOnce(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights, Class<T> type) Calls update state once, followed by a call to get the result- Specified by:
callOncein interfaceMetric- Type Parameters:
T- The data type for the metric result- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.values- the inputs to be passed to update state, this may not be nullsampleWeights- sample weights to be applied to the values, may be null.type- the data type for the result- Returns:
- the result, possibly with control dependencies
- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not have a Graph environment.
-
getVariableName
-
getName
The name for this metric. Defaults toClass.getSimpleName().Gets the name of this metric.
- Returns:
- the name of this metric
-
setName
-
getSeed
public long getSeed()Gets the random number generator seed value- Returns:
- the random number generator seed value
-
init
Initialize the TensorFlow Ops- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.- Throws:
IllegalArgumentException- if the TensorFlow Ops does not have a Graph environment,
-
getTF
Gets the TensorFlow Ops for this metric- Returns:
- the TensorFlow Ops for this metric.
-
setTF
Sets the TensorFlow Ops for this metric.This should be set from the
init(Ops)implementation.- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not have a Graph environment.
-
isInitialized
public boolean isInitialized()Checks whether the Metric is initialized or not.- Returns:
- true if the Metric has been initialized.
-
setInitialized
protected void setInitialized(boolean initialized) Sets the initialized indicator- Parameters:
initialized- the initialized indicator
-
checkIsGraph
Checks if the TensorFlow Ops encapsulates aGraphenvironment.- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not encapsulate a Graph environment.
-