{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module TensorFlow.Nodes where
import Control.Applicative (liftA2, liftA3)
import Data.Functor.Identity (Identity)
import Data.Map.Strict (Map)
import Data.Set (Set)
import Data.Text (Text)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types
import qualified TensorFlow.Internal.FFI as FFI
class Nodes t where
getNodes :: t -> Build (Set NodeName)
class Nodes t => Fetchable t a where
getFetch :: t -> Build (Fetch a)
data Fetch a = Fetch
{
Fetch a -> Set Text
fetches :: Set Text
, Fetch a -> Map Text TensorData -> a
fetchRestore :: Map Text FFI.TensorData -> a
}
instance Functor Fetch where
fmap :: (a -> b) -> Fetch a -> Fetch b
fmap f :: a -> b
f (Fetch fetch :: Set Text
fetch restore :: Map Text TensorData -> a
restore) = Set Text -> (Map Text TensorData -> b) -> Fetch b
forall a. Set Text -> (Map Text TensorData -> a) -> Fetch a
Fetch Set Text
fetch (a -> b
f (a -> b) -> (Map Text TensorData -> a) -> Map Text TensorData -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Text TensorData -> a
restore)
instance Applicative Fetch where
pure :: a -> Fetch a
pure x :: a
x = Set Text -> (Map Text TensorData -> a) -> Fetch a
forall a. Set Text -> (Map Text TensorData -> a) -> Fetch a
Fetch Set Text
forall a. Set a
Set.empty (a -> Map Text TensorData -> a
forall a b. a -> b -> a
const a
x)
Fetch fetch :: Set Text
fetch restore :: Map Text TensorData -> a -> b
restore <*> :: Fetch (a -> b) -> Fetch a -> Fetch b
<*> Fetch fetch' :: Set Text
fetch' restore' :: Map Text TensorData -> a
restore' =
Set Text -> (Map Text TensorData -> b) -> Fetch b
forall a. Set Text -> (Map Text TensorData -> a) -> Fetch a
Fetch (Set Text
fetch Set Text -> Set Text -> Set Text
forall a. Semigroup a => a -> a -> a
<> Set Text
fetch') (Map Text TensorData -> a -> b
restore (Map Text TensorData -> a -> b)
-> (Map Text TensorData -> a) -> Map Text TensorData -> b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Map Text TensorData -> a
restore')
nodesUnion :: (Monoid b, Traversable t, Applicative f) => t (f b) -> f b
nodesUnion :: t (f b) -> f b
nodesUnion = (t b -> b) -> f (t b) -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((b -> b) -> t b -> b
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap b -> b
forall a. a -> a
id) (f (t b) -> f b) -> (t (f b) -> f (t b)) -> t (f b) -> f b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t (f b) -> f (t b)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA
instance (Nodes t1, Nodes t2) => Nodes (t1, t2) where
getNodes :: (t1, t2) -> Build (Set NodeName)
getNodes (x :: t1
x, y :: t2
y) = [Build (Set NodeName)] -> Build (Set NodeName)
forall b (t :: * -> *) (f :: * -> *).
(Monoid b, Traversable t, Applicative f) =>
t (f b) -> f b
nodesUnion [t1 -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t1
x, t2 -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t2
y]
instance (Nodes t1, Nodes t2, Nodes t3) => Nodes (t1, t2, t3) where
getNodes :: (t1, t2, t3) -> Build (Set NodeName)
getNodes (x :: t1
x, y :: t2
y, z :: t3
z) = [Build (Set NodeName)] -> Build (Set NodeName)
forall b (t :: * -> *) (f :: * -> *).
(Monoid b, Traversable t, Applicative f) =>
t (f b) -> f b
nodesUnion [t1 -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t1
x, t2 -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t2
y, t3 -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t3
z]
instance (Fetchable t1 a1, Fetchable t2 a2) => Fetchable (t1, t2) (a1, a2) where
getFetch :: (t1, t2) -> Build (Fetch (a1, a2))
getFetch (x :: t1
x, y :: t2
y) = (a1 -> a2 -> (a1, a2)) -> Fetch a1 -> Fetch a2 -> Fetch (a1, a2)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) (Fetch a1 -> Fetch a2 -> Fetch (a1, a2))
-> BuildT Identity (Fetch a1)
-> BuildT Identity (Fetch a2 -> Fetch (a1, a2))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t1 -> BuildT Identity (Fetch a1)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t1
x BuildT Identity (Fetch a2 -> Fetch (a1, a2))
-> BuildT Identity (Fetch a2) -> Build (Fetch (a1, a2))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> t2 -> BuildT Identity (Fetch a2)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t2
y
instance (Fetchable t1 a1, Fetchable t2 a2, Fetchable t3 a3)
=> Fetchable (t1, t2, t3) (a1, a2, a3) where
getFetch :: (t1, t2, t3) -> Build (Fetch (a1, a2, a3))
getFetch (x :: t1
x, y :: t2
y, z :: t3
z) =
(a1 -> a2 -> a3 -> (a1, a2, a3))
-> Fetch a1 -> Fetch a2 -> Fetch a3 -> Fetch (a1, a2, a3)
forall (f :: * -> *) a b c d.
Applicative f =>
(a -> b -> c -> d) -> f a -> f b -> f c -> f d
liftA3 (,,) (Fetch a1 -> Fetch a2 -> Fetch a3 -> Fetch (a1, a2, a3))
-> BuildT Identity (Fetch a1)
-> BuildT Identity (Fetch a2 -> Fetch a3 -> Fetch (a1, a2, a3))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t1 -> BuildT Identity (Fetch a1)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t1
x BuildT Identity (Fetch a2 -> Fetch a3 -> Fetch (a1, a2, a3))
-> BuildT Identity (Fetch a2)
-> BuildT Identity (Fetch a3 -> Fetch (a1, a2, a3))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> t2 -> BuildT Identity (Fetch a2)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t2
y BuildT Identity (Fetch a3 -> Fetch (a1, a2, a3))
-> BuildT Identity (Fetch a3) -> Build (Fetch (a1, a2, a3))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> t3 -> BuildT Identity (Fetch a3)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t3
z
instance Nodes t => Nodes [t] where
getNodes :: [t] -> Build (Set NodeName)
getNodes = [Build (Set NodeName)] -> Build (Set NodeName)
forall b (t :: * -> *) (f :: * -> *).
(Monoid b, Traversable t, Applicative f) =>
t (f b) -> f b
nodesUnion ([Build (Set NodeName)] -> Build (Set NodeName))
-> ([t] -> [Build (Set NodeName)]) -> [t] -> Build (Set NodeName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t -> Build (Set NodeName)) -> [t] -> [Build (Set NodeName)]
forall a b. (a -> b) -> [a] -> [b]
map t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes
instance Fetchable t a => Fetchable [t] [a] where
getFetch :: [t] -> Build (Fetch [a])
getFetch ts :: [t]
ts = [Fetch a] -> Fetch [a]
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA ([Fetch a] -> Fetch [a])
-> BuildT Identity [Fetch a] -> Build (Fetch [a])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (t -> BuildT Identity (Fetch a))
-> [t] -> BuildT Identity [Fetch a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM t -> BuildT Identity (Fetch a)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch [t]
ts
instance Nodes t => Nodes (Maybe t) where
getNodes :: Maybe t -> Build (Set NodeName)
getNodes = Maybe (Build (Set NodeName)) -> Build (Set NodeName)
forall b (t :: * -> *) (f :: * -> *).
(Monoid b, Traversable t, Applicative f) =>
t (f b) -> f b
nodesUnion (Maybe (Build (Set NodeName)) -> Build (Set NodeName))
-> (Maybe t -> Maybe (Build (Set NodeName)))
-> Maybe t
-> Build (Set NodeName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t -> Build (Set NodeName))
-> Maybe t -> Maybe (Build (Set NodeName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes
instance Fetchable t a => Fetchable (Maybe t) (Maybe a) where
getFetch :: Maybe t -> Build (Fetch (Maybe a))
getFetch = (Maybe (Fetch a) -> Fetch (Maybe a))
-> BuildT Identity (Maybe (Fetch a)) -> Build (Fetch (Maybe a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Maybe (Fetch a) -> Fetch (Maybe a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA (BuildT Identity (Maybe (Fetch a)) -> Build (Fetch (Maybe a)))
-> (Maybe t -> BuildT Identity (Maybe (Fetch a)))
-> Maybe t
-> Build (Fetch (Maybe a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t -> BuildT Identity (Fetch a))
-> Maybe t -> BuildT Identity (Maybe (Fetch a))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM t -> BuildT Identity (Fetch a)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch
instance Nodes ControlNode where
getNodes :: ControlNode -> Build (Set NodeName)
getNodes (ControlNode o :: NodeName
o) = Set NodeName -> Build (Set NodeName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Set NodeName -> Build (Set NodeName))
-> Set NodeName -> Build (Set NodeName)
forall a b. (a -> b) -> a -> b
$ NodeName -> Set NodeName
forall a. a -> Set a
Set.singleton NodeName
o
instance a ~ () => Fetchable ControlNode a where
getFetch :: ControlNode -> Build (Fetch a)
getFetch _ = Fetch () -> Build (Fetch a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Fetch () -> Build (Fetch a)) -> Fetch () -> Build (Fetch a)
forall a b. (a -> b) -> a -> b
$ () -> Fetch ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
instance Nodes (ListOf f '[]) where
getNodes :: ListOf f '[] -> Build (Set NodeName)
getNodes _ = Set NodeName -> Build (Set NodeName)
forall (m :: * -> *) a. Monad m => a -> m a
return Set NodeName
forall a. Set a
Set.empty
instance (Nodes (f a), Nodes (ListOf f as)) => Nodes (ListOf f (a ': as)) where
getNodes :: ListOf f (a : as) -> Build (Set NodeName)
getNodes (x :: f a
x :/ xs :: ListOf f as
xs) = (Set NodeName -> Set NodeName -> Set NodeName)
-> Build (Set NodeName)
-> Build (Set NodeName)
-> Build (Set NodeName)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Set NodeName -> Set NodeName -> Set NodeName
forall a. Ord a => Set a -> Set a -> Set a
Set.union (f a -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes f a
x) (ListOf f as -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes ListOf f as
xs)
instance l ~ List '[] => Fetchable (ListOf f '[]) l where
getFetch :: ListOf f '[] -> Build (Fetch l)
getFetch _ = Fetch (List '[]) -> Build (Fetch l)
forall (m :: * -> *) a. Monad m => a -> m a
return (Fetch (List '[]) -> Build (Fetch l))
-> Fetch (List '[]) -> Build (Fetch l)
forall a b. (a -> b) -> a -> b
$ List '[] -> Fetch (List '[])
forall (f :: * -> *) a. Applicative f => a -> f a
pure List '[]
forall (f :: * -> *). ListOf f '[]
Nil
instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
=> Fetchable (ListOf f (t ': ts)) (ListOf i (a ': as)) where
getFetch :: ListOf f (t : ts) -> Build (Fetch (ListOf i (a : as)))
getFetch (x :: f a
x :/ xs :: ListOf f as
xs) = (a -> List as -> List (a : as))
-> Fetch a -> Fetch (List as) -> Fetch (List (a : as))
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (\y :: a
y ys :: List as
ys -> a
y a -> List as -> List (a : as)
forall a (as :: [*]). a -> List as -> List (a : as)
/:/ List as
ys) (Fetch a -> Fetch (List as) -> Fetch (List (a : as)))
-> BuildT Identity (Fetch a)
-> BuildT Identity (Fetch (List as) -> Fetch (List (a : as)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a -> BuildT Identity (Fetch a)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch f a
x BuildT Identity (Fetch (List as) -> Fetch (List (a : as)))
-> BuildT Identity (Fetch (List as))
-> BuildT Identity (Fetch (List (a : as)))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ListOf f as -> BuildT Identity (Fetch (List as))
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch ListOf f as
xs
instance Nodes (Tensor v a) where
getNodes :: Tensor v a -> Build (Set NodeName)
getNodes (Tensor o :: v Output
o) = NodeName -> Set NodeName
forall a. a -> Set a
Set.singleton (NodeName -> Set NodeName)
-> (Output -> NodeName) -> Output -> Set NodeName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> NodeName
outputNodeName (Output -> Set NodeName)
-> BuildT Identity Output -> Build (Set NodeName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild v Output
o
fetchTensorVector :: forall a v . (TensorType a)
=> Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector :: Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector (Tensor o :: v Output
o) = do
Text
outputName <- Output -> Text
encodeOutput (Output -> Text) -> BuildT Identity Output -> BuildT Identity Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild v Output
o
Fetch (TensorData a) -> Build (Fetch (TensorData a))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fetch (TensorData a) -> Build (Fetch (TensorData a)))
-> Fetch (TensorData a) -> Build (Fetch (TensorData a))
forall a b. (a -> b) -> a -> b
$ Set Text
-> (Map Text TensorData -> TensorData a) -> Fetch (TensorData a)
forall a. Set Text -> (Map Text TensorData -> a) -> Fetch a
Fetch (Text -> Set Text
forall a. a -> Set a
Set.singleton Text
outputName) ((Map Text TensorData -> TensorData a) -> Fetch (TensorData a))
-> (Map Text TensorData -> TensorData a) -> Fetch (TensorData a)
forall a b. (a -> b) -> a -> b
$ \tensors :: Map Text TensorData
tensors ->
let tensorData :: TensorData
tensorData = Map Text TensorData
tensors Map Text TensorData -> Text -> TensorData
forall k a. Ord k => Map k a -> k -> a
Map.! Text
outputName
expectedType :: DataType
expectedType = a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)
actualType :: DataType
actualType = TensorData -> DataType
FFI.tensorDataType TensorData
tensorData
badTypeError :: TensorData a
badTypeError = [Char] -> TensorData a
forall a. HasCallStack => [Char] -> a
error ([Char] -> TensorData a) -> [Char] -> TensorData a
forall a b. (a -> b) -> a -> b
$ "Bad tensor type: expected "
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ DataType -> [Char]
forall a. Show a => a -> [Char]
show DataType
expectedType
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ ", got "
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ DataType -> [Char]
forall a. Show a => a -> [Char]
show DataType
actualType
in if DataType
expectedType DataType -> DataType -> Bool
forall a. Eq a => a -> a -> Bool
/= DataType
actualType
then TensorData a
badTypeError
else TensorData -> TensorData a
forall a. TensorData -> TensorData a
TensorData TensorData
tensorData
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (TensorData a') where
getFetch :: Tensor v a -> Build (Fetch (TensorData a'))
getFetch = Tensor v a -> Build (Fetch (TensorData a'))
forall a (v :: * -> *).
TensorType a =>
Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector
instance (TensorType a, TensorDataType s a, a ~ a') => Fetchable (Tensor v a) (s a') where
getFetch :: Tensor v a -> Build (Fetch (s a'))
getFetch t :: Tensor v a
t = (TensorData a -> s a) -> Fetch (TensorData a) -> Fetch (s a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TensorData a -> s a
forall (s :: * -> *) a. TensorDataType s a => TensorData a -> s a
decodeTensorData (Fetch (TensorData a) -> Fetch (s a))
-> BuildT Identity (Fetch (TensorData a))
-> BuildT Identity (Fetch (s a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor v a -> BuildT Identity (Fetch (TensorData a))
forall a (v :: * -> *).
TensorType a =>
Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector Tensor v a
t