Skip to content

TFMA Writers

tensorflow_model_analysis.writers

Init module for TensorFlow Model Analysis writers.

Classes

Writer

Bases: NamedTuple

Attributes
ptransform instance-attribute
ptransform: PTransform
stage_name instance-attribute
stage_name: str

Functions

EvalConfigWriter

EvalConfigWriter(
    output_path: str,
    eval_config: EvalConfig,
    output_file_format: str = EVAL_CONFIG_FILE_FORMAT,
    data_location: Optional[str] = None,
    data_file_format: Optional[str] = None,
    model_locations: Optional[Dict[str, str]] = None,
    filename: Optional[str] = None,
) -> Writer

Returns eval config writer.


output_path: Output path to write config to. eval_config: EvalConfig. output_file_format: Output file format. Currently on 'json' is supported. data_location: Optional path indicating where data is read from. This is only used for display purposes. data_file_format: Optional format of the input examples. This is only used for display purposes. model_locations: Dict of model locations keyed by model name. This is only used for display purposes. filename: Name of file to store the config as.

Source code in tensorflow_model_analysis/writers/eval_config_writer.py
def EvalConfigWriter(  # pylint: disable=invalid-name
    output_path: str,
    eval_config: config_pb2.EvalConfig,
    output_file_format: str = EVAL_CONFIG_FILE_FORMAT,
    data_location: Optional[str] = None,
    data_file_format: Optional[str] = None,
    model_locations: Optional[Dict[str, str]] = None,
    filename: Optional[str] = None,
) -> writer.Writer:
    """Returns eval config writer.

    Args:
    ----
      output_path: Output path to write config to.
      eval_config: EvalConfig.
      output_file_format: Output file format. Currently on 'json' is supported.
      data_location: Optional path indicating where data is read from. This is
        only used for display purposes.
      data_file_format: Optional format of the input examples. This is only used
        for display purposes.
      model_locations: Dict of model locations keyed by model name. This is only
        used for display purposes.
      filename: Name of file to store the config as.
    """
    if data_location is None:
        data_location = "<user provided PCollection>"
    if data_file_format is None:
        data_file_format = "<unknown>"
    if model_locations is None:
        model_locations = {"": "<unknown>"}
    if filename is None:
        filename = EVAL_CONFIG_FILE + "." + output_file_format

    return writer.Writer(
        stage_name="WriteEvalConfig",
        ptransform=_WriteEvalConfig(  # pylint: disable=no-value-for-parameter
            eval_config=eval_config,
            output_path=output_path,
            output_file_format=output_file_format,
            data_location=data_location,
            data_file_format=data_file_format,
            model_locations=model_locations,
            filename=filename,
        ),
    )

MetricsPlotsAndValidationsWriter

MetricsPlotsAndValidationsWriter(
    output_paths: Dict[str, str],
    eval_config: EvalConfig,
    add_metrics_callbacks: Optional[
        List[AddMetricsCallbackType]
    ] = None,
    metrics_key: str = METRICS_KEY,
    plots_key: str = PLOTS_KEY,
    attributions_key: str = ATTRIBUTIONS_KEY,
    validations_key: str = VALIDATIONS_KEY,
    output_file_format: str = _TFRECORD_FORMAT,
    rubber_stamp: Optional[bool] = False,
    stage_name: str = METRICS_PLOTS_AND_VALIDATIONS_WRITER_STAGE_NAME,
) -> Writer

Returns metrics and plots writer.

Note, sharding will be enabled by default if a output_file_format is provided. The files will be named -SSSSS-of-NNNNN. where SSSSS is the shard number and NNNNN is the number of shards.


output_paths: Output paths keyed by output key (e.g. 'metrics', 'plots', 'validation'). eval_config: Eval config. add_metrics_callbacks: Optional list of metric callbacks (if used). metrics_key: Name to use for metrics key in Evaluation output. plots_key: Name to use for plots key in Evaluation output. attributions_key: Name to use for attributions key in Evaluation output. validations_key: Name to use for validations key in Evaluation output. output_file_format: File format to use when saving files. Currently 'tfrecord' and 'parquet' are supported and 'tfrecord is the default'. If using parquet, the output metrics and plots files will contain two columns, 'slice_key' and 'serialized_value'. The 'slice_key' column will be a structured column matching the metrics_for_slice_pb2.SliceKey proto. The 'serialized_value' column will contain a serialized MetricsForSlice or PlotsForSlice proto. The validation result file will contain a single column 'serialized_value' which will contain a single serialized ValidationResult proto. rubber_stamp: True if this model is being rubber stamped. When a model is rubber stamped diff thresholds will be ignored if an associated baseline model is not passed. stage_name: The stage name to use when this writer is added to the Beam pipeline.

