Class VarianceScaling<T extends TFloating>

java.lang.Object
org.tensorflow.framework.initializers.BaseInitializer<T>
org.tensorflow.framework.initializers.VarianceScaling<T>
Type Parameters:
T - The TType for the call operation
All Implemented Interfaces:
Initializer<T>
Direct Known Subclasses:
Glorot, He, LeCun

public class VarianceScaling<T extends TFloating> extends BaseInitializer<T>
Initializer capable of adapting its scale to the shape of weights tensors.

With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from a truncated/untruncated normal distribution with a mean of zero and a standard deviation (after truncation, if used) stddev = Math.sqrt(scale / n), where n is:

  • number of input units in the weight tensor, if mode=FAN_IN
  • number of output units, if mode=FAN_OUT
  • average of the numbers of input and output units, if mode=FAN_AVG

With distribution=UNIFORM, samples are drawn from a uniform distribution within [-limit, limit], where limit = Math.sqrt(3 * scale / n);.

Examples:

     long seed = 1234l;
     float scale = 0.1f;
     VarianceScaling<TFloat32, TFloat32> initializer =
         new org.tensorflow.framework.initializers.VarianceScaling<>(
             tf, scale, Mode.FAN_IN, Distribution.UNIFORM, seed);
     Operand<TFloat32> values =
         initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class);
See Also:
  • Field Details

  • Constructor Details

    • VarianceScaling

      public VarianceScaling(long seed)
      Creates a VarianceScaling Initializer
      Parameters:
      seed - sed to create random seeds.
    • VarianceScaling

      public VarianceScaling(double scale, VarianceScaling.Mode mode, VarianceScaling.Distribution distribution, long seed)
      Creates a VarianceScaling Initializer
      Parameters:
      scale - Scaling factor (positive float).
      mode - the mode for the variance
      distribution - Random distribution to use.
      seed - Used to create random seeds.
  • Method Details

    • call

      public Operand<T> call(Ops tf, Operand<TInt64> dims, Class<T> type)
      Generates the operation used to perform the initialization.
      Parameters:
      tf - the TensorFlow Ops
      dims - the shape dimensions
      type - the type of tensor
      Returns:
      An operand for the initialization.