Class KLDivergence
java.lang.Object
org.tensorflow.framework.losses.KLDivergence
- All Implemented Interfaces:
Loss
Computes Kullback-Leibler divergence loss between labels and predictions.
loss = labels * log(labels / predictions)
Standalone usage:
Operand<TFloat32> labels =
tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
Operand<TFloat32> predictions =
tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}});
KLDivergence kld = new KLDivergence();
Operand<TFloat32> result = kld.call(Ops tf, labels, predictions);
// produces 0.458
Calling with sample weight:
Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
Operand<TFloat32> result = kld.call(Ops tf, labels, predictions, sampleWeight);
// produces 0.366f
Using SUM reduction type:
KLDivergence kld = new KLDivergence(, Reduction.SUM); Operand<TFloat32> result = kld.call(Ops tf, labels, predictions); // produces 0.916f
Using NONE reduction type:
KLDivergence kld = new KLDivergence(, Reduction.NONE); Operand<TFloat32> result = kld.call(Ops tf, labels, predictions); // produces [0.916f, -3.08e-06f]
- See Also:
-
Field Summary
Fields -
Constructor Summary
ConstructorsConstructorDescriptionCreates a Kullback Leibler Divergence AbstractLoss usingClass.getSimpleName()as the loss name and a AbstractLoss Reduction ofAbstractLoss.REDUCTION_DEFAULTKLDivergence(String name, Reduction reduction) Creates a Kullback Leibler Divergence AbstractLossKLDivergence(Reduction reduction) Creates a Kullback Leibler Divergence AbstractLoss AbstractLoss usingClass.getSimpleName()as the loss name -
Method Summary
Modifier and TypeMethodDescriptionCalculates the lossGenerates an Operand that calculates the loss.getName()Gets the name for this lossGets the loss reduction
-
Field Details
-
REDUCTION_DEFAULT
-
reduction
-
-
Constructor Details
-
KLDivergence
public KLDivergence()Creates a Kullback Leibler Divergence AbstractLoss usingClass.getSimpleName()as the loss name and a AbstractLoss Reduction ofAbstractLoss.REDUCTION_DEFAULT -
KLDivergence
Creates a Kullback Leibler Divergence AbstractLoss AbstractLoss usingClass.getSimpleName()as the loss name- Parameters:
reduction- Type of Reduction to apply to the loss.
-
KLDivergence
-
-
Method Details
-
call
public <T extends TNumber> Operand<T> call(Ops tf, Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) Generates an Operand that calculates the loss.- Type Parameters:
T- The data type of the predictions, sampleWeights and loss.- Parameters:
tf- the TensorFlow Opslabels- the truth values or labelspredictions- the predictionssampleWeights- Optional sampleWeights acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the SampleWeights vector. If the shape of SampleWeights is [batch_size, d0, .. dN-1] (or can be broadcast to this shape), then each loss element of predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss functions reduce by 1 dimension, usually axis=-1.)- Returns:
- the loss
-
call
public <T extends TNumber> Operand<T> call(Ops tf, Operand<? extends TNumber> labels, Operand<T> predictions) Calculates the loss- Type Parameters:
T- The data type of the predictions and loss.- Parameters:
tf- the TensorFlow Opslabels- the truth values or labelspredictions- the predictions- Returns:
- the loss
-
getReduction
-
getName
-