Class Hinge<T extends TNumber>
java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.Mean<T>
org.tensorflow.framework.metrics.Hinge<T>
- Type Parameters:
T- The data type for the metric result.
- All Implemented Interfaces:
Metric
-
Field Summary
FieldsModifier and TypeFieldDescriptionthe variable that holds the count of the metric values.static final Stringprotected org.tensorflow.framework.metrics.impl.LossMetricThe loss function interfaceprotected final MetricReductionthe variable that holds the total of the metric valuesstatic final String -
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptioncall(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Class<U> resultType) Computes the hinge loss between labels and predictions.getCount()Gets the count variableGets the type for the variablesorg.tensorflow.framework.metrics.impl.LossMetricgetLoss()Gets the loss function.getTotal()Gets the total variableprotected voidInitialize the TensorFlow OpsresetStates(Ops tf) Resets any state variables to their initial valuesGets the current result of the metricprotected voidsetLoss(org.tensorflow.framework.metrics.impl.LossMetric loss) Sets the AbstractLoss function for this wrapper.updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Updates the metric variables based on the inputs.updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates Operations that update the state of the mean metric, by calling the loss function and passing the loss to the Mean metric to calculate the weighted mean of the loss over many iterations.Methods inherited from class BaseMetric
callOnce, checkIsGraph, getName, getSeed, getTF, getVariableName, isInitialized, setInitialized, setName, setTF, updateState, updateState
-
Field Details
-
loss
protected org.tensorflow.framework.metrics.impl.LossMetric lossThe loss function interface -
TOTAL
- See Also:
-
COUNT
- See Also:
-
reduction
-
total
-
count
the variable that holds the count of the metric values. ForMetricReduction.WEIGHTED_MEAN, this count may be weighted
-
-
Constructor Details
-
Hinge
Creates a Hinge metric usingClass.getSimpleName()for the metric name.- Parameters:
seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the type for the variables and result
-
Hinge
Creates a Hinge metric- Parameters:
name- the name of this metric, if null then metric name isClass.getSimpleName().seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the type for the variables and result
-
-
Method Details
-
call
public <U extends TNumber> Operand<U> call(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Class<U> resultType) Computes the hinge loss between labels and predictions.- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.labels- the truth values or labels, shape =[batch_size, d0, .. dN].predictions- the predictions, shape =[batch_size, d0, .. dN].- Returns:
- the hinge loss between labels and predictions.
- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not encapsulate a Graph environment.
-
getLoss
public org.tensorflow.framework.metrics.impl.LossMetric getLoss()Gets the loss function.- Returns:
- the loss function.
-
setLoss
protected void setLoss(org.tensorflow.framework.metrics.impl.LossMetric loss) Sets the AbstractLoss function for this wrapper.- Parameters:
loss- the loss function.
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates Operations that update the state of the mean metric, by calling the loss function and passing the loss to the Mean metric to calculate the weighted mean of the loss over many iterations.- Specified by:
updateStateListin interfaceMetric- Overrides:
updateStateListin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.labels- the truth values or labelspredictions- the predictionssampleWeights- Optional sampleWeights acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sampleWeights vector. If the shape of sampleWeights is [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss functions reduce by 1 dimension, usually axis=-1.)- Returns:
- a List of control operations that updates the Mean state variables.
- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not encapsulate a Graph environment.
-
init
Initialize the TensorFlow Ops- Specified by:
initin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.
-
resetStates
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Updates the metric variables based on the inputs. At least one input arg required forvalues, an optional additional input for thesampleWeights- Specified by:
updateStateListin interfaceMetric- Overrides:
updateStateListin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.values- the inputs to be passed to update state, this may not be nullsampleWeights- sample weights to be applied to values, will default to 1 if null.- Returns:
- the result with a control dependency on update state Operands
- Throws:
IllegalArgumentException- if values is null
-
result
Gets the current result of the metric- Type Parameters:
U- the date type for the result- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.type- the data type for the result- Returns:
- the result, possibly with control dependencies
-
getTotal
-
getCount
-
getInternalType
-