Source code in tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py
def MetricsPlotsAndValidationsWriter(  # pylint: disable=invalid-name
    output_paths: Dict[str, str],
    eval_config: config_pb2.EvalConfig,
    add_metrics_callbacks: Optional[List[types.AddMetricsCallbackType]] = None,
    metrics_key: str = constants.METRICS_KEY,
    plots_key: str = constants.PLOTS_KEY,
    attributions_key: str = constants.ATTRIBUTIONS_KEY,
    validations_key: str = constants.VALIDATIONS_KEY,
    output_file_format: str = _TFRECORD_FORMAT,
    rubber_stamp: Optional[bool] = False,
    stage_name: str = METRICS_PLOTS_AND_VALIDATIONS_WRITER_STAGE_NAME,
) -> writer.Writer:
    """Returns metrics and plots writer.

    Note, sharding will be enabled by default if a output_file_format is provided.
    The files will be named <output_path>-SSSSS-of-NNNNN.<output_file_format>
    where SSSSS is the shard number and NNNNN is the number of shards.

    Args:
    ----
      output_paths: Output paths keyed by output key (e.g. 'metrics', 'plots',
        'validation').
      eval_config: Eval config.
      add_metrics_callbacks: Optional list of metric callbacks (if used).
      metrics_key: Name to use for metrics key in Evaluation output.
      plots_key: Name to use for plots key in Evaluation output.
      attributions_key: Name to use for attributions key in Evaluation output.
      validations_key: Name to use for validations key in Evaluation output.
      output_file_format: File format to use when saving files. Currently
        'tfrecord' and 'parquet' are supported and 'tfrecord is the default'.
        If using parquet, the output metrics and plots files will contain two
        columns, 'slice_key' and 'serialized_value'. The 'slice_key' column will
        be a structured column matching the metrics_for_slice_pb2.SliceKey proto.
        The 'serialized_value' column will contain a serialized MetricsForSlice or
        PlotsForSlice proto. The validation result file will contain a single
        column 'serialized_value' which will contain a single serialized
        ValidationResult proto.
      rubber_stamp: True if this model is being rubber stamped. When a model is
        rubber stamped diff thresholds will be ignored if an associated baseline
        model is not passed.
      stage_name: The stage name to use when this writer is added to the Beam
        pipeline.
    """
    return writer.Writer(
        stage_name=stage_name,
        ptransform=_WriteMetricsPlotsAndValidations(  # pylint: disable=no-value-for-parameter
            output_paths=output_paths,
            eval_config=eval_config,
            add_metrics_callbacks=add_metrics_callbacks or [],
            metrics_key=metrics_key,
            plots_key=plots_key,
            attributions_key=attributions_key,
            validations_key=validations_key,
            output_file_format=output_file_format,
            rubber_stamp=rubber_stamp,
        ),
    )

Write

Write(
    evaluation_or_validation: Union[Evaluation, Validation],
    key: str,
    ptransform: PTransform,
) -> Optional[PCollection]

Writes given Evaluation or Validation data using given writer PTransform.


evaluation_or_validation: Evaluation or Validation data. key: Key for Evaluation or Validation output to write. It is valid for the key to not exist in the dict (in which case the write is a no-op). ptransform: PTransform to use for writing.


ValueError: If Evaluation or Validation is empty. The key does not need to exist in the Evaluation or Validation, but the dict must not be empty.


The result of the underlying beam write PTransform. This makes it possible for interactive environments to execute your writer, as well as for downstream Beam stages to make use of the files that are written.

Source code in tensorflow_model_analysis/writers/writer.py
@beam.ptransform_fn
def Write(
    evaluation_or_validation: Union[evaluator.Evaluation, validator.Validation],
    key: str,
    ptransform: beam.PTransform,
) -> Optional[beam.PCollection]:
    """Writes given Evaluation or Validation data using given writer PTransform.

    Args:
    ----
      evaluation_or_validation: Evaluation or Validation data.
      key: Key for Evaluation or Validation output to write. It is valid for the
        key to not exist in the dict (in which case the write is a no-op).
      ptransform: PTransform to use for writing.

    Raises:
    ------
      ValueError: If Evaluation or Validation is empty. The key does not need to
        exist in the Evaluation or Validation, but the dict must not be empty.

    Returns:
    -------
      The result of the underlying beam write PTransform. This makes it possible
      for interactive environments to execute your writer, as well as for
      downstream Beam stages to make use of the files that are written.
    """
    if not evaluation_or_validation:
        raise ValueError("Evaluations and Validations cannot be empty")
    if key in evaluation_or_validation:
        return evaluation_or_validation[key] | ptransform
    return None

convert_slice_metrics_to_proto

convert_slice_metrics_to_proto(
    metrics: Tuple[
        SliceKeyOrCrossSliceKeyType, MetricsDict
    ],
    add_metrics_callbacks: Optional[
        List[AddMetricsCallbackType]
    ],
) -> MetricsForSlice

