Skip to content

TensorFlow 2.x in TFX

TensorFlow 2.0 was released in 2019, with tight integration of Keras, eager execution by default, and Pythonic function execution, among other new features and improvements.

This guide provides a comprehensive technical overview of TF 2.x in TFX.

Which version to use?

TFX is compatible with TensorFlow 2.x, and the high-level APIs that existed in TensorFlow 1.x (particularly Estimators) continue to work.

Start new projects in TensorFlow 2.x

Since TensorFlow 2.x retains the high-level capabilities of TensorFlow 1.x, there is no advantage to using the older version on new projects, even if you don't plan to use the new features.

Therefore, if you are starting a new TFX project, we recommend that you use TensorFlow 2.x. You may want to update your code later as full support for Keras and other new features become available, and the scope of changes will be much more limited if you start with TensorFlow 2.x, rather than trying to upgrade from TensorFlow 1.x in the future.

Converting existing projects to TensorFlow 2.x

Code written for TensorFlow 1.x is largely compatible with TensorFlow 2.x and will continue to work in TFX.

However, if you'd like to take advantage of improvements and new features as they become available in TF 2.x, you can follow the instructions for migrating to TF 2.x.

Estimator

The Estimator API has been fully dropped since TensorFlow 2.16, we decided to discontinue the support for it.

Native Keras (i.e. Keras without Estimator)

Note

Full support for all features in Keras is in progress, in most cases, Keras in TFX will work as expected. It does not yet work with Sparse Features for FeatureColumns.

Examples and Colab

Here are several examples with native Keras:

We also have a per-component Keras Colab.

TFX Components

The following sections explain how related TFX components support native Keras.

Transform

Transform currently has experimental support for Keras models.

The Transform component itself can be used for native Keras without change. The preprocessing_fn definition remains the same, using TensorFlow and tf.Transform ops.

The serving function and eval function are changed for native Keras. Details will be discussed in the following Trainer and Evaluator sections.

Note

Transformations within the preprocessing_fn cannot be applied to the label feature for training or eval.

Trainer

Keras Module file with Transform

The training module file must contains a run_fn which will be called by the GenericExecutor, a typical Keras run_fn would look like this:

def run_fn(fn_args: TrainerFnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

  # Train and eval files contains transformed examples.
  # _input_fn read dataset based on transformed schema from tft.
  train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor,
                            tf_transform_output.transformed_metadata.schema)
  eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor,
                           tf_transform_output.transformed_metadata.schema)

  model = _build_keras_model()

  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps)

  signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(model,
                                    tf_transform_output).get_concrete_function(
                                        tf.TensorSpec(
                                            shape=[None],
                                            dtype=tf.string,
                                            name='examples')),
  }
  model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)

In the run_fn above, a serving signature is needed when exporting the trained model so that model can take raw examples for prediction. A typical serving function would look like this:

def _get_serve_tf_examples_fn(model, tf_transform_output):
  """Returns a function that parses a serialized tf.Example."""

  # the layer is added as an attribute to the model in order to make sure that
  # the model assets are handled correctly when exporting.
  model.tft_layer = tf_transform_output.transform_features_layer()

  @tf.function
  def serve_tf_examples_fn(serialized_tf_examples):
    """Returns the output to be used in the serving signature."""
    feature_spec = tf_transform_output.raw_feature_spec()
    feature_spec.pop(_LABEL_KEY)
    parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)

    transformed_features = model.tft_layer(parsed_features)

    return model(transformed_features)

  return serve_tf_examples_fn

In above serving function, tf.Transform transformations need to be applied to the raw data for inference, using the tft.TransformFeaturesLayer layer. The previous _serving_input_receiver_fn which was required for Estimators will no longer be needed with Keras.

Keras Module file without Transform

This is similar to the module file shown above, but without the transformations:

def _get_serve_tf_examples_fn(model, schema):

  @tf.function
  def serve_tf_examples_fn(serialized_tf_examples):
    feature_spec = _get_raw_feature_spec(schema)
    feature_spec.pop(_LABEL_KEY)
    parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
    return model(parsed_features)

  return serve_tf_examples_fn


def run_fn(fn_args: TrainerFnArgs):
  schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema())

  # Train and eval files contains raw examples.
  # _input_fn reads the dataset based on raw data schema.
  train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor, schema)
  eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor, schema)

  model = _build_keras_model()

  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps)

  signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(model, schema).get_concrete_function(
              tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')),
  }
  model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
tf.distribute.Strategy

At this time TFX only supports single worker strategies (e.g., MirroredStrategy, OneDeviceStrategy).

To use a distribution strategy, create an appropriate tf.distribute.Strategy and move the creation and compiling of the Keras model inside a strategy scope.

For example, replace above model = _build_keras_model() with:

  mirrored_strategy = tf.distribute.MirroredStrategy()
  with mirrored_strategy.scope():
    model = _build_keras_model()

  # Rest of the code can be unchanged.
  model.fit(...)

To verify the device (CPU/GPU) used by MirroredStrategy, enable info level tensorflow logging:

import logging
logging.getLogger("tensorflow").setLevel(logging.INFO)

and you should be able to see Using MirroredStrategy with devices (...) in the log.

Note

The environment variable TF_FORCE_GPU_ALLOW_GROWTH=true might be needed for a GPU out of memory issue. For details, please refer to tensorflow GPU guide.

Evaluator

In TFMA v0.2x, ModelValidator and Evaluator have been combined into a single new Evaluator component. The new Evaluator component can perform both single model evaluation and also validate the current model compared with previous models. With this change, the Pusher component now consumes a blessing result from Evaluator instead of ModelValidator.

See Evaluator for more information.