-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}

-- | Parallel lookups on the list of tensors.
module TensorFlow.EmbeddingOps where

import Control.Monad (zipWithM)
import Data.Int (Int32, Int64)
import TensorFlow.Build (MonadBuild)
import TensorFlow.Ops (shape, vector)  -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value, Rendered, colocateWith, render)
import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps

-- | Looks up `ids` in a list of embedding tensors.
--
-- This function is used to perform parallel lookups on the list of
-- tensors in `params`.  It is a generalization of `TF.gather`, where
-- `params` is interpreted as a partition of a larger embedding
-- tensor.
--
-- The partition_strategy is "mod", we assign each id to partition
-- `p = id % len(params)`. For instance,
-- 13 ids are split across 5 partitions as:
-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
--
-- The results of the lookup are concatenated into a dense
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v1 v2 m .
                   ( MonadBuild m
                   , Rendered (Tensor v1)
                   , TensorType a
                   , OneOf '[Int64, Int32] b
                   , Num b
                   )
                => [Tensor v1 a]
                -- ^ A list of tensors which can be concatenated along
                -- dimension 0. Each `Tensor` must be appropriately
                -- sized for `mod` partition strategy.
                -> Tensor v2 b
                -- ^ A `Tensor` with type `int32` or `int64`
                -- containing the ids to be looked up in `params`.
                -- The ids are required to have fewer than 2^31
                -- entries.
                -> m (Tensor Value a)
                -- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: [Tensor v1 a] -> Tensor v2 b -> m (Tensor Value a)
embeddingLookup [p0 :: Tensor v1 a
p0] ids :: Tensor v2 b
ids = Tensor v1 a -> m (Tensor Value a) -> m (Tensor Value a)
forall (m :: * -> *) (t :: * -> *) b a.
(MonadBuild m, Rendered t) =>
t b -> m a -> m a
colocateWith Tensor v1 a
p0 (Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor v1 a -> Tensor v2 b -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =>
Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams
CoreOps.gather Tensor v1 a
p0 Tensor v2 b
ids)
embeddingLookup params :: [Tensor v1 a]
params@(p0 :: Tensor v1 a
p0 : _) ids :: Tensor v2 b
ids = do
    -- Do np separate lookups, finding embeddings for plist[p] in params[p]
    [Tensor Value a]
partitionedResult <- (Tensor v1 a -> Tensor Build b -> m (Tensor Value a))
-> [Tensor v1 a] -> [Tensor Build b] -> m [Tensor Value a]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
                        (\p :: Tensor v1 a
p g :: Tensor Build b
g -> Tensor v1 a -> m (Tensor Value a) -> m (Tensor Value a)
forall (m :: * -> *) (t :: * -> *) b a.
(MonadBuild m, Rendered t) =>
t b -> m a -> m a
colocateWith Tensor v1 a
p (m (Tensor Value a) -> m (Tensor Value a))
-> m (Tensor Value a) -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor v1 a -> Tensor Build b -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =>
Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams
CoreOps.gather Tensor v1 a
p Tensor Build b
g)
                        [Tensor v1 a]
params [Tensor Build b]
gatherIds
    let unshapedResult :: Tensor Build a
unshapedResult = [Tensor Build Int32] -> [Tensor Value a] -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
[Tensor v'1 Int32] -> [Tensor v'2 t] -> Tensor Build t
CoreOps.dynamicStitch [Tensor Build Int32]
forall t.
(t /= Int8, t /= Int16, t /= Word8, t /= ByteString, t /= Bool,
 t /= Word16, t /= Float, t /= Double, TensorType t, Num t) =>
[Tensor Build t]
pindices [Tensor Value a]
partitionedResult
    -- Shape restoration is not as optimal as it would be with client
    -- side shape tracking.
    Tensor Value Int32
