Source code for lingvo.core.adagraft

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""AdaGraft optimizer https://arxiv.org/abs/2002.11803 ."""

import lingvo.compat as tf


[docs]class AdaGraftOptimizer(tf.train.Optimizer): """Optimizer which combines per-layer direction and magnitude from two optimizers. Disentangling Adaptive Gradient Methods from Learning Rates Naman Agarwal, Rohan Anil, Elad Hazan, Tomer Koren, Cyril Zhang https://arxiv.org/abs/2002.11803 """ def __init__(self, learning_rate, magnitude_optimizer, direction_optimizer, diagnostic=False, use_global_norm=False, name="AdaGraft"): """Construct a new AdaGraft optimizer. Args: learning_rate: A `Tensor` or a floating point value. The learning rate. magnitude_optimizer: Child Optimizer to inherit step sizes. direction_optimizer: Child Optimizer to inherit step directions. diagnostic: Whether to record per-tensor step norms. use_global_norm: Graft global l2 norms rather than per-layer. name: Optional name prefix for the operations created when applying gradients. """ super().__init__(False, name) self._learning_rate = learning_rate self.magnitude_optimizer = magnitude_optimizer self.direction_optimizer = direction_optimizer self.diagnostic = diagnostic self.use_global_norm = use_global_norm
[docs] def _create_slots(self, var_list): self.magnitude_optimizer._create_slots(var_list) # pylint: disable=protected-access self.direction_optimizer._create_slots(var_list) # pylint: disable=protected-access for v in var_list: with tf.ops.colocate_with(v): self._zeros_slot(v, "scratch_copy", self._name) if self.diagnostic or self.use_global_norm: self._get_or_make_slot(v, tf.constant(0.0), "m_step_norm", self._name) self._get_or_make_slot(v, tf.constant(0.0), "d_step_norm", self._name)
[docs] def _prepare(self): self.magnitude_optimizer._prepare() # pylint: disable=protected-access self.direction_optimizer._prepare() # pylint: disable=protected-access learning_rate = self._call_if_callable(self._learning_rate) self._learning_rate_tensor = tf.convert_to_tensor( learning_rate, name="learning_rate") if self.use_global_norm: # create list of all vars for global _finish self._variables = []
[docs] def _apply_dense(self, grad, var): # pylint: disable=g-doc-args return self._internal_apply_dense( grad, var, self.magnitude_optimizer._apply_dense, # pylint: disable=protected-access self.direction_optimizer._apply_dense) # pylint: disable=protected-access
[docs] def _resource_apply_dense(self, grad, var): return self._internal_apply_dense( grad, var, self.magnitude_optimizer._resource_apply_dense, # pylint: disable=protected-access self.direction_optimizer._resource_apply_dense) # pylint: disable=protected-access
[docs] def _internal_apply_dense(self, grad, var, magnitude_optimizer_apply_fn, direction_optimizer_apply_fn): # pylint: disable=g-doc-args """Main optimization logic of AdaGraft, which calls the child optimizers. Args: grad: Tensor containing gradients. var: Tensor containing parameter values. magnitude_optimizer_apply_fn: Apply magnitude optimizer. direction_optimizer_apply_fn: Apply direction optimizer. Returns: The final update op, which increments var by the grafted step. Pseudocode: - Copy weights into scratch space 'scratch_copy'. - Run magnitude_optimizer in-place. - Use scratch copy to figure out how far we moved ('magnitude_step'). - Copy weights back. - Run direction_optimizer in-place. - Move weights along the line segment with scratch_copy. """ if self.use_global_norm: self._variables.append(var) # Slot with current parameter values scratch_slot = self.get_slot(var, "scratch_copy") old_var = tf.assign(scratch_slot, var) with tf.control_dependencies([old_var]): m_updated_var = magnitude_optimizer_apply_fn(grad, var) # pylint: disable=protected-access # Run magnitude optimizer and compute the norm of the update. with tf.control_dependencies([m_updated_var]): m_step = var - old_var m_step_norm = tf.norm(m_step) if self.diagnostic or self.use_global_norm: m_step_norm = tf.assign(self.get_slot(var, "m_step_norm"), m_step_norm) # Run direction optimizer and compute its norm, and the direction. with tf.control_dependencies([m_step_norm]): flushed_var = tf.assign(var, old_var) with tf.control_dependencies([flushed_var]): d_updated_var = direction_optimizer_apply_fn(grad, var) # pylint: disable=protected-access # Run an update of the direction optimizer with magnitude optimizer norm. with tf.control_dependencies([d_updated_var]): d_step = var - old_var d_step_norm = tf.norm(d_step) if self.diagnostic or self.use_global_norm: d_step_norm = tf.assign(self.get_slot(var, "d_step_norm"), d_step_norm) if self.use_global_norm: flushed_var = tf.assign(var, old_var) with tf.control_dependencies([d_step_norm, flushed_var]): return tf.assign(scratch_slot, d_step) step = tf.where( tf.greater(d_step_norm, 0), (m_step_norm / tf.maximum(d_step_norm, 1e-30)) * d_step, tf.zeros_like(d_step)) return tf.assign(var, old_var + self._learning_rate_tensor * step)
[docs] def _finish(self, update_ops, name_scope): with tf.control_dependencies(update_ops): ops1 = self.magnitude_optimizer._finish([], name_scope + "_m") # pylint: disable=protected-access ops2 = self.direction_optimizer._finish([], name_scope + "_d") # pylint: disable=protected-access if self.use_global_norm: # apply global grafting with tf.control_dependencies([ops1, ops2]): m_global_norm = tf.Variable(0.) d_global_norm = tf.Variable(0.) for var in self._variables: m_step_norm = self.get_slot(var, "m_step_norm") d_step_norm = self.get_slot(var, "d_step_norm") tf.assign_add(m_global_norm, m_step_norm**2) tf.assign_add(d_global_norm, d_step_norm**2) multiplier = tf.sqrt(m_global_norm / tf.maximum(d_global_norm, 1e-30)) step_ops = [] for var in self._variables: d_step = self.get_slot(var, "scratch_copy") step = tf.where( tf.greater(d_step_norm, 0), multiplier * d_step, tf.zeros_like(d_step)) step_op = tf.assign_add(var, self._learning_rate_tensor * step) step_ops.append(step_op) return tf.group(*step_ops, name=name_scope) return tf.group(*([ops1, ops2] + update_ops), name=name_scope)
# Sparse gradients are not handled currently and is part of future work.
[docs] def _resource_apply_sparse(self, grad_values, var, grad_indices): return tf.no_op()
[docs] def _apply_sparse(self, grad, var): return tf.no_op()