Skip to content

TFMA Utils

tensorflow_model_analysis.utils

Init module for TensorFlow Model Analysis utils.

Classes

CombineFnWithModels

CombineFnWithModels(model_loaders: Dict[str, ModelLoader])

Bases: CombineFn

Abstract class for CombineFns that need the shared models.

Initializes CombineFn using dict of loaders keyed by model location.

Source code in tensorflow_model_analysis/utils/model_util.py
def __init__(self, model_loaders: Dict[str, types.ModelLoader]):
    """Initializes CombineFn using dict of loaders keyed by model location."""
    self._model_loaders = model_loaders
    self._loaded_models = None
    self._model_load_seconds = None
    self._model_load_seconds_distribution = beam.metrics.Metrics.distribution(
        constants.METRICS_NAMESPACE, "model_load_seconds"
    )
Functions
setup
setup()
Source code in tensorflow_model_analysis/utils/model_util.py
def setup(self):
    if self._loaded_models is None:
        self._loaded_models = {}
        for model_name, model_loader in self._model_loaders.items():
            self._loaded_models[model_name] = model_loader.load(
                model_load_time_callback=self._set_model_load_seconds
            )
        if self._model_load_seconds is not None:
            self._model_load_seconds_distribution.update(self._model_load_seconds)
            self._model_load_seconds = None

DoFnWithModels

DoFnWithModels(model_loaders: Dict[str, ModelLoader])

Bases: DoFn

Abstract class for DoFns that need the shared models.

Initializes DoFn using dict of model loaders keyed by model location.

Source code in tensorflow_model_analysis/utils/model_util.py
def __init__(self, model_loaders: Dict[str, types.ModelLoader]):
    """Initializes DoFn using dict of model loaders keyed by model location."""
    self._model_loaders = model_loaders
    self._loaded_models = None
    self._model_load_seconds = None
    self._model_load_seconds_distribution = beam.metrics.Metrics.distribution(
        constants.METRICS_NAMESPACE, "model_load_seconds"
    )
Functions
finish_bundle
finish_bundle()
Source code in tensorflow_model_analysis/utils/model_util.py
def finish_bundle(self):
    # Must update distribution in finish_bundle instead of setup
    # because Beam metrics are not supported in setup.
    if self._model_load_seconds is not None:
        self._model_load_seconds_distribution.update(self._model_load_seconds)
        self._model_load_seconds = None
process
process(elem)
Source code in tensorflow_model_analysis/utils/model_util.py
def process(self, elem):
    raise NotImplementedError("Subclasses are expected to override this.")
setup
setup()
Source code in tensorflow_model_analysis/utils/model_util.py
def setup(self):
    self._loaded_models = {}
    for model_name, model_loader in self._model_loaders.items():
        self._loaded_models[model_name] = model_loader.load(
            model_load_time_callback=self._set_model_load_seconds
        )

Functions

calculate_confidence_interval

calculate_confidence_interval(
    t_distribution_value: ValueWithTDistribution,
)

Calculate confidence intervals based 95% confidence level.

Source code in tensorflow_model_analysis/utils/math_util.py
def calculate_confidence_interval(t_distribution_value: types.ValueWithTDistribution):
    """Calculate confidence intervals based 95% confidence level."""
    alpha = 0.05
    std_err = t_distribution_value.sample_standard_deviation
    t_stat = stats.t.ppf(
        1 - (alpha / 2.0), t_distribution_value.sample_degrees_of_freedom
    )
    # The order of operands matters here because we want to use the
    # std_err.__mul__ operator below, rather than the t_stat.__mul__.
    # TODO(b/197669322): make StructuredMetricValues robust to operand ordering.
    upper_bound = t_distribution_value.sample_mean + std_err * t_stat
    lower_bound = t_distribution_value.sample_mean - std_err * t_stat
    return t_distribution_value.sample_mean, lower_bound, upper_bound

compound_key

compound_key(
    keys: Sequence[str], separator: str = KEY_SEPARATOR
) -> str

Returns a compound key based on a list of keys.


keys: Keys used to make up compound key. separator: Separator between keys. To ensure the keys can be parsed out of any compound key created, any use of a separator within a key will be replaced by two separators.

Source code in tensorflow_model_analysis/utils/util.py
def compound_key(keys: Sequence[str], separator: str = KEY_SEPARATOR) -> str:
    """Returns a compound key based on a list of keys.

    Args:
    ----
      keys: Keys used to make up compound key.
      separator: Separator between keys. To ensure the keys can be parsed out of
        any compound key created, any use of a separator within a key will be
        replaced by two separators.
    """
    return separator.join([key.replace(separator, separator * 2) for key in keys])

