lingvo.core.model_pruning.pruning_utils module¶
Utility functions for adding pruning related ops to the graph.
-
lingvo.core.model_pruning.pruning_utils.weight_mask_variable(var, scope)[source]¶ Create a mask for the weights.
This function adds a variable ‘mask’ to the graph.
- Parameters
var – the weight variable that needs to be masked
scope – The variable scope of the variable var
- Returns
the mask variable of the same size and shape as var, initialized to all 1s.
-
lingvo.core.model_pruning.pruning_utils.weight_threshold_variable(var, scope)[source]¶ Create a scalar threshold for the weights.
This function adds a variable ‘threshold’ to the graph.
- Parameters
var – The weight variable that needs to be masked
scope – The variable scope of the variable var
- Returns
A scalar threshold variable initialized to 0.
-
lingvo.core.model_pruning.pruning_utils.kronecker_product(mat1, mat2)[source]¶ Computes the Kronecker product of two matrices mat1 and mat2.
- Parameters
mat1 – A matrix of size m x n
mat2 – A matrix of size p x q
- Returns
Kronecker product of matrices mat1 and mat2 of size mp x nq
-
lingvo.core.model_pruning.pruning_utils.expand_tensor(tensor, block_dims)[source]¶ Expands a 2D tensor by replicating the tensor values.
This is equivalent to the kronecker product of the tensor and a matrix of ones of size block_dims.
Example:
tensor = [[1,2] [3,4]] block_dims = [2,2] result = [[1 1 2 2] [1 1 2 2] [3 3 4 4] [3 3 4 4]]
- Parameters
tensor – A 2D tensor that needs to be expanded.
block_dims – List of integers specifying the expansion factor.
- Returns
The expanded tensor
- Raises
ValueError – if tensor is not rank-2 or block_dims is does not have 2
elements. –
-
lingvo.core.model_pruning.pruning_utils.factorized_pool(input_tensor, window_shape, pooling_type, strides, padding, name=None)[source]¶ Performs m x n pooling through a combination of 1xm and 1xn pooling.
- Parameters
input_tensor – Input tensor. Must be rank 2
window_shape – Pooling window shape
pooling_type – Either ‘MAX’ or ‘AVG’
strides – The stride of the pooling window
padding – ‘SAME’ or ‘VALID’.
name – Name of the op
- Returns
A rank 2 tensor containing the pooled output
- Raises
ValueError – if the input tensor is not rank 2
-
lingvo.core.model_pruning.pruning_utils.partitioned_variable_assign(partitioned_var, new_value)[source]¶ Assign op for partitioned variables.
- Parameters
partitioned_var – A partitioned tensorflow variable
new_value – Value to be assigned to the variable var
- Returns
A tensorflow op that groups the assign ops for each of the variable slices