-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}  -- For Fetchable (TensorExpr a)
module TensorFlow.Nodes where

import Control.Applicative (liftA2, liftA3)
import Data.Functor.Identity (Identity)
import Data.Map.Strict (Map)
import Data.Set (Set)
import Data.Text (Text)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set

import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types
import qualified TensorFlow.Internal.FFI as FFI

-- | Types that contain ops which can be run.
class Nodes t where
    getNodes :: t -> Build (Set NodeName)

-- | Types that tensor representations (e.g. 'Tensor', 'ControlNode') can be
-- fetched into.
--
-- Includes collections of tensors (e.g. tuples).
class Nodes t => Fetchable t a where
    getFetch :: t -> Build (Fetch a)

-- | Fetch action. Keeps track of what needs to be fetched and how to decode
-- the fetched data.
data Fetch a = Fetch
          { -- | Nodes to fetch
            Fetch a -> Set Text
fetches :: Set Text
            -- | Function to create an 'a' from the fetched data.
          , Fetch a -> Map Text TensorData -> a
fetchRestore :: Map Text FFI.TensorData -> a
          }

instance Functor Fetch where
    fmap :: (a -> b) -> Fetch a -> Fetch b
fmap f :: a -> b
f (Fetch fetch :: Set Text
fetch restore :: Map Text TensorData -> a
restore) = Set Text -> (Map Text TensorData -> b) -> Fetch b
forall a. Set Text -> (Map Text TensorData -> a) -> Fetch a
Fetch Set Text
fetch (a -> b
f (a -> b) -> (Map Text TensorData -> a) -> Map Text TensorData -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Text TensorData -> a
restore)

instance Applicative Fetch where
    pure :: a -> Fetch a
pure x :: a
x = Set Text -> (Map Text TensorData -> a) -> Fetch a
forall a. Set Text -> (Map Text TensorData -> a) -> Fetch a
Fetch Set Text
forall a. Set a
Set.empty (a -> Map Text TensorData -> a
forall a b. a -> b -> a
const a
x)
    Fetch fetch :: Set Text
fetch restore :: Map Text TensorData -> a -> b
restore <*> :: Fetch (a -> b) -> Fetch a -> Fetch b
<*> Fetch fetch' :: Set Text
fetch' restore' :: Map Text TensorData -> a
restore' =
        Set Text -> (Map Text TensorData -> b) -> Fetch b
forall a. Set Text -> (Map Text TensorData -> a) -> Fetch a
Fetch (Set Text
fetch Set Text -> Set Text -> Set Text
forall a. Semigroup a => a -> a -> a
<> Set Text
fetch') (Map Text TensorData -> a -> b
restore (Map Text TensorData -> a -> b)
-> (Map Text TensorData -> a) -> Map Text TensorData -> b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Map Text TensorData -> a
restore')

nodesUnion :: (Monoid b, Traversable t, Applicative f) => t (f b) -> f b
nodesUnion :: t (f b) -> f b
nodesUnion = (t b -> b) -> f (t b) -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((b -> b) -> t b -> b
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap b -> b
forall a. a -> a
id) (f (t b) -> f b) -> (t (f b) -> f (t b)) -> t (f b) -> f b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t (f b) -> f (t b)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA

instance (Nodes t1, Nodes t2) => Nodes (t1, t2) where
    getNodes :: (t1, t2) -> Build (Set NodeName)
getNodes (x :: t1
x, y :: t2
y) = [Build (Set NodeName)] -> Build (Set NodeName)
forall b (t :: * -> *) (f :: * -> *).
(Monoid b, Traversable t, Applicative f) =>
t (f b) -> f b
nodesUnion [t1 -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t1
x, t2 -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t2
y]

instance (Nodes t1, Nodes t2, Nodes t3) => Nodes (t1, t2, t3) where
    getNodes :: (t1, t2, t3) -> Build (Set NodeName)
getNodes (x :: t1
x, y :: t2
y, z :: t3
z) = [Build (Set NodeName)] -> Build (Set NodeName)
forall b (t :: * -> *) (f :: * -> *).
(Monoid b, Traversable t, Applicative f) =>
t (f b) -> f b
nodesUnion [t1 -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t1
x, t2 -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t2
y, t3 -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t3
z]

instance (Fetchable t1 a1, Fetchable t2 a2) => Fetchable (t1, t2) (a1, a2) where
    getFetch :: (t1, t2) -> Build (Fetch (a1, a2))
getFetch (x :: t1
x, y :: t2
y) = (a1 -> a2 -> (a1, a2)) -> Fetch a1 -> Fetch a2 -> Fetch (a1, a2)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) (Fetch a1 -> Fetch a2 -> Fetch (a1, a2))
-> BuildT Identity (Fetch a1)
-> BuildT Identity (Fetch a2 -> Fetch (a1, a2))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t1 -> BuildT Identity (Fetch a1)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t1
x BuildT Identity (Fetch a2 -> Fetch (a1, a2))
-> BuildT Identity (Fetch a2) -> Build (Fetch (a1, a2))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> t2 -> BuildT Identity (Fetch a2)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t2
y

