lingvo.core.model_pruning.pruning module

Helper functions to add support for magnitude-based model pruning.

# Adds variables and ops to the graph to enable # elementwise masking of weights apply_mask(weights)

# Returns a list containing the sparsity of each of the weight tensors get_weight_sparsity()

# Returns a list of all the masked weight tensorflow variables get_masked_weights()

# Returns a list of all the mask tensorflow variables get_masks()

# Returns a list of all the thresholds get_thresholds()

# Returns a list of all the weight tensors that have been masked get_weights()

The Pruning class uses a tf.hparams object to set up the parameters for a model pruning. Here’s a typical usage:

# Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

# Create a pruning object using the pruning_hparams p = pruning.Pruning(pruning_hparams)

# Add mask update ops to the graph mask_update_op = p.conditional_mask_update_op()

# Add the summaries p.add_pruning_summaries()

# Run the op session.run(mask_update_op)

# An object of the pruning also accepts externally defined sparsity: sparsity = tf.Variable(0.5, name = “ConstantSparsity”) p = pruning.Pruning(pruning_hparams, sparsity=sparsity)

lingvo.core.model_pruning.pruning.apply_mask(x, scope='')[source]

Apply mask to a given weight tensor.

Parameters
  • x – Input weight tensor

  • scope – The current variable scope. Defaults to “”.

Returns

Tensor representing masked_weights

lingvo.core.model_pruning.pruning.get_masked_weights()[source]
lingvo.core.model_pruning.pruning.get_masks()[source]
lingvo.core.model_pruning.pruning.get_thresholds()[source]
lingvo.core.model_pruning.pruning.get_weights()[source]
lingvo.core.model_pruning.pruning.get_weight_sparsity()[source]

Get sparsity of the weights.

Parameters

None

Returns

A list containing the sparsity of each of the weight tensors

lingvo.core.model_pruning.pruning.get_pruning_hparams()[source]

Get a tf.HParams object with the default values for the hyperparameters.

name: string

name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope

begin_pruning_step: integer

the global step at which to begin pruning

end_pruning_step: integer

the global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops

weight_sparsity_map: list of strings

comma separed list of {weight_variable_name:target sparsity} or {regex:target sparsity} pairs. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. Eg. [conv1:0.9,conv2/kernel:0.8]

block_dims_map: list of strings

comma separated list of {weight variable name:block_height x block_width} or {regex:block_height x block_width} pairs. For layers/weights not in this list, block dims are specified by the block_height, block_width hyperparameters are used Eg. [dense1:4x4,dense2:1x16,dense3:1x1]

threshold_decay: float

the decay factor to use for exponential decay of the thresholds

pruning_frequency: integer

How often should the masks be updated? (in # of global_steps)

nbins: integer

number of bins to use for histogram computation

block_height: integer

number of rows in a block (defaults to 1), can be -1 in which case it is set to the size of the corresponding weight tensor.

block_width: integer

number of cols in a block (defaults to 1), can be -1 in which case it is set to the size of the corresponding weight tensor.

block_pooling_function: string

Whether to perform average (AVG) or max (MAX) pooling in the block (default: AVG)

initial_sparsity: float

initial sparsity value

target_sparsity: float

target sparsity value

sparsity_function_begin_step: integer

the global step at this which the gradual sparsity function begins to take effect

sparsity_function_end_step: integer

the global step used as the end point for the gradual sparsity function

sparsity_function_exponent: float

exponent = 1 is linearly varying sparsity between initial and final. exponent > 1 varies more slowly towards the end than the beginning

use_tpu: False

Indicates whether to use TPU

We use the following sparsity function:

num_steps = (sparsity_function_end_step -

sparsity_function_begin_step)/pruning_frequency

sparsity(step) = (initial_sparsity - target_sparsity)*

[1-step/(num_steps -1)]**exponent + target_sparsity

Parameters

None

Returns

tf.HParams object initialized to default values

class lingvo.core.model_pruning.pruning.Pruning(spec=None, global_step=None, sparsity=None)[source]

Bases: object

_validate_spec()[source]
_setup_global_step(global_step)[source]
_setup_sparsity()[source]
_setup_last_update_step()[source]
_get_block_dims_map()[source]

Returns the map of layer name: block dims.

_get_block_dims(weight_name)[source]

Returns the block dims for the given layer/weight name.

_get_weight_sparsity_map()[source]

Returns the map of weight_name:sparsity parsed from the hparams.

_get_sparsity(weight_name)[source]

Returns target sparsity for the given layer/weight name.

_update_mask(weights, threshold)[source]

Updates the mask for a given weight tensor.

This functions first computes the cdf of the weight tensor, and estimates the threshold value such that ‘desired_sparsity’ fraction of weights have magnitude less than the threshold.

Parameters
  • weights – The weight tensor that needs to be masked.

  • threshold – The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold

Returns

The new value of the threshold based on weights, and

sparsity at the current global_step

new_mask: A numpy array of the same size and shape as weights containing

0 or 1 to indicate which of the values in weights falls below the threshold

Return type

new_threshold

Raises

ValueError – if sparsity is not defined

_maybe_update_block_mask(weights, threshold)[source]

Performs block-granular masking of the weights.

Block pruning occurs only if the block_height or block_width is > 1 and if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise pruning occurs.

Parameters
  • weights – The weight tensor that needs to be masked.

  • threshold – The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold

Returns

The new value of the threshold based on weights, and

sparsity at the current global_step

new_mask: A numpy array of the same size and shape as weights containing

0 or 1 to indicate which of the values in weights falls below the threshold

Return type

new_threshold

Raises

ValueError – if block pooling function is not AVG or MAX

_get_mask_assign_ops()[source]
mask_update_op()[source]
conditional_mask_update_op()[source]
add_pruning_summaries()[source]

Adds summaries of weight sparsities and thresholds.

print_hparams()[source]