Class MeanIoU<T extends TNumber>

java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.MeanIoU<T>
Type Parameters:
T - The data type for the metric result
All Implemented Interfaces:
Metric

public class MeanIoU<T extends TNumber> extends BaseMetric
Computes the mean Intersection-Over-Union metric.

Mean Intersection-Over-Union is a common evaluation metric for semantic image segmentation, which first computes the IOU for each semantic class and then computes the average over classes. IOU is defined as follows: IOU = true_positive / (true_positive + false_positive + false_negative). The predictions are accumulated in a confusion matrix, weighted by sample_weight and the metric is then calculated from it.

If sampleWeight is null, weights default to 1. Use sample_weight of 0 to mask values.

  • Field Details

  • Constructor Details

    • MeanIoU

      protected MeanIoU(long numClasses, long seed, Class<T> type)
      Creates a metric MeanIoU, using name as Class.getSimpleName()
      Parameters:
      numClasses - The possible number of labels the prediction task can have
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
    • MeanIoU

      protected MeanIoU(String name, long numClasses, long seed, Class<T> type)
      Creates a MeanIoU metric
      Parameters:
      name - the name of the metric, if null then Class.getSimpleName() is used
      numClasses - The possible number of labels the prediction task can have
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the data type for the variables
  • Method Details

    • init

      protected void init(Ops tf)
      Initialize the TensorFlow Ops
      Specified by:
      init in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
    • resetStates

      public Op resetStates(Ops tf)
      Resets any state variables to their initial values
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Returns:
      the operation for doing the reset
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
      Accumulates the confusion matrix statistics.
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - the labels
      predictions - the predictions
      sampleWeights - Optional weighting of each example. Defaults to 1, if null. Rank is either 0, or the same rank as labels, and must be broadcastable to labels.
      Returns:
      the Operands that updates totalConfusionMatrix variable
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not encapsulate a Graph environment.
      IllegalArgumentException - if the weights rank is not 0, and weights rank @{code !=} labels rank, and if the predictions size is not equal to the labels size
    • result

      public <U extends TNumber> Operand<U> result(Ops tf, Class<U> resultType)
      Gets the current result of the metric
      Type Parameters:
      U - the date type for the result
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      resultType - the data type for the result
      Returns:
      the result, possibly with control dependencies