Source code for lingvo.core.differentiable_assignment

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Implementation of differentiable assignment operators in TF.

References:
[1] Csisz{\'a}r, 2008. On Iterative Algoirthms with an Information Geometry
Background.
[2] Cuturi, 2013. Lightspeed Computation of Optimal Transport.
[3] Schmitzer 2019. Stabilized Sparse Scaling Algorithms for Entropy
Regularized Transport Problems.
"""

from lingvo import compat as tf


[docs]def max_assignment(score: tf.Tensor, *, elementwise_upper_bound: tf.Tensor, row_sums: tf.Tensor, col_sums: tf.Tensor, epsilon: float = 0.1, num_iterations: int = 50, use_epsilon_scaling: bool = True): """Differentiable max assignment with margin and upper bound constraints. Args: score: a 3D tensor of size [batch_size, n_rows, n_columns]. score[i, j, k] denotes the weight if the assignment on this entry is non-zero. elementwise_upper_bound: a 3D tensor of size [batch_size, n_rows, n_columns]. Each entry denotes the maximum value assignment[i, j, k] can take and must be a non-negative value. For example, upper_bound[i, j, k]=1.0 for binary assignment problem. row_sums: a 2D tensor of size [batch_size, n_rows]. The row sum constraint. The output assignment p[i, j, :] must sum to row_sums[i, j]. col_sums: a 2D tensor of size [batch_size, n_columns]. The column sum constraint. The output assignment p[i, :, k] must sum to col_sums[i, k]. epsilon: the epsilon coefficient of entropy regularization. The value should be within the range (0, 1]. `0.01` might work better than `0.1`. `0.1` may not make the assignment close enough to 0 or 1. num_iterations: the maximum number of iterations to perform. use_epsilon_scaling: whether to use epsilon scaling. In practice, the convergence of the iterative algorithm is much better if we start by solving the optimization with a larger epsilon value and re-use the solution (i.e. dual variables) for the instance with a smaller epsilon. This is called the epsilon scaling trick. See [Schmitzer 2019] (https://arxiv.org/pdf/1610.06519.pdf) as a reference. Here if use_epsilon_scaling=True, after each iteration we decrease the running epsilon by a constant factor until it reaches the target epsilon value. We found this to work well for gradient backward propagation, while the original scaling trick doesn't. Returns: A tuple with the following values. - assignment: a 3D tensor of size [batch_size, n_rows, n_columns]. The output assignment. - used_iter: a scalar tensor indicating the number of iterations used. - eps: a scalar tensor indicating the stopping epsilon value. - delta: a scalar tensor indicating the stopping delta value (the relative change on the margins of assignment p in the last iteration). """ # Check if all shapes are correct score_shape = score.shape bsz = score_shape[0] n = score_shape[1] m = score_shape[2] score = tf.ensure_shape(score, [bsz, n, m]) elementwise_upper_bound = tf.ensure_shape(elementwise_upper_bound, [bsz, n, m]) row_sums = tf.ensure_shape(tf.expand_dims(row_sums, axis=2), [bsz, n, 1]) col_sums = tf.ensure_shape(tf.expand_dims(col_sums, axis=1), [bsz, 1, m]) # the total sum of row sums must be equal to total sum of column sums sum_diff = tf.reduce_sum(row_sums, axis=1) - tf.reduce_sum(col_sums, axis=2) sum_diff = tf.abs(sum_diff) tf.Assert(tf.reduce_all(sum_diff < 1e-6), [sum_diff]) # Convert upper_bound constraint into another margin constraint # by adding auxiliary variables & scores. Tensor `a`, `b` and `c` # represent the margins (i.e. reduced sum) of 3 axes respectively. # max_row_sums = tf.reduce_sum(elementwise_upper_bound, axis=-1, keepdims=True) max_col_sums = tf.reduce_sum(elementwise_upper_bound, axis=-2, keepdims=True) score_ = tf.stack([score, tf.zeros_like(score)], axis=1) # (bsz, 2, n, m) a = tf.stack([row_sums, max_row_sums - row_sums], axis=1) # (bsz, 2, n, 1) b = tf.stack([col_sums, max_col_sums - col_sums], axis=1) # (bsz, 2, 1, m) c = tf.expand_dims(elementwise_upper_bound, axis=1) # (bsz, 1, n, m) # Clip log(0) to a large negative values -1e+36 to avoid # getting inf or NaN values in computation. Cannot use larger # values because float32 would use `-inf` automatically. # tf.Assert(tf.reduce_all(a >= 0), [a]) tf.Assert(tf.reduce_all(b >= 0), [b]) tf.Assert(tf.reduce_all(c >= 0), [c]) log_a = tf.maximum(tf.math.log(a), -1e+36) log_b = tf.maximum(tf.math.log(b), -1e+36) log_c = tf.maximum(tf.math.log(c), -1e+36) # Initialize the dual variables of margin constraints u = tf.zeros_like(a) v = tf.zeros_like(b) w = tf.zeros_like(c) eps = tf.constant(1.0 if use_epsilon_scaling else epsilon, dtype=score.dtype) epsilon = tf.constant(epsilon, dtype=score.dtype) def do_updates(cur_iter, eps, u, v, w): # pylint: disable=unused-argument # Epsilon scaling, i.e. gradually decreasing `eps` until it # reaches the target `epsilon` value cur_iter = tf.cast(cur_iter, u.dtype) scaling = tf.minimum(0.6 * 1.04**cur_iter, 0.85) eps = tf.maximum(epsilon, eps * scaling) score_div_eps = score_ / eps # Update u log_q_1 = score_div_eps + (w + v) / eps log_q_1 = tf.reduce_logsumexp(log_q_1, axis=-1, keepdims=True) new_u = (log_a - tf.maximum(log_q_1, -1e+30)) * eps # Update v log_q_2 = score_div_eps + (w + new_u) / eps log_q_2 = tf.reduce_logsumexp(log_q_2, axis=-2, keepdims=True) new_v = (log_b - tf.maximum(log_q_2, -1e+30)) * eps # Update w log_q_3 = score_div_eps + (new_u + new_v) / eps log_q_3 = tf.reduce_logsumexp(log_q_3, axis=-3, keepdims=True) new_w = (log_c - tf.maximum(log_q_3, -1e+30)) * eps return eps, new_u, new_v, new_w def compute_relative_changes(eps, u, v, w, new_eps, new_u, new_v, new_w): prev_sum_uvw = tf.stop_gradient((u + v + w) / eps) sum_uvw = tf.stop_gradient((new_u + new_v + new_w) / new_eps) # Compute the relative changes on margins of P. # This will be used for stopping criteria. # Note the last update on w would guarantee the # margin constraint c is satisfied, so we don't # need to check it here. p = tf.exp(tf.stop_gradient(score_ / new_eps + sum_uvw)) p_a = tf.reduce_sum(p, axis=-1, keepdims=True) p_b = tf.reduce_sum(p, axis=-2, keepdims=True) delta_a = tf.abs(a - p_a) / (a + 1e-6) delta_b = tf.abs(b - p_b) / (b + 1e-6) new_delta = tf.reduce_max(delta_a) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_b)) # Compute the relative changes on assignment solution P. # This will be used for stopping criteria. delta_p = tf.abs(tf.exp(prev_sum_uvw) - tf.exp(sum_uvw)) / ( tf.exp(sum_uvw) + 1e-6) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_p)) return new_delta for cur_iter in tf.range(num_iterations): prev_eps, prev_u, prev_v, prev_w = eps, u, v, w eps, u, v, w = do_updates(cur_iter, eps, u, v, w) delta = compute_relative_changes(prev_eps, prev_u, prev_v, prev_w, eps, u, v, w) cur_iter = num_iterations assignment = tf.exp((score_ + u + v + w) / eps) assignment = assignment[:, 0] return assignment, cur_iter, eps, delta