Copyright 2021 The TensorFlow Authors.¶
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Model analysis using TFX Pipeline and TensorFlow Model Analysis¶
Note: We recommend running this tutorial in a Colab notebook, with no setup required! Just click "Run in Google Colab".
In this notebook-based tutorial, we will create and run a TFX pipeline which creates a simple classification model and analyzes its performance across multiple runs. This notebook is based on the TFX pipeline we built in Simple TFX Pipeline Tutorial. If you have not read that tutorial yet, you should read it before proceeding with this notebook.
As you tweak your model or train it with a new dataset, you need to check whether your model has improved or become worse. Just checking top-level metrics like accuracy might not be enough. Every trained model should be evaluated before it is pushed to production.
We will add an Evaluator
component to the pipeline created in the previous
tutorial. The Evaluator component performs deep analysis for your models and
compare the new model against a baseline to determine they are "good enough".
It is implemented using the
TensorFlow Model Analysis library.
Please see Understanding TFX Pipelines to learn more about various concepts in TFX.
Set Up¶
The Set up process is the same as the previous tutorial.
We first need to install the TFX Python package and download the dataset which we will use for our model.
Upgrade Pip¶
To avoid upgrading Pip in a system when running locally, check to make sure that we are running in Colab. Local systems can of course be upgraded separately.
try:
import colab
!pip install --upgrade pip
except:
pass
Install TFX¶
!pip install -U tfx
Did you restart the runtime?¶
If you are using Google Colab, the first time that you run the cell above, you must restart the runtime by clicking above "RESTART RUNTIME" button or using "Runtime > Restart runtime ..." menu. This is because of the way that Colab loads packages.
Check the TensorFlow and TFX versions.
import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))
from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))
Set up variables¶
There are some variables used to define a pipeline. You can customize these variables as you want. By default all output from the pipeline will be generated under the current directory.
import os
PIPELINE_NAME = "penguin-tfma"
# Output directory to store artifacts generated from the pipeline.
PIPELINE_ROOT = os.path.join('pipelines', PIPELINE_NAME)
# Path to a SQLite DB file to use as an MLMD storage.
METADATA_PATH = os.path.join('metadata', PIPELINE_NAME, 'metadata.db')
# Output directory where created models from the pipeline will be exported.
SERVING_MODEL_DIR = os.path.join('serving_model', PIPELINE_NAME)
from absl import logging
logging.set_verbosity(logging.INFO) # Set default logging level.
Prepare example data¶
We will use the same Palmer Penguins dataset.
There are four numeric features in this dataset which were already normalized
to have range [0,1]. We will build a classification model which predicts the
species
of penguins.
Because TFX ExampleGen reads inputs from a directory, we need to create a directory and copy dataset to it.
import urllib.request
import tempfile
DATA_ROOT = tempfile.mkdtemp(prefix='tfx-data') # Create a temporary directory.
_data_url = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/penguin/data/labelled/penguins_processed.csv'
_data_filepath = os.path.join(DATA_ROOT, "data.csv")
urllib.request.urlretrieve(_data_url, _data_filepath)
Create a pipeline¶
We will add an Evaluator
component to the pipeline we created in the
Simple TFX Pipeline Tutorial.
An Evaluator component requires input data from an ExampleGen
component and
a model from a Trainer
component and a
tfma.EvalConfig
object. We can optionally supply a baseline model which can be used to compare
metrics with the newly trained model.
An evaluator creates two kinds of output artifacts, ModelEvaluation
and
ModelBlessing
. ModelEvaluation contains the detailed evaluation result which
can be investigated and visualized further with TFMA library. ModelBlessing
contains a boolean result whether the model passed given criteria and can be
used in later components like a Pusher as a signal.
Write model training code¶
We will use the same model code as in the Simple TFX Pipeline Tutorial.
_trainer_module_file = 'penguin_trainer.py'
%%writefile {_trainer_module_file}
# Copied from https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple
from typing import List
from absl import logging
import tensorflow as tf
from tensorflow import keras
from tensorflow_transform.tf_metadata import schema_utils
from tfx.components.trainer.executor import TrainerFnArgs
from tfx.components.trainer.fn_args_utils import DataAccessor
from tfx_bsl.tfxio import dataset_options
from tensorflow_metadata.proto.v0 import schema_pb2
_FEATURE_KEYS = [
'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g'
]
_LABEL_KEY = 'species'
_TRAIN_BATCH_SIZE = 20
_EVAL_BATCH_SIZE = 10
# Since we're not generating or creating a schema, we will instead create
# a feature spec. Since there are a fairly small number of features this is
# manageable for this dataset.
_FEATURE_SPEC = {
**{
feature: tf.io.FixedLenFeature(shape=[1], dtype=tf.float32)
for feature in _FEATURE_KEYS
},
_LABEL_KEY: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64)
}
def _input_fn(file_pattern: List[str],
data_accessor: DataAccessor,
schema: schema_pb2.Schema,
batch_size: int = 200) -> tf.data.Dataset:
"""Generates features and label for training.
Args:
file_pattern: List of paths or patterns of input tfrecord files.
data_accessor: DataAccessor for converting input to RecordBatch.
schema: schema of the input data.
batch_size: representing the number of consecutive elements of returned
dataset to combine in a single batch
Returns:
A dataset that contains (features, indices) tuple where features is a
dictionary of Tensors, and indices is a single Tensor of label indices.
"""
return data_accessor.tf_dataset_factory(
file_pattern,
dataset_options.TensorFlowDatasetOptions(
batch_size=batch_size, label_key=_LABEL_KEY),
schema=schema).repeat()
def _build_keras_model() -> tf.keras.Model:
"""Creates a DNN Keras model for classifying penguin data.
Returns:
A Keras Model.
"""
# The model below is built with Functional API, please refer to
# https://www.tensorflow.org/guide/keras/overview for all API options.
inputs = [keras.layers.Input(shape=(1,), name=f) for f in _FEATURE_KEYS]
d = keras.layers.concatenate(inputs)
for _ in range(2):
d = keras.layers.Dense(8, activation='relu')(d)
outputs = keras.layers.Dense(3)(d)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer=keras.optimizers.Adam(1e-2),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.summary(print_fn=logging.info)
return model
# TFX Trainer will call this function.
def run_fn(fn_args: TrainerFnArgs):
"""Train the model based on given args.
Args:
fn_args: Holds args used to train the model as name/value pairs.
"""
# This schema is usually either an output of SchemaGen or a manually-curated
# version provided by pipeline author. A schema can also derived from TFT
# graph if a Transform component is used. In the case when either is missing,
# `schema_from_feature_spec` could be used to generate schema from very simple
# feature_spec, but the schema returned would be very primitive.
schema = schema_utils.schema_from_feature_spec(_FEATURE_SPEC)
train_dataset = _input_fn(
fn_args.train_files,
fn_args.data_accessor,
schema,
batch_size=_TRAIN_BATCH_SIZE)
eval_dataset = _input_fn(
fn_args.eval_files,
fn_args.data_accessor,
schema,
batch_size=_EVAL_BATCH_SIZE)
model = _build_keras_model()
model.fit(
train_dataset,
steps_per_epoch=fn_args.train_steps,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps)
# The result of the training should be saved in `fn_args.serving_model_dir`
# directory.
model.save(fn_args.serving_model_dir, save_format='tf')
Write a pipeline definition¶
We will define a function to create a TFX pipeline. In addition to the
Evaluator component we mentioned above, we will add one more node called
Resolver
.
To check a new model is getting better than previous model, we need to compare
it against a previous published model, called baseline.
ML Metadata(MLMD) tracks all
previous artifacts of the pipeline and Resolver
can find what was the latest
blessed model -- a model passed Evaluator successfully -- from MLMD using a
strategy class called LatestBlessedModelStrategy
.
import tensorflow_model_analysis as tfma
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
module_file: str, serving_model_dir: str,
metadata_path: str) -> tfx.dsl.Pipeline:
"""Creates a three component penguin pipeline with TFX."""
# Brings data into the pipeline.
example_gen = tfx.components.CsvExampleGen(input_base=data_root)
# Uses user-provided Python function that trains a model.
trainer = tfx.components.Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
train_args=tfx.proto.TrainArgs(num_steps=100),
eval_args=tfx.proto.EvalArgs(num_steps=5))
# NEW: Get the latest blessed model for Evaluator.
model_resolver = tfx.dsl.Resolver(
strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
model_blessing=tfx.dsl.Channel(
type=tfx.types.standard_artifacts.ModelBlessing)).with_id(
'latest_blessed_model_resolver')
# NEW: Uses TFMA to compute evaluation statistics over features of a model and
# perform quality validation of a candidate model (compared to a baseline).
eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(label_key='species')],
slicing_specs=[
# An empty slice spec means the overall slice, i.e. the whole dataset.
tfma.SlicingSpec(),
# Calculate metrics for each penguin species.
tfma.SlicingSpec(feature_keys=['species']),
],
metrics_specs=[
tfma.MetricsSpec(per_slice_thresholds={
'sparse_categorical_accuracy':
tfma.PerSliceMetricThresholds(thresholds=[
tfma.PerSliceMetricThreshold(
slicing_specs=[tfma.SlicingSpec()],
threshold=tfma.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={'value': 0.6}),
# Change threshold will be ignored if there is no
# baseline model resolved from MLMD (first run).
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={'value': -1e-10}))
)]),
})],
)
evaluator = tfx.components.Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
baseline_model=model_resolver.outputs['model'],
eval_config=eval_config)
# Checks whether the model passed the validation steps and pushes the model
# to a file destination if check passed.
pusher = tfx.components.Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'], # Pass an evaluation result.
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=serving_model_dir)))
components = [
example_gen,
trainer,
# Following two components were added to the pipeline.
model_resolver,
evaluator,
pusher,
]
return tfx.dsl.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
metadata_connection_config=tfx.orchestration.metadata
.sqlite_metadata_connection_config(metadata_path),
components=components)
We need to supply the following information to the Evaluator via eval_config
:
- Additional metrics to configure (if want more metrics than defined in model).
- Slices to configure
- Model validations thresholds to verify if validation to be included
Because SparseCategoricalAccuracy
was already included in the
model.compile()
call, it will be included in the analysis automatically. So
we do not add any additional metrics here. SparseCategoricalAccuracy
will be
used to decide whether the model is good enough, too.
We compute the metrics for the whole dataset and for each penguin species.
SlicingSpec
specifies how we aggregate the declared metrics.
There are two thresholds that a new model should pass, one is an absolute
threshold of 0.6 and the other is a relative threshold that it should
be higher than the baseline model. When you run the pipeline for the first
time, the change_threshold
will be ignored and only the value_threshold will
be checked. If you run the pipeline more than once, the Resolver
will find a
model from the previous run and it will be used as a baseline model for the
comparison.
See Evaluator component guide for more information.
Run the pipeline¶
We will use LocalDagRunner
as in the previous tutorial.
tfx.orchestration.LocalDagRunner().run(
_create_pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_root=PIPELINE_ROOT,
data_root=DATA_ROOT,
module_file=_trainer_module_file,
serving_model_dir=SERVING_MODEL_DIR,
metadata_path=METADATA_PATH))
When the pipeline completed, you should be able to see something like following:
INFO:absl:Blessing result True written to pipelines/penguin-tfma/Evaluator/blessing/4.
Or you can also check manually the output directory where the generated
artifacts are stored. If you visit
pipelines/penguin-tfma/Evaluator/blessing/
with a file broswer, you can see a
file with a name BLESSED
or NOT_BLESSED
according to the evaluation result.
If the blessing result is False
, Pusher will refuse to push the model to the
serving_model_dir
, because the model is not good enough to be used in
production.
You can run the pipeline again possibly with different evaluation configs. Even
if you run the pipeline with the exact same config and dataset, the trained
model might be slightly different due to the inherent randomness of the model
training which can lead to a NOT_BLESSED
model.
Examine outputs of the pipeline¶
You can use TFMA to investigate and visualize the evaluation result in ModelEvaluation artifact.
NOTE: If you are not on Colab, Install Jupyter Extensions. You need an TensorFlow Model Analysis extension to see the visualization from TFMA. This extension is already installed on Google Colab, but you might need to install it if you are running this notebook on other environments. See installation direction of Jupyter extension in the Install guide.
Get analysis result from output artifacts¶
You can use MLMD APIs to locate these outputs programatically. First, we will define some utility functions to search for the output artifacts that were just produced.
from ml_metadata.proto import metadata_store_pb2
# Non-public APIs, just for showcase.
from tfx.orchestration.portable.mlmd import execution_lib
# TODO(b/171447278): Move these functions into the TFX library.
def get_latest_artifacts(metadata, pipeline_name, component_id):
"""Output artifacts of the latest run of the component."""
context = metadata.store.get_context_by_type_and_name(
'node', f'{pipeline_name}.{component_id}')
executions = metadata.store.get_executions_by_context(context.id)
latest_execution = max(executions,
key=lambda e:e.last_update_time_since_epoch)
return execution_lib.get_output_artifacts(metadata, latest_execution.id)
We can find the latest execution of the Evaluator
component and get output
artifacts of it.
# Non-public APIs, just for showcase.
from tfx.orchestration.metadata import Metadata
from tfx.types import standard_component_specs
metadata_connection_config = tfx.orchestration.metadata.sqlite_metadata_connection_config(
METADATA_PATH)
with Metadata(metadata_connection_config) as metadata_handler:
# Find output artifacts from MLMD.
evaluator_output = get_latest_artifacts(metadata_handler, PIPELINE_NAME,
'Evaluator')
eval_artifact = evaluator_output[standard_component_specs.EVALUATION_KEY][0]
Evaluator
always returns one evaluation artifact, and we can visualize it
using TensorFlow Model Analysis library. For example, following code will
render the accuracy metrics for each penguin species.
import tensorflow_model_analysis as tfma
eval_result = tfma.load_eval_result(eval_artifact.uri)
tfma.view.render_slicing_metrics(eval_result, slicing_column='species')
If you choose 'sparse_categorical_accuracy' in Show
drop-down list, you can
see the accuracy values per species. You might want to add more slices and
check whether your model is good for all distribution and if there is any
possible bias.
Next steps¶
Learn more on model analysis at TensorFlow Model Analysis library tutorial.
You can find more resources on https://www.tensorflow.org/tfx/tutorials.
Please see Understanding TFX Pipelines to learn more about various concepts in TFX.