lingvo.core.bfloat16_variables module
Various methods for bfloat16 training & inference.
Bfloat16VariableSaveable: Saveable that restores variable into bfloat16 type.
Usage:
Given a checkpoint_path with a variable of type tf.float32, this particular saveable allows restore them as tf.bfloat16. This is specifically useful for inference.
Say: checkpoint_path contains a variable “var” with dtype tf.float32:
variable_name = "var" original_dtype = tf.float32 bfloat16_var = tf.Variable( 0.0, name=variable_name, dtype=tf.bfloat16, use_resource=True) saveable = bfloat16_variables.Bfloat16VariableSaveable( bfloat16_var, original_dtype, slice_spec, variable_name) saver = tf.train.Saver( {variable_name: saveable}, restore_sequentially=True) saver.restore(sess, checkpoint_path) # bfloat16_var is now loaded from the checkpoint.
- class lingvo.core.bfloat16_variables.Bfloat16VariableSaveable(var, orig_dtype, slice_spec, name)[source]
Bases:
SaveableObject
Saveable that loads Variables as bfloat16.
- restore(restored_tensors, restored_shapes)[source]
Restores this object from ‘restored_tensors’.
- Parameters
restored_tensors – the tensors that were loaded from a checkpoint
restored_shapes – the shapes this object should conform to after restore, or None.
- Returns
An operation that restores the state of the object.
- Raises
ValueError – If the object cannot be restored using the provided parameters.
- lingvo.core.bfloat16_variables.get_saver_spec_for_variables_with_bf16_overrides(variables_to_restore)[source]
Returns a dictionary containing overrides to load variables as bf16.
- Parameters
variables_to_restore – A mapping from variable to name (on checkpoint) to the Variable object.
- Returns
A saver dictionary which can be used to load from checkpoints.