Source code for lingvo.tasks.milan.common_schema

# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines a common `tf.train.Example` format for image-caption(-like) data."""

import lingvo.compat as tf


[docs]def _Feature(shape, dtype=tf.float32): shape = tf.TensorShape(shape) if shape.is_fully_defined(): return tf.io.FixedLenFeature(shape, dtype=dtype) else: if shape[:1].is_fully_defined() or not shape[1:].is_fully_defined(): raise ValueError(f'Unsupported sequence shape {shape}') return tf.io.FixedLenSequenceFeature( shape[1:], dtype=dtype, allow_missing=True)
[docs]def ImageFeatures(images_per_example=1): """Returns definitions of common image features. Args: images_per_example: Number of images stored in each example. Returns: A dict of feature definitions usable with `tf.io.parse_example`. """ images_shape = ([images_per_example] if images_per_example > 1 else []) return { 'image/encoded': _Feature(images_shape, tf.string), 'image/id': _Feature(images_shape, tf.int64), }
[docs]def TextFeatures(captions_per_example=1, bert_embeddings_shape=None): """Returns definitions of the common text features. Args: captions_per_example: Number of text captions stored in each example. bert_embeddings_shape: Optional time-major shape of BERT embedding sequences to include in the schema (if given). Set the leading (time) dimension to `None` if the sequences have variable length. Returns: A dict of feature definitions usable with `tf.io.parse_example`. """ captions_shape = ([captions_per_example] if captions_per_example > 1 else []) features = { 'text/captions': _Feature(captions_shape, tf.string), 'text/id': _Feature(captions_shape, tf.int64), } if bert_embeddings_shape is not None: max_length, feature_dim = bert_embeddings_shape features.update({ # Token-level BERT embeddings. 'text/bert/token_features': _Feature(captions_shape + [max_length, feature_dim]), # Lengths (in tokens) of the token-level embeddings. 'text/bert/lengths': _Feature(captions_shape, tf.int64) }) return features
[docs]def AudioFeatures(mfcc_shape=None, cpc8k_shape=None): """Returns definitions of common audio features. Args: mfcc_shape: Optional time-major shape of MFCC features to include in the schema (if given). Set the leading (time) dimension to `None` if the sequences have variable length. cpc8k_shape: Optional time-major shape of CPC-8K features to include. Returns: A dict of feature definitions usable with `tf.io.parse_example`. """ features = {'audio/id': _Feature([], tf.int64)} if mfcc_shape is not None: features['audio/mfccs'] = _Feature(mfcc_shape) if cpc8k_shape is not None: features['audio/cpc8k/features'] = _Feature(cpc8k_shape) return features