Class BaseMetric

java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
All Implemented Interfaces:
Metric
Direct Known Subclasses:
AUC, FalseNegatives, FalsePositives, Mean, MeanIoU, MeanTensor, Precision, PrecisionAtRecall, Recall, RecallAtPrecision, SensitivityAtSpecificity, SpecificityAtSensitivity, Sum, TrueNegatives, TruePositives

public abstract class BaseMetric extends Object implements Metric
Base class for Metrics
  • Constructor Details

    • BaseMetric

      protected BaseMetric(long seed)
      Creates a Metric with a name of Class.getSimpleName()
      Parameters:
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
    • BaseMetric

      protected BaseMetric(String name, long seed)
      Creates a Metric
      Parameters:
      name - the name for this metric. If null, name defaults to Class.getSimpleName().
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
  • Method Details

    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
      Creates a List of Operations to update the metric state based on input values.

      This is an empty implementation that should be overridden in a subclass, if needed.

      Specified by:
      updateStateList in interface Metric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      values - the inputs to be passed to update state, this may not be null
      sampleWeights - sample weights to be applied to the values, may be null.
      Returns:
      a List of Operations to update the metric state
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not have a Graph environment.
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
      Creates a List of Operations to update the metric state based on labels and predictions.

      This is an empty implementation that should be overridden in a subclass, if needed.

      Specified by:
      updateStateList in interface Metric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - the labels
      predictions - the predictions
      sampleWeights - sample weights to be applied to the metric values, may be null.
      Returns:
      a List of Operations to update the metric state
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not have a Graph environment.
    • updateState

      public final Op updateState(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
      Creates a NoOp Operation with control dependencies to update the metric state
      Specified by:
      updateState in interface Metric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      values - the inputs to be passed to update state, this may not be null
      sampleWeights - sample weights to be applied to the values, may be null.
      Returns:
      the Operation to update the metric state
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not have a Graph environment.
    • updateState

      public final Op updateState(Ops tf, Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
      Creates a NoOp Operation with control dependencies to update the metric state
      Specified by:
      updateState in interface Metric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      labels - the labels
      predictions - the predictions
      sampleWeights - sample weights to be applied to the metric values, may be null.
      Returns:
      the Operation to update the metric state
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not have a Graph environment.
    • callOnce

      public final <T extends TNumber> Operand<T> callOnce(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights, Class<T> type)
      Calls update state once, followed by a call to get the result
      Specified by:
      callOnce in interface Metric
      Type Parameters:
      T - The data type for the metric result
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      values - the inputs to be passed to update state, this may not be null
      sampleWeights - sample weights to be applied to the values, may be null.
      type - the data type for the result
      Returns:
      the result, possibly with control dependencies
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not have a Graph environment.
    • getVariableName

      protected String getVariableName(String varName)
      Gets a formatted name for a variable, in the form name + "_" + varName.
      Parameters:
      varName - the base name for the variable
      Returns:
      the formatted variable name
    • getName

      public String getName()
      The name for this metric. Defaults to Class.getSimpleName().

      Gets the name of this metric.

      Returns:
      the name of this metric
    • setName

      public void setName(String name)
      Sets the metric name
      Parameters:
      name - the metric name
    • getSeed

      public long getSeed()
      Gets the random number generator seed value
      Returns:
      the random number generator seed value
    • init

      protected abstract void init(Ops tf)
      Initialize the TensorFlow Ops
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Throws:
      IllegalArgumentException - if the TensorFlow Ops does not have a Graph environment,
    • getTF

      protected Ops getTF()
      Gets the TensorFlow Ops for this metric
      Returns:
      the TensorFlow Ops for this metric.
    • setTF

      protected void setTF(Ops tf)
      Sets the TensorFlow Ops for this metric.

      This should be set from the init(Ops) implementation.

      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not have a Graph environment.
    • isInitialized

      public boolean isInitialized()
      Checks whether the Metric is initialized or not.
      Returns:
      true if the Metric has been initialized.
    • setInitialized

      protected void setInitialized(boolean initialized)
      Sets the initialized indicator
      Parameters:
      initialized - the initialized indicator
    • checkIsGraph

      protected void checkIsGraph(Ops tf)
      Checks if the TensorFlow Ops encapsulates a Graph environment.
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Throws:
      IllegalArgumentException - if the TensorFlow Ops scope does not encapsulate a Graph environment.