Class SquaredHinge
java.lang.Object
org.tensorflow.framework.losses.SquaredHinge
- All Implemented Interfaces:
Loss
Computes the squared hinge loss between labels and predictions.
loss = square(maximum(1 - labels * predictions, 0))
labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided,
they will be converted to -1 or 1.
Standalone usage:
Operand<TFloat32> labels =
tf.constant(new float[][] {{0., 1.}, {0., 0.}});
Operand<TFloat32> predictions =
tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}});
SquaredHinge squaredHinge = new SquaredHinge(tf);
Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions);
// produces 1.86f
Calling with sample weight:
Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions,
sampleWeight);
// produces 0.73f
Using SUM reduction type:
SquaredHinge squaredHinge = new SquaredHinge(Reduction.SUM); Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions); // produces 3.72f
Using NONE reduction type:
SquaredHinge squaredHinge = new SquaredHinge(Reduction.NONE); Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions); // produces [1.46f, 2.26f]
-
Field Summary
Fields -
Constructor Summary
ConstructorsConstructorDescriptionCreates a Squared Hinge AbstractLoss usingClass.getSimpleName()as the loss name and a AbstractLoss Reduction ofAbstractLoss.REDUCTION_DEFAULTSquaredHinge(String name, Reduction reduction) Creates a Squared HingeSquaredHinge(Reduction reduction) Creates a Squared Hinge AbstractLoss usingClass.getSimpleName()as the loss name -
Method Summary
Modifier and TypeMethodDescriptionCalculates the lossGenerates an Operand that calculates the loss.getName()Gets the name for this lossGets the loss reduction
-
Field Details
-
REDUCTION_DEFAULT
-
reduction
-
-
Constructor Details
-
SquaredHinge
public SquaredHinge()Creates a Squared Hinge AbstractLoss usingClass.getSimpleName()as the loss name and a AbstractLoss Reduction ofAbstractLoss.REDUCTION_DEFAULT -
SquaredHinge
Creates a Squared Hinge AbstractLoss usingClass.getSimpleName()as the loss name- Parameters:
reduction- Type of Reduction to apply to the loss.
-
SquaredHinge
-
-
Method Details
-
call
public <T extends TNumber> Operand<T> call(Ops tf, Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) Generates an Operand that calculates the loss.If run in Graph mode, the computation will throw
TFInvalidArgumentExceptionif the label values are not in the set [-1., 0., 1.]. In Eager Mode, this call will throwIllegalArgumentException, if the label values are not in the set [-1., 0., 1.].- Type Parameters:
T- The data type of the predictions, sampleWeights and loss.- Parameters:
tf- the TensorFlow Opslabels- the truth values or labels, must be either -1, 0, or 1. Values are expected to be -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1.predictions- the predictions, values must be in the range [0. to 1.] inclusive.sampleWeights- Optional SampleWeights acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the SampleWeights vector. If the shape of SampleWeights is [batch_size, d0, .. dN-1] (or can be broadcast to this shape), then each loss element of predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss functions reduce by 1 dimension, usually axis=-1.)- Returns:
- the loss
- Throws:
IllegalArgumentException- if the predictions are outside the range [0.-1.].
-
call
public <T extends TNumber> Operand<T> call(Ops tf, Operand<? extends TNumber> labels, Operand<T> predictions) Calculates the loss- Type Parameters:
T- The data type of the predictions and loss.- Parameters:
tf- the TensorFlow Opslabels- the truth values or labelspredictions- the predictions- Returns:
- the loss
-
getReduction
-
getName
-