instance (Fetchable t1 a1, Fetchable t2 a2, Fetchable t3 a3)
         => Fetchable (t1, t2, t3) (a1, a2, a3) where
    getFetch :: (t1, t2, t3) -> Build (Fetch (a1, a2, a3))
getFetch (x :: t1
x, y :: t2
y, z :: t3
z) =
        (a1 -> a2 -> a3 -> (a1, a2, a3))
-> Fetch a1 -> Fetch a2 -> Fetch a3 -> Fetch (a1, a2, a3)
forall (f :: * -> *) a b c d.
Applicative f =>
(a -> b -> c -> d) -> f a -> f b -> f c -> f d
liftA3 (,,) (Fetch a1 -> Fetch a2 -> Fetch a3 -> Fetch (a1, a2, a3))
-> BuildT Identity (Fetch a1)
-> BuildT Identity (Fetch a2 -> Fetch a3 -> Fetch (a1, a2, a3))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t1 -> BuildT Identity (Fetch a1)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t1
x BuildT Identity (Fetch a2 -> Fetch a3 -> Fetch (a1, a2, a3))
-> BuildT Identity (Fetch a2)
-> BuildT Identity (Fetch a3 -> Fetch (a1, a2, a3))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> t2 -> BuildT Identity (Fetch a2)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t2
y BuildT Identity (Fetch a3 -> Fetch (a1, a2, a3))
-> BuildT Identity (Fetch a3) -> Build (Fetch (a1, a2, a3))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> t3 -> BuildT Identity (Fetch a3)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t3
z

instance Nodes t => Nodes [t] where
    getNodes :: [t] -> Build (Set NodeName)
getNodes = [Build (Set NodeName)] -> Build (Set NodeName)
forall b (t :: * -> *) (f :: * -> *).
(Monoid b, Traversable t, Applicative f) =>
t (f b) -> f b
nodesUnion ([Build (Set NodeName)] -> Build (Set NodeName))
-> ([t] -> [Build (Set NodeName)]) -> [t] -> Build (Set NodeName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t -> Build (Set NodeName)) -> [t] -> [Build (Set NodeName)]
forall a b. (a -> b) -> [a] -> [b]
map t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes

instance Fetchable t a => Fetchable [t] [a] where
    getFetch :: [t] -> Build (Fetch [a])
