Source code for lingvo.tools.count_records

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Tool to count number of records in a dataset.

Most other file formats have efficient ways to fetch the number of records in a
dataset.  However, some formats such as TFRecord requires you to essentially
scan the files to perform this count.

This is a short little beam script that can leverage many machines to read
all of the files in parallel potentially faster than a single machine script.
It is recommended that for other file formats, simply reading the metadata
available in their formats should work; this file should not really be
extended to any other format that already has efficient ways of counting
records.
"""

from absl import app
from absl import flags

import apache_beam as beam
from lingvo.tools import beam_utils

flags.DEFINE_string('input_file_pattern', None, 'Path to read input')
flags.DEFINE_string('output_count_file', None, 'File to write output to.')
flags.DEFINE_string('record_format', None,
                    'Record format of the input, e.g., tfrecord.')

FLAGS = flags.FLAGS


[docs]def main(argv): beam_utils.BeamInit() # Construct pipeline options from argv. options = beam.options.pipeline_options.PipelineOptions(argv[1:]) reader = beam_utils.GetReader( FLAGS.record_format, FLAGS.input_file_pattern, value_coder=beam.coders.BytesCoder()) with beam_utils.GetPipelineRoot(options=options) as root: _ = ( root | 'Read' >> reader # Read each record. | 'EmitOne' >> beam.Map(lambda _: 1) # Emit a 1 for each record. | 'Count' >> beam.CombineGlobally(sum) # Sum counts. | 'WriteToText' >> beam.io.WriteToText(FLAGS.output_count_file))
if __name__ == '__main__': flags.mark_flags_as_required( ['input_file_pattern', 'output_count_file', 'record_format']) app.run(main)