lingvo.core.gshard_utils module
Utilities for applying xla-sharding to a model.
- lingvo.core.gshard_utils.Split(x, split_dimension, num_devices, use_sharding_op=True, input_shape=None)[source]
Wrapper for xla_sharding.split.
- Parameters
x – Tensor to annotate.
split_dimension – xla_sharding.split arg.
num_devices – xla_sharding.split arg.
use_sharding_op –
If true, adds a sharding op to set the sharding: tensor = gen_xla_ops.xla_sharding(tensor)
- hyouklee@: use_sharding_op=False “It adds the sharding attribute to the op
itself. The outcome is that, that information could be lost by TF graph transformations. Also, directly attaching the sharding annotation to the op caused some compilation failures in the past (due to incompatible shardings), so the plan is to make use_sharding_op to be the default.” “The only case I would set it to False today is when annotating weights. Weight annotation does some special handling, so there may be some changes needed in that logic if we add separate sharding op.”
input_shape – The shape of the original tensor.
- Returns
Tensor conditionally annotated with sharding.
- lingvo.core.gshard_utils.Replicate(x, use_sharding_op=True)[source]
Wrapper of xla_sharding.replicate.
- lingvo.core.gshard_utils.GetMeshSplitSharding(device_mesh, tensor_split_dims_mapping)[source]
Wrapper of xla_sharding.mesh_split_sharding().
- lingvo.core.gshard_utils.MeshSplit(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=True, unspecified_dims=None)[source]
Wrapper of xla_sharding.mesh_split().
- lingvo.core.gshard_utils.MeshSplitDimPrefixContext(prefix_mesh_dim)[source]
Adds a prefix mesh dim for tensor_split_dims_mapping in MeshSplit.
- lingvo.core.gshard_utils.ManualMeshDimContext(mesh_dim)[source]
Adds a context where mesh_dim is used for manual sharding.
- lingvo.core.gshard_utils.ZigzagOrderOnDeviceMesh(device_mesh, zigzag_mesh_dim)[source]
Permutes device_mesh to form zigzag order along zigzag_mesh_dim.
- lingvo.core.gshard_utils.GetNonPod2dMesh(device_mesh_shape, physical_mesh_shape)[source]
Returns a 2D device mesh on slices smaller than a pod.
- lingvo.core.gshard_utils.ReshapeDim(x, dim, dim_reshape_segments=None)[source]
Reshapes tensor x according to dim_reshape_segments.
- Parameters
x – A input Tensor of shape […, x.shape[dim], …].
dim – The dim that needs to be reshaped.
dim_reshape_segments – The leading dim size of the reshaped dims.
- Returns
A Tensor of shape […, dim_reshape_segments, x.shape[dim] // dim_reshape_segments, …].
- class lingvo.core.gshard_utils.TensorShardingSpec(split_dims_mapping: Optional[List[int]] = None, device_mesh: Optional[ndarray] = None, uneven_padding: Optional[List[int]] = None)[source]
Bases:
object
Represents a sharding annotation for GShard/XLA.
- classmethod FromFullShape(full_shape: Sequence[int], split_dims_mapping: List[int], device_mesh: ndarray)[source]
Creates tiled sharding spec with uneven padding computed from shape.
- ShardShape(full_shape: Sequence[int]) Sequence[int] [source]
Returns the shape after applying this sharding to full_shape.
- ManualToAutoPartitioning(tensor: Tensor) Tensor [source]
Converts manually sharded tensor to full-size for auto partitioning.
- AutoToManualPartitioning(tensor: Tensor) Tensor [source]
Converts full-size tensor (auto partitioning) to manually sharded.
- classmethod FromXlaOpSharding(op_sharding_proto: OpSharding) TensorShardingSpec [source]
Parses from an XLA OpSharding proto.
- AddLeadingDims(num_dims: int = 1) TensorShardingSpec [source]
- RemoveLeadingDims(num_dims: int = 1) TensorShardingSpec [source]
- RemoveDim(dim) TensorShardingSpec [source]
Returns a copy of self with dimension ‘dim’ removed.
- lingvo.core.gshard_utils.GetVarSharding(var: VariableV1) TensorShardingSpec [source]
Returns the sharding directly attached to a variable.