create_keys_key

create_keys_key(key: str) -> str

Creates secondary key representing the sparse keys associated with key.

Source code in tensorflow_model_analysis/utils/util.py
def create_keys_key(key: str) -> str:
    """Creates secondary key representing the sparse keys associated with key."""
    return "_".join([key, KEYS_SUFFIX])

create_values_key

create_values_key(key: str) -> str

Creates secondary key representing sparse values associated with key.

Source code in tensorflow_model_analysis/utils/util.py
def create_values_key(key: str) -> str:
    """Creates secondary key representing sparse values associated with key."""
    return "_".join([key, VALUES_SUFFIX])

get_baseline_model_spec

get_baseline_model_spec(
    eval_config: EvalConfig,
) -> Optional[ModelSpec]

Returns baseline model spec.

Source code in tensorflow_model_analysis/utils/model_util.py
def get_baseline_model_spec(
    eval_config: config_pb2.EvalConfig,
) -> Optional[config_pb2.ModelSpec]:
    """Returns baseline model spec."""
    for spec in eval_config.model_specs:
        if spec.is_baseline:
            return spec
    return None

get_by_keys

get_by_keys(
    data: Mapping[str, Any],
    keys: Sequence[Any],
    default_value=None,
    optional: bool = False,
) -> Any

Returns value with given key(s) in (possibly multi-level) dict.

The keys represent multiple levels of indirection into the data. For example if 3 keys are passed then the data is expected to be a dict of dict of dict. For compatibily with data that uses prefixing to create separate the keys in a single dict, lookups will also be searched for under the keys separated by '/'. For example, the keys 'head1' and 'probabilities' could be stored in a a single dict as 'head1/probabilties'.


data: Dict to get value from. keys: Sequence of keys to lookup in data. None keys will be ignored. default_value: Default value if not found. optional: Whether the key is optional or not. If default value is None and optional is False then a ValueError will be raised if key not found.


ValueError: If (non-optional) key is not found.

Source code in tensorflow_model_analysis/utils/util.py
def get_by_keys(
    data: Mapping[str, Any],
    keys: Sequence[Any],
    default_value=None,
    optional: bool = False,
) -> Any:
    """Returns value with given key(s) in (possibly multi-level) dict.

    The keys represent multiple levels of indirection into the data. For example
    if 3 keys are passed then the data is expected to be a dict of dict of dict.
    For compatibily with data that uses prefixing to create separate the keys in a
    single dict, lookups will also be searched for under the keys separated by
    '/'. For example, the keys 'head1' and 'probabilities' could be stored in a
    a single dict as 'head1/probabilties'.

    Args:
    ----
      data: Dict to get value from.
      keys: Sequence of keys to lookup in data. None keys will be ignored.
      default_value: Default value if not found.
      optional: Whether the key is optional or not. If default value is None and
        optional is False then a ValueError will be raised if key not found.

    Raises:
    ------
      ValueError: If (non-optional) key is not found.
    """
    if not keys:
        raise ValueError("no keys provided to get_by_keys: %s" % data)

    format_keys = lambda keys: "->".join([str(k) for k in keys if k is not None])

    value = data
    keys_matched = 0
    for i, key in enumerate(keys):
        if key is None:
            keys_matched += 1
            continue

        if not isinstance(value, Mapping):
            raise ValueError(
                'expected dict for "%s" but found %s: %s'
                % (format_keys(keys[: i + 1]), type(value), data)
            )

        if key in value:
            value = value[key]
            keys_matched += 1
            continue

        # If values have prefixes matching the key, return those values (stripped
        # of the prefix) instead.
        prefix_matches = {}
        for k, v in value.items():
            if k.startswith(key + "/"):
                prefix_matches[k[len(key) + 1 :]] = v
        if prefix_matches:
            value = prefix_matches
            keys_matched += 1
            continue

        break

    if keys_matched < len(keys) or isinstance(value, Mapping) and not value:
        if default_value is not None:
            return default_value
        if optional:
            return None
        raise ValueError(
            '"%s" key not found (or value is empty dict): %s'
            % (format_keys(keys[: keys_matched + 1]), data)
        )
    return value

get_model_spec

get_model_spec(
    eval_config: EvalConfig, model_name: str
) -> Optional[ModelSpec]

Returns model spec with given model name.

Source code in tensorflow_model_analysis/utils/model_util.py
def get_model_spec(
    eval_config: config_pb2.EvalConfig, model_name: str
) -> Optional[config_pb2.ModelSpec]:
    """Returns model spec with given model name."""
    if len(eval_config.model_specs) == 1 and not model_name:
        return eval_config.model_specs[0]
    for spec in eval_config.model_specs:
        if spec.name == model_name:
            return spec
    return None

