Source code for lingvo.core.conformer_layer

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Conformer layers as in https://arxiv.org/abs/2005.08100."""

from lingvo import compat as tf

from lingvo.core import activations
from lingvo.core import base_layer
from lingvo.core import batch_major_attention as attention_lib
from lingvo.core import bn_layers
from lingvo.core import conv_layers_with_time_padding
from lingvo.core import gshard_builder
from lingvo.core import gshard_utils
from lingvo.core import hyperparams as hparams_lib
from lingvo.core import layers
from lingvo.core import layers_with_attention
from lingvo.core import plot
from lingvo.core import py_utils
from lingvo.core import recurrent
from lingvo.core import summary_utils


[docs]class LConvLayer(base_layer.BaseLayer): r"""Lightweight conv layer. architecture:: input / \ | ln(.) # input_dim | fflayer(.) # 2 * input_dim | | | glu(.) # input_dim | depthwise_conv_1d(.) | norm(.) | act(.) | | | fflayer(.) | dropout(.) \ / + | output """
[docs] @classmethod def Params(cls): p = super().Params() p.Define('input_dim', None, 'Input and (in fact,) output dimension.') p.Define('kernel_size', None, 'Kernel size of 1d deptwise conv.') p.Define('conv_activation', 'SWISH', 'Activation after normalization.') p.Define( 'is_causal', False, 'Whether this is a causal layer.' 'If set to true, use ' 'conv_layers_with_time_padding.CausalDepthwiseConv2DLayer for ' '`depthwise_conv_tpl`.') p.Define( 'glu_activation', 'NONE', 'Activation in GLU. Check lingvo.core.activations._ACTIVATIONS for ' 'other options.') p.Define('dropout_prob', 0., 'Dropout probability.') p.Define('ln_tpl', layers.LayerNorm.Params(), 'Input layer norm template.') p.Define('linear_start_tpl', layers.FCLayer.Params(), 'Linear start layer.') p.Define( 'depthwise_conv_tpl', conv_layers_with_time_padding.DepthwiseConv2DLayer.Params(), 'Depthwise conv template. For causal layer, use ' 'conv_layers_with_time_padding.CausalDepthwiseConv2DLayer.') p.Define('conv_norm_layer_tpl', bn_layers.BatchNormLayer.Params(), 'Normalization layer after conv.') p.Define('linear_end_tpl', layers.FCLayer.Params(), 'Linear end layer.') p.Define('dropout_tpl', layers.DropoutLayer.Params(), 'Residual dropout layer.') p.Define( 'split_act_gated_linear_start', False, 'Separate act and gated linear start to remove data formatting ' 'overheads') p.linear_start_tpl.Set(activation='NONE', has_bias=True) p.linear_end_tpl.Set(activation='NONE', has_bias=True) # SPMD partition related params. # # d - model_dim # f - ff_hidden_dim (here ff_hidden_dim has the same size as model_dim) # h - height # w - width # i - in_channels # m - channel_multiplier # b - batch_size # l - seq_len p.weight_split_dims_mapping = hparams_lib.Params() wp = p.weight_split_dims_mapping wp.Define( 'df', None, 'Mesh split for lconv linear start weight with the shape of ' '[model_dim, ff_hidden_dim], the default hidden_dim is the same as ' 'the model_dim.') wp.Define( 'hwim', None, 'Mesh split for lconv depthwise conv weight with the shape of ' '[height, width, in_channels, channel_multiplier]. Width and ' 'channel_multiplier are both 1 for the common use case.') wp.Define( 'fd', None, 'Mesh split for lconv linear end weight with the shape of ' '[ff_hidden_dim, model_dim], the default hidden_dim is the same as ' 'the model_dim.') p.activation_split_dims_mapping = hparams_lib.Params() ap = p.activation_split_dims_mapping ap.Define( 'blf', None, 'Mesh split for lconv linear start activation and lconv ' 'depthwise conv after normalization with the shape of ' '[batch_size, seq_len, ff_hidden_dim], the default hidden_dim is the ' 'same as model_dim.') ap.Define( 'bld', None, 'Mesh split for lconv linear end activation with the shape of ' '[batch_size, seq_len, model_dim].') return p
[docs] @classmethod def SetCanonicalShardingParams(cls, params): """Set up canonical SPMD sharding params.""" assert params.device_mesh.ndim >= 2 wp = params.weight_split_dims_mapping wp.df = [0, 1] wp.hwim = [-1, -1, 1, -1] # TODO(shibow/rpang): understand the effects of fd sharding, especially why # [-1, -1] performs better when bld is [0, -1, -1]. wp.fd = [1, 0] ap = params.activation_split_dims_mapping ap.blf = [0, -1, 1] ap.bld = [1, -1, -1]
[docs] @classmethod def CommonParams(cls, input_dim=None, kernel_size=None, is_causal=False, conv_activation='SWISH', dropout_prob=0.): p = cls.Params().Set( input_dim=input_dim, is_causal=is_causal, kernel_size=kernel_size, conv_activation=conv_activation, dropout_prob=dropout_prob) if is_causal: p.depthwise_conv_tpl = ( conv_layers_with_time_padding.CausalDepthwiseConv2DLayer.Params()) return p
[docs] @classmethod def SetFPropDtype(cls, p, fprop_dtype): p.fprop_dtype = fprop_dtype if fprop_dtype == tf.bfloat16 and not py_utils.use_tpu(): # Depthwise conv supports bfloat16 only on TPUs. p.depthwise_conv_tpl.fprop_dtype = tf.float32 if issubclass(p.conv_norm_layer_tpl.cls, bn_layers.BatchNormLayer): # Batch norm does not support bfloat16 on TPUs. p.conv_norm_layer_tpl.fprop_dtype = tf.float32 return p
def __init__(self, params): super().__init__(params) p = self.params ln_p = p.ln_tpl.Copy().Set(name='ln', input_dim=p.input_dim) self.CreateChild('ln', ln_p) if p.split_act_gated_linear_start: linear_start_act_p = p.linear_start_tpl.Copy().Set( input_dim=p.input_dim, output_dim=p.input_dim, device_mesh=p.device_mesh, weight_split_dims_mapping=p.weight_split_dims_mapping.df, activation_split_dims_mapping=p.activation_split_dims_mapping.blf) linear_start_gated_p = p.linear_start_tpl.Copy().Set( input_dim=p.input_dim, output_dim=p.input_dim, device_mesh=p.device_mesh, weight_split_dims_mapping=p.weight_split_dims_mapping.df, activation_split_dims_mapping=p.activation_split_dims_mapping.blf) self.CreateChild('linear_start_act', linear_start_act_p) self.CreateChild('linear_start_gated', linear_start_gated_p) else: linear_start_p = p.linear_start_tpl.Copy().Set( name='linear_start', input_dim=p.input_dim, output_dim=2 * p.input_dim) self.CreateChild('linear_start', linear_start_p) linear_end_p = p.linear_end_tpl.Copy().Set( name='linear_end', input_dim=p.input_dim, output_dim=p.input_dim, device_mesh=p.device_mesh, weight_split_dims_mapping=p.weight_split_dims_mapping.fd, activation_split_dims_mapping=p.activation_split_dims_mapping.bld) self.CreateChild('linear_end', linear_end_p) if p.conv_norm_layer_tpl.cls == layers.LayerNorm: norm_p = p.conv_norm_layer_tpl.Copy().Set( name='norm_layer', input_dim=p.input_dim) else: norm_p = p.conv_norm_layer_tpl.Copy().Set( name='norm_layer', dim=p.input_dim) if p.conv_norm_layer_tpl.cls == bn_layers.GroupNormLayer: norm_p.cumulative = p.is_causal self.CreateChild('norm', norm_p) if (p.is_causal and p.depthwise_conv_tpl.cls == conv_layers_with_time_padding.DepthwiseConv2DLayer): # If causal, switch to causal depthwise conv. depthwise_conv_p = ( conv_layers_with_time_padding.CausalDepthwiseConv2DLayer.Params()) hparams_lib.CopyFieldsTo(p.depthwise_conv_tpl, depthwise_conv_p) else: depthwise_conv_p = p.depthwise_conv_tpl.Copy() if issubclass(depthwise_conv_p.cls, conv_layers_with_time_padding.DepthwiseConv2DLayer): depthwise_conv_p.filter_shape = (p.kernel_size, 1, p.input_dim, 1) else: depthwise_conv_p.filter_shape = (p.kernel_size, 1, p.input_dim, p.input_dim) # 1d depthwise conv with channel_mulitplier = 1 depthwise_conv_p.Set( name='depthwise_conv', filter_stride=(1, 1)) self.CreateChild('depthwise_conv1d', depthwise_conv_p) dropout_p = p.dropout_tpl.Copy().Set( name='dropout', keep_prob=1. - p.dropout_prob) self.CreateChild('dropout', dropout_p)
[docs] def _GLU(self, gated_inputs, act_inputs): p = self.params return self._ApplyActivation(act_inputs, p.glu_activation) * tf.sigmoid(gated_inputs)
[docs] def _ApplyActivation(self, inputs, act_name): if act_name == 'NONE': return inputs return activations.GetFn(act_name)(inputs)
[docs] def _Normalize(self, theta, inputs, paddings): """Applies normalization. Args: theta: A NestedMap of layer params. inputs: [b, t, 1, d]. paddings: [b, t]. Returns: A Tensor of shape [b, t, d]. """ if isinstance(self.norm, bn_layers.GroupNormLayer): assert self.norm.params.input_rank == 4 inputs, _ = self.norm.FProp(theta.norm, inputs, paddings) # [b, t, d] inputs = tf.squeeze(inputs, 2) else: # [b, t, 1, d] -> [b, t, d] inputs = tf.squeeze(inputs, 2) if isinstance(self.norm, bn_layers.BatchNormLayer): inputs = self.norm.FProp(theta.norm, inputs, paddings) elif isinstance(self.norm, layers.LayerNorm): inputs = self.norm.FProp(theta.norm, inputs) else: raise NotImplementedError( 'Only bn_layers.{BatchNormLayer,GroupNormLayer}, layers.LayerNorm ' 'are supported.') return self._CastToFPropDtype(inputs)
[docs] def FProp(self, theta, inputs, paddings): """Builds FProp graph. Args: theta: A NestedMap of Tensors, see base class. inputs: A Tensor of shape [batch, seqlen, dim0]. paddings: A Tensor of shape [batch, seqlen]. Returns: output: A Tensor of shape [batch, seqlen, dim0]. out_paddings: A Tensor of shape [batch, seqlen]. """ p = self.params with tf.name_scope(p.name): inputs, paddings = self._CastToFPropDtype((inputs, paddings)) unnormalized_inputs = inputs inputs = self.ln.FProp(theta.ln, inputs) inputs = self._CastToFPropDtype(inputs) if p.split_act_gated_linear_start: act_inputs = self.linear_start_act.FProp(theta.linear_start_act, inputs) gated_inputs = self.linear_start_gated.FProp(theta.linear_start_gated, inputs) else: inputs = self.linear_start.FProp(theta.linear_start, inputs) gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1) inputs = self._GLU(gated_inputs, act_inputs) # TODO(jamesqin): introduce depthwise conv2d with 3d inputs. # [b, t, d] --> [b, t, 1, d] inputs = tf.expand_dims(inputs, 2) adapted_blf_dims_mapping = None if p.activation_split_dims_mapping.blf is not None: adapted_blf_dims_mapping = p.activation_split_dims_mapping.blf.copy() adapted_blf_dims_mapping.insert(2, -1) inputs = gshard_utils.MeshSplit(inputs, p.device_mesh, adapted_blf_dims_mapping) theta.depthwise_conv1d.w = gshard_utils.MeshSplit( theta.depthwise_conv1d.w, p.device_mesh, p.weight_split_dims_mapping.hwim) if inputs.dtype == tf.bfloat16 and not py_utils.use_tpu(): # Depthwise conv doesn't support bfloat32 on CPU. inputs = tf.cast(inputs, tf.float32) paddings = tf.cast(paddings, tf.float32) inputs, paddings = self.depthwise_conv1d.FProp(theta.depthwise_conv1d, inputs, paddings) inputs, paddings = self._CastToFPropDtype((inputs, paddings)) inputs = gshard_utils.MeshSplit(inputs, p.device_mesh, adapted_blf_dims_mapping) inputs = self._Normalize(theta, inputs, paddings) inputs = gshard_utils.MeshSplit(inputs, p.device_mesh, p.activation_split_dims_mapping.blf) inputs = self._ApplyActivation(inputs, p.conv_activation) inputs = self.linear_end.FProp(theta.linear_end, inputs) inputs = self.dropout.FProp(theta.dropout, inputs) output = inputs + unnormalized_inputs return output, paddings
[docs] def zero_state(self, batch_size): p = self.params with tf.name_scope('zero_state'): if p.is_causal: with tf.name_scope('depthwise_conv1d'): res = py_utils.NestedMap( conv_state=self.depthwise_conv1d.zero_state(batch_size)) if hasattr(self.norm, 'zero_state'): with tf.name_scope('norm'): res.norm_state = self.norm.zero_state(batch_size) return res else: # If not causal, depthwise_conv1d does not have zero_state(). return py_utils.NestedMap()
[docs] def _NormalizeStep(self, theta, inputs, paddings, state0, state1): if hasattr(self.norm, 'StreamStep'): # TODO(jamesqin): support 3d inputs. # At present it's guaranteed GroupNorm. assert (isinstance(self.norm, bn_layers.GroupNormLayer) and self.norm.params.input_rank == 4) inputs, paddings, norm_state1 = self.norm.StreamStep( theta.norm, inputs, paddings, state0.norm_state) # [b, t, d] inputs = tf.squeeze(inputs, 2) state1.norm_state = norm_state1 else: # [b, t, 1, d] -> [b, t, d] inputs = tf.squeeze(inputs, 2) if isinstance(self.norm, layers.LayerNorm): inputs = self.norm.FProp(theta.norm, inputs) else: raise NotImplementedError( 'Only bn_layers.GroupNormLayer, layers.LayerNorm are supported.') # [b, t, d] return inputs, paddings
[docs] def StreamStep(self, theta, inputs, paddings, state0): """Streams t steps. Args: theta: A NestedMap of layer params. inputs: [b, t, d]. paddings: A 0/1 valued tensor of shape [b, t]. state0: A NestedMap of tensors of the same struct as returned by zero_state(). Returns: outputs: A NestedMap of tensors consisting: padding: the same as input paddings. state1: A NestedMap of tensors of the same struct as state0. """ p = self.params assert p.is_causal state1 = py_utils.NestedMap() with tf.name_scope(f'{p.name}/StreamStep'): unnormalized_inputs = inputs inputs = self.ln.FProp(theta.ln, inputs) if p.split_act_gated_linear_start: act_inputs = self.linear_start_act.FProp(theta.linear_start_act, inputs) gated_inputs = self.linear_start_gated.FProp(theta.linear_start_gated, inputs) else: inputs = self.linear_start.FProp(theta.linear_start, inputs) gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1) inputs = self._GLU(gated_inputs, act_inputs) # TODO(jamesqin): introduce depthwise conv2d with 3d inputs. # TODO(jamesqin): optimize DepthwiseConv1D.StreamStep() # [b, t, d] --> [b, t, 1, d] inputs = tf.expand_dims(inputs, 2) # [b, t, 1, d] inputs, paddings, conv_state1 = self.depthwise_conv1d.StreamStep( theta.depthwise_conv1d, inputs, paddings, state0.conv_state) state1.conv_state = conv_state1 # [b, t, d] inputs, paddings = self._NormalizeStep(theta, inputs, paddings, state0, state1) inputs = self._ApplyActivation(inputs, p.conv_activation) inputs = self.linear_end.FProp(theta.linear_end, inputs) inputs = self.dropout.FProp(theta.dropout, inputs) output = inputs + unnormalized_inputs return output, paddings, state1
[docs]def _AttenCtxIsSet(atten_context): return atten_context is not None and atten_context >= 0
[docs]def GShardMoELayerParams(num_devices, num_groups, num_experts, per_expert_capacity_dim, use_densebuilder=False): if use_densebuilder: moe_cls = gshard_builder.DenseBuilder else: moe_cls = gshard_builder.MoEBuilder return moe_cls.Params().Set( num_devices=num_devices, num_groups=num_groups, e_dim=num_experts, c_dim=per_expert_capacity_dim)
[docs]class ConformerLayer(base_layer.BaseLayer): """Conformer layer as in https://arxiv.org/abs/2005.08100. Canonical version (with default params.) x = x + 1/2 * FFN(x) x = x + MHSA(x) x = x + Lconv(x) x = x + 1/2 * FFN(x) y = ln(x) Optionally one can change the order of MHSA and conv. """
[docs] @classmethod def Params(cls): p = super().Params() p.Define('input_dim', None, 'Input dimension.') p.Define( 'is_causal', False, 'If use causal lconv and MHSA layer.' 'Notice atten_right_context must be not be infinite(None) if is_causal ' 'is True. It is important to always set is_causal for streaming case, ' 'and not expect to infer from atten_{left,right}_context.') p.Define( 'layer_order', 'mhsa_before_conv', 'Only mhsa, conv, mhsa_before_conv or conv_before_mhsa are ' 'supported.') p.Define('dropout_prob', None, 'Dropout prob of inner components.') # fflayer tpl p.Define( 'fflayer_start_tpl', layers_with_attention.TransformerFeedForwardLayer.Params(), 'Layer params for Feed forward layer at the beginning. Supports ' 'using gshard_builder.MoEBuilder.Params() as well wherein the ' 'MoE() will be used. If set to None, this layer is excluded.') p.Define('trans_atten_tpl', attention_lib.TransformerAttentionLayer.Params(), 'Self attention layer params.') p.Define( 'lconv_tpl', LConvLayer.Params(), 'Convolution module params. If set to None, this layer is excluded.') p.Define( 'fflayer_end_tpl', layers_with_attention.TransformerFeedForwardLayer.Params(), 'Layer params for Feed forward layer at the end. Supports using ' 'gshard_builder.MoEBuilder.Params() as well wherein the MoE() ' 'will be used.') p.Define( 'fflayer_weight_sharing', False, 'If True, will ignore `fflayer_end_tpl`, and will make the fflayer_end ' 'layer as a weight-shared copy of the fflayer_start layer.') p.Define('final_ln_tpl', layers.LayerNorm.Params(), 'Final layer norm.') # If adapter_tpl is set, layer_out = adapter(conformer(layer_in)) # The adapter must # 1. have instance method FProp(self, theta, in_nmap) -> out_nmap, where # {in,out}_nmap must have 'features' and 'paddings' and $adapter_p.task_ids # fields. # 2. have class method SetInputDim(cls, p, input_dim) p.Define('adapter_tpl', None, 'If set, runs an adapter layer in the end.') # https://b/167460492#comment16 p.Define( 'remat', False, 'If to rematerialize the layer. If true, ' 'intermediate tensors are not saved in FProp().') p.Define( 'list_regex_dtypes', [], 'A list of (regex, dtype) to set the data types of variables using ' 'regex. The default value is [] to use the existing data types without ' 'any changes. If a variable name matches the first regex in the list, ' 'the variable data type will be set by the corresponding dtype.') p.Define('allow_attention_summaries', False, 'Allow plotting attention histogram and plot summaries.') return p
[docs] @classmethod def CommonParams(cls, *, input_dim=None, is_causal=False, atten_num_heads=None, atten_local_context=None, atten_left_context=None, atten_right_context=None, atten_chunk_size=None, atten_logit_cap=0.0, use_relative_atten=True, query_stride=1, kernel_size=None, fflayer_hidden_dim=None, fflayer_activation='SWISH', fflayer_residual_weight=0.5, layer_order='mhsa_before_conv', dropout_prob=0., conv_norm_layer_tpl=None, fprop_dtype=None, use_moe_in_fflayer_start=False, use_moe_in_fflayer_end=False, moe_num_partitions=None, moe_num_experts=None, moe_num_groups=None, moe_per_capacity_dim=None, fflayer_start_tpl=None, fflayer_end_tpl=None, trans_atten_tpl=None, lconv_tpl=None, list_regex_dtypes=None): assert all([input_dim]) if layer_order != 'conv': assert atten_num_heads or trans_atten_tpl if layer_order == 'mhsa': assert not any([kernel_size, conv_norm_layer_tpl, lconv_tpl]) else: assert kernel_size if _AttenCtxIsSet(atten_local_context): assert not _AttenCtxIsSet(atten_left_context) and not _AttenCtxIsSet( atten_right_context ), ('atten_local_context and atten_{left,right}_context can not be set' 'at the same time.') atten_left_context = atten_local_context + 1 # including self position. atten_right_context = atten_local_context if is_causal and trans_atten_tpl is None: # None is different from 0, the former is 'infinite'. assert atten_right_context is not None, ( 'is_causal is not compatible with infinite atten_right_context ' '(None).') p = cls.Params().Set( input_dim=input_dim, is_causal=is_causal, layer_order=layer_order, dropout_prob=dropout_prob) if use_moe_in_fflayer_start: assert fflayer_start_tpl is None fflayer_start_tpl = GShardMoELayerParams(moe_num_partitions, moe_num_experts, moe_num_groups, moe_per_capacity_dim) if use_moe_in_fflayer_end: assert fflayer_end_tpl is None fflayer_end_tpl = GShardMoELayerParams(moe_num_partitions, moe_num_experts, moe_num_groups, moe_per_capacity_dim) def _ConfigureFF(fflayer_tpl): config_kwargs = dict( input_dim=input_dim, hidden_dim=fflayer_hidden_dim, activation=fflayer_activation, residual_weight=fflayer_residual_weight, dropout_prob=dropout_prob) if fflayer_tpl is None: return cls.ConfigFFLayer( tpl=layers_with_attention.TransformerFeedForwardLayer.Params(), **config_kwargs) elif issubclass(fflayer_tpl.cls, gshard_builder.MoEBuilder): # TODO(rpang): make users call ConfigMoEParams directly. return cls.ConfigMoEParams(tpl=fflayer_tpl, **config_kwargs) elif fflayer_tpl.use_block_diagonal_matmul_pl: # Override params other than block_diag_matmul options. return cls.ConfigFFLayer(tpl=fflayer_tpl, **config_kwargs) else: assert fflayer_hidden_dim is None, fflayer_hidden_dim assert fflayer_activation is None, fflayer_activation assert fflayer_residual_weight is None, fflayer_residual_weight return fflayer_tpl # Set the two feed forward modules. p.fflayer_start_tpl = _ConfigureFF(fflayer_start_tpl) p.fflayer_end_tpl = _ConfigureFF(fflayer_end_tpl) # Set the MHSA module. if trans_atten_tpl is not None: assert atten_left_context is None assert atten_right_context is None assert use_relative_atten is None assert atten_num_heads is None assert query_stride == 1 p.trans_atten_tpl = trans_atten_tpl else: atten_tpl = cls._ConfigSelfAttenContext( atten_left_context, atten_right_context, atten_chunk_size=atten_chunk_size, use_relative_atten=use_relative_atten, query_stride=query_stride, relative_pos_emb_dim=input_dim) if query_stride == 1: p.trans_atten_tpl = attention_lib.TransformerAttentionLayer.Params( ).Set( atten_tpl=atten_tpl, num_heads=atten_num_heads) else: p.trans_atten_tpl = attention_lib.FunnelTransformerAttentionLayer.Params( ).Set( atten_tpl=atten_tpl, num_heads=atten_num_heads) p.trans_atten_tpl.funnel_tpl.stride = query_stride p.trans_atten_tpl.res_funnel_tpl.stride = query_stride # Set the convolution module. if lconv_tpl is not None: p.lconv_tpl = lconv_tpl if kernel_size: p.lconv_tpl.kernel_size = kernel_size if conv_norm_layer_tpl is not None: p.lconv_tpl.conv_norm_layer_tpl = conv_norm_layer_tpl if fprop_dtype is not None: p.cls.SetFPropDtype(p, fprop_dtype) if isinstance(p.trans_atten_tpl.atten_tpl, list): for a in p.trans_atten_tpl.atten_tpl: a.atten_logit_cap = atten_logit_cap else: p.trans_atten_tpl.atten_tpl.atten_logit_cap = atten_logit_cap if list_regex_dtypes is not None: p.list_regex_dtypes = list_regex_dtypes if layer_order == 'mhsa': p.lconv_tpl = None return p
[docs] @classmethod def Stride(cls, params): if 'funnel_tpl' in params.trans_atten_tpl: return params.trans_atten_tpl.funnel_tpl.stride return 1
[docs] @classmethod def NumOutputNodes(cls, p): return p.input_dim
[docs] @classmethod def _ConfigSelfAttenContext(cls, atten_left_context, atten_right_context, *, use_relative_atten, atten_chunk_size, query_stride, relative_pos_emb_dim): # TODO(jamesqin): add an attention factory in batch_major_attention. if atten_chunk_size is not None: if use_relative_atten: atten_type = 'chunk_relative' else: atten_type = 'chunk' elif not _AttenCtxIsSet(atten_left_context) and not _AttenCtxIsSet( atten_right_context): # No atten context set, each position attends to all positions. atten_type = 'global' if not use_relative_atten else 'global_relative' elif not _AttenCtxIsSet(atten_left_context) and atten_right_context == 0: # Left context is infinite, right context is 0. assert not use_relative_atten, ( 'Relative attention isn\'t supported for causal attention.') atten_type = 'global_causal' else: atten_type = 'local_relative' if use_relative_atten else 'local' if atten_type == 'global_relative': atten_tpl = ( attention_lib.MultiHeadedAttentionXL.Params().Set( rel_pos_emb_dim=relative_pos_emb_dim)) elif atten_type == 'local_relative': atten_tpl = attention_lib.LocalSelfAttentionXL.Params().Set( left_context=atten_left_context, right_context=atten_right_context, rel_pos_emb_dim=relative_pos_emb_dim, query_stride=query_stride) elif atten_type == 'local': atten_tpl = attention_lib.LocalSelfAttention.Params().Set( left_context=atten_left_context, right_context=atten_right_context, query_stride=query_stride) elif atten_type == 'chunk': atten_tpl = attention_lib.ChunkwiseSelfAttention.Params().Set( left_context=atten_left_context, right_context=atten_right_context, chunk_size=atten_chunk_size) elif atten_type == 'chunk_relative': atten_tpl = attention_lib.ChunkwiseSelfAttentionXL.Params().Set( left_context=atten_left_context, right_context=atten_right_context, chunk_size=atten_chunk_size, rel_pos_emb_dim=relative_pos_emb_dim) else: # No op for 'global' atten assert atten_type in ('global', 'global_causal'), ( f'Unknown atten_type {atten_type}') atten_tpl = attention_lib.MultiHeadedAttention.Params() return atten_tpl
[docs] @classmethod def SetFPropDtype(cls, p, fprop_dtype): p.fprop_dtype = fprop_dtype for sub_p in (p.lconv_tpl, p.trans_atten_tpl, p.fflayer_start_tpl, p.fflayer_end_tpl): if sub_p is not None: sub_p.cls.SetFPropDtype(sub_p, fprop_dtype) return p
[docs] @classmethod def ConfigFFLayer(cls, *, tpl, input_dim, hidden_dim, activation, residual_weight, dropout_prob): """Configures tpl params for a feed-forward layer.""" p = tpl.Copy().Set( input_dim=input_dim, hidden_dim=hidden_dim, activation=activation, residual_weight=residual_weight, residual_dropout_prob=dropout_prob, relu_dropout_prob=dropout_prob) return p
[docs] @classmethod def ConfigMoEParams(cls, *, tpl, input_dim, hidden_dim, activation, residual_weight, dropout_prob): """Configures tpl params for a MoE layer.""" moe_builder_p = tpl.Copy().Set( model_dim=input_dim, dropout_rate=dropout_prob, moe_hidden_dim=hidden_dim, moe_activation=activation) if moe_builder_p.cls == gshard_builder.MoEBuilder: if moe_builder_p.num_devices is None: raise ValueError('num_devices must be specified for MoEBuilder.') if residual_weight != 0.5: raise ValueError('residual_weight must be 0.5') return moe_builder_p
def __init__(self, params): super().__init__(params) p = self.params assert p.layer_order in [ 'mhsa', 'conv', 'mhsa_before_conv', 'conv_before_mhsa' ] if p.layer_order == 'mhsa': assert not self.has_lconv, 'mhsa must not have a lconv block.' # Change the variable dtypes by list_regex_dtypes. with py_utils.VariableListDtypeRegexScope(self.params.list_regex_dtypes): if self.has_fflayer_start: fflayer_start_p = self._ConfigFFLayerOrMoEParams( p.fflayer_start_tpl, 'fflayer_start') if fflayer_start_p.name: assert fflayer_start_p.name == 'fflayer_start_moe' else: fflayer_start_p.name = 'fflayer_start' self.CreateChild(fflayer_start_p.name, fflayer_start_p) fflayer_end_p = self._ConfigFFLayerOrMoEParams(p.fflayer_end_tpl, 'fflayer_end') if fflayer_end_p.name: assert fflayer_end_p.name == 'fflayer_end_moe' else: fflayer_end_p.name = 'fflayer_end' if not p.fflayer_weight_sharing: self.CreateChild(fflayer_end_p.name, fflayer_end_p) else: self.AddChild(fflayer_end_p.name, self.children[fflayer_start_p.name]) # For local MHSA, is_masked is ignored, thus it's safe to set is_masked # based on p.is_causal, for global and local MHSA cases. if self.has_mhsa: trans_atten_p = p.trans_atten_tpl.Copy().Set( input_dim=p.input_dim, is_masked=p.is_causal, atten_dropout_prob=p.dropout_prob, residual_dropout_prob=p.dropout_prob) if tf.logging.vlog_is_on(2): for line in trans_atten_p.atten_tpl.ToText().split('\n'): tf.logging.info('ConformerLayer.atten_tpl: %s', line) self.CreateChild('trans_atten', trans_atten_p) if self.has_lconv: lconv_p = p.lconv_tpl.Copy().Set( input_dim=p.input_dim, is_causal=p.is_causal) self.CreateChild('lconv', lconv_p) ln_p = p.final_ln_tpl.Copy().Set(name='final_ln', input_dim=p.input_dim) self.CreateChild('final_ln', ln_p) if p.adapter_tpl: p.adapter_tpl.cls.SetInputDim(p.adapter_tpl, p.input_dim) self.CreateChild('adapter', p.adapter_tpl) # lconv and fflayer_start have the special treatment, which can be absent, # because Transformer doesn't have those. @property def has_lconv(self): return bool(self.params.lconv_tpl) @property def has_mhsa(self): return bool('mhsa' in self.params.layer_order) @property def has_fflayer_start(self): return bool(self.params.fflayer_start_tpl)
[docs] def _ConfigFFLayerOrMoEParams(self, fflayer_tpl, name_prefix): p = self.params fflayer_tpl = fflayer_tpl.Copy() if not issubclass(fflayer_tpl.cls, gshard_builder.MoEBuilder): if hasattr(fflayer_tpl, 'input_dim'): fflayer_tpl.Set(input_dim=p.input_dim) return fflayer_tpl # TODO(rpang): migrate clients to TransformerShardedMoeLayer. fflayer_tpl.model_dim = p.input_dim moe_builder = fflayer_tpl.Instantiate() name = name_prefix + '_moe' # TODO(rpang): find a way to configure residual weight. moe_p = moe_builder.EncoderLayer( name, moe_builder.MoE(name), residual_weight=0.5) return moe_p
[docs] def _SelfAtten(self, theta, inputs, paddings): if isinstance(self.trans_atten, attention_lib.FunnelTransformerAttentionLayer): inputs, paddings, atten_probs = self.trans_atten.FProp( theta.trans_atten, query_vec=inputs, source_vecs=None, paddings=paddings) else: inputs, atten_probs = self.trans_atten.FProp( theta.trans_atten, query_vec=inputs, source_vecs=None, paddings=paddings) return inputs, paddings, atten_probs
[docs] def _LConv(self, theta, inputs, paddings): assert self.has_lconv and self.params.layer_order != 'mhsa', ( 'mhsa does not have a lconv block.') inputs, paddings = self.lconv.FProp(theta.lconv, inputs, paddings) return inputs, paddings
[docs] def _MoeOrFFLayer(self, theta, fflayer_name, features, paddings, aux_loss): """FProp for MoE or Feed forward layer. Args: theta: Layer theta: A NestedMap of Tensors. fflayer_name: Child FFLayer name as created in __init__. For example: 'fflayer_end'. This assumes the moe_layer if created would have the convention as (`fflayer_name` + `_moe`). features: [batch, seqlen, dim0]. paddings: [batch, seqlen]. aux_loss: [], can be None. Returns: features: [batch, seqlen, dim0]. paddings: [batch, seqlen]. aux_loss: [], is None if input aux_loss is None. """ if fflayer_name in self.children: outputs = self.children[fflayer_name].FProp( theta.GetItem(fflayer_name), features, paddings) return outputs, paddings, aux_loss else: moe_fflayer_name = fflayer_name + '_moe' if moe_fflayer_name not in self.children: raise AssertionError( '{} child layer not present.'.format(moe_fflayer_name)) if moe_fflayer_name not in theta: raise AssertionError( '{} layer theta not present.'.format(moe_fflayer_name)) # 0 - padded positions and 1 - non-padded positions. segment_ids = tf.cast(1. - paddings, tf.int32) segment_pos = tf.zeros_like(segment_ids) # not used but required by MoE. moe_in = py_utils.NestedMap( vec=features, segment_id=segment_ids, segment_pos=segment_pos) moe_out = self.children[moe_fflayer_name].FProp( theta.GetItem(moe_fflayer_name), moe_in) moe_aux_loss = moe_out.aux_loss if aux_loss is not None: assert not moe_aux_loss.shape.rank, 'MoE aux-loss should be a scalar.' if len(py_utils.GetShape(aux_loss)) == 1: b_size = py_utils.GetShape(aux_loss)[0] moe_aux_loss = tf.tile(tf.expand_dims(moe_aux_loss, axis=0), [b_size]) assert moe_aux_loss.shape.rank == aux_loss.shape.rank aux_loss += moe_aux_loss else: aux_loss = moe_aux_loss return moe_out.vec, paddings, aux_loss
[docs] def _FProp(self, theta, in_nmap): p = self.params with tf.name_scope(p.name): features, paddings = in_nmap.features, in_nmap.paddings aux_loss = in_nmap.Get('aux_loss') features, paddings = self._CastToFPropDtype((features, paddings)) out_nmap = py_utils.NestedMap() if self.has_fflayer_start: features, paddings, aux_loss = self._MoeOrFFLayer( theta, 'fflayer_start', features, paddings, aux_loss) atten_probs = None if p.layer_order == 'mhsa': features, paddings, atten_probs = self._SelfAtten( theta, features, paddings) elif p.layer_order == 'conv': features, paddings = self._LConv(theta, features, paddings) elif p.layer_order == 'mhsa_before_conv': features, paddings, atten_probs = self._SelfAtten( theta, features, paddings) features, paddings = self._LConv(theta, features, paddings) else: assert p.layer_order == 'conv_before_mhsa' features, paddings = self._LConv(theta, features, paddings) features, paddings, atten_probs = self._SelfAtten( theta, features, paddings) features, paddings, aux_loss = self._MoeOrFFLayer(theta, 'fflayer_end', features, paddings, aux_loss) features = self.final_ln.FProp(theta.final_ln, features) out_nmap = in_nmap.DeepCopy() if p.adapter_tpl: adapter_in_map = in_nmap.DeepCopy() adapter_in_map.features, adapter_in_map.padding = features, paddings adapter_out_nmap = self.adapter.FProp(theta.adapter, adapter_in_map) features, paddings = adapter_out_nmap.features, adapter_out_nmap.paddings features, paddings = self._CastToFPropDtype((features, paddings)) out_nmap.features, out_nmap.paddings = features, paddings if aux_loss is not None: out_nmap.aux_loss = aux_loss self._AddAttentionSummaries(p.name, atten_probs) return out_nmap
[docs] def FProp(self, theta, in_nmap): p = self.params if not p.remat: return self._FProp(theta, in_nmap) def CellFn(theta, state0, unused_inputs): out_nmap = self._FProp(theta, state0) return out_nmap, py_utils.NestedMap() _, state1 = recurrent.Recurrent( theta=theta, state0=in_nmap, inputs=py_utils.NestedMap( inputs=tf.zeros([1, 0])), # A dummy input of shape [T, ?]. cell_fn=CellFn, allow_implicit_capture=p.allow_implicit_capture) return state1
[docs] def zero_state(self, batch_size): if self.params.is_causal: lconv_state = py_utils.NestedMap() atten_state = py_utils.NestedMap() if self.has_lconv: with tf.name_scope('lconv'): lconv_state = self.lconv.zero_state(batch_size) if self.has_mhsa: with tf.name_scope('atten'): atten_state = self.trans_atten.zero_state(batch_size) return py_utils.NestedMap( lconv_state=lconv_state, atten_state=atten_state) else: return py_utils.NestedMap()
[docs] def StreamStep(self, theta, inputs, paddings, state0): """Streams t steps. Args: theta: A NestedMap of read-only layer params. inputs: A tensor of shape [b, t, d]. paddings: A 0/1 valued tensor of shape [b, t]. state0: A NestedMap of tensors of the same struct as returned by zero_state(). Returns: outputs:A tensor of shape [b, t, d]. padding: the same as input paddings. state1: A NestedMap of tensors of the same struct as state0. """ p = self.params assert p.is_causal assert not p.remat with tf.name_scope(f'{p.name}/StreamStep'): features, aux_loss = inputs, None if self.has_fflayer_start: features, paddings, aux_loss = self._MoeOrFFLayer( theta, 'fflayer_start', features, paddings, aux_loss) if p.layer_order == 'mhsa': features, paddings, atten_state1 = self.trans_atten.StreamStep( theta.trans_atten, features, paddings, state0.atten_state) elif p.layer_order == 'conv': features, paddings, lconv_state1 = self.lconv.StreamStep( theta.lconv, features, paddings, state0.lconv_state) elif p.layer_order == 'mhsa_before_conv': features, paddings, atten_state1 = self.trans_atten.StreamStep( theta.trans_atten, features, paddings, state0.atten_state) features, paddings, lconv_state1 = self.lconv.StreamStep( theta.lconv, features, paddings, state0.lconv_state) else: assert p.layer_order == 'conv_before_mhsa' features, paddings, lconv_state1 = self.lconv.StreamStep( theta.lconv, features, paddings, state0.lconv_state) features, paddings, atten_state1 = self.trans_atten.StreamStep( theta.trans_atten, features, paddings, state0.atten_state) if not self.has_lconv: lconv_state1 = py_utils.NestedMap() if not self.has_mhsa: atten_state1 = py_utils.NestedMap() features, paddings, _ = self._MoeOrFFLayer(theta, 'fflayer_end', features, paddings, aux_loss) outputs = self.final_ln.FProp(theta.final_ln, features) state1 = py_utils.NestedMap( lconv_state=lconv_state1, atten_state=atten_state1) return outputs, paddings, state1
[docs] def _AddAttentionSummaries(self, name, atten_probs): # Plots attention prob summaries for joint network. # TODO(ankurbpn): Check why op wasn't compiling on TPUs. p = self.params if not py_utils.use_tpu( ) and p.allow_attention_summaries and atten_probs is not None: atten_shape = tf.shape(atten_probs) atten_probs = tf.reshape( atten_probs, [atten_shape[0], atten_shape[1], -1, atten_shape[-1]]) # Only plots first example of the batch. atten_probs = tf.reduce_mean(atten_probs[0:1, :, :, :], 1) self._AddAttenProbsImageSummary(name, atten_probs) self._AddAttenProbsHistogramSummary(name, atten_probs)
[docs] def _AddAttenProbsHistogramSummary(self, name, atten_probs): """Add histogram summary of attention probs.""" summary_utils.histogram(name + '/atten_probs', atten_probs)
[docs] def _AddAttenProbsImageSummary(self, name, atten_probs): """Add image summary of input attention probabilities.""" def PlotAttention(fig, axes, cur_atten_probs, title): plot.AddImage(fig, axes, cur_atten_probs, title=title) axes.set_ylabel(plot.ToUnicode('Output sequence index'), wrap=True) axes.set_xlabel(plot.ToUnicode('Input sequence index'), wrap=True) with plot.MatplotlibFigureSummary( name + '/atten_example', figsize=(10, 10), max_outputs=1, subplot_grid_shape=(1, 1)) as fig: # Extract first entry in batch of attention prob matrices # [tgt_len, src_len] fig.AddSubplot([atten_probs], PlotAttention, title='atten_probs')
[docs]def ApplyGshard(conformer_tpl, device_mesh=None, proj_w_split_list=None, proj_activation_split_list=None, atten_dnh_w_split=None, atten_blnh_activation_split=None, atten_bld_activation_split=None, lconv_df_w_split=None, lconv_hwim_w_split=None, lconv_fd_w_split=None, lconv_blf_activation_split=None, lconv_bld_activation_split=None, moe_device_mesh_shape=None, moe_emh_split=None, moe_ehm_split=None, moe_eah_split=None, moe_eam_split=None, moe_egcm_split=None, moe_gecm_split=None, moe_gsec_split=None, moe_blm_split=None): """Applies gshard on conformer params. Args: conformer_tpl: A NestedMap of conformer Params. device_mesh: A numpy.ndarray specifying the device mesh on which the computation is sharded. proj_w_split_list: A list of mesh split specifying how weights are sharded for fflayer. proj_activation_split_list: A list of mesh split specifying how activations are sharded for fflayer. atten_dnh_w_split: Mesh split of the attention projection weight with the shape of [model_dim, num_heads, dim_per_head]. atten_blnh_activation_split: Mesh split of the attention activation with shape of [batch, seq_len, num_heads, dim_per_head]. atten_bld_activation_split: Mesh split of the attention activation with shape of [batch, seq_len, model_dim]. lconv_df_w_split: Mesh split of the weights in lconv with the shape of [model_dim, ff_hidden_dim]. lconv_hwim_w_split: Mesh split of the depthwise conv weight in lconv with the shape of [height, width, in_channels, channel_multiplier]. lconv_fd_w_split: Mesh split of the weights in lconv with the shape of [ff_hidden_dim, model_dim]. lconv_blf_activation_split: Mesh split of the activations in lconv with the shape of [batch, seq_len, ff_hidden_dim]. lconv_bld_activation_split: Mesh split of the activations in lconv with the shape of [batch, seq_len, model_dim]. moe_device_mesh_shape: Device mesh shape for MoE params. moe_emh_split: split dimension for MoE EMH activation. See: gshard_layers. moe_ehm_split: split dimension for MoE EHM activation. See: gshard_layers. moe_eah_split: split dimension for MoE EAH activation. See: gshard_layers. moe_eam_split: split dimension for MoE EAH activation. See: gshard_layers. moe_egcm_split: split dimension for MoE EGCM activation. See: gshard_layers. moe_gecm_split: split dimension for MoE GECM activation. See: gshard_layers. moe_gsec_split: split dimension for MoE GSEC activation. See: gshard_layers. moe_blm_split: split dimension for MoE BLM activation. See: gshard_layers. Returns: The updated conformer_tpl. """ # Not all attention class supports gshard. If not, errors would be throw here. conformer_tpl.trans_atten_tpl.atten_tpl.device_mesh = device_mesh conformer_tpl.trans_atten_tpl.atten_tpl.weight_split_dims_mapping = ( atten_dnh_w_split) conformer_tpl.trans_atten_tpl.atten_tpl.proj_tpl.weight_split_dims_mapping = ( atten_dnh_w_split) conformer_tpl.trans_atten_tpl.atten_tpl.activation_split_dims_mapping.blnh = ( atten_blnh_activation_split) conformer_tpl.trans_atten_tpl.atten_tpl.activation_split_dims_mapping.bld = ( atten_bld_activation_split) # TODO(jamesqin): support residual_proj xla sharding too. def _ApplyGshardToMoELayer(fflayer_tpl): if not issubclass(fflayer_tpl.cls, gshard_builder.DenseBuilder): return fflayer_tpl.device_mesh = device_mesh fflayer_tpl.device_mesh_shape = moe_device_mesh_shape fflayer_tpl.emh_split = moe_emh_split fflayer_tpl.ehm_split = moe_ehm_split fflayer_tpl.eah_split = moe_eah_split fflayer_tpl.eam_split = moe_eam_split fflayer_tpl.egcm_split = moe_egcm_split fflayer_tpl.gecm_split = moe_gecm_split fflayer_tpl.gsec_split = moe_gsec_split fflayer_tpl.blm_split = moe_blm_split # Set sharding in FFLayers. if not issubclass(conformer_tpl.fflayer_start_tpl.cls, gshard_builder.MoEBuilder): conformer_tpl.fflayer_start_tpl.fflayer_tpl.Set( device_mesh=device_mesh, weight_split_dims_mapping_list=proj_w_split_list, activation_split_dims_mapping_list=proj_activation_split_list) else: _ApplyGshardToMoELayer(conformer_tpl.fflayer_start_tpl) if not issubclass(conformer_tpl.fflayer_end_tpl.cls, gshard_builder.MoEBuilder): conformer_tpl.fflayer_end_tpl.fflayer_tpl.Set( device_mesh=device_mesh, weight_split_dims_mapping_list=proj_w_split_list, activation_split_dims_mapping_list=proj_activation_split_list) else: _ApplyGshardToMoELayer(conformer_tpl.fflayer_end_tpl) # Set sharding in LConv layer. conformer_tpl.lconv_tpl.Set( split_act_gated_linear_start=True, device_mesh=device_mesh) lconv_w_split = conformer_tpl.lconv_tpl.weight_split_dims_mapping lconv_w_split.df = lconv_df_w_split lconv_w_split.hwim = lconv_hwim_w_split lconv_w_split.fd = lconv_fd_w_split lconv_activation_split = conformer_tpl.lconv_tpl.activation_split_dims_mapping lconv_activation_split.blf = lconv_blf_activation_split lconv_activation_split.bld = lconv_bld_activation_split return conformer_tpl