lingvo.core.entmax module

Define the entmax sampling and entmax loss.

This is the entmax that demonstrated in this publication: https://arxiv.org/pdf/2004.02644.pdf. And the implementation is based on the https://github.com/deep-spin/entmax which is implemented in pytorch. We hope to use it in Meena2 to unify the training and inference under the entmax framework that can produce the sparse probabilities.

lingvo.core.entmax._calculate_probability(inputs: Tensor, alpha: float = 1.5)[source]

Calculate the probability.

lingvo.core.entmax.entmax_support(inputs: Tensor, alpha: float = 1.5, axis: int = -1, n_iter: int = 50, ensure_sum_one: bool = True) Tensor[source]

Calculate the entmax probabilities.

lingvo.core.entmax.entmax_loss(labels: Tensor, inputs: Tensor, alpha: float = 1.5, n_iter: int = 50, ensure_sum_one: bool = True) Tensor[source]

Calculates the loss using the entmax.