Source code for lingvo.tools.compute_stats

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Compute stats from tfrecords files."""

import lingvo.compat as tf
import numpy as np

tf.flags.DEFINE_string('input_filepattern', '',
                       'File pattern of binary tfrecord files.')
tf.flags.DEFINE_integer('frame_size', 1, 'Size of the frame, for reshaping.')
tf.flags.DEFINE_integer('num_buckets', 8, 'Number of buckets for the length.')
tf.flags.DEFINE_string('feature_name', None, 'Name of feature to examine.')

FLAGS = tf.flags.FLAGS


[docs]class StatsCollector: """Collect stats.""" def __init__(self,): self._num_examples = 0 self._lengths = [] self._num_frames = 0 self._mean_acc = np.zeros(FLAGS.frame_size, dtype=np.float64) self._var_acc = np.zeros(FLAGS.frame_size, dtype=np.float64)
[docs] def _AccumulateMoments(self, float_list): frames = np.reshape(float_list, [-1, FLAGS.frame_size]) self._num_frames += frames.shape[0] self._mean_acc += np.sum(frames, axis=0) self._var_acc += np.sum(frames * frames, axis=0)
[docs] def _ComputeMeanVar(self): mu = self._mean_acc / self._num_frames # The user is in charge of replacing NaNs with a floor value. v = np.sqrt(self._var_acc / self._num_frames - mu * mu) return mu, v
[docs] def Accumulate(self, tf_ex): self._num_examples += 1 if 0 == self._num_examples % 10000: tf.logging.info('Processing example %u...', self._num_examples) v = tf_ex.features.feature[FLAGS.feature_name] if v.HasField('float_list'): num_frames = len(v.float_list.value) // FLAGS.frame_size self._AccumulateMoments(v.float_list.value) elif v.HasField('int64_list'): num_frames = len(v.int64_list.value) // FLAGS.frame_size else: tf.logging.fatal( 'Not sure what to do with value. ' 'Only float/int64 lists are supported: %s', v) self._lengths.append(num_frames)
[docs] def _PrintLengthBuckets(self): sorted_lengths = sorted(self._lengths) num_buckets = FLAGS.num_buckets n = len(sorted_lengths) idx = (n * (np.array(list(range(num_buckets - 1))) + 1)) // num_buckets buckets = [sorted_lengths[i] for i in idx] + [sorted_lengths[-1]] tf.logging.info('== Buckets.') tf.logging.info('bucket upper limits: %s', buckets) tf.logging.info('Other candidates for last bucket:') tf.logging.info(' 0.1%% loss: %u', sorted_lengths[int(n * .999)]) tf.logging.info(' 1%% loss: %u', sorted_lengths[int(n * .99)]) tf.logging.info(' 2%% loss: %u', sorted_lengths[int(n * .98)])
[docs] def _PrintMeanVar(self): m, v = self._ComputeMeanVar() original = np.get_printoptions() np.set_printoptions(threshold=np.inf) tf.logging.info('== Mean/variance.') tf.logging.info('mean = %s', m) tf.logging.info('var = %s', v) np.set_printoptions(**original)
[docs] def Print(self): tf.logging.info('== Total number of examples: %u', self._num_examples) self._PrintLengthBuckets() self._PrintMeanVar()
[docs]def main(_): tf.logging.set_verbosity(tf.logging.INFO) if not FLAGS.feature_name: tf.logging.fatal( 'Use a --feature_name to specify what to bucketize on. ' 'For instance, source_id for MT or frames for ASR.') stats = StatsCollector() for filepath in tf.io.gfile.glob(FLAGS.input_filepattern): records = tf.compat.v1.io.tf_record_iterator(filepath) for serialized in records: ex = tf.train.Example() ex.ParseFromString(serialized) stats.Accumulate(ex) stats.Print()
if __name__ == '__main__': tf.app.run(main)