Class Mean<T extends TNumber>
java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.Mean<T>
- Type Parameters:
T- The data type for the metric result
- All Implemented Interfaces:
Metric
- Direct Known Subclasses:
Accuracy, BinaryAccuracy, BinaryCrossentropy, CategoricalAccuracy, CategoricalCrossentropy, CategoricalHinge, CosineSimilarity, Hinge, KLDivergence, LogCoshError, MeanAbsoluteError, MeanAbsolutePercentageError, MeanRelativeError, MeanSquaredError, MeanSquaredLogarithmicError, Poisson, RootMeanSquaredError, SparseCategoricalAccuracy, SparseCategoricalCrossentropy, SparseTopKCategoricalAccuracy, SquaredHinge, TopKCategoricalAccuracy
A metric that that implements a weighted mean
MetricReduction.WEIGHTED_MEAN-
Field Summary
Fields -
Constructor Summary
ConstructorsConstructorDescriptionCreates a Reducible Metric with a metric reductions ofMetricReduction.SUMand usingClass.getSimpleName()for the metric name.Creates a Reducible Metric with a metric reductions ofMetricReduction.SUM -
Method Summary
Modifier and TypeMethodDescriptiongetCount()Gets the count variableGets the type for the variablesgetTotal()Gets the total variableprotected voidInitialize the TensorFlow OpsresetStates(Ops tf) Resets any state variables to their initial valuesGets the current result of the metricupdateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Updates the metric variables based on the inputs.Methods inherited from class BaseMetric
callOnce, checkIsGraph, getName, getSeed, getTF, getVariableName, isInitialized, setInitialized, setName, setTF, updateState, updateState, updateStateListModifier and TypeMethodDescriptioncallOnce(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights, Class<T> type) Calls update state once, followed by a call to get the resultprotected voidcheckIsGraph(Ops tf) Checks if the TensorFlow Ops encapsulates aGraphenvironment.getName()The name for this metric.longgetSeed()Gets the random number generator seed valueprotected OpsgetTF()Gets the TensorFlow Ops for this metricprotected StringgetVariableName(String varName) Gets a formatted name for a variable, in the formBaseMetric.name+ "_" + varName.booleanChecks whether the Metric is initialized or not.protected voidsetInitialized(boolean initialized) Sets the initialized indicatorvoidSets the metric nameprotected voidSets the TensorFlow Ops for this metric.final OpCreates a NoOp Operation with control dependencies to update the metric statefinal OpupdateState(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates a NoOp Operation with control dependencies to update the metric stateupdateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates a List of Operations to update the metric state based on labels and predictions.
-
Field Details
-
TOTAL
- See Also:
-
COUNT
- See Also:
-
reduction
-
total
-
count
the variable that holds the count of the metric values. ForMetricReduction.WEIGHTED_MEAN, this count may be weighted
-
-
Constructor Details
-
Mean
Creates a Reducible Metric with a metric reductions ofMetricReduction.SUMand usingClass.getSimpleName()for the metric name.- Parameters:
seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the type for the variables and result
-
Mean
Creates a Reducible Metric with a metric reductions ofMetricReduction.SUM- Parameters:
name- the name for this metric. If null, name defaults toClass.getSimpleName().seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the type for the variables and result
-
-
Method Details
-
init
Initialize the TensorFlow Ops- Specified by:
initin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.
-
resetStates
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Updates the metric variables based on the inputs. At least one input arg required forvalues, an optional additional input for thesampleWeights- Specified by:
updateStateListin interfaceMetric- Overrides:
updateStateListin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.values- the inputs to be passed to update state, this may not be nullsampleWeights- sample weights to be applied to values, will default to 1 if null.- Returns:
- the result with a control dependency on update state Operands
- Throws:
IllegalArgumentException- if values is null
-
result
Gets the current result of the metric- Type Parameters:
U- the date type for the result- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.type- the data type for the result- Returns:
- the result, possibly with control dependencies
-
getTotal
-
getCount
-
getInternalType
-