{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module TensorFlow.Tensor where
import Data.ByteString (ByteString)
import Data.String (IsString(..))
import qualified Data.Text as Text
import Lens.Family2 ((^.))
import Lens.Family2.State ((%=), use)
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (device)
import TensorFlow.Build
import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..))
import TensorFlow.Types
( TensorType
, TensorData(..)
, ListOf(..)
)
import qualified TensorFlow.Internal.FFI as FFI
data Tensor v a where
Tensor :: TensorKind v => {Tensor v a -> v Output
tensorOutput :: v Output} -> Tensor v a
newtype Value a = Value {Value a -> a
runValue :: a}
deriving a -> Value b -> Value a
(a -> b) -> Value a -> Value b
(forall a b. (a -> b) -> Value a -> Value b)
-> (forall a b. a -> Value b -> Value a) -> Functor Value
forall a b. a -> Value b -> Value a
forall a b. (a -> b) -> Value a -> Value b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Value b -> Value a
$c<$ :: forall a b. a -> Value b -> Value a
fmap :: (a -> b) -> Value a -> Value b
$cfmap :: forall a b. (a -> b) -> Value a -> Value b
Functor
instance Applicative Value where
pure :: a -> Value a
pure = a -> Value a
forall a. a -> Value a
Value
Value f :: a -> b
f <*> :: Value (a -> b) -> Value a -> Value b
<*> Value x :: a
x = b -> Value b
forall a. a -> Value a
Value (b -> Value b) -> b -> Value b
forall a b. (a -> b) -> a -> b
$ a -> b
f a
x
instance Monad Value where
f :: Value a
f >>= :: Value a -> (a -> Value b) -> Value b
>>= g :: a -> Value b
g = a -> Value b
g (a -> Value b) -> a -> Value b
forall a b. (a -> b) -> a -> b
$ Value a -> a
forall a. Value a -> a
runValue Value a
f
newtype Ref a = Ref {Ref a -> a
runRef :: a}
deriving a -> Ref b -> Ref a
(a -> b) -> Ref a -> Ref b
(forall a b. (a -> b) -> Ref a -> Ref b)
-> (forall a b. a -> Ref b -> Ref a) -> Functor Ref
forall a b. a -> Ref b -> Ref a
forall a b. (a -> b) -> Ref a -> Ref b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Ref b -> Ref a
$c<$ :: forall a b. a -> Ref b -> Ref a
fmap :: (a -> b) -> Ref a -> Ref b
$cfmap :: forall a b. (a -> b) -> Ref a -> Ref b
Functor
instance Applicative Ref where
pure :: a -> Ref a
pure = a -> Ref a
forall a. a -> Ref a
Ref
Ref f :: a -> b
f <*> :: Ref (a -> b) -> Ref a -> Ref b
<*> Ref x :: a
x = b -> Ref b
forall a. a -> Ref a
Ref (b -> Ref b) -> b -> Ref b
forall a b. (a -> b) -> a -> b
$ a -> b
f a
x
instance Monad Ref where
f :: Ref a
f >>= :: Ref a -> (a -> Ref b) -> Ref b
>>= g :: a -> Ref b
g = a -> Ref b
g (a -> Ref b) -> a -> Ref b
forall a b. (a -> b) -> a -> b
$ Ref a -> a
forall a. Ref a -> a
runRef Ref a
f
value :: Tensor Ref a -> Tensor Value a
value :: Tensor Ref a -> Tensor Value a
value (Tensor o :: Ref Output
o) = Value Output -> Tensor Value a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (Value Output -> Tensor Value a) -> Value Output -> Tensor Value a
forall a b. (a -> b) -> a -> b
$ Output -> Value Output
forall a. a -> Value a
Value (Output -> Value Output) -> Output -> Value Output
forall a b. (a -> b) -> a -> b
$ Ref Output -> Output
forall a. Ref a -> a
runRef Ref Output
o
renderValue :: MonadBuild m => Tensor v a -> m (Tensor Value a)
renderValue :: Tensor v a -> m (Tensor Value a)
renderValue (Tensor o :: v Output
o) = Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ BuildT Identity Output -> Tensor Build a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (BuildT Identity Output -> Tensor Build a)
-> BuildT Identity Output -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild v Output
o
data Feed = Feed Output FFI.TensorData
class Rendered t where
renderedOutput :: t a -> Output
instance Rendered (Tensor Value) where
renderedOutput :: Tensor Value a -> Output
renderedOutput = Value Output -> Output
forall a. Value a -> a
runValue (Value Output -> Output)
-> (Tensor Value a -> Value Output) -> Tensor Value a -> Output
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Value a -> Value Output
forall (v :: * -> *) a. Tensor v a -> v Output
tensorOutput
instance Rendered (Tensor Ref) where
renderedOutput :: Tensor Ref a -> Output
renderedOutput = Ref Output -> Output
forall a. Ref a -> a
runRef (Ref Output -> Output)
-> (Tensor Ref a -> Ref Output) -> Tensor Ref a -> Output
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Ref a -> Ref Output
forall (v :: * -> *) a. Tensor v a -> v Output
tensorOutput
tensorNodeName :: Rendered t => t a -> NodeName
tensorNodeName :: t a -> NodeName
tensorNodeName = Output -> NodeName
outputNodeName (Output -> NodeName) -> (t a -> Output) -> t a -> NodeName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput
feed :: Rendered t => t a -> TensorData a -> Feed
feed :: t a -> TensorData a -> Feed
feed t :: t a
t (TensorData td :: TensorData
td) = Output -> TensorData -> Feed
Feed (t a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput t a
t) TensorData
td
tensorFromName :: TensorKind v => Text.Text -> Tensor v a
tensorFromName :: Text -> Tensor v a
tensorFromName = v Output -> Tensor v a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (v Output -> Tensor v a)
-> (Text -> v Output) -> Text -> Tensor v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> v Output
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Output -> v Output) -> (Text -> Output) -> Text -> v Output
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Output
forall a. IsString a => String -> a
fromString (String -> Output) -> (Text -> String) -> Text -> Output
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
Text.unpack
tensorValueFromName :: Text.Text -> Tensor Value a
tensorValueFromName :: Text -> Tensor Value a
tensorValueFromName = Text -> Tensor Value a
forall (v :: * -> *) a. TensorKind v => Text -> Tensor v a
tensorFromName
tensorRefFromName :: Text.Text -> Tensor Ref a
tensorRefFromName :: Text -> Tensor Ref a
tensorRefFromName = Text -> Tensor Ref a
forall (v :: * -> *) a. TensorKind v => Text -> Tensor v a
tensorFromName
type TensorList v = ListOf (Tensor v)
tensorListOutputs :: Rendered (Tensor v) => TensorList v as -> [Output]
tensorListOutputs :: TensorList v as -> [Output]
tensorListOutputs Nil = []
tensorListOutputs (t :: Tensor v a
t :/ ts :: ListOf (Tensor v) as
ts) = Tensor v a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput Tensor v a
t Output -> [Output] -> [Output]
forall a. a -> [a] -> [a]
: ListOf (Tensor v) as -> [Output]
forall (v :: * -> *) (as :: [*]).
Rendered (Tensor v) =>
TensorList v as -> [Output]
tensorListOutputs ListOf (Tensor v) as
ts
colocateWith :: (MonadBuild m, Rendered t) => t b -> m a -> m a
colocateWith :: t b -> m a -> m a
colocateWith t :: t b
t x :: m a
x = do
Device
d <- Build Device -> m Device
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build Device -> m Device) -> Build Device -> m Device
forall a b. (a -> b) -> a -> b
$ Text -> Device
Device (Text -> Device) -> (NodeDef -> Text) -> NodeDef -> Device
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "device" a) =>
LensLike' f s a
device)
(NodeDef -> Device) -> BuildT Identity NodeDef -> Build Device
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NodeName -> BuildT Identity NodeDef
lookupNode (Output -> NodeName
outputNodeName (Output -> NodeName) -> Output -> NodeName
forall a b. (a -> b) -> a -> b
$ t b -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput t b
t)
Maybe Device -> m a -> m a
forall (m :: * -> *) a. MonadBuild m => Maybe Device -> m a -> m a
withDevice (Device -> Maybe Device
forall a. a -> Maybe a
Just Device
d) m a
x
render :: MonadBuild m => Tensor Build a -> m (Tensor Value a)
render :: Tensor Build a -> m (Tensor Value a)
render (Tensor t :: BuildT Identity Output
t) = Value Output -> Tensor Value a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (Value Output -> Tensor Value a)
-> (Output -> Value Output) -> Output -> Tensor Value a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> Value Output
forall a. a -> Value a
Value (Output -> Tensor Value a) -> m Output -> m (Tensor Value a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuildT Identity Output -> m Output
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build BuildT Identity Output
t
expr :: TensorKind v => Tensor v a -> Tensor Build a
expr :: Tensor v a -> Tensor Build a
expr (Tensor o :: v Output
o) = BuildT Identity Output -> Tensor Build a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (BuildT Identity Output -> Tensor Build a)
-> BuildT Identity Output -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild v Output
o
addSummary :: (MonadBuild m, TensorKind v) => Tensor v ByteString
-> m ()
addSummary :: Tensor v ByteString -> m ()
addSummary t :: Tensor v ByteString
t = Build () -> m ()
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build () -> m ()) -> Build () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Output
o <- v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild (v Output -> BuildT Identity Output)
-> v Output -> BuildT Identity Output
forall a b. (a -> b) -> a -> b
$ Tensor v ByteString -> v Output
forall (v :: * -> *) a. Tensor v a -> v Output
tensorOutput Tensor v ByteString
t
Lens' GraphState [Output]
forall (f :: * -> *).
Identical f =>
LensLike' f GraphState [Output]
summaries (forall (f :: * -> *).
Identical f =>
LensLike' f GraphState [Output])
-> ([Output] -> [Output]) -> Build ()
forall s (m :: * -> *) a b.
MonadState s m =>
Setter s s a b -> (a -> b) -> m ()
%= (Output
o Output -> [Output] -> [Output]
forall a. a -> [a] -> [a]
:)
collectAllSummaries :: MonadBuild m => m [SummaryTensor]
collectAllSummaries :: m [SummaryTensor]
collectAllSummaries = Build [SummaryTensor] -> m [SummaryTensor]
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build [SummaryTensor] -> m [SummaryTensor])
-> Build [SummaryTensor] -> m [SummaryTensor]
forall a b. (a -> b) -> a -> b
$ (Output -> SummaryTensor) -> [Output] -> [SummaryTensor]
forall a b. (a -> b) -> [a] -> [b]
map (Value Output -> SummaryTensor
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (Value Output -> SummaryTensor)
-> (Output -> Value Output) -> Output -> SummaryTensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> Value Output
forall a. a -> Value a
Value) ([Output] -> [SummaryTensor])
-> BuildT Identity [Output] -> Build [SummaryTensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FoldLike [Output] GraphState GraphState [Output] [Output]
-> BuildT Identity [Output]
forall s (m :: * -> *) a t b.
MonadState s m =>
FoldLike a s t a b -> m a
use FoldLike [Output] GraphState GraphState [Output] [Output]
Lens' GraphState [Output]
summaries
type SummaryTensor = Tensor Value ByteString
class Monad v => TensorKind v where
toBuild :: v a -> Build a
instance TensorKind Value where
toBuild :: Value a -> Build a
toBuild = a -> Build a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Build a) -> (Value a -> a) -> Value a -> Build a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value a -> a
forall a. Value a -> a
runValue
instance TensorKind Ref where
toBuild :: Ref a -> Build a
toBuild = a -> Build a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Build a) -> (Ref a -> a) -> Ref a -> Build a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ref a -> a
forall a. Ref a -> a
runRef
instance TensorKind Build where
toBuild :: Build a -> Build a
toBuild = Build a -> Build a
forall a. a -> a
id
class ToTensor t where
toTensor :: TensorType a => t a -> Tensor Build a
instance TensorKind v => ToTensor (Tensor v) where
toTensor :: Tensor v a -> Tensor Build a
toTensor = Tensor v a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr