Source code for lingvo.tasks.asr.tools.simple_wer

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Stand-alone script to evalute the word error rate (WER) for ASR tasks.

THIS SCRIPT IS NO LONGER SUPPORTED. PLEASE USE simple_wer_v2.py INSTEAD.

Tensorflow and Lingvo are not required to run this script.

Example of Usage::

  python simple_wer.py file_hypothesis file_reference
  python simple_wer.py file_hypothesis file_reference diagnosis_html

where `file_hypothesis` is the file name for hypothesis text and
`file_reference` is the file name for reference text.
`diagnosis_html` (optional) is the html filename to diagnose the errors.

Or you can use this file as a library, and call either of the following:

  - ``ComputeWER(hyp, ref)``    compute WER for one pair of hypothesis/reference
  - ``AverageWERs(hyps, refs)`` average WER for a list of hypotheses/references

Note to evaluate the ASR, we consider the following pre-processing:

  - change transcripts to lower-case
  - remove punctuation: ``" , . ! ? (  ) [ ]``
  - remove extra empty spaces
"""

import re
import sys


[docs]def ComputeEditDistanceMatrix(hs, rs): """Compute edit distance between two list of strings. Args: hs: the list of words in the hypothesis sentence rs: the list of words in the reference sentence Returns: Edit distance matrix (in the format of list of lists), where the first index is the reference and the second index is the hypothesis. """ dr, dh = len(rs) + 1, len(hs) + 1 dists = [[]] * dr # Initialization. for i in range(dr): dists[i] = [0] * dh for j in range(dh): if i == 0: dists[0][j] = j elif j == 0: dists[i][0] = i # Do dynamic programming. for i in range(1, dr): for j in range(1, dh): if rs[i - 1] == hs[j - 1]: dists[i][j] = dists[i - 1][j - 1] else: tmp0 = dists[i - 1][j - 1] + 1 tmp1 = dists[i][j - 1] + 1 tmp2 = dists[i - 1][j] + 1 dists[i][j] = min(tmp0, tmp1, tmp2) return dists
[docs]def PreprocessTxtBeforeWER(txt): """Preprocess text before WER caculation.""" # Lowercase, remove \t and new line. txt = re.sub(r'[\t\n]', ' ', txt.lower()) # Remove punctuation before space. txt = re.sub(r'[,.\?!]+ ', ' ', txt) # Remove punctuation before end. txt = re.sub(r'[,.\?!]+$', ' ', txt) # Remove punctuation after space. txt = re.sub(r' [,.\?!]+', ' ', txt) # Remove quotes, [, ], ( and ). txt = re.sub(r'["\(\)\[\]]', '', txt) # Remove extra space. txt = re.sub(' +', ' ', txt.strip()) return txt
[docs]def _GenerateAlignedHtml(hyp, ref, err_type): """Generate a html element to highlight the difference between hyp and ref. Args: hyp: Hypothesis string. ref: Reference string. err_type: one of 'none', 'sub', 'del', 'ins'. Returns: a html string where disagreements are highlighted. - hyp highlighted in green, and marked with <del> </del> - ref highlighted in yellow """ highlighted_html = '' if err_type == 'none': highlighted_html += '%s ' % hyp elif err_type == 'sub': highlighted_html += """<span style="background-color: yellow"> <del>%s</del></span><span style="background-color: yellow"> %s </span> """ % (hyp, ref) elif err_type == 'del': highlighted_html += """<span style="background-color: red"> %s </span> """ % ( ref) elif err_type == 'ins': highlighted_html += """<span style="background-color: green"> <del>%s</del> </span> """ % ( hyp) else: raise ValueError('unknown err_type ' + err_type) return highlighted_html
[docs]def GenerateSummaryFromErrs(nref, errs): """Generate strings to summarize word errors. Args: nref: integer of total words in references errs: dict of three types of errors. e.g. {'sub':10, 'ins': 15, 'del': 3} Returns: Two strings: - string summarizing total error, total word, WER, - string breaking down three errors: deleting, insertion, substitute """ total_error = sum(errs.values()) str_sum = 'total error = %d, total word = %d, wer = %.2f%%' % ( total_error, nref, total_error * 100.0 / nref) str_details = 'Error breakdown: del = %.2f%%, ins=%.2f%%, sub=%.2f%%' % ( errs['del'] * 100.0 / nref, errs['ins'] * 100.0 / nref, errs['sub'] * 100.0 / nref) return str_sum, str_details
[docs]def ComputeWER(hyp, ref, diagnosis=False): """Computes WER for ASR by ignoring diff of punctuation, space, captions. Args: hyp: Hypothesis string. ref: Reference string. diagnosis (optional): whether to generate diagnosis str (in html format) Returns: A tuple of 3 elements: - dict of three types of errors. e.g. ``{'sub':0, 'ins': 0, 'del': 0}`` - num of reference words, integer - aligned html string for diagnois (empty if diagnosis = False) """ hyp = PreprocessTxtBeforeWER(hyp) ref = PreprocessTxtBeforeWER(ref) # Compute edit distance. hs = hyp.split() rs = ref.split() distmat = ComputeEditDistanceMatrix(hs, rs) # Back trace, to distinguish different errors: insert, deletion, substitution. ih, ir = len(hs), len(rs) errs = {'sub': 0, 'ins': 0, 'del': 0} aligned_html = '' while ih > 0 or ir > 0: err_type = '' # Distinguish error type by back tracking if ir == 0: err_type = 'ins' elif ih == 0: err_type = 'del' else: if hs[ih - 1] == rs[ir - 1]: # correct err_type = 'none' elif distmat[ir][ih] == distmat[ir - 1][ih - 1] + 1: # substitute err_type = 'sub' elif distmat[ir][ih] == distmat[ir - 1][ih] + 1: # deletion err_type = 'del' elif distmat[ir][ih] == distmat[ir][ih - 1] + 1: # insert err_type = 'ins' else: raise ValueError('fail to parse edit distance matrix') # Generate aligned_html if diagnosis: if ih == 0 or not hs: tmph = ' ' else: tmph = hs[ih - 1] if ir == 0 or not rs: tmpr = ' ' else: tmpr = rs[ir - 1] aligned_html = _GenerateAlignedHtml(tmph, tmpr, err_type) + aligned_html # If no error, go to previous ref and hyp. if err_type == 'none': ih, ir = ih - 1, ir - 1 continue # Update error. errs[err_type] += 1 # Adjust position of ref and hyp. if err_type == 'del': ir = ir - 1 elif err_type == 'ins': ih = ih - 1 else: # err_type == 'sub' ih, ir = ih - 1, ir - 1 assert distmat[-1][-1] == sum(errs.values()) # Num of words. For empty ref we set num = 1. nref = max(len(rs), 1) return errs, nref, aligned_html
[docs]def AverageWERs(hyps, refs, verbose=True, diagnosis=False): """Computes average WER from a list of references/hypotheses. Args: hyps: list of hypothesis strings. refs: list of reference strings. verbose: optional (default True) diagnosis (optional): whether to generate list of diagnosis html Returns: A tuple of 3 elements: - dict of three types of errors. e.g. ``{'sub':0, 'ins': 0, 'del': 0}`` - num of reference words, integer - list of aligned html string for diagnosis (empty if diagnosis = False) """ totalw = 0 total_errs = {'sub': 0, 'ins': 0, 'del': 0} aligned_html_list = [] for hyp, ref in zip(hyps, refs): errs_i, nref_i, diag_str = ComputeWER(hyp, ref, diagnosis) if diagnosis: aligned_html_list += [diag_str] totalw += nref_i total_errs['sub'] += errs_i['sub'] total_errs['ins'] += errs_i['ins'] total_errs['del'] += errs_i['del'] if verbose: str_summary, str_details = GenerateSummaryFromErrs(totalw, total_errs) print(str_summary) print(str_details) return total_errs, totalw, aligned_html_list
[docs]def main(argv): hyp = open(argv[1], 'r').read() ref = open(argv[2], 'r').read() if len(argv) == 4: diagnosis = True fn_output = argv[3] else: diagnosis = False fn_output = None errs, nref, aligned_html = ComputeWER(hyp, ref, diagnosis) str_summary, str_details = GenerateSummaryFromErrs(nref, errs) print(str_summary) print(str_details) if fn_output: with open(fn_output, 'wt') as fp: fp.write('<body><html>') fp.write('<div>%s</div>' % aligned_html) fp.write('</body></html>')
if __name__ == '__main__': print('THIS SCRIPT IS NO LONGER SUPPORTED.' 'PLEASE USE simple_wer_v2.py INSTEAD.') if len(sys.argv) < 3 or len(sys.argv) > 4: print(""" Example of Usage: python simple_wer.py file_hypothesis file_reference or python simple_wer.py file_hypothesis file_reference diagnosis_html where file_hypothesis is the file name for hypothesis text file_reference is the file name for reference text. diagnosis_html (optional) is the html filename to diagnose the errors. Or you can use this file as a library, and call either of the following - ComputeWER(hyp, ref) to compute WER for one pair of hypothesis/reference - AverageWERs(hyps, refs) to average WER for a list of hypotheses/references """) sys.exit(1) main(sys.argv)