paramShape <- Tensor v1 a -> m (Tensor Value Int32) -> m (Tensor Value Int32)
forall (m :: * -> *) (t :: * -> *) b a.
(MonadBuild m, Rendered t) =>
t b -> m a -> m a
colocateWith Tensor v1 a
p0 (Tensor Build Int32 -> m (Tensor Value Int32)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor v1 a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor v1 a
p0))
    let finalShape :: Tensor Build Int32
finalShape = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 0 [Tensor v2 b -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor v2 b
ids, Tensor Build Int32
tailShape]
        tailShape :: Tensor Build Int32
tailShape = Tensor Value Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t index.
(TensorType t, OneOf '[Int32, Int64] index) =>
Tensor v'1 t
-> Tensor v'2 index -> Tensor v'3 index -> Tensor Build t
CoreOps.slice Tensor Value Int32
paramShape (Int32 -> Tensor Build Int32
singleton 1) (Int32 -> Tensor Build Int32
singleton (-1))
    Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =>
Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams
CoreOps.reshape Tensor Build a
unshapedResult Tensor Build Int32
finalShape
  where
    -- Avoids genericLength here which would be evaluated by TF.
    np :: b
np = Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Tensor v1 a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor v1 a]
params)
    flatIds :: Tensor Build b
flatIds = Tensor v2 b -> Tensor Build Int32 -> Tensor Build b
forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =>
Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams
CoreOps.reshape Tensor v2 b
ids (Int32 -> Tensor Build Int32
singleton (-1))
    pAssignments :: Tensor Build dstT
pAssignments = Tensor Build b -> Tensor Build dstT
forall (v'1 :: * -> *) srcT dstT.
(TensorType srcT, TensorType dstT) =>
Tensor v'1 srcT -> Tensor Build dstT
CoreOps.cast (Tensor Build b
flatIds Tensor Build b -> Tensor Build b -> Tensor Build b
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mod` Tensor Build b
forall b. Num b => b
np)
    newIds :: Tensor Build b
newIds = Tensor Build b
flatIds Tensor Build b -> Tensor Build b -> Tensor Build b
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` Tensor Build b
forall b. Num b => b
np
    originalIndices :: Tensor Build tidx
originalIndices = Tensor Build tidx
-> Tensor Build tidx -> Tensor Build tidx -> Tensor Build tidx
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor Build b -> Tensor Build tidx
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.size Tensor Build b
flatIds) 1
    -- Partition list of ids based on assignments into np separate lists
    gatherIds :: [Tensor Build b]
gatherIds = Int64 -> Tensor Build b -> Tensor Build Int32 -> [Tensor Build b]
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Int64 -> Tensor v'1 t -> Tensor v'2 Int32 -> [Tensor Build t]
CoreOps.dynamicPartition Int64
forall b. Num b => b
np Tensor Build b
newIds Tensor Build Int32
forall dstT. TensorType dstT => Tensor Build dstT
pAssignments
    -- Similarly, partition the original indices.
    pindices :: [Tensor Build t]
pindices = Int64 -> Tensor Build t -> Tensor Build Int32 -> [Tensor Build t]
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Int64 -> Tensor v'1 t -> Tensor v'2 Int32 -> [Tensor Build t]
CoreOps.dynamicPartition Int64
forall b. Num b => b
np Tensor Build t
forall tidx.
(tidx /= Int8, tidx /= Int16, tidx /= Word8, tidx /= ByteString,
 tidx /= Bool, tidx /= Word16, tidx /= Float, tidx /= Double,
 TensorType tidx, Num tidx) =>
Tensor Build tidx
originalIndices Tensor Build Int32
forall dstT. TensorType dstT => Tensor Build dstT
pAssignments
    singleton :: Int32 -> Tensor Build Int32
singleton i :: Int32
i = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [Int32
i :: Int32]

embeddingLookup [] _ = [Char] -> m (Tensor Value a)
forall a. HasCallStack => [Char] -> a
error "embeddingLookup requires params to be non empty"