lingvo.core.moe_layers module¶
Layers and utilities that facilitate building MOE models.
-
class
lingvo.core.moe_layers.
VarLayer
(*args, **kwargs)[source]¶ Bases:
lingvo.core.base_layer.BaseLayer
Container for variables.
-
FProp
(theta, *args, **kwargs)[source]¶ Forward propagation.
The central interface that subclasses should implement. The caller calls
FProp
with atheta
dictionary. E.g.:foo = InstanceOfASubClassOfFoo(params) y = foo.FProp(foo.theta, x)
The implementation of
FProp()
computes a function given the theta and the inputs. E.g.:subs = self.children inputs = args[0] a0 = subs.linear.FProp(theta.linear, inputs) a1 = subs.softmax.FProp(theta.softmax, a0) # The same layer applied twice. a2 = subs.linear.FProp(theta.linear, a1) return a2
- Parameters
theta – A
NestedMap
object containing weights’ values of this layer and its children layers.*args – List args.
**kwargs – Keyward args.
-
-
lingvo.core.moe_layers.
ShardedWeightParams
(shape, init=None, dtype=None, collections=None, tensor_split_dims_mapping=None)[source]¶ Returns a hyperparams for a weight variable with optional XLA sharding.
-
class
lingvo.core.moe_layers.
ShardedVarLayer
(*args, **kwargs)[source]¶ Bases:
lingvo.core.moe_layers.VarLayer
Container for variables whose values sharded across different devices.
-
FProp
(theta, *args, **kwargs)[source]¶ Forward propagation.
The central interface that subclasses should implement. The caller calls
FProp
with atheta
dictionary. E.g.:foo = InstanceOfASubClassOfFoo(params) y = foo.FProp(foo.theta, x)
The implementation of
FProp()
computes a function given the theta and the inputs. E.g.:subs = self.children inputs = args[0] a0 = subs.linear.FProp(theta.linear, inputs) a1 = subs.softmax.FProp(theta.softmax, a0) # The same layer applied twice. a2 = subs.linear.FProp(theta.linear, a1) return a2
- Parameters
theta – A
NestedMap
object containing weights’ values of this layer and its children layers.*args – List args.
**kwargs – Keyward args.
-
-
class
lingvo.core.moe_layers.
StateLayer
(*args, **kwargs)[source]¶ Bases:
lingvo.core.base_layer.BaseLayer
Container for recurrent state for incremental decoding.
It has two operation modes.
During training, it does nothing. It expects that FProp(x, t) is called with t=None, and returns x unchanged.
During decoding, it expects:
t: an int32 scalar and x: a tensor of shape
[batch, 1, ...]
.It updates state
x_full[:, t, :] <- x[:, 0, :]
and returns x_full. The shape of x_full is then[batch, time, ...]
.The state is stored as theta.state attribute.
To construct initial state, call InitState classmethod on the root layer. InitState() will traverse root layer children recursively, will initialize internal state for each StateLayer instance, and will return a nested tuple of states.
For incremental iteration the static methods work as follows:
dec = builder.DecoderLayerStack(...).Instantiate() state0 = StateLayer.InitState(dec, shape=[tgt_batch, max_len]) theta0 = StateLayer.UpdateTheta(dec, dec.theta, state0, t=0) # (FProp in nested StateLayer now has access to 'state0' and 't') dec.FProp(theta0, ...) # FProp will modify theta0 in-place state1 = state0.copy() state1 = StateLayer.UpdateState(dec, theta0, state1)
-
_use_flat_beam_search
= False¶
-
NewState
(shape)[source]¶ Returns initial state.
- Parameters
shape – [batch, time] for beam_search_tpu_helper or [batch, beam, time] for flat_beam_search.
- Returns
- zero-initialized state tensor with shape [batch, time, …] for
beam_search_tpu_helper or [time, batch, beam, …] for flat_beam_search.
- Raises
ValueError – the length of shape is not 2 or 3.
-
FProp
(theta, x)[source]¶ Forward propagation.
The central interface that subclasses should implement. The caller calls
FProp
with atheta
dictionary. E.g.:foo = InstanceOfASubClassOfFoo(params) y = foo.FProp(foo.theta, x)
The implementation of
FProp()
computes a function given the theta and the inputs. E.g.:subs = self.children inputs = args[0] a0 = subs.linear.FProp(theta.linear, inputs) a1 = subs.softmax.FProp(theta.softmax, a0) # The same layer applied twice. a2 = subs.linear.FProp(theta.linear, a1) return a2
- Parameters
theta – A
NestedMap
object containing weights’ values of this layer and its children layers.*args – List args.
**kwargs – Keyward args.
-
-
class
lingvo.core.moe_layers.
OverrideLayer
(*args, **kwargs)[source]¶ Bases:
lingvo.core.base_layer.BaseLayer
Allows to override arbitrary tensors in the graph.
If key is not set in the global context, FProp does nothing. Otherwise it returns value associated to ‘key’.
To override a tensor during my_layer.FProp:
OverrideLayer.Set(key, value) out_with_override = my_layer.FProp(...) OverrideLayer.Clear()
-
_OVERRIDE
= {}¶
-
FProp
(theta, x)[source]¶ Forward propagation.
The central interface that subclasses should implement. The caller calls
FProp
with atheta
dictionary. E.g.:foo = InstanceOfASubClassOfFoo(params) y = foo.FProp(foo.theta, x)
The implementation of
FProp()
computes a function given the theta and the inputs. E.g.:subs = self.children inputs = args[0] a0 = subs.linear.FProp(theta.linear, inputs) a1 = subs.softmax.FProp(theta.softmax, a0) # The same layer applied twice. a2 = subs.linear.FProp(theta.linear, a1) return a2
- Parameters
theta – A
NestedMap
object containing weights’ values of this layer and its children layers.*args – List args.
**kwargs – Keyward args.
-
Bases:
lingvo.core.base_layer.BaseLayer
Shared weights for embemdding lookup and softmax.
Returns the layer params.
Forward propagation.
The central interface that subclasses should implement. The caller calls
FProp
with atheta
dictionary. E.g.:foo = InstanceOfASubClassOfFoo(params) y = foo.FProp(foo.theta, x)
The implementation of
FProp()
computes a function given the theta and the inputs. E.g.:subs = self.children inputs = args[0] a0 = subs.linear.FProp(theta.linear, inputs) a1 = subs.softmax.FProp(theta.softmax, a0) # The same layer applied twice. a2 = subs.linear.FProp(theta.linear, a1) return a2
- Parameters
theta – A
NestedMap
object containing weights’ values of this layer and its children layers.*args – List args.
**kwargs – Keyward args.
-
lingvo.core.moe_layers.
Top2GatingOnLogits
(inputs, paddings, logits, num_devices, experts_dim, expert_capacity_dim, fprop_dtype, use_xla_sharding=True, second_expert_policy='all', second_expert_threshold=0.0, legacy_mtf_behavior=True, capacity_factor=None)[source]¶ Computes Top-2 gating for Mixture-of-Experts.
There are two expected usages of this function:
used with xla_sharding. In this case, ‘inputs’ corresponds to a sharded tensor across multiple tpu cores. The operations within this function are automatically sharded/replicated across tpu cores.
used within other projects where’inputs’ is always local to one tpu core. All computations below are carried out on one tpu core only. This function tries to dispatch examples across tpu cores in such a way that each expert is assigned no more than ‘expert_capacity_dim’ number of examples.
Below ` indicates common way of splitting along mesh dimension.
Dimensions cheat sheet:
G: group_dim S: group_size_dim E: number of experts C: capacity per expert M: model_dim (same as input_dim, same as output_dim) B: original batch_dim L: original sequence_length_dim
Note that for local_dispatch original batch BLM is reshaped into GSM, each group
g = 0...G-1
is being dispatched independently.- Parameters
inputs – G`SM Tensor.
paddings – G`S Tensor.
logits – G`SE Tensor.
num_devices – number of MoE devices for local dispatch
experts_dim – number of experts.
expert_capacity_dim – number of examples per minibatch(group) per expert. Each example is typically a vector of size input_dim, representing embedded token or an element of Transformer layer output.
fprop_dtype – activations datatype to use.
use_xla_sharding – bool, True if this function is used for the xla_sharding case.
second_expert_policy –
‘all’, ‘sampling’ or ‘random’.
’all’: we greedily pick the 2nd expert.
’sampling’: we sample the 2nd expert from the softmax.
’random’: we optionally ‘random’-ize dispatch to second-best expert proportional to (weight / second_expert_threshold).
second_expert_threshold – threshold for probability normalization for second_expert_policy == ‘random’.
legacy_mtf_behavior – bool, True if to match legacy mtf behavior exactly.
capacity_factor – if set, increases expert_capacity_dim to at least (group_size * capacity_factor) / experts_dim where
group_size
is the size of G dimension ofinputs
. If the value of expert_capacity_dim is already big enough no change is made.
TODO(lepikhin): get rid of the legacy_mtf_behavior flag.
- Returns
A tuple (aux_loss, combine_tensor, dispatch_tensor).
aux_loss: auxiliary loss, for equalizing the expert assignment ratios.
combine_tensor: G`SEC Tensor for combining expert outputs.
dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to experts.
-
lingvo.core.moe_layers.
Top2Gating
(w, inputs, paddings, num_devices, experts_dim, expert_capacity_dim, local_dispatch, fprop_dtype, use_xla_sharding=True, second_expert_policy='all', second_expert_threshold=0.0, legacy_mtf_behavior=True, capacity_factor=None)[source]¶ Computes Top-2 gating for Mixture-of-Experts.
See Top2GatingOnLogits for more details.
Note that for local_dispatch original batch BLM is reshaped into GSM, each group
g = 0...G-1
is being dispatched independently.- Parameters
w – gating weights for each experts.
inputs – G`SM Tensor.
paddings – G`S Tensor.
num_devices – number of MoE devices for local dispatch
experts_dim – number of experts.
expert_capacity_dim – number of examples per minibatch(group) per expert. Each example is typically a vector of size input_dim, representing embedded token or an element of Transformer layer output.
local_dispatch – whether dispatch is local to the group (G dim)
fprop_dtype – activations datatype to use.
use_xla_sharding – bool, True if this function is used for the xla_sharding case.
second_expert_policy – ‘all’ or ‘random’, we optionally ‘random’-ize dispatch to second-best expert proportional to (weight / second_expert_threshold).
second_expert_threshold – threshold for probability normalization for second_expert_policy == ‘random’.
legacy_mtf_behavior – True for legacy behavior with no re-normalization of expert assignment weights if we go over capacity or randomly decide to not dispatch to second expert.
capacity_factor – if set, increases expert_capacity_dim to at least
(group_size * capacity_factor) / experts_dim
wheregroup_size
is the size of G dimension ofinputs
. If the value of expert_capacity_dim is already big enough no change is made.
- Returns
A tuple (dispatch_tensor, combine_tensor, aux_loss).
dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to experts.
combine_tensor: G`SEC Tensor. combining expert outputs.
aux_loss: auxiliary loss, equalizing the expert assignment ratios.
-
lingvo.core.moe_layers.
FeedForwardNetworksApplyGating
(gating, inputs, reshaped_inputs, wi_split, wo_split, num_devices, num_groups, bi_split=None, bo_split=None, dropout_rate=0.0, device_mesh=None, gsm_split=None, egcm_split=None, gecm_split=None, gsec_split=None, eah_split=None, eam_split=None, activation_name='RELU')[source]¶ Apply top_2 gating to feedforward networks.
- Parameters
gating – returns from Top2Gating consisting of: dispatch_tensor, G`SEC Tensor, scattering/dispatching inputs to experts. combine_tensor, G`SEC Tensor, combining expert outputs. aux_loss. auxiliary loss, equalizing the expert assignment ratios
inputs – G`SM Tensor.
reshaped_inputs – G`SM Tensor.
wi_split – First projection weights [E, M, H] of the feedforward networks.
wo_split – Last projection weights [E, H, M] of the feedforward networks.
num_devices – number of devices.
num_groups – number of groups (generally matches to or proportional to num_devices).
bi_split – First projection bias [E, 1, H] of the feedforward networks.
bo_split – Last projection bias [E, 1, M] of the feedforward networks.
dropout_rate – Dropout rate.
device_mesh – Device mesh as a numpy ND array of device IDs. Split arguments must be set if device_mesh is not None.
gsm_split – Mesh split for GSM tensors.
egcm_split – Mesh split for EGCM tensors.
gecm_split – Mesh split for GECM tensors.
gsec_split – Mesh split for GSEC tensors.
eah_split – Mesh split for EAH tensors.
eam_split – Mesh split for EAM tensors.
activation_name – Default:
RELU
. Activation function for feed-forward.
- Returns
G`SM Tensor. aux_loss: scalar auxiliary loss.
- Return type
outputs
-
lingvo.core.moe_layers.
GatherK
(selected_pos, values, k, num_devices=1)[source]¶ Gather up to k elements from given tensors at selected pos under SPMD.
Example:
# Input k = 3 selected_pos = [ [0, 0, 1, 1], [0, 1, 1, 0], [0, 0, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1], # topk(k=3) largest indices are selected in this row. ] value_2d = [ [1, 3, 5, 7], [9, 11, 13, 15], [17, 19, 21, 23], [25, 27, 29, 31], [33, 35, 37, 39], ] # Output: output = [ [0, 5, 7], [0, 11, 13], [0, 0, 0], [25, 27, 29], [35, 37, 39], ] # Output padding: output_padding = [ [1, 0, 0], [1, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], ]
- Parameters
selected_pos – a 0/1 2D tf.int32 tensor of shape [batch, time].
values – a list of tensors, the rank of each is at least rank=2. [batch, time, …].
k – a scalar tf.int32 tensor or a Python int. On TPU, k must be a compile-time constant.
num_devices – number of TPU devices used in xla_sharding SPMD.
- Returns
A tuple (output, padding).
output: a list of tensors of shape [batch, k, …].
padding: a 2D 0/1 tensor of shape [batch, k], ‘1’s are padded locations.
-
lingvo.core.moe_layers.
GetSentenceEmbeddings
(inputs, segment_id)[source]¶ Returns the average sentence embedding to gate by.
Example:
inputs: <tf.Variable 'Variable:0' shape=(10, 3) dtype=float64, numpy= array([[0.41258181, 0.61071571, 0.63777673], [0.65571443, 0.54297766, 0.10288261], [0.8577837 , 0.81915847, 0.61996602], [0.46897136, 0.92662692, 0.32942232], [0.60162383, 0.3385829 , 0.3408632 ], [0.40774807, 0.86139635, 0.00927162], [0.56126334, 0.51748817, 0.07791397], [0.06595223, 0.95529216, 0.34458149], [0.1238971 , 0.49897169, 0.25216722], [0.11221774, 0.50284604, 0.84106974]])> segment_id: <tf.Variable 'Variable:0' shape=(10,) dtype=int64, numpy=array([1, 1, 2, 0, 0, 3, 3, 3, 3, 0])>
- Parameters
inputs – G`SM Tensor.
segment_id – G`S Tensor.
- Returns
GSM Tensor that is an average of the input embeddings per segment.
- Return type
sentence_embeddings
-
lingvo.core.moe_layers.
SentenceTop2Gating
(w, inputs, paddings, segment_id, num_devices, experts_dim, expert_capacity_dim, local_dispatch, fprop_dtype, use_xla_sharding=True, second_expert_policy='all', second_expert_threshold=0.0, legacy_mtf_behavior=True, embedding_type='sentence', capacity_factor=None)[source]¶ Computes Top-2 sentence gating for Mixture-of-Experts.
Instead of using the each token, this function uses embedding_type to return a sentence-wise embedding to create dispatch and combine tensors that gate the entire sentence.
Note that for local_dispatch original batch BLM is reshaped into GSM, each group
g = 0...G-1
is being dispatched independently.- Parameters
w – gating weights for each experts.
inputs – G`SM Tensor.
paddings – G`S Tensor.
segment_id – G`SM Tensor used for differentiating different sentences in an input example.
num_devices – number of MoE devices for local dispatch
experts_dim – number of experts.
expert_capacity_dim – number of examples per minibatch(group) per expert. Each example is typically a vector of size input_dim, representing embedded token or an element of Transformer layer output.
local_dispatch – whether dispatch is local to the group (G dim)
fprop_dtype – activations datatype to use.
use_xla_sharding – bool, True if this function is used for the xla_sharding case.
second_expert_policy – ‘all’ or ‘random’, we optionally ‘random’-ize dispatch to second-best expert proportional to (weight / second_expert_threshold).
second_expert_threshold – threshold for probability normalization for second_expert_policy == ‘random’.
legacy_mtf_behavior – True for legacy behavior with no re-normalization of expert assignment weights if we go over capacity or randomly decide to not dispatch to second expert.
embedding_type – ‘sentence’ by default. Options: ‘sentence’. Setting this option calls GetSentenceEmbeddings.
capacity_factor – if set, increases expert_capacity_dim to at least (group_size * capacity_factor) / experts_dim where
group_size
is the size of G dimension ofinputs
. If the value of expert_capacity_dim is already big enough no change is made.
- Returns
A tuple (dispatch_tensor, combine_tensor, aux_loss).
dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to experts.
combine_tensor: G`SEC Tensor. combining expert outputs.
aux_loss: auxiliary loss, equalizing the expert assignment ratios.
-
lingvo.core.moe_layers.
TaskTop2Gating
(w, inputs, paddings, task_embeddings, num_devices, experts_dim, expert_capacity_dim, local_dispatch, fprop_dtype, use_xla_sharding=True, second_expert_policy='all', second_expert_threshold=0.0, legacy_mtf_behavior=True)[source]¶ Computes Top-2 sentence gating for Mixture-of-Experts.
Instead of using the each token, this function uses embedding_type to return a sentence-wise embedding to create dispatch and combine tensors that gate the entire sentence.
Note that for local_dispatch original batch BLM is reshaped into GSM, each group
g = 0...G-1
is being dispatched independently.- Parameters
w – gating weights for each experts.
inputs – G`SM Tensor.
paddings – G`S Tensor.
task_embeddings – G`SM Tensor.
num_devices – number of MoE devices for local dispatch
experts_dim – number of experts.
expert_capacity_dim – number of examples per minibatch(group) per expert. Each example is typically a vector of size input_dim, representing embedded token or an element of Transformer layer output.
local_dispatch – whether dispatch is local to the group (G dim)
fprop_dtype – activations datatype to use.
use_xla_sharding – bool, True if this function is used for the xla_sharding case.
second_expert_policy – ‘all’ or ‘random’, we optionally ‘random’-ize dispatch to second-best expert proportional to (weight / second_expert_threshold).
second_expert_threshold – threshold for probability normalization for second_expert_policy == ‘random’.
legacy_mtf_behavior – True for legacy behavior with no re-normalization of expert assignment weights if we go over capacity or randomly decide to not dispatch to second expert.
- Returns
dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to experts.
combine_tensor: G`SEC Tensor. combining expert outputs.
aux_loss: auxiliary loss, equalizing the expert assignment ratios.
- Return type
A tuple (dispatch_tensor, combine_tensor, aux_loss)