get_model_type

get_model_type(
    model_spec: Optional[ModelSpec],
    model_path: Optional[str] = "",
    tags: Optional[List[str]] = None,
) -> str

Returns model type for given model spec taking into account defaults.

The defaults are chosen such that if a model_path is provided and the model can be loaded as a keras model then TF_KERAS is assumed. Next, if tags are provided and the tags contains 'eval' then TF_ESTIMATOR is assumed. Lastly, if the model spec contains an 'eval' signature TF_ESTIMATOR is assumed otherwise TF_GENERIC is assumed.


model_spec: Model spec. model_path: Optional model path to verify if keras model. tags: Options tags to verify if eval is used.

Source code in tensorflow_model_analysis/utils/model_util.py
def get_model_type(
    model_spec: Optional[config_pb2.ModelSpec],
    model_path: Optional[str] = "",
    tags: Optional[List[str]] = None,
) -> str:
    """Returns model type for given model spec taking into account defaults.

    The defaults are chosen such that if a model_path is provided and the model
    can be loaded as a keras model then TF_KERAS is assumed. Next, if tags
    are provided and the tags contains 'eval' then TF_ESTIMATOR is assumed.
    Lastly, if the model spec contains an 'eval' signature TF_ESTIMATOR is assumed
    otherwise TF_GENERIC is assumed.

    Args:
    ----
      model_spec: Model spec.
      model_path: Optional model path to verify if keras model.
      tags: Options tags to verify if eval is used.
    """
    if model_spec and model_spec.model_type:
        return model_spec.model_type

    if model_path:
        try:
            keras_model = tf.keras.models.load_model(model_path)
            # In some cases, tf.keras.models.load_model can successfully load a
            # saved_model but it won't actually be a keras model.
            if isinstance(keras_model, tf.keras.models.Model):
                return constants.TF_KERAS
        except Exception:  # pylint: disable=broad-except
            pass

    signature_name = None
    if model_spec:
        if model_spec.signature_name:
            signature_name = model_spec.signature_name
        else:
            signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY

    return constants.TF_GENERIC

get_non_baseline_model_specs

get_non_baseline_model_specs(
    eval_config: EvalConfig,
) -> Iterable[ModelSpec]

Returns non-baseline model specs.

Source code in tensorflow_model_analysis/utils/model_util.py
def get_non_baseline_model_specs(
    eval_config: config_pb2.EvalConfig,
) -> Iterable[config_pb2.ModelSpec]:
    """Returns non-baseline model specs."""
    return [spec for spec in eval_config.model_specs if not spec.is_baseline]

has_change_threshold

has_change_threshold(eval_config: EvalConfig) -> bool

Checks whether the eval_config has any change thresholds.


eval_config: the TFMA eval_config.


True when there are change thresholds otherwise False.

Source code in tensorflow_model_analysis/utils/config_util.py
def has_change_threshold(eval_config: config_pb2.EvalConfig) -> bool:
    """Checks whether the eval_config has any change thresholds.

    Args:
    ----
      eval_config: the TFMA eval_config.

    Returns:
    -------
      True when there are change thresholds otherwise False.
    """
    for metrics_spec in eval_config.metrics_specs:
        for metric in metrics_spec.metrics:
            if metric.threshold.change_threshold.ByteSize():
                return True
            for per_slice_threshold in metric.per_slice_thresholds:
                if per_slice_threshold.threshold.change_threshold.ByteSize():
                    return True
            for cross_slice_threshold in metric.cross_slice_thresholds:
                if cross_slice_threshold.threshold.change_threshold.ByteSize():
                    return True
        for threshold in metrics_spec.thresholds.values():
            if threshold.change_threshold.ByteSize():
                return True
        for per_slice_thresholds in metrics_spec.per_slice_thresholds.values():
            for per_slice_threshold in per_slice_thresholds.thresholds:
                if per_slice_threshold.threshold.change_threshold.ByteSize():
                    return True
        for cross_slice_thresholds in metrics_spec.cross_slice_thresholds.values():
            for cross_slice_threshold in cross_slice_thresholds.thresholds:
                if cross_slice_threshold.threshold.change_threshold.ByteSize():
                    return True
    return False

merge_extracts

merge_extracts(
    extracts: List[Extracts],
    squeeze_two_dim_vector: bool = True,
) -> Extracts

Merges list of extracts into a single extract with multidimensional data.

Running split_extracts followed by merge extracts with default options

