Class BinaryAccuracy<T extends TNumber>

Type Parameters:
T - The data type for the metric result
All Implemented Interfaces:
Metric

public class BinaryAccuracy<T extends TNumber> extends Mean<T>
Metric that calculates how often predictions matches binary labels.

This metric creates two local variables, total and count that are used to compute the frequency with which predictions matches labels. This frequency is ultimately returned as binary accuracy: an idempotent operation that simply divides total by count.

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values.

  • Field Details

  • Constructor Details

    • BinaryAccuracy

      public BinaryAccuracy(long seed, Class<T> type)
      Creates a BinaryAccuracy Metric using Class.getSimpleName() for the metric name and DEFAULT_THRESHOLD for the threshold value.
      Parameters:
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • BinaryAccuracy

      public BinaryAccuracy(float threshold, long seed, Class<T> type)
      Creates a BinaryAccuracy Metric using Class.getSimpleName() for the metric name
      Parameters:
      threshold - a threshold for deciding whether prediction values are 1 or 0
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • BinaryAccuracy

      public BinaryAccuracy(String name, float threshold, long seed, Class<T> type)
      Creates a BinaryAccuracy Metric
      Parameters:
      name - the name of the metric, if null then Class.getSimpleName() is used
      threshold - a threshold for deciding whether prediction values are 1 or 0
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
  • Method Details

    • call

      public <U extends TNumber> Operand<U> call(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Class<U> resultType)
      Calculates how often predictions match binary labels.
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - the truth values or labels, shape = [batch_size, d0, .. dN].
      predictions - the predictions, shape = [batch_size, d0, .. dN].
      Returns:
      Binary accuracy values. shape = [batch_size, d0, .. dN-1]
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not encapsulate a Graph environment.
    • getLoss

      public org.tensorflow.framework.metrics.impl.LossMetric getLoss()
      Gets the loss function.
      Returns:
      the loss function.
    • setLoss

      protected void setLoss(org.tensorflow.framework.metrics.impl.LossMetric loss)
      Sets the AbstractLoss function for this wrapper.
      Parameters:
      loss - the loss function.
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
      Creates Operations that update the state of the mean metric, by calling the loss function and passing the loss to the Mean metric to calculate the weighted mean of the loss over many iterations.
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - the truth values or labels
      predictions - the predictions
      sampleWeights - Optional sampleWeights acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sampleWeights vector. If the shape of sampleWeights is [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss functions reduce by 1 dimension, usually axis=-1.)
      Returns:
      a List of control operations that updates the Mean state variables.
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not encapsulate a Graph environment.
    • init

      protected void init(Ops tf)
      Initialize the TensorFlow Ops
      Specified by:
      init in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
    • resetStates

      public Op resetStates(Ops tf)
      Resets any state variables to their initial values
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Returns:
      the operation for doing the reset
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
      Updates the metric variables based on the inputs. At least one input arg required for values, an optional additional input for the sampleWeights
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      values - the inputs to be passed to update state, this may not be null
      sampleWeights - sample weights to be applied to values, will default to 1 if null.
      Returns:
      the result with a control dependency on update state Operands
      Throws:
      IllegalArgumentException - if values is null
    • result

      public <U extends TNumber> Operand<U> result(Ops tf, Class<U> type)
      Gets the current result of the metric
      Type Parameters:
      U - the date type for the result
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      type - the data type for the result
      Returns:
      the result, possibly with control dependencies
    • getTotal

      public Variable<T> getTotal()
      Gets the total variable
      Returns:
      the total variable
    • getCount

      public Variable<T> getCount()
      Gets the count variable
      Returns:
      the count variable
    • getInternalType

      public Class<T> getInternalType()
      Gets the type for the variables
      Returns:
      the type for the variables