Safe Haskell | None |
---|---|

Language | Haskell2010 |

Parallel lookups on the list of tensors.

## Synopsis

- embeddingLookup :: forall a b v1 v2 m. (MonadBuild m, Rendered (Tensor v1), TensorType a, OneOf '[Int64, Int32] b, Num b) => [Tensor v1 a] -> Tensor v2 b -> m (Tensor Value a)

# Documentation

:: forall a b v1 v2 m. (MonadBuild m, Rendered (Tensor v1), TensorType a, OneOf '[Int64, Int32] b, Num b) | |

=> [Tensor v1 a] | A list of tensors which can be concatenated along
dimension 0. Each |

-> Tensor v2 b | A |

-> m (Tensor Value a) | A dense tensor with shape `shape(ids) + shape(params)[1:]`. |

Looks up `ids`

in a list of embedding tensors.

This function is used to perform parallel lookups on the list of
tensors in `params`

. It is a generalization of `gather`

, where
`params`

is interpreted as a partition of a larger embedding
tensor.

The partition_strategy is "mod", we assign each id to partition `p = id % len(params)`. For instance, 13 ids are split across 5 partitions as: `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`

The results of the lookup are concatenated into a dense tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.