Source code for lingvo.trainer_impl

# Lint as: python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=line-too-long
r"""Trainer.

To run locally:

.. code-block:: bash

  $ bazel build -c opt //lingvo:trainer
  $ bazel-bin/lingvo/trainer --logtostderr \
      --model=image.mnist.LeNet5 --mode=sync --logdir=/tmp/lenet5 \
      --run_locally=cpu

To use GPU, add `--config=cuda` to build command and set `--run_locally=gpu`.
"""
# pylint: enable=line-too-long
import os
import re

import time

import lingvo.compat as tf
from lingvo.core import base_model
from lingvo.core import checkpointer
from lingvo.core import cluster_factory
from lingvo.core import metrics
from lingvo.core import py_utils
from lingvo.core import summary_utils

from lingvo import base_runner


[docs]class Trainer(base_runner.BaseRunner): """Trainer on non-TPU.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._job_name = 'trainer' with self._graph.as_default(), tf.container(self._container_id): try: self._task_probs_summary_writers = [] for task in self._model.task_schedule.tasks: path = os.path.join(os.path.join(self._train_dir, task)) tf.io.gfile.makedirs(path) self._task_probs_summary_writers.append( self._CreateSummaryWriter(path)) except AttributeError: tf.logging.info('AttributeError. Expected for single task models.') self._task_probs_summary_writers = [] if self.params.cluster.task == 0: self._summary_writer = self._CreateSummaryWriter(self._train_dir) self._CreateTF2SummaryWriter(self._train_dir) else: self._summary_writer = None with self._cluster, tf.device( self._cluster.GetPlacer()), self._TF2SummaryContext(): self._model = self.params.Instantiate() self._params = self._model.params self._model.ConstructFPropBPropGraph() self._CreateTF2SummaryOps() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS) tf.logging.info('Trainer number of enqueue ops: %d', len(self.enqueue_ops)) self._step_rate_tracker = summary_utils.StepRateTracker() if self.params.cluster.task == 0: self._WriteToLog(self.params.ToText(), self._train_dir, 'trainer_params.txt') worker_id = self.params.cluster.task self._start_up_delay_steps = (((worker_id + 1) * worker_id / 2) * self.params.train.start_up_delay_steps)
[docs] def _SummarizeValue(self, steps, tag, value, writer=None): if writer: writer.add_summary(metrics.CreateScalarSummary(tag, value), steps) elif self._summary_writer: self._summary_writer.add_summary( metrics.CreateScalarSummary(tag, value), steps)
[docs] def Start(self): self._RunLoop('trainer', self._Loop)
[docs] def StartEnqueueOp(self, op): self._RunLoop( 'trainer/enqueue_op/%s' % op.name, self._LoopEnqueue, loop_args=[op])
[docs] def _LoopEnqueue(self, op): # Evaler/Controller jobs may find that the trial is infeasible and report # done earlier. This is an important check since the trainer may retry # indefinitely without it. if self._trial.ShouldStop(): tf.logging.info('Training skipped (trial requested to stop).') return return super()._LoopEnqueue(op)
[docs] def _Loop(self): # Evaler/Controller jobs may find that the trial is infeasible and report # done earlier. This is an important check since the trainer may retry # indefinitely without it. if self._trial.ShouldStop(): tf.logging.info('Training skipped (trial requested to stop).') return with tf.container( self._container_id), self._cluster, self._GetSession() as sess: # This initializes local tables sess.run(self._initialize_tables) # This initializes local variables. sess.run(self._initialize_local_vars) self._InitializeTF2SummaryWriter(sess) for task in self._model.tasks: task.input.Initialize(sess) global_step = self._WaitUntilInit(sess, self._start_up_delay_steps) status_interval_steps = 100 next_status_step = 1 eval_metrics = None while True: if (self._trial.ShouldStopAndMaybeReport(global_step, eval_metrics) or self._ShouldStop(sess, global_step)): tf.logging.info('Training finished.') if self._early_stop: time.sleep(300) # controller hangs if it doesn't finish first self._DequeueThreadComplete() return # If a task is explicitly specified, only train that task. if self._model_task_name: task = self._model.GetTask(self._model_task_name) else: # Note: This is a slightly stale global_step value from the previous # sess.run() call. # For multi-task models, `self._model.task_schedule.cur_probs` will # be updated. task = self._model.SampleTask(global_step) if self._task_probs_summary_writers: for index, prob in enumerate(self._model.task_schedule.cur_probs): self._SummarizeValue(global_step, 'task_probability', prob, self._task_probs_summary_writers[index]) try: for index, task in enumerate(self._model.tasks): self._SummarizeValue(global_step, 'task_weight', sess.run(task.vars.task_weight), self._task_probs_summary_writers[index]) except AttributeError: pass (_, eval_metrics, per_example_tensors) = sess.run([ task.train_op, task.eval_metrics, task.per_example_tensors, ]) # Explicitly fetch global_step after running train_op. # TODO(b/151181934): Investigate this behavior further. task_global_step = sess.run(task.global_step) task.ProcessFPropResults(sess, task_global_step, eval_metrics, per_example_tensors) self._RunTF2SummaryOps(sess) global_step = sess.run(self._model.global_step) step_rate, example_rate, total_examples = ( self._step_rate_tracker.ComputeStepRate( global_step, eval_metrics['num_samples_in_batch'][0])) self._SummarizeValue(global_step, 'global_step/sec', step_rate) self._SummarizeValue(global_step, 'examples/sec', example_rate) self._SummarizeValue(global_step, 'total_samples', total_examples) msg = 'step:%6d, steps/sec: %0.2f, examples/sec: %0.2f' % ( global_step, step_rate, example_rate) for key, (val, _) in sorted(eval_metrics.items()): msg += ' %s:%.8g' % (key, val) self._SummarizeValue(global_step, key, val) if global_step >= next_status_step: self._SetStatusMessage(msg) self._ExportMetrics( # Metrics expects python int, but global_step is numpy.int64. global_step=int(global_step), step_rate=step_rate, example_rate=example_rate) next_status_step = global_step + status_interval_steps else: tf.logging.info(msg) self._model.ProcessFPropResults(sess, global_step, eval_metrics, per_example_tensors)
[docs]def GetDecoderDir(logdir, decoder_type, model_task_name): if model_task_name: decoder_dir = '%s_%s' % (decoder_type, model_task_name) else: decoder_dir = decoder_type return os.path.join(logdir, decoder_dir)
[docs]def _GetCheckpointIdForDecodeOut(ckpt_id_from_file, global_step): """Retrieve the checkpoint id for the decoder out file. Compares the checkpoint id found in the checkpoint file name to global step. If they diverge, uses the retrieved id and prints a warning. Args: ckpt_id_from_file: Checkpoint Id from the checkpoint file path. global_step: int specifying the global step of the model. Returns: Checkpoint id as int. """ tf.logging.info('Loaded checkpoint is at global step: %d', global_step) tf.logging.info('Checkpoint id according to checkpoint path: %d', ckpt_id_from_file) if global_step != ckpt_id_from_file: tf.logging.warning( 'Checkpoint id %d != global step %d. ' 'Will use checkpoint id from checkpoint file for ' 'writing decoder output.', ckpt_id_from_file, global_step) return ckpt_id_from_file
[docs]class Decoder(base_runner.BaseRunner): """Decoder.""" def __init__(self, decoder_type, *args, **kwargs): super().__init__(*args, **kwargs) self._job_name = 'decoder_' + decoder_type self.params.cluster.do_eval = True self._cluster = cluster_factory.Cluster(self.params.cluster) self._decoder_dir = GetDecoderDir(self._logdir, self._job_name, self._model_task_name) tf.io.gfile.makedirs(self._decoder_dir) self._decode_path = None # Multitask params doesn't have 'task'. if 'task' in self.params: self._decode_path = checkpointer.GetSpecificCheckpoint( self.params.task.eval.load_checkpoint_from) self._should_report_metrics = self._job_name.startswith( self._cluster.reporting_job) with self._graph.as_default(), tf.container(self._container_id): self._summary_writer = self._CreateSummaryWriter(self._decoder_dir) self._CreateTF2SummaryWriter(self._decoder_dir) with self._cluster, tf.device( self._cluster.GetPlacer()), self._TF2SummaryContext(): self._model = self.params.Instantiate() self._params = self._model.params self._task = self._model.GetTask(self._model_task_name) # Note, different graphs are being constructed for different model # tasks, which may result in different node names being chosen. # Obviously, variable names has to be stay the same between train and # decode. cluster = self._cluster with tf.device(cluster.input_device): input_batch = self._task.input_generator.GetPreprocessedInputBatch() self._dec_output = self._task.Decode(input_batch) for key in self._task.input_generator.GetCpuPassthroughKeys(): if key in input_batch: if key in self._dec_output: tf.logging.warning(f'Key {key} already present in decode output. ' f'Not adding from input batch.') else: self._dec_output[key] = input_batch[key] self._summary_op = tf.summary.merge_all() self.checkpointer = self._CreateCheckpointer(self._train_dir, self._model) self._CreateTF2SummaryOps() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() # No queues are allowed for decoder models. self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS) assert not self.enqueue_ops # Saves the graph def. self._WriteToLog(self.params.ToText(), self._decoder_dir, 'params.txt') if self.params.cluster.task == 0: tf.io.write_graph(self._graph.as_graph_def(), self._decoder_dir, '%s.pbtxt' % self._job_name)
[docs] def _CreateCheckpointer(self, train_dir, model): """Wrapper method for override purposes.""" return checkpointer.Checkpointer(train_dir, model)
[docs] def Start(self): self._RunLoop(self._job_name, self._Loop)
[docs] def _Loop(self): with tf.container(self._container_id), self._cluster, self._GetSession( inline=False) as sess: # This initializes local tables sess.run(self._initialize_tables) # This initializes local variables. sess.run(self._initialize_local_vars) self._InitializeTF2SummaryWriter(sess) self._task.input.Initialize(sess) if self._decode_path: self.DecodeCheckpoint(sess, self._decode_path) py_utils.UpdateProcessedCheckpoints(self._decoder_dir, self._decode_path) elif self._task.params.eval.decode_all_checkpoints: self._RunOnAllCheckpoints(sess, self.DecodeCheckpoint, self._decoder_dir) else: self._RunOnLatestCheckpoints(sess, self.DecodeCheckpoint, self._decoder_dir) if self._should_report_metrics: tf.logging.info('Reporting trial done.') self._trial.ReportDone() tf.logging.info('Decoding finished.')
[docs] @classmethod def GetDecodeOutPath(cls, decoder_dir, checkpoint_id): """Gets the path to decode out file.""" out_dir = cls._GetTtlDir(decoder_dir, duration='7d') return os.path.join(out_dir, 'decoder_out_%09d' % checkpoint_id)
[docs] def GetCkptIdFromFile(self, checkpoint_path): return int(re.sub(r'.*ckpt-', '', checkpoint_path))
[docs] def _RemoveScalarSummaries(self, summaries): proto = tf.Summary() proto.ParseFromString(summaries) for i, value in enumerate(proto.value): if value.WhichOneof('value') == 'simple_value': del proto.value[i] return proto.SerializeToString()
[docs] def DecodeCheckpoint(self, sess, checkpoint_path): """Decodes `samples_per_summary` examples using `checkpoint_path`.""" p = self._task.params ckpt_id_from_file = self.GetCkptIdFromFile(checkpoint_path) if ckpt_id_from_file < p.eval.start_decoder_after: return samples_per_summary = p.eval.decoder_samples_per_summary if samples_per_summary is None: samples_per_summary = p.eval.samples_per_summary if samples_per_summary == 0: assert self._task.input.params.resettable self.checkpointer.RestoreFromPath(sess, checkpoint_path) global_step = sess.run(py_utils.GetGlobalStep()) if self._task.input.params.resettable: tf.logging.info('Resetting input_generator.') self._task.input.Reset(sess) dec_metrics = self._task.CreateDecoderMetrics() if not dec_metrics: tf.logging.info('Empty decoder metrics') return buffered_decode_out = [] num_examples_metric = dec_metrics['num_samples_in_batch'] start_time = time.time() while samples_per_summary == 0 or (num_examples_metric.total_value < samples_per_summary): try: is_first_loop = num_examples_metric.total_value == 0 tf.logging.info('Fetching dec_output.') fetch_start = time.time() run_options = tf.RunOptions(report_tensor_allocations_upon_oom=False) # NOTE: We intentionally do not generate scalar summaries by # default, because decoder is run multiple times for each # checkpoint. Multiple summaries at the same step is often confusing. # Instead, models should generate aggregate summaries using # PostProcessDecodeOut. Other types of summaries (images, audio etc.) # will be generated for the first eval batch. if self._summary_op is not None and is_first_loop: dec_out, summaries = sess.run([self._dec_output, self._summary_op], options=run_options) summaries = self._RemoveScalarSummaries(summaries) # Add non-scalar summaries only for the first batch of data. self._summary_writer.add_summary(summaries, global_step) self._summary_writer.flush() else: dec_out = sess.run(self._dec_output, options=run_options) self._RunTF2SummaryOps(sess) post_process_start = time.time() tf.logging.info('Done fetching (%f seconds)' % (post_process_start - fetch_start)) decode_out = self._task.PostProcessDecodeOut(dec_out, dec_metrics) if decode_out: if isinstance(decode_out, dict): decode_out = decode_out.items() if is_first_loop: # Add summaries only for the first batch of data. for key, value in decode_out: if isinstance(value, tf.Summary): tf.logging.info(f'Adding summary {key} with tags ' f'{[x.tag for x in value.value]}.') self._summary_writer.add_summary(value, global_step) self._summary_writer.flush() buffered_decode_out.extend( kv for kv in decode_out if not isinstance(kv[1], tf.Summary)) tf.logging.info( 'Total examples done: %d/%d ' '(%f seconds decode postprocess)', num_examples_metric.total_value, samples_per_summary, time.time() - post_process_start) except tf.errors.OutOfRangeError: if not self._task.input.params.resettable: raise break tf.logging.info('Done decoding ckpt: %s', checkpoint_path) summaries = {k: v.Summary(k) for k, v in dec_metrics.items()} elapsed_secs = time.time() - start_time example_rate = num_examples_metric.total_value / elapsed_secs summaries['examples/sec'] = metrics.CreateScalarSummary( 'examples/sec', example_rate) summaries['total_samples'] = metrics.CreateScalarSummary( 'total_samples', num_examples_metric.total_value) self._WriteSummaries( self._summary_writer, os.path.basename(self._decoder_dir), global_step, summaries, text_filename=os.path.join(self._decoder_dir, 'score-{:08d}.txt'.format(global_step))) self._ExportMetrics( # Metrics expects python int, but global_step is numpy.int64. decode_checkpoint=int(global_step), dec_metrics=dec_metrics, example_rate=example_rate) # global_step and the checkpoint id from the checkpoint file might be # different. For consistency of checkpoint filename and decoder_out # file, use the checkpoint id as derived from the checkpoint filename. checkpoint_id = _GetCheckpointIdForDecodeOut(ckpt_id_from_file, global_step) decode_out_path = self.GetDecodeOutPath(self._decoder_dir, checkpoint_id) decode_finalize_args = base_model.DecodeFinalizeArgs( decode_out_path=decode_out_path, decode_out=buffered_decode_out) self._task.DecodeFinalize(decode_finalize_args) if self._should_report_metrics: tf.logging.info('Reporting eval measure for step %d.' % global_step) self._trial.ReportEvalMeasure(global_step, dec_metrics, checkpoint_path)
[docs] def DecodeLatestCheckpoint(self, last_path=None): """Runs decoder on the latest checkpoint.""" with tf.container( self._container_id), self._cluster, self._GetSession() as sess: # This initializes local tables sess.run(self._initialize_tables) # This initializes local variables. sess.run(self._initialize_local_vars) self._task.input.Initialize(sess) path = tf.train.latest_checkpoint(self._train_dir) if not path: tf.logging.info('No checkpoint available.') return elif path == last_path: tf.logging.info('Latest checkpoint was already decoded.') return self.DecodeCheckpoint(sess, path)