# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils function for program.py."""
import os
import lingvo.compat as tf
[docs]def SummaryToCsv(summaries):
"""Convert summary (Dict[str, tf.Summary]) to csv format."""
res = ''
for k, s in summaries.items():
res += f'{k},{s.value[0].simple_value}\n'
return res
[docs]def CsvToSummary(csv):
"""Convert csv format to summary (Dict[str, tf.Summary])."""
summaries = {}
for l in csv.split('\n'):
row = l.split(',')
if len(row) != 2:
tf.logging.warn(f'Failed to parse csv line: {l}, will ignore it.')
continue
s = tf.Summary()
v = s.value.add()
v.tag, v.simple_value = row[0], float(row[1])
summaries.update({v.tag: s})
return summaries
[docs]class DecodeStatusCache:
"""Maintain status file to keep decoding datasets status.
Status file should have following format:
- 1st line is checkpoint key, e.g. ckpt-123
- the rest lines are dataset names that has been decoded.
Here's an example:
ckpt-123
Dev
Test
"""
def __init__(self, program_dir):
self.ckpt_key = ''
self.decoded_datasets = []
self.status_file = os.path.join(program_dir, 'decoded_datasets.txt')
# TODO(xingwu): Consider add a TTL.
self.cache_dir = os.path.join(program_dir, 'cache')
tf.io.gfile.makedirs(self.cache_dir)
if tf.io.gfile.exists(self.status_file):
with tf.io.gfile.GFile(self.status_file, 'r') as f:
content = list(l.strip() for l in f.readlines())
if content:
self.ckpt_key = content[0]
if len(content) > 1:
self.decoded_datasets = content[1:]
[docs] def UpdateCkpt(self, ckpt_key):
"""Update checkpoint key in the status."""
if ckpt_key != self.ckpt_key:
self.ckpt_key = ckpt_key
self.decoded_datasets = []
with tf.io.gfile.GFile(self.status_file, 'w') as f:
f.write(self.ckpt_key)
[docs] def UpdateDataset(self, dataset_name, summaries):
"""Update decoded dataset in the status."""
cache_file = os.path.join(self.cache_dir, f'{dataset_name}.csv')
with tf.io.gfile.GFile(cache_file, 'w') as f:
f.write(SummaryToCsv(summaries))
with tf.io.gfile.GFile(self.status_file, 'w+') as f:
f.write(f.read().strip() + '\n' + dataset_name)
[docs] def TryLoadCache(self, ckpt_key, dataset_name):
"""Try load summary cache for ckpt_key, dataset_name.
Args:
ckpt_key: str, checkpoint key, e.g. ckpt-123
dataset_name: str, the dataset name, e.g. Test
Returns:
summaries if load successful, otherwise, return None
"""
if ckpt_key == self.ckpt_key and dataset_name in self.decoded_datasets:
cache_file = os.path.join(self.cache_dir, f'{dataset_name}.csv')
if not tf.io.gfile.exists(cache_file):
tf.logging.warn(f'cached summary {cache_file} is gone!')
return None
with tf.io.gfile.GFile(cache_file, 'r') as f:
summaries = CsvToSummary(f.read())
with tf.io.gfile.GFile(self.status_file, 'w+') as f:
f.write(f.read().strip() + '\n' + dataset_name)
return summaries
return None
[docs]class TriggerScheduler:
"""A trigger scheduler with offset, and interval.
Maintains an counter, incremented when Trigger() called. ShouldRun() only
returns True when (counter - offset) % interval == 0.
"""
def __init__(self, offset, interval):
self.offset = offset
self.interval = interval
self.counter = -offset
[docs] def Trigger(self):
self.counter += 1
if self.counter >= self.interval:
self.counter = 0
[docs] def ShouldRun(self):
return self.counter == 0