Class BinaryCrossentropy

java.lang.Object
org.tensorflow.framework.losses.BinaryCrossentropy
All Implemented Interfaces:
Loss

public class BinaryCrossentropy extends Object
Computes the cross-entropy loss between true labels and predicted labels.

Use this cross-entropy loss when there are only two label classes (assumed to be 0 and 1). For each example, there should be a single floating-point value per prediction.

Standalone usage:

   Operand<TFloat32> labels =
       tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
   Operand<TFloat32> predictions =
       tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}});
   BinaryCrossentropy bce = new BinaryCrossentropy(tf);
   Operand<TFloat32> result = bce.call(Ops tf, labels, predictions);
   // produces 0.815

Calling with sample weight:

   Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
   Operand<TFloat32> result = bce.call(Ops tf, labels, predictions, sampleWeight);
   // produces 0.458f

Using SUM reduction type:

   BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.SUM);
   Operand<TFloat32> result = bce.call(Ops tf, labels, predictions);
   // produces 1.630f

Using NONE reduction type:

   BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.NONE);
   Operand<TFloat32> result = bce.call(Ops tf, labels, predictions);
   // produces [0.916f, 0.714f]
  • Field Details

    • FROM_LOGITS_DEFAULT

      public static final boolean FROM_LOGITS_DEFAULT
      See Also:
    • LABEL_SMOOTHING_DEFAULT

      public static final float LABEL_SMOOTHING_DEFAULT
      See Also:
    • REDUCTION_DEFAULT

      public static final Reduction REDUCTION_DEFAULT
    • reduction

      protected final Reduction reduction
  • Constructor Details

    • BinaryCrossentropy

      public BinaryCrossentropy()
      Creates a Binary Crossentropy AbstractLoss using Class.getSimpleName() as the loss name, FROM_LOGITS_DEFAULT for fromLogits, LABEL_SMOOTHING_DEFAULT for labelSmoothing and a AbstractLoss Reduction of AbstractLoss.REDUCTION_DEFAULT
    • BinaryCrossentropy

      public BinaryCrossentropy(Reduction reduction)
      Creates a Binary Crossentropy loss using Class.getSimpleName() as the loss name, FROM_LOGITS_DEFAULT for fromLogits, and LABEL_SMOOTHING_DEFAULT for labelSmoothing
      Parameters:
      reduction - Type of Reduction to apply to the loss.
    • BinaryCrossentropy

      public BinaryCrossentropy(boolean fromLogits)
      Creates a Binary Crossentropy loss using using Class.getSimpleName() as the loss name, labelSmoothing of LABEL_SMOOTHING_DEFAULT, a reduction of AbstractLoss.REDUCTION_DEFAULT,
      Parameters:
      fromLogits - Whether to interpret predictions as a tensor of logit values
    • BinaryCrossentropy

      public BinaryCrossentropy(String name, boolean fromLogits)
      Creates a Binary Crossentropy loss using labelSmoothing of LABEL_SMOOTHING_DEFAULT a reduction of AbstractLoss.REDUCTION_DEFAULT.
      Parameters:
      name - the name of the loss
      fromLogits - Whether to interpret predictions as a tensor of logit values
    • BinaryCrossentropy

      public BinaryCrossentropy(boolean fromLogits, float labelSmoothing)
      Creates a Binary Crossentropy loss using using Class.getSimpleName() as the loss name, and a reduction of AbstractLoss.REDUCTION_DEFAULT.
      Parameters:
      fromLogits - Whether to interpret predictions as a tensor of logit values
      labelSmoothing - A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, compute the loss between the predicted labels and a smoothed version of the true labels, where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing correspond to heavier smoothing.
    • BinaryCrossentropy

      public BinaryCrossentropy(String name, boolean fromLogits, float labelSmoothing)
      Creates a Binary Crossentropy loss using a reduction of AbstractLoss.REDUCTION_DEFAULT.
      Parameters:
      name - the name of the loss
      fromLogits - Whether to interpret predictions as a tensor of logit values
      labelSmoothing - A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, compute the loss between the predicted labels and a smoothed version of the true labels, where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing correspond to heavier smoothing.
    • BinaryCrossentropy

      public BinaryCrossentropy(boolean fromLogits, float labelSmoothing, Reduction reduction)
      Creates a Binary Crossentropy loss
      Parameters:
      fromLogits - Whether to interpret predictions as a tensor of logit values
      labelSmoothing - A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, compute the loss between the predicted labels and a smoothed version of the true labels, where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing correspond to heavier smoothing.
      reduction - Type of Reduction to apply to the loss.
    • BinaryCrossentropy

      public BinaryCrossentropy(String name, boolean fromLogits, float labelSmoothing, Reduction reduction)
      Creates a Binary Crossentropy loss
      Parameters:
      name - the name of the loss
      fromLogits - Whether to interpret predictions as a tensor of logit values
      labelSmoothing - A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, compute the loss between the predicted labels and a smoothed version of the true labels, where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing correspond to heavier smoothing.
      reduction - Type of Reduction to apply to the loss.
      Throws:
      IllegalArgumentException - if labelSmoothing is not in the inclusive range of 0. - 1.
  • Method Details

    • call

      public <T extends TNumber> Operand<T> call(Ops tf, Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights)
      Generates an Operand that calculates the loss.

      If run in Graph mode, the computation will throw TFInvalidArgumentException if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call will throw IllegalArgumentException, if the predictions values are outside the range o [0. to 1.]

      Type Parameters:
      T - The data type of the predictions, sampleWeights and loss.
      Parameters:
      tf - the TensorFlow Ops
      labels - the truth values or labels
      predictions - the predictions, values must be in the range [0. to 1.] inclusive.
      sampleWeights - Optional SampleWeights acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the SampleWeights vector. If the shape of SampleWeights is [batch_size, d0, .. dN-1] (or can be broadcast to this shape), then each loss element of predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss functions reduce by 1 dimension, usually axis=-1.)
      Returns:
      the loss
      Throws:
      IllegalArgumentException - if the predictions are outside the range [0.-1.].
    • call

      public <T extends TNumber> Operand<T> call(Ops tf, Operand<? extends TNumber> labels, Operand<T> predictions)
      Calculates the loss
      Type Parameters:
      T - The data type of the predictions and loss.
      Parameters:
      tf - the TensorFlow Ops
      labels - the truth values or labels
      predictions - the predictions
      Returns:
      the loss
    • getReduction

      public Reduction getReduction()
      Gets the loss reduction
      Returns:
      the loss reduction
    • getName

      public String getName()
      Gets the name for this loss
      Returns:
      the name for this loss