# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines trials for parameter exploration."""
import time
from lingvo.core import hyperparams
[docs]class Trial:
"""Base class for a trial."""
[docs] @classmethod
def Params(cls):
"""Default parameters for a trial."""
p = hyperparams.Params()
p.Define(
'report_interval_seconds', 600,
'Interval between reporting trial results and checking for early '
'stopping.')
p.Define('vizier_objective_metric_key', 'loss',
'Which eval metric to use as the "objective value" for tuning.')
p.Define(
'report_during_training', False,
'Whether to report objective metrics during the training process.')
return p
def __init__(self, params):
self._params = params.Copy()
self._next_report_time = time.time()
@property
def report_interval_seconds(self):
return self._params.report_interval_seconds
@property
def objective_metric_key(self):
return self._params.vizier_objective_metric_key
[docs] def Name(self):
raise NotImplementedError('Abstract method')
[docs] def OverrideModelParams(self, model_params):
"""Modifies `model_params` according to trial params.
Through this method a `Trial` may tweak model hyperparams (e.g., learning
rate, shape, depth, or width of networks).
Args:
model_params: the original model hyperparams.
Returns:
The modified `model_params`.
"""
raise NotImplementedError('Abstract method')
[docs] def ShouldStop(self):
"""Returns whether the trial should stop."""
raise NotImplementedError('Abstract method')
[docs] def ReportDone(self, infeasible=False, infeasible_reason=''):
"""Report that the trial is completed."""
raise NotImplementedError('Abstract method')
[docs] def ShouldStopAndMaybeReport(self, global_step, metrics_dict):
"""Returns whether the trial should stop.
Args:
global_step: The global step counter.
metrics_dict: If not None, contains the metric should be
reported. If None, do nothing but returns whether the
trial should stop.
"""
if not metrics_dict or not self._params.report_during_training:
return self.ShouldStop()
if time.time() < self._next_report_time:
return False
self._next_report_time = time.time() + self.report_interval_seconds
return self._DoReportTrainingProgress(global_step, metrics_dict)
[docs] def _DoReportTrainingProgress(self, global_step, metrics_dict):
raise NotImplementedError('Abstract method')
[docs] def ReportEvalMeasure(self, global_step, metrics_dict, checkpoint_path):
"""Reports eval measurement and returns whether the trial should stop."""
raise NotImplementedError('Abstract method')
[docs]class NoOpTrial(Trial):
"""A Trial implementation that does nothing."""
def __init__(self):
super().__init__(Trial.Params())
[docs] def Name(self):
return ''
[docs] def OverrideModelParams(self, model_params):
return model_params
[docs] def ShouldStop(self):
return False
[docs] def ReportDone(self, infeasible=False, infeasible_reason=''):
return False
[docs] def ShouldStopAndMaybeReport(self, global_step, metrics_dict):
del global_step, metrics_dict # Unused
return False
[docs] def ReportEvalMeasure(self, global_step, metrics_dict, checkpoint_path):
del global_step, metrics_dict, checkpoint_path # Unused
return False
[docs]class TunerManagedError(BaseException):
"""Base class for error that should be propagated to the tuner.
In base_runner.py, the training loop catchs all exceptions and treats
unknown errors as failure. However, in some cases (e.g. PyGlove uses
an EarlyStoppingError to signal early stopping that might take place
at any moment), it requires the error to propagate to the tuner. This
class is a base for such errors.
"""