lingvo.core.model_pruning.pruning module¶
Helper functions to add support for magnitude-based model pruning.
# Adds variables and ops to the graph to enable # elementwise masking of weights apply_mask(weights)
# Returns a list containing the sparsity of each of the weight tensors get_weight_sparsity()
# Returns a list of all the masked weight tensorflow variables get_masked_weights()
# Returns a list of all the mask tensorflow variables get_masks()
# Returns a list of all the thresholds get_thresholds()
# Returns a list of all the weight tensors that have been masked get_weights()
The Pruning class uses a tf.hparams object to set up the parameters for a model pruning. Here’s a typical usage:
# Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
# Create a pruning object using the pruning_hparams p = pruning.Pruning(pruning_hparams)
# Add mask update ops to the graph mask_update_op = p.conditional_mask_update_op()
# Add the summaries p.add_pruning_summaries()
# Run the op session.run(mask_update_op)
# An object of the pruning also accepts externally defined sparsity: sparsity = tf.Variable(0.5, name = “ConstantSparsity”) p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
-
lingvo.core.model_pruning.pruning.apply_mask(x, scope='')[source]¶ Apply mask to a given weight tensor.
- Parameters
x – Input weight tensor
scope – The current variable scope. Defaults to “”.
- Returns
Tensor representing masked_weights
-
lingvo.core.model_pruning.pruning.get_weight_sparsity()[source]¶ Get sparsity of the weights.
- Parameters
None –
- Returns
A list containing the sparsity of each of the weight tensors
-
lingvo.core.model_pruning.pruning.get_pruning_hparams()[source]¶ Get a tf.HParams object with the default values for the hyperparameters.
- name: string
name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope
- begin_pruning_step: integer
the global step at which to begin pruning
- end_pruning_step: integer
the global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops
- weight_sparsity_map: list of strings
comma separed list of {weight_variable_name:target sparsity} or {regex:target sparsity} pairs. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. Eg. [conv1:0.9,conv2/kernel:0.8]
- block_dims_map: list of strings
comma separated list of {weight variable name:block_height x block_width} or {regex:block_height x block_width} pairs. For layers/weights not in this list, block dims are specified by the block_height, block_width hyperparameters are used Eg. [dense1:4x4,dense2:1x16,dense3:1x1]
- threshold_decay: float
the decay factor to use for exponential decay of the thresholds
- pruning_frequency: integer
How often should the masks be updated? (in # of global_steps)
- nbins: integer
number of bins to use for histogram computation
- block_height: integer
number of rows in a block (defaults to 1), can be -1 in which case it is set to the size of the corresponding weight tensor.
- block_width: integer
number of cols in a block (defaults to 1), can be -1 in which case it is set to the size of the corresponding weight tensor.
- block_pooling_function: string
Whether to perform average (AVG) or max (MAX) pooling in the block (default: AVG)
- initial_sparsity: float
initial sparsity value
- target_sparsity: float
target sparsity value
- sparsity_function_begin_step: integer
the global step at this which the gradual sparsity function begins to take effect
- sparsity_function_end_step: integer
the global step used as the end point for the gradual sparsity function
- sparsity_function_exponent: float
exponent = 1 is linearly varying sparsity between initial and final. exponent > 1 varies more slowly towards the end than the beginning
- use_tpu: False
Indicates whether to use TPU
We use the following sparsity function:
- num_steps = (sparsity_function_end_step -
sparsity_function_begin_step)/pruning_frequency
- sparsity(step) = (initial_sparsity - target_sparsity)*
[1-step/(num_steps -1)]**exponent + target_sparsity
- Parameters
None –
- Returns
tf.HParams object initialized to default values
-
class
lingvo.core.model_pruning.pruning.Pruning(spec=None, global_step=None, sparsity=None)[source]¶ Bases:
object-
_get_weight_sparsity_map()[source]¶ Returns the map of weight_name:sparsity parsed from the hparams.
-
_update_mask(weights, threshold)[source]¶ Updates the mask for a given weight tensor.
This functions first computes the cdf of the weight tensor, and estimates the threshold value such that ‘desired_sparsity’ fraction of weights have magnitude less than the threshold.
- Parameters
weights – The weight tensor that needs to be masked.
threshold – The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold
- Returns
- The new value of the threshold based on weights, and
sparsity at the current global_step
- new_mask: A numpy array of the same size and shape as weights containing
0 or 1 to indicate which of the values in weights falls below the threshold
- Return type
new_threshold
- Raises
ValueError – if sparsity is not defined
-
_maybe_update_block_mask(weights, threshold)[source]¶ Performs block-granular masking of the weights.
Block pruning occurs only if the block_height or block_width is > 1 and if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise pruning occurs.
- Parameters
weights – The weight tensor that needs to be masked.
threshold – The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold
- Returns
- The new value of the threshold based on weights, and
sparsity at the current global_step
- new_mask: A numpy array of the same size and shape as weights containing
0 or 1 to indicate which of the values in weights falls below the threshold
- Return type
new_threshold
- Raises
ValueError – if block pooling function is not AVG or MAX
-