Copyright 2021 The TensorFlow Authors.¶
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Note: We recommend running this tutorial in a Colab notebook, with no setup required! Just click "Run in Google Colab".
TensorFlow Model Analysis¶
An Example of a Key Component of TensorFlow Extended (TFX)
TensorFlow Model Analysis (TFMA) is a library for performing model evaluation across different slices of data. TFMA performs its computations in a distributed manner over large amounts of data using Apache Beam.
This example colab notebook illustrates how TFMA can be used to investigate and visualize the performance of a model with respect to characteristics of the dataset. We'll use a model that we trained previously, and now you get to play with the results! The model we trained was for the Chicago Taxi Example, which uses the Taxi Trips dataset released by the City of Chicago. Explore the full dataset in the BigQuery UI.
As a modeler and developer, think about how this data is used and the potential benefits and harm a model's predictions can cause. A model like this could reinforce societal biases and disparities. Is a feature relevant to the problem you want to solve or will it introduce bias? For more information, read about ML fairness.
Note: In order to understand TFMA and how it works with Apache Beam, you'll need to know a little bit about Apache Beam itself. The Beam Programming Guide is a great place to start.
The columns in the dataset are:
pickup_community_area | fare | trip_start_month |
trip_start_hour | trip_start_day | trip_start_timestamp |
pickup_latitude | pickup_longitude | dropoff_latitude |
dropoff_longitude | trip_miles | pickup_census_tract |
dropoff_census_tract | payment_type | company |
trip_seconds | dropoff_community_area | tips |
Install Jupyter Extensions¶
Note: If running in a local Jupyter notebook, then these Jupyter extensions must be installed in the environment before running Jupyter.
jupyter nbextension enable --py widgetsnbextension --sys-prefix
jupyter nbextension install --py --symlink tensorflow_model_analysis --sys-prefix
jupyter nbextension enable --py tensorflow_model_analysis --sys-prefix
Install TensorFlow Model Analysis (TFMA)¶
This will pull in all the dependencies, and will take a minute.
# Upgrade pip to the latest, and install TFMA.
!pip install -U pip
!pip install tensorflow-model-analysis
Now you must restart the runtime before running the cells below.
# This setup was tested with TF 2.10 and TFMA 0.41 (using colab), but it should
# also work with the latest release.
import sys
# Confirm that we're using Python 3
assert sys.version_info.major==3, 'This notebook must be run using Python 3.'
import tensorflow as tf
print('TF version: {}'.format(tf.__version__))
import apache_beam as beam
print('Beam version: {}'.format(beam.__version__))
import tensorflow_model_analysis as tfma
print('TFMA version: {}'.format(tfma.__version__))
NOTE: The output above should be clear of errors before proceeding. Re-run the install if you are still seeing errors. Also, make sure to restart the runtime/kernel before moving to the next step.
Load The Files¶
We'll download a tar file that has everything we need. That includes:
- Training and evaluation datasets
- Data schema
- Training and serving saved models (keras and estimator) and eval saved models (estimator).
# Download the tar file from GCP and extract it
import io, os, tempfile
TAR_NAME = 'saved_models-2.2'
BASE_DIR = tempfile.mkdtemp()
DATA_DIR = os.path.join(BASE_DIR, TAR_NAME, 'data')
MODELS_DIR = os.path.join(BASE_DIR, TAR_NAME, 'models')
SCHEMA = os.path.join(BASE_DIR, TAR_NAME, 'schema.pbtxt')
OUTPUT_DIR = os.path.join(BASE_DIR, 'output')
!curl -O https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/{TAR_NAME}.tar
!tar xf {TAR_NAME}.tar
!mv {TAR_NAME} {BASE_DIR}
!rm {TAR_NAME}.tar
print("Here's what we downloaded:")
!ls -R {BASE_DIR}
Parse the Schema¶
Among the things we downloaded was a schema for our data that was created by TensorFlow Data Validation. Let's parse that now so that we can use it with TFMA.
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.lib.io import file_io
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow.core.example import example_pb2
schema = schema_pb2.Schema()
contents = file_io.read_file_to_string(SCHEMA)
schema = text_format.Parse(contents, schema)
Use the Schema to Create TFRecords¶
We need to give TFMA access to our dataset, so let's create a TFRecords file. We can use our schema to create it, since it gives us the correct type for each feature.
import csv
datafile = os.path.join(DATA_DIR, 'eval', 'data.csv')
reader = csv.DictReader(open(datafile, 'r'))
examples = []
for line in reader:
example = example_pb2.Example()
for feature in schema.feature:
key = feature.name
if feature.type == schema_pb2.FLOAT:
example.features.feature[key].float_list.value[:] = (
[float(line[key])] if len(line[key]) > 0 else [])
elif feature.type == schema_pb2.INT:
example.features.feature[key].int64_list.value[:] = (
[int(line[key])] if len(line[key]) > 0 else [])
elif feature.type == schema_pb2.BYTES:
example.features.feature[key].bytes_list.value[:] = (
[line[key].encode('utf8')] if len(line[key]) > 0 else [])
# Add a new column 'big_tipper' that indicates if tips was > 20% of the fare.
# TODO(b/157064428): Remove after label transformation is supported for Keras.
big_tipper = float(line['tips']) > float(line['fare']) * 0.2
example.features.feature['big_tipper'].float_list.value[:] = [big_tipper]
examples.append(example)
tfrecord_file = os.path.join(BASE_DIR, 'train_data.rio')
with tf.io.TFRecordWriter(tfrecord_file) as writer:
for example in examples:
writer.write(example.SerializeToString())
!ls {tfrecord_file}
Setup and Run TFMA¶
TFMA supports a number of different model types including TF keras models, models based on generic TF2 signature APIs, as well TF estimator based models. The get_started guide has the full list of model types supported and any restrictions. For this example we are going to show how to configure a keras based model as well as an estimator based model that was saved as an EvalSavedModel
. See the FAQ for examples of other configurations.
TFMA provides support for calculating metrics that were used at training time (i.e. built-in metrics) as well metrics defined after the model was saved as part of the TFMA configuration settings. For our keras setup we will demonstrate adding our metrics and plots manually as part of our configuration (see the metrics guide for information on the metrics and plots that are supported). For the estimator setup we will use the built-in metrics that were saved with the model. Our setups also include a number of slicing specs which are discussed in more detail in the following sections.
After creating a tfma.EvalConfig
and tfma.EvalSharedModel
we can then run TFMA using tfma.run_model_analysis
. This will create a tfma.EvalResult
which we can use later for rendering our metrics and plots.
Keras¶
import tensorflow_model_analysis as tfma
# Setup tfma.EvalConfig settings
keras_eval_config = text_format.Parse("""
## Model information
model_specs {
# For keras (and serving models) we need to add a `label_key`.
label_key: "big_tipper"
}
## Post training metric information. These will be merged with any built-in
## metrics from training.
metrics_specs {
metrics { class_name: "ExampleCount" }
metrics { class_name: "AUC" }
metrics { class_name: "Precision" }
metrics { class_name: "Recall" }
metrics { class_name: "MeanPrediction" }
metrics { class_name: "Calibration" }
metrics { class_name: "CalibrationPlot" }
metrics { class_name: "ConfusionMatrixPlot" }
# ... add additional metrics and plots ...
}
## Slicing information
slicing_specs {} # overall slice
slicing_specs {
feature_keys: ["trip_start_hour"]
}
slicing_specs {
feature_keys: ["trip_start_day"]
}
slicing_specs {
feature_values: {
key: "trip_start_month"
value: "1"
}
}
""", tfma.EvalConfig())
# Create a tfma.EvalSharedModel that points at our keras model.
keras_model_path = os.path.join(MODELS_DIR, 'keras', '2')
keras_eval_shared_model = tfma.default_eval_shared_model(
eval_saved_model_path=keras_model_path,
eval_config=keras_eval_config)
keras_output_path = os.path.join(OUTPUT_DIR, 'keras')
# Run TFMA
keras_eval_result = tfma.run_model_analysis(
eval_shared_model=keras_eval_shared_model,
eval_config=keras_eval_config,
data_location=tfrecord_file,
output_path=keras_output_path)
Estimator¶
import tensorflow_model_analysis as tfma
# Setup tfma.EvalConfig settings
estimator_eval_config = text_format.Parse("""
## Model information
model_specs {
# To use EvalSavedModel set `signature_name` to "eval".
signature_name: "eval"
}
## Post training metric information. These will be merged with any built-in
## metrics from training.
metrics_specs {
metrics { class_name: "ConfusionMatrixPlot" }
# ... add additional metrics and plots ...
}
## Slicing information
slicing_specs {} # overall slice
slicing_specs {
feature_keys: ["trip_start_hour"]
}
slicing_specs {
feature_keys: ["trip_start_day"]
}
slicing_specs {
feature_values: {
key: "trip_start_month"
value: "1"
}
}
""", tfma.EvalConfig())
# Create a tfma.EvalSharedModel that points at our eval saved model.
estimator_base_model_path = os.path.join(
MODELS_DIR, 'estimator', 'eval_model_dir')
estimator_model_path = os.path.join(
estimator_base_model_path, os.listdir(estimator_base_model_path)[0])
estimator_eval_shared_model = tfma.default_eval_shared_model(
eval_saved_model_path=estimator_model_path,
eval_config=estimator_eval_config)
estimator_output_path = os.path.join(OUTPUT_DIR, 'estimator')
# Run TFMA
estimator_eval_result = tfma.run_model_analysis(
eval_shared_model=estimator_eval_shared_model,
eval_config=estimator_eval_config,
data_location=tfrecord_file,
output_path=estimator_output_path)
Visualizing Metrics and Plots¶
Now that we've run the evaluation, let's take a look at our visualizations using TFMA. For the following examples, we will visualize the results from running the evaluation on the keras model. To view the estimator based model update the eval_result_path
to point at our estimator_output_path
variable.
eval_result_path = keras_output_path
# eval_result_path = estimator_output_path
eval_result = keras_eval_result
# eval_result = estimator_eval_result
Rendering Metrics¶
TFMA provides dataframe APIs in tfma.experimental.dataframe
to load the materalized output as Pandas DataFrames
. To view metrics you can use metrics_as_dataframes(tfma.load_metrics(eval_path))
, which returns an object which potentially contains several DataFrames, one for each metric value type (double_value
, confusion_matrix_at_thresholds
, bytes_value
, and array_value
). The specific DataFrames populated depends on the eval result. Here, we show the double_value
DataFrame as an example.
import tensorflow_model_analysis.experimental.dataframe as tfma_dataframe
dfs = tfma_dataframe.metrics_as_dataframes(
tfma.load_metrics(eval_result_path))
display(dfs.double_value.head())
Each of the DataFrames has a column multi-index with the top-level columns: slices
, metric_keys
, and metric_values
. The exact columns of each group can change according to the payload. we can use DataFrame.columns
API to inspect all the multi-index columns. For example, the slices columns are 'Overall', 'trip_start_day', 'trip_start_hour', and 'trip_start_month', which is configured by the slicing_specs
in the eval_config
.
print(dfs.double_value.columns)
Auto pivoting¶
The DataFrame is verbose by design so that there is no loss of information from the payload. However, sometimes, for direct consumption, we might want to organize the information in a more concise but lossy form: slices as rows and metrics as columns. TFMA provides an auto_pivot
API for this purpose. The util pivots on all of the non-unique columns inside metric_keys
, and condenses all the slices into one stringified_slices
column by default.
tfma_dataframe.auto_pivot(dfs.double_value).head()
Filtering slices¶
Since the outputs are DataFrames, any native DataFrame APIs can be used to slice and dice the DataFrame. For example, if we are only interested in trip_start_hour
of 1, 3, 5, 7 and not in trip_start_day
, we can use DataFrame's .loc
filtering logic. Again, we use the auto_pivot
function to re-organize the DataFrame in the slice vs. metrics view after the filtering is performed.
df_double = dfs.double_value
df_filtered = (df_double
.loc[df_double.slices.trip_start_hour.isin([1,3,5,7])]
)
display(tfma_dataframe.auto_pivot(df_filtered))
Sorting by metric values¶
We can also sort slices by metrics value. As an example, we show how to sort slices in the above DataFrame by ascending AUC, so that we can find poorly performing slices. This involves two steps: auto-pivoting so that slices are represented as rows and columns are metrics, and then sorting the pivoted DataFrame by the AUC column.
# Pivoted table sorted by AUC in ascending order.
df_sorted = (
tfma_dataframe.auto_pivot(df_double)
.sort_values(by='auc', ascending=True)
)
display(df_sorted.head())
Rendering Plots¶
Any plots that were added to the tfma.EvalConfig
as post training metric_specs
can be displayed using tfma.view.render_plot
.
As with metrics, plots can be viewed by slice. Unlike metrics, only plots for a particular slice value can be displayed so the tfma.SlicingSpec
must be used and it must specify both a slice feature name and value. If no slice is provided then the plots for the Overall
slice is used.
In the example below we are displaying the CalibrationPlot
and ConfusionMatrixPlot
plots that were computed for the trip_start_hour:1
slice.
tfma.view.render_plot(
eval_result,
tfma.SlicingSpec(feature_values={'trip_start_hour': '1'}))
Tracking Model Performance Over Time¶
Your training dataset will be used for training your model, and will hopefully be representative of your test dataset and the data that will be sent to your model in production. However, while the data in inference requests may remain the same as your training data, in many cases it will start to change enough so that the performance of your model will change.
That means that you need to monitor and measure your model's performance on an ongoing basis, so that you can be aware of and react to changes. Let's take a look at how TFMA can help.
Let's load 3 different model runs and use TFMA to see how they compare using render_time_series
.
# Note this re-uses the EvalConfig from the keras setup.
# Run eval on each saved model
output_paths = []
for i in range(3):
# Create a tfma.EvalSharedModel that points at our saved model.
eval_shared_model = tfma.default_eval_shared_model(
eval_saved_model_path=os.path.join(MODELS_DIR, 'keras', str(i)),
eval_config=keras_eval_config)
output_path = os.path.join(OUTPUT_DIR, 'time_series', str(i))
output_paths.append(output_path)
# Run TFMA
tfma.run_model_analysis(eval_shared_model=eval_shared_model,
eval_config=keras_eval_config,
data_location=tfrecord_file,
output_path=output_path)
First, we'll imagine that we've trained and deployed our model yesterday, and now we want to see how it's doing on the new data coming in today. The visualization will start by displaying AUC. From the UI you can:
- Add other metrics using the "Add metric series" menu.
- Close unwanted graphs by clicking on x
- Hover over data points (the ends of line segments in the graph) to get more details
Note: In the metric series charts the X axis is the model directory name of the model run that you're examining. These names themselves are not meaningful.
eval_results_from_disk = tfma.load_eval_results(output_paths[:2])
tfma.view.render_time_series(eval_results_from_disk)
Now we'll imagine that another day has passed and we want to see how it's doing on the new data coming in today, compared to the previous two days:
eval_results_from_disk = tfma.load_eval_results(output_paths)
tfma.view.render_time_series(eval_results_from_disk)
Model Validation¶
TFMA can be configured to evaluate multiple models at the same time. Typically this is done to compare a new model against a baseline (such as the currently serving model) to determine what the performance differences in metrics (e.g. AUC, etc) are relative to the baseline. When thresholds are configured, TFMA will produce a tfma.ValidationResult
record indicating whether the performance matches expecations.
Let's re-configure our keras evaluation to compare two models: a candidate and a baseline. We will also validate the candidate's performance against the baseline by setting a tmfa.MetricThreshold
on the AUC metric.
# Setup tfma.EvalConfig setting
eval_config_with_thresholds = text_format.Parse("""
## Model information
model_specs {
name: "candidate"
# For keras we need to add a `label_key`.
label_key: "big_tipper"
}
model_specs {
name: "baseline"
# For keras we need to add a `label_key`.
label_key: "big_tipper"
is_baseline: true
}
## Post training metric information
metrics_specs {
metrics { class_name: "ExampleCount" }
metrics { class_name: "BinaryAccuracy" }
metrics { class_name: "BinaryCrossentropy" }
metrics {
class_name: "AUC"
threshold {
# Ensure that AUC is always > 0.9
value_threshold {
lower_bound { value: 0.9 }
}
# Ensure that AUC does not drop by more than a small epsilon
# e.g. (candidate - baseline) > -1e-10 or candidate > baseline - 1e-10
change_threshold {
direction: HIGHER_IS_BETTER
absolute { value: -1e-10 }
}
}
}
metrics { class_name: "AUCPrecisionRecall" }
metrics { class_name: "Precision" }
metrics { class_name: "Recall" }
metrics { class_name: "MeanLabel" }
metrics { class_name: "MeanPrediction" }
metrics { class_name: "Calibration" }
metrics { class_name: "CalibrationPlot" }
metrics { class_name: "ConfusionMatrixPlot" }
# ... add additional metrics and plots ...
}
## Slicing information
slicing_specs {} # overall slice
slicing_specs {
feature_keys: ["trip_start_hour"]
}
slicing_specs {
feature_keys: ["trip_start_day"]
}
slicing_specs {
feature_keys: ["trip_start_month"]
}
slicing_specs {
feature_keys: ["trip_start_hour", "trip_start_day"]
}
""", tfma.EvalConfig())
# Create tfma.EvalSharedModels that point at our keras models.
candidate_model_path = os.path.join(MODELS_DIR, 'keras', '2')
baseline_model_path = os.path.join(MODELS_DIR, 'keras', '1')
eval_shared_models = [
tfma.default_eval_shared_model(
model_name=tfma.CANDIDATE_KEY,
eval_saved_model_path=candidate_model_path,
eval_config=eval_config_with_thresholds),
tfma.default_eval_shared_model(
model_name=tfma.BASELINE_KEY,
eval_saved_model_path=baseline_model_path,
eval_config=eval_config_with_thresholds),
]
validation_output_path = os.path.join(OUTPUT_DIR, 'validation')
# Run TFMA
eval_result_with_validation = tfma.run_model_analysis(
eval_shared_models,
eval_config=eval_config_with_thresholds,
data_location=tfrecord_file,
output_path=validation_output_path)
When running evaluations with one or more models against a baseline, TFMA automatically adds diff metrics for all of the metrics computed during the evaluation. These metrics are named after the corresponding metric but with _diff
appended to the metric name.
Let's take a look at the metrics produced by our run:
tfma.view.render_time_series(eval_result_with_validation)
Now let's look at the output from our validation checks. To view the validation results we use tfma.load_validator_result
. For our example, the validation fails because AUC is below the threshold.
validation_result = tfma.load_validation_result(validation_output_path)
print(validation_result.validation_ok)
Copyright © 2020 The TensorFlow Authors.¶
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Note: This site provides applications using data that has been modified for use from its original source, www.cityofchicago.org, the official website of the City of Chicago. The City of Chicago makes no claims as to the content, accuracy, timeliness, or completeness of any of the data provided at this site. The data provided at this site is subject to change at any time. It is understood that the data provided at this site is being used at one’s own risk.