# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Debug print tf records in text format."""
import lingvo.compat as tf
import six
tf.flags.DEFINE_string('input_filepattern', '',
'File pattern of binary tfrecord files.')
tf.flags.DEFINE_string('input_format', 'tf.Example',
'Input format: only "tf.Example" supported for now.')
tf.flags.DEFINE_integer('skip_first_n', 0, 'Skip first records.')
tf.flags.DEFINE_integer('print_only_n', -1,
'Only print a certain number of records.')
tf.flags.DEFINE_bool('abbreviated', True, 'Print in abbreviated format.')
tf.flags.DEFINE_bool('bytes_as_utf8', True,
'Print byte strings as UTF-8 strings')
tf.flags.DEFINE_bool('count_only', False,
'Don\'t print, just count number of entries')
FLAGS = tf.flags.FLAGS
[docs]def _ListDebugString(values, to_string=str):
if len(values) <= 8:
return repr(values)
first_values = [to_string(v) for v in values[0:6]]
last_values = [to_string(v) for v in values[-2:]]
return '[' + ' '.join(first_values + ['...'] + last_values) + ']'
[docs]def _CustomShortDebugString(tf_example):
text = []
for name, value in sorted(tf_example.features.feature.items()):
if value.HasField('bytes_list'):
if FLAGS.bytes_as_utf8:
utf8_values = [
six.ensure_text(v, 'utf-8') for v in value.bytes_list.value
]
value_string = _ListDebugString(utf8_values)
else:
value_string = _ListDebugString(value.bytes_list.value)
elif value.HasField('float_list'):
value_string = _ListDebugString(value.float_list.value)
elif value.HasField('int64_list'):
value_string = _ListDebugString(value.int64_list.value, to_string=repr)
text += ['%s: %s' % (name, value_string)]
return '\n'.join(text)
[docs]def _PrintFiles():
entry = 0
for filepath in tf.io.gfile.glob(FLAGS.input_filepattern):
records = tf.compat.v1.io.tf_record_iterator(filepath)
for serialized in records:
if entry < FLAGS.skip_first_n:
entry += 1
continue
if FLAGS.print_only_n >= 0 and (entry - FLAGS.skip_first_n >
FLAGS.print_only_n):
break
if FLAGS.count_only:
entry += 1
if (entry % 100000) == 0:
tf.logging.info('Counted %d entries so far...', entry)
continue
assert FLAGS.input_format == 'tf.Example'
ex = tf.train.Example()
ex.ParseFromString(serialized)
if entry == FLAGS.skip_first_n:
_PrintHeader(ex)
text_format = _CustomShortDebugString(ex) if FLAGS.abbreviated else str(
ex)
tf.logging.info('== Record [%d]\n%s', entry, text_format)
entry += 1
tf.logging.info('== Total entries: %d', entry)
[docs]def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
_PrintFiles()
if __name__ == '__main__':
tf.app.run(main)