# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculate and update metrics for Decode Post-Process."""
import collections
from typing import Any, Dict
import lingvo.compat as tf
from lingvo.tasks.asr import decoder_utils
# transcripts: [num_utts]. A sequence of transcripts (references), one for
# each utterance.
# topk_decoded: [num_utts, num_topk_hyps]. A sequence, for each utterance, of a
# sequence of the top K hypotheses.
# filtered_transcripts: [num_utts]. Same as transcripts, but after simple
# filtering.
# filtered_top_hyps: [num_utts]. A sequence of filtered top hypotheses, one for
# each utterance.
# topk_scores: [num_utts, num_topk_hyps]. A sequence, for each utterance, of a
# sequence of the top K decoder scores corresponding to the top K hypotheses.
# utt_id: [num_utts]. A sequence of utterance_ids.
# norm_wer_errors: [num_utts, num_topk_hyps]. A sequence, for each utterance, of
# a sequence of the normalized wer corresponding to the top K hypotheses.
# target_labels: [num_utts, max_target_length]. A sequence, for each utterance,
# of a sequence of labels for the target(reference) tokens (e.g. word-pieces).
# target_paddings: [num_utts, max_target_length]. A sequence, for each
# utterance, of a sequence of padding values for the target(reference) token
# sequence, used to determine non-padding labels.
# topk_ids: [num_utts * num_topk_hyps, max_target_length]. A sequence, for each
# topK hypothesis of ALL utterances, of a sequence of token labels/ids
# (up to max_target_length).
# topk_lens: [num_utts * num_topk_hyps]. A sequence, for each topK hypothesis of
# ALL utterances, of the length of its token labels/ids sequence.
PostProcessInputs = collections.namedtuple('postprocess_input', [
'transcripts', 'topk_decoded', 'filtered_transcripts', 'filtered_top_hyps',
'topk_scores', 'utt_id', 'norm_wer_errors', 'target_labels',
'target_paddings', 'topk_ids', 'topk_lens'
])
[docs]def GetRefIds(ref_ids, ref_paddings):
assert len(ref_ids) == len(ref_paddings)
return_ids = []
for i in range(len(ref_ids)):
if ref_paddings[i] == 0:
return_ids.append(ref_ids[i])
return return_ids
[docs]def CalculateMetrics(
postprocess_inputs: PostProcessInputs,
dec_metrics_dict: Dict[str, Any],
add_summary: bool,
use_tpu: bool,
log_utf8: bool,
):
"""Calculate and update metrics.
Args:
postprocess_inputs: namedtuple of Postprocess input objects/tensors.
dec_metrics_dict: A dictionary of metric names to metrics.
add_summary: Whether to add detailed summary logging for processing each
utterance.
use_tpu: Whether TPU is used (for decoding).
log_utf8: DecoderMetrics param. If True, decode reference and hypotheses
bytes to UTF-8 for logging.
"""
(transcripts, topk_decoded, filtered_transcripts, filtered_top_hyps,
topk_scores, utt_id, norm_wer_errors, target_labels, target_paddings,
topk_ids, topk_lens) = postprocess_inputs
if not transcripts.size:
return
# Case sensitive WERs.
total_ins, total_subs, total_dels, total_errs = 0, 0, 0, 0
# Case insensitive WERs.
ci_total_ins, ci_total_subs, ci_total_dels, ci_total_errs = 0, 0, 0, 0
total_oracle_errs = 0
total_ref_words = 0
total_token_errs = 0
total_ref_tokens = 0
total_accurate_sentences = 0
for i in range(len(transcripts)):
ref_str = transcripts[i]
if not use_tpu:
tf.logging.info('utt_id: %s', utt_id[i])
if add_summary:
tf.logging.info(' ref_str: %s',
ref_str.decode('utf-8') if log_utf8 else ref_str)
hyps = topk_decoded[i]
num_hyps_per_beam = len(hyps)
ref_ids = GetRefIds(target_labels[i], target_paddings[i])
hyp_index = i * num_hyps_per_beam
top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]]
if add_summary:
tf.logging.info(' ref_ids: %s', ref_ids)
tf.logging.info(' top_hyp_ids: %s', top_hyp_ids)
total_ref_tokens += len(ref_ids)
_, _, _, token_errs = decoder_utils.EditDistanceInIds(ref_ids, top_hyp_ids)
total_token_errs += token_errs
filtered_ref = filtered_transcripts[i]
oracle_errs = norm_wer_errors[i][0]
for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)):
oracle_errs = min(oracle_errs, norm_wer_errors[i, n])
if add_summary:
tf.logging.info(' %f: %s', score,
hyp_str.decode('utf-8') if log_utf8 else hyp_str)
# Only aggregate scores of the top hypothesis.
if n != 0:
continue
filtered_hyp = filtered_top_hyps[i]
ins, subs, dels, errs = decoder_utils.EditDistance(
filtered_ref, filtered_hyp)
total_ins += ins
total_subs += subs
total_dels += dels
total_errs += errs
# Calculating case_insensitive WERs
ci_ins, ci_subs, ci_dels, ci_errs = decoder_utils.EditDistance(
filtered_ref.lower(), filtered_hyp.lower())
ci_total_ins += ci_ins
ci_total_subs += ci_subs
ci_total_dels += ci_dels
ci_total_errs += ci_errs
ref_words = len(decoder_utils.Tokenize(filtered_ref))
total_ref_words += ref_words
if norm_wer_errors[i, n] == 0:
total_accurate_sentences += 1
tf.logging.info(
' ins: %d, subs: %d, del: %d, total: %d, ref_words: %d, wer: %f',
ins, subs, dels, errs, ref_words, errs / max(1, ref_words))
tf.logging.info(
' ci_ins: %d, ci_subs: %d, ci_del: %d, ci_total: %d, '
'ref_words: %d, ci_wer: %f', ci_ins, ci_subs, ci_dels, ci_errs,
ref_words, ci_errs / max(1, ref_words))
total_oracle_errs += oracle_errs
non_zero_total_ref_words = max(1., total_ref_words)
dec_metrics_dict['wer'].Update(total_errs / non_zero_total_ref_words,
total_ref_words)
dec_metrics_dict['error_rates/ins'].Update(
total_ins / non_zero_total_ref_words, total_ref_words)
dec_metrics_dict['error_rates/sub'].Update(
total_subs / non_zero_total_ref_words, total_ref_words)
dec_metrics_dict['error_rates/del'].Update(
total_dels / non_zero_total_ref_words, total_ref_words)
dec_metrics_dict['error_rates/wer'].Update(
total_errs / non_zero_total_ref_words, total_ref_words)
dec_metrics_dict['case_insensitive_error_rates/ins'].Update(
ci_total_ins / non_zero_total_ref_words, total_ref_words)
dec_metrics_dict['case_insensitive_error_rates/sub'].Update(
ci_total_subs / non_zero_total_ref_words, total_ref_words)
dec_metrics_dict['case_insensitive_error_rates/del'].Update(
ci_total_dels / non_zero_total_ref_words, total_ref_words)
dec_metrics_dict['case_insensitive_error_rates/wer'].Update(
ci_total_errs / non_zero_total_ref_words, total_ref_words)
dec_metrics_dict['oracle_norm_wer'].Update(
total_oracle_errs / non_zero_total_ref_words, total_ref_words)
dec_metrics_dict['sacc'].Update(total_accurate_sentences / len(transcripts),
len(transcripts))
dec_metrics_dict['ter'].Update(total_token_errs / max(1., total_ref_tokens),
total_ref_tokens)