Class CosineSimilarity<T extends TNumber>
java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.Mean<T>
org.tensorflow.framework.metrics.CosineSimilarity<T>
- Type Parameters:
T- The data type for the metric result.
- All Implemented Interfaces:
Metric
A metric that computes the cosine similarity metric between labels and predictions.
Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 indicates orthogonality and values closer to -1 indicate greater similarity. The values closer to 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where you try to maximize the proximity between predictions and targets. If either labels and predictions is a zero vector, cosine similarity will be 0 regardless of the proximity between predictions and targets.
loss = -sum(l2_norm(y_true) * l2_norm(y_pred))- See Also:
-
Field Summary
FieldsModifier and TypeFieldDescriptionthe variable that holds the count of the metric values.static final Stringstatic final intprotected org.tensorflow.framework.metrics.impl.LossMetricThe loss function interfaceprotected final MetricReductionthe variable that holds the total of the metric valuesstatic final String -
Constructor Summary
ConstructorsConstructorDescriptionCosineSimilarity(int[] axis, long seed, Class<T> type) Creates a CosineSimilarity metric usingClass.getSimpleName()for the metric name.CosineSimilarity(int axis, long seed, Class<T> type) Creates a metric that computes the cosine similarity metric between labels and predictions usingClass.getSimpleName()for the metric name.CosineSimilarity(long seed, Class<T> type) Creates a metric that computes the cosine similarity metric between labels and predictions with a default axis,DEFAULT_AXISand usingClass.getSimpleName()for the metric name.CosineSimilarity(String name, int[] axis, long seed, Class<T> type) Creates a CosineSimilarity metricCosineSimilarity(String name, int axis, long seed, Class<T> type) Creates a metric that computes the cosine similarity metric between labels and predictions.CosineSimilarity(String name, long seed, Class<T> type) Creates a metric that computes the cosine similarity metric between labels and predictions with a default axis,DEFAULT_AXIS -
Method Summary
Modifier and TypeMethodDescriptioncall(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Class<U> resultType) Computes the cosine similarity loss between labels and predictions.getCount()Gets the count variableGets the type for the variablesorg.tensorflow.framework.metrics.impl.LossMetricgetLoss()Gets the loss function.getTotal()Gets the total variableprotected voidInitialize the TensorFlow OpsresetStates(Ops tf) Resets any state variables to their initial valuesGets the current result of the metricprotected voidsetLoss(org.tensorflow.framework.metrics.impl.LossMetric loss) Sets the AbstractLoss function for this wrapper.updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Updates the metric variables based on the inputs.updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates Operations that update the state of the mean metric, by calling the loss function and passing the loss to the Mean metric to calculate the weighted mean of the loss over many iterations.Methods inherited from class BaseMetric
callOnce, checkIsGraph, getName, getSeed, getTF, getVariableName, isInitialized, setInitialized, setName, setTF, updateState, updateState
-
Field Details
-
DEFAULT_AXIS
public static final int DEFAULT_AXIS- See Also:
-
loss
protected org.tensorflow.framework.metrics.impl.LossMetric lossThe loss function interface -
TOTAL
- See Also:
-
COUNT
- See Also:
-
reduction
-
total
-
count
the variable that holds the count of the metric values. ForMetricReduction.WEIGHTED_MEAN, this count may be weighted
-
-
Constructor Details
-
CosineSimilarity
Creates a metric that computes the cosine similarity metric between labels and predictions with a default axis,DEFAULT_AXISand usingClass.getSimpleName()for the metric name.- Parameters:
seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the type for the variables and result
-
CosineSimilarity
Creates a metric that computes the cosine similarity metric between labels and predictions usingClass.getSimpleName()for the metric name.- Parameters:
axis- The dimension along which the cosine similarity is computed.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the type for the variables and result
-
CosineSimilarity
Creates a CosineSimilarity metric usingClass.getSimpleName()for the metric name.- Parameters:
axis- The dimension along which the cosine similarity is computed.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the type for the variables and result
-
CosineSimilarity
Creates a metric that computes the cosine similarity metric between labels and predictions with a default axis,DEFAULT_AXIS- Parameters:
name- the name of this metric, if null then metric name isClass.getSimpleName().seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the type for the variables and result
-
CosineSimilarity
Creates a metric that computes the cosine similarity metric between labels and predictions.- Parameters:
name- the name of this metric, if null then metric name isClass.getSimpleName().axis- The dimension along which the cosine similarity is computed.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the type for the variables and result
-
CosineSimilarity
Creates a CosineSimilarity metric- Parameters:
name- the name of this metric, if null then metric name isClass.getSimpleName().axis- The dimension along which the cosine similarity is computed.seed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the type for the variables and result
-
-
Method Details
-
call
public <U extends TNumber> Operand<U> call(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Class<U> resultType) Computes the cosine similarity loss between labels and predictions.- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.labels- the truth values or labelspredictions- the predictions- Returns:
- the cosine similarity loss
- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not encapsulate a Graph environment.
-
getLoss
public org.tensorflow.framework.metrics.impl.LossMetric getLoss()Gets the loss function.- Returns:
- the loss function.
-
setLoss
protected void setLoss(org.tensorflow.framework.metrics.impl.LossMetric loss) Sets the AbstractLoss function for this wrapper.- Parameters:
loss- the loss function.
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates Operations that update the state of the mean metric, by calling the loss function and passing the loss to the Mean metric to calculate the weighted mean of the loss over many iterations.- Specified by:
updateStateListin interfaceMetric- Overrides:
updateStateListin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.labels- the truth values or labelspredictions- the predictionssampleWeights- Optional sampleWeights acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sampleWeights vector. If the shape of sampleWeights is [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss functions reduce by 1 dimension, usually axis=-1.)- Returns:
- a List of control operations that updates the Mean state variables.
- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not encapsulate a Graph environment.
-
init
Initialize the TensorFlow Ops- Specified by:
initin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.
-
resetStates
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Updates the metric variables based on the inputs. At least one input arg required forvalues, an optional additional input for thesampleWeights- Specified by:
updateStateListin interfaceMetric- Overrides:
updateStateListin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.values- the inputs to be passed to update state, this may not be nullsampleWeights- sample weights to be applied to values, will default to 1 if null.- Returns:
- the result with a control dependency on update state Operands
- Throws:
IllegalArgumentException- if values is null
-
result
Gets the current result of the metric- Type Parameters:
U- the date type for the result- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.type- the data type for the result- Returns:
- the result, possibly with control dependencies
-
getTotal
-
getCount
-
getInternalType
-