Class MeanIoU<T extends TNumber>
java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.MeanIoU<T>
- Type Parameters:
T- The data type for the metric result
- All Implemented Interfaces:
Metric
Computes the mean Intersection-Over-Union metric.
Mean Intersection-Over-Union is a common evaluation metric for semantic image segmentation,
which first computes the IOU for each semantic class and then computes the average over classes.
IOU is defined as follows: IOU = true_positive / (true_positive + false_positive +
false_negative). The predictions are accumulated in a confusion matrix, weighted by
sample_weight and the metric is then calculated from it.
If sampleWeight is null, weights default to 1. Use sample_weight of 0 to mask
values.
-
Field Summary
Fields -
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionprotected voidInitialize the TensorFlow OpsresetStates(Ops tf) Resets any state variables to their initial valuesGets the current result of the metricupdateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Accumulates the confusion matrix statistics.Methods inherited from class BaseMetric
callOnce, checkIsGraph, getName, getSeed, getTF, getVariableName, isInitialized, setInitialized, setName, setTF, updateState, updateState, updateStateListModifier and TypeMethodDescriptioncallOnce(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights, Class<T> type) Calls update state once, followed by a call to get the resultprotected voidcheckIsGraph(Ops tf) Checks if the TensorFlow Ops encapsulates aGraphenvironment.getName()The name for this metric.longgetSeed()Gets the random number generator seed valueprotected OpsgetTF()Gets the TensorFlow Ops for this metricprotected StringgetVariableName(String varName) Gets a formatted name for a variable, in the formBaseMetric.name+ "_" + varName.booleanChecks whether the Metric is initialized or not.protected voidsetInitialized(boolean initialized) Sets the initialized indicatorvoidSets the metric nameprotected voidSets the TensorFlow Ops for this metric.final OpCreates a NoOp Operation with control dependencies to update the metric statefinal OpupdateState(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Creates a NoOp Operation with control dependencies to update the metric stateupdateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) Creates a List of Operations to update the metric state based on input values.
-
Field Details
-
TOTAL_CONFUSION_MATRIX
- See Also:
-
-
Constructor Details
-
MeanIoU
Creates a metric MeanIoU, using name asClass.getSimpleName()- Parameters:
numClasses- The possible number of labels the prediction task can haveseed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the variables
-
MeanIoU
Creates a MeanIoU metric- Parameters:
name- the name of the metric, if null thenClass.getSimpleName()is usednumClasses- The possible number of labels the prediction task can haveseed- the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.type- the data type for the variables
-
-
Method Details
-
init
Initialize the TensorFlow Ops- Specified by:
initin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.
-
resetStates
-
updateStateList
public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights) Accumulates the confusion matrix statistics.- Specified by:
updateStateListin interfaceMetric- Overrides:
updateStateListin classBaseMetric- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.labels- the labelspredictions- the predictionssampleWeights- Optional weighting of each example. Defaults to 1, if null. Rank is either 0, or the same rank as labels, and must be broadcastable to labels.- Returns:
- the Operands that updates totalConfusionMatrix variable
- Throws:
IllegalArgumentException- if the TensorFlow Ops scope does not encapsulate a Graph environment.IllegalArgumentException- if the weights rank is not 0, and weights rank @{code !=} labels rank, and if the predictions size is not equal to the labels size
-
result
Gets the current result of the metric- Type Parameters:
U- the date type for the result- Parameters:
tf- the TensorFlow Ops encapsulating aGraphenvironment.resultType- the data type for the result- Returns:
- the result, possibly with control dependencies
-