will not reproduce the exact shape of the original extracts. Arrays in shape (x,1) will be flattened to (x,). To maintain the original shape of extract values of array shape (x,1), you must run with these options: split_extracts(extracts, expand_zero_dims=False) merge_extracts(extracts, squeeze_two_dim_vector=False)


extracts: Batched TFMA Extracts. squeeze_two_dim_vector: Determines how the function will handle arrays of shape (x,1). If squeeze_two_dim_vector is True, the array will be squeezed to shape (x,).


A single Extracts whose values have been grouped into batches.

Source code in tensorflow_model_analysis/utils/util.py
def merge_extracts(
    extracts: List[types.Extracts], squeeze_two_dim_vector: bool = True
) -> types.Extracts:
    """Merges list of extracts into a single extract with multidimensional data.

    Note: Running split_extracts followed by merge extracts with default options
      will not reproduce the exact shape of the original extracts. Arrays in shape
      (x,1) will be flattened to (x,). To maintain the original shape of extract
      values of array shape (x,1), you must run with these options:
      split_extracts(extracts, expand_zero_dims=False)
      merge_extracts(extracts, squeeze_two_dim_vector=False)

    Args:
    ----
      extracts: Batched TFMA Extracts.
      squeeze_two_dim_vector: Determines how the function will handle arrays of
        shape (x,1). If squeeze_two_dim_vector is True, the array will be squeezed
        to shape (x,).

    Returns:
    -------
      A single Extracts whose values have been grouped into batches.
    """

    def merge_with_lists(
        target: types.Extracts, index: int, key: str, value: Any, num_extracts: int
    ):
        """Merges key and value into the target extracts as a list of values.

        Args:
        ----
         target: The extract to store all merged all the data.
         index: The index at which the value should be stored. It is in accordance
           with the order of extracts in the batch.
         key: The key of the key-value pair to store in the target.
         value: The value of the key-value pair to store in the target.
         num_extracts: The total number of extracts to be merged in this target.
        """
        if isinstance(value, Mapping):
            if key not in target:
                target[key] = {}
            target = target[key]
            for k, v in value.items():
                merge_with_lists(target, index, k, v, num_extracts)
        else:
            # If key is newly found, we create a list with length of extracts,
            # so that every value of the i th extracts will go to the i th position.
            # And the extracts without this key will have value np.array([]).
            if key not in target:
                target[key] = [np.array([])] * num_extracts
            target[key][index] = value

    def merge_lists(target: types.Extracts) -> types.Extracts:
        """Converts target's leaves which are lists to batched np.array's, etc."""
        if isinstance(target, Mapping):
            result = {}
            for key, value in target.items():
                try:
                    result[key] = merge_lists(value)
                except Exception as e:
                    raise RuntimeError(
                        f"Failed to convert value for key: {key} and value: {value}"
                    ) from e
            return {k: merge_lists(v) for k, v in target.items()}
        elif (
            target
            and np.any([isinstance(t, tf.compat.v1.SparseTensorValue) for t in target])
            or np.any([isinstance(target[0], types.SparseTensorValue) for _ in target])
        ):
            t = tf.compat.v1.sparse_concat(
                0,
                [tf.sparse.expand_dims(to_tensorflow_tensor(t), 0) for t in target],
                expand_nonconcat_dim=True,
            )
            return to_tensor_value(t)
        elif target and np.any(
            [isinstance(t, types.RaggedTensorValue) for t in target]
        ):
            t = tf.concat(
                [tf.expand_dims(to_tensorflow_tensor(t), 0) for t in target], 0
            )
            return to_tensor_value(t)
        elif (
            all(isinstance(t, np.ndarray) for t in target)
            and len({t.shape for t in target}) > 1
        ):
            target = (t.squeeze() for t in target)
            return types.VarLenTensorValue.from_dense_rows(target)
        # If all value in the target are scalar numpy array, we stack them.
        # This is to avoid np.array([np.array(b'abc'), np.array(b'abcd')])
        # and stack to np.array([b'abc', b'abcd'])
        elif all(isinstance(t, np.ndarray) and t.shape == () for t in target):  # pylint: disable=g-explicit-bool-comparison
            return np.stack(target)
        elif all(t is None for t in target):
            return None
        else:
            # Compatibility shim for NumPy 1.24. See:
            # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
            try:
                arr = np.array(target)
            except ValueError:
                arr = np.array(target, dtype=object)
            # Flatten values that were originally single item lists into a single list
            # e.g. [[1], [2], [3]] -> [1, 2, 3]
            if squeeze_two_dim_vector and len(arr.shape) == 2 and arr.shape[1] == 1:
                return arr.squeeze(axis=1)
            return arr

    result = {}
    num_extracts = len(extracts)
    for i, x in enumerate(extracts):
        if x:
            for k, v in x.items():
                merge_with_lists(result, i, k, v, num_extracts)
    return merge_lists(result)

