Source code for lingvo.base_trial

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines trials for parameter exploration."""

import time

from lingvo.core import hyperparams


[docs]class Trial: """Base class for a trial."""
[docs] @classmethod def Params(cls): """Default parameters for a trial.""" p = hyperparams.Params() p.Define( 'report_interval_seconds', 600, 'Interval between reporting trial results and checking for early ' 'stopping.') p.Define('vizier_objective_metric_key', 'loss', 'Which eval metric to use as the "objective value" for tuning.') p.Define( 'report_during_training', False, 'Whether to report objective metrics during the training process.') return p
def __init__(self, params): self._params = params.Copy() self._next_report_time = time.time() @property def report_interval_seconds(self): return self._params.report_interval_seconds @property def objective_metric_key(self): return self._params.vizier_objective_metric_key
[docs] def Name(self): raise NotImplementedError('Abstract method')
[docs] def OverrideModelParams(self, model_params): """Modifies `model_params` according to trial params. Through this method a `Trial` may tweak model hyperparams (e.g., learning rate, shape, depth, or width of networks). Args: model_params: the original model hyperparams. Returns: The modified `model_params`. """ raise NotImplementedError('Abstract method')
[docs] def ShouldStop(self): """Returns whether the trial should stop.""" raise NotImplementedError('Abstract method')
[docs] def ReportDone(self, infeasible=False, infeasible_reason=''): """Report that the trial is completed.""" raise NotImplementedError('Abstract method')
[docs] def ShouldStopAndMaybeReport(self, global_step, metrics_dict): """Returns whether the trial should stop. Args: global_step: The global step counter. metrics_dict: If not None, contains the metric should be reported. If None, do nothing but returns whether the trial should stop. """ if not metrics_dict or not self._params.report_during_training: return self.ShouldStop() if time.time() < self._next_report_time: return False self._next_report_time = time.time() + self.report_interval_seconds return self._DoReportTrainingProgress(global_step, metrics_dict)
[docs] def _DoReportTrainingProgress(self, global_step, metrics_dict): raise NotImplementedError('Abstract method')
[docs] def ReportEvalMeasure(self, global_step, metrics_dict, checkpoint_path): """Reports eval measurement and returns whether the trial should stop.""" raise NotImplementedError('Abstract method')
[docs]class NoOpTrial(Trial): """A Trial implementation that does nothing.""" def __init__(self): super().__init__(Trial.Params())
[docs] def Name(self): return ''
[docs] def OverrideModelParams(self, model_params): return model_params
[docs] def ShouldStop(self): return False
[docs] def ReportDone(self, infeasible=False, infeasible_reason=''): return False
[docs] def ShouldStopAndMaybeReport(self, global_step, metrics_dict): del global_step, metrics_dict # Unused return False
[docs] def ReportEvalMeasure(self, global_step, metrics_dict, checkpoint_path): del global_step, metrics_dict, checkpoint_path # Unused return False
[docs]class TunerManagedError(BaseException): """Base class for error that should be propagated to the tuner. In base_runner.py, the training loop catchs all exceptions and treats unknown errors as failure. However, in some cases (e.g. PyGlove uses an EarlyStoppingError to signal early stopping that might take place at any moment), it requires the error to propagate to the tuner. This class is a base for such errors. """