# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Eager runners."""
import os
import time
from lingvo import compat as tf
from lingvo.core import base_model
from lingvo.core import checkpointer
from lingvo.core import cluster_factory
from lingvo.core import metrics
from lingvo.core import py_utils
from lingvo.core import summary_utils
from lingvo import base_runner
FLAGS = tf.flags.FLAGS
[docs]class Trainer(base_runner.BaseRunner):
"""Trainer that runs in eager mode."""
[docs] def Start(self):
"""Run training."""
super().Start()
with self._cluster:
model = self._params.Instantiate()
ckptr = self._CreateCheckpointer(self._train_dir, model)
task = model.GetTask(self._model_task_name)
@tf.function(autograph=False)
def TrainFunc():
with py_utils.GradientTape(persistent=True):
model.ConstructFPropBPropGraph()
return task.eval_metrics, task.per_example_tensors
step_rate_tracker = summary_utils.StepRateTracker()
summary_writer = tf.compat.v2.summary.create_file_writer(self._train_dir)
# Attempt to restore the checkpoint
# A 'dummy run' to initialze the optimizer and related slot variables
# This is also needed for V2 checkpoint even though it supports delayed
# loading, in case the checkpoint already exeeds max_steps. In that
# scenario the slot variables will be lost without a dummy run due to
# checkpoint overwrites.
_, _ = TrainFunc()
global_step = py_utils.GetOrCreateGlobalStepVar()
# Reset global_step after the dummy run, before loading checkpoints.
global_step.assign(0)
path = ckptr.Restore()
if path:
tf.logging.info(f'Loaded checkpoints from {path}.')
else:
tf.logging.info('Did not find checkpoints in the current directory.')
global_step = model.global_step.numpy()
# Save at the beginning of training
ckptr.Save(gsteps=global_step)
while True:
if self._ShouldStop(global_step):
break
tf.logging.info('Starting train function.')
metrics_dict, outfeed = TrainFunc()
tf.logging.info('Train function complete.')
global_step = model.global_step.numpy()
if not task.per_example_tensors:
assert not outfeed
else:
# TODO(laigd): debugging only, remove later.
tf.logging.info(f'outfeed: {outfeed}')
ckptr.MaybeSave(gsteps=global_step)
step_rate, example_rate, total_examples = (
step_rate_tracker.ComputeStepRate(
global_step, metrics_dict['num_samples_in_batch'][0].numpy()))
msg = 'step:%6d, steps/sec: %0.2f, examples/sec: %0.2f' % (
global_step, step_rate, example_rate)
# Write summaries.
with summary_writer.as_default():
tf.compat.v2.summary.scalar(
'global_step/sec', step_rate, step=global_step)
tf.compat.v2.summary.scalar(
'examples/sec', example_rate, step=global_step)
tf.compat.v2.summary.scalar(
'total_samples', total_examples, step=global_step)
for key, (val, _) in sorted(metrics_dict.items()):
msg += ' %s:%.8g' % (key, val)
tf.compat.v2.summary.scalar(key, val, step=global_step)
summary_writer.flush()
# Log training progress.
self._SetStatusMessage(msg)
# Also save at the end of training
ckptr.Save(gsteps=global_step)
[docs]class TrainSummaries(base_runner.BaseRunner):
"""Write training summaries."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
logdir = os.path.join(self._logdir, 'train_summaries')
if self._model_task_name:
logdir += '_' + self._model_task_name
tf.io.gfile.makedirs(logdir)
self._summary_writer = tf.compat.v2.summary.create_file_writer(logdir)
[docs] def Start(self):
"""Start."""
super().Start()
with self._cluster:
model = self._params.Instantiate()
ckptr = self._CreateCheckpointer(self._train_dir, model)
task = model.GetTask(self._model_task_name)
next_summary_step = 1
global_step = model.global_step.numpy()
last_path = None
# Initialze the datasets and iterators before `tf.function` because
# `tf.function` does not trace python side effects.
# https://www.tensorflow.org/guide/function#executing_python_side_effects
_ = task.GetInputBatch()
@tf.function(autograph=False)
def ModelFunc():
with self._summary_writer.as_default():
with py_utils.GradientTape(persistent=True):
model.ConstructFPropBPropGraph()
return task.eval_metrics
while True:
if self._ShouldStop(global_step):
break
time.sleep(30) # Wait some time between loops.
path = tf.train.latest_checkpoint(ckptr.checkpoint_dir)
if path == last_path:
continue
# Attempt to restore the checkpoint
path = ckptr.Restore()
if not path:
continue
last_path = path
global_step = model.global_step.numpy()
if global_step >= next_summary_step:
_ = ModelFunc()
self._SetStatusMessage(f'Write summary @{global_step}')
self._summary_writer.flush()
next_summary_step = (
global_step + model.params.train.summary_interval_steps)
[docs]class Evaler(base_runner.BaseRunner):
"""Evaler."""
def __init__(self, eval_type, *args, **kwargs):
super().__init__(*args, **kwargs)
self.params.cluster.do_eval = True
self._cluster = cluster_factory.Cluster(self.params.cluster)
self._eval_type = eval_type
self._eval_dir = os.path.join(self._logdir, f'eval_{eval_type}')
if self._model_task_name:
self._eval_dir += '_' + self._model_task_name
tf.io.gfile.makedirs(self._eval_dir)
self._summary_writer = tf.compat.v2.summary.create_file_writer(
self._eval_dir)
[docs] def Start(self):
"""Start."""
super().Start()
with self._cluster:
self._model = self._params.Instantiate()
self._checkpointer = self._CreateCheckpointer(self._train_dir,
self._model)
self._task = self._model.GetTask(self._model_task_name)
self._eval_fn = self._GetEvalFunc()
self._eval_fn_with_summary = self._GetEvalFunc(write_summary=True)
self._eval_path = checkpointer.GetSpecificCheckpoint(
self._task.params.eval.load_checkpoint_from)
if self._eval_path:
self._EvalOnce(path=self._eval_path)
py_utils.UpdateProcessedCheckpoints(self._eval_dir, self._eval_path)
elif self._task.params.eval.eval_all_checkpoints:
self._RunOnAllCheckpoints(
runner_fn=self._EvalOnce, runner_dir=self._eval_dir)
else:
self._RunOnLatestCheckpoints(
runner_fn=self._EvalOnce, runner_dir=self._eval_dir)
[docs] def _GetEvalFunc(self, write_summary=False):
@tf.function(autograph=False)
def EvalFunc():
if write_summary:
# TODO(jiaweix): Investigate how to only write non-scalar summaries.
with self._summary_writer.as_default():
self._model.ConstructFPropGraph()
else:
self._model.ConstructFPropGraph()
return self._task.eval_metrics
return EvalFunc
[docs] def _EvalOnce(self, sess=None, path=''):
"""Eval a single checkpoint."""
with self._cluster:
# Attempt to restore the checkpoint
self._checkpointer.RestoreFromPath(checkpoint_path=path)
# Save any additional information to disk before evaluation.
if self._eval_type == 'train':
self._task.Export(path)
global_step = self._model.global_step.numpy()
if global_step < self._task.params.eval.start_eval_after:
return
if self._task.input.params.resettable:
tf.logging.info('Resetting input_generator.')
self._task.input_generator.Reset()
metrics_dict = None
num_samples_metric = None
samples_per_summary = self._task.params.eval.samples_per_summary
if samples_per_summary == 0:
assert self._task.input.params.resettable
while (samples_per_summary == 0 or metrics_dict is None or
num_samples_metric.total_value < samples_per_summary):
try:
# Evaler calls FProp multiple times for each checkpoint. Multiple
# summaries at the same step is often confusing. Instead, models
# should update eval_metrics and generate aggregate summaries. Other
# types of summaries (images, audio etc.) will be generated for the
# first batch only.
eval_fn = (
self._eval_fn_with_summary
if metrics_dict is None else self._eval_fn)
eval_metrics = eval_fn()
if metrics_dict is None:
metrics_dict = {
name: metrics.AverageMetric() for name in eval_metrics
}
num_samples_metric = metrics_dict['num_samples_in_batch']
eval_metrics = py_utils.Transform(lambda x: x.numpy(), eval_metrics)
for name, (value, weight) in eval_metrics.items():
metrics_dict[name].Update(value, weight)
tf.logging.info('Total examples done: %d/%d',
num_samples_metric.total_value, samples_per_summary)
except tf.errors.OutOfRangeError:
if not self._task.input.params.resettable:
raise
break
if metrics_dict is None:
metrics_dict = {}
# Replace average values with total values for certain metrics.
if 'num_predictions' in metrics_dict:
metrics_dict['num_predictions'].total_weight = 1.0
if 'num_words' in metrics_dict:
metrics_dict['num_words'].total_weight = 1.0
msg = 'step:%6d' % global_step
with self._summary_writer.as_default():
tf.compat.v2.summary.scalar(
'total_samples', num_samples_metric.total_value, step=global_step)
for key, metric in sorted(metrics_dict.items()):
msg += ' %s:%.8g' % (key, metric.value)
tf.compat.v2.summary.scalar(key, metric.value, step=global_step)
self._summary_writer.flush()
self._SetStatusMessage(msg)
[docs]def _GetCheckpointIdForDecodeOut(ckpt_id_from_file, global_step):
"""Retrieve the checkpoint id for the decoder out file.
Compares the checkpoint id found in the checkpoint file name to global
step. If they diverge, uses the retrieved id and prints a warning.
Args:
ckpt_id_from_file: Checkpoint Id from the checkpoint file path.
global_step: int specifying the global step of the model.
Returns:
Checkpoint id as int.
"""
tf.logging.info('Loaded checkpoint is at global step: %d', global_step)
tf.logging.info('Checkpoint id according to checkpoint path: %d',
ckpt_id_from_file)
if global_step != ckpt_id_from_file:
tf.logging.warning(
'Checkpoint id %d != global step %d. '
'Will use checkpoint id from checkpoint file for '
'writing decoder output.', ckpt_id_from_file, global_step)
return ckpt_id_from_file
[docs]class Decoder(base_runner.BaseRunner):
"""Decoder."""
def __init__(self, decoder_type, *args, **kwargs):
super().__init__(*args, **kwargs)
self.params.cluster.do_eval = True
self._cluster = cluster_factory.Cluster(self.params.cluster)
self._decoder_dir = os.path.join(self._logdir, f'decoder_{decoder_type}')
if self._model_task_name:
self._decoder_dir += '_' + self._model_task_name
tf.io.gfile.makedirs(self._decoder_dir)
self._summary_writer = tf.compat.v2.summary.create_file_writer(
self._decoder_dir)
[docs] def Start(self):
"""Start."""
super().Start()
with self._cluster:
self._model = self._params.Instantiate()
self._checkpointer = self._CreateCheckpointer(self._train_dir,
self._model)
self._task = self._model.GetTask(self._model_task_name)
self._decode_fn = self._GetDecodeFunc()
self._decode_fn_with_summary = self._GetDecodeFunc(write_summary=True)
self._decode_path = checkpointer.GetSpecificCheckpoint(
self._task.params.eval.load_checkpoint_from)
if self._decode_path:
self._DecodeOnce(path=self._decode_path)
py_utils.UpdateProcessedCheckpoints(self._decoder_dir, self._decode_path)
elif self._task.params.eval.decode_all_checkpoints:
self._RunOnAllCheckpoints(
runner_fn=self._DecodeOnce, runner_dir=self._decoder_dir)
else:
self._RunOnLatestCheckpoints(
runner_fn=self._DecodeOnce, runner_dir=self._decoder_dir)
[docs] def _GetDecodeFunc(self, write_summary=False):
@tf.function(autograph=False)
def DecodeFunc():
if write_summary:
# TODO(jiaweix): Investigate how to only write non-scalar summaries.
with self._summary_writer.as_default():
input_batch, dec_output = self._model.ConstructDecodeGraph(
self._model_task_name)
else:
input_batch, dec_output = self._model.ConstructDecodeGraph(
self._model_task_name)
return input_batch, dec_output
return DecodeFunc
[docs] @classmethod
def GetDecodeOutPath(cls, decoder_dir, checkpoint_id):
"""Gets the path to decode out file."""
out_dir = cls._GetTtlDir(decoder_dir, duration='7d')
return os.path.join(out_dir, 'decoder_out_%09d' % checkpoint_id)
[docs] def _DecodeOnce(self, sess=None, path=''):
"""Decode a single checkpoint."""
with self._cluster:
# Attempt to restore the checkpoint
self._checkpointer.RestoreFromPath(checkpoint_path=path)
global_step = self._model.global_step.numpy()
if global_step < self._task.params.eval.start_decoder_after:
return
if self._task.input.params.resettable:
tf.logging.info('Resetting input_generator.')
self._task.input_generator.Reset()
dec_metrics = self._task.CreateDecoderMetrics()
if not dec_metrics:
tf.logging.info('Empty decoder metrics')
return
buffered_decode_out = []
num_samples_metric = dec_metrics['num_samples_in_batch']
samples_per_summary = self._task.params.eval.decoder_samples_per_summary
if samples_per_summary is None:
samples_per_summary = self._task.params.eval.samples_per_summary
if samples_per_summary == 0:
assert self._task.input.params.resettable
start_time = time.time()
while samples_per_summary == 0 or (num_samples_metric.total_value <
samples_per_summary):
try:
tf.logging.info('Fetching dec_output.')
fetch_start = time.time()
# Decoder calls FProp multiple times for each checkpoint. Multiple
# summaries at the same step is often confusing. Instead, models
# should generate aggregate summaries using PostProcessDecodeOut.
# Other types of summaries (images, audio etc.) will be generated for
# the first batch only.
is_first_loop = num_samples_metric.total_value == 0
decode_fn = (
self._decode_fn_with_summary
if is_first_loop else self._decode_fn)
input_batch, dec_output = decode_fn()
for key in self._task.input_generator.GetCpuPassthroughKeys():
if key in input_batch:
if key in dec_output:
tf.logging.warning(
f'Key {key} already present in decode output. '
f'Not adding from input batch.')
else:
dec_output[key] = input_batch[key]
dec_output = py_utils.Transform(lambda x: x.numpy(), dec_output)
post_process_start = time.time()
tf.logging.info('Done fetching (%f seconds)' %
(post_process_start - fetch_start))
decode_out = self._task.PostProcessDecodeOut(dec_output, dec_metrics)
if decode_out:
if isinstance(decode_out, dict):
decode_out = decode_out.items()
if is_first_loop:
# Add summaries only for the first batch of data.
with self._summary_writer.as_default():
for key, value in decode_out:
if isinstance(value, tf.Summary):
tf.logging.info(f'Adding summary {key} with tags '
f'{[x.tag for x in value.value]}.')
tf.compat.v2.summary.experimental.write_raw_pb(
tf.constant(value.SerializeToString()), global_step)
buffered_decode_out.extend(
kv for kv in decode_out if not isinstance(kv[1], tf.Summary))
tf.logging.info(
'Total examples done: %d/%d '
'(%f seconds decode postprocess)', num_samples_metric.total_value,
samples_per_summary,
time.time() - post_process_start)
except tf.errors.OutOfRangeError:
if not self._task.input.params.resettable:
raise
break
tf.logging.info('Done decoding ckpt: %s', path)
elapsed_secs = time.time() - start_time
example_rate = num_samples_metric.total_value / elapsed_secs
msg = 'step:%6d, elapsed_secs: %0.2f, examples/sec: %0.2f' % (
global_step, elapsed_secs, example_rate)
with self._summary_writer.as_default():
tf.compat.v2.summary.scalar(
'decode_secs', elapsed_secs, step=global_step)
tf.compat.v2.summary.scalar(
'examples/sec', example_rate, step=global_step)
tf.compat.v2.summary.scalar(
'total_samples', num_samples_metric.total_value, step=global_step)
for key, metric in sorted(dec_metrics.items()):
msg += ' %s:%.8g' % (key, metric.value)
tf.compat.v2.summary.scalar(key, metric.value, step=global_step)
self._summary_writer.flush()
self._SetStatusMessage(msg)
self._ExportMetrics(
# Metrics expects python int, but global_step is numpy.int64.
decode_checkpoint=int(global_step),
dec_metrics=dec_metrics,
example_rate=example_rate)
decode_out_path = self.GetDecodeOutPath(self._decoder_dir, global_step)
decode_finalize_args = base_model.DecodeFinalizeArgs(
decode_out_path=decode_out_path, decode_out=buffered_decode_out)
self._task.DecodeFinalize(decode_finalize_args)