model_construct_fn

model_construct_fn(
    eval_saved_model_path: Optional[str] = None,
    add_metrics_callbacks: Optional[
        List[AddMetricsCallbackType]
    ] = None,
    include_default_metrics: Optional[bool] = None,
    additional_fetches: Optional[List[str]] = None,
    blacklist_feature_fetches: Optional[List[str]] = None,
    tags: Optional[List[str]] = None,
    model_type: Optional[str] = TFMA_EVAL,
) -> Callable[[], Any]

Returns function for constructing shared models.

Source code in tensorflow_model_analysis/utils/model_util.py
def model_construct_fn(  # pylint: disable=invalid-name
    eval_saved_model_path: Optional[str] = None,
    add_metrics_callbacks: Optional[List[types.AddMetricsCallbackType]] = None,
    include_default_metrics: Optional[bool] = None,
    additional_fetches: Optional[List[str]] = None,
    blacklist_feature_fetches: Optional[List[str]] = None,
    tags: Optional[List[str]] = None,
    model_type: Optional[str] = constants.TFMA_EVAL,
) -> Callable[[], Any]:
    """Returns function for constructing shared models."""
    if tags is None:
        raise ValueError("Model tags must be specified.")

    def construct_fn():  # pylint: disable=invalid-name
        """Function for constructing shared models."""
        # If we are evaluating on TPU, initialize the TPU.
        # TODO(b/143484017): Add model warmup for TPU.
        if tf.saved_model.TPU in tags:
            tf.tpu.experimental.initialize_tpu_system()

        if model_type == constants.TF_KERAS:
            model = tf.keras.models.load_model(eval_saved_model_path)
        elif model_type == constants.TF_LITE:
            # The tf.lite.Interpreter is not thread-safe so we only load the model
            # file's contents and leave construction of the Interpreter up to the
            # PTransform using it.
            model_filename = os.path.join(eval_saved_model_path, _TFLITE_FILE_NAME)
            with tf.io.gfile.GFile(model_filename, "rb") as model_file:
                model_bytes = model_file.read()

            # If a SavedModel is present in the same directory, load it as well.
            # This allows the SavedModel to be used for computing the
            # Transformed Features and Labels.
            if tf.io.gfile.exists(
                os.path.join(
                    eval_saved_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PB
                )
            ) or tf.io.gfile.exists(
                os.path.join(
                    eval_saved_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT
                )
            ):
                model = tf.compat.v1.saved_model.load_v2(
                    eval_saved_model_path, tags=tags
                )
                model.contents = model_bytes
            else:
                model = ModelContents(model_bytes)

        elif model_type == constants.TF_JS:
            # We invoke TFJS models via a subprocess call. So this call is no-op.
            return None
        else:
            model = tf.compat.v1.saved_model.load_v2(eval_saved_model_path, tags=tags)
        return model

    return construct_fn

unique_key

unique_key(
    key: str,
    current_keys: List[str],
    update_keys: Optional[bool] = False,
) -> str

Returns a unique key given a list of current keys.

If the key exists in current_keys then a new key with _1, _2, ..., etc appended will be returned, otherwise the key will be returned as passed.


key: desired key name. current_keys: List of current key names. update_keys: True to append the new key to current_keys.

Source code in tensorflow_model_analysis/utils/util.py
def unique_key(
    key: str, current_keys: List[str], update_keys: Optional[bool] = False
) -> str:
    """Returns a unique key given a list of current keys.

    If the key exists in current_keys then a new key with _1, _2, ..., etc
    appended will be returned, otherwise the key will be returned as passed.

    Args:
    ----
      key: desired key name.
      current_keys: List of current key names.
      update_keys: True to append the new key to current_keys.
    """
    index = 1
    k = key
    while k in current_keys:
        k = "%s_%d" % (key, index)
        index += 1
    if update_keys:
        current_keys.append(k)
    return k

update_eval_config_with_defaults

update_eval_config_with_defaults(
    eval_config: EvalConfig,
    maybe_add_baseline: Optional[bool] = None,
    maybe_remove_baseline: Optional[bool] = None,
    has_baseline: Optional[bool] = False,
    rubber_stamp: Optional[bool] = False,
) -> EvalConfig

Returns a new config with default settings applied.

a) Add or remove a model_spec according to "has_baseline". b) Fix the model names (model_spec.name) to tfma.CANDIDATE_KEY and tfma.BASELINE_KEY. c) Update the metrics_specs with the fixed model name.


