Source code for lingvo.core.gshard_utils

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for applying xla-sharding to a model."""

import contextlib
from typing import Dict, List, Optional, Sequence

from lingvo import compat as tf
from lingvo.core import py_utils_flags
from lingvo.core import thread_local_utils
import numpy as np

import sentencepiece as sentencepiece_processor
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.xla import xla_data_pb2
# pylint: disable=g-import-not-at-top
try:
  from tensorflow.python.compiler.xla.experimental import xla_sharding
except ImportError:
  # OSS backward compatibility, can be removed when TF is updated.
  from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding  # pytype: disable=import-error
# pylint: enable=g-direct-tensorflow-import

ThreadLocalStack = thread_local_utils.ThreadLocalStack


[docs]def Split(x, split_dimension, num_devices, use_sharding_op=True, input_shape=None): """Wrapper for xla_sharding.split. Args: x: Tensor to annotate. split_dimension: xla_sharding.split arg. num_devices: xla_sharding.split arg. use_sharding_op: If true, adds a sharding op to set the sharding: tensor = gen_xla_ops.xla_sharding(tensor) hyouklee@: use_sharding_op=False "It adds the sharding attribute to the op itself. The outcome is that, that information could be lost by TF graph transformations. Also, directly attaching the sharding annotation to the op caused some compilation failures in the past (due to incompatible shardings), so the plan is to make use_sharding_op to be the default." "The only case I would set it to False today is when annotating weights. Weight annotation does some special handling, so there may be some changes needed in that logic if we add separate sharding op." input_shape: The shape of the original tensor. Returns: Tensor conditionally annotated with sharding. """ if not py_utils_flags.use_tpu() or num_devices is None or not num_devices > 1: return x return xla_sharding.split( x, split_dimension, num_devices, input_shape=input_shape, use_sharding_op=use_sharding_op, )
[docs]def Replicate(x, use_sharding_op=True): """Wrapper of xla_sharding.replicate.""" if not py_utils_flags.use_tpu(): return x return xla_sharding.replicate(x, use_sharding_op=use_sharding_op)
_MESH_SPLIT_DIM_PREFIXES = ThreadLocalStack() _MANUAL_MESH_DIMS = ThreadLocalStack()
[docs]def GetMeshSplitSharding(device_mesh, tensor_split_dims_mapping): """Wrapper of xla_sharding.mesh_split_sharding().""" # Apply the prefix in the context. tensor_split_dims_mapping = ( _MESH_SPLIT_DIM_PREFIXES.stack + tensor_split_dims_mapping) if _MANUAL_MESH_DIMS.stack: return xla_sharding.mesh_split_sharding( device_mesh, tensor_split_dims_mapping, manual_mesh_dims=_MANUAL_MESH_DIMS.stack) # Do not include manual_mesh_dims to support legacy TF versions. return xla_sharding.mesh_split_sharding(device_mesh, tensor_split_dims_mapping)
[docs]def MeshSplit(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=True, unspecified_dims=None): """Wrapper of xla_sharding.mesh_split().""" if (not py_utils_flags.use_tpu() or tensor_split_dims_mapping is None or device_mesh is None or device_mesh.size <= 1): return x # Apply the prefix in the context. tensor_split_dims_mapping = ( _MESH_SPLIT_DIM_PREFIXES.stack + tensor_split_dims_mapping) num_tiles = np.prod( [device_mesh.shape[i] for i in tensor_split_dims_mapping if i >= 0]) if num_tiles <= 1: return x if _MANUAL_MESH_DIMS.stack or unspecified_dims: return xla_sharding.mesh_split( x, device_mesh, tensor_split_dims_mapping, use_sharding_op=use_sharding_op, manual_mesh_dims=_MANUAL_MESH_DIMS.stack, unspecified_dims=unspecified_dims) # Do not include manual_mesh_dims or unspecified_dims to support legacy TF # versions. return xla_sharding.mesh_split( x, device_mesh, tensor_split_dims_mapping, use_sharding_op=use_sharding_op)
[docs]@contextlib.contextmanager def MeshSplitDimPrefixContext(prefix_mesh_dim): """Adds a prefix mesh dim for tensor_split_dims_mapping in MeshSplit.""" if prefix_mesh_dim is not None: _MESH_SPLIT_DIM_PREFIXES.stack.append(prefix_mesh_dim) try: yield finally: if prefix_mesh_dim is not None: _MESH_SPLIT_DIM_PREFIXES.stack.pop()
[docs]def GetMeshSplitDimPrefixContext(): return _MESH_SPLIT_DIM_PREFIXES.stack
[docs]@contextlib.contextmanager def ManualMeshDimContext(mesh_dim): """Adds a context where mesh_dim is used for manual sharding.""" if mesh_dim is not None: _MANUAL_MESH_DIMS.stack.append(mesh_dim) try: yield finally: if mesh_dim is not None: _MANUAL_MESH_DIMS.stack.pop()
[docs]def ZigzagOrderOnDeviceMesh(device_mesh, zigzag_mesh_dim): """Permutes device_mesh to form zigzag order along zigzag_mesh_dim.""" # Where there is no wrap-around links along one edge, we might # reduce all-reduce latency along that edge by permuting the device order: # instead of # 0 - 1 - 2 - 3 - 4 - 5 - 6 - 7 # | | # +---------------------------+ # it will be # +-------+-------+-------+ # | | | | # 0 - 7 1 6 2 5 3 - 4 # | | | | # +-------+-------+-------+ xpose_dims = list(range(len(device_mesh.shape))) xpose_dims[0] = zigzag_mesh_dim xpose_dims[zigzag_mesh_dim] = 0 device_mesh = np.transpose(device_mesh, xpose_dims) permuted_mesh = np.copy(device_mesh) for i in range(device_mesh.shape[0]): zigzag_i = i * 2 if i * 2 < device_mesh.shape[0] else ( device_mesh.shape[0] - i) * 2 - 1 permuted_mesh[i, ...] = device_mesh[zigzag_i, ...] return np.transpose(permuted_mesh, xpose_dims)
[docs]def GetNonPod2dMesh(device_mesh_shape, physical_mesh_shape): """Returns a 2D device mesh on slices smaller than a pod.""" assert len(device_mesh_shape) == 2 assert len(physical_mesh_shape) == 3 if device_mesh_shape[1] != physical_mesh_shape[1] * physical_mesh_shape[2]: tf.logging.warning( 'This only works when device_mesh_shape == [physical_mesh_shape[0], ' ' physical_mesh_shape[1] * physical_mesh_shape[2]]. ' 'If device_mesh_shape is [32, 16] where physical_mesh_shape is ' ' [16, 16, 2]. we can transpose the result of this function ' 'GetNonPod2dMesh([16, 32], [16, 16, 2]).') # Form a ring on inner mesh dim. device_mesh = np.reshape( np.arange(0, np.product(device_mesh_shape)), physical_mesh_shape) device_mesh = np.transpose(device_mesh, [0, 2, 1]) device_mesh[:, 1, :] = device_mesh[:, 1, ::-1] # Next line: reshape back to mesh shape device_mesh = np.reshape(device_mesh, device_mesh_shape) # Next line: zigzag on outer mesh dim (8). It doesn't have wrap link, either. device_mesh = ZigzagOrderOnDeviceMesh(device_mesh, zigzag_mesh_dim=0) return device_mesh
[docs]def ReshapeDim(x, dim, dim_reshape_segments=None): """Reshapes tensor x according to dim_reshape_segments. Args: x: A input Tensor of shape [..., x.shape[dim], ...]. dim: The dim that needs to be reshaped. dim_reshape_segments: The leading dim size of the reshaped dims. Returns: A Tensor of shape [..., dim_reshape_segments, x.shape[dim] // dim_reshape_segments, ...]. """ if dim_reshape_segments is None: return x assert x.shape[dim] % dim_reshape_segments == 0 new_shape = list(x.shape[0:dim]) new_shape.append(dim_reshape_segments) new_shape.append(x.shape[dim] // dim_reshape_segments) new_shape.extend(d for d in x.shape[dim + 1:]) return tf.reshape(x, new_shape)
[docs]class TensorShardingSpec: """Represents a sharding annotation for GShard/XLA.""" def __init__(self, split_dims_mapping: Optional[List[int]] = None, device_mesh: Optional[np.ndarray] = None, uneven_padding: Optional[List[int]] = None): """Creates a sharding specification. Args: split_dims_mapping: a list of integers that map each tensor axis to the device mesh axis along which it is sharded. Its length is the tensor rank, and split_dims_mapping[i] is device mesh axis for tensor dimension i. Use -1 for tensor dimensions that are not sharded. If the list is set to None, the sharding will be treated as replicated. device_mesh: a numpy.ndarray describing the topology of the device mesh and each element is the ID of the device in the topology. Not needed for replicated sharding, where it can be set to None. uneven_padding: amount of padding applied to the right side of each tensor dimension due to uneven partitioning of the shape in SPMD. """ self._split_dims_mapping: Optional[List[int]] = split_dims_mapping self._device_mesh: Optional[np.ndarray] = device_mesh self._uneven_padding = uneven_padding
[docs] @classmethod def FromFullShape(cls, full_shape: Sequence[int], split_dims_mapping: List[int], device_mesh: np.ndarray): """Creates tiled sharding spec with uneven padding computed from shape.""" uneven_padding = [0] * len(split_dims_mapping) for i in range(len(split_dims_mapping)): if split_dims_mapping[i] >= 0: partitions = device_mesh.shape[split_dims_mapping[i]] shard_size = (full_shape[i] + partitions - 1) // partitions uneven_padding[i] = shard_size * partitions - full_shape[i] return TensorShardingSpec(split_dims_mapping, device_mesh, uneven_padding)
[docs] def ApplyToTensor(self, tensor: tf.Tensor, use_sharding_op: bool = True) -> tf.Tensor: if self.is_replicated: return xla_sharding.replicate(tensor, use_sharding_op=use_sharding_op) return MeshSplit( tensor, self.device_mesh, self.split_dims_mapping, use_sharding_op=use_sharding_op)
[docs] def ApplyToVariable(self, variable: tf.Variable) -> tf.Variable: if self.is_replicated: return xla_sharding.replicate(variable, use_sharding_op=False) return MeshSplit( variable, self.device_mesh, self.split_dims_mapping, use_sharding_op=False)
[docs] def ShardShape(self, full_shape: Sequence[int]) -> Sequence[int]: """Returns the shape after applying this sharding to full_shape.""" if self.is_replicated: return full_shape shard_shape = list(full_shape) for i in range(len(self._split_dims_mapping)): if self._split_dims_mapping[i] >= 0: partitions = self._device_mesh.shape[self._split_dims_mapping[i]] shard_shape[i] = (full_shape[i] + partitions - 1) // partitions return shard_shape
[docs] def ManualToAutoPartitioning(self, tensor: tf.Tensor) -> tf.Tensor: """Converts manually sharded tensor to full-size for auto partitioning.""" full_shape = list(tensor.shape) if not self.is_replicated: for i in range(len(self._split_dims_mapping)): if self._split_dims_mapping[i] >= 0: full_shape[i] *= self._device_mesh.shape[self._split_dims_mapping[i]] if self._uneven_padding is not None and self._uneven_padding[i] > 0: full_shape[i] -= self._uneven_padding[i] return xla_sharding.manual_to_auto_spmd_partition( tensor, self.ToXlaOpSharding().SerializeToString(), full_shape)
[docs] def AutoToManualPartitioning(self, tensor: tf.Tensor) -> tf.Tensor: """Converts full-size tensor (auto partitioning) to manually sharded.""" manual = xla_sharding.auto_to_manual_spmd_partition( tensor, self.ToXlaOpSharding().SerializeToString()) xla_sharding.Sharding.manual().apply_to_tensor(manual) return manual
[docs] def ToXlaOpSharding(self) -> xla_data_pb2.OpSharding: if self.is_replicated: return xla_sharding.Sharding.replicate().proto dims_mapping = _MESH_SPLIT_DIM_PREFIXES.stack + self.split_dims_mapping return xla_sharding.mesh_split_sharding(self.device_mesh, dims_mapping).proto
[docs] @classmethod def FromXlaOpSharding( cls, op_sharding_proto: xla_data_pb2.OpSharding) -> 'TensorShardingSpec': """Parses from an XLA OpSharding proto.""" if op_sharding_proto.type == xla_data_pb2.OpSharding.OTHER: device_mesh_shape = op_sharding_proto.tile_assignment_dimensions device_mesh = np.reshape( np.array(op_sharding_proto.tile_assignment_devices), device_mesh_shape) if op_sharding_proto.replicate_on_last_tile_dim: split_dims_mapping = list(range(len(device_mesh_shape) - 1)) else: split_dims_mapping = list(range(len(device_mesh_shape))) prefix = _MESH_SPLIT_DIM_PREFIXES.stack if prefix: assert split_dims_mapping[:len(prefix)] == prefix return cls(split_dims_mapping[len(prefix):], device_mesh) else: return cls.ReplicatedSpec()
[docs] def AddLeadingDims(self, num_dims: int = 1) -> 'TensorShardingSpec': if self.is_replicated: return self new_padding = (None if self._uneven_padding is None else [0] * num_dims + self._uneven_padding) return TensorShardingSpec([-1] * num_dims + self._split_dims_mapping, self.device_mesh, new_padding)
[docs] def RemoveLeadingDims(self, num_dims: int = 1) -> 'TensorShardingSpec': if self.is_replicated: return self new_padding = (None if self._uneven_padding is None else self._uneven_padding[num_dims:]) return TensorShardingSpec(self._split_dims_mapping[num_dims:], self.device_mesh, new_padding)
[docs] def RemoveDim(self, dim) -> 'TensorShardingSpec': """Returns a copy of self with dimension 'dim' removed.""" if self.is_replicated: return self if dim < 0: num_dims = len(self._split_dims_mapping) dim = num_dims + dim assert dim >= 0 and dim < len(self._split_dims_mapping) new_padding = (None if self._uneven_padding is None else self._uneven_padding[:dim] + self._uneven_padding[dim + 1:]) split_dims_mapping = ( self._split_dims_mapping[:dim] + self._split_dims_mapping[dim + 1:]) return TensorShardingSpec(split_dims_mapping, self.device_mesh, new_padding)
[docs] @classmethod def ReplicatedSpec(cls): return TensorShardingSpec()
@property def split_dims_mapping(self) -> Optional[List[int]]: return self._split_dims_mapping @property def device_mesh(self) -> Optional[np.ndarray]: return self._device_mesh @property def is_replicated(self) -> bool: if self.device_mesh is None or self.split_dims_mapping is None: return True for mesh_dim in self.split_dims_mapping: if mesh_dim >= 0 and self.device_mesh.shape[mesh_dim] > 1: return False return True @property def mesh_dim_to_tensor_dim_mapping(self) -> Dict[int, int]: mapping = {} if self.is_replicated: return mapping for i in range(len(self.split_dims_mapping)): if self.split_dims_mapping[i] >= 0: mapping[self.split_dims_mapping[i]] = i return mapping @property def uneven_padding(self) -> Optional[List[int]]: return self._uneven_padding
[docs]def GetVarSharding(var: tf.Variable) -> TensorShardingSpec: """Returns the sharding directly attached to a variable.""" sharding = xla_sharding.get_op_sharding(var.op) if not sharding: return TensorShardingSpec.ReplicatedSpec() proto = xla_data_pb2.OpSharding() proto.ParseFromString(sharding) spec_without_padding = TensorShardingSpec.FromXlaOpSharding(proto) # Consider uneven padding. return TensorShardingSpec.FromFullShape( [int(d) for d in var.shape], spec_without_padding.split_dims_mapping, spec_without_padding.device_mesh)
_spm_cache = {}
[docs]def LoadSpm(model_file): """Loads SPM from model_file. Returns SentencePieceProcessor.""" global _spm_cache if model_file in _spm_cache: return _spm_cache[model_file] else: spm = sentencepiece_processor.SentencePieceProcessor() spm.Load(model_file) _spm_cache[model_file] = spm return spm