Source code for lingvo.core.conformer_layer
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Conformer layers as in https://arxiv.org/abs/2005.08100."""
from lingvo import compat as tf
from lingvo.core import activations
from lingvo.core import base_layer
from lingvo.core import batch_major_attention as attention_lib
from lingvo.core import bn_layers
from lingvo.core import conv_layers_with_time_padding
from lingvo.core import gshard_builder
from lingvo.core import gshard_utils
from lingvo.core import hyperparams as hparams_lib
from lingvo.core import layers
from lingvo.core import layers_with_attention
from lingvo.core import plot
from lingvo.core import py_utils
from lingvo.core import recurrent
from lingvo.core import summary_utils
[docs]class LConvLayer(base_layer.BaseLayer):
r"""Lightweight conv layer.
architecture::
input
/ \
| ln(.) # input_dim
| fflayer(.) # 2 * input_dim
| |
| glu(.) # input_dim
| depthwise_conv_1d(.)
| norm(.)
| act(.)
| |
| fflayer(.)
| dropout(.)
\ /
+
|
output
"""
[docs] @classmethod
def Params(cls):
p = super().Params()
p.Define('input_dim', None, 'Input and (in fact,) output dimension.')
p.Define('kernel_size', None, 'Kernel size of 1d deptwise conv.')
p.Define('conv_activation', 'SWISH', 'Activation after normalization.')
p.Define(
'is_causal', False, 'Whether this is a causal layer.'
'If set to true, use '
'conv_layers_with_time_padding.CausalDepthwiseConv2DLayer for '
'`depthwise_conv_tpl`.')
p.Define(
'glu_activation', 'NONE',
'Activation in GLU. Check lingvo.core.activations._ACTIVATIONS for '
'other options.')
p.Define('dropout_prob', 0., 'Dropout probability.')
p.Define('ln_tpl', layers.LayerNorm.Params(), 'Input layer norm template.')
p.Define('linear_start_tpl', layers.FCLayer.Params(), 'Linear start layer.')
p.Define(
'depthwise_conv_tpl',
conv_layers_with_time_padding.DepthwiseConv2DLayer.Params(),
'Depthwise conv template. For causal layer, use '
'conv_layers_with_time_padding.CausalDepthwiseConv2DLayer.')
p.Define('conv_norm_layer_tpl', bn_layers.BatchNormLayer.Params(),
'Normalization layer after conv.')
p.Define('linear_end_tpl', layers.FCLayer.Params(), 'Linear end layer.')
p.Define('dropout_tpl', layers.DropoutLayer.Params(),
'Residual dropout layer.')
p.Define(
'split_act_gated_linear_start', False,
'Separate act and gated linear start to remove data formatting '
'overheads')
p.linear_start_tpl.Set(activation='NONE', has_bias=True)
p.linear_end_tpl.Set(activation='NONE', has_bias=True)
# SPMD partition related params.
#
# d - model_dim
# f - ff_hidden_dim (here ff_hidden_dim has the same size as model_dim)
# h - height
# w - width
# i - in_channels
# m - channel_multiplier
# b - batch_size
# l - seq_len
p.weight_split_dims_mapping = hparams_lib.Params()
wp = p.weight_split_dims_mapping
wp.Define(
'df', None,
'Mesh split for lconv linear start weight with the shape of '
'[model_dim, ff_hidden_dim], the default hidden_dim is the same as '
'the model_dim.')
wp.Define(
'hwim', None,
'Mesh split for lconv depthwise conv weight with the shape of '
'[height, width, in_channels, channel_multiplier]. Width and '
'channel_multiplier are both 1 for the common use case.')
wp.Define(
'fd', None, 'Mesh split for lconv linear end weight with the shape of '
'[ff_hidden_dim, model_dim], the default hidden_dim is the same as '
'the model_dim.')
p.activation_split_dims_mapping = hparams_lib.Params()
ap = p.activation_split_dims_mapping
ap.Define(
'blf', None, 'Mesh split for lconv linear start activation and lconv '
'depthwise conv after normalization with the shape of '
'[batch_size, seq_len, ff_hidden_dim], the default hidden_dim is the '
'same as model_dim.')
ap.Define(
'bld', None,
'Mesh split for lconv linear end activation with the shape of '
'[batch_size, seq_len, model_dim].')
return p
[docs] @classmethod
def SetCanonicalShardingParams(cls, params):
"""Set up canonical SPMD sharding params."""
assert params.device_mesh.ndim >= 2
wp = params.weight_split_dims_mapping
wp.df = [0, 1]
wp.hwim = [-1, -1, 1, -1]
# TODO(shibow/rpang): understand the effects of fd sharding, especially why
# [-1, -1] performs better when bld is [0, -1, -1].
wp.fd = [1, 0]
ap = params.activation_split_dims_mapping
ap.blf = [0, -1, 1]
ap.bld = [1, -1, -1]
[docs] @classmethod
def CommonParams(cls,
input_dim=None,
kernel_size=None,
is_causal=False,
conv_activation='SWISH',
dropout_prob=0.):
p = cls.Params().Set(
input_dim=input_dim,
is_causal=is_causal,
kernel_size=kernel_size,
conv_activation=conv_activation,
dropout_prob=dropout_prob)
if is_causal:
p.depthwise_conv_tpl = (
conv_layers_with_time_padding.CausalDepthwiseConv2DLayer.Params())
return p
[docs] @classmethod
def SetFPropDtype(cls, p, fprop_dtype):
p.fprop_dtype = fprop_dtype
if fprop_dtype == tf.bfloat16 and not py_utils.use_tpu():
# Depthwise conv supports bfloat16 only on TPUs.
p.depthwise_conv_tpl.fprop_dtype = tf.float32
if issubclass(p.conv_norm_layer_tpl.cls, bn_layers.BatchNormLayer):
# Batch norm does not support bfloat16 on TPUs.
p.conv_norm_layer_tpl.fprop_dtype = tf.float32
return p
def __init__(self, params):
super().__init__(params)
p = self.params
ln_p = p.ln_tpl.Copy().Set(name='ln', input_dim=p.input_dim)
self.CreateChild('ln', ln_p)
if p.split_act_gated_linear_start:
linear_start_act_p = p.linear_start_tpl.Copy().Set(
input_dim=p.input_dim,
output_dim=p.input_dim,
device_mesh=p.device_mesh,
weight_split_dims_mapping=p.weight_split_dims_mapping.df,
activation_split_dims_mapping=p.activation_split_dims_mapping.blf)
linear_start_gated_p = p.linear_start_tpl.Copy().Set(
input_dim=p.input_dim,
output_dim=p.input_dim,
device_mesh=p.device_mesh,
weight_split_dims_mapping=p.weight_split_dims_mapping.df,
activation_split_dims_mapping=p.activation_split_dims_mapping.blf)
self.CreateChild('linear_start_act', linear_start_act_p)
self.CreateChild('linear_start_gated', linear_start_gated_p)
else:
linear_start_p = p.linear_start_tpl.Copy().Set(
name='linear_start',
input_dim=p.input_dim,
output_dim=2 * p.input_dim)
self.CreateChild('linear_start', linear_start_p)
linear_end_p = p.linear_end_tpl.Copy().Set(
name='linear_end',
input_dim=p.input_dim,
output_dim=p.input_dim,
device_mesh=p.device_mesh,
weight_split_dims_mapping=p.weight_split_dims_mapping.fd,
activation_split_dims_mapping=p.activation_split_dims_mapping.bld)
self.CreateChild('linear_end', linear_end_p)
if p.conv_norm_layer_tpl.cls == layers.LayerNorm:
norm_p = p.conv_norm_layer_tpl.Copy().Set(
name='norm_layer', input_dim=p.input_dim)
else:
norm_p = p.conv_norm_layer_tpl.Copy().Set(
name='norm_layer', dim=p.input_dim)
if p.conv_norm_layer_tpl.cls == bn_layers.GroupNormLayer:
norm_p.cumulative = p.is_causal
self.CreateChild('norm', norm_p)
if (p.is_causal and p.depthwise_conv_tpl.cls ==
conv_layers_with_time_padding.DepthwiseConv2DLayer):
# If causal, switch to causal depthwise conv.
depthwise_conv_p = (
conv_layers_with_time_padding.CausalDepthwiseConv2DLayer.Params())
hparams_lib.CopyFieldsTo(p.depthwise_conv_tpl, depthwise_conv_p)
else:
depthwise_conv_p = p.depthwise_conv_tpl.Copy()
if issubclass(depthwise_conv_p.cls,
conv_layers_with_time_padding.DepthwiseConv2DLayer):
depthwise_conv_p.filter_shape = (p.kernel_size, 1, p.input_dim, 1)
else:
depthwise_conv_p.filter_shape = (p.kernel_size, 1, p.input_dim,
p.input_dim)
# 1d depthwise conv with channel_mulitplier = 1
depthwise_conv_p.Set(
name='depthwise_conv',
filter_stride=(1, 1))
self.CreateChild('depthwise_conv1d', depthwise_conv_p)
dropout_p = p.dropout_tpl.Copy().Set(
name='dropout', keep_prob=1. - p.dropout_prob)
self.CreateChild('dropout', dropout_p)
[docs] def _GLU(self, gated_inputs, act_inputs):
p = self.params
return self._ApplyActivation(act_inputs,
p.glu_activation) * tf.sigmoid(gated_inputs)
[docs] def _ApplyActivation(self, inputs, act_name):
if act_name == 'NONE':
return inputs
return activations.GetFn(act_name)(inputs)
[docs] def _Normalize(self, theta, inputs, paddings):
"""Applies normalization.
Args:
theta: A NestedMap of layer params.
inputs: [b, t, 1, d].
paddings: [b, t].
Returns:
A Tensor of shape [b, t, d].
"""
if isinstance(self.norm, bn_layers.GroupNormLayer):
assert self.norm.params.input_rank == 4
inputs, _ = self.norm.FProp(theta.norm, inputs, paddings)
# [b, t, d]
inputs = tf.squeeze(inputs, 2)
else:
# [b, t, 1, d] -> [b, t, d]
inputs = tf.squeeze(inputs, 2)
if isinstance(self.norm, bn_layers.BatchNormLayer):
inputs = self.norm.FProp(theta.norm, inputs, paddings)
elif isinstance(self.norm, layers.LayerNorm):
inputs = self.norm.FProp(theta.norm, inputs)
else:
raise NotImplementedError(
'Only bn_layers.{BatchNormLayer,GroupNormLayer}, layers.LayerNorm '
'are supported.')
return self._CastToFPropDtype(inputs)
[docs] def FProp(self, theta, inputs, paddings):
"""Builds FProp graph.
Args:
theta: A NestedMap of Tensors, see base class.
inputs: A Tensor of shape [batch, seqlen, dim0].
paddings: A Tensor of shape [batch, seqlen].
Returns:
output: A Tensor of shape [batch, seqlen, dim0].
out_paddings: A Tensor of shape [batch, seqlen].
"""
p = self.params
with tf.name_scope(p.name):
inputs, paddings = self._CastToFPropDtype((inputs, paddings))
unnormalized_inputs = inputs
inputs = self.ln.FProp(theta.ln, inputs)
inputs = self._CastToFPropDtype(inputs)
if p.split_act_gated_linear_start:
act_inputs = self.linear_start_act.FProp(theta.linear_start_act, inputs)
gated_inputs = self.linear_start_gated.FProp(theta.linear_start_gated,
inputs)
else:
inputs = self.linear_start.FProp(theta.linear_start, inputs)
gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1)
inputs = self._GLU(gated_inputs, act_inputs)
# TODO(jamesqin): introduce depthwise conv2d with 3d inputs.
# [b, t, d] --> [b, t, 1, d]
inputs = tf.expand_dims(inputs, 2)
adapted_blf_dims_mapping = None
if p.activation_split_dims_mapping.blf is not None:
adapted_blf_dims_mapping = p.activation_split_dims_mapping.blf.copy()
adapted_blf_dims_mapping.insert(2, -1)
inputs = gshard_utils.MeshSplit(inputs, p.device_mesh,
adapted_blf_dims_mapping)
theta.depthwise_conv1d.w = gshard_utils.MeshSplit(
theta.depthwise_conv1d.w, p.device_mesh,
p.weight_split_dims_mapping.hwim)
if inputs.dtype == tf.bfloat16 and not py_utils.use_tpu():
# Depthwise conv doesn't support bfloat32 on CPU.
inputs = tf.cast(inputs, tf.float32)
paddings = tf.cast(paddings, tf.float32)
inputs, paddings = self.depthwise_conv1d.FProp(theta.depthwise_conv1d,
inputs, paddings)
inputs, paddings = self._CastToFPropDtype((inputs, paddings))
inputs = gshard_utils.MeshSplit(inputs, p.device_mesh,
adapted_blf_dims_mapping)
inputs = self._Normalize(theta, inputs, paddings)
inputs = gshard_utils.MeshSplit(inputs, p.device_mesh,
p.activation_split_dims_mapping.blf)
inputs = self._ApplyActivation(inputs, p.conv_activation)
inputs = self.linear_end.FProp(theta.linear_end, inputs)
inputs = self.dropout.FProp(theta.dropout, inputs)
output = inputs + unnormalized_inputs
return output, paddings
[docs] def zero_state(self, batch_size):
p = self.params
with tf.name_scope('zero_state'):
if p.is_causal:
with tf.name_scope('depthwise_conv1d'):
res = py_utils.NestedMap(
conv_state=self.depthwise_conv1d.zero_state(batch_size))
if hasattr(self.norm, 'zero_state'):
with tf.name_scope('norm'):
res.norm_state = self.norm.zero_state(batch_size)
return res
else:
# If not causal, depthwise_conv1d does not have zero_state().
return py_utils.NestedMap()
[docs] def _NormalizeStep(self, theta, inputs, paddings, state0, state1):
if hasattr(self.norm, 'StreamStep'):
# TODO(jamesqin): support 3d inputs.
# At present it's guaranteed GroupNorm.
assert (isinstance(self.norm, bn_layers.GroupNormLayer) and
self.norm.params.input_rank == 4)
inputs, paddings, norm_state1 = self.norm.StreamStep(
theta.norm, inputs, paddings, state0.norm_state)
# [b, t, d]
inputs = tf.squeeze(inputs, 2)
state1.norm_state = norm_state1
else:
# [b, t, 1, d] -> [b, t, d]
inputs = tf.squeeze(inputs, 2)
if isinstance(self.norm, layers.LayerNorm):
inputs = self.norm.FProp(theta.norm, inputs)
else:
raise NotImplementedError(
'Only bn_layers.GroupNormLayer, layers.LayerNorm are supported.')
# [b, t, d]
return inputs, paddings
[docs] def StreamStep(self, theta, inputs, paddings, state0):
"""Streams t steps.
Args:
theta: A NestedMap of layer params.
inputs: [b, t, d].
paddings: A 0/1 valued tensor of shape [b, t].
state0: A NestedMap of tensors of the same struct as returned by
zero_state().
Returns:
outputs: A NestedMap of tensors consisting:
padding: the same as input paddings.
state1: A NestedMap of tensors of the same struct as state0.
"""
p = self.params
assert p.is_causal
state1 = py_utils.NestedMap()
with tf.name_scope(f'{p.name}/StreamStep'):
unnormalized_inputs = inputs
inputs = self.ln.FProp(theta.ln, inputs)
if p.split_act_gated_linear_start:
act_inputs = self.linear_start_act.FProp(theta.linear_start_act, inputs)
gated_inputs = self.linear_start_gated.FProp(theta.linear_start_gated,
inputs)
else:
inputs = self.linear_start.FProp(theta.linear_start, inputs)
gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1)
inputs = self._GLU(gated_inputs, act_inputs)
# TODO(jamesqin): introduce depthwise conv2d with 3d inputs.
# TODO(jamesqin): optimize DepthwiseConv1D.StreamStep()
# [b, t, d] --> [b, t, 1, d]
inputs = tf.expand_dims(inputs, 2)
# [b, t, 1, d]
inputs, paddings, conv_state1 = self.depthwise_conv1d.StreamStep(
theta.depthwise_conv1d, inputs, paddings, state0.conv_state)
state1.conv_state = conv_state1
# [b, t, d]
inputs, paddings = self._NormalizeStep(theta, inputs, paddings, state0,
state1)
inputs = self._ApplyActivation(inputs, p.conv_activation)
inputs = self.linear_end.FProp(theta.linear_end, inputs)
inputs = self.dropout.FProp(theta.dropout, inputs)
output = inputs + unnormalized_inputs
return output, paddings, state1
[docs]def GShardMoELayerParams(num_devices,
num_groups,
num_experts,
per_expert_capacity_dim,
use_densebuilder=False):
if use_densebuilder:
moe_cls = gshard_builder.DenseBuilder
else:
moe_cls = gshard_builder.MoEBuilder
return moe_cls.Params().Set(
num_devices=num_devices,
num_groups=num_groups,
e_dim=num_experts,
c_dim=per_expert_capacity_dim)
[docs]class ConformerLayer(base_layer.BaseLayer):
"""Conformer layer as in https://arxiv.org/abs/2005.08100.
Canonical version (with default params.)
x = x + 1/2 * FFN(x)
x = x + MHSA(x)
x = x + Lconv(x)
x = x + 1/2 * FFN(x)
y = ln(x)
Optionally one can change the order of MHSA and conv.
"""
[docs] @classmethod
def Params(cls):
p = super().Params()
p.Define('input_dim', None, 'Input dimension.')
p.Define(
'is_causal', False, 'If use causal lconv and MHSA layer.'
'Notice atten_right_context must be not be infinite(None) if is_causal '
'is True. It is important to always set is_causal for streaming case, '
'and not expect to infer from atten_{left,right}_context.')
p.Define(
'layer_order', 'mhsa_before_conv',
'Only mhsa, conv, mhsa_before_conv or conv_before_mhsa are '
'supported.')
p.Define('dropout_prob', None, 'Dropout prob of inner components.')
# fflayer tpl
p.Define(
'fflayer_start_tpl',
layers_with_attention.TransformerFeedForwardLayer.Params(),
'Layer params for Feed forward layer at the beginning. Supports '
'using gshard_builder.MoEBuilder.Params() as well wherein the '
'MoE() will be used. If set to None, this layer is excluded.')
p.Define('trans_atten_tpl',
attention_lib.TransformerAttentionLayer.Params(),
'Self attention layer params.')
p.Define(
'lconv_tpl', LConvLayer.Params(),
'Convolution module params. If set to None, this layer is excluded.')
p.Define(
'fflayer_end_tpl',
layers_with_attention.TransformerFeedForwardLayer.Params(),
'Layer params for Feed forward layer at the end. Supports using '
'gshard_builder.MoEBuilder.Params() as well wherein the MoE() '
'will be used.')
p.Define(
'fflayer_weight_sharing', False,
'If True, will ignore `fflayer_end_tpl`, and will make the fflayer_end '
'layer as a weight-shared copy of the fflayer_start layer.')
p.Define('final_ln_tpl', layers.LayerNorm.Params(), 'Final layer norm.')
# If adapter_tpl is set, layer_out = adapter(conformer(layer_in))
# The adapter must
# 1. have instance method FProp(self, theta, in_nmap) -> out_nmap, where
# {in,out}_nmap must have 'features' and 'paddings' and $adapter_p.task_ids
# fields.
# 2. have class method SetInputDim(cls, p, input_dim)
p.Define('adapter_tpl', None, 'If set, runs an adapter layer in the end.')
# https://b/167460492#comment16
p.Define(
'remat', False, 'If to rematerialize the layer. If true, '
'intermediate tensors are not saved in FProp().')
p.Define(
'list_regex_dtypes', [],
'A list of (regex, dtype) to set the data types of variables using '
'regex. The default value is [] to use the existing data types without '
'any changes. If a variable name matches the first regex in the list, '
'the variable data type will be set by the corresponding dtype.')
p.Define('allow_attention_summaries', False,
'Allow plotting attention histogram and plot summaries.')
return p
[docs] @classmethod
def CommonParams(cls,
*,
input_dim=None,
is_causal=False,
atten_num_heads=None,
atten_local_context=None,
atten_left_context=None,
atten_right_context=None,
atten_chunk_size=None,
atten_logit_cap=0.0,
use_relative_atten=True,
query_stride=1,
kernel_size=None,
fflayer_hidden_dim=None,
fflayer_activation='SWISH',
fflayer_residual_weight=0.5,
layer_order='mhsa_before_conv',
dropout_prob=0.,
conv_norm_layer_tpl=None,
fprop_dtype=None,
use_moe_in_fflayer_start=False,
use_moe_in_fflayer_end=False,
moe_num_partitions=None,
moe_num_experts=None,
moe_num_groups=None,
moe_per_capacity_dim=None,
fflayer_start_tpl=None,
fflayer_end_tpl=None,
trans_atten_tpl=None,
lconv_tpl=None,
list_regex_dtypes=None):
assert all([input_dim])
if layer_order != 'conv':
assert atten_num_heads or trans_atten_tpl
if layer_order == 'mhsa':
assert not any([kernel_size, conv_norm_layer_tpl, lconv_tpl])
else:
assert kernel_size
if _AttenCtxIsSet(atten_local_context):
assert not _AttenCtxIsSet(atten_left_context) and not _AttenCtxIsSet(
atten_right_context
), ('atten_local_context and atten_{left,right}_context can not be set'
'at the same time.')
atten_left_context = atten_local_context + 1 # including self position.
atten_right_context = atten_local_context
if is_causal and trans_atten_tpl is None:
# None is different from 0, the former is 'infinite'.
assert atten_right_context is not None, (
'is_causal is not compatible with infinite atten_right_context '
'(None).')
p = cls.Params().Set(
input_dim=input_dim,
is_causal=is_causal,
layer_order=layer_order,
dropout_prob=dropout_prob)
if use_moe_in_fflayer_start:
assert fflayer_start_tpl is None
fflayer_start_tpl = GShardMoELayerParams(moe_num_partitions,
moe_num_experts, moe_num_groups,
moe_per_capacity_dim)
if use_moe_in_fflayer_end:
assert fflayer_end_tpl is None
fflayer_end_tpl = GShardMoELayerParams(moe_num_partitions,
moe_num_experts, moe_num_groups,
moe_per_capacity_dim)
def _ConfigureFF(fflayer_tpl):
config_kwargs = dict(
input_dim=input_dim,
hidden_dim=fflayer_hidden_dim,
activation=fflayer_activation,
residual_weight=fflayer_residual_weight,
dropout_prob=dropout_prob)
if fflayer_tpl is None:
return cls.ConfigFFLayer(
tpl=layers_with_attention.TransformerFeedForwardLayer.Params(),
**config_kwargs)
elif issubclass(fflayer_tpl.cls, gshard_builder.MoEBuilder):
# TODO(rpang): make users call ConfigMoEParams directly.
return cls.ConfigMoEParams(tpl=fflayer_tpl, **config_kwargs)
elif fflayer_tpl.use_block_diagonal_matmul_pl:
# Override params other than block_diag_matmul options.
return cls.ConfigFFLayer(tpl=fflayer_tpl, **config_kwargs)
else:
assert fflayer_hidden_dim is None, fflayer_hidden_dim
assert fflayer_activation is None, fflayer_activation
assert fflayer_residual_weight is None, fflayer_residual_weight
return fflayer_tpl
# Set the two feed forward modules.
p.fflayer_start_tpl = _ConfigureFF(fflayer_start_tpl)
p.fflayer_end_tpl = _ConfigureFF(fflayer_end_tpl)
# Set the MHSA module.
if trans_atten_tpl is not None:
assert atten_left_context is None
assert atten_right_context is None
assert use_relative_atten is None
assert atten_num_heads is None
assert query_stride == 1
p.trans_atten_tpl = trans_atten_tpl
else:
atten_tpl = cls._ConfigSelfAttenContext(
atten_left_context,
atten_right_context,
atten_chunk_size=atten_chunk_size,
use_relative_atten=use_relative_atten,
query_stride=query_stride,
relative_pos_emb_dim=input_dim)
if query_stride == 1:
p.trans_atten_tpl = attention_lib.TransformerAttentionLayer.Params(
).Set(
atten_tpl=atten_tpl, num_heads=atten_num_heads)
else:
p.trans_atten_tpl = attention_lib.FunnelTransformerAttentionLayer.Params(
).Set(
atten_tpl=atten_tpl, num_heads=atten_num_heads)
p.trans_atten_tpl.funnel_tpl.stride = query_stride
p.trans_atten_tpl.res_funnel_tpl.stride = query_stride
# Set the convolution module.
if lconv_tpl is not None:
p.lconv_tpl = lconv_tpl
if kernel_size:
p.lconv_tpl.kernel_size = kernel_size
if conv_norm_layer_tpl is not None:
p.lconv_tpl.conv_norm_layer_tpl = conv_norm_layer_tpl
if fprop_dtype is not None:
p.cls.SetFPropDtype(p, fprop_dtype)
if isinstance(p.trans_atten_tpl.atten_tpl, list):
for a in p.trans_atten_tpl.atten_tpl:
a.atten_logit_cap = atten_logit_cap
else:
p.trans_atten_tpl.atten_tpl.atten_logit_cap = atten_logit_cap
if list_regex_dtypes is not None:
p.list_regex_dtypes = list_regex_dtypes
if layer_order == 'mhsa':
p.lconv_tpl = None
return p
[docs] @classmethod
def Stride(cls, params):
if 'funnel_tpl' in params.trans_atten_tpl:
return params.trans_atten_tpl.funnel_tpl.stride
return 1
[docs] @classmethod
def _ConfigSelfAttenContext(cls, atten_left_context, atten_right_context, *,
use_relative_atten, atten_chunk_size,
query_stride, relative_pos_emb_dim):
# TODO(jamesqin): add an attention factory in batch_major_attention.
if atten_chunk_size is not None:
if use_relative_atten:
atten_type = 'chunk_relative'
else:
atten_type = 'chunk'
elif not _AttenCtxIsSet(atten_left_context) and not _AttenCtxIsSet(
atten_right_context):
# No atten context set, each position attends to all positions.
atten_type = 'global' if not use_relative_atten else 'global_relative'
elif not _AttenCtxIsSet(atten_left_context) and atten_right_context == 0:
# Left context is infinite, right context is 0.
assert not use_relative_atten, (
'Relative attention isn\'t supported for causal attention.')
atten_type = 'global_causal'
else:
atten_type = 'local_relative' if use_relative_atten else 'local'
if atten_type == 'global_relative':
atten_tpl = (
attention_lib.MultiHeadedAttentionXL.Params().Set(
rel_pos_emb_dim=relative_pos_emb_dim))
elif atten_type == 'local_relative':
atten_tpl = attention_lib.LocalSelfAttentionXL.Params().Set(
left_context=atten_left_context,
right_context=atten_right_context,
rel_pos_emb_dim=relative_pos_emb_dim,
query_stride=query_stride)
elif atten_type == 'local':
atten_tpl = attention_lib.LocalSelfAttention.Params().Set(
left_context=atten_left_context,
right_context=atten_right_context,
query_stride=query_stride)
elif atten_type == 'chunk':
atten_tpl = attention_lib.ChunkwiseSelfAttention.Params().Set(
left_context=atten_left_context,
right_context=atten_right_context,
chunk_size=atten_chunk_size)
elif atten_type == 'chunk_relative':
atten_tpl = attention_lib.ChunkwiseSelfAttentionXL.Params().Set(
left_context=atten_left_context,
right_context=atten_right_context,
chunk_size=atten_chunk_size,
rel_pos_emb_dim=relative_pos_emb_dim)
else:
# No op for 'global' atten
assert atten_type in ('global', 'global_causal'), (
f'Unknown atten_type {atten_type}')
atten_tpl = attention_lib.MultiHeadedAttention.Params()
return atten_tpl
[docs] @classmethod
def SetFPropDtype(cls, p, fprop_dtype):
p.fprop_dtype = fprop_dtype
for sub_p in (p.lconv_tpl, p.trans_atten_tpl, p.fflayer_start_tpl,
p.fflayer_end_tpl):
if sub_p is not None:
sub_p.cls.SetFPropDtype(sub_p, fprop_dtype)
return p
[docs] @classmethod
def ConfigFFLayer(cls, *, tpl, input_dim, hidden_dim, activation,
residual_weight, dropout_prob):
"""Configures tpl params for a feed-forward layer."""
p = tpl.Copy().Set(
input_dim=input_dim,
hidden_dim=hidden_dim,
activation=activation,
residual_weight=residual_weight,
residual_dropout_prob=dropout_prob,
relu_dropout_prob=dropout_prob)
return p
[docs] @classmethod
def ConfigMoEParams(cls, *, tpl, input_dim, hidden_dim, activation,
residual_weight, dropout_prob):
"""Configures tpl params for a MoE layer."""
moe_builder_p = tpl.Copy().Set(
model_dim=input_dim,
dropout_rate=dropout_prob,
moe_hidden_dim=hidden_dim,
moe_activation=activation)
if moe_builder_p.cls == gshard_builder.MoEBuilder:
if moe_builder_p.num_devices is None:
raise ValueError('num_devices must be specified for MoEBuilder.')
if residual_weight != 0.5:
raise ValueError('residual_weight must be 0.5')
return moe_builder_p
def __init__(self, params):
super().__init__(params)
p = self.params
assert p.layer_order in [
'mhsa', 'conv', 'mhsa_before_conv', 'conv_before_mhsa'
]
if p.layer_order == 'mhsa':
assert not self.has_lconv, 'mhsa must not have a lconv block.'
# Change the variable dtypes by list_regex_dtypes.
with py_utils.VariableListDtypeRegexScope(self.params.list_regex_dtypes):
if self.has_fflayer_start:
fflayer_start_p = self._ConfigFFLayerOrMoEParams(
p.fflayer_start_tpl, 'fflayer_start')
if fflayer_start_p.name:
assert fflayer_start_p.name == 'fflayer_start_moe'
else:
fflayer_start_p.name = 'fflayer_start'
self.CreateChild(fflayer_start_p.name, fflayer_start_p)
fflayer_end_p = self._ConfigFFLayerOrMoEParams(p.fflayer_end_tpl,
'fflayer_end')
if fflayer_end_p.name:
assert fflayer_end_p.name == 'fflayer_end_moe'
else:
fflayer_end_p.name = 'fflayer_end'
if not p.fflayer_weight_sharing:
self.CreateChild(fflayer_end_p.name, fflayer_end_p)
else:
self.AddChild(fflayer_end_p.name, self.children[fflayer_start_p.name])
# For local MHSA, is_masked is ignored, thus it's safe to set is_masked
# based on p.is_causal, for global and local MHSA cases.
if self.has_mhsa:
trans_atten_p = p.trans_atten_tpl.Copy().Set(
input_dim=p.input_dim,
is_masked=p.is_causal,
atten_dropout_prob=p.dropout_prob,
residual_dropout_prob=p.dropout_prob)
if tf.logging.vlog_is_on(2):
for line in trans_atten_p.atten_tpl.ToText().split('\n'):
tf.logging.info('ConformerLayer.atten_tpl: %s', line)
self.CreateChild('trans_atten', trans_atten_p)
if self.has_lconv:
lconv_p = p.lconv_tpl.Copy().Set(
input_dim=p.input_dim, is_causal=p.is_causal)
self.CreateChild('lconv', lconv_p)
ln_p = p.final_ln_tpl.Copy().Set(name='final_ln', input_dim=p.input_dim)
self.CreateChild('final_ln', ln_p)
if p.adapter_tpl:
p.adapter_tpl.cls.SetInputDim(p.adapter_tpl, p.input_dim)
self.CreateChild('adapter', p.adapter_tpl)
# lconv and fflayer_start have the special treatment, which can be absent,
# because Transformer doesn't have those.
@property
def has_lconv(self):
return bool(self.params.lconv_tpl)
@property
def has_mhsa(self):
return bool('mhsa' in self.params.layer_order)
@property
def has_fflayer_start(self):
return bool(self.params.fflayer_start_tpl)
[docs] def _ConfigFFLayerOrMoEParams(self, fflayer_tpl, name_prefix):
p = self.params
fflayer_tpl = fflayer_tpl.Copy()
if not issubclass(fflayer_tpl.cls, gshard_builder.MoEBuilder):
if hasattr(fflayer_tpl, 'input_dim'):
fflayer_tpl.Set(input_dim=p.input_dim)
return fflayer_tpl
# TODO(rpang): migrate clients to TransformerShardedMoeLayer.
fflayer_tpl.model_dim = p.input_dim
moe_builder = fflayer_tpl.Instantiate()
name = name_prefix + '_moe'
# TODO(rpang): find a way to configure residual weight.
moe_p = moe_builder.EncoderLayer(
name, moe_builder.MoE(name), residual_weight=0.5)
return moe_p
[docs] def _SelfAtten(self, theta, inputs, paddings):
if isinstance(self.trans_atten,
attention_lib.FunnelTransformerAttentionLayer):
inputs, paddings, atten_probs = self.trans_atten.FProp(
theta.trans_atten,
query_vec=inputs,
source_vecs=None,
paddings=paddings)
else:
inputs, atten_probs = self.trans_atten.FProp(
theta.trans_atten,
query_vec=inputs,
source_vecs=None,
paddings=paddings)
return inputs, paddings, atten_probs
[docs] def _LConv(self, theta, inputs, paddings):
assert self.has_lconv and self.params.layer_order != 'mhsa', (
'mhsa does not have a lconv block.')
inputs, paddings = self.lconv.FProp(theta.lconv, inputs, paddings)
return inputs, paddings
[docs] def _MoeOrFFLayer(self, theta, fflayer_name, features, paddings, aux_loss):
"""FProp for MoE or Feed forward layer.
Args:
theta: Layer theta: A NestedMap of Tensors.
fflayer_name: Child FFLayer name as created in __init__.
For example: 'fflayer_end'. This assumes the moe_layer if created would
have the convention as (`fflayer_name` + `_moe`).
features: [batch, seqlen, dim0].
paddings: [batch, seqlen].
aux_loss: [], can be None.
Returns:
features: [batch, seqlen, dim0].
paddings: [batch, seqlen].
aux_loss: [], is None if input aux_loss is None.
"""
if fflayer_name in self.children:
outputs = self.children[fflayer_name].FProp(
theta.GetItem(fflayer_name), features, paddings)
return outputs, paddings, aux_loss
else:
moe_fflayer_name = fflayer_name + '_moe'
if moe_fflayer_name not in self.children:
raise AssertionError(
'{} child layer not present.'.format(moe_fflayer_name))
if moe_fflayer_name not in theta:
raise AssertionError(
'{} layer theta not present.'.format(moe_fflayer_name))
# 0 - padded positions and 1 - non-padded positions.
segment_ids = tf.cast(1. - paddings, tf.int32)
segment_pos = tf.zeros_like(segment_ids) # not used but required by MoE.
moe_in = py_utils.NestedMap(
vec=features, segment_id=segment_ids, segment_pos=segment_pos)
moe_out = self.children[moe_fflayer_name].FProp(
theta.GetItem(moe_fflayer_name), moe_in)
moe_aux_loss = moe_out.aux_loss
if aux_loss is not None:
assert not moe_aux_loss.shape.rank, 'MoE aux-loss should be a scalar.'
if len(py_utils.GetShape(aux_loss)) == 1:
b_size = py_utils.GetShape(aux_loss)[0]
moe_aux_loss = tf.tile(tf.expand_dims(moe_aux_loss, axis=0), [b_size])
assert moe_aux_loss.shape.rank == aux_loss.shape.rank
aux_loss += moe_aux_loss
else:
aux_loss = moe_aux_loss
return moe_out.vec, paddings, aux_loss
[docs] def _FProp(self, theta, in_nmap):
p = self.params
with tf.name_scope(p.name):
features, paddings = in_nmap.features, in_nmap.paddings
aux_loss = in_nmap.Get('aux_loss')
features, paddings = self._CastToFPropDtype((features, paddings))
out_nmap = py_utils.NestedMap()
if self.has_fflayer_start:
features, paddings, aux_loss = self._MoeOrFFLayer(
theta, 'fflayer_start', features, paddings, aux_loss)
atten_probs = None
if p.layer_order == 'mhsa':
features, paddings, atten_probs = self._SelfAtten(
theta, features, paddings)
elif p.layer_order == 'conv':
features, paddings = self._LConv(theta, features, paddings)
elif p.layer_order == 'mhsa_before_conv':
features, paddings, atten_probs = self._SelfAtten(
theta, features, paddings)
features, paddings = self._LConv(theta, features, paddings)
else:
assert p.layer_order == 'conv_before_mhsa'
features, paddings = self._LConv(theta, features, paddings)
features, paddings, atten_probs = self._SelfAtten(
theta, features, paddings)
features, paddings, aux_loss = self._MoeOrFFLayer(theta, 'fflayer_end',
features, paddings,
aux_loss)
features = self.final_ln.FProp(theta.final_ln, features)
out_nmap = in_nmap.DeepCopy()
if p.adapter_tpl:
adapter_in_map = in_nmap.DeepCopy()
adapter_in_map.features, adapter_in_map.padding = features, paddings
adapter_out_nmap = self.adapter.FProp(theta.adapter, adapter_in_map)
features, paddings = adapter_out_nmap.features, adapter_out_nmap.paddings
features, paddings = self._CastToFPropDtype((features, paddings))
out_nmap.features, out_nmap.paddings = features, paddings
if aux_loss is not None:
out_nmap.aux_loss = aux_loss
self._AddAttentionSummaries(p.name, atten_probs)
return out_nmap
[docs] def FProp(self, theta, in_nmap):
p = self.params
if not p.remat:
return self._FProp(theta, in_nmap)
def CellFn(theta, state0, unused_inputs):
out_nmap = self._FProp(theta, state0)
return out_nmap, py_utils.NestedMap()
_, state1 = recurrent.Recurrent(
theta=theta,
state0=in_nmap,
inputs=py_utils.NestedMap(
inputs=tf.zeros([1, 0])), # A dummy input of shape [T, ?].
cell_fn=CellFn,
allow_implicit_capture=p.allow_implicit_capture)
return state1
[docs] def zero_state(self, batch_size):
if self.params.is_causal:
lconv_state = py_utils.NestedMap()
atten_state = py_utils.NestedMap()
if self.has_lconv:
with tf.name_scope('lconv'):
lconv_state = self.lconv.zero_state(batch_size)
if self.has_mhsa:
with tf.name_scope('atten'):
atten_state = self.trans_atten.zero_state(batch_size)
return py_utils.NestedMap(
lconv_state=lconv_state, atten_state=atten_state)
else:
return py_utils.NestedMap()
[docs] def StreamStep(self, theta, inputs, paddings, state0):
"""Streams t steps.
Args:
theta: A NestedMap of read-only layer params.
inputs: A tensor of shape [b, t, d].
paddings: A 0/1 valued tensor of shape [b, t].
state0: A NestedMap of tensors of the same struct as returned by
zero_state().
Returns:
outputs:A tensor of shape [b, t, d].
padding: the same as input paddings.
state1: A NestedMap of tensors of the same struct as state0.
"""
p = self.params
assert p.is_causal
assert not p.remat
with tf.name_scope(f'{p.name}/StreamStep'):
features, aux_loss = inputs, None
if self.has_fflayer_start:
features, paddings, aux_loss = self._MoeOrFFLayer(
theta, 'fflayer_start', features, paddings, aux_loss)
if p.layer_order == 'mhsa':
features, paddings, atten_state1 = self.trans_atten.StreamStep(
theta.trans_atten, features, paddings, state0.atten_state)
elif p.layer_order == 'conv':
features, paddings, lconv_state1 = self.lconv.StreamStep(
theta.lconv, features, paddings, state0.lconv_state)
elif p.layer_order == 'mhsa_before_conv':
features, paddings, atten_state1 = self.trans_atten.StreamStep(
theta.trans_atten, features, paddings, state0.atten_state)
features, paddings, lconv_state1 = self.lconv.StreamStep(
theta.lconv, features, paddings, state0.lconv_state)
else:
assert p.layer_order == 'conv_before_mhsa'
features, paddings, lconv_state1 = self.lconv.StreamStep(
theta.lconv, features, paddings, state0.lconv_state)
features, paddings, atten_state1 = self.trans_atten.StreamStep(
theta.trans_atten, features, paddings, state0.atten_state)
if not self.has_lconv:
lconv_state1 = py_utils.NestedMap()
if not self.has_mhsa:
atten_state1 = py_utils.NestedMap()
features, paddings, _ = self._MoeOrFFLayer(theta, 'fflayer_end', features,
paddings, aux_loss)
outputs = self.final_ln.FProp(theta.final_ln, features)
state1 = py_utils.NestedMap(
lconv_state=lconv_state1, atten_state=atten_state1)
return outputs, paddings, state1
[docs] def _AddAttentionSummaries(self, name, atten_probs):
# Plots attention prob summaries for joint network.
# TODO(ankurbpn): Check why op wasn't compiling on TPUs.
p = self.params
if not py_utils.use_tpu(
) and p.allow_attention_summaries and atten_probs is not None:
atten_shape = tf.shape(atten_probs)
atten_probs = tf.reshape(
atten_probs, [atten_shape[0], atten_shape[1], -1, atten_shape[-1]])
# Only plots first example of the batch.
atten_probs = tf.reduce_mean(atten_probs[0:1, :, :, :], 1)
self._AddAttenProbsImageSummary(name, atten_probs)
self._AddAttenProbsHistogramSummary(name, atten_probs)
[docs] def _AddAttenProbsHistogramSummary(self, name, atten_probs):
"""Add histogram summary of attention probs."""
summary_utils.histogram(name + '/atten_probs', atten_probs)
[docs] def _AddAttenProbsImageSummary(self, name, atten_probs):
"""Add image summary of input attention probabilities."""
def PlotAttention(fig, axes, cur_atten_probs, title):
plot.AddImage(fig, axes, cur_atten_probs, title=title)
axes.set_ylabel(plot.ToUnicode('Output sequence index'), wrap=True)
axes.set_xlabel(plot.ToUnicode('Input sequence index'), wrap=True)
with plot.MatplotlibFigureSummary(
name + '/atten_example',
figsize=(10, 10),
max_outputs=1,
subplot_grid_shape=(1, 1)) as fig:
# Extract first entry in batch of attention prob matrices
# [tgt_len, src_len]
fig.AddSubplot([atten_probs], PlotAttention, title='atten_probs')
[docs]def ApplyGshard(conformer_tpl,
device_mesh=None,
proj_w_split_list=None,
proj_activation_split_list=None,
atten_dnh_w_split=None,
atten_blnh_activation_split=None,
atten_bld_activation_split=None,
lconv_df_w_split=None,
lconv_hwim_w_split=None,
lconv_fd_w_split=None,
lconv_blf_activation_split=None,
lconv_bld_activation_split=None,
moe_device_mesh_shape=None,
moe_emh_split=None,
moe_ehm_split=None,
moe_eah_split=None,
moe_eam_split=None,
moe_egcm_split=None,
moe_gecm_split=None,
moe_gsec_split=None,
moe_blm_split=None):
"""Applies gshard on conformer params.
Args:
conformer_tpl: A NestedMap of conformer Params.
device_mesh: A numpy.ndarray specifying the device mesh on which the
computation is sharded.
proj_w_split_list: A list of mesh split specifying how weights are sharded
for fflayer.
proj_activation_split_list: A list of mesh split specifying how activations
are sharded for fflayer.
atten_dnh_w_split: Mesh split of the attention projection weight with the
shape of [model_dim, num_heads, dim_per_head].
atten_blnh_activation_split: Mesh split of the attention activation with
shape of [batch, seq_len, num_heads, dim_per_head].
atten_bld_activation_split: Mesh split of the attention activation with
shape of [batch, seq_len, model_dim].
lconv_df_w_split: Mesh split of the weights in lconv with the shape of
[model_dim, ff_hidden_dim].
lconv_hwim_w_split: Mesh split of the depthwise conv weight in lconv with
the shape of [height, width, in_channels, channel_multiplier].
lconv_fd_w_split: Mesh split of the weights in lconv with the shape of
[ff_hidden_dim, model_dim].
lconv_blf_activation_split: Mesh split of the activations in lconv with the
shape of [batch, seq_len, ff_hidden_dim].
lconv_bld_activation_split: Mesh split of the activations in lconv with the
shape of [batch, seq_len, model_dim].
moe_device_mesh_shape: Device mesh shape for MoE params.
moe_emh_split: split dimension for MoE EMH activation. See: gshard_layers.
moe_ehm_split: split dimension for MoE EHM activation. See: gshard_layers.
moe_eah_split: split dimension for MoE EAH activation. See: gshard_layers.
moe_eam_split: split dimension for MoE EAH activation. See: gshard_layers.
moe_egcm_split: split dimension for MoE EGCM activation. See: gshard_layers.
moe_gecm_split: split dimension for MoE GECM activation. See: gshard_layers.
moe_gsec_split: split dimension for MoE GSEC activation. See: gshard_layers.
moe_blm_split: split dimension for MoE BLM activation. See: gshard_layers.
Returns:
The updated conformer_tpl.
"""
# Not all attention class supports gshard. If not, errors would be throw here.
conformer_tpl.trans_atten_tpl.atten_tpl.device_mesh = device_mesh
conformer_tpl.trans_atten_tpl.atten_tpl.weight_split_dims_mapping = (
atten_dnh_w_split)
conformer_tpl.trans_atten_tpl.atten_tpl.proj_tpl.weight_split_dims_mapping = (
atten_dnh_w_split)
conformer_tpl.trans_atten_tpl.atten_tpl.activation_split_dims_mapping.blnh = (
atten_blnh_activation_split)
conformer_tpl.trans_atten_tpl.atten_tpl.activation_split_dims_mapping.bld = (
atten_bld_activation_split)
# TODO(jamesqin): support residual_proj xla sharding too.
def _ApplyGshardToMoELayer(fflayer_tpl):
if not issubclass(fflayer_tpl.cls, gshard_builder.DenseBuilder):
return
fflayer_tpl.device_mesh = device_mesh
fflayer_tpl.device_mesh_shape = moe_device_mesh_shape
fflayer_tpl.emh_split = moe_emh_split
fflayer_tpl.ehm_split = moe_ehm_split
fflayer_tpl.eah_split = moe_eah_split
fflayer_tpl.eam_split = moe_eam_split
fflayer_tpl.egcm_split = moe_egcm_split
fflayer_tpl.gecm_split = moe_gecm_split
fflayer_tpl.gsec_split = moe_gsec_split
fflayer_tpl.blm_split = moe_blm_split
# Set sharding in FFLayers.
if not issubclass(conformer_tpl.fflayer_start_tpl.cls,
gshard_builder.MoEBuilder):
conformer_tpl.fflayer_start_tpl.fflayer_tpl.Set(
device_mesh=device_mesh,
weight_split_dims_mapping_list=proj_w_split_list,
activation_split_dims_mapping_list=proj_activation_split_list)
else:
_ApplyGshardToMoELayer(conformer_tpl.fflayer_start_tpl)
if not issubclass(conformer_tpl.fflayer_end_tpl.cls,
gshard_builder.MoEBuilder):
conformer_tpl.fflayer_end_tpl.fflayer_tpl.Set(
device_mesh=device_mesh,
weight_split_dims_mapping_list=proj_w_split_list,
activation_split_dims_mapping_list=proj_activation_split_list)
else:
_ApplyGshardToMoELayer(conformer_tpl.fflayer_end_tpl)
# Set sharding in LConv layer.
conformer_tpl.lconv_tpl.Set(
split_act_gated_linear_start=True, device_mesh=device_mesh)
lconv_w_split = conformer_tpl.lconv_tpl.weight_split_dims_mapping
lconv_w_split.df = lconv_df_w_split
lconv_w_split.hwim = lconv_hwim_w_split
lconv_w_split.fd = lconv_fd_w_split
lconv_activation_split = conformer_tpl.lconv_tpl.activation_split_dims_mapping
lconv_activation_split.blf = lconv_blf_activation_split
lconv_activation_split.bld = lconv_bld_activation_split
return conformer_tpl