eval_config: Original eval config. maybe_add_baseline: DEPRECATED. True to add a baseline ModelSpec to the config as a copy of the candidate ModelSpec that should already be present. This is only applied if a single ModelSpec already exists in the config and that spec doesn't have a name associated with it. When applied the model specs will use the names tfma.CANDIDATE_KEY and tfma.BASELINE_KEY. Only one of maybe_add_baseline or maybe_remove_baseline should be used. maybe_remove_baseline: DEPRECATED. True to remove a baseline ModelSpec from the config if it already exists. Removal of the baseline also removes any change thresholds. Only one of maybe_add_baseline or maybe_remove_baseline should be used. has_baseline: True to add a baseline ModelSpec to the config as a copy of the candidate ModelSpec that should already be present. This is only applied if a single ModelSpec already exists in the config and that spec doesn't have a name associated with it. When applied the model specs will use the names tfma.CANDIDATE_KEY and tfma.BASELINE_KEY. False to remove a baseline ModelSpec from the config if it already exists. Removal of the baseline also removes any change thresholds. Only one of has_baseline or maybe_remove_baseline should be used. rubber_stamp: True if this model is being rubber stamped. When a model is rubber stamped diff thresholds will be ignored if an associated baseline model is not passed.


RuntimeError: on missing baseline model for non-rubberstamp cases.

