Source code for lingvo.core.model_pruning.pruning

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Forked with minor changes from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/model_pruning/python/pruning.py  pylint: disable=line-too-long
"""Helper functions to add support for magnitude-based model pruning.

  # Adds variables and ops to the graph to enable
  # elementwise masking of weights
  apply_mask(weights)

  # Returns a list containing the sparsity of each of the weight tensors
  get_weight_sparsity()

  # Returns a list of all the masked weight tensorflow variables
  get_masked_weights()

  # Returns a list of all the mask tensorflow variables
  get_masks()

  # Returns a list of all the thresholds
  get_thresholds()

  # Returns a list of all the weight tensors that have been masked
  get_weights()

  The Pruning class uses a tf.hparams object to set up the
  parameters for a model pruning. Here's a typical usage:

  # Parse pruning hyperparameters
  pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

  # Create a pruning object using the pruning_hparams
  p = pruning.Pruning(pruning_hparams)

  # Add mask update ops to the graph
  mask_update_op = p.conditional_mask_update_op()

  # Add the summaries
  p.add_pruning_summaries()

  # Run the op
  session.run(mask_update_op)

  # An object of the pruning also accepts externally defined sparsity:
  sparsity = tf.Variable(0.5, name = "ConstantSparsity")
  p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re
from lingvo import compat as tf
from lingvo.core.model_pruning import hparam
from lingvo.core.model_pruning import pruning_utils

from tensorflow.python.ops import variables  # pylint: disable=g-direct-tensorflow-import

_MASK_COLLECTION = pruning_utils.MASK_COLLECTION
_THRESHOLD_COLLECTION = pruning_utils.THRESHOLD_COLLECTION
_MASKED_WEIGHT_COLLECTION = pruning_utils.MASKED_WEIGHT_COLLECTION
_WEIGHT_COLLECTION = pruning_utils.WEIGHT_COLLECTION
_MASKED_WEIGHT_NAME = pruning_utils.MASKED_WEIGHT_NAME


[docs]def apply_mask(x, scope=''): """Apply mask to a given weight tensor. Args: x: Input weight tensor scope: The current variable scope. Defaults to "". Returns: Tensor representing masked_weights """ mask = pruning_utils.weight_mask_variable(x, scope) threshold = pruning_utils.weight_threshold_variable(x, scope) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. masked_weights = tf.multiply(mask, x, _MASKED_WEIGHT_NAME) # Make sure the mask for a given variable are not added multiple times to the # collection. This is particularly important when applying mask to RNN's # weight variables if mask not in tf.get_collection_ref(_MASK_COLLECTION): tf.add_to_collection(_THRESHOLD_COLLECTION, threshold) tf.add_to_collection(_MASK_COLLECTION, mask) tf.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights) tf.add_to_collection(_WEIGHT_COLLECTION, x) return masked_weights
[docs]def get_masked_weights(): return tf.get_collection(_MASKED_WEIGHT_COLLECTION)
[docs]def get_masks(): return tf.get_collection(_MASK_COLLECTION)
[docs]def get_thresholds(): return tf.get_collection(_THRESHOLD_COLLECTION)
[docs]def get_weights(): return tf.get_collection(_WEIGHT_COLLECTION)
[docs]def get_weight_sparsity(): """Get sparsity of the weights. Args: None Returns: A list containing the sparsity of each of the weight tensors """ masks = get_masks() return [tf.nn.zero_fraction(mask) for mask in masks]
[docs]def get_pruning_hparams(): """Get a tf.HParams object with the default values for the hyperparameters. name: string name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope begin_pruning_step: integer the global step at which to begin pruning end_pruning_step: integer the global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops weight_sparsity_map: list of strings comma separed list of {weight_variable_name:target sparsity} or {regex:target sparsity} pairs. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. Eg. [conv1:0.9,conv2/kernel:0.8] block_dims_map: list of strings comma separated list of {weight variable name:block_height x block_width} or {regex:block_height x block_width} pairs. For layers/weights not in this list, block dims are specified by the block_height, block_width hyperparameters are used Eg. [dense1:4x4,dense2:1x16,dense3:1x1] threshold_decay: float the decay factor to use for exponential decay of the thresholds pruning_frequency: integer How often should the masks be updated? (in # of global_steps) nbins: integer number of bins to use for histogram computation block_height: integer number of rows in a block (defaults to 1), can be -1 in which case it is set to the size of the corresponding weight tensor. block_width: integer number of cols in a block (defaults to 1), can be -1 in which case it is set to the size of the corresponding weight tensor. block_pooling_function: string Whether to perform average (AVG) or max (MAX) pooling in the block (default: AVG) initial_sparsity: float initial sparsity value target_sparsity: float target sparsity value sparsity_function_begin_step: integer the global step at this which the gradual sparsity function begins to take effect sparsity_function_end_step: integer the global step used as the end point for the gradual sparsity function sparsity_function_exponent: float exponent = 1 is linearly varying sparsity between initial and final. exponent > 1 varies more slowly towards the end than the beginning use_tpu: False Indicates whether to use TPU We use the following sparsity function: num_steps = (sparsity_function_end_step - sparsity_function_begin_step)/pruning_frequency sparsity(step) = (initial_sparsity - target_sparsity)* [1-step/(num_steps -1)]**exponent + target_sparsity Args: None Returns: tf.HParams object initialized to default values """ return hparam.HParams( name='model_pruning', begin_pruning_step=0, end_pruning_step=-1, weight_sparsity_map=[''], block_dims_map=[''], threshold_decay=0.0, pruning_frequency=10, nbins=256, block_height=1, block_width=1, block_pooling_function='AVG', initial_sparsity=0.0, target_sparsity=0.5, sparsity_function_begin_step=0, sparsity_function_end_step=100, sparsity_function_exponent=3.0, use_tpu=False)
[docs]class Pruning(object): def __init__(self, spec=None, global_step=None, sparsity=None): """Set up the specification for model pruning. If a spec is provided, the sparsity is set up based on the sparsity_function in the spec. The effect of sparsity_function is overridden if the sparsity variable is passed to the constructor. This enables setting up arbitrary sparsity profiles externally and passing it to this pruning functions. Args: spec: Pruning spec as defined in pruning.proto global_step: A tensorflow variable that is used while setting up the sparsity function sparsity: A tensorflow scalar variable storing the sparsity """ # Pruning specification self._spec = spec if spec else get_pruning_hparams() # Sanity check for pruning hparams self._validate_spec() # A tensorflow variable that tracks the sparsity function. # If not provided as input, the graph must already contain the global_step # variable before calling this constructor. self._global_step = self._setup_global_step(global_step) # Stores the tensorflow sparsity variable. # Built using self._setup_sparsity() or provided externally self._sparsity = (sparsity if sparsity is not None else self._setup_sparsity()) # List of tensorflow assignments ops for new masks and thresholds self._assign_ops = [] # Tensorflow variable keeping track of the last global step when the masks # were updated self._last_update_step = self._setup_last_update_step() # Block dimensions self._block_dims = [self._spec.block_height, self._spec.block_width] # Block pooling function self._block_pooling_function = self._spec.block_pooling_function # Mapping of layer/weight names and block dims self._block_dims_map = self._get_block_dims_map() # Mapping of weight names and target sparsity self._weight_sparsity_map = self._get_weight_sparsity_map()
[docs] def _validate_spec(self): spec = self._spec if spec.begin_pruning_step < 0: raise ValueError('Illegal value for begin_pruning_step') if spec.begin_pruning_step >= spec.end_pruning_step: if spec.end_pruning_step != -1: raise ValueError( 'Pruning must begin before it can end. begin_step=%d, end_step=%d.' 'Set end_pruning_step to -1 if pruning is required till training' 'stops' % (spec.begin_pruning_step, spec.end_pruning_step)) if spec.sparsity_function_begin_step < 0: raise ValueError('Illegal value for sparsity_function_begin_step') if spec.sparsity_function_begin_step >= spec.sparsity_function_end_step: raise ValueError( 'Sparsity function requires begin_step < end_step') if not 0.0 <= spec.threshold_decay < 1.0: raise ValueError('threshold_decay must be in range [0,1)') if not 0.0 <= spec.initial_sparsity < 1.0: raise ValueError('initial_sparsity must be in range [0,1)') if not 0.0 <= spec.target_sparsity < 1.0: raise ValueError('target_sparsity must be in range [0,1)')
[docs] def _setup_global_step(self, global_step): graph_global_step = global_step if graph_global_step is None: graph_global_step = tf.train.get_global_step() return tf.cast(graph_global_step, tf.int32)
[docs] def _setup_sparsity(self): begin_step = self._spec.sparsity_function_begin_step end_step = self._spec.sparsity_function_end_step initial_sparsity = self._spec.initial_sparsity target_sparsity = self._spec.target_sparsity exponent = self._spec.sparsity_function_exponent with tf.name_scope(self._spec.name): p = tf.minimum( 1.0, tf.maximum( 0.0, tf.div( tf.cast(self._global_step - begin_step, tf.float32), end_step - begin_step))) sparsity = tf.add( tf.multiply(initial_sparsity - target_sparsity, tf.pow(1 - p, exponent)), target_sparsity, name='sparsity') return sparsity
[docs] def _setup_last_update_step(self): with tf.variable_scope( self._spec.name, use_resource=self._spec.use_tpu) as scope: try: last_update_step = tf.get_variable( 'last_mask_update_step', [], initializer=tf.zeros_initializer(), trainable=False, dtype=tf.int32) except ValueError: scope.reuse_variables() last_update_step = tf.get_variable( 'last_mask_update_step', dtype=tf.int32) return last_update_step
[docs] def _get_block_dims_map(self): """Returns the map of layer name: block dims.""" block_dims_map = {} val_list = self._spec.block_dims_map filtered_val_list = [l for l in val_list if l] for val in filtered_val_list: weight_name, block_dims_str = val.split(':') block_dims_str = block_dims_str.split('x') if len(block_dims_str) != 2: raise ValueError('Expected 2 values for block dim for %s, got %s' % (weight_name, block_dims_str)) block_dims = [int(block_dims_str[0]), int(block_dims_str[1])] block_dims_map[re.compile(weight_name)] = block_dims return block_dims_map
[docs] def _get_block_dims(self, weight_name): """Returns the block dims for the given layer/weight name.""" block_dims_list = [ block_dims for regexp, block_dims in self._block_dims_map.items() if regexp.search(weight_name) ] if not block_dims_list: return self._block_dims if len(block_dims_list) > 1: raise ValueError('Multiple matches in block_dims_map for weight %s' % weight_name) return block_dims_list[0]
[docs] def _get_weight_sparsity_map(self): """Returns the map of weight_name:sparsity parsed from the hparams.""" weight_sparsity_map = {} val_list = self._spec.weight_sparsity_map filtered_val_list = [l for l in val_list if l] for val in filtered_val_list: weight_name, sparsity = val.split(':') if float(sparsity) >= 1.0: raise ValueError('Weight sparsity can not exceed 1.0') weight_sparsity_map[re.compile(weight_name)] = float(sparsity) return weight_sparsity_map
[docs] def _get_sparsity(self, weight_name): """Returns target sparsity for the given layer/weight name.""" target_sparsity = [ sparsity for regexp, sparsity in self._weight_sparsity_map.items() if regexp.search(weight_name) ] if not target_sparsity: return self._sparsity if len(target_sparsity) > 1: raise ValueError( 'Multiple matches in weight_sparsity_map for weight %s' % weight_name) # TODO(suyoggupta): This will work when initial_sparsity = 0. Generalize # to handle other cases as well. return tf.multiply(self._sparsity, tf.div(target_sparsity[0], self._spec.target_sparsity))
[docs] def _update_mask(self, weights, threshold): """Updates the mask for a given weight tensor. This functions first computes the cdf of the weight tensor, and estimates the threshold value such that 'desired_sparsity' fraction of weights have magnitude less than the threshold. Args: weights: The weight tensor that needs to be masked. threshold: The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold Returns: new_threshold: The new value of the threshold based on weights, and sparsity at the current global_step new_mask: A numpy array of the same size and shape as weights containing 0 or 1 to indicate which of the values in weights falls below the threshold Raises: ValueError: if sparsity is not defined """ if self._sparsity is None: raise ValueError('Sparsity variable undefined') sparsity = self._get_sparsity(weights.op.name) with tf.name_scope(weights.op.name + '_pruning_ops'): abs_weights = tf.abs(weights) k = tf.cast( tf.round(tf.cast(tf.size(abs_weights), tf.float32) * (1 - sparsity)), tf.int32) # Sort the entire array values, _ = tf.nn.top_k( tf.reshape(abs_weights, [-1]), k=tf.size(abs_weights)) # Grab the (k-1) th value current_threshold = tf.gather(values, k - 1) smoothed_threshold = tf.add_n([ tf.multiply(current_threshold, 1 - self._spec.threshold_decay), tf.multiply(threshold, self._spec.threshold_decay) ]) new_mask = tf.cast( tf.greater_equal(abs_weights, smoothed_threshold), tf.float32) return smoothed_threshold, new_mask
[docs] def _maybe_update_block_mask(self, weights, threshold): """Performs block-granular masking of the weights. Block pruning occurs only if the block_height or block_width is > 1 and if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise pruning occurs. Args: weights: The weight tensor that needs to be masked. threshold: The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold Returns: new_threshold: The new value of the threshold based on weights, and sparsity at the current global_step new_mask: A numpy array of the same size and shape as weights containing 0 or 1 to indicate which of the values in weights falls below the threshold Raises: ValueError: if block pooling function is not AVG or MAX """ block_dims = self._get_block_dims(weights.op.name) squeezed_weights = tf.squeeze(weights) if squeezed_weights.get_shape().ndims != 2 or block_dims == [1, 1]: return self._update_mask(weights, threshold) for i in range(2): if block_dims[i] == -1: block_dims[i] = squeezed_weights.get_shape()[i] if self._block_pooling_function not in ['AVG', 'MAX']: raise ValueError('Unknown pooling function for block sparsity: %s' % self._block_pooling_function) with tf.name_scope(weights.op.name + '_pruning_ops'): abs_weights = tf.abs(squeezed_weights) pool_window = block_dims pool_fn = pruning_utils.factorized_pool squeeze_axis = None if not self._spec.use_tpu: pool_fn = tf.nn.pool abs_weights = tf.reshape( abs_weights, [1, abs_weights.get_shape()[0], abs_weights.get_shape()[1], 1]) squeeze_axis = [0, 3] pooled_weights = pool_fn( abs_weights, window_shape=pool_window, pooling_type=self._block_pooling_function, strides=pool_window, padding='SAME', name=weights.op.name + '_pooled') if pooled_weights.get_shape().ndims != 2: pooled_weights = tf.squeeze(pooled_weights, axis=squeeze_axis) smoothed_threshold, new_mask = self._update_mask(pooled_weights, threshold) updated_mask = pruning_utils.expand_tensor(new_mask, block_dims) sliced_mask = tf.slice( updated_mask, [0, 0], [squeezed_weights.get_shape()[0], squeezed_weights.get_shape()[1]]) return smoothed_threshold, tf.reshape(sliced_mask, tf.shape(weights))
[docs] def _get_mask_assign_ops(self): # Make sure the assignment ops have not already been added to the list if self._assign_ops: raise ValueError( 'Assign op list not empty. _get_mask_assign_ops() called twice?') masks = get_masks() weights = get_weights() thresholds = get_thresholds() if len(masks) != len(thresholds): raise ValueError( 'Number of masks %s and number of thresholds %s mismatch' % (len(masks), len(thresholds))) for index, mask in enumerate(masks): threshold = thresholds[index] weight = weights[index] is_partitioned = isinstance(weight, variables.PartitionedVariable) if is_partitioned: weight = weight.as_tensor() new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold) self._assign_ops.append( pruning_utils.variable_assign(threshold, new_threshold)) self._assign_ops.append( pruning_utils.partitioned_variable_assign(mask, new_mask) if is_partitioned else pruning_utils.variable_assign(mask, new_mask))
[docs] def mask_update_op(self): with tf.name_scope(self._spec.name): if not self._assign_ops: self._get_mask_assign_ops() with tf.control_dependencies([ tf.assign( self._last_update_step, self._global_step, name='last_mask_update_step_assign') ]): with tf.control_dependencies(self._assign_ops): tf.logging.info('Updating masks.') return tf.no_op('mask_update')
[docs] def conditional_mask_update_op(self): def maybe_update_masks(): with tf.name_scope(self._spec.name): is_step_within_pruning_range = tf.logical_and( tf.greater_equal(self._global_step, self._spec.begin_pruning_step), # If end_pruning_step is negative, keep pruning forever! tf.logical_or( tf.less_equal(self._global_step, self._spec.end_pruning_step), tf.less(self._spec.end_pruning_step, 0))) is_pruning_step = tf.less_equal( tf.add(self._last_update_step, self._spec.pruning_frequency), self._global_step) return tf.logical_and(is_step_within_pruning_range, is_pruning_step) def mask_update_op(): return self.mask_update_op() def no_update_op(): return tf.no_op() return tf.cond(maybe_update_masks(), mask_update_op, no_update_op)
[docs] def add_pruning_summaries(self): """Adds summaries of weight sparsities and thresholds.""" with tf.name_scope(self._spec.name + '_summaries'): tf.summary.scalar('sparsity', self._sparsity) tf.summary.scalar('last_mask_update_step', self._last_update_step) masks = get_masks() thresholds = get_thresholds() for mask, threshold in zip(masks, thresholds): tf.summary.scalar(mask.op.name + '/sparsity', tf.nn.zero_fraction(mask)) tf.summary.scalar(threshold.op.name + '/threshold', threshold)
[docs] def print_hparams(self): tf.logging.info(self._spec.to_json())