# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common utilities."""
# ==============================================================================
# Note: Avoid adding dependencies to py_utils beyond standard python packages
# and tensorflow.
# ==============================================================================
import collections as py_collections
import contextlib
import functools
import hashlib
import inspect
import math
import numbers
import os
import pkgutil
import re
import threading
import traceback
import typing
from typing import Any, Callable, Optional, Union
import lingvo.compat as tf
from lingvo.core import cluster_factory
from lingvo.core import gshard_utils
from lingvo.core import hyperparams
from lingvo.core import nested_map
from lingvo.core import ops
from lingvo.core import py_utils_flags
from lingvo.core import retry
from lingvo.core import symbolic
from lingvo.core import thread_local_utils
from lingvo.core import tshape
import numpy as np
import six
# pylint: disable=g-direct-tensorflow-import
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import function
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.tpu import topology as tf_topology
from tensorflow.python.tpu import tpu_function
from tensorflow.python.util import deprecation
# pylint: enable=g-direct-tensorflow-import
FLAGS = tf.flags.FLAGS
# pylint: disable=protected-access
_FromGlobal = py_utils_flags._FromGlobal
# pylint: enable=protected-access
use_xla = py_utils_flags.use_xla
use_tpu = py_utils_flags.use_tpu
testonly_skip_norm_layers = py_utils_flags.testonly_skip_norm_layers
tpu_compat = py_utils_flags.tpu_compat
use_stateless_vars_init = py_utils_flags.use_stateless_vars_init
ENQUEUE_OPS = '__lingvo_enqueue_ops'
# pylint: disable=protected-access
deprecation._PRINT_DEPRECATION_WARNINGS = False
# pylint: enable=protected-access
ThreadLocalStack = thread_local_utils.ThreadLocalStack
ThreadLocalDict = thread_local_utils.ThreadLocalDict
NestedMap = nested_map.NestedMap
Params = hyperparams.Params
InstantiableParams = hyperparams.InstantiableParams
[docs]def Assert(condition, data, *args, **kwargs):
if py_utils_flags.enable_asserts():
return tf.Assert(condition, data, *args, **kwargs)
else:
return tf.no_op()
[docs]def assert_equal(*args, **kwargs): # pylint: disable=invalid-name
if py_utils_flags.enable_asserts():
return tf.assert_equal(*args, **kwargs)
else:
return tf.no_op()
[docs]def assert_greater_equal(*args, **kwargs): # pylint: disable=invalid-name
if py_utils_flags.enable_asserts():
return tf.debugging.assert_greater_equal(*args, **kwargs)
else:
return tf.no_op()
[docs]def assert_greater(*args, **kwargs): # pylint: disable=invalid-name
if py_utils_flags.enable_asserts():
return tf.assert_greater(*args, **kwargs)
else:
return tf.no_op()
[docs]def assert_less_equal(*args, **kwargs): # pylint: disable=invalid-name
if py_utils_flags.enable_asserts():
return tf.debugging.assert_less_equal(*args, **kwargs)
else:
return tf.no_op()
[docs]def assert_less(*args, **kwargs): # pylint: disable=invalid-name
if py_utils_flags.enable_asserts():
return tf.assert_less(*args, **kwargs)
else:
return tf.no_op()
[docs]def assert_between(x, l, r, *args, **kwargs): # pylint: disable=invalid-name
x = tf.convert_to_tensor(x)
l = tf.cast(tf.convert_to_tensor(l), x.dtype)
r = tf.cast(tf.convert_to_tensor(r), x.dtype)
return tf.group([
assert_greater_equal(x, l, *args, **kwargs),
assert_less(x, r, *args, **kwargs)
])
[docs]def assert_shape_match(*args, **kwargs): # pylint: disable=invalid-name
if py_utils_flags.enable_asserts():
filepath, line, func, _ = traceback.extract_stack(limit=3)[-2]
kwargs['msg'] = 'LINGVO ASSERT %s:%s(%s)' % (re.sub(
r'.*/', '', filepath), line, func)
return ops.assert_shape_match(*args, **kwargs)
else:
return tf.no_op()
[docs]def assert_same_dim0(xs, *args, **kwargs): # pylint: disable=invalid-name
if py_utils_flags.enable_asserts():
return ops.assert_same_dim0(xs, *args, **kwargs)
else:
return tf.no_op()
[docs]def assert_even_divide(denorm, num): # pylint: disable=invalid-name
"""Asserts that denorm is evenly divided by num."""
denorm = tf.convert_to_tensor(denorm)
num = tf.convert_to_tensor(num)
if denorm.dtype not in (tf.int32, tf.int64):
raise ValueError('denorminator.dtype is not tf.int32 or tf.int64.')
if num.dtype not in (tf.int32, tf.int64):
raise ValueError('numerator.dtype is not tf.int32 or tf.int64.')
num = HasShape(num, GetShape(denorm))
quo = denorm // num
return assert_equal(quo * num, denorm)
[docs]def AssertIdShape(expected_ids_shape_pattern, ids_shape, *args):
"""Asserts shape expected_ids_shape_pattern matches all other input shapes."""
def AssertFn(inputs):
dependencies = [
assert_shape_match(inputs.ids_shape, inputs.expected_ids_shape_pattern)
] + [
assert_shape_match(inputs.ids_shape, x_shape) for x_shape in inputs.args
]
return with_dependencies(dependencies, inputs.ids_shape)
inputs = NestedMap(
expected_ids_shape_pattern=expected_ids_shape_pattern,
ids_shape=ids_shape,
args=args)
return CallDefun(AssertFn, Transform(tf.convert_to_tensor, inputs))
[docs]def _CheckNumerics(x, message=None, *args, **kwargs):
if x.dtype.is_floating:
x_name = x.name if not tf.executing_eagerly() else '[eager]'
if 'name' not in kwargs:
kwargs['name'] = re.sub(r':\d+', '', x_name) + '_CheckNumerics'
return tf.debugging.check_numerics(x, message if message else x_name, *args,
**kwargs)
else:
return x
[docs]def CheckNumerics(inp, message=None, *args, **kwargs):
"""Check numerics for tensors in inp."""
if not py_utils_flags.enable_check_numerics():
return inp
if isinstance(inp, list):
return [_CheckNumerics(x, message, *args, **kwargs) for x in inp]
if isinstance(inp, tuple):
return tuple(_CheckNumerics(x, message, *args, **kwargs) for x in inp)
return _CheckNumerics(inp, message, *args, **kwargs)
[docs]def with_dependencies(dependencies, output_tensor): # pylint: disable=invalid-name
with tf.control_dependencies(dependencies):
return tf.identity(output_tensor)
[docs]def _VarInCollection(var, collection) -> bool:
"""Return whether a variable `var` is in the given variable collection."""
# We use variable reference for comparison, since variable is not hashable in
# eager mode.
return var.ref() in [v.ref() for v in collection]
[docs]@contextlib.contextmanager
def _PrintOptions(*args, **kwargs):
original = np.get_printoptions()
np.set_printoptions(*args, **kwargs)
try:
yield
finally:
np.set_printoptions(**original)
[docs]def _Print(name, x):
with _PrintOptions(linewidth=1000):
tf.logging.info('%s = %s', name, np.array_repr(x))
[docs]def Log(value, prefix, **kwargs):
"""Prints out values of tensors.
Useful for debugging. E.g.,
x = ... a tf.Tensor ...
y = ... a tf.Tensor ...
z = compute(x, y)
z = Log(z, 'debug compute()', x=x, y=y)
Args:
value: A Tensor. Log happens after this tensor's computed.
prefix: Every tensor is logged with this prefix.
**kwargs: keywords and tensors. Tensors are logged in the sort order of
these keywards.
Returns:
value is returned.
"""
# Ensures tensors are printed in order.
last = value
for k in sorted(kwargs):
with tf.control_dependencies([last]):
last = tf.py_func(_Print, [prefix + ' : ' + k, kwargs[k]], [])
with tf.control_dependencies([last]):
return tf.identity(value)
[docs]def Debug(tensor, message='', enabled=True, summarize=100, more=None):
"""Wrapper around tf.Print() and tf.logging.info() to simplify debug printing.
x = py_utils.Debug(x)
When the graph is built a regular log info line will be printed:
-DBG- py_utils_test.py:429 x=Tensor(...
Then when the tensor node is evaluated it will print lines like:
-DBG- py_utils_test.py:429 x Const:0[x.shape=][2 2][x=][[1 2][3 4]]
WARNING: The code that parses local variable names can fail. E.g. don't write
two Debug() calls on one line or a Debug() call that spans more than one line.
Args:
tensor: A tensor to print.
message: A message to print.
enabled: To enable the debugging.
summarize: Integer with number of tensor values to print.
more: An optional list of additional tensors.
Returns:
The tensor.
"""
if not enabled or _FromGlobal('disable_py_utils_debug'):
return tensor
if more is None:
more = []
stack = inspect.stack()[1][0]
caller = inspect.getframeinfo(stack)
caller_var = ''
caller_more_vars = []
if caller.code_context:
# Rough and likely to fail. But better than nothing.
match = re.compile(r'Debug\((.*?)(\)|,).*$').search(caller.code_context[0])
if match:
caller_var = match.groups()[0]
if more:
more_vars = re.compile(r'more=\[(.*?)\].*$').search(
caller.code_context[0]).groups()[0]
if more_vars:
caller_more_vars = more_vars.split(',')
the_class = ''
if 'self' in stack.f_locals:
the_class = stack.f_locals['self'].__class__.__name__
header = '-DBG- {}:{}:{}:{} {} '.format(
os.path.basename(caller.filename), the_class, caller.function,
caller.lineno, message)
info = '{}{}={}'.format(header, caller_var, tensor)
for name, val in zip(caller_more_vars, more):
info += ' {}={}'.format(name.strip(), val)
tf.logging.info(info)
if isinstance(tensor, tf.Tensor):
tensors = []
tensors += [tf.constant('{}.shape='.format(caller_var)), tf.shape(tensor)]
for name, val in zip(caller_more_vars, more):
tensors += [tf.constant('{}.shape='.format(name.strip())), tf.shape(val)]
tensors += [tf.constant('{}='.format(caller_var)), tensor]
for name, val in zip(caller_more_vars, more):
tensors += [tf.constant('{}='.format(name.strip())), val]
name = tensor.name if not tf.executing_eagerly() else '[eager]'
info = '{}{} {}'.format(header, caller_var, name)
return tf.identity(
tf.Print(tensor, tensors, info, summarize=summarize),
re.sub(':.*$', '', name))
return tensor
[docs]def _Save(steps, prefix, key, val):
filename = '%s.%08d.%s.npy' % (six.ensure_text(prefix), steps,
six.ensure_text(key))
with tf.io.gfile.GFile(filename, 'w') as outfile:
np.save(outfile, val)
[docs]def Save(value, filename_prefix, **kwargs):
"""Saves values of tensors into files.
Useful for debugging. E.g.,
x = ... a tf.Tensor ...
y = ... a tf.Tensor ...
z = compute(x, y)
z = Save(z, '/path/tmp', x=x, y=y, z=z)
Args:
value: A Tensor. Saving happens after this tensor is computed.
filename_prefix: Every tensor is saved with this filename prefix.
**kwargs: keywords and tensors. Tensors are logged in the sort order of
these keywards.
Returns:
value is returned.
"""
last = value
steps = GetGlobalStep()
for k in sorted(kwargs):
with tf.control_dependencies([last]):
last = tf.py_func(_Save, [steps, filename_prefix, k, kwargs[k]], [])
with tf.control_dependencies([last]):
return tf.identity(value)
[docs]def HasRank(tensor, expected_rank):
"""Syntactic sugar for asserting that tensor has the expected rank."""
if tensor.shape.ndims is not None and isinstance(expected_rank, int):
assert tensor.shape.ndims == expected_rank, (
'Ranks did not match, got %d, '
'expected %d') % (tensor.shape.ndims, expected_rank)
return tensor
if py_utils_flags.enable_asserts():
return with_dependencies([tf.assert_equal(tf.rank(tensor), expected_rank)],
tensor)
else:
return tensor
[docs]def HasAtLeastRank(tensor, expected_rank):
"""Syntactic sugar for asserting that tensor has rank >= expected_rank."""
if tensor.shape.ndims is not None and isinstance(expected_rank, int):
assert tensor.shape.ndims >= expected_rank, (
'Rank of tensor %d did not exceed the expected value %d.') % (
tensor.shape.ndims, expected_rank)
return tensor
if py_utils_flags.enable_asserts():
return with_dependencies(
[tf.debugging.assert_greater_equal(tf.rank(tensor), expected_rank)],
tensor)
else:
return tensor
[docs]def GetRank(tensor):
"""Returns tensor's rank as an int if it's available, otherwise a Tensor.
Args:
tensor: The input tensor.
Returns:
Either an int or a Tensor for the rank of the input tensor.
"""
if tensor.shape.ndims is not None:
return tensor.shape.ndims # int
else:
return tf.rank(tensor) # Tensor
[docs]def GetShape(tensor, ndims=None):
"""Returns tensor's shape as a list which can be unpacked, unlike tf.shape.
Tries to return static shape if it's available. Note that this means
some of the outputs will be ints while the rest will be Tensors.
Args:
tensor: The input tensor.
ndims: If not None, returns the shapes for the first `ndims` dimensions.
"""
tensor = tf.convert_to_tensor(tensor)
dynamic_shape = tf.shape(tensor)
# Early exit for unranked tensor.
if tensor.shape.ndims is None:
if ndims is None:
return dynamic_shape
else:
return [dynamic_shape[x] for x in range(ndims)]
# Ranked tensor.
if ndims is None:
ndims = tensor.shape.ndims
else:
ndims = min(ndims, tensor.shape.ndims)
# Return mixture of static and dynamic dims.
static_shape = tensor.shape.as_list()
shapes = [
static_shape[x] if static_shape[x] is not None else dynamic_shape[x]
for x in range(ndims)
]
return shapes
[docs]def HasShape(tensor, expected_shape, ndims=None):
"""Syntactic sugar for asserting that tensor has the expected shape.
Args:
tensor: A Tensor.
expected_shape: A Python list or a 1D tensor. Elements of expected_shape can
be -1 which indicate that any size is valid for that dimension.
ndims: If not None, check only the first `ndims` dimensions of `tensor`.
Must be equal to the length of `expected_shape` if not None.
Returns:
The input `tensor` with control dependencies that will raise a runtime
error if dynamic shape checks fail.
Raises:
ValueError: A value error if the assertion fails at static shape checks.
"""
if not py_utils_flags.enable_asserts():
return tensor
filepath, line, func, _ = traceback.extract_stack(limit=3)[-2]
msg = 'LINGVO ASSERT %s:%s(%s)' % (re.sub(r'.*/', '',
filepath), line, func)
tensor_shape = GetShape(tensor)
if ndims is not None:
tensor_shape = tensor_shape[:ndims]
# TODO(jngiam): Attempt to switch back to tf.Assert after it has better
# support on GPUs.
assert_op = ops.assert_shape_match(tensor_shape, expected_shape, msg=msg)
# If expected_shape is a Tensor, then we are unable to perform static checks.
# In this case, we can do a dynamic check and return.
if isinstance(expected_shape, tf.Tensor):
return with_dependencies([assert_op], tensor)
# Infer ranks from the inputs.
expected_rank = len(expected_shape)
if isinstance(tensor_shape, tf.Tensor):
tensor_rank = tensor.shape.ndims
else:
tensor_rank = len(tensor_shape)
# If ndims is None, then either one of the ranks should not be None, or they
# should both match. If both ranks are None, then they are both tensors and
# should be caught by the earlier short-circuit.
if ndims is None:
if (tensor_rank is not None) and (expected_rank != tensor_rank):
raise ValueError('Tensor does not match rank of expected shape.\n'
'Tensor shape: {} Expected shape: {}'.format(
tensor_shape, expected_shape))
# Both tensors can be assumed to be of same rank.
ndims = expected_rank
else:
if (tensor_rank is not None) and (tensor_rank < ndims):
raise ValueError('Tensor has fewer dimensions than ndims.\n'
'Tensor shape: {} ndims: {}'.format(tensor_shape, ndims))
if expected_rank != ndims:
raise ValueError(
'Expected shape must have number of dimensions equal to ndims.\n'
'Expected shape: {} ndims: {}'.format(expected_shape, ndims))
# Ensure that both tensor_shape and expected_shape are both lists.
tensor_shape = tensor_shape[:ndims]
if isinstance(tensor_shape, tf.Tensor):
tensor_shape = tf.unstack(tensor_shape, num=ndims)
# Map tf.Dimension values to their held values.
tensor_shape = [
v.value if isinstance(v, tf.Dimension) else v for v in tensor_shape
]
expected_shape = [
v.value if isinstance(v, tf.Dimension) else v for v in expected_shape
]
all_static_checks = True
for idx, (dim, expected_dim) in enumerate(zip(tensor_shape, expected_shape)):
if isinstance(expected_dim, tf.Tensor):
all_static_checks = False
elif expected_dim == -1:
continue
elif isinstance(dim, tf.Tensor):
all_static_checks = False
elif dim != expected_dim:
raise ValueError('Tensor does not match expected shape on dimension {}.\n'
'Tensor shape: {} Expected shape: {}'.format(
idx, tensor_shape, expected_shape))
if all_static_checks:
return tf.convert_to_tensor(tensor)
else:
return with_dependencies([assert_op], tensor)
[docs]def HasSameShape(x, ref):
return HasShape(x, GetShape(ref))
[docs]def GetSize(tensor):
shape = GetShape(tensor)
if (isinstance(shape, tf.Tensor) or
any([isinstance(x, tf.Tensor) for x in shape])):
return tf.size(tensor)
return np.prod(shape)
[docs]def CausalSelfAttenPadding(seqlen, dtype):
"""Wraps tf.linalg.band_part() for tflite compatibility."""
if FLAGS.tflite_compatible:
# [N, 1]
rows = tf.expand_dims(tf.range(seqlen), -1)
# [1, N]
cols = tf.expand_dims(tf.range(seqlen), 0)
row_cols = rows - cols
return tf.where(row_cols < 0, tf.ones([seqlen, seqlen], dtype),
tf.zeros([seqlen, seqlen], tf.float32))
else:
return 1.0 - tf.linalg.band_part(
tf.ones([seqlen, seqlen], dtype=dtype), -1, 0)
[docs]def outside_all_rewrites(): # pylint: disable=invalid-name
return tf.control_dependencies(None)
# TODO(jamesqin): remove once b/147439702 is fixed.
_OUTSIDE_COMPILATION = threading.local()
[docs]def RunOnTpuHost(func, *args, **kwargs):
r"""Runs the given function call on TPU host.
Invokes func(\*args, \*\*kwargs) directly if not running on tpu.
Args:
func: the function to invoke.
*args: args of func
**kwargs: kwargs of func
Returns:
The function return value.
"""
if use_tpu() and not getattr(_OUTSIDE_COMPILATION, 'on', False):
_OUTSIDE_COMPILATION.on = True
res = tf.tpu.outside_compilation(func, *args, **kwargs)
_OUTSIDE_COMPILATION.on = False
else:
res = func(*args, **kwargs)
return res
[docs]def tpu_host(func): # pylint: disable=invalid-name
r"""Decorates a python function to only run on TPU hosts.
This function has no effect when running on CPU/GPU.
Example::
@py_utils.tpu_host()
def ComputeWER(self):
# Call a custom op computing WER.
Args:
func: the function to invoke
Returns:
A TPU-host only function
"""
def Wrapped(*args, **kwargs):
return RunOnTpuHost(func, *args, **kwargs)
return Wrapped
# Maps a TPU job name ('/job:xxx') to the job's DeviceAssignment object.
# When there is only a single TPU job, the key could be None.
_tpu_device_assignment_dict = dict()
[docs]def SetTpuDeviceAssignment(tpu_device_assignment, job=None):
if job in _tpu_device_assignment_dict:
tf.logging.warning('tpu_device_assignment was already set, '
'overwriting with new assignment.')
_tpu_device_assignment_dict[job] = tpu_device_assignment
# This function should called in unittest only.
[docs]def ClearTpuDevice():
global _tpu_device_assignment_dict
_tpu_device_assignment_dict = dict()
[docs]def GetTpuDeviceAssignment(job=None):
return _tpu_device_assignment_dict[job]
# Whether it's running in eager mode. This is different than
# tf.executing_eagerly(), which will return False inside a tf.function.
_IS_EAGER_MODE = False
# If you get an error "tf.enable_eager_execution must be called at program
# startup." but you are calling this function at the start, check if your change
# adds type hints for "tf.data" and wrap those type hints in quotes.
[docs]def SetEagerMode(eager_mode=True, test_mode=False):
"""Switch between Eager and Graph mode. Use this instead of TF APIs."""
global _IS_EAGER_MODE
_IS_EAGER_MODE = eager_mode
# Only change the global flag.
# Used in tests. In those scenarios we might want to use Graph mode along with
# Eager mode. All we need is changing the flag `_IS_EAGER_MODE` without
# calling `enable_eager_execution`/`disable_eager_execution`.
if test_mode:
return
if eager_mode:
tf.enable_eager_execution()
tf.config.set_soft_device_placement(True)
else:
tf.disable_eager_execution()
[docs]def IsEagerMode():
return _IS_EAGER_MODE
# Defaults to True.
# Set to false for already existing Eager mode tests.
_EAGER_RNG_ADAPTATION = True
[docs]def SetEagerRngAdaptation(mode=True):
"""Whether to reproduce the TF1 stateful RNG behaviors in pure Eager mode."""
global _EAGER_RNG_ADAPTATION
_EAGER_RNG_ADAPTATION = mode
# Maintains a tf.GradientTape stack.
_GRADIENT_TAPE_STACK = ThreadLocalStack()
[docs]@contextlib.contextmanager
def GradientTape(*args, **kwargs):
"""Creates a tf.GradientTape and use it for automatic differentiation."""
tape = tf.GradientTape(*args, **kwargs)
_GRADIENT_TAPE_STACK.stack.append(tape)
try:
with tape as t:
yield t
finally:
_GRADIENT_TAPE_STACK.stack.pop()
[docs]def SessionConfig(soft_placement=True,
inline=True,
cluster_def=None,
disable_meta_optimizer=False):
"""Returns a session config proto.
Args:
soft_placement: Turns allow_soft_placement on iff True.
inline: Turns do_function_inlining on iff True.
cluster_def: A tf.train.ClusterDef describing the cluster.
disable_meta_optimizer: Turns off grappler/metagraph optimizer.
Returns:
A TF session config proto.
"""
session_config = tf.config_pb2.ConfigProto(
allow_soft_placement=soft_placement,
graph_options=tf.GraphOptions(
optimizer_options=tf.OptimizerOptions(
opt_level=tf.OptimizerOptions.L1, do_function_inlining=inline)),
cluster_def=cluster_def)
session_config.share_cluster_devices_in_session = True
if disable_meta_optimizer:
# Useful if start-up time is critical.
session_config.graph_options.rewrite_options.disable_meta_optimizer = True
# Disable layout optimizer which increases GPU memory usage.
session_config.graph_options.rewrite_options.layout_optimizer = (
rewriter_config_pb2.RewriterConfig.OFF)
return session_config
[docs]def AssertIsCompatible(a, b):
assert a.IsCompatible(b), ('%s vs %s' % (a, b))
[docs]def SetShapes(dst_nmap: NestedMap, src_nmap: NestedMap) -> None:
"""Set shapes in dst_nmap using those in src_nmap."""
AssertIsCompatible(src_nmap, dst_nmap)
for src, dst in zip(src_nmap.Flatten(), dst_nmap.Flatten()):
dst.set_shape(src.shape)
[docs]def Dtypes(nmap_list) -> list[tf.dtypes.DType]:
"""Returns all tensors' data types in a list."""
return [v.dtype for v in Flatten(nmap_list)]
[docs]def Flatten(x) -> list[Any]:
"""Flattens 'x' by extracting tensors from nested structures to a list."""
return tf.nest.flatten(x)
[docs]def Pack(tmpl, values):
"""Packs 'values' according to 'tmpl'."""
return tf.nest.pack_sequence_as(tmpl, values)
[docs]def ConvertNoneGradientToZeros(xs, dxs):
"""Sanitize dxs so that None becomes zeros appropriately.
Args:
xs: A list of tensors.
dxs: A list of tensors. dxs[i] corresponds to xs[i]'s gradient.
Returns:
A `.NestedMap` same as dxs with None replaced by a zero tensor.
"""
fn = lambda x, dx: tf.zeros_like(x) if dx is None else dx
return Transform(fn, xs, dxs)
[docs]def IsCompatible(lhs, rhs):
"""Returns true if lhs and rhs are compatible."""
try:
tf.nest.assert_same_structure(lhs, rhs)
return True
except (ValueError, TypeError):
return False
[docs]class _Unique:
"""A helper to uniqify variables in a NestedMap."""
def __init__(self):
self._vset = set()
def __call__(self, v):
if (v is None) or (id(v) in self._vset):
return False
else:
self._vset.add(id(v))
return True
[docs]def ToUniqueList(nmap):
"""Returns the flattened `nmap` with duplicates removed."""
return nmap.Filter(_Unique()).Flatten()
[docs]def ReadOnlyAttrDictView(backing):
"""Wraps a dict to provide a read-only view of its contents.
Dict keys can also be accessed by attribute.
Args:
backing: Dict-like object to wrap.
Returns:
Read-only Mapping that can be accessed by index (['foo']) or attr (d.foo).
"""
class Wrapper:
"""Wrapper object."""
# Disable pytype attribute checking.
_HAS_DYNAMIC_ATTRIBUTES = True
def __getitem__(self, key):
return backing[key]
def __len__(self):
return len(backing)
def __iter__(self):
return iter(backing)
def __getattr__(self, key):
return backing[key]
def __hasattr__(self, key):
return key in backing
def __setattr__(self, key, value):
raise AttributeError('Dictionary is read-only.')
def __setitem__(self, key, value):
raise AttributeError('Dictionary is read-only.')
return Wrapper()
[docs]def ToStaticShape(shape):
"""Converts 'shape' to a static shape."""
if isinstance(shape, (list, tuple)):
shape = [
dim.value if isinstance(dim, tf.Dimension) else dim for dim in shape
]
static_shape = []
for dim in shape:
if symbolic.IsExpr(dim):
static_shape.append(symbolic.ToStatic(dim))
else:
static_shape.append(dim)
return static_shape
else:
return shape.value if isinstance(shape, tf.Dimension) else shape
[docs]def Zeros(shape, *args, **kwargs):
return tf.zeros(ToStaticShape(shape), *args, **kwargs)
[docs]class RNNCellStateInit:
"""State initialization functions for RNN cell init state."""
[docs] @staticmethod
def _Params(method, seed):
p = hyperparams.Params()
p.Define('method', method,
'Initialization method. Should be one of zeros, random_normal.')
p.Define('seed', seed, 'Random seed used to generate initial values.')
p.Freeze()
return p
[docs] @staticmethod
def Zeros():
"""tf.zeros()."""
return RNNCellStateInit._Params('zeros', seed=None)
[docs] @staticmethod
def RandomNormal(seed=None):
"""tf.random.normal()."""
return RNNCellStateInit._Params('random_normal', seed)
[docs]def DefaultRNNCellStateInit():
return RNNCellStateInit.Zeros()
[docs]def InitRNNCellState(shape, init=None, dtype=None, name=None, is_eval=False):
"""Initial state definitions for RNN cell implementations.
Args:
shape: A array of ints/symbols for specifying the shape of the state.
init: Hyperparameters as returned by one of the static implemetaitons in
RNNCellStateInit.
dtype: The dype of the states. Defaults to tf.float32.
name: A name for the operation. If --stateless_vars_init is set, this name
is used to generate a seed on a per-variable basis. Otherwise, this name
is optional.
is_eval: Bool, set to True if we need special behavior in eval mode.
Returns:
A Tensor of the specified shape, and sampled from the distribution as
defined by the init parameters.
"""
shape = ToStaticShape(shape)
if init is None:
init = DefaultRNNCellStateInit()
if dtype is None:
dtype = tf.float32
method = init.method
if ((method in ['zeros']) or (method in ['random_normal'] and is_eval)):
init_state = tf.zeros(shape=shape, dtype=dtype, name=name)
elif method in ['random_normal']:
if use_stateless_vars_init():
if name is None:
raise ValueError('InitRNNCellState() requires a `name` argument when '
'--stateless_vars_init is enabled.')
seed = _GenerateStatelessRngSeed(name, init.seed)
init_state = stateless_random_ops.stateless_random_normal(
shape=shape, dtype=dtype, name=name, seed=seed)
else:
init_state = tf.random.normal(
shape=shape, dtype=dtype, name=name, seed=init.seed)
else:
raise ValueError('Initialization method (%s) not supported.' % method)
return init_state
[docs]class WeightInit:
"""Static class providing weight initialization config params."""
[docs] @staticmethod
def _Params(method, scale, seed, custom_v_init=None):
"""Parameters of this class."""
p = hyperparams.Params()
p.Define('method', method, 'Initialization method.')
p.Define('scale', scale, 'Initialization scale.')
p.Define('seed', seed, 'Random seed used to generate initial values.')
p.Define('custom_v_init', custom_v_init,
'A custom tf.init_ops.Initializer instance.')
p.Freeze()
return p
[docs] @staticmethod
def Gaussian(scale=1.0, seed=None):
"""scale * tf.random.normal(0, 1.0)."""
return WeightInit._Params('gaussian', scale, seed)
[docs] @staticmethod
def Category(scale=2, seed=None):
"""tf.floor(scale * tf.random.uniform(0., 1.0))."""
return WeightInit._Params('category', scale, seed)
[docs] @staticmethod
def Xavier(scale=1.0, seed=None):
"""Xavier initialization (x = sqrt(6. / (in + out)); [-x, x])."""
return WeightInit._Params('xavier', scale, seed)
[docs] @staticmethod
def XavierWithFixupParams(scale=1.0,
depth=1.0,
layers_per_residual_block=1.0,
seed=None):
"""Xavier initialization with Fixup."""
scale = scale * math.pow(depth, (-1.0 / (2 * layers_per_residual_block)))
return WeightInit._Params('xavier', scale, seed)
[docs] @staticmethod
def GeoMeanXavier(scale=1.0, seed=None):
"""A variant of Xavier (x = sqrt(3. / sqrt(in * out)); [-x, x])."""
return WeightInit._Params('geo_mean_xavier', scale, seed)
[docs] @staticmethod
def Constant(scale=1.0):
"""scale."""
return WeightInit._Params('constant', scale, 0)
[docs] @staticmethod
def TruncatedGaussian(scale=1.0, seed=None):
"""scale * tf.random.truncated_normal(0, 1.0)."""
return WeightInit._Params('truncated_gaussian', scale, seed)
[docs] @staticmethod
def GaussianSqrtDim(scale=1.0, seed=None):
"""scale * tf.random.normal(0, 1 / sqrt(dim0))."""
return WeightInit._Params('gaussian_sqrt_dim', scale, seed)
[docs] @staticmethod
def GaussianSqrtFanIn(scale=1.0, seed=None):
"""scale * tf.random.normal(0, 1 / sqrt(fan_in))."""
return WeightInit._Params('gaussian_sqrt_fanin', scale, seed)
[docs] @staticmethod
def GaussianSqrtFanOut(scale=1.0, seed=None):
"""scale * tf.random.normal(0, 1 / sqrt(fan_out))."""
return WeightInit._Params('gaussian_sqrt_fanout', scale, seed)
[docs] @staticmethod
def GaussianSqrtFanAvg(scale=1.0, seed=None):
"""tf.random.normal(0, sqrt(2.0 / (in + out)))."""
return WeightInit._Params('gaussian_sqrt_fanavg', scale, seed)
[docs] @staticmethod
def TruncatedGaussianSqrtDim(scale=1.0, seed=None):
"""scale * tf.random.truncated_normal(0, 1 / sqrt(dim0))."""
return WeightInit._Params('truncated_gaussian_sqrt_dim', scale, seed)
[docs] @staticmethod
def TruncatedGaussianSqrtFanIn(scale=1.0, seed=None):
"""scale * tf.random.truncated_normal(0, 1 / sqrt(fan_in))."""
return WeightInit._Params('truncated_gaussian_sqrt_fanin', scale, seed)
[docs] @staticmethod
def TruncatedGaussianSqrtFanOut(scale=1.0, seed=None):
"""scale * tf.random.truncated_normal(0, 1 / sqrt(fan_out))."""
return WeightInit._Params('truncated_gaussian_sqrt_fanout', scale, seed)
[docs] @staticmethod
def CustomVarInit(custom_v_init):
return WeightInit._Params('custom', 1.0, None, custom_v_init)
[docs] @staticmethod
def CustomConstantVarInit(custom_v_init):
return WeightInit._Params('custom_constant', 1.0, None, custom_v_init)
[docs] @staticmethod
def ScaledDeltaOrthogonal(scale=1.0, seed=None):
return WeightInit._Params('delta_orthogonal', scale, seed)
_DEFAULT_XAVIER_INIT = 1.000001
[docs]def DefaultParamInit():
# Here we use 1.000001 as a signature for user picking up the
# default param initializer.
return WeightInit.Xavier(_DEFAULT_XAVIER_INIT)
# TODO(rpang, jonathanasdf): explore adding _is_default to hyperparams.Param.
[docs]def IsDefaultParamInit(p):
return (p.method == 'xavier' and
abs(p.scale - _DEFAULT_XAVIER_INIT) < 1e-7 and p.seed is None)
[docs]def WeightParams(shape,
init=None,
dtype=None,
collections=None,
device_mesh=None,
tensor_split_dims_mapping=None):
"""Returns a hyperparams for a weight variable given the shape/init/dtype."""
if init is None:
init = WeightInit.Xavier(_DEFAULT_XAVIER_INIT)
if dtype is None:
dtype = tf.float32
if collections is None:
collections = []
if device_mesh is not None:
if tensor_split_dims_mapping is None:
tensor_split_dims_mapping = (-1,) * len(shape)
tf.logging.info(
'Sets tensor_split_dims_mapping of a param of shape {} to {}'.format(
shape, tensor_split_dims_mapping))
assert len(tensor_split_dims_mapping) == len(shape)
p = hyperparams.Params()
p.Define('dtype', dtype, 'The weight data type.')
p.Define('shape', shape, 'The weight shape.')
p.Define('init', init, 'Initialization method.')
p.Define('collections', collections,
'Variable collections this weight belongs to.')
p.Define(
'device_mesh', device_mesh,
'A numpy.ndarray describing the topology of a device mesh to partition'
' this variable onto. Each element in the np.ndarray is the ID of a'
' device in the topology. device_mesh and tensor_split_dims_mapping below'
' together specifies how this weight tensor should be sharded across'
' different tpu cores. If None, this variable is not sharded.'
' Here are examples: np.array([0, 1, 2, 3, 4, 5, 6, 7]) which is a 1d'
' mesh with 8 devices, np.array([[0, 1, 2, 3], [4, 5, 6, 7]]) which is'
' 2d matrix of 8 devices.')
p.Define(
'tensor_split_dims_mapping', tensor_split_dims_mapping,
'A list of integers that map each tensor axis to the device mesh axis'
' along which it is sharded. Its length is the tensor rank, and'
' split_dims_mapping[i] is device mesh axis for tensor dimension i. Use'
' -1 for tensor dimensions that are not sharded. If the list is set to'
' None and a device_mesh is specified, the sharding will be treated as'
' replicated. Here is a concrete examples: '
' device_mesh=np.array([[0, 1, 2, 3] [4, 5, 6, 7]]), of shape [2, 4]'
' shape=[x, y, z], so this is a 3d variable.'
' tensor_split_dims_mapping=[-1, -1, 1], in this case, the third dim'
' of the variable is split along the second dim of the mesh. Each '
' split of the variable is of the shape [x, y, z/4].')
# The following two flags are used in Jax only.
p.Define(
'repeat_prefix', None,
'If not None, the full shape of this var is repeat_prefix+shape. '
'For example, if repeat_prefix=[16, 2], and shape=[512, 1024], then '
'real shape of variable is [16, 2, 512, 1024]. "repeat_prefix" is '
'often used if a layer is to be used in a recurrent loop, where '
'logically there are n sub-layers, but for performance/hbm usage '
'reasons we stack all the variables in creating those n-layers.')
p.Define('repeat_prefix_split_dims_mapping', None,
'Tensor split dims mapping for the repeat_prefix dims.')
return p
[docs]def FindNeeded(endpoints):
"""List names of tensors and operations required to compute endpoints."""
names_seen = set()
queue = []
for e in Flatten(endpoints):
if isinstance(e, tf.Operation):
queue.append(e)
else:
queue.append(e.op)
while queue:
op = queue.pop()
name = op.name
if name not in names_seen:
names_seen.add(name)
names_seen.update((o.name for o in op.outputs))
queue.extend(i.op for i in op.inputs)
queue.extend(op.control_inputs)
return names_seen
[docs]class _CollectionGetter:
"""Get graph local value from a defined collection."""
def __init__(self, key, default_factory):
self._key = key
self._default_factory = default_factory
def __call__(self):
collection = tf.get_collection(self._key)
if collection:
assert len(collection) == 1
return collection[0]
value = self._default_factory()
tf.add_to_collection(self._key, value)
return value
[docs]def SanitizeScopeKey(key):
"""Removes invalid symbols from name_scope keys."""
if key.startswith('_'):
key = key[1:]
return key.replace('[', '_').replace(']', '')
# Maintain a session for unit tests (initialized in test_utils.py).
_SESSION_SCOPE = ThreadLocalStack()
[docs]@contextlib.contextmanager
def UnitTestSessionScope(sess):
_SESSION_SCOPE.stack.append(sess)
try:
yield
finally:
_SESSION_SCOPE.stack.pop()
[docs]def GetUnitTestSession():
"""Get the current variable reuse setting."""
return _SESSION_SCOPE.stack[-1] if _SESSION_SCOPE.stack else None
# Global variable to control multitask variable reuse
# If False (default) the default tf.get_variable is used, that is:
# - Reusing scopes only allow getting existing variables
# - Non-reusing scopes only allow getting new variables
# With GetOpportunisticVariableReuse() == True:
# - Reusing scopes only allow getting existing variables, as usual
# - Non-reusing scopes reuse new variables or get new ones
_OPPORTUNISTIC_VARIABLE_REUSE = ThreadLocalStack()
[docs]@contextlib.contextmanager
def OpportunisticVariableReuseScope(enable_opportunistic_reuse=True):
_OPPORTUNISTIC_VARIABLE_REUSE.stack.append(enable_opportunistic_reuse)
try:
yield
finally:
_OPPORTUNISTIC_VARIABLE_REUSE.stack.pop()
[docs]def GetOpportunisticVariableReuse():
"""Get the current variable reuse setting."""
return (_OPPORTUNISTIC_VARIABLE_REUSE.stack[-1]
if _OPPORTUNISTIC_VARIABLE_REUSE.stack else False)
_DISABLE_VARIABLE_NAME_CHECKING = ThreadLocalStack()
[docs]@contextlib.contextmanager
def DisableVariableNameChecking(disable=True):
_DISABLE_VARIABLE_NAME_CHECKING.stack.append(disable)
try:
yield
finally:
_DISABLE_VARIABLE_NAME_CHECKING.stack.pop()
_VARIABLE_RENAME_RULES = ThreadLocalStack()
# Global variable to track task calling scope.
# Currently only used for TPU Embedding purposes as a TPUEmbeddingLayer
# may be shared across tasks and the calling task needs to be known
# for tracking embedding activations for backprop.
_TASK_CALL_SCOPE = ThreadLocalStack()
[docs]def TaskCallScopeName(task):
"""Get a unique string identifying a task."""
return f'{task.params.name}_{id(task)}'
[docs]@contextlib.contextmanager
def TaskCallScope(task):
_TASK_CALL_SCOPE.stack.append(TaskCallScopeName(task))
try:
yield
finally:
_TASK_CALL_SCOPE.stack.pop()
[docs]def GetTaskCallScope():
"""Get the current task call scope."""
return _TASK_CALL_SCOPE.stack[-1] if _TASK_CALL_SCOPE.stack else None
[docs]@contextlib.contextmanager
def VariableRenameScope(renames):
"""Append the renaming rules to the stack of renames.
Args:
renames: pairs of (regexp, new_name_format). If the regexp matches, the
new_name_format will be interpolated using the matched groups.
Yields:
scope in which the renaming rules are applied
"""
_VARIABLE_RENAME_RULES.stack.append(renames)
try:
yield
finally:
_VARIABLE_RENAME_RULES.stack.pop()
[docs]def GetVariableName(name):
"""Get variable name after application of all renaming rules.
Args:
name: untransformed variable name with scope_name prepended
Returns:
name possibly modified using renaming rules
"""
matched = False
new_name = name
for renames in _VARIABLE_RENAME_RULES.stack:
tf.logging.log_first_n(
tf.logging.WARN,
('Renaming variables is not supported in eager mode. '
'Please look into migrating away from variable renaming.'), 1)
for regexp, name_format in renames:
match = re.match(regexp, name)
if match:
if matched:
tf.logging.warning('Multiple matches for: %s', name)
matched = True
new_name = name_format % match.groups()
if new_name != name:
tf.logging.info("WARNING!!! Renaming variable '%s' to '%s'", name, new_name)
return new_name
_LIST_REGEX_DTYPE = ThreadLocalStack()
[docs]@contextlib.contextmanager
def VariableListDtypeRegexScope(list_regex_dtypes):
"""Append the list of (regex, dtype) to override the dtype.
Args:
list_regex_dtypes: pairs of (regexp, dtype). If the regexp matches, the data
type of the variable will be changed by the corresponding dtype.
Yields:
scope in which the list of (regex, dtype) is applied.
"""
_LIST_REGEX_DTYPE.stack.append(list_regex_dtypes)
try:
yield
finally:
_LIST_REGEX_DTYPE.stack.pop()
[docs]def FindDataType(var_name):
"""Find the data type for var_name.
Args:
var_name: A string, name of the variable.
Returns:
The dtype of the first matched regex with var_name, or None if no matching
found.
"""
for regex_dtypes in _LIST_REGEX_DTYPE.stack:
for regex, data_type in regex_dtypes:
if re.match(regex, var_name):
return data_type
return None
[docs]def GenerateSeedFromName(name):
"""Generate a random seed from a name string.
Args:
name: A string.
Returns:
An integer seed in the range [0, 2**31 - 1).
"""
md5 = hashlib.md5()
md5.update(six.ensure_binary(name))
return np.int64(int(md5.hexdigest(), 16) % (2**31 - 1))
[docs]def MaybeGenerateSeedFromScope():
"""Generate a random seed from the current name of the scope.
If running in eager mode, this returns 0.
Returns:
An integer seed in the range [0, 2**31 - 1).
"""
if not tf.executing_eagerly():
return GenerateSeedFromName(tf.no_op(name='new_step_seed').name)
return 0
[docs]def GenerateSeedFromId(obj_id):
"""Generate a random seed from the id of an object.
If deterministic execution (i.e. unit test), generate the seed from a fixed
unique name instead.
Args:
obj_id: id(object).
Returns:
An integer seed in the range [0, 2**31 - 1).
"""
if tf.get_default_graph().seed is not None:
# We are in a program/test which need determistic randomization.
with tf.name_scope(''):
return GenerateSeedFromName(tf.no_op(name='new_step_seed').name)
md5 = hashlib.md5()
md5.update(np.int64(obj_id))
return np.int64(int(md5.hexdigest(), 16) % (2**31 - 1))
_VARIABLE_SHAPE_PREFIXES = ThreadLocalStack()
[docs]def GetVarLeadingDimsAsCombinedLayers(var):
"""Gets the number of leading dimensions of `var` marked as combined layers.
Such dimensions represent variables from different layers stacked together,
e.g., in RepeatLayer, and optimizers (which have shape-dependant behaviors)
can adjust its behavior based on this information to match the behavior for
separate layer variables.
Args:
var: A variable.
Returns:
An integer representing the number of leading dimensions.
"""
try:
return var.op.get_attr('_num_leading_dims_for_combined_layers')
except ValueError:
return 0
except AttributeError:
# AttributeError: 'DistributedVarOp' object has no attribute 'get_attr'.
return 0
[docs]@contextlib.contextmanager
def VariableShapePrefixContext(shape_prefix):
"""Add a shape prefix to variable created by CreateVariable().
This new dimension will be marked as combined-layers. See also comments for
GetVarLeadingDimsAsCombinedLayers().
Args:
shape_prefix: a positive integer of shape prefix.
Yields:
None.
"""
assert shape_prefix > 0, ('%s' % shape_prefix)
_VARIABLE_SHAPE_PREFIXES.stack.append(shape_prefix)
try:
yield
finally:
_VARIABLE_SHAPE_PREFIXES.stack.pop()
[docs]def GetVariableShapePrefixes():
"""Return the list of shape prefixes for CreateVariable()."""
return _VARIABLE_SHAPE_PREFIXES.stack
[docs]def GetVariableNumLeadingDimsForCombinedLayersContext():
"""Return the number of leading combined-layers dims for CreateVariable()."""
return len(_VARIABLE_SHAPE_PREFIXES.stack)
[docs]def GetFanInFanOut(shape, prefix_dims_to_skip):
"""Returns (fan_in, fan_out) of a weight variable of the give shape."""
if not shape:
return None, None
if len(shape) < prefix_dims_to_skip:
raise ValueError(f'Variable shape is {shape} but prefix_dims_to_skip is '
f'{prefix_dims_to_skip}, larger than the shape rank.')
adjusted_shape = shape[prefix_dims_to_skip:]
if len(adjusted_shape) < 1:
return 1, 1
elif len(adjusted_shape) == 1:
# Following _compute_fans() from TF's init_ops.py.
return adjusted_shape[0], adjusted_shape[0]
else:
receptive_field_size = 1
for s in adjusted_shape[:-2]:
receptive_field_size *= s
fan_in = adjusted_shape[-2] * receptive_field_size
fan_out = adjusted_shape[-1] * receptive_field_size
return fan_in, fan_out
[docs]@contextlib.contextmanager
def VariableStore(default_store=None):
"""Keeps track of {variable_name: (variable, var_params)}.
When CreateVariable would result in a variable name that exists in the store,
the existing variable is returned, or an error is raised, depending on whether
the variable scope supports reuse.
This mimics the behavior of tf.compat.v1.get_variable() with regards to
variable reuse, while functioning correctly in TF2 eager context. However, it
only applies to variables created via CreateVariable.
When there are nested VariableStore contexts, they all provide the same
variable store object. That is, the scope of the variable store is the
outermost context.
Args:
default_store: variable store dict. If set, and there is no store in the
stack, use this store instead of creating a new dict.
Yields:
A dictionary representing the variable store.
"""
old_store = _GetVariableStore()
default_store = default_store or {}
store = old_store if old_store is not None else default_store
graph = tf.get_default_graph()
while hasattr(graph, 'outer_graph') and graph.outer_graph:
graph = graph.outer_graph
graph.lingvo_variable_store = store
yield store
graph.lingvo_variable_store = old_store
[docs]def _GetVariableStore():
graph = tf.get_default_graph()
while hasattr(graph, 'outer_graph') and graph.outer_graph:
graph = graph.outer_graph
if hasattr(graph, 'lingvo_variable_store'):
return graph.lingvo_variable_store
return None
[docs]def _DefaultVariableCreator(**kwargs):
kwargs.pop('var_name')
kwargs.pop('var_params')
return tf.get_variable(**kwargs)
_VARIABLE_CREATOR_STACK = ThreadLocalStack()
[docs]def _GetVariableCreator():
fn = _DefaultVariableCreator
# Latest entry in _VARIABLE_CREATOR_STACK is called last.
for wrapper in reversed(_VARIABLE_CREATOR_STACK.stack):
fn = functools.partial(wrapper, fn)
return fn
[docs]@contextlib.contextmanager
def VariableCreatorScope(variable_creator):
"""Yields a context around a variable_creator, used by `CreateVariable()`.
The function must have the following signature::
def variable_creator(next_creator, **kwargs)
The function may delegate variable creation to the next variable creator, or
return its own tf.Variable.
This differs from tf.variable_creator_scope in that tf.variable_creator_scope
modifies a tf.Variable() call while this modifies a tf.get_variable() call. As
the code is migrated to TF2 and tf.get_variable() is deprecated, this may be
upgraded to using tf.variable_creator_scope instead.
This differs from tf.variable_scope(custom_getter=variable_creator) in that
the kwargs passed can be manipulated.
Variable creators are resolved from the outermost towards the innermost.
The innermost variable creator function is tf.get_variable.
The passed in kwargs must conform to what tf.get_variable accepts, with the
addition of `var_name` and `var_params`.
Args:
variable_creator: A variable creator function.
"""
_VARIABLE_CREATOR_STACK.stack.append(variable_creator)
try:
yield
finally:
_VARIABLE_CREATOR_STACK.stack.pop()
[docs]def PlaceOnTpuCore(core_id):
"""Returns a VariableCreatorScope that places variables on a given tpu core.
Only applies when running with TPUs.
Does not yet properly support model parallelism.
Args:
core_id: The tpu core id.
"""
def Creator(next_creator, **kwargs):
cluster = cluster_factory.Current()
if use_tpu():
device = cluster.WorkerDeviceInModelSplit(core_id)
elif (
tpu_compat() and
cluster.params.job in ('controller', 'trainer_client', 'executor_tpu')):
# The job is running in a fleet that uses tpu, but does not itself have
# access to the tpu, e.g. controller job. In this case, the returned
# device needs to be the cpu device on the tpu host for the given core.
# FIXME: the current implementation is wrong for large values of core_id.
device = cluster.ListDevices(cluster.params.worker)[0, 0]
else:
device = ''
with tf.device(device):
return next_creator(**kwargs)
return VariableCreatorScope(Creator)
# Variable creators.
[docs]def MaybeReuseFromVariableStore(next_creator, **kwargs):
"""Variable creator that attempts to reuse variables from variable store."""
var_name = kwargs['var_name']
p = kwargs['var_params']
store = _GetVariableStore()
if store is not None:
if var_name in store:
if tf.get_variable_scope().reuse:
var, cached_p = store[var_name]
tf.logging.info('Reusing var %s', var.name)
assert cached_p == p.ToText(), (
'Cached config:\n %s vs new config:\n %s' % (cached_p, p.ToText()))
return var
var = next_creator(**kwargs)
if not _DISABLE_VARIABLE_NAME_CHECKING:
if var.name != f'{var_name}/var:0':
raise ValueError(
'Expected %s but created variable %s. Did you mean to set reuse=True '
'or reuse=tf.AUTO_REUSE in VarScope, or did not create a '
'VariableStore for variable reuse?' % (f'{var_name}/var:0', var.name))
tf.logging.vlog(0, 'Creating var %s shape=%s on device %s', var.name,
var.shape, var.device)
for col in p.collections:
tf.add_to_collection(col, var)
if store is not None:
store[var_name] = (var, p.ToText())
return var
[docs]def MaybePinVarsToCpu(next_creator, **kwargs):
if _FromGlobal('pin_vars_to_cpu'):
with tf.device('/cpu:0'):
return next_creator(**kwargs)
return next_creator(**kwargs)
[docs]def MaybeOpportunisticVariableReuse(next_creator, **kwargs):
if GetOpportunisticVariableReuse():
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
return next_creator(**kwargs)
return next_creator(**kwargs)
[docs]def GetLingvoVariableCreator(name, var_name):
"""Returns a variable creator function."""
def LingvoVariableCreator(next_creator, **kwargs):
"""Lingvo variable creator."""
# TODO(yonghui): Possibly get away from variable_scope and implement our own
# variable sharing mechanism.
with tf.variable_scope(name) as scope:
var_scope = tf.VariableScope(
scope.reuse,
custom_getter=scope.custom_getter,
caching_device=scope.caching_device,
use_resource=True)
with tf.variable_scope(var_scope), tf.variable_scope(var_name):
return next_creator(**kwargs)
return LingvoVariableCreator
# TODO(yonghui): Add support for partitioned Variables.
[docs]def CreateVariable(name,
params,
trainable=True,
collections=None,
default_seed=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE):
"""Creates tf.Variable according to param_config.
Args:
name: A string, name of the variable.
params: A WeightParams specifying the details of how this variable should be
constructed and initialized.
trainable: Whether or not the variable is trainable.
collections: Override the default variable collection (
tf.GraphKeys.GLOBAL_VARIABLES). Note that specifying a collections
argument in `params` does not override this collection; the caller must
set this field explicitly in the call to CreateVariable().
default_seed: Seed to use for initialization if not specified in params.
Used for deterministic initialization in tests.
synchronization: Indicates when a distributed a variable will be aggregated.
Accepted values are constants defined in the class
tf.VariableSynchronization. By default the synchronization is set to AUTO
and the current DistributionStrategy chooses when to synchronize.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class tf.VariableAggregation.
Returns:
The created variable.
"""
if use_stateless_vars_init():
return _CreateVariableStateless(name, params, trainable, collections,
default_seed, synchronization, aggregation)
else:
return _CreateVariableStateful(name, params, trainable, collections,
default_seed, synchronization, aggregation)
[docs]def _CreateVariableStateful(name,
params,
trainable=True,
collections=None,
default_seed=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE):
"""Creates tf.Variable using TF stateful RNGs according to param_config.
Args:
name: A string, name of the variable.
params: A WeightParams specifying the details of how this variable should be
constructed and initialized.
trainable: Whether or not the variable is trainable.
collections: Override the default variable collection (
tf.GraphKeys.GLOBAL_VARIABLES).
default_seed: Seed to use for initialization if not specified in params.
Used for deterministic initialization in tests.
synchronization: Indicates when a distributed a variable will be aggregated.
Accepted values are constants defined in the class
tf.VariableSynchronization. By default the synchronization is set to AUTO
and the current DistributionStrategy chooses when to synchronize.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class tf.VariableAggregation.
Returns:
The created variable.
"""
p = params.Copy()
shape = tf.TensorShape(ToStaticShape(p.shape)).as_list()
if shape:
assert all([dim_size > 0 for dim_size in shape]), shape
dim0 = shape[0]
else:
dim0 = 1
assert p.init.method == 'constant' or np.all(np.asarray(p.init.scale) >= 0)
method = p.init.method
scale = p.init.scale
seed = p.init.seed
if IsDefaultParamInit(p.init):
tf.logging.warning(
'WARNING!!! var %s is using the default xavier initializer.'
' Make sure this is intended.', name)
with tf.variable_scope(name) as scope:
var_name = GetVariableName(scope.name)
global_seed = None
if tf.executing_eagerly() and _EAGER_RNG_ADAPTATION:
# Equivalent to graph mode `tf.get_default_graph().seed`
# Refer to `tf.get_seed` implementation
global_seed = context.global_seed()
# These stateful rng initializers use e.g. `tf.compat.v1.random_uniform`,
# which use internal counters in addition to the arg seeds to generate nums.
# `set_seed` resets the internal counters, so that repeated calls to the
# same stateful rng kernel with the same seeds returns the same results.
# This was not necessary for `testCreateVariableUniform`.
# However, it is very likely needed for casese e.g. when `default_seed` is
# specified.
tf.random.set_seed(global_seed)
else:
global_seed = tf.get_default_graph().seed
if global_seed is not None:
# We are in a program/test which need determistic randomization.
if seed is None:
if default_seed is not None:
seed = default_seed
else:
# We are not given a per-variable random seed. We use hash of
# variable name as a stable random seed.
seed = GenerateSeedFromName(var_name)
# If var_name matches a regex, then set the var_dtype; else use p.dtype.
var_dtype = FindDataType(var_name)
if var_dtype is None:
var_dtype = p.dtype
init_dtype = var_dtype.real_dtype
# TODO(b/172827074): we do not natively support var initialization for
# int8 type except for constant initialization.
# NOTE: For int8, we initialize by scaling float32 random values to integer.
if init_dtype == tf.int8:
init_dtype = tf.float32
v_init = _CreateVarInitStateful(name, method, shape, dim0, seed, scale,
init_dtype, p.init.custom_v_init)
if var_dtype == tf.complex64:
def ComplexWrapper(init):
def _Wrapper(shape, dtype):
del dtype
# A more complex alternative may be to use the init function for
# magnitudes and uniform random for phases instead.
shape = [2] + shape
value = init(shape, init_dtype)
return tf.complex(value[0], value[1])
return _Wrapper
v_init = ComplexWrapper(v_init)
if var_dtype == tf.int8:
def FloatToInt8Wrapper(init):
def _Wrapper(shape, dtype):
del dtype
value = init(shape, init_dtype)
scale = tf.math.maximum(
tf.math.reduce_min(value) / -127,
tf.math.reduce_max(value) / 127)
value = tf.divide(value, scale)
return tf.cast(value, tf.int8)
return _Wrapper
v_init = FloatToInt8Wrapper(v_init)
with contextlib.ExitStack() as context_stack:
for variable_creator_fn in (GetLingvoVariableCreator(name, var_name),
MaybeOpportunisticVariableReuse,
MaybePinVarsToCpu, MaybeReuseFromVariableStore):
context_stack.enter_context(VariableCreatorScope(variable_creator_fn))
if method == 'custom_constant':
call_shape = None
else:
call_shape = GetVariableShapePrefixes() + list(shape)
var = _GetVariableCreator()(
var_name=var_name,
var_params=p,
name='var',
shape=call_shape,
dtype=var_dtype,
initializer=v_init,
collections=collections,
trainable=trainable,
validate_shape=True,
synchronization=synchronization,
aggregation=aggregation)
combined_layers_dims = GetVariableNumLeadingDimsForCombinedLayersContext()
if combined_layers_dims > 0:
# pylint: disable=protected-access
var.op._set_attr('_num_leading_dims_for_combined_layers',
attr_value_pb2.AttrValue(i=combined_layers_dims))
# Shard the variable according to the sharding spec.
tensor_split_dims_mapping = p.tensor_split_dims_mapping
if tensor_split_dims_mapping is not None:
count = (
len(GetVariableShapePrefixes()) + len(shape) -
len(tensor_split_dims_mapping) -
len(gshard_utils.GetMeshSplitDimPrefixContext()))
tensor_split_dims_mapping = [-1] * count + tensor_split_dims_mapping
var = gshard_utils.MeshSplit(
var, p.device_mesh, tensor_split_dims_mapping, use_sharding_op=False)
return var
[docs]def _CreateVariableStateless(name,
params,
trainable=True,
collections=None,
default_seed=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE):
"""Creates tf.Variable using TF stateless RNGs according to `params`.
Args:
name: A string, name of the variable.
params: A WeightParams specifying the details of how this variable should be
constructed and initialized.
trainable: Whether or not the variable is trainable.
collections: Override the default variable collection (
tf.GraphKeys.GLOBAL_VARIABLES).
default_seed: Seed to use for initialization if not specified in params.
Used for deterministic initialization in tests.
synchronization: Indicates when a distributed a variable will be aggregated.
Accepted values are constants defined in the class
tf.VariableSynchronization. By default the synchronization is set to AUTO
and the current DistributionStrategy chooses when to synchronize.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class tf.VariableAggregation.
Returns:
The created variable.
"""
p = params.Copy()
shape = tf.TensorShape(ToStaticShape(p.shape)).as_list()
if shape:
assert all([dim_size > 0 for dim_size in shape]), shape
dim0 = shape[0]
else:
dim0 = 1
assert p.init.method == 'constant' or np.all(np.asarray(p.init.scale) >= 0)
method = p.init.method
scale = p.init.scale
seed = p.init.seed
if IsDefaultParamInit(p.init):
tf.logging.warning(
'WARNING!!! var %s is using the default xavier initializer.'
' Make sure this is intended.', name)
with tf.variable_scope(name) as scope:
var_name = GetVariableName(scope.name)
user_seed = seed if seed is not None else default_seed
seed = _GenerateStatelessRngSeed(var_name, user_seed)
# If var_name matches a regex, then set the var_dtype; else use p.dtype.
var_dtype = FindDataType(var_name)
if var_dtype is None:
var_dtype = p.dtype
init_dtype = var_dtype.real_dtype
v_init = _CreateVarInitStateless(name, method, shape, dim0, seed, scale,
init_dtype, p.init.custom_v_init)
if var_dtype == tf.complex64:
raise TypeError(
'Stateless variable initialization does not support tf.complex64.')
with contextlib.ExitStack() as context_stack:
for variable_creator_fn in (GetLingvoVariableCreator(name, var_name),
MaybeOpportunisticVariableReuse,
MaybeReuseFromVariableStore):
context_stack.enter_context(VariableCreatorScope(variable_creator_fn))
var = _GetVariableCreator()(
var_name=var_name,
var_params=p,
name='var',
shape=GetVariableShapePrefixes() + list(shape),
dtype=var_dtype,
initializer=v_init,
collections=collections,
trainable=trainable,
validate_shape=True,
synchronization=synchronization,
aggregation=aggregation)
combined_layers_dims = GetVariableNumLeadingDimsForCombinedLayersContext()
if combined_layers_dims > 0:
# pylint: disable=protected-access
var.op._set_attr('_num_leading_dims_for_combined_layers',
attr_value_pb2.AttrValue(i=combined_layers_dims))
# Shard the variable according to the sharding spec.
tensor_split_dims_mapping = p.tensor_split_dims_mapping
if tensor_split_dims_mapping is not None:
count = (
len(GetVariableShapePrefixes()) + len(shape) -
len(tensor_split_dims_mapping) -
len(gshard_utils.GetMeshSplitDimPrefixContext()))
tensor_split_dims_mapping = [-1] * count + tensor_split_dims_mapping
var = gshard_utils.MeshSplit(
var, p.device_mesh, tensor_split_dims_mapping, use_sharding_op=False)
return var
[docs]def _CreateVarInitStateful(name,
method,
shape,
dim0,
seed,
scale,
init_dtype,
custom_v_init=None):
"""Creates variable initialization function for a stateful RNG."""
if (method in [
'gaussian_sqrt_dim', 'uniform_sqrt_dim', 'truncated_gaussian_sqrt_dim'
]):
if len(shape) > 2:
# This is probably not the right method to use when len(shape) > 2,
# e.g. dim0 will be 3 with a 3x3 conv2d kernel.
tf.logging.warning(
'Initializing %s of shape %s with method %s: dim0=%s. '
'Make sure that it is intended.', name, shape, method, dim0)
scale *= 1.0 / math.sqrt(dim0)
combined_layers_dims = GetVariableNumLeadingDimsForCombinedLayersContext()
if method in ['gaussian_sqrt_fanin', 'truncated_gaussian_sqrt_fanin']:
fan_in, _ = GetFanInFanOut(shape, combined_layers_dims)
if fan_in is not None:
scale *= 1.0 / math.sqrt(fan_in)
if method in ['gaussian_sqrt_fanout', 'truncated_gaussian_sqrt_fanout']:
_, fan_out = GetFanInFanOut(shape, combined_layers_dims)
if fan_out is not None:
scale *= 1.0 / math.sqrt(fan_out)
if method in ['gaussian_sqrt_fanavg']:
fan_in, fan_out = GetFanInFanOut(shape, combined_layers_dims)
if fan_in is not None and fan_out is not None:
scale *= math.sqrt(2.0 / (fan_in + fan_out))
if method in [
'gaussian', 'gaussian_sqrt_dim', 'gaussian_sqrt_fanin',
'gaussian_sqrt_fanout', 'gaussian_sqrt_fanavg'
]:
v_init = init_ops.random_normal_initializer(
mean=0.0, stddev=scale, seed=seed, dtype=init_dtype)
elif method in ['uniform', 'uniform_sqrt_dim']:
v_init = init_ops.random_uniform_initializer(
minval=-scale, maxval=scale, seed=seed, dtype=init_dtype)
elif method in ['uniform_positive']:
v_init = init_ops.random_uniform_initializer(
minval=0.0, maxval=scale, seed=seed, dtype=init_dtype)
elif method == 'category':
uniform_init = init_ops.random_uniform_initializer(
minval=0.0, maxval=scale, seed=seed, dtype=init_dtype)
v_init = lambda *args, **kwargs: tf.floor(uniform_init(*args, **kwargs))
elif method in ['uniform_unit_scaling']:
v_init = init_ops.uniform_unit_scaling_initializer(
factor=scale, seed=seed, dtype=init_dtype)
elif method in ['uniform_unit_scaling_fan_avg']:
v_init = tf.variance_scaling_initializer(
scale=scale,
mode='fan_avg',
distribution='uniform',
seed=seed,
dtype=init_dtype)
elif method in [
'truncated_gaussian', 'truncated_gaussian_sqrt_dim',
'truncated_gaussian_sqrt_fanin', 'truncated_gaussian_sqrt_fanout'
]:
v_init = init_ops.truncated_normal_initializer(
mean=0.0, stddev=scale, seed=seed, dtype=init_dtype)
elif method in ['constant']:
v_init = init_ops.constant_initializer(value=scale, dtype=init_dtype)
elif method in ['xavier', 'geo_mean_xavier']:
def XavierUniform(shape, dtype):
"""Xavier initialization (x = sqrt(6. / (in + out)); scale*[-x, x])."""
if not shape:
raise ValueError(
'\'shape\' must not be \'None\' or 0 for XavierUniform')
fan_in, fan_out = GetFanInFanOut(shape, combined_layers_dims)
if method == 'xavier':
limit = math.sqrt(6. / (fan_in + fan_out))
elif method == 'geo_mean_xavier':
limit = math.sqrt(3. / math.sqrt(fan_in * fan_out))
return scale * tf.random.uniform(shape, -limit, limit, dtype, seed)
v_init = XavierUniform
elif method in [
'kaiming_uniform_fanin_relu', 'kaiming_uniform_fanin_leakyrelu'
]:
fan_in = np.prod(shape[:-1])
if method == 'kaiming_uniform_fanin_leakyrelu':
# Assume the 'a' parameter is the 'scale' argument.
gain = np.sqrt(2. / (1 + scale**2))
else:
gain = np.sqrt(2.)
std_dev = gain / np.sqrt(fan_in)
bound = np.sqrt(3.0) * std_dev
v_init = init_ops.random_uniform_initializer(
minval=-bound, maxval=bound, seed=seed, dtype=init_dtype)
elif method in ['custom', 'custom_constant']:
v_init = custom_v_init
else:
assert False, 'init_type `%s` not supported.' % method
return v_init
[docs]def _GenerateStatelessRngSeed(name, seed):
"""Generates a 2-tuple seed for a stateless variable initializer.
We want to ensure that different variables end up with different random values
even when they are passed the same seed and shape. To this aim, this function
generates a pseudo-unique seed by hashing the variable name and mapping it
into a scalar seed. More specifically, the returned value is a 2-tuple of
tf.int32 scalar, where the first element is the user-provided seed and the
second element is obtained by hashing the variable name.
Args:
name: The variable name for which to generate a stateless-like seed.
seed: The user-specified scalar seed.
Returns:
A 2-tuple seed of tf.int32 values (for TPU compatibility).
"""
seed0 = seed or 0
seed1 = GenerateSeedFromName(name)
return tf.constant([seed0, seed1], dtype=tf.int32)
[docs]def _DeterministicRandomNormalInitializer(seed, mean, stddev):
"""Creates a random normal initializer."""
def DeterministicNormal(shape, dtype):
return stateless_random_ops.stateless_random_normal(
shape=shape, seed=seed, mean=mean, stddev=stddev, dtype=dtype)
return DeterministicNormal
[docs]def _DeterministicRandomTruncatedNormalInitializer(seed, mean, stddev):
"""Creates a random truncated normal initializer."""
def DeterministicTruncatedNormal(shape, dtype):
return stateless_random_ops.stateless_truncated_normal(
shape=shape, seed=seed, mean=mean, stddev=stddev, dtype=dtype)
return DeterministicTruncatedNormal
[docs]def _DeterministicRandomVarianceScalingInitializer(scale, mode, distribution,
seed):
"""Creates a variance scaling initializer."""
if scale <= 0.:
raise ValueError('`scale` must be positive float.')
if mode not in {'fan_in', 'fan_out', 'fan_avg'}:
raise ValueError('Invalid `mode` argument:', mode)
distribution = distribution.lower()
if distribution not in {
'normal', 'uniform', 'truncated_normal', 'untruncated_normal'
}:
raise ValueError('Invalid `distribution` argument:', distribution)
combined_layers_dims = GetVariableNumLeadingDimsForCombinedLayersContext()
def DeterministicVarianceScaling(shape, dtype):
# This is originally from TensorFlow: python/ops/init_ops.py
scale_shape = shape
# Handle special case of empty list as shape, since fan_in and fan_out
# are numerically added below. Without this, GetFanInFanOut() would
# return None, None instead.
if isinstance(scale_shape, (list, tuple)) and not scale_shape:
fan_in, fan_out = 1, 1
else:
fan_in, fan_out = GetFanInFanOut(scale_shape, combined_layers_dims)
if mode == 'fan_in':
scale_inner = scale / max(1., fan_in)
elif mode == 'fan_out':
scale_inner = scale / max(1., fan_out)
else:
scale_inner = scale / max(1., (fan_in + fan_out) / 2.)
if distribution == 'normal' or distribution == 'truncated_normal':
# constant taken from scipy.stats.truncnorm.std(
# a=-2, b=2, loc=0., scale=1.)
stddev = math.sqrt(scale_inner) / .87962566103423978
return stateless_random_ops.stateless_truncated_normal(
shape=shape, seed=seed, mean=0.0, stddev=stddev, dtype=dtype)
elif distribution == 'untruncated_normal':
stddev = math.sqrt(scale_inner)
return stateless_random_ops.stateless_random_normal(
shape=shape, seed=seed, mean=0.0, stddev=stddev, dtype=dtype)
else:
limit = math.sqrt(3.0 * scale_inner)
return stateless_random_ops.stateless_random_uniform(
shape=shape, seed=seed, minval=-limit, maxval=limit, dtype=dtype)
return DeterministicVarianceScaling
[docs]def _CreateVarInitStateless(name,
method,
shape,
dim0,
seed,
scale,
init_dtype,
custom_v_init=None):
"""Creates variable initialization function for a stateless RNG."""
if (method in [
'gaussian_sqrt_dim', 'uniform_sqrt_dim', 'truncated_gaussian_sqrt_dim'
]):
if len(shape) > 2:
# This is probably not the right method to use when len(shape) > 2,
# e.g. dim0 will be 3 with a 3x3 conv2d kernel.
tf.logging.warning(
'Initializing %s of shape %s with method %s: dim0=%s. '
'Make sure that it is intended.', name, shape, method, dim0)
scale *= 1.0 / math.sqrt(dim0)
combined_layers_dims = GetVariableNumLeadingDimsForCombinedLayersContext()
if method in ['gaussian_sqrt_fanin', 'truncated_gaussian_sqrt_fanin']:
fan_in, _ = GetFanInFanOut(shape, combined_layers_dims)
if fan_in is not None:
scale *= 1.0 / math.sqrt(fan_in)
if method in ['gaussian_sqrt_fanout', 'truncated_gaussian_sqrt_fanout']:
_, fan_out = GetFanInFanOut(shape, combined_layers_dims)
if fan_out is not None:
scale *= 1.0 / math.sqrt(fan_out)
if method in ['gaussian_sqrt_fanavg']:
fan_in, fan_out = GetFanInFanOut(shape, combined_layers_dims)
if fan_in is not None and fan_out is not None:
scale *= math.sqrt(2.0 / (fan_in + fan_out))
if method in [
'gaussian', 'gaussian_sqrt_dim', 'gaussian_sqrt_fanin',
'gaussian_sqrt_fanout', 'gaussian_sqrt_fanavg'
]:
v_init = _DeterministicRandomNormalInitializer(
seed=seed, mean=0., stddev=scale)
elif method in ['uniform', 'uniform_sqrt_dim']:
v_init = _DeterministicRandomUniformInitializer(
seed=seed, minval=-scale, maxval=scale)
elif method in ['uniform_positive']:
v_init = _DeterministicRandomUniformInitializer(
seed=seed, minval=0., maxval=scale)
elif method in ['uniform_unit_scaling']:
v_init = _DeterministicRandomUniformUnitScalingInitializer(
seed=seed, factor=scale)
elif method in ['uniform_unit_scaling_fan_avg']:
v_init = _DeterministicRandomVarianceScalingInitializer(
scale=scale, mode='fan_avg', distribution='uniform', seed=seed)
elif method in [
'truncated_gaussian', 'truncated_gaussian_sqrt_dim',
'truncated_gaussian_sqrt_fanin', 'truncated_gaussian_sqrt_fanout'
]:
v_init = _DeterministicRandomTruncatedNormalInitializer(
seed=seed, mean=0., stddev=scale)
elif method in ['constant']:
v_init = init_ops.constant_initializer(value=scale, dtype=init_dtype)
elif method in ['xavier', 'geo_mean_xavier']:
v_init = _DeterministicRandomXavierUniformInitializer(method, scale, seed)
elif method in [
'kaiming_uniform_fanin_relu', 'kaiming_uniform_fanin_leakyrelu'
]:
fan_in = np.prod(shape[:-1])
if method == 'kaiming_uniform_fanin_leakyrelu':
# Assume the 'a' parameter is the 'scale' argument.
gain = np.sqrt(2. / (1 + scale**2))
else:
gain = np.sqrt(2.)
std_dev = gain / np.sqrt(fan_in)
bound = np.sqrt(3.0) * std_dev
v_init = _DeterministicRandomUniformInitializer(
seed=seed, minval=-bound, maxval=bound)
elif method in ['custom', 'custom_constant']:
v_init = custom_v_init
else:
assert False, 'init_type %s not supported.' % method
return v_init
_global_variable_scope = None
[docs]def GetGlobalVariableScope():
"""Gets the global variable scope (as if no variable_scope has been set).
Returns:
The VariableScope corresponding to as if no tf.variable_scope is in effect.
"""
if not _global_variable_scope:
# Each thread gets its own default global variable scope, and we take
# advantage of that in order to get a top-level scope. This avoids the
# need to call tf.get_variable_scope() at the module level, which allows
# this module to be imported without modifying global state (i.e. creating
# the default graph). It is important to not mutate the global state at
# module load time, because it let's us flip flags after import that affect
# core TensorFlow behavior.
def Initialize():
global _global_variable_scope
_global_variable_scope = tf.get_variable_scope()
t = threading.Thread(target=Initialize)
t.start()
t.join()
return _global_variable_scope
_GLOBAL_STEP_STACK = ThreadLocalStack()
[docs]@contextlib.contextmanager
def GlobalStepContext(global_step_tensor):
_GLOBAL_STEP_STACK.stack.append(global_step_tensor)
try:
yield
finally:
_GLOBAL_STEP_STACK.stack.pop()
_WARN_ON_GLOBAL_STEP_ACCESS = False
_WARNED_STACKS = set()
[docs]@contextlib.contextmanager
def WarnOnGlobalStepAccess():
global _WARN_ON_GLOBAL_STEP_ACCESS
_WARN_ON_GLOBAL_STEP_ACCESS = True
try:
yield
finally:
_WARN_ON_GLOBAL_STEP_ACCESS = False
[docs]def _MaybeWarnAboutPotentiallyIncorrectEvaluations():
"""Warn the user about access to the global step variable in pure Eager mode.
This is useful for the TF1 -> TF2 migration project. When a model has a value
that should change as a function of the global step (e.g. a scheduler in the
input pipeline), it might get incorrectly evaluated to a static value in the
TF2 mode trainer at the begining. This function gives a warning to the user
about such cases, so that issues e.g. unexpected loss curves can be more
easily triaged.
"""
# TODO(b/212648719): Come up with better notifications for the users.
# `executing_eagerly` filters out access to global step in `tf.function`.
# In tf.function, access to global step (which is a Variable) is safe.
if _WARN_ON_GLOBAL_STEP_ACCESS and tf.executing_eagerly():
# The number of stacks needs to be larger than 2 (to include the call to the
# global step), but not too large to avoid spamming
trace = traceback.format_stack(limit=4)
if tuple(trace) not in _WARNED_STACKS:
tf.logging.warning(
'You are accessing the global step in Tensorflow pure Eager mode. '
'If the global step is used to calculate a value, this evaluation '
'will be eagerly executed and the value will not change during '
'training. If you want the value to change during training, please '
'move the evaluation expression into the training loop instead.')
tf.logging.warning(f'Traceback to locate global step access: {trace}')
_WARNED_STACKS.add(tuple(trace))
[docs]def GetGlobalStep():
"""Return the global_step."""
_MaybeWarnAboutPotentiallyIncorrectEvaluations()
if _GLOBAL_STEP_STACK.stack:
return _GLOBAL_STEP_STACK.stack[-1]
return tf.train.get_global_step()
[docs]def GetOrCreateGlobalStepVar():
"""Return the global_step variable, creating it if it does not exist.
Prefer GetGlobalStep if a tensor rather than a tf.Variable is sufficient.
Returns:
The global_step variable, or a new created one if it does not exist.
"""
with tf.variable_scope(GetGlobalVariableScope(), use_resource=True):
if _FromGlobal('pin_vars_to_cpu'):
with tf.device('/cpu:0'):
return tf.train.get_or_create_global_step()
else:
return tf.train.get_or_create_global_step()
[docs]def LogMultiLines(label, lines):
if not isinstance(lines, (list, tuple)):
lines = lines.split('\n')
for line in lines:
tf.logging.info('%s: %s', label, line)
[docs]def _LogPlacement(label, theta, copy):
"""Logs theta and its copy's device placement."""
def GetDevices(m):
"""Flatten a `.NestedMap` m and extracts each value's device."""
return [x.device for x in m.Flatten()]
tf.logging.info('=== %s ===', label)
LogMultiLines(
label,
theta.Pack([('%s -> %s' % (x[0], x[1]))
for x in zip(GetDevices(theta), GetDevices(copy))
]).DebugString())
tf.logging.info('==========')
[docs]def CreateLocalTheta(theta, device_list=None, label=None):
"""Creates local copy of theta and shards across devices device list.
Leaves variables intact.
Args:
theta: a `.NestedMap` of variables.
device_list: list of devices to shard across. If None, defaults to a list
[''].
label: Logging label.
Returns:
A `.NestedMap` of identity() wrapped theta
"""
class AddIdentity:
"""Helper class."""
def __init__(self, device_list):
self._list = device_list if device_list else ['']
self._index = 0
def __call__(self, x):
if isinstance(x, tf.Variable):
return x
with tf.device(self._list[self._index % len(self._list)]):
self._index += 1
return tf.identity(x)
copy = theta.Transform(AddIdentity(device_list))
_LogPlacement(label, theta, copy)
return copy
[docs]def _GetVarsToLoad(all_vars,
variable_loading_rules,
var_ignore_rules,
ckpt_path,
suppress_logging=False):
"""Determines variables to load and their names in checkpoint."""
# This list contains mappings from var names as they appear in the checkpoint
# to the vars in our model they correspond to.
unused_rules = {
regexp: name_format for regexp, name_format in variable_loading_rules
}
vars_to_load = []
for model_var in all_vars:
loaded = False
for regexp, name_format in variable_loading_rules:
match = re.match(regexp, model_var.name)
# Skip if var doesn't match the loading rules, or if it should be ignored.
if not match:
if not suppress_logging:
tf.logging.debug('Loading rules do not match %s.', model_var.name)
continue
elif any(re.match(r, model_var.name) for r in var_ignore_rules):
if not suppress_logging:
tf.logging.debug('Ignoring %s from loading.', model_var.name)
continue
checkpoint_var_name = name_format % match.groups()
if checkpoint_var_name.endswith(':0'):
checkpoint_var_name = checkpoint_var_name[:-2]
if not suppress_logging:
tf.logging.info('Loading %s from %s with regexp: %s', model_var.name,
checkpoint_var_name, regexp)
vars_to_load.append((checkpoint_var_name, model_var))
unused_rules.pop(regexp, None)
loaded = True
break
if not loaded and not suppress_logging:
tf.logging.info(
'Not loading model variable %s from %s as it does not match any rules'
' or matches ignored', model_var.name, ckpt_path)
if not suppress_logging:
for regexp, name_format in unused_rules.items():
tf.logging.warning(f'User provided rule matched no variables: ({regexp}, '
f'{name_format})')
return vars_to_load
[docs]def OverrideVarsFromCheckpoint(all_vars, checkpoint_path,
variable_loading_rules, var_ignore_rules):
"""Add TF graph ops to override variables from a provided checkpoint.
Args:
all_vars: List of all the parameters in the model.
checkpoint_path: A path to the checkpoints of a pretrained model.
variable_loading_rules: A list of tuples of strings defining (regex to match
parameter names in the model to override, format string to determine the
corresponding var in the checkpoint).
var_ignore_rules: A list consisting of a list of regexes to match parameter
names in the model which should not be overridden, even if they match
those in the loading rules.
Returns:
A callable that, when called with a tf.Session, will restore the variables
from the provided checkpoint.
"""
vars_to_load = _GetVarsToLoad(all_vars, variable_loading_rules,
var_ignore_rules, checkpoint_path)
if not vars_to_load:
all_rules_text = '\n'.join(
[f'{k} --> {v}' for k, v in variable_loading_rules])
raise ValueError(f'Variable loading rules {all_rules_text} '
f'did not match any of {len(all_vars)} vars.')
load_var_names = '\n'.join(sorted([v.name for _, v in vars_to_load]))
tf.logging.info(f'Overriding {len(vars_to_load)} vars from '
f'{checkpoint_path}:\n{load_var_names}')
savers = []
while vars_to_load:
# When restoring, it's possible the same value in the checkpoint
# can be restored to multiple variables (e.g. during
# distillation). However, tf.train.Saver, since it's used for
# both saving and restoring, requires the name in the checkpoint
# to be unique for each variable. So, we call it multiple times
# with a unique set of names each time.
unique_vars_to_load = {}
remaining_vars_to_load = []
for k, v in vars_to_load:
if k not in unique_vars_to_load:
unique_vars_to_load[k] = v
else:
remaining_vars_to_load.append((k, v))
savers.append(tf.train.Saver(var_list=unique_vars_to_load, sharded=True))
vars_to_load = remaining_vars_to_load
def _Restore(sess):
for saver in savers:
saver.restore(sess, checkpoint_path)
return _Restore
[docs]def OverrideVarsFromCheckpoints(all_vars, ckpts_loading_rules):
"""Add TF graph ops to override model variables from checkpoints.
Args:
all_vars: List of all the parameters in the model.
ckpts_loading_rules: A dictionary of checkpoint path: loading rules.
Checkpoint path must be a path to a pretrained model, and loading rules is
expected to be a tuple of two lists. The first consisting of tuples of
strings defining (regex to match parameter names in the model to override,
format string to determine the corresponding var in the checkpoint), and
the second list consisting of a list of regexes to match parameter names
in the model which should not be overridden, even if they match those in
the loading rules.
Returns:
A callable that, when called with a tf.Session, will restore the variables
from checkpoint and return a list of overwritten variables.
Raises:
ValueError: if colliding vars exist or loading rules is not a list.
"""
if len(ckpts_loading_rules) > 1:
tf.logging.info('Overriding vars from multiple checkpoints.')
var_refs_overridden = set()
var_names_overridden = set()
restore_fns = []
for ckpt_path, loading_rules in ckpts_loading_rules.items():
tf.logging.info('Overriding vars from checkpoint: %s', ckpt_path)
if not isinstance(loading_rules, tuple):
raise ValueError('Loading rules for %s must be a tuple of two lists!' %
ckpt_path)
if len(loading_rules) != 2 or not all(
isinstance(l, list) for l in loading_rules):
raise ValueError('Loading rules for %s must be a tuple of two lists!' %
ckpt_path)
# Filter the model variables to be overridden.
to_load_vars = _GetVarsToLoad(
all_vars,
loading_rules[0],
loading_rules[1],
ckpt_path,
suppress_logging=True)
var_refs_to_override = [var[1].ref() for var in to_load_vars]
var_names_to_override = [var[1].name for var in to_load_vars]
overlap_refs = set.intersection(var_refs_overridden, var_refs_to_override)
if overlap_refs:
raise ValueError('Colliding variables to override: %s' % overlap_refs)
restore_fns.append(
OverrideVarsFromCheckpoint(all_vars, ckpt_path, loading_rules[0],
loading_rules[1]))
var_refs_overridden.update(var_refs_to_override)
var_names_overridden.update(var_names_to_override)
def _Restore(sess):
for fn in restore_fns:
fn(sess)
tf.logging.info('Model variables overridden: %s', var_names_overridden)
return var_names_overridden
return _Restore
[docs]def ComputeGradientsSimple(loss_or_activations,
all_vars,
grad_aggregation_method,
colocate_gradients_with_ops,
gate_gradients,
activations_grad=None):
"""Compute gradients."""
tape = _GRADIENT_TAPE_STACK.stack[-1] if _GRADIENT_TAPE_STACK.stack else None
if IsEagerMode() and tape:
tf.logging.info('ComputeGradientsSimple: using gradient tape.')
if activations_grad is not None:
raise ValueError('GradientTape does not accept gradient input values.')
if grad_aggregation_method or colocate_gradients_with_ops or gate_gradients:
tf.logging.warning(
'When GradientTape is used, these field will be ignored: '
f'grad_aggregation_method ({grad_aggregation_method}), '
f'colocate_gradients_with_ops ({colocate_gradients_with_ops}), '
f'gate_gradients ({gate_gradients}).')
return tape.gradient(
loss_or_activations,
all_vars,
unconnected_gradients=tf.UnconnectedGradients.NONE)
return tf.gradients(
loss_or_activations,
all_vars,
grad_ys=activations_grad,
aggregation_method=grad_aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
gate_gradients=gate_gradients)
[docs]def _ComputeGradientsTpu(loss_or_activations,
all_vars,
grad_aggregation_method,
colocate_gradients_with_ops,
gate_gradients,
skip_zero_gradients=None,
use_bf16_gradients_ar=False,
defer_crs_to_apply_grad=False,
activations_grad=None,
is_activations=False,
tpu_embedding_activations=None):
"""Computes gradients for local loss across whole TPU cluster.
This implementation specializes for the case where weight params maybe used
for different number of times in the forward computation, so that gradients
should be normalized by the actual number of times they are being computed.
TODO(yonghui): Maybe merge this implementation with the _ComputeGradientsTpu
one.
Args:
loss_or_activations: The loss or activations to backprop from.
all_vars: Vars with respect to which gradients are to be computed.
grad_aggregation_method: aggregation method to use when calling
tf.gradients.
colocate_gradients_with_ops: boolean, whether or not to colocate gradient op
with the original op.
gate_gradients: boolean, flag to be passed to tf.gradients.
skip_zero_gradients: whether to skip zero gradients during aggregation.
use_bf16_gradients_ar: Whether to use bfloat16 dtype for gradients
all-reduce.
defer_crs_to_apply_grad: Whether to defer gradient cross replica sum to
apply_gradient. This helps reducing the number of gradient all-reduces
when doing gradient accumulation, which does gradient cross replica sum
only every k steps in a tf.cond. Currently this works only when
skip_zero_gradients is None.
activations_grad: The gradients computed for activations.
is_activations: A boolean, whether the input is loss or activations.
tpu_embedding_activations: A `.NestedMap` of tpu embedding feature name ->
embedding feature tensor.
Returns:
Gradients to be passed back. If tpu_embedding_activations is set, their
gradients will be placed at the end.
Raises:
ValueError: upon invalid arguments.
"""
if is_activations:
assert activations_grad is not None
if not skip_zero_gradients and not is_activations:
# Scale the loss to account for the full batch size.
shards = tpu_function.get_tpu_context().number_of_shards
assert shards
loss_or_activations *= tf.constant(
1.0 / shards, dtype=loss_or_activations.dtype)
else:
assert not tpu_embedding_activations, (
'Gradient computation for tpu embedding activations requires proper '
'loss scaling, and so is not compatible with skip_zero_gradients and '
'is_activations.')
# Computes the gradients.
# Sum the grads so that we can compute statistics across the whole batch.
all_grads = ComputeGradientsSimple(
loss_or_activations=loss_or_activations,
all_vars=all_vars +
(tpu_embedding_activations if tpu_embedding_activations else []),
grad_aggregation_method=grad_aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
gate_gradients=gate_gradients,
activations_grad=activations_grad)
if tpu_embedding_activations:
# Note we don't need to aggregate TPU embedding gradients below.
tpu_embedding_grads = all_grads[len(all_vars):]
all_grads = all_grads[:len(all_vars)]
else:
tpu_embedding_grads = []
# NOTE: We can't use tpu_optimizer.CrossShardOptimizer since
# we need to scale the grads *after* the cross_replica_sum to
# match GPU version!
# TODO(cwhipkey): should we do something different here? - we could do
# some operations on the gradients before the aggregation (see comments in
# tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py - see compute_gradients -
# for some more details).
aggregated_grads = []
for g in all_grads:
if g is None:
aggregated_grads.append(None)
continue
if use_bf16_gradients_ar:
g = tf.cast(g, tf.bfloat16)
with tf.ops.colocate_with(g):
if skip_zero_gradients is None:
# loss is already scaled by 1/shards.
if defer_crs_to_apply_grad:
normalized_g = tf.convert_to_tensor(g)
else:
normalized_g = tf.tpu.cross_replica_sum(g)
else:
# Compute the cross-replica mean of 'g', skipping zero gradients.
# Q(yonghui): Is there a better way to detect a non-zero gradient?
# Note(yonghui): gradient of a weight can be zero if that
# weight is not used in the forward computation, e.g. as in
# switchable layers in neural architecture search, pruned by channel
# mask, or sparsified.
if skip_zero_gradients == 'weight':
# Same shape as 'g'.
g_is_non_zero = tf.cast(tf.math.abs(g) > 1e-8, g.dtype)
elif skip_zero_gradients == 'variable':
# A variable-wide 0/1 scalar.
g_is_non_zero = tf.cast(
tf.reduce_sum(tf.math.abs(g)) > 1e-24, g.dtype)
else:
raise ValueError('Unknown skip_zero_gradients: %s' %
skip_zero_gradients)
num_updates = tf.maximum(tf.tpu.cross_replica_sum(g_is_non_zero), 1.0)
normalized_g = tf.tpu.cross_replica_sum(g) / num_updates
aggregated_grads.append(normalized_g)
return aggregated_grads + tpu_embedding_grads
[docs]class _VarGrad(typing.NamedTuple):
var: tf.Tensor
grad: Union[tf.Tensor, tf.IndexedSlices]
scale: Optional[tf.Tensor] = None
[docs]class VarGrad:
"""A class that holds a variable and a gradient.
This does not inherit from namedtuple so that tf.nest operations do not
recurse into it.
"""
def __init__(self, *args, **kwargs):
self._var_grad = _VarGrad(*args, **kwargs)
def __getitem__(self, key):
return self._var_grad[key]
def __getattr__(self, key):
return getattr(self._var_grad, key)
def __iter__(self):
if self._var_grad.scale is None:
return iter((self._var_grad.var, self._var_grad.grad))
return iter(self._var_grad)
def __repr__(self):
return repr(self._var_grad)
[docs]def SkipNoneGradients(var_grads):
"""Removes pairs whose grad is None."""
for key, (_, g) in var_grads.FlattenItems():
if g is None:
tf.logging.info('ComputeGradients drops %s', key)
return var_grads.Filter(lambda var_grad: var_grad.grad is not None)
[docs]def ComputeGradients(
loss_or_activations,
vmap,
grad_aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE,
colocate_gradients_with_ops=True,
gate_gradients=False,
compute_gradients_fn=None,
skip_zero_gradients=None,
use_bf16_gradients_ar=False,
skip_none_gradients=True,
defer_crs_to_apply_grad=False,
activations_grad=None,
is_activations=False,
tpu_embedding_activations=None):
"""Computes gradients of variables in vmap w.r.t loss.
Args:
loss_or_activations: either the loss, which is a scalar tensor, or
activations, which could be a tensor or a list of tensors.
vmap: A `.NestedMap` of variables.
grad_aggregation_method: Specifies the method used to combine gradient
terms. Accepted values are constants defined in the class
AggregationMethod.
colocate_gradients_with_ops: If True, try colocating gradients with the
corresponding op.
gate_gradients: If True, add a tuple around the gradients returned for an
operations. This avoids some race conditions.
compute_gradients_fn: Function to use to compute gradients. If None, use
default. compute_gradients_fn should have the same signature as this
function, but without the last argument.
skip_zero_gradients: Whether to skip aggregating zero gradients. This helps
in case where some weights may not be used in forward computation, e.g.,
sparsely activated networks or switchable layers in neural architectural
search. Only applicable on TPU.
Possible values are:
- None: do not skip zero gradients;
- `variable`: skip if the entire variable's gradients are almost zero;
reduce_sum(abs(grads)) < 1e-8.
- `weight`: skip if the individual weight's gradients are almost zero:
abs(grad) < 1e-8.
use_bf16_gradients_ar: Whether to use bfloat16 dtype for gradients
all-reduce. This applies to TPU only.
skip_none_gradients: Whether to skip gradients that are None.
defer_crs_to_apply_grad: Whether to defer gradient cross replica sum to
apply_gradient. This applies to TPU only.
activations_grad: The gradients computed for activations.
is_activations: A boolean, whether the input is loss or activations.
tpu_embedding_activations: A `.NestedMap` of tpu embedding feature name ->
embedding feature tensor.
Returns:
var_grad - a `.NestedMap` of VarGrad. You can view
var_grad as an ordered list of (key, (var, grad)) tuples. Every
key of var_grad exists in vmap. Every variable in vmap that
contributes to loss must exist in var_grad. Every var of var_grad
must exist in vmap. grad is the corresponding gradient computed
for var. grad is guaranteed to be not None.
If tpu_embedding_activations is set, a sub `.NestedMap` named
tpu_embedding_var_grads will be used to store the VarGrads for the
activations. In this case, key is the feature name, and var in the VarGrad
is the activation tensor (not a real variable).
"""
if not is_activations:
loss_or_activations = HasRank(loss_or_activations, 0)
if not tpu_embedding_activations:
tpu_embedding_activations = NestedMap()
assert isinstance(tpu_embedding_activations, NestedMap)
assert isinstance(vmap, NestedMap)
assert skip_zero_gradients in (None, 'variable', 'weight')
# Uniqify and remove None.
filtered_vmap = vmap.Filter(_Unique())
assert filtered_vmap
# Skip non-trainable variables. Otherwise, tf.Optimizer.apply_gradients throws
# up an exception instead of skipping the update.
def Needed(v):
if isinstance(v, tf.Variable):
return v.trainable
return True
filtered_vmap = filtered_vmap.Filter(Needed)
assert filtered_vmap
filtered_vlist = filtered_vmap.Flatten()
# Use caller-supplied gradient function if supplied.
if compute_gradients_fn is not None:
assert not tpu_embedding_activations
take_grad = compute_gradients_fn
else:
# tpu vs non-tpu is slightly different.
if use_tpu():
take_grad = functools.partial(
_ComputeGradientsTpu,
skip_zero_gradients=skip_zero_gradients,
use_bf16_gradients_ar=use_bf16_gradients_ar,
defer_crs_to_apply_grad=defer_crs_to_apply_grad,
activations_grad=activations_grad,
is_activations=is_activations,
tpu_embedding_activations=tpu_embedding_activations.Flatten())
else:
assert not tpu_embedding_activations
take_grad = ComputeGradientsSimple
grads = take_grad(loss_or_activations, filtered_vlist,
grad_aggregation_method, colocate_gradients_with_ops,
gate_gradients)
if tpu_embedding_activations:
tpu_embedding_grads = grads[len(filtered_vlist):]
grads = grads[:len(filtered_vlist)]
else:
tpu_embedding_grads = None
# Formulate pairs of (var, grad) and pack them into the same
# structure as filtered_vmap.
var_grads = filtered_vmap.Pack(
[VarGrad(v, g) for v, g in zip(filtered_vlist, grads)])
if skip_none_gradients:
var_grads = SkipNoneGradients(var_grads)
if tpu_embedding_grads:
# Create VarGrads for TPU embedding activations in a dedicated sub map.
assert 'tpu_embedding_var_grads' not in var_grads
tpu_embedding_activation_list = tpu_embedding_activations.Flatten()
tpu_embedding_var_grads = [
VarGrad(v, g)
for v, g in zip(tpu_embedding_activation_list, tpu_embedding_grads)
]
tpu_embedding_var_grads = tpu_embedding_activations.Pack(
tpu_embedding_var_grads)
# Replace None gradients with zeros, since TPU embedding expect all
# activations to have gradients.
def _NoneToZeros(key, var_grad):
if var_grad.grad is None:
tf.logging.warning(
f'TPU embedding gradient for feature {key} is None. Replacing with '
'zeros.')
return VarGrad(var_grad.var, tf.zeros_like(var_grad.var))
return var_grad
var_grads.tpu_embedding_var_grads = (
tpu_embedding_var_grads.TransformWithKey(_NoneToZeros))
return var_grads
[docs]def MaskGradients(var_grad, grad_mask):
"""Computes gradients of non-masked variables in vmap w.r.t loss.
Args:
var_grad: A `.NestedMap` of (variable, gradient)
grad_mask: A dict of (variable name, mask).
Returns:
var_grad - a `.NestedMap` of (variable, mask * gradient).
"""
def ApplyMask(entry):
var, grad = entry
mask = grad_mask[var.name]
if isinstance(grad, tf.IndexedSlices):
return VarGrad(var, tf.IndexedSlices(grad.values * mask, grad.indices))
else:
return VarGrad(var, grad * mask)
return var_grad.Transform(ApplyMask)
[docs]def ApplyGradMultiplier(vs_gs, grad_scale=None):
"""Scale gradients by grad_scale on same device as corresponding variables.
Args:
vs_gs: A `.NestedMap` of VarGrad.
grad_scale: If None, each vs_gs entry has the scale. Otherwise, grad_scale
applies to every entry.
Returns:
A `.NestedMap` of (variable, gradient * grad_scale). In particular, if
grad_scale is 0, the result gradient is always 0, even if the input
gradient is inf or nan.
"""
def ScaleOrZero(var: tf.Tensor, grad: tf.Tensor,
scale: tf.Tensor) -> tf.Tensor:
grad = CheckNumerics(grad, 'Gradient for %s is not finite.' % var.name)
return tf.where(
tf.equal(scale, 0.), tf.zeros_like(grad),
tf.cast(scale, grad.dtype) * grad)
def Scale(item: VarGrad) -> VarGrad:
"""Scales the gradient."""
var, grad = item
assert grad is not None, ('No grad found for ', var.name)
if grad_scale is None:
scale = item.scale
else:
scale = grad_scale
with tf.device(var.device):
if isinstance(grad, tf.IndexedSlices):
grad = tf.IndexedSlices(
ScaleOrZero(var, grad.values, scale), grad.indices,
grad.dense_shape)
else:
grad = ScaleOrZero(var, grad, scale)
return VarGrad(var, grad)
return vs_gs.Transform(Scale)
[docs]def HasNanOrInf(x):
if isinstance(x, tf.IndexedSlices):
x = x.values
with tf.device(x.device):
if x.dtype.is_complex:
return tf.reduce_any(
[HasNanOrInf(tf.math.real(x)),
HasNanOrInf(tf.math.imag(x))])
return tf.reduce_any(
tf.math.logical_or(tf.math.is_nan(x), tf.math.is_inf(x)))
[docs]def HasNanOrInfGradient(var_grads):
"""Returns a bool tensor to indicate if `var_grads` contains NaNs or Infs.
Args:
var_grads: A `.NestedMap` with (var, grad) tuple as the map value.
Returns:
A bool scalar tensor to indicate if the `var_grads` contains NaNs or Infs.
"""
return tf.reduce_any([HasNanOrInf(g) for (_, g) in var_grads.Flatten()])
[docs]def ApplyGradNormClipping(vs_gs, norm=1.0):
"""Clip gradients to norm on same device as corresponding variables.
Args:
vs_gs: A `.NestedMap` of VarGrad.
norm: Each tensor's gradient will be scaled down to have a maximum L2-norm
value of `norm`.
Returns:
A `.NestedMap` of VarGrad(variable, scaled_gradient). In particular, if
grad_scale is 0, the result gradient is always 0, even if the input
gradient is inf or nan.
"""
def ClipByNorm(var, grad, norm):
grad = CheckNumerics(grad, 'Gradient for %s is not finite.' % var.name)
return tf.clip_by_norm(grad, norm)
def Clip(item):
"""Scales the gradient."""
var, grad = item
assert grad is not None, ('No grad found for ', var.name)
with tf.device(var.device):
if isinstance(grad, tf.IndexedSlices):
grad = tf.IndexedSlices(
ClipByNorm(var, grad.values, norm), grad.indices, grad.dense_shape)
else:
grad = ClipByNorm(var, grad, norm)
return VarGrad(var, grad)
return vs_gs.Transform(Clip)
SKIP_LP_REGULARIZATION = '__lingvo_skip_lp_regularization'
[docs]def AdjustGradientsWithLpLoss(var_grads, lp_regularizer_weight, p=2.0):
"""Adjusts the map of (var, grad) with Lp regularization, where p=1.0 or 2.0.
Args:
var_grads: a `.NestedMap` or list of (variable, gradient).
lp_regularizer_weight: Lp regularization weight.
p: For now we support 1.0 or 2.0.
Returns:
A tuple (lp_loss, var_grads).
- lp_loss: A scalar. The lp loss.
- var_grads: a `.NestedMap` or list of (variable, gradient) regulated by Lp.
"""
# TODO(yuancao): For now we support p=1 or 2, but this can be extended to
# lp-norm in general.
assert p in [2.0, 1.0], 'For now we only support L1/L2 regularization.'
def GetVar(item):
var, grad = item
if isinstance(grad, tf.IndexedSlices):
with tf.device(var.device):
ids = HasRank(grad.indices, 1)
uniq_ids = tf.unique(ids).y
return tf.gather(var, uniq_ids)
else:
return var
def ShouldAdjust(v):
return not _VarInCollection(v, tf.get_collection(SKIP_LP_REGULARIZATION))
filtered_var_grads = [
var_grad for var_grad in Flatten(var_grads) if ShouldAdjust(var_grad.var)
]
filtered_vars = Transform(GetVar, filtered_var_grads)
for v in filtered_vars:
v_name = v.name if not tf.executing_eagerly() else '[eager]'
tf.logging.info('AdjustGradientsWithLpLoss: %s', v_name)
if p == 2.0:
lp_loss = 0.5 * lp_regularizer_weight * SumSquared(filtered_vars)
elif p == 1.0:
lp_loss = lp_regularizer_weight * SumAbs(filtered_vars)
def LpGrad(var_grad):
"""Adjusts item's grad w/ Lp loss term."""
var, grad = var_grad
if isinstance(grad, tf.IndexedSlices):
# Question(rpang): do we apply Lp loss here even if 'var' is in
# SKIP_LP_REGULARIZATION?
#
# Note: IndexedSlces appears for embedding lookups.
# Embedding lookup ids can have duplicate. For duplicated ids, we
# only want to consider once for each ids.
with tf.device(var.device):
emb = HasRank(var, 2)
vocab_size = tf.shape(emb)[0]
ids = HasRank(grad.indices, 1)
values = tf.gather(emb, ids) # [#ids, dims]
with tf.device(grad.device):
# Counts is a vector of size vocab_size. counts[i] is i-th words
# occurrences in 'ids'.
counts = tf.math.unsorted_segment_sum(
tf.ones_like(ids, dtype=values.dtype), ids, vocab_size)
# Gradients for duplicated ids will be summed when they get
# applied, and hence we account for that by first dividing
# gradient resulting from lp loss by how many times the id is
# duplicated.
#
# For each id in 'ids', we know counts[id] is non-zero,
# hence, it's always safe to take reciprocal.
weights = tf.math.reciprocal(tf.gather(counts, ids))
weights = tf.expand_dims(weights, -1) # [#ids, 1]
if p == 2.0:
grad_v = values
elif p == 1.0:
grad_v = tf.sign(values)
delta = lp_regularizer_weight * weights * grad_v
grad = tf.IndexedSlices(grad.values + delta, ids)
elif not _VarInCollection(var, tf.get_collection(SKIP_LP_REGULARIZATION)):
with tf.device(var.device):
if p == 2.0:
grad_v = var
elif p == 1.0:
grad_v = tf.sign(var)
delta = lp_regularizer_weight * grad_v
with tf.device(grad.device):
grad += delta
return VarGrad(var, grad)
return lp_loss, Transform(LpGrad, var_grads)
[docs]def SplitRecursively(x, num_splits, axis=-1):
"""Splits Tensors in 'x' recursively.
Args:
x: a Tensor, or a list or NestMap containing Tensors to split.
num_splits: number of splits per Tensor.
axis: the split axis.
Returns:
A list of split values of length 'num_splits'.
- If 'x' is a Tensor, a list of split Tensors.
- If 'x' is a list, a list of lists, where each sublist has the same length
as 'x' and the k'th element in each sublist corresponds to a split of the
k'th element from 'x'.
- If 'x' is a `.NestedMap`, a list of `.NestedMap`, where each field
corresponds to a split from the same field of 'x'.
"""
if isinstance(x, tf.Tensor):
return tf.split(x, num_splits, axis=axis)
elif isinstance(x, list):
splits = [SplitRecursively(element, num_splits, axis) for element in x]
splits = list(zip(*splits))
return [list(t) for t in splits]
elif isinstance(x, NestedMap):
results = [NestedMap() for _ in range(num_splits)]
for key, val in x.items():
val_splits = SplitRecursively(val, num_splits, axis)
for i in range(num_splits):
results[i][key] = val_splits[i]
return results
else:
raise TypeError('Unexpected type for SplitRecursively: %s' % type(x))
[docs]def ConcatRecursively(splits, axis=-1):
"""Concatenates tensors from 'splits'.
This is the inverse function of SplitRecursively.
Args:
splits: a list of splits to concatenate, where elements can be Tensors,
lists, or `.NestedMap`. The elements must share the same type and
structure. For example, list elements must have the same length;
`.NestedMap` must have the same set of fields.
axis: the concatenation axis.
Returns:
Concatenated data.
- If input 'splits' are Tensors, returns a concatenated Tensor.
- If input 'splits' are lists, returns a list of the same length where the
k'th element represents concatenated data of the k'th element from each
split.
- If input 'splits' are `.NestedMap`, returns a `.NestedMap` with each field
concatenated from corresponding fields of input splits.
Raises:
TypeError: if 'splits' is not a list or elements of 'splits' do not have
known or matching types.
ValueError: if 'splits' is empty or elements of 'splits' do not have
matching structures.
"""
if not isinstance(splits, list):
raise TypeError('Non-list inputs for ConcatRecursively: %s' % splits)
if not splits:
raise ValueError('Empty inputs for ConcatRecursively: %s' % splits)
tmpl = splits[0]
if isinstance(tmpl, tf.Tensor):
return tf.concat(splits, axis=axis)
elif isinstance(tmpl, list):
if not all(isinstance(split, list) for split in splits):
raise TypeError('Type mismatch for ConcatRecursively: %s' % splits)
if not all(len(split) == len(tmpl) for split in splits):
raise ValueError('Length mismatch for ConcatRecursively: %s' % splits)
return [
ConcatRecursively([split[i]
for split in splits], axis)
for i in range(len(tmpl))
]
elif isinstance(tmpl, NestedMap):
if not all(isinstance(split, NestedMap) for split in splits):
raise TypeError('Type mismatch for ConcatRecursively: %s' % splits)
results = NestedMap()
for key in tmpl:
results[key] = ConcatRecursively([split[key] for split in splits], axis)
return results
else:
raise TypeError('Unexpected type for ConcatRecursively: %s' % type(splits))
[docs]def WeightedAvg(values, weights, sum_reduction_fn=tf.reduce_sum, name=''):
"""Computes weighted average of values from a tensor.
Args:
values: a tensor of values
weights: a tensor of weights
sum_reduction_fn: called to reduce the values and weights to single value
name: name of metric.
Returns:
A tuple (avg, total_weight).
- avg: weighted average value
- total_weight: sum of all weights
"""
msg = 'shape of values and weights tensors must match for metric ' + name
values = with_dependencies(
[assert_equal(tf.shape(values), tf.shape(weights), message=msg)], values)
total_weight = sum_reduction_fn(weights)
# divide_no_nan only supports tf.{float,complex}*.
dtype = values.dtype if values.dtype is tf.float64 else tf.float32
avg = tf.math.divide_no_nan(
sum_reduction_fn(tf.cast(values, dtype) * tf.cast(weights, dtype)),
tf.cast(total_weight, dtype))
return tf.cast(avg, values.dtype), total_weight
[docs]def WeightedAvgOfMetrics(metrics):
"""Computes the weighted average of metrics in the list.
Args:
metrics: list of dictionaries of metrics
Returns:
ret_dict - dictionary of weighted averages of each metrics.
"""
ret_dict = {}
lists_of_metrics = {}
for m in metrics:
for name, (value, weight) in m.items():
if name not in lists_of_metrics:
lists_of_metrics[name] = []
lists_of_metrics[name].append((value, weight))
for name, values_and_weights in sorted(lists_of_metrics.items()):
values = tf.stack([x[0] for x in values_and_weights])
weights = tf.stack([x[1] for x in values_and_weights])
ret_dict[name] = WeightedAvg(values, weights, tf.reduce_sum, name)
return ret_dict
[docs]def ConcatPerExampleTensors(per_example):
"""Concatenate per-example tensors from many hosts into one large block.
Args:
per_example: list of dictionaries of per-example tensors.
Returns:
ret_dict - string -> concatenated tensors.
"""
ret_dict = {}
lists_of_per_example = {}
for m in per_example:
for name, value in m.items():
if name not in lists_of_per_example:
lists_of_per_example[name] = []
lists_of_per_example[name].append(value)
for name, values in sorted(lists_of_per_example.items()):
ret_dict[name] = tf.concat(values, 0)
return ret_dict
[docs]def CombineMetrics(loss_metric_weight_pairs):
"""Combines metrics from `loss_metric_weight_pairs` according to weights.
Keys must either exist in all metrics, in which it will be processed as a
weighted sum, or exist in only one metrics, in which case it will be copied.
Args:
loss_metric_weight_pairs: a list of (metrics, weight) pairs, where each
weight is a float and each metrics is a dict with str keys and
(metric_value, target_weight) values.
Returns:
A dict with the same set of keys as input metrics and values of
(weighted_sum(metric_value), weighted_sum(target_weight)).
Raises:
ValueError: if there exists a metric that exists in more than one element
of `loss_metric_weight_pairs` but not in all of them.
"""
all_keys = set(
[k for loss_metrics, _ in loss_metric_weight_pairs for k in loss_metrics]) # pylint: disable=g-complex-comprehension
result = {}
for k in all_keys:
count = 0
for loss_metrics, weight in loss_metric_weight_pairs:
if k in loss_metrics:
count += 1
if count > 1 and count != len(loss_metric_weight_pairs):
raise ValueError('Found metric %s which exists in more than one'
'but not all loss metrics.' % k)
total_val = 0
total_target_weight = 0
for loss_metrics, weight in loss_metric_weight_pairs:
if k in loss_metrics:
val, target_weight = loss_metrics[k]
if count == 1:
# Single metric, don't multiply by weight.
total_val = val * target_weight
total_target_weight = target_weight
else:
# Total weighted sum of all predictions.
total_val += weight * val * target_weight
total_target_weight += weight * target_weight
result[k] = (total_val / total_target_weight, total_target_weight)
return result
[docs]def AddVN(p, x, per_step=False):
"""Add variational noise to x.
Args:
p: Layer params, with a `vn` subparam containing `VariationalNoiseParams`.
x: Input to add variational noise to.
per_step: Whether to add per_step noise.
Returns:
The input with variational noise added according to params.
"""
tensor_name = x.name if not tf.executing_eagerly() else '[eager]'
if per_step:
if not p.vn.per_step_vn:
tf.logging.info(
'p.vn.per_step_vn is not set. Not adding per-step vn to ' +
tensor_name)
return x
else:
if not p.vn.global_vn:
tf.logging.info('p.vn.global_vn is not set. Not adding global vn to ' +
tensor_name)
return x
tf.logging.info(
f"Add {'per-step' if per_step else 'global'} vn to {tensor_name}: {p.vn}")
if p.vn.scale is None:
raise ValueError('VN scale must be set.')
if p.vn.deterministic:
noises = DeterministicVN(p, tf.shape(x), mean=0.0, std=1.0)
noises = tf.cast(noises, x.dtype)
else:
if per_step:
# recurrent.py does not support stateful random ops in cell_fn due to
# rematerialization.
raise ValueError('per_step vn requires deterministic=True.')
noises = tf.random.normal(
tf.shape(x), stddev=1.0, seed=p.vn.seed, dtype=x.dtype)
scale = tf.where(GetGlobalStep() >= p.vn.start_step, p.vn.scale, 0.0)
return x + tf.cast(scale, x.dtype) * noises
[docs]def VariationalNoiseParams(scale,
global_vn=False,
per_step_vn=False,
seed=None,
deterministic=None,
start_step=0):
"""Returns a hyperparams for variational noise."""
if deterministic is None:
deterministic = cluster_factory.Current().in_unit_test
p = hyperparams.Params()
p.Define(
'scale', scale,
'Std of the variational noise to apply . This can be a scalar,'
' or a scalar tensor.')
p.Define('global_vn', global_vn,
'Adds global variational noise every training setp iff True.')
p.Define('per_step_vn', per_step_vn,
'Adds per-timesetp variational noise iff True.')
p.Define('seed', seed, 'Random seed used to generate noise.')
p.Define(
'deterministic', deterministic, 'If true, generate noise using'
'stateless random ops that are compatible with TF functional ops.')
p.Define(
'start_step', start_step,
'Step starting from which variational noise is added during training.')
return p
[docs]def DefaultVN():
return VariationalNoiseParams(scale=None)
# To disable VN of a layer, we use 1.0 in the first input parameter
# of the following function because otherwise it is the same to DefaultVN()
# which will be updated by parent configuration in CopyBaseParams()
[docs]def DisableVN():
return VariationalNoiseParams(1.0, False, False)
# Step seed keyed by graph.
_STEP_SEED_DICT = ThreadLocalDict()
# The step seed will increment by np.prod(_STEP_SEED_INCREMENT.stack)
_STEP_SEED_INCREMENT = ThreadLocalStack()
[docs]@contextlib.contextmanager
def StepSeedIncrementContext(step):
"""Adds an element to _STEP_SEED_INCREMENT."""
assert step > 0, ('%s' % step)
_STEP_SEED_INCREMENT.stack.append(step)
try:
yield
finally:
_STEP_SEED_INCREMENT.stack.pop()
[docs]def GetStepSeed():
"""Gets step_seed."""
key = id(tf.get_default_graph())
if key not in _STEP_SEED_DICT.dict:
ResetStepSeed()
return _STEP_SEED_DICT.dict[key]
[docs]def ResetStepSeed(seed=0):
"""Resets step_seed to specified value."""
key = id(tf.get_default_graph())
_STEP_SEED_DICT.dict[key] = tf.convert_to_tensor(seed, dtype=tf.int64)
[docs]def MaybeResetStepSeedFromScope():
"""In graph mode, resets step_seed according to the current named scope.
This is used in graph mode to avoid "tensor is from a different graph"
errors that happen when we share random seend tensors too much.
See b/129159299 for more context.
Eager mode does not have this problem, so in eager mode we do nothing.
"""
if not tf.executing_eagerly():
ResetStepSeed(GenerateSeedFromName(tf.no_op(name='new_step_seed').name))
[docs]def MaybeResetStepSeed(seed):
"""If we're in graph mode, reset the step seed."""
if not tf.executing_eagerly():
ResetStepSeed(seed)
[docs]def GetIncStepSeed():
"""Returns and increments the step_seed."""
step_seed = GetStepSeed()
# TODO(lepikhin): introduce a routine filling a queue of uint32 random seeds
# independent of underlying PRNG used by tensorflow.
inc = np.prod(_STEP_SEED_INCREMENT.stack)
ResetStepSeed(step_seed + inc)
return step_seed
[docs]def GenerateStepSeedPair(p, op_seed=None):
"""Generates a seed pair for deterministic random operations in ...
functional loops.
This function retrieves a unique seed pair on each call, based off the current
global step and step seed. The step seed ensures this function returns a
unique seed pair on each call: calling this function automatically increments
the step seed. The step seed is automatically reset at the beginning of each
global step in the model's FProp and works transparently through recurrent.py.
Args:
p: A hyperparams.Params object, containing keys 'random_seed' and
'is_inference'.
op_seed: An additional operation-level seed to apply.
Returns:
A size 2 tensor of op seeds to use for stateless_random ops.
"""
seed_dtype = tf.int32 if use_tpu() else tf.int64
if p.is_inference and p.random_seed is None:
# Ensure GetIncStepSeed is called even inside the shortcut.
# This ensures if p.random_seed is set for other ops that use this function
# that they will get the same seed pair whether or not p.random_seed is set
# for this specific call.
GetIncStepSeed()
# Unlike tf.random*, stateless random ops are completely determined by the
# passed-in seeds. This means at inference time the same inputs will produce
# the same outputs, even if the model is supposed to have randomness such as
# dropout during inference. We inject additional randomness only during
# inference if the graph is exported with random_seed=None as a workaround.
return tf.random.uniform([2], maxval=seed_dtype.max, dtype=seed_dtype)
global_step = tf.cast(GetGlobalStep(), seed_dtype)
step_seed = tf.cast(GetIncStepSeed(), seed_dtype)
seeds = tf.stack([global_step, step_seed])
if p.random_seed is not None:
seeds += p.random_seed
if op_seed is not None:
op_seed = tf.cast(op_seed, seed_dtype)
seeds += op_seed
return seeds
[docs]def DeterministicDropout(x, keep_prob, seeds, noise_shape=None, name=None):
"""Similar to `tf.nn.dropout()`, but fully deterministic.
Args:
x: A float Tensor on which to apply dropout.
keep_prob: A scalar `Tensor` of keep probability.
seeds: A Tensor of shape [2]. 2 seeds for deterministic random number
generator.
noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for
randomly generated keep/drop flags.
name: An optional name for this operation.
Returns:
A Tensor with the same shape as `x`.
Raises:
InvalidArgumentError: if keep_prob is invalid.
"""
if isinstance(keep_prob, numbers.Real):
if keep_prob <= 0 or keep_prob > 1:
raise tf.errors.InvalidArgumentError(
'keep_prob must be in range (0, 1]. Value: {}'.format(keep_prob))
if keep_prob == 1:
return x
with tf.name_scope(name, 'dropout', [x]) as name:
if use_tpu():
seeds = tf.cast(seeds, tf.int32)
keep_prob = tf.convert_to_tensor(
keep_prob, dtype=tf.float32, name='keep_prob')
# uniform in [keep_prob, 1.0 + keep_prob)
# StatelessRandomUniform op does not support non-float (e.g. bfloat16) dtype
# and non-int32 seed types.
noise_shape = noise_shape or GetShape(x)
random_tensor = keep_prob + tf.random.stateless_uniform(
noise_shape, seed=seeds, dtype=tf.float32)
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
binary_tensor = tf.floor(random_tensor)
if x.dtype != tf.float32:
binary_tensor = tf.cast(binary_tensor, x.dtype)
keep_prob = tf.cast(keep_prob, dtype=x.dtype)
result = tf.div(x, keep_prob) * binary_tensor
result.set_shape(x.get_shape())
return result
[docs]def DeterministicVN(params, noise_shape, mean=0.0, std=1.0, name=None):
"""Produces Fully deterministic Gaussian noise from shape, mean and std.
Args:
params: Nested map of params.
noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for
randomly generated Gaussian noise.
mean: Mean for the Gaussian noise.
std: Standard deviation for noise.
name: An optional name for this operation.
Returns:
A Tensor with the shape noise_shape and type fprop_dtype.
"""
with tf.name_scope(name, 'gaussian_noise') as name:
seeds = GenerateStepSeedPair(params, params.vn.seed)
random_tensor = mean + (
std * tf.random.stateless_normal(noise_shape, seed=seeds))
if FPropDtype(params) != tf.float32:
random_tensor = tf.cast(random_tensor, FPropDtype(params))
return random_tensor
BATCH_NORM_UPDATES = 'batch_norm_updates'
_BATCH_NORM_UPDATES_DICT = '__batch_norm_update_dict'
_get_batch_norm_updates_dict = _CollectionGetter(_BATCH_NORM_UPDATES_DICT,
lambda: {})
[docs]def UpdateBatchNormVars(batch_norm_var, batch_norm_stats, decay):
"""Update batch normalization moving averages."""
with tf.name_scope(
'AssignMovingAvg', values=[
batch_norm_var,
batch_norm_stats,
decay,
]) as scope:
with tf.ops.colocate_with(batch_norm_var):
decay = tf.convert_to_tensor(
1.0 - decay, dtype=batch_norm_var.dtype.base_dtype)
update_delta = (batch_norm_var - tf.cast(
batch_norm_stats, batch_norm_var.dtype.base_dtype)) * decay
has_nan_or_inf = tf.reduce_any(
tf.math.logical_or(
tf.math.is_nan(update_delta), tf.math.is_inf(update_delta)))
update_delta = tf.where(has_nan_or_inf, tf.zeros_like(update_delta),
update_delta)
bn_update = tf.assign_sub(batch_norm_var, update_delta, name=scope)
tf.add_to_collection(BATCH_NORM_UPDATES, bn_update)
if not tf.executing_eagerly_outside_functions():
bn_update_dict = _get_batch_norm_updates_dict()
if bn_update.name in bn_update_dict:
raise ValueError(f'BN update {bn_update.name} already exists.')
bn_update_dict[bn_update.name] = (batch_norm_var, batch_norm_stats)
return bn_update
[docs]def FindRelevantBatchNormUpdates(loss, batch_norm_updates):
"""Finds and returns a list of relevant batch-normalization updates.
Args:
loss: The loss that is being optimized for. A tensor or a list of tensors.
batch_norm_updates: A list of batch normalization updates.
Returns:
A pair of lists. The first list contains all the batch normalization updates
that are relevant to the loss being optimized, and the second list contains
all in batch_norm_updates but not in the first list.
"""
if tf.executing_eagerly_outside_functions():
return [], []
dependent_ops_and_tensors = set(FindNeeded(loss))
relevant_updates = []
irrelevant_updates = []
bn_update_dict = _get_batch_norm_updates_dict()
for bn_update in batch_norm_updates:
assert bn_update.name in bn_update_dict, (
f'{bn_update.name} is probably not a valid batch normalization update '
'op. Make sure batch normalization is done through calling'
' the py_utils.UpdateBatchNormVars helper routine.')
bn_stat_name = bn_update_dict[bn_update.name][1].name
if bn_stat_name in dependent_ops_and_tensors:
# If a batch normalization stat is computed in the forward pass in
# computing loss, then the corresponding batch normalization update is
# relevant. Otherwise, it is not.
relevant_updates.append(bn_update)
else:
irrelevant_updates.append(bn_update)
return relevant_updates, irrelevant_updates
_SAMPLE_STEP_STACK = ThreadLocalStack()
[docs]@contextlib.contextmanager
def SampleStep(step):
"""A context for a sample step during decoding.
Example usage::
with py_utils.SampleStep(step):
sample = self.DecodeOneStep()
Args:
step: the step tensor.
Yields:
a context manager for the step scope.
"""
try:
_SAMPLE_STEP_STACK.stack.append(step)
yield step
finally:
_SAMPLE_STEP_STACK.stack.pop()
[docs]def _GetSampleStep():
return _SAMPLE_STEP_STACK.stack[-1] if _SAMPLE_STEP_STACK.stack else None
[docs]def AddDebugTensor(tensor, summarize=None, name=None):
"""Adds `tensor` to the debug collection.
Prints the tensor if `--print_debug_tensors` is True.
Args:
tensor: A tensor.
summarize: Only print this many entries of each tensor. If None, then a
maximum of 3 elements are printed per input tensor.
name: An optional name for the tensor.
Returns:
A Tensor that evaluates to the same value as the input tensor.
"""
if _FromGlobal('print_debug_tensors'):
step = _GetSampleStep()
tensors_to_print = ([] if step is None else [step]) + [tensor]
with tf.name_scope(name) as s:
tensor = tf.Print(
tensor,
tensors_to_print,
message='DEBUG tensor %s' % s,
name=name,
summarize=summarize)
return tensor
[docs]def ArgMax(inputs):
"""tf.argmax wrapper.
Args:
inputs: A tensor, whose last dimension is being reduced on.
Returns:
A tensor of rank tf.rank(logits)-1. If i == ret[indices],
logits[indices, i] is the maximum among logits[indices, :].
"""
if use_tpu():
return tf.argmax(inputs, axis=-1, output_type=tf.int32)
else:
return tf.argmax(inputs, axis=-1)
[docs]def _EnsureMatrixShape(x):
if x.shape.ndims is None:
x.set_shape([None, None])
else:
assert x.shape.ndims == 2
return x
[docs]def Matmul(x, y, *args, **kwargs):
"""tf.matmul wrapper expecting x and y are actually matrices."""
x = _EnsureMatrixShape(x)
y = _EnsureMatrixShape(y)
return tf.matmul(x, y, *args, **kwargs)
[docs]def clip_by_value(t, clip_value_min, clip_value_max, name=None): # pylint: disable=invalid-name
if t.dtype.is_complex:
return tf.complex(
tf.clip_by_value(
tf.math.real(t), clip_value_min, clip_value_max, '%s_real' % name),
tf.clip_by_value(
tf.math.imag(t), clip_value_min, clip_value_max, '%s_imag' % name))
return tf.clip_by_value(t, clip_value_min, clip_value_max, name)
[docs]def SumSquared(tensor_list):
return _TransformAndSum(tensor_list, lambda v: v**2)
[docs]def SumAbs(tensor_list):
return _TransformAndSum(tensor_list, tf.abs)
[docs]def ReduceRms(x: tf.Tensor) -> tf.Tensor:
"""Computes root mean square of tensor x with numerical stability."""
if not x.shape.is_fully_defined():
raise ValueError('Shape of x must be fully defined.')
if not x.shape.as_list():
return x
denom = functools.reduce((lambda x, y: x * y), x.shape.as_list())
if denom <= 1e8:
return tf.math.sqrt(tf.math.reduce_mean(tf.math.square(x)))
tf.logging.info('reduce_rms %s denom=%d', x, denom)
sum_square_x = tf.math.reduce_sum(tf.math.reduce_sum(tf.math.square(x), -1))
avg_square_x = sum_square_x / tf.constant(denom, dtype=sum_square_x.dtype)
return tf.math.sqrt(avg_square_x)
[docs]def PiecewiseConstant(x_in, boundaries, values, vdtype):
"""Returns the piecewise value of x_in."""
x_in = tf.cast(tf.convert_to_tensor(x_in), tf.float32)
assert len(values) == len(boundaries) + 1
assert sorted(boundaries) == list(boundaries)
bs = tf.convert_to_tensor(boundaries, dtype=tf.float32)
vs = tf.convert_to_tensor(values, dtype=vdtype)
# The following is equivalent to 'return vs[index]'.
index = tf.reduce_sum(tf.cast(tf.greater_equal(x_in, bs), tf.int32))
one_hot_vec = tf.one_hot(
tf.expand_dims(index, 0), depth=len(values), dtype=vdtype)
return Matmul(tf.reshape(vs, (1, -1)), tf.transpose(one_hot_vec))[0][0]
[docs]def PadSequenceDimension(x, length, pad_val, shape=None, axis=1):
"""Pads x to `length` using `pad_val` along the axis dim.
Assumes `x` is a tensor with rank >= 2, and it only pads `x` to `length`
along the axis dim. Explicitly sets the returned tensor shape to `shape` if
given. Raises runtime errors if x.shape[axis] > length or
x.shape[i] != shape[i] where i != axis.
Args:
x: the tensor to be padded with axis dimension being the time. E.g., x
usually has shape [batch, seq_len, ...], when axis=1.
length: an int to specify the length to pad x to.
pad_val: an int or float used to pad x.
shape: an int array specifying the shape of the padded tensor if specified.
axis: The dimension that x will be padded, default to 1.
Returns:
The padded tensor with shape [batch, seq_len, ...], where
ret[:, :seq_len, ...] == x, when axis=1, and similarly for other axes.
"""
if x.shape.ndims is not None:
rank = x.shape.ndims
assert rank >= 2
slen = GetShape(x, rank)[axis]
pad_len = length - slen
pad = [[0, 0] for _ in range(rank)]
pad[axis][1] = pad_len
else:
rank = tf.rank(x)
with tf.control_dependencies([assert_greater_equal(rank, 2)]):
slen = tf.shape(x)[axis]
pad_len = length - slen
pad = tf.scatter_nd([[axis, 1]], [pad_len], [rank, 2])
x = tf.pad(x, pad, constant_values=pad_val)
if x.shape.ndims is not None and isinstance(length, int):
static_shape = x.shape.as_list()
static_shape[axis] = length
x.set_shape(static_shape)
if shape:
if not isinstance(shape, (list, tuple)):
raise TypeError('Shape must be a list or tuple.')
x = HasRank(x, len(shape))
x = tf.ensure_shape(x, shape)
return x
[docs]def PadSequenceTo(xs, padding, length, pad_val):
"""Pads `xs` and `padding` to `length` using `pad_val` along the 2nd dim.
Pads `xs` to `length` using `pad_val`, and `padding` using 1.
Raise error if `x.shape[:2]` and `padding.shape` are not the same.
Args:
xs: A Tensor or a list of Tensors of shape [batch, seqlen] or [batch,
seqlen, ...].
padding: A 0/1 Tensor of shape [batch, seqlen]. 1 is for padded locations.
length: A Python int, the length to pad to.
pad_val: A Python numeric, used for padding x.
Returns:
A tuple of padded xs and padding.
"""
if not isinstance(xs, (list, tuple)):
new_xs = [xs]
else:
new_xs = xs
res = []
for x in new_xs:
batch, slen = GetShape(x, 2)
padding = HasRank(padding, 2)
padding = HasShape(padding, [batch, slen])
new_x = PadSequenceDimension(x, length, pad_val)
res.append(new_x)
padding = PadSequenceDimension(padding, length, tf.cast(1, padding.dtype))
if not isinstance(xs, (list, tuple)):
assert len(res) == 1
return res[0], padding
else:
return tuple(res), padding
[docs]def ApplyPadding(padding, x, padded=None, use_select=True, ensure_shape=True):
"""Applies padding to a tensor.
This is preferable to using arithmetic means for masking out padded values
such as::
# Equiv to ApplyPadding(padding, x)
x *= 1.0 - padding
# Equiv to ApplyPadding(padding, new, old)
new = old * padding + new * (1 - padding)
Aside from just being easier to read and reason about, using this function
is friendly to quantized representations because it does not mix arithmetic
on the padding values with the values in the tensor being padded (which can
have a very different range than the 0..1 padding tensor).
In addition, this works around issues in quantized schemes where we are
guaranteed to have an exact 0 but not necessarily any other number (i.e. 1).
Args:
padding: Tensor of padding values where 0 == keep and 1 == pad.
x: Tensor to apply padding to.
padded: Optional. Values to include for padded elements. Defaults to zeros.
Must have a shape broadcastable to 'x' if specified.
use_select: Controls whether padding is applied with a select-mask
(True/default) or arithmetically (False). Some platforms have a
sensitivity to one or the other and this is used to work around such
issues.
ensure_shape: If true, ensures the shape of the result is the same as of x.
Returns:
A tensor with the same shape as x with padded values masked.
"""
padding = with_dependencies([
Assert(
tf.reduce_all(
tf.math.logical_or(
tf.equal(padding, tf.zeros([], padding.dtype)),
tf.equal(padding, tf.ones([], padding.dtype)))), [padding])
], padding)
if use_select:
if padded is None:
padded = tf.zeros([], x.dtype)
if padding.dtype != tf.bool:
padding = padding > tf.zeros([], padding.dtype)
result = tf.where_v2(padding, padded, x)
else:
result = x * tf.cast(1.0 - tf.cast(padding, tf.float32), x.dtype)
if padded is not None:
result += padded * tf.cast(padding, padded.dtype)
if ensure_shape:
result = tf.ensure_shape(result, x.shape)
return result
[docs]def LengthsFromPaddings(paddings, dtype=None):
"""Computes lengths of each sequence in a batch, ignoring trailing padding.
Note the following isn't guaranteed due to leading paddings.
PaddingsFromLengths(LengthsFromPaddings(x)) == x
Args:
paddings: a tensor with shape [batch, length].
dtype: A type to optionally cast the result to.
Returns:
lengths tensor shaped [batch] containing the unpadded length of each
sequence in the batch.
"""
paddings = HasRank(paddings, 2)
mask = 1 - tf.cast(paddings, tf.int32)
# We cannot just use `tf.reduce_sum(mask, axis=1)` to compute the number of
# elements prior to the trailing padding, because there might be leading
# padding. Thus, to identify the amount of trailing padding, we notice that
# the mask values for all the trailing padding will be zero, and thus in the
# cumsum below they will all be equal to the last element of the cumsum. Note
# that the final unpadded value will also be equal to the final cumsum value.
cumsum = tf.cumsum(mask, axis=1)
same_as_last_element = tf.equal(cumsum, cumsum[:, -1:])
# Counting the number of elements with the same value as the last cumsum value
# gives us num_trailing_paddings + 1, and so counting the number of elements
# that *differ* from the last cumsum value gives us the unpadded_length - 1.
unpadded_length = tf.reduce_sum(
1 - tf.cast(same_as_last_element, tf.int32), axis=1) + 1
# Special case for when the mask is all zeros.
# In this case, all the entries in the cumsum will be equal to the last
# element, so the number that differ would be zero, and thus the
# unpadded_length value would be 1 (which is incorrect). We thus set it to 0.
all_zero_mask = tf.equal(tf.reduce_sum(mask, axis=1), 0)
result = tf.where(all_zero_mask, tf.zeros_like(unpadded_length),
unpadded_length)
if dtype and result.dtype != dtype:
result = tf.cast(result, dtype)
return result
[docs]def PaddingsFromLengths(lengths, maxlen=None):
"""Computes paddings Tensor from lengths.
Note the following isn't guaranteed due to leading paddings.
PaddingsFromLengths(LengthsFromPaddings(x)) == x.
This method does not generate leading paddings.
Args:
lengths: A int32 Tensor of shape [B].
maxlen: None or a Python int or a scalar Tensor.
Returns:
A 0/1 valued Tensor of shape [B, maxlen or ?] where 1s are padded positions.
"""
lengths = HasRank(lengths, 1)
if maxlen is not None:
lengths = with_dependencies(
[assert_less_equal(tf.cast(tf.reduce_max(lengths), tf.int32), maxlen)],
lengths)
return 1. - tf.sequence_mask(lengths, maxlen=maxlen, dtype=tf.float32)
[docs]def TrimTrailingPaddings(inputs, paddings):
"""Trims trailing paddings from inputs.
Since the number of dimensions is not fixed, this will not work on TPU.
Args:
inputs: a tensor with shape [batch, length, ...].
paddings: a tensor with shape [batch, length].
Returns:
Trimmed inputs and paddings. For compatibility reasons, the trimmed tensors
will always have length at least 1.
"""
paddings = HasRank(paddings, 2)
max_length = tf.maximum(tf.reduce_max(LengthsFromPaddings(paddings)), 1)
output_shape = tf.shape(inputs)
output_shape = tf.concat([[output_shape[0], max_length], output_shape[2:]],
axis=0)
outputs = tf.slice(inputs, tf.zeros_like(output_shape), output_shape)
out_paddings = tf.slice(paddings, [0, 0],
tf.stack([output_shape[0], max_length]))
return outputs, out_paddings
[docs]def ReversePaddedSequence(inputs, paddings):
"""Reverse inputs based on paddings.
Only reverse the unpadded portion of `inputs`. It assumes inputs are only
padded in the end.
Args:
inputs: a tensor of [seq_length, batch_size, num_input_nodes].
paddings: a tensor of float32/float64 zero or one of shape [seq_length,
batch_size, 1].
Returns:
A reversed tensor of the same shape as `inputs`.
"""
inversed_paddings = 1.0 - tf.squeeze(paddings, 2)
inputs_length = tf.cast(
tf.math.rint(tf.reduce_sum(inversed_paddings, axis=0)), tf.int32)
return tf.reverse_sequence(inputs, inputs_length, seq_axis=0, batch_axis=1)
[docs]def ConcatenatePaddedSequences(input0, input1, padding0, padding1, seq_dim=1):
"""Concatenates input sequences with varying lengths as defined by paddings.
This is a helper function for concatenating 2 batches of input sequences,
where each example in the batch can have different lengths, as defined by
the corresponding paddings. To concatenate correctly, it makes use of
tf.reverse_sequence to partially reverse the sequences before
concatenating them together.
NOTE: We assume that the tensors have no leading paddings.
Args:
input0: A tensor of size [batch, max_length, ...] or [max_length, batch,
...] depending on the value set for axis.
input1: A tensor of size [batch, max_length, ...] or [max_length, batch,
...] depending on the value set for axis.
padding0: A Tensor of size [batch, max_length] or [max_length, batch]
corresponding to the padding for input0.
padding1: A Tensor of size [batch, max_length] or [max_length, batch]
corresponding to the padding for input1.
seq_dim: int, the time axis along which the tensors will be concatenated.
Should be 0 or 1. Assumes that batch_dim is 1 - seq_dim.
Returns:
The concatenation of input0 and input1, and the corresponding padding.
Raises:
tf.errors.InvalidArgumentError when seq_dim is not 0 or 1.
"""
if seq_dim != 0 and seq_dim != 1:
raise tf.errors.InvalidArgumentError(None, None, 'seq_dim must be 0 or 1.')
batch_dim = 1 - seq_dim
# inpu0 and input1 should have the same batch size and same rank.
input0 = with_dependencies([
assert_equal(GetShape(input0)[batch_dim],
GetShape(input1)[batch_dim]),
assert_equal(GetRank(input0), GetRank(input1))
], input0)
batch_size = GetShape(padding0)[batch_dim]
# batch dimension of inputs and paddings should match.
input0 = with_dependencies([
assert_equal(GetShape(input0)[batch_dim], batch_size),
assert_equal(GetShape(padding1)[batch_dim], batch_size)
], input0)
input0_seq_dim = tf.cast(
tf.tile([tf.shape(padding0)[seq_dim]], [batch_size]), dtype=tf.int32)
input1_seq_dim = tf.cast(
tf.tile([tf.shape(padding1)[seq_dim]], [batch_size]), dtype=tf.int32)
# LengthsFromPaddings assumes that paddings is of size [batch, max_length].
if seq_dim == 1:
seq_length0 = LengthsFromPaddings(padding0)
seq_length1 = LengthsFromPaddings(padding1)
else:
seq_length0 = LengthsFromPaddings(tf.transpose(padding0))
seq_length1 = LengthsFromPaddings(tf.transpose(padding1))
# We assume that the tensors have no leading paddings.
# TODO(arunnt): Concatenate tensors with leading paddings correctly.
seq_length0 = with_dependencies([
assert_equal(
seq_length0,
tf.cast(tf.reduce_sum(1.0 - padding0, seq_dim), dtype=tf.int32))
], seq_length0)
seq_length1 = with_dependencies([
assert_equal(
seq_length1,
tf.cast(tf.reduce_sum(1.0 - padding1, seq_dim), dtype=tf.int32))
], seq_length1)
# Concatenate input sequences.
reversed_input0 = tf.reverse_sequence(
input0, seq_length0, seq_axis=seq_dim, batch_axis=batch_dim)
reversed_input1 = tf.reverse_sequence(
input1, input1_seq_dim, seq_axis=seq_dim, batch_axis=batch_dim)
reversed_concat = tf.concat([reversed_input1, reversed_input0], axis=seq_dim)
concat_inputs = tf.reverse_sequence(
reversed_concat,
seq_length0 + input1_seq_dim,
seq_axis=seq_dim,
batch_axis=batch_dim)
# Concatenate paddings. Note that paddings are always a Tensor of 0s and 1s,
# so, unlike the inputs, we don't have to reverse padding1, we can simply
# concatenate reversed padding0 and padding1.
reversed_padding0 = tf.reverse_sequence(
padding0, input0_seq_dim, seq_axis=seq_dim, batch_axis=batch_dim)
reversed_concat_padding = tf.concat([reversed_padding0, padding1],
axis=seq_dim)
concat_paddings = tf.reverse_sequence(
reversed_concat_padding,
input0_seq_dim + seq_length1,
seq_axis=seq_dim,
batch_axis=batch_dim)
return concat_inputs, concat_paddings
[docs]def ShiftLeft(tensor, shift_size, pad_val=0, axis=1):
"""Shifts the values in a tensor to the left along the axis dimension.
The first shift_size values are dropped, and the tensor is padded on the
right with pad_val.
Args:
tensor: the input tensor with the axis dim being time.
shift_size: the number of frames >= 0 to shift.
pad_val: the value to pad on the right of the tensor.
axis: The dimension along which the tensor will be shifted, default to 1.
Returns:
A left shifted tensor on dimension axis.
"""
rank = tensor.shape.rank
with tf.control_dependencies(
[assert_greater_equal(rank, 2),
assert_greater_equal(shift_size, 0)]):
time = GetShape(tensor)[axis]
begin = tf.scatter_nd([[axis]], [shift_size], [rank])
return PadSequenceDimension(
tf.slice(tensor, begin, size=[-1] * rank), time, pad_val, axis=axis)
[docs]def CreateIdsAndLabels(ids, paddings, sos_id=1, eos_id=2, trim=False):
"""Creates ids and labels to be used as decoder targets.
Args:
ids: int Tensor of shape [batch, maxlen], without sos or eos.
paddings: float Tensor of shape [batch, maxlen].
sos_id: ID for the sos special token.
eos_id: ID for the eos special token.
trim: Whether to trim the last elements in the output Tensors, so that the
lenghts of the output Tensors are same as the input Tensors. Otherwise,
the output Tensors are longer than the input Tensors by one because of
the added sos / eos.
Returns:
A NestedMap with the following fields, where maxlen' equals maxlen when
trim=True, otherwise maxlen + 1:
- ids: int Tensor of shape [batch, maxlen'], with sos prepended.
- labels: int Tensor of shape [batch, maxlen'], with eos appended.
- paddings: float Tensor of shape [batch, maxlen'].
- weights: float Tensor of shape [batch, maxlen'].
"""
ids = tf.where(
tf.equal(paddings, 0.0), ids, tf.broadcast_to([[eos_id]], GetShape(ids)))
targets = NestedMap()
targets.ids = tf.pad(ids, [[0, 0], [1, 0]], constant_values=sos_id)
targets.labels = tf.pad(ids, [[0, 0], [0, 1]], constant_values=eos_id)
targets.paddings = tf.pad(paddings, [[0, 0], [1, 0]])
targets.weights = 1.0 - targets.paddings
if trim:
targets = targets.Transform(lambda v: v[:, :-1])
return targets
[docs]def Retry(*args, **kwargs):
return retry.Retry(*args, **kwargs)
# FailedPreconditionError: variables are not initialized.
# AbortedError: processes restarts.
# UnavailableError: Bad hardware status: 0x1
transient_tf_errors = (tf.errors.FailedPreconditionError,
tf.errors.AbortedError, tf.errors.UnavailableError)
[docs]def RetryOnTransientTfError(*args, **kwargs):
return Retry(transient_tf_errors, *args, **kwargs)
[docs]def PadOrTrimTo(x, shape, pad_val=0, pad_after_contents=True):
"""Pad and slice x to the given shape.
Args:
x: A tensor.
shape: The shape of the returned tensor.
pad_val: An int or float used to pad x.
pad_after_contents: Whether to pad and trim after the original contents of
each dimension.
Returns:
'x' is padded with pad_val and sliced so that the result has the given
shape.
Raises:
ValueError: if shape is a tf.TensorShape and not fully defined.
"""
if isinstance(shape, (list, tuple)):
expected_rank = len(shape)
elif isinstance(shape, tf.TensorShape):
if not shape.is_fully_defined():
raise ValueError('shape %s padding %s must be fully defined.' %
(shape, x))
expected_rank = shape.rank
else:
shape = HasRank(shape, 1)
expected_rank = tf.size(shape)
x = HasRank(x, expected_rank)
pad = shape - tf.minimum(tf.shape(x), shape)
zeros = tf.zeros_like(pad)
if pad_after_contents:
# If dim_i is less than shape[i], pads after contents.
paddings = tf.stack([zeros, pad], axis=1)
# If dim_i is larger than shape[i], we slice [0:shape[i]] for dim_i.
slice_begin = zeros
else:
# If dim_i is less than shape[i], pads before contents.
paddings = tf.stack([pad, zeros], axis=1)
# If dim-i is larger than shape[i], we slice [dim_i - shape[i]:dim_i]
# for dim_i.
slice_begin = tf.shape(x) + pad - shape
x = tf.pad(x, paddings, constant_values=pad_val)
x = tf.slice(x, slice_begin, shape)
return tf.reshape(x, shape)
[docs]def ExpandTo(x, target_rank):
"""Expands the last dimension of x until it has rank target_rank."""
if x is None:
return None
shape = GetShape(x)
rank = GetRank(x)
rank_diff = target_rank - rank
if isinstance(rank_diff, tf.Tensor):
new_shape = tf.concat([shape, tf.ones([rank_diff], tf.int32)], -1)
else:
new_shape = shape + [1] * rank_diff
new_x = tf.reshape(x, new_shape)
if not isinstance(target_rank, tf.Tensor):
new_x.shape.with_rank(target_rank)
return new_x
[docs]def ExpandAndPadOrTrimTo(x, target_shape, pad_val=0):
"""Ensures that x is broadcast compatible with target_shape.
x is first expanded to the target rank. Thereafter, if x is not broadcast
compatible with target_shape the non-broadcast compatible dimensions are
either padded or trimmed to the target shape.
Args:
x: A tensor.
target_shape: A tensor shape either as a list or Tensor.
pad_val: The value to pad.
Returns:
A tensor which is broadcast compatible with target_shape.
"""
if x is None:
return None
target_rank = None
if isinstance(target_shape, tf.Tensor):
target_rank = GetShape(target_shape)[0]
else:
target_rank = len(target_shape)
x = ExpandTo(x, target_rank)
x_shape = GetShape(x)
is_static = (not isinstance(x_shape, tf.Tensor) and
all(not isinstance(d, tf.Tensor) for d in x_shape))
if is_static:
masked_target_shape = [
1 if x_shape[i] == 1 else target_shape[i] for i in range(len(x_shape))
]
else:
masked_target_shape = tf.where(
tf.equal(x_shape, 1), tf.ones_like(target_shape), target_shape)
new_x = PadOrTrimTo(x, masked_target_shape, pad_val)
return tf.reshape(new_x, masked_target_shape)
[docs]def RepeatDim(tensor, multiple, axis):
"""Copies elements in tensor's axis "multiple" times, like np.repeat."""
# x = [[1, 2, 3], [4, 5, 6]]
# RepeatDim(x, multiple=2, axis=1) gives:
# [[1, 1, 2, 2, 3, 3]. [4, 4, 5, 5, 6, 6]]
# As a comparison tf.tile(x, multiples=[1, 2]) gives:\
# [[1, 2, 3, 1, 2, 3], [4, 5, 6, 4, 5, 6]]
if multiple == 1:
return tensor
t_shape = tf.shape(tensor)
tensor_dims = tf.concat(
[t_shape[:axis], [t_shape[axis] * multiple], t_shape[axis + 1:]], 0)
multiple_dims = tf.concat([
tf.fill([axis + 1], 1), [multiple],
tf.fill([tf.rank(tensor) - axis - 1], 1)
], 0)
return tf.reshape(
tf.tile(tf.expand_dims(tensor, axis + 1), multiple_dims), tensor_dims)
[docs]def StackTensorsRecursively(values):
"""Recursively stacks Tensors in a list of `.NestedMap`.
Args:
values: a list of `.NestedMap` or Tensors to stacks.
Returns:
A `.NestedMap` with stacked values or a stacked Tensor.
"""
flatten = [w.Flatten() for w in values]
stacked = []
for i in range(len(flatten[0])):
stacked += [tf.stack([flatten[j][i] for j in range(len(flatten))])]
ret = values[0].Pack(stacked)
return ret
[docs]def MixByWeight(inputs: list[Callable[..., tf.Tensor]],
weights: tf.Tensor,
seed: Optional[int] = None) -> tuple[tf.Tensor, tf.Tensor]:
"""Returns a weighted random choice and bprop type from the give inputs.
Args:
inputs: a list of callables, where each callable returns a tf.Tensor or a
nested structure containing tf.Tensor. Function return types must be
consistent across elements. The tf.Operation to compute the result tensor
will only be invoked for one input at a time. For example, if each fn
represents an input record stream, a record will be drawn only from a
selected stream while the other streams will remain unchanged.
weights: a 1D tensor of float > 0 of the same length as inputs.
seed: random seed.
Returns:
A probabilistic sample from the inputs proportional to the weights. The
return type will be the same as return type of individual 'fn' from the
inputs.
A one-hot vector of the source selected.
"""
weights = tf.convert_to_tensor(weights, dtype=tf.float32)
weights = with_dependencies([
assert_equal(tf.shape(weights), [len(inputs)]),
assert_greater_equal(tf.reduce_min(weights), 0.0)
], weights)
lower = tf.cumsum(weights, exclusive=True)
upper = tf.cumsum(weights, exclusive=False)
r = tf.random.uniform(shape=[], maxval=upper[-1], seed=seed)
return_input = tf.case(
[(tf.math.logical_and(lower[i] <= r, r < upper[i]), inputs[i])
for i in range(len(inputs))],
exclusive=True)
selected_index = tf.case(
[(tf.math.logical_and(lower[i] <= r, r < upper[i]), lambda i=i: i)
for i in range(len(inputs))],
exclusive=True)
bprop_index = tf.one_hot(selected_index, len(inputs), dtype=tf.float32)
return return_input, bprop_index
[docs]def CheckShapes(shapes):
"""Asserts that shapes is a tuple of NestedMap or tshape.Shape."""
assert isinstance(shapes, tuple), str(shapes)
for s in shapes:
if isinstance(s, NestedMap):
assert all([isinstance(t, tshape.Shape) for t in Flatten(s)
]), '{} contains non-tensor value.'.format(s)
else:
assert isinstance(s, tshape.Shape), '{}: {}'.format(type(s), s)
[docs]def FPropDtype(params):
return params.fprop_dtype if params.fprop_dtype is not None else params.dtype
[docs]def UpdateFpropDtype(params, fprop_dtype):
"""Recursively update the fprop_dtype of the Params."""
# Handle the case when the input "params" is not an instance of hyperparams
# For example, when UpdateDtype is called recursively for all the items in
# the "sub" list of SequentialLayer (see 1st elif below)
if not isinstance(params, hyperparams.Params):
return
for key, val in params.IterParams():
if isinstance(val, hyperparams.Params):
UpdateFpropDtype(val, fprop_dtype)
elif isinstance(val, (list, tuple)):
for item in val:
UpdateFpropDtype(item, fprop_dtype)
elif key == 'fprop_dtype':
params.fprop_dtype = fprop_dtype
[docs]def UpdateDtype(params, dtype):
"""Recursively update the dtype of the Params."""
# Handle the case when the input "params" is not an instance of hyperparams
# For example, when UpdateDtype is called recursively for all the items in
# the "sub" list of SequentialLayer (see 1st elif below)
if not isinstance(params, hyperparams.Params):
return
for key, val in params.IterParams():
if isinstance(val, hyperparams.Params):
UpdateDtype(val, dtype)
elif isinstance(val, (list, tuple)):
for item in val:
UpdateDtype(item, dtype)
elif key == 'dtype':
params.dtype = dtype
[docs]def NameScopeDecorator(name_scope):
"""Decorates a python function to introduce a tf.name_scope.
Example::
@py_utils.NameScopeDecorator('foobar')
def MyFoobarMethod(self):
# ... Do TF things
Args:
name_scope: The name scope to introduce.
Returns:
A function decorator.
"""
def Decorator(f):
def Wrapped(*args, **kwargs):
with tf.name_scope(name_scope):
return f(*args, **kwargs)
return Wrapped
return Decorator
[docs]def SequencesToDebugStrings(ids, lens, summarize=5):
"""Returns debug strings for the given sequences.
Args:
ids: int32 of [batch, len].
lens: int32 of [batch].
summarize: number of ids to summarize per sequence.
Returns:
A string tensor of [batch].
"""
num_seqs = tf.shape(lens)[0]
def _Body(i, result):
line = tf.strings.format('{}', ids[i, :lens[i]], summarize=summarize)
return i + 1, tf.concat([result, tf.reshape(line, [1])], axis=0)
i0 = tf.zeros(shape=[], dtype=tf.int32)
result0 = tf.constant('', shape=[0], dtype=tf.string)
_, strs = tf.while_loop(
lambda i, result: i < num_seqs,
_Body, (i0, result0),
shape_invariants=(i0.shape, tf.TensorShape([None])))
return strs
# TODO(jamesqin): follow suggestions in
# b/167460492#comment16
[docs]def RematerializeFn(fn, *xs):
"""Calls fn and rematerializes fn in the backward pass.
`fn(*xs) -> ys`, where xs and ys can be a single tensor or a tuple of tensors.
Args:
fn: A python function to be rematerialized in the backprop pass.
*xs: A single tensor or a list/tuple of tensors. `xs` are input args to the
fn function.
Returns:
`fn(*xs)`
"""
initial_step_seed = GetStepSeed()
final_step_seed = MaybeGenerateSeedFromScope()
def Backward(fwd_xs, fwd_ys, d_fwd_ys):
"""The backward function that rematerializes forward outputs."""
del fwd_ys
always_true = tf.random.uniform([]) < 2.0
# Alternatively, can do this:
# tf.where(tf.math.is_nan(x),
# tf.constant(float('nan'), dtype=x.dtype) * tf.ones_like(x),
# x)
bak_xs = [tf.where(always_true, x, tf.zeros_like(x)) for x in fwd_xs.xs]
for dst, src in zip(bak_xs, xs):
dst.set_shape(src.shape)
ResetStepSeed(initial_step_seed)
ys = fn(*bak_xs)
MaybeResetStepSeed(final_step_seed)
dxs = tf.gradients(ys, bak_xs, grad_ys=d_fwd_ys)
dxs_final = []
for dx, x in zip(dxs, bak_xs):
if dx is None:
dxs_final.append(tf.zeros_like(x))
else:
dxs_final.append(dx)
assert len(dxs_final) == len(bak_xs)
return NestedMap(
initial_step_seed=tf.zeros_like(initial_step_seed), xs=dxs_final)
ys_shapes = []
# TODO(huangyp, yonghui): Check Forward doesn't use any stateful random ops.
def Forward(fwd_xs):
"""Forward function plus sanity checks."""
for dst, src in zip(fwd_xs.xs, xs):
dst.set_shape(src.shape)
ResetStepSeed(fwd_xs.initial_step_seed)
ys = fn(*fwd_xs.xs)
# Some sanity check.
assert not GetExtraInputs()
assert not GetExtraArgs()
assert not GetExtraVars()
if isinstance(ys, tuple):
for y in ys:
assert isinstance(y, tf.Tensor)
ys_shapes.append(y.shape)
else:
assert isinstance(ys, tf.Tensor)
ys_shapes.append(ys.shape)
return ys
ys = CallDefun(
Forward,
NestedMap(initial_step_seed=initial_step_seed, xs=xs),
bak=Backward)
if isinstance(ys, tuple):
for y, s in zip(ys, ys_shapes):
y.set_shape(s)
else:
ys.set_shape(ys_shapes[0])
# TODO(b/129159299): The ResetStepSeed below is needed to work around this
# bug, which is a problem with global tensors being shared by different
# inference graphs. It should be replaced with the new step seed value
# returned from the Forward function when the bug is fixed.
MaybeResetStepSeed(final_step_seed)
return ys
# A set of names of stateful random number generator ops.
# See tensorflow/core/ops/random_ops.cc
_STATEFUL_RANDOM_OPS = frozenset({
# pyformat: disable
'RandomUniform',
'RandomUniformInt',
'RandomStandardNormal',
'ParameterizedTruncatedNormal',
'TruncatedNormal',
'RandomShuffle',
'Multinomial',
'RandomGamma',
'RandomPoisson',
'RandomPoissonV2',
# pyformat: enable
})
[docs]def StatefulRandomOpsInDefun(func, graph=None):
"""Checks whether the Defun depends on stateful random number ops.
Stateful random number generator ops should be avoid in Recurrent() call.
Otherwise, these ops produce inconsistent values between FProp and BProp.
Args:
func: a _DefinedFunction or ConcreteFunction to check.
graph: a Graph. Set None to use the default graph.
Returns:
A list of names of the stateful random ops.
Raises:
InvalidArgumentError: if the input func/graph is invalid.
"""
if graph is None:
graph = tf.get_default_graph()
func.add_to_graph(graph)
graph_def = graph.as_graph_def()
# A dict from function name to FunctionDef.
func_defs = {x.signature.name: x for x in graph_def.library.function}
if isinstance(func, function._DefinedFunction): # pylint: disable=protected-access
if func.definition.signature.name not in func_defs:
raise tf.errors.InvalidArgumentError(
None, None, 'Defun {} is not in the graph .'.format(
func.definition.signature.name))
nodes = py_collections.deque(func.definition.node_def)
else:
nodes = py_collections.deque(func.function_def.node_def)
stateful_ops = []
# Recursively search for stateful random op.
while nodes:
node = nodes.pop()
assert isinstance(node, node_def_pb2.NodeDef), node
if node.op in _STATEFUL_RANDOM_OPS:
stateful_ops.append(node.name)
continue
def _AddDefunNodes(func_name):
"""If the given func_name is a Defun, add its sub-nodes into nodes."""
if func_name in func_defs:
nodes.extend(func_defs[func_name].node_def)
# For functional.{While|For|If} ops, add their Defun attr into search.
if node.op == 'While':
_AddDefunNodes(node.attr['body'].func.name)
_AddDefunNodes(node.attr['cond'].func.name)
elif node.op == 'For':
_AddDefunNodes(node.attr['body'].func.name)
elif node.op == 'If':
_AddDefunNodes(node.attr['then_branch'].func.name)
_AddDefunNodes(node.attr['else_branch'].func.name)
elif node.op == 'StatefulPartitionedCall':
_AddDefunNodes(node.attr['f'].func.name)
elif node.op != 'PartitionedCall':
# For other op, check whether itself is a Defun op.
_AddDefunNodes(node.op)
return stateful_ops
[docs]def ToPlaceholders(nmap, dtype=None):
"""Converts every Tensor in nmap to a placeholder."""
def _ToPlacerholder(x):
shape = [None for _ in x.shape[:-1]] + [x.shape[-1]]
return tf.placeholder(dtype=dtype or x.dtype, shape=shape)
return nmap.Transform(_ToPlacerholder)
[docs]def Softmax(logits, axis=None, extra_logit=None, name=None):
"""Softmax with extra_logits, might be useful for large xformer LM."""
if extra_logit is None:
return tf.nn.softmax(logits, axis=axis, name=name)
axis = -1 if axis is None else axis
def ReduceLogSumExp(x):
max_logit = tf.math.reduce_max(
tf.stop_gradient(x), axis=axis, keepdims=True)
base_logit = tf.math.maximum(max_logit, extra_logit)
x -= base_logit
exp_x = tf.math.exp(x)
sum_exp_x = tf.math.reduce_sum(exp_x, axis=axis, keepdims=True)
sum_exp_x += tf.math.exp(extra_logit - base_logit)
return tf.math.log(sum_exp_x) + base_logit
def LogSoftmax(x):
return x - ReduceLogSumExp(x)
with tf.name_scope(name):
return tf.math.exp(LogSoftmax(logits))
[docs]def SoftmaxCrossEntropyFocalLoss(logits,
label_ids=None,
label_probs=None,
alpha=None,
gamma=None,
stop_gradient_on_focal_loss_coefficient=False):
u"""Focal loss for multinomial (softmax) logistic loss.
[1] Focal loss https://arxiv.org/abs/1708.02002
Args:
logits: [..., C]. Logits for the multinomial logistic regression. C is the
number of classes.
label_ids: [...]. Each entry in labels must be an index in [0, C).
label_probs: [..., C]. Each vector along last dimension must be a valid
probability distribution.
alpha: [C]. The weighting factor alpha. Eq (3) in [1].
gamma: []. Tunable focusing parameter. Eq (4) in [1].
stop_gradient_on_focal_loss_coefficient: If true, stops gradient on the
focal loss coefficient (1-p)^gamma to stabilize the gradient.
Returns:
loss[i..., j] = FL(pₜ) = - αₜ(1-pₜ)Ë log(pₜ) Eq (5) in [1].
"""
def _ApplyFocalLossCoefficient(loss, log_probs):
if gamma is not None and gamma != 0:
probs = tf.exp(log_probs)
coefficient = tf.pow(1.0 - probs, gamma)
if stop_gradient_on_focal_loss_coefficient:
coefficient = tf.stop_gradient(coefficient)
loss *= coefficient
return loss
if label_probs is not None:
log_probs = tf.nn.log_softmax(logits)
loss = -(label_probs * log_probs)
loss = _ApplyFocalLossCoefficient(loss, log_probs)
if alpha is not None:
loss *= tf.reshape(
alpha, tf.concat([tf.ones(tf.rank(loss) - 1, tf.int32), [-1]],
axis=0))
loss = tf.reduce_sum(loss, axis=-1)
else:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=label_ids, logits=logits)
loss = _ApplyFocalLossCoefficient(loss, -loss)
if alpha is not None:
loss *= tf.gather(alpha, label_ids)
return loss
[docs]def SigmoidCrossEntropyFocalLoss(logits, labels, alpha=None, gamma=None):
u"""Focal loss for binary (sigmoid) logistic loss.
[1] Focal loss https://arxiv.org/abs/1708.02002
Args:
logits: [..., C]. Logits for the sigmoid logistic regression.
labels: [..., C]. 0/1 labels.
alpha: The weighting factor alpha. Eq (3) in [1].
gamma: Tunable focusing parameter. Eq (4) in [1].
Returns:
loss[i..., j] = FL(pₜ) = - αₜ(1-pₜ)Ë log(pₜ) Eq (5) in [1].
"""
# [1] Eq (4).
#
# The numerically-stable way to compute
# log(p) for positives;
# log(1 - p) for negatives.
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
if gamma is not None and gamma != 0:
# The modulating factor. Note that
# (1 - p)Ë = [1 - σ(x)]Ë = [σ(-x)]Ë , for positives.
# pË = [σ(x)]Ë , for negatives.
loss *= tf.pow(tf.sigmoid(logits * (1 - labels * 2)), gamma)
if alpha is not None:
# [1] Eq (3)
loss *= (alpha * labels + (1 - alpha) * (1 - labels))
return loss
_RECORD_FORMAT_RE = re.compile('(^[A-Za-z_]+):(.*)')
[docs]def ReadFileLines(file_path):
"""Read a text file and return the lines.
If the file cannot be found at the given path, attempt to load it from the
Lingvo package (useful for data dependencies in par files).
Args:
file_path: path to file, either absolute or relative to the bazel workspace.
Returns:
A list of lines from the file.
"""
if not tf.io.gfile.exists(file_path):
try:
lines = pkgutil.get_data(
'lingvo', file_path.replace('lingvo/', '', 1))
if lines:
lines = lines.splitlines(True)
except IOError:
# If pkgutil can't find the file, continue and let GFile raise the error.
lines = None
else:
lines = None
if not lines:
with tf.io.gfile.GFile(file_path, 'r') as f:
lines = f.readlines()
return lines
# Partially borrowed from
# https://github.com/tensorflow/tensor2tensor/blob/32929305e1a4ec926eff24123758b794df35492b/tensor2tensor/layers/common_layers.py#L349
[docs]def CumSum(x, axis=0, exclusive=False, use_einsum=False):
"""A TPU efficient implementation of tf.cumsum().
This is equivalent to tf.cumsum and is faster on TPU as of 08/2019 unless
the axis dimension is very large. The current Tensorflow implementation is
based on scanning and reducing which is not efficient on TPU.
Args:
x: An input Tensor.
axis: An int for the axis.
exclusive: A bool for performing exclusive cumsum.
use_einsum: If true, use einsum on TPU.
Returns:
A Tensor of the same shape as x.
Raises:
ValueError: if the input axis is invalid.
"""
if x.dtype not in (tf.float32, tf.bfloat16) or not use_tpu():
# Fallback to tf.cumsum when inputs are not floats or not running on TPU.
return tf.cumsum(x, axis=axis, exclusive=exclusive)
rank = GetRank(x)
# Needs to know the rank for the final transpose if axis is not the last
# dimension. Otherwise, falls back to tf.cumsum.
if not isinstance(rank, int) and axis != -1:
return tf.cumsum(x, axis=axis, exclusive=exclusive)
if axis < -1:
if axis + rank < 0:
raise ValueError('Unexpected axis: %d (rank = %d)' % (axis, rank))
axis += rank
if use_einsum:
assert isinstance(rank, int) and rank < 26, rank
# Use einsum to avoid data formatting overhead.
a2z = ''.join([chr(i) for i in range(97, 123)]) # abc...xyz
src = a2z[:rank]
if axis == -1:
tgt = src[:-1] + 'z'
else:
tgt = src[:axis] + 'z' + src[axis + 1:]
length = GetShape(x)[axis]
causal_mask = tf.linalg.band_part(
tf.ones([length, length], dtype=x.dtype), 0, -1)
return tf.einsum(f'{src},{src[axis]}z->{tgt}', x, causal_mask)
length = GetShape(x)[axis]
my_range = tf.range(length)
comparator = tf.less if exclusive else tf.less_equal
mask = tf.cast(
comparator(tf.expand_dims(my_range, 1), tf.expand_dims(my_range, 0)),
x.dtype)
result = tf.tensordot(x, mask, axes=[[axis], [0]])
if axis != -1 and axis != rank - 1:
result = tf.transpose(
result,
list(range(axis)) + [rank - 1] + list(range(axis, rank - 1)))
return result
[docs]def ProjectLastDim(inputs, weight, input_dim, output_dim, use_einsum=True):
"""Linear projection on the last dim of the input tensor.
This is a TPU efficient implementation to avoid reshaping inputs to Rank-2
tensor by using Einsum for the compute.
Args:
inputs: An input Tensor, the last dimension of which is input_dim.
weight: A weight matrix with shape [input_dim, output_dim].
input_dim: An integer or a symbolic dim, the last dimension of the inputs.
output_dim: An integer or a symbolic dim, the last dimension of the outputs.
use_einsum: use tf.einsum for calculation (True), or use tf.matmul (False)
Returns:
An output Tensor of the same rank as inputs, the last dimension is
output_dim.
"""
input_dim = int(
symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim) else input_dim)
output_dim = int(
symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim
) else output_dim)
# Assert input_dim and output_dim
inputs = with_dependencies([assert_equal(GetShape(inputs)[-1], input_dim)],
inputs)
weight = with_dependencies([
assert_equal(GetShape(weight)[0], input_dim),
assert_equal(GetShape(weight)[-1], output_dim)
], weight)
if inputs.shape.rank == 2:
# Shortcut for 2D inputs: just use tf.matmul, no need for Einsum.
outputs = tf.matmul(inputs, weight)
elif (use_einsum and use_tpu() and inputs.shape is not None and
inputs.shape.rank is not None and inputs.shape.rank < 26):
# This is equivalent to:
# outputs = tf.einsum('...y,yz->...z', inputs, weight)
# Unfortunately ... in einsum() leads to extra HBM usage.
s = ''.join([chr(x) for x in range(97, 123)]) # abc...xyz
r = inputs.shape.rank
outputs = tf.einsum('{0}y,yz->{0}z'.format(s[:r - 1]), inputs, weight)
else:
# not use_einsum or not use_tpu() or inputs.shape.rank >= 26
outputs = Matmul(tf.reshape(inputs, [-1, input_dim]), weight)
outputs = tf.reshape(
outputs,
tf.concat([tf.cast(GetShape(inputs)[:-1], tf.int32), [output_dim]],
axis=0))
return outputs
[docs]@contextlib.contextmanager
def RemoveAssertContext(remove=True):
"""Hacks to replace certain unwanted tensorflow ops."""
# TODO(zhifengc/huangyp): Consider implementing assert_equal
# op replacement for lingvo. As assert_equal doesn't support String on GPUs.
# Hack to replace tf.assert_equal
# TODO(b/136040013): Remove this after migration to tf.function.
if remove:
saved_assert_equal = tf.check_ops.assert_equal
def NoOP(*args, **kwargs): # pylint: disable=unused-argument
return tf.no_op()
tf.check_ops.assert_equal = NoOP # Make assert_equal a no op.
try:
yield
finally:
tf.check_ops.assert_equal = saved_assert_equal
else:
yield
[docs]def TensorSpecs(nmap, keep_shape=True):
"""Transforms tensors in the input nested structure to TensorSpecs."""
if nmap is None:
return None
fn = lambda t: tf.TensorSpec(t.shape if keep_shape else None, t.dtype)
return Transform(fn, nmap)
[docs]def _DefineDefun(fwd, fwd_sig, bak=None, bak_as_function=False, device=None):
"""Wraps fwd in a defun with custom gradient bak.
Args:
fwd: A callable xs: Nested Structure -> ys: Nested Structure.
fwd_sig: A Nested Structure of tf.TensorSpec representing the input
signature of `fwd`, or None (meaning that fwd takes no inputs).
bak: A callable xs, ys, dys: Nested Structure -> dxs[, dcapture]: Nested
Structure. The custom backprop function for `fwd`. bak needs to return
dcapture if fwd uses any implicitly captured tensors, whose gradients are
dcapture.
bak_as_function: Whether to create a TF graph function for `bak`.
device: the device on which to run `fwd` and `bak`.
Returns:
A NestedMap containing:
- call: A callable that will execute `fwd`. It has the same input and output
signatures as `fwd`.
- func: The underlying TF function that `call` calls. If not None, it will
be a _DefinedFunction or ConcreteFunction that takes flat inputs and
returns flat outputs, and can be used by routines that require a TF
function object (e.g. tf.If, tf.While, etc).
Always not None when `bak` is None.
- output_dtypes: A nested structure compatible with the outputs of `fwd`
containing the corresponding output dtypes.
- stateful_ops: A list of (op_name, op_type) tuples representing the
stateful ops used by `fwd`.
- captured_inputs: Implicit inputs captured by `fwd`.
"""
assert fwd is not None
noinline = False
if fwd_sig is None:
fwd_sig = []
get_dtype = lambda x: x.dtype
arg_dtypes = Flatten(Transform(get_dtype, fwd_sig))
get_shape = lambda x: x.shape
arg_shapes = Flatten(Transform(get_shape, fwd_sig))
# Used to hold the backward function used by Grad, which will be defined if
# bak is set.
sigs = NestedMap()
# Output of this method.
res = NestedMap()
python_grad_func = None
if bak:
def Grad(op, *args):
"""Gradient function for the forward function.
Args:
op: The forward operation.
*args: Gradients wrt op.outputs.
Returns:
Tuple of derivatives.
"""
_AssertInputsMatch(op, fwd_sig, res.captured_inputs)
# Ensure dys contains no None.
args = ConvertNoneGradientToZeros(list(op.outputs), list(args))
xs = op.inputs[:len(arg_dtypes)] # The rest are captures.
return sigs.backward(*Flatten([xs, op.outputs, args]))
python_grad_func = Grad
def _SetShape(dst_list, shape_list):
for dst, shape in zip(dst_list, shape_list):
if isinstance(dst, tf.Tensor):
dst.set_shape(shape)
@tf.Defun(*arg_dtypes, python_grad_func=python_grad_func, noinline=noinline)
def Forward(*args):
"""The forward function."""
_SetShape(args, arg_shapes)
with RemoveAssertContext(remove=noinline):
call = lambda: fwd(Pack(fwd_sig, args)) if args else fwd()
if device is None:
# Defun will handle the device assignment.
rets = call()
else:
with tf.device(device):
rets = call()
res.outputs = rets
return Flatten(rets)
forward = Forward
if not arg_dtypes:
# In this case Forward is an _OverloadedFunction, we need to instantiate it.
forward = Forward.instantiate([])
# Invokes fwd() to get res.outputs.
forward.add_to_graph(tf.get_default_graph())
res.func = forward
res.stateful_ops = forward.stateful_ops
res.captured_inputs = forward.captured_inputs
output_dtypes = Transform(get_dtype, res.outputs)
output_shapes = Transform(get_shape, res.outputs)
def Call(args=None):
"""Wrapper of fwd."""
if args is None:
flat_rets = forward()
else:
flat_rets = forward(*Flatten(args))
if not isinstance(flat_rets, (tuple, list)):
flat_rets = [flat_rets]
_SetShape(flat_rets, Flatten(output_shapes))
return Pack(output_dtypes, flat_rets)
res.call = Call
if bak:
def Backward(*args):
"""The backward function."""
_SetShape(args, Flatten([arg_shapes, output_shapes, output_shapes]))
xs, ys, dys = Pack([fwd_sig, output_dtypes, output_dtypes], args)
with RemoveAssertContext(remove=noinline):
if device is None:
# Defun will handle the device assignment.
dxs = bak(xs, ys, dys)
else:
with tf.device(device):
dxs = bak(xs, ys, dys)
return Flatten(dxs)
if bak_as_function:
sigs.backward = tf.Defun(
*Flatten([arg_dtypes, output_dtypes, output_dtypes]),
noinline=noinline)(
Backward)
sigs.backward.add_to_graph(tf.get_default_graph())
else:
sigs.backward = Backward
return res
# Global variable to control rendezvous sharing in tf.function.
# If False (default) rendezvous sharing is disabled in tf.function, that is, the
# function body use a separate rendezvous and can't communicate with parent
# graph via send/recv.
# With _GetSharedRendezvous() == True, the function body share the same
# rendezvous with the parent graph and can talk to it using send/recv. This is
# useful for layers like StackedRecurrent.
_SHARED_RENDEZVOUS = ThreadLocalStack()
[docs]@contextlib.contextmanager
def _SharedRendezvousScope(shared_rendezvous=True):
_SHARED_RENDEZVOUS.stack.append(shared_rendezvous)
try:
yield
finally:
_SHARED_RENDEZVOUS.stack.pop()
[docs]def _GetSharedRendezvous():
"""Get the current rendezvous sharing setting."""
return _SHARED_RENDEZVOUS.stack[-1] if _SHARED_RENDEZVOUS.stack else False
[docs]def _ApplySharedRendezvous(func):
"""Apply the rendezvous sharing setting on the given tf.function func."""
# pylint: disable=protected-access
func._shared_rendezvous = _GetSharedRendezvous()
# pylint: enable=protected-access
[docs]def _WrapFunction(func=None, input_signature=None):
"""Wraps func as a tf.function."""
if input_signature is None:
input_signature = []
def Decorated(fn):
@tf.function(input_signature=input_signature, autograph=False)
def Fn(*args):
# TODO(b/163904067): mimic Defun' behavior and reset the step seed to
# avoid it being used as an implicit capture. This is not a desired
# behavior, it should take the step seed from parent graph instead.
ResetStepSeed()
# Mimic Defun and disable collection sharing.
graph = tf.get_default_graph()
# Don't share summaries collection with parent graph (b/168745134).
graph.clear_collection(tf.GraphKeys.SUMMARIES)
return fn(*args)
_ApplySharedRendezvous(Fn)
# Add the function to the graph so it'll be traced under the current
# context. This is necessary if the function body captures any non-tensor
# values from the environment, like symbolic maps.
cf = Fn.get_concrete_function()
cf.add_to_graph()
return cf
# For the `foo = _WrapFunction(foo, ...)` use case.
if func is not None:
return Decorated(func)
# For the `@_WrapFunction(...)` use case.
return Decorated
[docs]def _DefineFunction(fwd, fwd_sig, bak=None, bak_as_function=False, device=None):
"""Wraps fwd in a defun with custom gradient bak.
Args:
fwd: A callable xs: Nested Structure -> ys: Nested Structure.
fwd_sig: A Nested Structure of tf.TensorSpec representing the input
signature of `fwd`, or None (meaning that fwd takes no inputs).
bak: A callable xs, ys, dys: Nested Structure -> dxs[, dcapture]: Nested
Structure. The custom backprop function for `fwd`. bak needs to return
dcapture if fwd uses any implicitly captured tensors, whose gradients are
dcapture.
bak_as_function: Whether to create a TF graph function for `bak`.
device: the device on which to run `fwd` and `bak`.
Returns:
A NestedMap containing:
- call: A callable that will execute `fwd`. It has the same input and output
signatures as `fwd`.
- func: The underlying TF function that `call` calls. If not None, it will
be a _DefinedFunction or ConcreteFunction that takes flat inputs and
returns flat outputs, and can be used by routines that require a TF
function object (e.g. tf.If, tf.While, etc).
Always not None when `bak` is None.
- outputs: The outputs of `fwd`. Used for reflection only (e.g. to get the
output dtypes, shapes, etc).
- stateful_ops: A list of (op_name, op_type) tuples representing the
stateful ops used by `fwd`.
- captured_inputs: Implicit inputs captured by `fwd`.
"""
assert fwd is not None
noinline = not use_xla()
if fwd_sig is None:
fwd_sig = []
if device is None:
# Get the current device to mimic Defun's behavior.
# pylint: disable=protected-access
device_funcs = tf.get_default_graph()._device_functions_outer_to_inner
device = device_funcs[-1] if device_funcs else None
# pylint: enable=protected-access
# Output of this method.
res = NestedMap()
@_WrapFunction(input_signature=Flatten(fwd_sig))
def Forward(*args):
"""The forward function."""
with RemoveAssertContext(remove=noinline), tf.device(device):
if args:
xs = Pack(fwd_sig, args)
rets = fwd(xs)
else:
rets = fwd()
res.outputs = rets
return Flatten(rets)
res.captured_inputs = Forward.captured_inputs
# Get the stateful ops used in cell_fn. Logic borrowed from
# _EagerDefinedFunction.__init__().
graph = Forward.graph
input_ops = set(arg.op for arg in graph.inputs)
operations = [op for op in graph.get_operations() if op not in input_ops]
res.stateful_ops = [(o.name, o.type) for o in operations if o._is_stateful] # pylint: disable=protected-access
def Call(func, args=None):
"""Wrapper of fwd."""
if args is None:
flat_rets = func()
else:
flat_rets = func(*Flatten(args))
if not isinstance(flat_rets, (tuple, list)):
flat_rets = [flat_rets]
return Pack(res.outputs, flat_rets)
if not bak:
res.func = Forward
res.call = lambda args=None: Call(Forward, args)
return res
shared_rendezvous = _GetSharedRendezvous()
ret_specs = TensorSpecs(res.outputs)
def Backward(*args):
xs, ys, dys = Pack([fwd_sig, ret_specs, ret_specs], args)
with RemoveAssertContext(remove=noinline), tf.device(device):
dxs = bak(xs, ys, dys)
return Flatten(dxs)
if bak_as_function:
backward_cf = _WrapFunction(
Backward, input_signature=Flatten([fwd_sig, ret_specs, ret_specs]))
else:
def BackwardWithSharedRendezvous(*args):
with _SharedRendezvousScope(shared_rendezvous):
return Backward(*args)
backward_cf = BackwardWithSharedRendezvous
@tf.custom_gradient
def ForwardWithGrad(*args):
"""Forward function and its custom gradient."""
# Note that `args` includes implicit captures. This is required by
# tf.custom_gradient so that when the Grad() outputs include gradients to
# implicit captures, they match the inputs to ForwardWithGrad().
#
# However, Forward doesn't take implicit captures as input, so we exclude
# them here.
fwd_args = args[:(len(args) - len(Flatten(res.captured_inputs)))]
op = NestedMap(inputs=args, outputs=Forward(*fwd_args))
def Grad(*args, **kwargs):
"""Gradient function for the forward function.
Args:
*args: Gradients wrt op.outputs.
**kwargs: Additional arguments from tf.custom_gradient.
Returns:
Tuple of derivatives.
"""
if kwargs:
tf.logging.warning(
'Ignoring additional arguments used by tf.custom_gradient: %s',
str(kwargs))
_AssertInputsMatch(op, fwd_sig, res.captured_inputs)
# Ensure dys contains no None.
args = ConvertNoneGradientToZeros(list(op.outputs), list(args))
xs, _ = Pack([fwd_sig, res.captured_inputs], op.inputs)
return backward_cf(*Flatten([xs, op.outputs, args]))
return op.outputs, Grad
res.func = None
forward = lambda *xs: ForwardWithGrad(*Flatten([xs, res.captured_inputs]))
res.call = lambda args=None: Call(forward, args)
return res
# Global variable to control whether to use tf.function.
# If not set, the result is determined by tf2 status. See _UseTfFunction for
# details.
# TODO(laigd): remove after b/169869929 is fixed.
_USE_TF_FUNCTION = ThreadLocalStack()
# Constants for propagating framework tensors through Function.
_FRAMEWORK_TENSOR_GLOBAL_STEP = '_global_step'
[docs]@contextlib.contextmanager
def TfFunctionScope(use_tf_function=True):
_USE_TF_FUNCTION.stack.append(use_tf_function)
try:
yield
finally:
_USE_TF_FUNCTION.stack.pop()
[docs]def _UseTfFunction():
"""Whether to use tf.function instead of tf.Defun."""
if _USE_TF_FUNCTION.stack:
return _USE_TF_FUNCTION.stack[-1]
return True
[docs]class Function(object):
"""Function builds a TensorFlow graph function from a callable.
In the high level this is similar to tf.Defun and tf.function. In fact this
relies on those as underlying implementations, but with specific configuration
so it's easier to use and can work well in some extreme cases in Lingvo.
Example usage:
- No inputs:
>>> @Function()
... def foo():
... return tf.constant(1.0)
>>> y = foo()
- Scalar input:
>>> @Function(fwd_sig=tf.TensorSpec(None, tf.float32))
... def foo(x):
... return x * 2
>>> y = foo(1.0)
- List input:
>>> @Function(fwd_sig=[tf.TensorSpec(None, tf.float32) for _ in range(2)])
... def foo(xs):
... return xs[0] + xs[1]
>>> y = foo([1.0, 2.0])
- Nested input:
>>> @Function(fwd_sig=NestedMap(x=tf.TensorSpec(None, tf.float32)))
... def foo(nmap):
... return nmap.x * 2
>>> y = foo(NestedMap(x=1.0))
- With custom gradient function (other input types mentioned above are also
supported):
>>> def bar(x, y, dy):
... del y, dy
... return 4.0 * x * dy
>>>
>>> @Function(fwd_sig=tf.TensorSpec(None, tf.float32), bak=bar)
... def foo(x):
... return 2.0 * x * x
- Used in control flow ops:
>>> then_branch = Function(tf.TensorSpec([], tf.int32))(lambda x: x / 2)
>>> else_branch = Function(tf.TensorSpec([], tf.int32))(lambda x: 3 * x + 1)
>>> y = tf.If(cond, inputs, then_branch.func, else_branch.func)
"""
def __init__(self,
fwd_sig=None,
bak=None,
bak_as_function=False,
device=None,
use_tf_function=None):
"""Constructor.
Below we assume `fwd` is the input to `__call__` that is used to build the
TensorFlow graph function encapsulated by this object.
Args:
fwd_sig: A Nested Structure of tf.TensorSpec representing the input
signature of `fwd`, or None (meaning that `fwd` takes no inputs). The
actual inputs should be compatible with this (have same shapes and
dtypes).
bak: A callable xs, ys, dys: Nested Structure -> dxs[, dcapture]: Nested
Structure. The custom backprop function for `fwd`. bak needs to return
dcapture if `fwd` uses any implicitly captured tensors, whose gradients
are dcapture.
bak_as_function: Whether to create a TF graph function for `bak`.
device: The device on which to run `fwd` and `bak`. Defaults to the
current device.
use_tf_function: Deprecated.
"""
del use_tf_function
self._fwd_sig = fwd_sig
self._bak = bak
self._bak_as_function = bak_as_function
self._device = device
def __call__(self, fwd):
"""Creates a graph function.
Args:
fwd: a callable xs: Nested Structure -> ys: Nested Structure.
Returns:
A DefinedFunction object encapsulating `fwd` as a graph function.
"""
assert callable(fwd)
return DefinedFunction(fwd, self._fwd_sig, self._bak, self._bak_as_function,
self._device)
[docs]class DefinedFunction(object):
"""Encapsulates a TensorFlow graph function and its properties."""
def __init__(self,
fwd,
fwd_sig=None,
bak=None,
bak_as_function=False,
device=None):
"""Constructor.
Args:
fwd: A callable xs: Nested Structure -> ys: Nested Structure. Used to
build the TensorFlow graph function that this object encapsulates.
fwd_sig: A Nested Structure of tf.TensorSpec representing the input
signature of `fwd`, or None (meaning that `fwd` takes no inputs). The
actual inputs should be compatible with this (have same shapes and
dtypes).
bak: A callable xs, ys, dys: Nested Structure -> dxs[, dcapture]: Nested
Structure. The custom backprop function for `fwd`. bak needs to return
dcapture if `fwd` uses any implicitly captured tensors, whose gradients
are dcapture.
bak_as_function: Whether to create a TF graph function for `bak`.
device: The device on which to run `fwd` and `bak`. Defaults to the
current device.
"""
self._fwd_sig = fwd_sig
wrapped_fwd_sig = fwd_sig
fwd_fn = fwd
bak_fn = bak
graph_random_seed = None
if tf.get_default_graph().seed is not None:
graph_random_seed = tf.get_default_graph().seed
# Wrap the forward function to propagate framework tensors like step_seed
# and global_step.
wrapped_fwd_sig = NestedMap()
self._added_global_step = False
if GetGlobalStep() is not None:
wrapped_fwd_sig[_FRAMEWORK_TENSOR_GLOBAL_STEP] = (
tf.TensorSpec([], tf.int64))
self._added_global_step = True
if fwd_sig is not None:
wrapped_fwd_sig.inputs = fwd_sig
elif not wrapped_fwd_sig:
wrapped_fwd_sig = None
def ForwardWrapped(wrapped_inputs=None):
if graph_random_seed is not None:
tf.random.set_seed(graph_random_seed)
global_step = None
if wrapped_inputs:
assert isinstance(wrapped_inputs, NestedMap)
global_step = wrapped_inputs.get(_FRAMEWORK_TENSOR_GLOBAL_STEP, None)
with GlobalStepContext(global_step):
if wrapped_inputs and 'inputs' in wrapped_inputs:
result = fwd(wrapped_inputs.inputs)
else:
result = fwd()
return result
fwd_fn = ForwardWrapped
if bak:
# Wrap the backward function to return zero gradients for framework
# tensors like step_seed and global_step.
def BackwardWrapped(wrapped_xs, ys, dys):
if graph_random_seed is not None:
tf.random.set_seed(graph_random_seed)
with GlobalStepContext(
wrapped_xs.get(_FRAMEWORK_TENSOR_GLOBAL_STEP, None)):
result = bak(wrapped_xs.inputs, ys, dys)
dxs = Transform(tf.zeros_like, wrapped_xs)
if isinstance(result, tuple) and len(result) == 2:
dxs.inputs, dcapture = result
return dxs, dcapture
else:
dxs.inputs = result
return dxs
bak_fn = BackwardWrapped
self._data = _DefineFunction(
fwd=fwd_fn,
fwd_sig=wrapped_fwd_sig,
bak=bak_fn,
bak_as_function=bak_as_function,
device=device)
def __call__(self, args=None):
"""Invokes the graph function.
Args:
args: the inputs to the graph function, must be compatible with `fwd_sig`.
Returns:
The output tensors with the same structure as the output of `fwd`,
returned by a call to the graph function.
"""
assert IsCompatible(args,
self._fwd_sig), '{} vs {}'.format(args, self._fwd_sig)
return self._data.call(self.AddFrameworkInputs(args))
@property
def func(self):
"""The underlying TensorFlow graph function that this object encapsulates.
The returned graph function is created by tracing `fwd` during construction.
If not None, it will be a _DefinedFunction or ConcreteFunction that takes
flat inputs and returns flat outputs, and can be used by routines that
require a TensorFlow function object (e.g. tf.If, tf.While, etc).
If no backprop function is provided during construction, the result is
always not None.
"""
return self._data.func
@property
def output_dtypes(self):
"""Output dtypes of the graph function.
The result will have the same structure as the outputs of `fwd` but contain
the corresponding output dtypes.
"""
return Transform(lambda x: x.dtype, self._data.outputs)
@property
def stateful_ops(self):
"""Stateful ops used by `fwd`, as a list of (op_name, op_type) tuples."""
return self._data.stateful_ops
@property
def captured_inputs(self):
"""Implicit input tensors captured by `fwd`."""
return self._data.captured_inputs
[docs]def CallDefun(fwd, args=None, bak=None, bak_as_function=False, device=None):
"""Wraps fwd in a defun with custom gradient bak and calls it with args.
Args:
fwd: A callable xs: Nested Structure -> ys: Nested Structure.
args: A Nested Structure of tf.Tensor or None.
bak: A callable xs, ys, dys: Nested Structure -> dxs[, dcapture]: Nested
Structure. The custom backprop function for fwd. bak needs to return
dcapture if fwd uses any implicitly captured tensors, whose gradients are
dcapture.
bak_as_function: Whether to create a TF graph function for bak.
device: the device on which to run fwd and bak.
Returns:
A Nested Structure equivalent to what fwd(args) computes.
"""
if args is not None:
args = Transform(tf.convert_to_tensor, args)
sigs = Function(
fwd_sig=TensorSpecs(args),
bak=bak,
bak_as_function=bak_as_function,
device=device)(
fwd=fwd)
if args is None:
return sigs()
else:
return sigs(args)
[docs]def If(cond, inputs, then_branch, else_branch):
"""Helper to construct an if/else statement.
Args:
cond: A scalar `Tensor` that can be converted to boolean.
inputs: A flattenable representing the input tensors of the if/else
statement. Can be None to represent no inputs.
then_branch: A callable 'inputs' -> flattenable. The returned value should
be compatible with what 'else_branch' returns.
else_branch: A callable 'inputs' -> flattenable. The returned value should
be compatible with what 'then_branch' returns.
Returns:
Output returned by the call to either 'then_branch' or 'else_branch'.
"""
fwd_sig = TensorSpecs(inputs)
then_sigs = Function(fwd_sig=fwd_sig)(fwd=then_branch)
else_sigs = Function(fwd_sig=fwd_sig)(fwd=else_branch)
assert IsCompatible(then_sigs.output_dtypes, else_sigs.output_dtypes), (
'Outputs of then_branch and else_branch are not compatible: {} vs {}'
.format(then_sigs.output_dtypes, else_sigs.output_dtypes))
if then_sigs.captured_inputs != else_sigs.captured_inputs:
raise ValueError('Differing captured inputs in then and else. '
'Ensure the same tensors are captured in the same order.')
ret = tf.If(
cond=cond,
inputs=Flatten(then_sigs.AddFrameworkInputs(inputs)) +
then_sigs.captured_inputs,
then_branch=then_sigs.func,
else_branch=else_sigs.func)
return Pack(then_sigs.output_dtypes, ret)
[docs]def _Itype():
"""Loop iterator data type."""
return tf.int32 if use_xla() else tf.int64
[docs]def WhileLoop(cond, body, loop_state):
"""Helper to construct a while loop.
Args:
cond: A callable NestedMap -> tf.bool.
body: A callable NestedMap -> NestedMap.
loop_state: A flattenable (NestedMap, list, tuple, etc.) representing the
loop state.
Returns:
The final loop state in the same structure as loop_state.
"""
fwd_sig = TensorSpecs(loop_state)
cond_sigs = Function(fwd_sig=fwd_sig)(fwd=cond)
def BodyWrapped(loop_state):
result = body(loop_state)
# loop_state is augmented with global tensors inside of DefinedFunction.
# WhileLoop needs to return the same structure as the inputs, so we augment
# the return value here to match.
result = cond_sigs.AddFrameworkInputs(result)
return result
body_sigs = Function(fwd_sig=fwd_sig)(fwd=BodyWrapped)
wrapped_inputs = body_sigs.AddFrameworkInputs(loop_state)
new_state = tf.While(
Flatten(wrapped_inputs), cond=cond_sigs.func, body=body_sigs.func)
# The functional `While` used above does not have a registered gradient.
# This was not a problem in Graph mode, however in Eager mode,
# GradientTape will attempt to call the gradient of the While op in the
# forward pass. `stop_gradient` is used to pretend the op is a constant
# in the forward pass. This also avoids calling the gradient of other ops in
# `While` in the forward pass.
# Details in https://www.tensorflow.org/api_docs/python/tf/custom_gradient.
# Guarded by 'IsEagerMode' to limit impact.
if IsEagerMode():
new_state = [tf.stop_gradient(t) for t in new_state]
return Pack(wrapped_inputs, new_state).inputs
[docs]def ForLoop(body, start, limit, delta, loop_state):
"""Helper to construct a for loop.
Args:
body: A callable (tf.int, NestedMap) -> NestedMap.
start: Loop variable's initial value.
limit: Loop variable's limit value.
delta: Loop variable's change per iteration.
loop_state: A flattenable (NestedMap, list, tuple, etc.) representing the
loop state.
Returns:
The final loop state in the same structure as loop_state.
"""
state = NestedMap(
iter=tf.cast(start, _Itype()),
limit=tf.cast(limit, _Itype()),
delta=tf.cast(delta, _Itype()),
loop_state=loop_state)
def LoopCond(state):
return tf.less(state.iter, state.limit)
def LoopBody(state):
state.loop_state = body(state.iter, state.loop_state)
state.iter = tf.add(state.iter, state.delta)
return state
return WhileLoop(LoopCond, LoopBody, state).loop_state
[docs]def TopK(x_in, k):
"""Equivalent to tf.math.top_k(x_in, k) but more efficient on tpu."""
assert k <= 2, 'This implementation is only efficient for small k.'
# TODO(yonghui): Try out an alternative idea where we first reshape x_in as a
# 2d tensor, then call tf.math.top_k, and then reshape back.
x_in_shape = x_in.shape
x_rank = x_in_shape.rank
assert x_rank and x_in_shape.as_list()[x_rank - 1] > 0
last_dim_size = x_in_shape.as_list()[x_rank - 1]
min_value = tf.math.reduce_min(x_in) - 1.0
out_indices = []
out_values = []
for unused_i in range(k):
index_i = tf.math.argmax(x_in, axis=-1, output_type=tf.int32)
mask_i = tf.one_hot(index_i, last_dim_size)
# TODO(yonghui): Would tf.gather be more efficient and numerically stable
# here?
value_i = tf.reduce_sum(mask_i * x_in, -1, keepdims=True)
x_in = (1.0 - mask_i) * x_in + mask_i * min_value
out_indices.append(tf.expand_dims(index_i, -1))
out_values.append(value_i)
if k == 1:
return out_values[0], out_indices[0]
else:
return tf.concat(out_values, x_rank - 1), tf.concat(out_indices, x_rank - 1)
[docs]def ReadVariable(var_op):
"""Returns the value of the given variable operation.
Args:
var_op: the `Operation` object for a VarHandleOp.
Raises:
TypeError: if var_op is not a VarHandleOp.
Returns:
A `Tensor` containing the value of the variable.
"""
if var_op.type != 'VarHandleOp':
raise TypeError('var_op should be a VarHandleOp, got %s' % str(var_op.type))
# Filter out the ReadVariableOps that have control dependencies to avoid
# side-effects when the user runs it.
filter_fn = lambda op: op.type == 'ReadVariableOp' and not op.control_inputs
var_readers = list(filter(filter_fn, var_op.outputs[0].consumers()))
assert var_readers
return var_readers[0].outputs[0]
_TPU_SUMMARY_TENSORS_KEY = ('__lingvo_tpu_summary_tensors')
_TPU_SUMMARY_CONTEXTS = ThreadLocalStack()
[docs]def _GetTpuSummaryTensor():
if _TPU_SUMMARY_CONTEXTS.stack:
return _TPU_SUMMARY_CONTEXTS.stack[-1]
return _CollectionGetter(_TPU_SUMMARY_TENSORS_KEY, lambda: [])()
[docs]@contextlib.contextmanager
def TpuSummaryTensorContext():
"""Creates a context where AddTpuSummaryTensor() will add tensors."""
_TPU_SUMMARY_CONTEXTS.stack.append([])
try:
yield
finally:
_TPU_SUMMARY_CONTEXTS.stack.pop()
[docs]def AddTpuSummaryTensor(name, value, weight=1.0):
"""Adds tensor to global collection of summaries, or a local context if any.
This needs to be used in situations where tf.summary() could be used but
currently tf.summary is not supported. Use py_utils.AddTpuSummaryTensor() in
low level code to add summary tensors to global collection of summaries.
Then recover all summary tensors from global collection by calling
py_utils.GetTpuSummaryTensors() from top level code (for example from
ComputeLoss method of BaseTask).
In addition to 'name' argument, current tensorflow name scope is also
captured and added to the metric name. This way for example summaries from
a repeated layer will appear as separate graphs in the tensorboard.
Weight argument is optional and defaults to 1.0. See BaseTask.ComputeLoss for
the exact definition of weight for eval metrics.
Args:
name: metric name
value: metric value tensor
weight: weight tensor for weighted metrics
"""
tpu_summary_tensors = _GetTpuSummaryTensor()
x = NestedMap()
x.name = name
x.value = value, tf.convert_to_tensor(weight)
x.name_scope = tf.get_default_graph().get_name_scope()
tpu_summary_tensors.append(x)
[docs]def GetTpuSummaryTensors():
"""Returns summary tensors from global collection.
Returns:
A dict containing str keys and (metric, weight) pairs as values
"""
tpu_summary_tensors = _GetTpuSummaryTensor()
return {
'%s/%s' % (x.name, SanitizeScopeKey(x.name_scope)): x.value
for x in tpu_summary_tensors
}
[docs]def ClearTpuSummaryTensors():
tpu_summary_tensors = _GetTpuSummaryTensor()
del tpu_summary_tensors[:]
_SAVED_MODEL_ASSETS = ThreadLocalStack()
[docs]def AddSavedModelAssets(asset):
if tf.executing_eagerly_outside_functions():
_SAVED_MODEL_ASSETS.stack.append(asset)
else:
# Yes, graph collections are bad, however this seems to be the
# easiest way to get this assets registered from
# TextFileInitializer.
tf.compat.v1.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS,
asset.asset_path)
[docs]def GetSavedModelAssets():
if tf.executing_eagerly_outside_functions():
return _SAVED_MODEL_ASSETS.stack
else:
# Yes, graph collections are bad, however this seems to be the
# easiest way to get this assets registered from
# TextFileInitializer.
return tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
[docs]def ComputationShape(split_size, topology=None):
"""Decides the computation shape based on the split_size.
Args:
split_size: number of accelerators to use per split.
topology: a serialized string of `tensorflow.tpu.TopologyProto`, or a
`tf.tpu.experimental.Topology` object, that describes the TPU cluster
topology. If not set, it'll use a default setting based on split_size.
Returns:
A 4-element list that describes the computation shape.
"""
if topology:
if isinstance(topology, tf.tpu.experimental.Topology):
topology_info = topology
else:
topology_info = tf_topology.Topology(serialized=topology)
computation_shape = None
if topology and functools.reduce(lambda a, b: a * b,
topology_info.mesh_shape) == split_size:
computation_shape = topology_info.mesh_shape
elif split_size == 1:
computation_shape = [1, 1, 1, 1]
elif topology and topology_info.mesh_shape[
-1] == 1 and split_size in topology_info.mesh_shape:
# For Megacore, if we find exact match on mesh shape, map split_size to it
computation_shape = [1, 1, 1, 1]
computation_shape[topology_info.mesh_shape.tolist().index(
split_size)] = split_size
else:
if topology:
cores_per_chip = topology_info.mesh_shape[-1]
else:
cores_per_chip = 2
assert split_size % cores_per_chip == 0
split_chips = split_size // cores_per_chip
if split_chips == 1:
computation_shape = [1, 1, 1, cores_per_chip]
elif split_chips == 2:
computation_shape = [1, 2, 1, cores_per_chip]
elif split_chips == 4:
computation_shape = [2, 2, 1, cores_per_chip]
elif split_chips == 8:
computation_shape = [4, 2, 1, cores_per_chip]
elif split_chips == 12:
computation_shape = [1, 1, 12, cores_per_chip]
elif split_chips == 16:
computation_shape = [4, 4, 1, cores_per_chip]
elif split_chips == 24:
computation_shape = [1, 2, 12, cores_per_chip]
elif split_chips == 32:
if topology and topology_info.mesh_shape[1] == 32:
# Fwd within-replica all-reduces is performed along column;
# Bwd gradient cross-replica all-reduces is performed along row.
# This currently has better performance than the strided patten.
computation_shape = [1, 32, 1, cores_per_chip]
else:
computation_shape = [4, 8, 1, cores_per_chip]
elif split_chips == 64:
computation_shape = [8, 8, 1, cores_per_chip]
elif split_chips == 128:
computation_shape = [8, 16, 1, cores_per_chip]
elif split_chips == 256:
computation_shape = [16, 16, 1, cores_per_chip]
elif split_chips == 512:
computation_shape = [16, 32, 1, cores_per_chip]
elif split_chips == 1024:
computation_shape = [32, 32, 1, cores_per_chip]
elif split_chips == 2048:
computation_shape = [64, 32, 1, cores_per_chip]
elif split_chips == 4096:
computation_shape = [128, 32, 1, cores_per_chip]
else:
assert False, ('Model parallelism with %d devices is currently not'
' supported.' % split_size)
assert computation_shape is not None
return computation_shape
[docs]def ShardedFilePatternToGlob(file_pattern):
"""Converts a file pattern path@shards to path-?????-of-shards."""
if ',' in file_pattern:
raise ValueError(
'ShardedFilePatternToGlob does not support multiple file patterns.')
if '@' not in file_pattern:
return file_pattern
path, shards = file_pattern.split('@')
if shards == '*':
return f'{path}-?????-of-*'
return f'{path}-?????-of-{int(shards):05}'
[docs]def ComputeNceAndAuc(probs, targets, mask):
"""Compute normalized cross entropy and AUC of the PR curve for a batch.
Args:
probs: a tensor of shape [batch, time].
targets: a tensor of shape [batch, time], where each element is either 0 or
1 indicating wrong or correct.
mask: a tensor of shape [batch, time], a mask for hyp sequence.
Returns:
nce: a tensor of shape [1], the normalized cross entropy value.
auc: a tensor of shape [1], the AUC value.
"""
def LogWithClip(tensor, clip_value_min=1e-8):
"""Clip all elements of a tensor to a minimum before taking log."""
return tf.math.log(tf.clip_by_value(tensor, clip_value_min, 1.0))
bce = -targets * LogWithClip(probs) - (1 - targets) * LogWithClip(1 - probs)
num_cor = tf.reduce_sum(targets * mask)
num_tokens = tf.reduce_sum(mask)
wcr = num_cor / num_tokens
entropy = -wcr * LogWithClip(wcr) - (1 - wcr) * LogWithClip(1 - wcr)
avg_conditional_entropy = tf.reduce_mean(tf.boolean_mask(bce, mask))
nce = (entropy - avg_conditional_entropy) / entropy
auc = tf.metrics.auc(targets, probs, mask, curve='PR')[1]
return nce, auc
[docs]def GatherTensorValuesBySeqIndices(tensor, class_indices, keepdims=False):
"""Gather values from a 3d tensor according to sequences of indices.
Args:
tensor: a 3d tensor of [dim0, dim1, num_class], e.g. output from softmax.
class_indices: a 2d tensor of [dim0, dim1], where the second dim is a
sequence of class indices between 0 to num_class - 1, inclusive.
keepdims: bool, expand the last dimension of the returned tensor if True.
Returns:
A tensor ret of [dim0, dim1], where
ret[b, t] = tensor[b, t, indices[b, t]].
If keepdims is True, then ret has shape [dim0, dim1, 1].
"""
tensor = HasRank(tensor, 3)
class_indices = HasRank(class_indices, 2)
tensor = HasShape(tensor, GetShape(class_indices), 2)
dim0 = GetShape(class_indices)[0]
dim1 = GetShape(class_indices)[1]
dim0_indices = tf.tile(tf.expand_dims(tf.range(dim0), axis=-1), [1, dim1])
dim1_indices = tf.tile(tf.expand_dims(tf.range(dim1), axis=0), [dim0, 1])
gather_indices = tf.stack([
tf.cast(dim0_indices, dtype=class_indices.dtype),
tf.cast(dim1_indices, dtype=class_indices.dtype), class_indices
],
axis=-1)
ret = tf.gather_nd(tensor, gather_indices)
if keepdims:
ret = tf.expand_dims(ret, axis=-1)
return ret
[docs]def GetSoftmaxProbsBySeqIndices(logits, indices, keepdims=False):
"""Get softmax probabilities from index sequences given logits sequences.
Args:
logits: a tensor of [batch, time, num_class] or [time, batch, num_class].
indices: a tensor of [batch, time] or [time, batch].
keepdims: bool, expand the last dimension of the returned tensor if True.
Returns:
a tensor of [batch, time] or [time, batch] for the corresponding softmax
probabilities. If keepdims is True, returned tensor has a third dimension
of size 1.
"""
probs = tf.nn.softmax(logits)
return GatherTensorValuesBySeqIndices(probs, indices, keepdims)
[docs]def DivideNoNan(x, y):
"""Equivalent to tf.math.divide_no_nan but supports bfloat16."""
safe_y = tf.where(tf.equal(y, 0.), tf.ones_like(y), y)
return tf.where(tf.equal(y, 0.0), tf.zeros_like(x), x / safe_y)
[docs]def SequencePaddings(seqlen, maxlen=None):
mask = tf.sequence_mask(seqlen, maxlen, dtype=tf.float32)
return 1 - mask
[docs]def AppendDims(x, ndims):
if ndims == 1:
return tf.expand_dims(x, -1)
else:
return tf.reshape(x, tf.concat([tf.shape(x), [1] * ndims], axis=0))
[docs]def MaybeSoftCapLogits(x, cap=0.0):
"""Caps logits x to be within a certain range.
Args:
x: A float tensor, the logit values to be capped.
cap: a float, the limit to cap x within. If cap <= 0.0, x is not capped.
Returns:
logits after capping.
"""
if cap <= 0.0:
return x
else:
return cap * tf.math.tanh(x / cap)
[docs]def GetTpuEmbeddingGraphCollection():
"""Return the graph collection that stores the TpuEmbeddingCollection."""
tpu_emb_graph_collection = tf.get_collection_ref('__tpu_embedding_collection')
assert len(tpu_emb_graph_collection) <= 1
return tpu_emb_graph_collection
[docs]class AuxLossContext:
"""Context that holds a list of aux-losses.
By default it is non-reentrant, but can be specified as reentrant explicitly
when creating an inner context.
"""
_global_stack = []
[docs] @classmethod
def Current(cls):
"""Returns current context or None."""
if cls._global_stack:
return cls._global_stack[-1]
else:
return None
def __init__(self, reentrant=False):
self.aux_loss_tensors = []
self._reentrant = reentrant
[docs] def AddLoss(self, loss):
self.aux_loss_tensors.append(loss)
@property
def aux_losses(self):
return self.aux_loss_tensors
def __enter__(self):
if not self._reentrant:
assert not self._global_stack, 'no re-entry'
self._global_stack.append(self)
return self
def __exit__(self, *args):
self._global_stack.pop()
[docs]def GetTrainableVariables(scope, bprop_variable_filter,
bprop_variable_exclusion, vmap):
"""Returns trainable vars.
Args:
scope: A Python str.
bprop_variable_filter: see BaseTask.Params().bprop_variable_filter.
bprop_variable_exclusion: see BaseTask.Params().bprop_variable_exclusion.
vmap: A NestedMap of var_path(str) -> tf Variable.
Returns:
A filtered NestedMap of var_path(str) -> trainable tf Variable.
"""
pos = re.compile(bprop_variable_filter) if bprop_variable_filter else None
neg = re.compile(
bprop_variable_exclusion) if bprop_variable_exclusion else None
def VariableFilter(v):
"""Returns True if variable v should be optimized by this learner."""
if not v.trainable:
return False
if pos and not pos.search(v.name):
tf.logging.info('%s: disabled by bprop_variable_filter: %s', scope,
v.name)
return False
if neg and neg.search(v.name):
tf.logging.info('%s: disabled by bprop_variable_exclusion: %s', scope,
v.name)
return False
return True
return vmap.Filter(VariableFilter)
[docs]def BlockDiagonalMatmul(inputs, w, input_num_blocks, mix_kernel=None):
"""Block diagonal matmul.
With mix, the results from the blocked matmul are (linearly) mixed with
trainable weights in mix_kernel.
Args:
inputs: a tf.Tensor with the last dimension being the dimension for matmul.
w: an order-3 tf.Tensor of shape (input_num_blocks, input_dim //
input_num_blocks, output_dim // input_num_blocks)
input_num_blocks: an int specifying number of blocks for the input.
mix_kernel: an (optional) order-2 tf.Tensor of shape (input_num_blocks,
input_num_blocks).
Returns:
A tf.Tensor of shape: inputs.shape[:-1] + [w.shape[-1]].
"""
input_splitted = tf.split(inputs, input_num_blocks, axis=-1)
output_splitted = []
for i, input_i in enumerate(input_splitted):
output_splitted.append(tf.matmul(input_i, w[i, :, :]))
if mix_kernel is not None:
output_mixed = [0.0] * input_num_blocks
for i in range(input_num_blocks):
for j in range(input_num_blocks):
output_mixed[i] += mix_kernel[i, j] * output_splitted[j]
output_splitted = output_mixed
return tf.concat(output_splitted, axis=-1)
[docs]def BlockDiagonalProjectLastDim(inputs,
weight,
input_dim,
output_dim,
num_blocks=1,
mix_kernel=None,
use_einsum=True):
"""Block diagonal linear projection on the last dim with mix.
This is a TPU efficient implementation to avoid reshaping inputs to Rank-2
tensor by using Einsum for the compute.
Args:
inputs: An input Tensor, the last dimension of which is input_dim.
weight: A weight matrix with shape [input_dim, output_dim].
input_dim: An integer or a symbolic dim, the last dimension of the inputs.
output_dim: An integer or a symbolic dim, the last dimension of the outputs.
num_blocks: An integer or a symbolic dim, the number of blocks.
mix_kernel: an (optional) order-2 tf.Tensor of shape [num_blocks,
num_blocks].
use_einsum: use tf.einsum for calculation (True), or use tf.matmul (False)
Returns:
An output Tensor of the same rank as inputs, the last dimension is
output_dim.
"""
input_dim = int(
symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim) else input_dim)
output_dim = int(
symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim
) else output_dim)
# Assert input_dim and output_dim
inputs = with_dependencies([assert_equal(GetShape(inputs)[-1], input_dim)],
inputs)
weight = with_dependencies([
assert_equal(GetShape(weight)[0], num_blocks),
assert_equal(GetShape(weight)[1], input_dim // num_blocks),
assert_equal(GetShape(weight)[-1], output_dim // num_blocks)
], weight)
if (inputs.shape is not None and inputs.shape.rank is not None and
inputs.shape.rank == 2):
# Shortcut for 2D inputs: just use tf.matmul, no need for Einsum.
outputs = BlockDiagonalMatmul(inputs, weight, num_blocks, mix_kernel)
elif (use_einsum and use_tpu() and inputs.shape is not None and
inputs.shape.rank is not None and inputs.shape.rank < 26):
# Avoids reshape if feasible and uses Einsum.
# This is equivalent to:
# outputs = tf.einsum('...y,yz->...z', inputs, weight)
# Unfortunately ... in einsum() leads to extra HBM usage.
s = ''.join([chr(x) for x in range(97, 123)]) # abc...xyz
r = inputs.shape.rank
input_splitted = tf.split(inputs, num_blocks, axis=-1)
output_splitted = []
for i, input_i in enumerate(input_splitted):
output_splitted.append(
tf.einsum('{0}y,yz->{0}z'.format(s[:r - 1]), input_i,
weight[i, :, :]))
if mix_kernel is not None:
output_mixed = [0.0] * num_blocks
for i in range(num_blocks):
for j in range(num_blocks):
output_mixed[i] += mix_kernel[i, j] * output_splitted[j]
output_splitted = output_mixed
outputs = tf.concat(output_splitted, axis=-1)
else:
# not use_einsum or not use_tpu() or inputs.shape.rank >= 26
outputs = BlockDiagonalMatmul(
tf.reshape(inputs, ToStaticShape([-1, input_dim])), weight, num_blocks,
mix_kernel)
outputs = tf.reshape(
outputs,
tf.concat([
tf.cast(GetShape(inputs)[:-1], tf.int32),
ToStaticShape([output_dim])
],
axis=0))
return outputs
[docs]def GetProcessedCheckpoints(runner_dir):
"""Returns the list of checkpoints previously processed by this runner."""
# Set up (or reload) a file storing the list of previously processed
# checkpoints. This caching allows jobs to run on VMs which may be
# interrupted without duplicating work.
processed_ckpts_path = os.path.join(runner_dir, 'processed_ckpts.txt')
if not tf.io.gfile.exists(processed_ckpts_path):
with tf.io.gfile.GFile(processed_ckpts_path, 'w') as f:
f.write('')
with tf.io.gfile.GFile(processed_ckpts_path, 'r') as f:
processed_ckpts = list(line.strip() for line in f.readlines())
return processed_ckpts
[docs]def UpdateProcessedCheckpoints(runner_dir, ckpt_path):
"""Denotes 'ckpt_path' as having been processed by this runner."""
processed_ckpts_path = os.path.join(runner_dir, 'processed_ckpts.txt')
# Some file systems don't support append operations, so we rewrite whole
# file to append the latest checkpoint.
processed_ckpts = GetProcessedCheckpoints(runner_dir)
processed_ckpts.append(ckpt_path)
with tf.io.gfile.GFile(processed_ckpts_path, 'w') as f:
f.write('\n'.join(processed_ckpts) + '\n')
[docs]def MergeDictsWithValueCheck(dict1, dict2):
"""Merge two dictionaries and assert keys in both have identical values."""
for key in set(dict1) & set(dict2):
# The values must be the same object
if dict1[key] is not dict2[key]:
raise RuntimeError(f'The same key {key} corresponds to different values '
f'in the dictionaries: {dict1[key]} vs {dict2[key]}')
return dict1 | dict2
[docs]def MergeDuplicateIds(ids, paddings, extra_tensors=None):
"""Merge consecutive duplicated ids.
Given ids = [4, 4, 5, 6, 6, 5, 0, 0] and paddings =[0, 0, 0, 0, 0, 0, 1, 1],
this function returns ret_ids = [4, 5, 6, 5, 0, 0, 0, 0] and paddings = [
0, 0, 0, 0, 1, 1, 1, 1] by merging consecutive duplicated ids.
Args:
ids: A non-negative tensor of shape [batch, time].
paddings: A padding tensor of shape [batch, time] with "0" non padded, and
"1" as padded..
extra_tensors: A `.NestedMap` containing tensors that need to be
deduplicated according to ids, each tensor at least has two dimensions.
Returns:
ret_ids: same as ids.
ret_paddings: same as paddings.
ret_tensors: same as extra_tensors.
"""
ids = with_dependencies([assert_greater_equal(ids, 0)], ids)
prev_ids = tf.pad(ids, [[0, 0], [1, 0]], constant_values=-1)[:, :-1]
keep = tf.cast(tf.math.not_equal(ids, prev_ids), tf.int32) * tf.cast(
1 - paddings, tf.int32)
b, t = GetShape(ids)
# Generate descend_keep in descending order for each row and set elements in
# the matrix to 0 if they are duplicated ids.
descend_keep = keep * tf.range(t, 0, -1, dtype=tf.int32)
# Get the indices of non-duplicated ids along the time axis.
sorted_indices = tf.argsort(descend_keep, stable=True, direction='DESCENDING')
# Get the batch indices.
batch_indices = tf.tile(tf.expand_dims(tf.range(b), -1), [1, t])
# Stack them to get 2-d indices.
ids_indices = tf.stack([batch_indices, sorted_indices], axis=2)
seq_mask = tf.sequence_mask(tf.reduce_sum(keep, axis=-1), t, paddings.dtype)
ret_paddings = 1. - seq_mask
ret_ids = tf.gather_nd(ids, ids_indices) * tf.cast(seq_mask, ids.dtype)
ret_tensors = NestedMap()
if extra_tensors:
for key, tensor in extra_tensors.items():
tensor_mask = ExpandTo(seq_mask, GetRank(tensor))
ret_tensors[key] = tf.gather_nd(tensor, ids_indices) * tf.cast(
tensor_mask, tensor.dtype)
return ret_ids, ret_paddings, ret_tensors
[docs]def DecodeProtoField(serialized_protos, message_type, field_name, output_type):
"""Decodes a non-repeated field in a proto.
Args:
serialized_protos: A string Tensor of shape [batch].
message_type: Name of the proto message type. Since tf.io.decode_proto() is
called with the default descriptor_source='local://', the C++ (not Python!)
proto definition(s) must be linked to the binary. You can link in a proto
descriptor by creating a cc_library target with alwayslink=1.
field_name: Name of the field.
output_type: A DType for the output.
Returns:
A Tensor of shape [batch].
"""
_, [output] = tf.io.decode_proto(serialized_protos, message_type,
[field_name], [output_type])
return tf.squeeze(output, -1)
[docs]def DecodeRepeatedProtoField(serialized_protos, message_type, field_name,
output_type):
"""Decodes a repeated field in a proto.
Args:
serialized_protos: A string Tensor of shape [batch].
message_type: Name of the proto message type. Since tf.io.decode_proto() is
called with the default descriptor_source='local://', the C++ (not Python!)
proto definition(s) must be linked to the binary. You can link in a proto
descriptor by creating a cc_library target with alwayslink=1.
field_name: Name of the field.
output_type: A DType for the output.
Returns:
A Tensor of shape [batch, field_name_size].
"""
[output] = tf.io.decode_proto(serialized_protos, message_type, [field_name],
[output_type]).values
return output