Class BinaryAccuracy<T extends TNumber>
java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.Mean<T>
org.tensorflow.framework.metrics.BinaryAccuracy<T>
- Type Parameters:
T- The data type for the metric result
- All Implemented Interfaces:
Metric
Metric that calculates how often predictions matches binary labels.
This metric creates two local variables, total and count that are used to compute the
frequency with which predictions matches labels. This frequency is ultimately
returned as binary accuracy: an idempotent operation that simply divides total by count.
If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values.
-
Field Summary
FieldsModifier and TypeFieldDescriptionthe variable that holds the count of the metric values.static final Stringstatic final floatthe default threshold value for deciding whether prediction values are 1 or 0protected org.tensorflow.framework.metrics.impl.LossMetricThe loss function interfaceprotected final MetricReductionthe variable that holds the total of the metric valuesstatic final String -
Constructor Summary
ConstructorsConstructorDescriptionBinaryAccuracy(float threshold, long seed, Class<T> type) Creates a BinaryAccuracy Metric usingClass.getSimpleName()for the metric nameBinaryAccuracy(long seed, Class<T> type) Creates a BinaryAccuracy Metric usingClass.getSimpleName()for the metric name andDEFAULT_THRESHOLDfor the threshold value.BinaryAccuracy(String name, float threshold, long seed, Class<T> type) Creates a BinaryAccuracy Metric -
Method Summary
Modifier and TypeMethodDescriptioncall(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Class<U> resultType) Calculates how often predictions match binary labels.getCount()Gets the count variableGets the type for the variablesorg.tensorflow.framework.metrics.impl.LossMetricgetLoss()Gets the loss function.getTotal()Gets the total variableprotected voidInitialize the TensorFlow OpsresetStates(Ops tf) Resets any state variables to their initial valuesGets the current result of the metricprotected voidsetLoss(org.tensorflow.framework.metrics.impl.LossMetric loss) Sets the AbstractLoss function for this wrapper.updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Updates the metric variables based on the inputs.updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates Operations that update the state of the mean metric, by calling the loss function and passing the loss to the Mean metric to calculate the weighted mean of the loss over many iterations.Methods inherited from class BaseMetric
callOnce, checkIsGraph, getName, getSeed, getTF, getVariableName, isInitialized, setInitialized, setName, setTF, updateState, updateState
-
Field Details
-
DEFAULT_THRESHOLD
public static final float DEFAULT_THRESHOLDthe default threshold value for deciding whether prediction values are 1 or 0- See Also:
-
loss
protected org.tensorflow.framework.metrics.impl.LossMetric lossThe loss function interface -
TOTAL
- See Also:
-
COUNT
- See Also:
-
reduction
-
total
-
count
the variable that holds the count of the metric values. ForMetricReduction.WEIGHTED_MEAN, this count may be weighted
-
-
Constructor Details
-
BinaryAccuracy
Creates a BinaryAccuracy Metric usingClass.getSimpleName()for the metric name andDEFAULT_THRESHOLDfor the threshold value.- Parameters:
seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the variables
-
BinaryAccuracy
Creates a BinaryAccuracy Metric usingClass.getSimpleName()for the metric name- Parameters:
threshold- a threshold for deciding whether prediction values are 1 or 0seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the variables
-
BinaryAccuracy
Creates a BinaryAccuracy Metric- Parameters:
name- the name of the metric, if null thenClass.getSimpleName()is usedthreshold- a threshold for deciding whether prediction values are 1 or 0seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the variables
-
-
Method Details
-
call
public <U extends TNumber> Operand<U> call(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Class<U> resultType) Calculates how often predictions match binary labels.- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.labels- the truth values or labels, shape =[batch_size, d0, .. dN].predictions- the predictions, shape =[batch_size, d0, .. dN].- Returns:
- Binary accuracy values. shape =
[batch_size, d0, .. dN-1] - Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not encapsulate a Graph environment.
-
getLoss
public org.tensorflow.framework.metrics.impl.LossMetric getLoss()Gets the loss function.- Returns:
- the loss function.
-
setLoss
protected void setLoss(org.tensorflow.framework.metrics.impl.LossMetric loss) Sets the AbstractLoss function for this wrapper.- Parameters:
loss- the loss function.
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates Operations that update the state of the mean metric, by calling the loss function and passing the loss to the Mean metric to calculate the weighted mean of the loss over many iterations.- Specified by:
updateStateListin interfaceMetric- Overrides:
updateStateListin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.labels- the truth values or labelspredictions- the predictionssampleWeights- Optional sampleWeights acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sampleWeights vector. If the shape of sampleWeights is [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss functions reduce by 1 dimension, usually axis=-1.)- Returns:
- a List of control operations that updates the Mean state variables.
- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not encapsulate a Graph environment.
-
init
Initialize the TensorFlow Ops- Specified by:
initin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.
-
resetStates
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Updates the metric variables based on the inputs. At least one input arg required forvalues, an optional additional input for thesampleWeights- Specified by:
updateStateListin interfaceMetric- Overrides:
updateStateListin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.values- the inputs to be passed to update state, this may not be nullsampleWeights- sample weights to be applied to values, will default to 1 if null.- Returns:
- the result with a control dependency on update state Operands
- Throws:
IllegalArgumentException- if values is null
-
result
Gets the current result of the metric- Type Parameters:
U- the date type for the result- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.type- the data type for the result- Returns:
- the result, possibly with control dependencies
-
getTotal
-
getCount
-
getInternalType
-