Source code in tensorflow_model_analysis/utils/config_util.py
def update_eval_config_with_defaults(
    eval_config: config_pb2.EvalConfig,
    maybe_add_baseline: Optional[bool] = None,
    maybe_remove_baseline: Optional[bool] = None,
    has_baseline: Optional[bool] = False,
    rubber_stamp: Optional[bool] = False,
) -> config_pb2.EvalConfig:
    """Returns a new config with default settings applied.

    a) Add or remove a model_spec according to "has_baseline".
    b) Fix the model names (model_spec.name) to tfma.CANDIDATE_KEY and
       tfma.BASELINE_KEY.
    c) Update the metrics_specs with the fixed model name.

    Args:
    ----
      eval_config: Original eval config.
      maybe_add_baseline: DEPRECATED. True to add a baseline ModelSpec to the
        config as a copy of the candidate ModelSpec that should already be
        present. This is only applied if a single ModelSpec already exists in the
        config and that spec doesn't have a name associated with it. When applied
        the model specs will use the names tfma.CANDIDATE_KEY and
        tfma.BASELINE_KEY. Only one of maybe_add_baseline or maybe_remove_baseline
        should be used.
      maybe_remove_baseline: DEPRECATED. True to remove a baseline ModelSpec from
        the config if it already exists. Removal of the baseline also removes any
        change thresholds. Only one of maybe_add_baseline or maybe_remove_baseline
        should be used.
      has_baseline: True to add a baseline ModelSpec to the config as a copy of
        the candidate ModelSpec that should already be present. This is only
        applied if a single ModelSpec already exists in the config and that spec
        doesn't have a name associated with it. When applied the model specs will
        use the names tfma.CANDIDATE_KEY and tfma.BASELINE_KEY. False to remove a
        baseline ModelSpec from the config if it already exists. Removal of the
        baseline also removes any change thresholds. Only one of has_baseline or
        maybe_remove_baseline should be used.
      rubber_stamp: True if this model is being rubber stamped. When a model is
        rubber stamped diff thresholds will be ignored if an associated baseline
        model is not passed.

    Raises:
    ------
      RuntimeError: on missing baseline model for non-rubberstamp cases.
    """
    if not has_baseline and has_change_threshold(eval_config) and not rubber_stamp:
        # TODO(b/173657964): Raise an error instead of logging an error.
        raise RuntimeError(
            "There are change thresholds, but the baseline is missing. "
            "This is allowed only when rubber stamping (first run)."
        )

    updated_config = config_pb2.EvalConfig()
    updated_config.CopyFrom(eval_config)
    # if user requests CIs but doesn't set method, use JACKKNIFE
    if (
        eval_config.options.compute_confidence_intervals.value
        and eval_config.options.confidence_intervals.method
        == config_pb2.ConfidenceIntervalOptions.UNKNOWN_CONFIDENCE_INTERVAL_METHOD
    ):
        updated_config.options.confidence_intervals.method = (
            config_pb2.ConfidenceIntervalOptions.JACKKNIFE
        )
    if maybe_add_baseline and maybe_remove_baseline:
        raise ValueError(
            "only one of maybe_add_baseline and maybe_remove_baseline " "should be used"
        )
    if maybe_add_baseline or maybe_remove_baseline:
        logging.warning(
            """"maybe_add_baseline" and "maybe_remove_baseline" are deprecated,
        please use "has_baseline" instead."""
        )
        if has_baseline:
            raise ValueError(
                """"maybe_add_baseline" and "maybe_remove_baseline" are ignored if
          "has_baseline" is set."""
            )
    if has_baseline is not None:
        if has_baseline:
            maybe_add_baseline = True
        else:
            maybe_remove_baseline = True

    # Has a baseline model.
    if (
        maybe_add_baseline
        and len(updated_config.model_specs) == 1
        and not updated_config.model_specs[0].name
    ):
        baseline = updated_config.model_specs.add()
        baseline.CopyFrom(updated_config.model_specs[0])
        baseline.name = constants.BASELINE_KEY
        baseline.is_baseline = True
        updated_config.model_specs[0].name = constants.CANDIDATE_KEY
        logging.info(
            "Adding default baseline ModelSpec based on the candidate ModelSpec "
            'provided. The candidate model will be called "%s" and the baseline '
            'will be called "%s": updated_config=\n%s',
            constants.CANDIDATE_KEY,
            constants.BASELINE_KEY,
            updated_config,
        )

    # Does not have a baseline.
    if maybe_remove_baseline:
        tmp_model_specs = []
        for model_spec in updated_config.model_specs:
            if not model_spec.is_baseline:
                tmp_model_specs.append(model_spec)
        del updated_config.model_specs[:]
        updated_config.model_specs.extend(tmp_model_specs)
        for metrics_spec in updated_config.metrics_specs:
            for metric in metrics_spec.metrics:
                if metric.threshold.ByteSize():
                    metric.threshold.ClearField("change_threshold")
                for per_slice_threshold in metric.per_slice_thresholds:
                    if per_slice_threshold.threshold.ByteSize():
                        per_slice_threshold.threshold.ClearField("change_threshold")
                for cross_slice_threshold in metric.cross_slice_thresholds:
                    if cross_slice_threshold.threshold.ByteSize():
                        cross_slice_threshold.threshold.ClearField("change_threshold")
            for threshold in metrics_spec.thresholds.values():
                if threshold.ByteSize():
                    threshold.ClearField("change_threshold")
            for per_slice_thresholds in metrics_spec.per_slice_thresholds.values():
                for per_slice_threshold in per_slice_thresholds.thresholds:
                    if per_slice_threshold.threshold.ByteSize():
                        per_slice_threshold.threshold.ClearField("change_threshold")
            for cross_slice_thresholds in metrics_spec.cross_slice_thresholds.values():
                for cross_slice_threshold in cross_slice_thresholds.thresholds:
                    if cross_slice_threshold.threshold.ByteSize():
                        cross_slice_threshold.threshold.ClearField("change_threshold")
        logging.info(
            "Request was made to ignore the baseline ModelSpec and any change "
            "thresholds. This is likely because a baseline model was not provided: "
            "updated_config=\n%s",
            updated_config,
        )

    if not updated_config.model_specs:
        updated_config.model_specs.add()

    model_names = []
    for spec in updated_config.model_specs:
        model_names.append(spec.name)
    if len(model_names) == 1 and model_names[0]:
        logging.info(
            'ModelSpec name "%s" is being ignored and replaced by "" because a '
            "single ModelSpec is being used",
            model_names[0],
        )
        updated_config.model_specs[0].name = ""
        model_names = [""]
    for spec in updated_config.metrics_specs:
        if not spec.model_names:
            spec.model_names.extend(model_names)
        elif len(model_names) == 1:
            del spec.model_names[:]
            spec.model_names.append("")

    return updated_config

verify_and_update_eval_shared_models

verify_and_update_eval_shared_models(
    eval_shared_model: Optional[
        MaybeMultipleEvalSharedModels
    ],
) -> Optional[List[EvalSharedModel]]

Verifies eval shared models and normnalizes to produce a single list.

The output is normalized such that if a list or dict contains a single entry, the model name will always be empty.


eval_shared_model: None, a single model, a list of models, or a dict of models keyed by model name.


A list of models or None.


ValueError if dict is passed and keys don't match model names or a multi-item list is passed without model names.

