Class Mean<T extends TNumber>

java.lang.Object
org.tensorflow.framework.metrics.BaseMetric
org.tensorflow.framework.metrics.Mean<T>
Type Parameters:
T - The data type for the metric result
All Implemented Interfaces:
Metric
Direct Known Subclasses:
Accuracy, BinaryAccuracy, BinaryCrossentropy, CategoricalAccuracy, CategoricalCrossentropy, CategoricalHinge, CosineSimilarity, Hinge, KLDivergence, LogCoshError, MeanAbsoluteError, MeanAbsolutePercentageError, MeanRelativeError, MeanSquaredError, MeanSquaredLogarithmicError, Poisson, RootMeanSquaredError, SparseCategoricalAccuracy, SparseCategoricalCrossentropy, SparseTopKCategoricalAccuracy, SquaredHinge, TopKCategoricalAccuracy

public class Mean<T extends TNumber> extends BaseMetric
A metric that that implements a weighted mean MetricReduction.WEIGHTED_MEAN
  • Field Details

  • Constructor Details

    • Mean

      public Mean(long seed, Class<T> type)
      Creates a Reducible Metric with a metric reductions of MetricReduction.SUM and using Class.getSimpleName() for the metric name.
      Parameters:
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the type for the variables and result
    • Mean

      public Mean(String name, long seed, Class<T> type)
      Creates a Reducible Metric with a metric reductions of MetricReduction.SUM
      Parameters:
      name - the name for this metric. If null, name defaults to Class.getSimpleName().
      seed - the seed for random number generation. An initializer created with a given seed will always produce the same random tensor for a given shape and data type.
      type - the type for the variables and result
  • Method Details

    • init

      protected void init(Ops tf)
      Initialize the TensorFlow Ops
      Specified by:
      init in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
    • resetStates

      public Op resetStates(Ops tf)
      Resets any state variables to their initial values
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      Returns:
      the operation for doing the reset
    • updateStateList

      public List<Op> updateStateList(Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
      Updates the metric variables based on the inputs. At least one input arg required for values, an optional additional input for the sampleWeights
      Specified by:
      updateStateList in interface Metric
      Overrides:
      updateStateList in class BaseMetric
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      values - the inputs to be passed to update state, this may not be null
      sampleWeights - sample weights to be applied to values, will default to 1 if null.
      Returns:
      the result with a control dependency on update state Operands
      Throws:
      IllegalArgumentException - if values is null
    • result

      public <U extends TNumber> Operand<U> result(Ops tf, Class<U> type)
      Gets the current result of the metric
      Type Parameters:
      U - the date type for the result
      Parameters:
      tf - the TensorFlow Ops encapsulating a Graph environment.
      type - the data type for the result
      Returns:
      the result, possibly with control dependencies
    • getTotal

      public Variable<T> getTotal()
      Gets the total variable
      Returns:
      the total variable
    • getCount

      public Variable<T> getCount()
      Gets the count variable
      Returns:
      the count variable
    • getInternalType

      public Class<T> getInternalType()
      Gets the type for the variables
      Returns:
      the type for the variables