getFetch ts :: [t]
ts  = [Fetch a] -> Fetch [a]
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA ([Fetch a] -> Fetch [a])
-> BuildT Identity [Fetch a] -> Build (Fetch [a])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (t -> BuildT Identity (Fetch a))
-> [t] -> BuildT Identity [Fetch a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM t -> BuildT Identity (Fetch a)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch [t]
ts

instance Nodes t => Nodes (Maybe t) where
    getNodes :: Maybe t -> Build (Set NodeName)
getNodes = Maybe (Build (Set NodeName)) -> Build (Set NodeName)
forall b (t :: * -> *) (f :: * -> *).
(Monoid b, Traversable t, Applicative f) =>
t (f b) -> f b
nodesUnion (Maybe (Build (Set NodeName)) -> Build (Set NodeName))
-> (Maybe t -> Maybe (Build (Set NodeName)))
-> Maybe t
-> Build (Set NodeName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t -> Build (Set NodeName))
-> Maybe t -> Maybe (Build (Set NodeName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes

instance Fetchable t a => Fetchable (Maybe t) (Maybe a) where
    getFetch :: Maybe t -> Build (Fetch (Maybe a))
getFetch = (Maybe (Fetch a) -> Fetch (Maybe a))
-> BuildT Identity (Maybe (Fetch a)) -> Build (Fetch (Maybe a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Maybe (Fetch a) -> Fetch (Maybe a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA (BuildT Identity (Maybe (Fetch a)) -> Build (Fetch (Maybe a)))
-> (Maybe t -> BuildT Identity (Maybe (Fetch a)))
-> Maybe t
-> Build (Fetch (Maybe a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t -> BuildT Identity (Fetch a))
-> Maybe t -> BuildT Identity (Maybe (Fetch a))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM t -> BuildT Identity (Fetch a)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch

instance Nodes ControlNode where
    getNodes :: ControlNode -> Build (Set NodeName)
getNodes (ControlNode o :: NodeName
o) = Set NodeName -> Build (Set NodeName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Set NodeName -> Build (Set NodeName))
-> Set NodeName -> Build (Set NodeName)
forall a b. (a -> b) -> a -> b
$ NodeName -> Set NodeName
forall a. a -> Set a
Set.singleton NodeName
o

-- We use the constraint @(a ~ ())@ to help with type inference.  For example,
-- if @t :: ControlNode@, then this constraint ensures that @run t :: Session
-- ()@.  If we used @instance Fetchable ControlNode ()@ instead, then that
-- expression would be ambiguous without explicitly specifying the return type.
instance a ~ () => Fetchable ControlNode a where
    getFetch :: ControlNode -> Build (Fetch a)
getFetch _ = Fetch () -> Build (Fetch a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Fetch () -> Build (Fetch a)) -> Fetch () -> Build (Fetch a)
forall a b. (a -> b) -> a -> b
$ () -> Fetch ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance Nodes (ListOf f '[]) where
    getNodes :: ListOf f '[] -> Build (Set NodeName)
getNodes _ = Set NodeName -> Build (Set NodeName)
forall (m :: * -> *) a. Monad m => a -> m a
return Set NodeName
forall a. Set a
Set.empty

instance (Nodes (f a), Nodes (ListOf f as)) => Nodes (ListOf f (a ': as)) where
    getNodes :: ListOf f (a : as) -> Build (Set NodeName)
getNodes (x :: f a
x :/ xs :: ListOf f as
xs) = (Set NodeName -> Set NodeName -> Set NodeName)
-> Build (Set NodeName)
-> Build (Set NodeName)
-> Build (Set NodeName)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Set NodeName -> Set NodeName -> Set NodeName
forall a. Ord a => Set a -> Set a -> Set a
Set.union (f a -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes f a
x) (ListOf f as -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes ListOf f as
xs)

instance l ~ List '[] => Fetchable (ListOf f '[]) l where
    getFetch :: ListOf f '[] -> Build (Fetch l)
getFetch _ = Fetch (List '[]) -> Build (Fetch l)
forall (m :: * -> *) a. Monad m => a -> m a
return (Fetch (List '[]) -> Build (Fetch l))
-> Fetch (List '[]) -> Build (Fetch l)
forall a b. (a -> b) -> a -> b
$ List '[] -> Fetch (List '[])
forall (f :: * -> *) a. Applicative f => a -> f a
pure List '[]
forall (f :: * -> *). ListOf f '[]
Nil

instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
    => Fetchable (ListOf f (t ': ts)) (ListOf i (a ': as)) where
    getFetch :: ListOf f (t : ts) -> Build (Fetch (ListOf i (a : as)))
getFetch (x :: f a
x :/ xs :: ListOf f as
xs) = (a -> List as -> List (a : as))
-> Fetch a -> Fetch (List as) -> Fetch (List (a : as))
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (\y :: a
y ys :: List as
ys -> a
y a -> List as -> List (a : as)
forall a (as :: [*]). a -> List as -> List (a : as)
/:/ List as
ys) (Fetch a -> Fetch (List as) -> Fetch (List (a : as)))
-> BuildT Identity (Fetch a)
-> BuildT Identity (Fetch (List as) -> Fetch (List (a : as)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a -> BuildT Identity (Fetch a)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch f a
x BuildT Identity (Fetch (List as) -> Fetch (List (a : as)))
-> BuildT Identity (Fetch (List as))
-> BuildT Identity (Fetch (List (a : as)))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ListOf f as -> BuildT Identity (Fetch (List as))
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch ListOf f as
xs

instance Nodes (Tensor v a) where
    getNodes :: Tensor v a -> Build (Set NodeName)
getNodes (Tensor o :: v Output
o) = NodeName -> Set NodeName
forall a. a -> Set a
Set.singleton (NodeName -> Set NodeName)
-> (Output -> NodeName) -> Output -> Set NodeName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> NodeName
outputNodeName (Output -> Set NodeName)
-> BuildT Identity Output -> Build (Set NodeName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild v Output
o

fetchTensorVector :: forall a v . (TensorType a)
                  => Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector :: Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector (Tensor o :: v Output
o) = do
    Text
outputName <- Output -> Text
encodeOutput (Output -> Text) -> BuildT Identity Output -> BuildT Identity Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild v Output
o
    Fetch (TensorData a) -> Build (Fetch (TensorData a))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fetch (TensorData a) -> Build (Fetch (TensorData a)))
-> Fetch (TensorData a) -> Build (Fetch (TensorData a))
forall a b. (a -> b) -> a -> b
$ Set Text
-> (Map Text TensorData -> TensorData a) -> Fetch (TensorData a)
forall a. Set Text -> (Map Text TensorData -> a) -> Fetch a
Fetch (Text -> Set Text
forall a. a -> Set a
Set.singleton Text
outputName) ((Map Text TensorData -> TensorData a) -> Fetch (TensorData a))
-> (Map Text TensorData -> TensorData a) -> Fetch (TensorData a)
forall a b. (a -> b) -> a -> b
$ \tensors :: Map Text TensorData
tensors ->
        let tensorData :: TensorData
tensorData = Map Text TensorData
tensors Map Text TensorData -> Text -> TensorData
forall k a. Ord k => Map k a -> k -> a
Map.! Text
outputName
            expectedType :: DataType
expectedType = a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)
            actualType :: DataType
actualType = TensorData -> DataType
FFI.tensorDataType TensorData
tensorData
            badTypeError :: TensorData a
badTypeError = [Char] -> TensorData a
forall a. HasCallStack => [Char] -> a
error ([Char] -> TensorData a) -> [Char] -> TensorData a
forall a b. (a -> b) -> a -> b
$ "Bad tensor type: expected "
                                   [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ DataType -> [Char]
forall a. Show a => a -> [Char]
show DataType
expectedType
                                   [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ ", got "
                                   [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ DataType -> [Char]
forall a. Show a => a -> [Char]
show DataType
actualType
        in if DataType
expectedType DataType -> DataType -> Bool
forall a. Eq a => a -> a -> Bool
/= DataType
actualType
               then TensorData a
badTypeError
               else TensorData -> TensorData a
forall a. TensorData -> TensorData a
TensorData TensorData
tensorData

-- The constraint "a ~ a'" means that the input/output of fetch can constrain
-- the TensorType of each other.
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (TensorData a') where
    getFetch :: Tensor v a -> Build (Fetch (TensorData a'))
getFetch = Tensor v a -> Build (Fetch (TensorData a'))
forall a (v :: * -> *).
TensorType a =>
Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector

instance (TensorType a, TensorDataType s a, a ~ a') => Fetchable (Tensor v a) (s a') where
    getFetch :: Tensor v a -> Build (Fetch (s a'))
getFetch t :: Tensor v a
t = (TensorData a -> s a) -> Fetch (TensorData a) -> Fetch (s a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TensorData a -> s a
forall (s :: * -> *) a. TensorDataType s a => TensorData a -> s a
decodeTensorData (Fetch (TensorData a) -> Fetch (s a))
-> BuildT Identity (Fetch (TensorData a))
-> BuildT Identity (Fetch (s a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor v a -> BuildT Identity (Fetch (TensorData a))
forall a (v :: * -> *).
TensorType a =>
Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector Tensor v a
t