Source code in tensorflow_model_analysis/utils/model_util.py
def verify_and_update_eval_shared_models(
    eval_shared_model: Optional[types.MaybeMultipleEvalSharedModels],
) -> Optional[List[types.EvalSharedModel]]:
    """Verifies eval shared models and normnalizes to produce a single list.

    The output is normalized such that if a list or dict contains a single entry,
    the model name will always be empty.

    Args:
    ----
      eval_shared_model: None, a single model, a list of models, or a dict of
        models keyed by model name.

    Returns:
    -------
      A list of models or None.

    Raises:
    ------
      ValueError if dict is passed and keys don't match model names or a
      multi-item list is passed without model names.
    """
    if not eval_shared_model:
        return None
    eval_shared_models = []
    if isinstance(eval_shared_model, dict):
        for k, v in eval_shared_model.items():
            if v.model_name and k and k != v.model_name:
                raise ValueError(
                    "keys for EvalSharedModel dict do not match "
                    f"model_names: dict={eval_shared_model}"
                )
            if not v.model_name and k:
                v = v._replace(model_name=k)
            eval_shared_models.append(v)
    elif isinstance(eval_shared_model, list):
        # Ensure we don't modify the input list when updating model_name, below.
        eval_shared_models = eval_shared_model.copy()
    else:
        eval_shared_models = [eval_shared_model]
    if len(eval_shared_models) > 1:
        for v in eval_shared_models:
            if not v.model_name:
                raise ValueError(
                    "model_name is required when passing multiple EvalSharedModels: "
                    f"eval_shared_models={eval_shared_models}"
                )
    # To maintain consistency between settings where single models are used,
    # always use '' as the model name regardless of whether a name is passed.
    elif len(eval_shared_models) == 1 and eval_shared_models[0].model_name:
        eval_shared_models[0] = eval_shared_models[0]._replace(model_name="")
    # Normalizes model types to TFMA_EVAL when appropriate.
    for i, model in enumerate(eval_shared_models):
        assert isinstance(model, types.EvalSharedModel)
    return eval_shared_models  # pytype: disable=bad-return-type  # py310-upgrade

verify_eval_config

verify_eval_config(
    eval_config: EvalConfig,
    baseline_required: Optional[bool] = None,
)

Verifies eval config.

Source code in tensorflow_model_analysis/utils/config_util.py
def verify_eval_config(
    eval_config: config_pb2.EvalConfig, baseline_required: Optional[bool] = None
):
    """Verifies eval config."""
    if not eval_config.model_specs:
        raise ValueError(
            f"At least one model_spec is required: eval_config=\n{eval_config}"
        )

    model_specs_by_name = {}
    baseline = None
    for spec in eval_config.model_specs:
        if spec.label_key and spec.label_keys:
            raise ValueError(
                "only one of label_key or label_keys should be used at "
                f"a time: model_spec=\n{spec}"
            )
        if spec.prediction_key and spec.prediction_keys:
            raise ValueError(
                "only one of prediction_key or prediction_keys should be used at "
                f"a time: model_spec=\n{spec}"
            )
        if spec.example_weight_key and spec.example_weight_keys:
            raise ValueError(
                "only one of example_weight_key or example_weight_keys should be "
                f"used at a time: model_spec=\n{spec}"
            )
        if spec.name in eval_config.model_specs:
            raise ValueError(
                f'more than one model_spec found for model "{spec.name}": {[spec, model_specs_by_name[spec.name]]}'
            )
        model_specs_by_name[spec.name] = spec
        if spec.is_baseline:
            if baseline is not None:
                raise ValueError(
                    "only one model_spec may be a baseline, found: "
                    f"{spec} and {baseline}"
                )
            baseline = spec

    if len(model_specs_by_name) > 1 and "" in model_specs_by_name:
        raise ValueError(
            "A name is required for all ModelSpecs when multiple "
            f"models are used: eval_config=\n{eval_config}"
        )

    if baseline_required and not baseline:
        raise ValueError(
            f"A baseline ModelSpec is required: eval_config=\n{eval_config}"
        )

    # Raise exception if per_slice_thresholds has no slicing_specs.
    for metric_spec in eval_config.metrics_specs:
        for name, per_slice_thresholds in metric_spec.per_slice_thresholds.items():
            for per_slice_threshold in per_slice_thresholds.thresholds:
                if not per_slice_threshold.slicing_specs:
                    raise ValueError(
                        "slicing_specs must be set on per_slice_thresholds but found "
                        f"per_slice_threshold=\n{per_slice_threshold}\n"
                        f"for metric name {name} in metric_spec:\n{metric_spec}"
                    )
        for metric_config in metric_spec.metrics:
            for per_slice_threshold in metric_config.per_slice_thresholds:
                if not per_slice_threshold.slicing_specs:
                    raise ValueError(
                        "slicing_specs must be set on per_slice_thresholds but found "
                        f"per_slice_threshold=\n{per_slice_threshold}\n"
                        f"for metric config:\t{metric_config}"
                    )