Class BinaryCrossentropy<T extends TNumber>

Type Parameters:
T - The data type for the metric result
All Implemented Interfaces:
Metric

public class BinaryCrossentropy<T extends TNumber> extends Mean<T>
A Metric that computes the binary cross-entropy loss between true labels and predicted labels.

This is the crossentropy metric class to be used when there are only two label classes (0 and 1).

  • Field Details

  • Constructor Details

    • BinaryCrossentropy

      public BinaryCrossentropy(boolean fromLogits, float labelSmoothing, long seed, Class<T> type)
      Creates a BinaryCrossentropy metric where name is Class.getSimpleName().
      Parameters:
      fromLogits - Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution.
      labelSmoothing - value used to smooth labels, When 0, no smoothing occurs. When > 0, compute the loss between the predicted labels and a smoothed version of the true labels, where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to heavier smoothing.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the type for the variables and result
    • BinaryCrossentropy

      public BinaryCrossentropy(String name, boolean fromLogits, float labelSmoothing, long seed, Class<T> type)
      Creates a BinaryCrossentropy metric
      Parameters:
      name - the name of this metric, if null then metric name is Class.getSimpleName().
      fromLogits - Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution.
      labelSmoothing - value used to smooth labels, When 0, no smoothing occurs. When > 0, compute the loss between the predicted labels and a smoothed version of the true labels, where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to heavier smoothing.
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the type for the variables and result
  • Method Details

    • call

      public <U extends TNumber> Operand<U> call(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Class<U> resultType)
      Computes the binary crossentropy loss between labels and predictions.
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - the truth values or labels, has the same shape as predictions and shape = [batch_size, d0, .. dN].
      predictions - the predictions, shape = [batch_size, d0, .. dN].
      Returns:
      Binary crossentropy loss value. shape = [batch_size, d0, .. dN-1].
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not encapsulate a Graph environment.
    • getLoss

      public org.tensorflow.framework.metrics.impl.LossMetric getLoss()
      Gets the loss function.
      Returns:
      the loss function.
    • setLoss

      protected void setLoss(org.tensorflow.framework.metrics.impl.LossMetric loss)
      Sets the AbstractLoss function for this wrapper.
      Parameters:
      loss - the loss function.
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
      Creates Operations that update the state of the mean metric, by calling the loss function and passing the loss to the Mean metric to calculate the weighted mean of the loss over many iterations.
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - the truth values or labels
      predictions - the predictions
      sampleWeights - Optional sampleWeights acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sampleWeights vector. If the shape of sampleWeights is [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss functions reduce by 1 dimension, usually axis=-1.)
      Returns:
      a List of control operations that updates the Mean state variables.
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not encapsulate a Graph environment.
    • init

      protected void init(Ops tf)
      Initialize the TensorFlow Ops
      Specified by:
      init in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
    • resetStates

      public Op resetStates(Ops tf)
      Resets any state variables to their initial values
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Returns:
      the operation for doing the reset
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
      Updates the metric variables based on the inputs. At least one input arg required for values, an optional additional input for the sampleWeights
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      values - the inputs to be passed to update state, this may not be null
      sampleWeights - sample weights to be applied to values, will default to 1 if null.
      Returns:
      the result with a control dependency on update state Operands
      Throws:
      IllegalArgumentException - if values is null
    • result

      public <U extends TNumber> Operand<U> result(Ops tf, Class<U> type)
      Gets the current result of the metric
      Type Parameters:
      U - the date type for the result
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      type - the data type for the result
      Returns:
      the result, possibly with control dependencies
    • getTotal

      public Variable<T> getTotal()
      Gets the total variable
      Returns:
      the total variable
    • getCount

      public Variable<T> getCount()
      Gets the count variable
      Returns:
      the count variable
    • getInternalType

      public Class<T> getInternalType()
      Gets the type for the variables
      Returns:
      the type for the variables