Converts the given slice metrics into serialized proto MetricsForSlice.


metrics: The slice metrics. add_metrics_callbacks: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate().


The MetricsForSlice proto.


TypeError: If the type of the feature value in slice key cannot be recognized.

Source code in tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py
def convert_slice_metrics_to_proto(
    metrics: Tuple[slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict],
    add_metrics_callbacks: Optional[List[types.AddMetricsCallbackType]],
) -> metrics_for_slice_pb2.MetricsForSlice:
    """Converts the given slice metrics into serialized proto MetricsForSlice.

    Args:
    ----
      metrics: The slice metrics.
      add_metrics_callbacks: A list of metric callbacks. This should be the same
        list as the one passed to tfma.Evaluate().

    Returns:
    -------
      The MetricsForSlice proto.

    Raises:
    ------
      TypeError: If the type of the feature value in slice key cannot be
        recognized.
    """
    result = metrics_for_slice_pb2.MetricsForSlice()
    slice_key, slice_metrics = metrics

    if slicer.is_cross_slice_key(slice_key):
        result.cross_slice_key.CopyFrom(slicer.serialize_cross_slice_key(slice_key))
    else:
        result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

    slice_metrics = slice_metrics.copy()

    if metric_keys.ERROR_METRIC in slice_metrics:
        logging.warning(
            "Error for slice: %s with error message: %s ",
            slice_key,
            slice_metrics[metric_keys.ERROR_METRIC],
        )
        result.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[
            metric_keys.ERROR_METRIC
        ]
        return result

    # Convert the metrics from add_metrics_callbacks to the structured output if
    # defined.
    if add_metrics_callbacks and (
        not any(isinstance(k, metric_types.MetricKey) for k in slice_metrics)
    ):
        for add_metrics_callback in add_metrics_callbacks:
            if hasattr(add_metrics_callback, "populate_stats_and_pop"):
                add_metrics_callback.populate_stats_and_pop(
                    slice_key, slice_metrics, result.metrics
                )
    for key in sorted(slice_metrics):
        value = slice_metrics[key]
        if isinstance(value, types.ValueWithTDistribution):
            unsampled_value = value.unsampled_value
            _, lower_bound, upper_bound = math_util.calculate_confidence_interval(value)
            confidence_interval = metrics_for_slice_pb2.ConfidenceInterval(
                lower_bound=convert_metric_value_to_proto(lower_bound),
                upper_bound=convert_metric_value_to_proto(upper_bound),
                standard_error=convert_metric_value_to_proto(
                    value.sample_standard_deviation
                ),
                degrees_of_freedom={"value": value.sample_degrees_of_freedom},
            )
            metric_value = convert_metric_value_to_proto(unsampled_value)
            if isinstance(key, metric_types.MetricKey):
                result.metric_keys_and_values.add(
                    key=key.to_proto(),
                    value=metric_value,
                    confidence_interval=confidence_interval,
                )
            else:
                # For v1 we continue to populate bounded_value for backwards
                # compatibility. If metric can be stored to double_value metrics,
                # replace it with a bounded_value.
                # TODO(b/171992041): remove the string-typed metric key branch once v1
                # code is removed.
                if metric_value.WhichOneof("type") == "double_value":
                    # setting bounded_value clears double_value in the same oneof scope.
                    metric_value.bounded_value.value.value = unsampled_value
                    metric_value.bounded_value.lower_bound.value = lower_bound
                    metric_value.bounded_value.upper_bound.value = upper_bound
                    metric_value.bounded_value.methodology = (
                        metrics_for_slice_pb2.BoundedValue.POISSON_BOOTSTRAP
                    )
                result.metrics[key].CopyFrom(metric_value)
        elif isinstance(value, metrics_for_slice_pb2.BoundedValue):
            metric_value = metrics_for_slice_pb2.MetricValue(
                double_value=wrappers_pb2.DoubleValue(value=value.value.value)
            )
            confidence_interval = metrics_for_slice_pb2.ConfidenceInterval(
                lower_bound=metrics_for_slice_pb2.MetricValue(
                    double_value=wrappers_pb2.DoubleValue(value=value.lower_bound.value)
                ),
                upper_bound=metrics_for_slice_pb2.MetricValue(
                    double_value=wrappers_pb2.DoubleValue(value=value.upper_bound.value)
                ),
            )
            result.metric_keys_and_values.add(
                key=key.to_proto(),
                value=metric_value,
                confidence_interval=confidence_interval,
            )
        else:
            metric_value = convert_metric_value_to_proto(value)
            if isinstance(key, metric_types.MetricKey):
                result.metric_keys_and_values.add(
                    key=key.to_proto(), value=metric_value
                )
            else:
                # TODO(b/171992041): remove the string-typed metric key branch once v1
                # code is removed.
                result.metrics[key].CopyFrom(metric_value)
    return result