Source code for lingvo.core.bfloat16_variables

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various methods for bfloat16 training & inference.

Bfloat16VariableSaveable: Saveable that restores variable into bfloat16 type.

Usage:

  Given a checkpoint_path with a variable of type tf.float32, this particular
  saveable allows restore them as tf.bfloat16. This is specifically useful for
  inference.

  Say: checkpoint_path contains a variable "var" with dtype tf.float32::

      variable_name = "var"
      original_dtype =  tf.float32

      bfloat16_var = tf.Variable(
          0.0, name=variable_name, dtype=tf.bfloat16, use_resource=True)

      saveable = bfloat16_variables.Bfloat16VariableSaveable(
          bfloat16_var, original_dtype, slice_spec, variable_name)
      saver = tf.train.Saver(
          {variable_name: saveable}, restore_sequentially=True)
      saver.restore(sess, checkpoint_path)

      # bfloat16_var is now loaded from the checkpoint.
"""
import lingvo.compat as tf
from tensorflow.python.training import saver


[docs]class Bfloat16VariableSaveable(saver.BaseSaverBuilder.SaveableObject): """Saveable that loads Variables as bfloat16.""" def __init__(self, var, orig_dtype, slice_spec, name): # TODO(rohananil): Investigate if we can avoid using a callable, instead # change the saveable api to make use of dtype passed in. def _make_callable_var(): return var spec = saver.BaseSaverBuilder.SaveSpec( _make_callable_var, slice_spec, name, dtype=orig_dtype, device=var.device) super().__init__(var, [spec], name)
[docs] def restore(self, restored_tensors, restored_shapes): restored_tensor = restored_tensors[0] if restored_shapes is not None: restored_tensor = tf.reshape(restored_tensor, restored_shapes[0]) return tf.assign( self.op, tf.cast(restored_tensor, tf.bfloat16), validate_shape=restored_shapes is None and self.op.get_shape().is_fully_defined())
[docs]def get_saver_spec_for_variables_with_bf16_overrides(variables_to_restore): """Returns a dictionary containing overrides to load variables as bf16. Args: variables_to_restore: A mapping from variable to name (on checkpoint) to the Variable object. Returns: A saver dictionary which can be used to load from checkpoints. """ saver_dict = {} for var_name, v in variables_to_restore.items(): if v.dtype == tf.bfloat16: # TODO(rohananil): Add support for PartitionedVariables if there is # demand. savable = Bfloat16VariableSaveable(v, tf.float32, '', var_name) saver_dict[var_name] = savable else: saver_dict[var_name] = v return saver_dict