T2T: Life of an Example

version GitHub
Issues Contributions
welcome Gitter License

This doc explains how a training example flows through T2T, from data generation to training, evaluation, and decoding.

Some key files and their functions:

Data Generation

The t2t-datagen binary is the entrypoint for data generation. It simply looks up the Problem specified by --problem and calls Problem.generate_data(data_dir, tmp_dir).

All Problems are expected to generate 2 sharded TFRecords files - 1 for training and 1 for evaluation - with tensorflow.Example protocol buffers. The expected names of the files are given by Problem.{training, dev}_filepaths. Typically, the features in the Example will be "inputs" and "targets"; however, some tasks have a different on-disk representation that is converted to "inputs" and "targets" online in the input pipeline (e.g. image features are typically stored with features "image/encoded" and "image/format" and the decoding happens in the input pipeline).

For tasks that require a vocabulary, this is also the point at which the vocabulary is generated and all examples are encoded.

There are several utility functions in generator_utils that are commonly used by Problems to generate data. Several are highlighted below:

Data Input Pipeline

Once the data is produced on disk, training, evaluation, and inference (if decoding from the dataset) consume it by way of the T2T input pipeline, defined by Problem.input_fn.

The entire input pipeline is implemented with the new tf.data.Dataset API.

The input function has 2 main parts: first, reading and processing individual examples, which is done is Problem.dataset, and second, batching, which is done in Problem.input_fn after the call to Problem.dataset.

Problem subclasses may override the entire input_fn or portions of it (e.g. example_reading_spec to indicate the names, types, and shapes of features on disk). Typically they only override portions.


Problems that have fixed size features (e.g. image problems) can use hp.batch_size to set the batch size.

Variable length Problems are bucketed by sequence length and then batched out of those buckets. This significantly improves performance over a naive batching scheme for variable length sequences because each example in a batch must be padded to match the example with the maximum length in the batch.

Controlling hparams:

Building the Model

At this point, the input features typically have "inputs" and "targets", each of which is a batched 4-D Tensor (e.g. of shape [batch_size, sequence_length, 1, 1] for text input or [batch_size, height, width, 3] for image input).

The Estimator model function is created by T2TModel.estimator_model_fn, which may be overridden in its entirety by subclasses if desired. Typically, subclasses only override T2TModel.body.

The model function constructs a T2TModel, calls it, and then calls T2TModel.{estimator_spec_train, estimator_spec_eval, estimator_spec_predict} depending on the mode.

A call of a T2TModel internally calls bottom, body, top, and loss, all of which can be overridden by subclasses (typically only body is).

The default implementations of bottom, top, and loss depend on the Modality specified for the input and target features (e.g. SymbolModality.bottom embeds integer tokens and SymbolModality.loss is softmax_cross_entropy).

Estimator and Experiment

The actual training loop and related services (checkpointing, summaries, continuous evaluation, etc.) are all handled by Estimator and Experiment objects. t2t_trainer.py is the main entrypoint and uses trainer_lib.py to construct the various components.


System Overview for Train/Eval

See t2t_trainer.py and trainer_lib.py.