TensorFlow 2.x in TFX¶
TensorFlow 2.0 was released in 2019, with tight integration of Keras, eager execution by default, and Pythonic function execution, among other new features and improvements.
This guide provides a comprehensive technical overview of TF 2.x in TFX.
Which version to use?¶
TFX is compatible with TensorFlow 2.x, and the high-level APIs that existed in TensorFlow 1.x (particularly Estimators) continue to work.
Start new projects in TensorFlow 2.x¶
Since TensorFlow 2.x retains the high-level capabilities of TensorFlow 1.x, there is no advantage to using the older version on new projects, even if you don't plan to use the new features.
Therefore, if you are starting a new TFX project, we recommend that you use TensorFlow 2.x. You may want to update your code later as full support for Keras and other new features become available, and the scope of changes will be much more limited if you start with TensorFlow 2.x, rather than trying to upgrade from TensorFlow 1.x in the future.
Converting existing projects to TensorFlow 2.x¶
Code written for TensorFlow 1.x is largely compatible with TensorFlow 2.x and will continue to work in TFX.
However, if you'd like to take advantage of improvements and new features as they become available in TF 2.x, you can follow the instructions for migrating to TF 2.x.
Estimator¶
The Estimator API has been fully dropped since TensorFlow 2.16, we decided to discontinue the support for it.
Native Keras (i.e. Keras without Estimator)¶
Note
Full support for all features in Keras is in progress, in most cases, Keras in TFX will work as expected. It does not yet work with Sparse Features for FeatureColumns.
Examples and Colab¶
Here are several examples with native Keras:
- Penguin (module file): 'Hello world' end-to-end example.
- MNIST (module file): Image end-to-end example.
- Taxi (module file): end-to-end example with advanced Transform usage.
We also have a per-component Keras Colab.
TFX Components¶
The following sections explain how related TFX components support native Keras.
Transform¶
Transform currently has experimental support for Keras models.
The Transform component itself can be used for native Keras without change. The
preprocessing_fn
definition remains the same, using
TensorFlow and
tf.Transform
ops.
The serving function and eval function are changed for native Keras. Details will be discussed in the following Trainer and Evaluator sections.
Note
Transformations within the preprocessing_fn
cannot be applied to the
label feature for training or eval.
Trainer¶
Keras Module file with Transform¶
The training module file must contains a run_fn
which will be called by the
GenericExecutor
, a typical Keras run_fn
would look like this:
def run_fn(fn_args: TrainerFnArgs):
"""Train the model based on given args.
Args:
fn_args: Holds args used to train the model as name/value pairs.
"""
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
# Train and eval files contains transformed examples.
# _input_fn read dataset based on transformed schema from tft.
train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor,
tf_transform_output.transformed_metadata.schema)
eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor,
tf_transform_output.transformed_metadata.schema)
model = _build_keras_model()
model.fit(
train_dataset,
steps_per_epoch=fn_args.train_steps,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps)
signatures = {
'serving_default':
_get_serve_tf_examples_fn(model,
tf_transform_output).get_concrete_function(
tf.TensorSpec(
shape=[None],
dtype=tf.string,
name='examples')),
}
model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
In the run_fn
above, a serving signature is needed when exporting the trained
model so that model can take raw examples for prediction. A typical serving
function would look like this:
def _get_serve_tf_examples_fn(model, tf_transform_output):
"""Returns a function that parses a serialized tf.Example."""
# the layer is added as an attribute to the model in order to make sure that
# the model assets are handled correctly when exporting.
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function
def serve_tf_examples_fn(serialized_tf_examples):
"""Returns the output to be used in the serving signature."""
feature_spec = tf_transform_output.raw_feature_spec()
feature_spec.pop(_LABEL_KEY)
parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
transformed_features = model.tft_layer(parsed_features)
return model(transformed_features)
return serve_tf_examples_fn
In above serving function, tf.Transform transformations need to be applied to
the raw data for inference, using the
tft.TransformFeaturesLayer
layer. The previous _serving_input_receiver_fn
which was required for
Estimators will no longer be needed with Keras.
Keras Module file without Transform¶
This is similar to the module file shown above, but without the transformations:
def _get_serve_tf_examples_fn(model, schema):
@tf.function
def serve_tf_examples_fn(serialized_tf_examples):
feature_spec = _get_raw_feature_spec(schema)
feature_spec.pop(_LABEL_KEY)
parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
return model(parsed_features)
return serve_tf_examples_fn
def run_fn(fn_args: TrainerFnArgs):
schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema())
# Train and eval files contains raw examples.
# _input_fn reads the dataset based on raw data schema.
train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor, schema)
eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor, schema)
model = _build_keras_model()
model.fit(
train_dataset,
steps_per_epoch=fn_args.train_steps,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps)
signatures = {
'serving_default':
_get_serve_tf_examples_fn(model, schema).get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')),
}
model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
tf.distribute.Strategy¶
At this time TFX only supports single worker strategies (e.g., MirroredStrategy, OneDeviceStrategy).
To use a distribution strategy, create an appropriate tf.distribute.Strategy and move the creation and compiling of the Keras model inside a strategy scope.
For example, replace above model = _build_keras_model()
with:
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = _build_keras_model()
# Rest of the code can be unchanged.
model.fit(...)
To verify the device (CPU/GPU) used by MirroredStrategy
, enable info level
tensorflow logging:
and you should be able to see Using MirroredStrategy with devices (...)
in the
log.
Note
The environment variable TF_FORCE_GPU_ALLOW_GROWTH=true
might be needed
for a GPU out of memory issue. For details, please refer to
tensorflow GPU guide.
Evaluator¶
In TFMA v0.2x, ModelValidator and Evaluator have been combined into a single new Evaluator component. The new Evaluator component can perform both single model evaluation and also validate the current model compared with previous models. With this change, the Pusher component now consumes a blessing result from Evaluator instead of ModelValidator.