{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.Session (
Session,
SessionT,
Options,
sessionConfig,
sessionTarget,
sessionTracer,
runSession,
runSessionWithOptions,
MonadBuild(..),
extend,
addGraphDef,
run,
runWithFeeds,
run_,
runWithFeeds_,
asyncProdNodes,
) where
import Data.ProtoLens.Message(defMessage)
import Control.Monad (forever, unless, void)
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (MonadTrans, lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString)
import Data.Default (Default, def)
import Data.ProtoLens (showMessage)
import Data.Set (Set)
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 (Lens', (^.), (&), (.~))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build
import TensorFlow.Nodes
import TensorFlow.Output (NodeName, unNodeName)
import TensorFlow.Tensor
import qualified Data.ByteString.Builder as Builder
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified TensorFlow.Internal.FFI as FFI
type Tracer = Builder.Builder -> IO ()
data SessionState
= SessionState {
SessionState -> Session
rawSession :: FFI.Session
, SessionState -> IO () -> IO ()
asyncCollector :: IO () -> IO ()
, SessionState -> Tracer
tracer :: Tracer
}
newtype SessionT m a
= Session (ReaderT SessionState (BuildT m) a)
deriving (a -> SessionT m b -> SessionT m a
(a -> b) -> SessionT m a -> SessionT m b
(forall a b. (a -> b) -> SessionT m a -> SessionT m b)
-> (forall a b. a -> SessionT m b -> SessionT m a)
-> Functor (SessionT m)
forall a b. a -> SessionT m b -> SessionT m a
forall a b. (a -> b) -> SessionT m a -> SessionT m b
forall (m :: * -> *) a b.
Functor m =>
a -> SessionT m b -> SessionT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SessionT m a -> SessionT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> SessionT m b -> SessionT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> SessionT m b -> SessionT m a
fmap :: (a -> b) -> SessionT m a -> SessionT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SessionT m a -> SessionT m b
Functor, Functor (SessionT m)
a -> SessionT m a
Functor (SessionT m) =>
(forall a. a -> SessionT m a)
-> (forall a b.
SessionT m (a -> b) -> SessionT m a -> SessionT m b)
-> (forall a b c.
(a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c)
-> (forall a b. SessionT m a -> SessionT m b -> SessionT m b)
-> (forall a b. SessionT m a -> SessionT m b -> SessionT m a)
-> Applicative (SessionT m)
SessionT m a -> SessionT m b -> SessionT m b
SessionT m a -> SessionT m b -> SessionT m a
SessionT m (a -> b) -> SessionT m a -> SessionT m b
(a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c
forall a. a -> SessionT m a
forall a b. SessionT m a -> SessionT m b -> SessionT m a
forall a b. SessionT m a -> SessionT m b -> SessionT m b
forall a b. SessionT m (a -> b) -> SessionT m a -> SessionT m b
forall a b c.
(a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c
forall (m :: * -> *). Monad m => Functor (SessionT m)
forall (m :: * -> *) a. Monad m => a -> SessionT m a
forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m a
forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m b
forall (m :: * -> *) a b.
Monad m =>
SessionT m (a -> b) -> SessionT m a -> SessionT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: SessionT m a -> SessionT m b -> SessionT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m a
*> :: SessionT m a -> SessionT m b -> SessionT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m b
liftA2 :: (a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c
<*> :: SessionT m (a -> b) -> SessionT m a -> SessionT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SessionT m (a -> b) -> SessionT m a -> SessionT m b
pure :: a -> SessionT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> SessionT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (SessionT m)
Applicative, Applicative (SessionT m)
a -> SessionT m a
Applicative (SessionT m) =>
(forall a b. SessionT m a -> (a -> SessionT m b) -> SessionT m b)
-> (forall a b. SessionT m a -> SessionT m b -> SessionT m b)
-> (forall a. a -> SessionT m a)
-> Monad (SessionT m)
SessionT m a -> (a -> SessionT m b) -> SessionT m b
SessionT m a -> SessionT m b -> SessionT m b
forall a. a -> SessionT m a
forall a b. SessionT m a -> SessionT m b -> SessionT m b
forall a b. SessionT m a -> (a -> SessionT m b) -> SessionT m b
forall (m :: * -> *). Monad m => Applicative (SessionT m)
forall (m :: * -> *) a. Monad m => a -> SessionT m a
forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m b
forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> (a -> SessionT m b) -> SessionT m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> SessionT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> SessionT m a
>> :: SessionT m a -> SessionT m b -> SessionT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m b
>>= :: SessionT m a -> (a -> SessionT m b) -> SessionT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> (a -> SessionT m b) -> SessionT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (SessionT m)
Monad, Monad (SessionT m)
Monad (SessionT m) =>
(forall a. IO a -> SessionT m a) -> MonadIO (SessionT m)
IO a -> SessionT m a
forall a. IO a -> SessionT m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (SessionT m)
forall (m :: * -> *) a. MonadIO m => IO a -> SessionT m a
liftIO :: IO a -> SessionT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> SessionT m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (SessionT m)
MonadIO, Monad (SessionT m)
e -> SessionT m a
Monad (SessionT m) =>
(forall e a. Exception e => e -> SessionT m a)
-> MonadThrow (SessionT m)
forall e a. Exception e => e -> SessionT m a
forall (m :: * -> *).
Monad m =>
(forall e a. Exception e => e -> m a) -> MonadThrow m
forall (m :: * -> *). MonadThrow m => Monad (SessionT m)
forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> SessionT m a
throwM :: e -> SessionT m a
$cthrowM :: forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> SessionT m a
$cp1MonadThrow :: forall (m :: * -> *). MonadThrow m => Monad (SessionT m)
MonadThrow, MonadThrow (SessionT m)
MonadThrow (SessionT m) =>
(forall e a.
Exception e =>
SessionT m a -> (e -> SessionT m a) -> SessionT m a)
-> MonadCatch (SessionT m)
SessionT m a -> (e -> SessionT m a) -> SessionT m a
forall e a.
Exception e =>
SessionT m a -> (e -> SessionT m a) -> SessionT m a
forall (m :: * -> *). MonadCatch m => MonadThrow (SessionT m)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
SessionT m a -> (e -> SessionT m a) -> SessionT m a
forall (m :: * -> *).
MonadThrow m =>
(forall e a. Exception e => m a -> (e -> m a) -> m a)
-> MonadCatch m
catch :: SessionT m a -> (e -> SessionT m a) -> SessionT m a
$ccatch :: forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
SessionT m a -> (e -> SessionT m a) -> SessionT m a
$cp1MonadCatch :: forall (m :: * -> *). MonadCatch m => MonadThrow (SessionT m)
MonadCatch,
MonadCatch (SessionT m)
MonadCatch (SessionT m) =>
(forall b.
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b)
-> (forall b.
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b)
-> (forall a b c.
SessionT m a
-> (a -> ExitCase b -> SessionT m c)
-> (a -> SessionT m b)
-> SessionT m (b, c))
-> MonadMask (SessionT m)
SessionT m a
-> (a -> ExitCase b -> SessionT m c)
-> (a -> SessionT m b)
-> SessionT m (b, c)
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
forall b.
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
forall a b c.
SessionT m a
-> (a -> ExitCase b -> SessionT m c)
-> (a -> SessionT m b)
-> SessionT m (b, c)
forall (m :: * -> *).
MonadCatch m =>
(forall b. ((forall a. m a -> m a) -> m b) -> m b)
-> (forall b. ((forall a. m a -> m a) -> m b) -> m b)
-> (forall a b c.
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c))
-> MonadMask m
forall (m :: * -> *). MonadMask m => MonadCatch (SessionT m)
forall (m :: * -> *) b.
MonadMask m =>
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
forall (m :: * -> *) a b c.
MonadMask m =>
SessionT m a
-> (a -> ExitCase b -> SessionT m c)
-> (a -> SessionT m b)
-> SessionT m (b, c)
generalBracket :: SessionT m a
-> (a -> ExitCase b -> SessionT m c)
-> (a -> SessionT m b)
-> SessionT m (b, c)
$cgeneralBracket :: forall (m :: * -> *) a b c.
MonadMask m =>
SessionT m a
-> (a -> ExitCase b -> SessionT m c)
-> (a -> SessionT m b)
-> SessionT m (b, c)
uninterruptibleMask :: ((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
$cuninterruptibleMask :: forall (m :: * -> *) b.
MonadMask m =>
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
mask :: ((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
$cmask :: forall (m :: * -> *) b.
MonadMask m =>
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
$cp1MonadMask :: forall (m :: * -> *). MonadMask m => MonadCatch (SessionT m)
MonadMask, Monad (SessionT m)
Monad (SessionT m) =>
(forall a. String -> SessionT m a) -> MonadFail (SessionT m)
String -> SessionT m a
forall a. String -> SessionT m a
forall (m :: * -> *).
Monad m =>
(forall a. String -> m a) -> MonadFail m
forall (m :: * -> *). MonadFail m => Monad (SessionT m)
forall (m :: * -> *) a. MonadFail m => String -> SessionT m a
fail :: String -> SessionT m a
$cfail :: forall (m :: * -> *) a. MonadFail m => String -> SessionT m a
$cp1MonadFail :: forall (m :: * -> *). MonadFail m => Monad (SessionT m)
MonadFail)
instance MonadTrans SessionT where
lift :: m a -> SessionT m a
lift = ReaderT SessionState (BuildT m) a -> SessionT m a
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session (ReaderT SessionState (BuildT m) a -> SessionT m a)
-> (m a -> ReaderT SessionState (BuildT m) a)
-> m a
-> SessionT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BuildT m a -> ReaderT SessionState (BuildT m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (BuildT m a -> ReaderT SessionState (BuildT m) a)
-> (m a -> BuildT m a) -> m a -> ReaderT SessionState (BuildT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> BuildT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
type Session = SessionT IO
runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a
runSession :: SessionT m a -> m a
runSession = Options -> SessionT m a -> m a
forall (m :: * -> *) a.
(MonadMask m, MonadIO m) =>
Options -> SessionT m a -> m a
runSessionWithOptions Options
forall a. Default a => a
def
data Options = Options
{ Options -> ByteString
_sessionTarget :: ByteString
, Options -> ConfigProto
_sessionConfig :: ConfigProto
, Options -> Tracer
_sessionTracer :: Tracer
}
instance Default Options where
def :: Options
def = Options :: ByteString -> ConfigProto -> Tracer -> Options
Options
{ _sessionTarget :: ByteString
_sessionTarget = ""
, _sessionConfig :: ConfigProto
_sessionConfig = ConfigProto
forall msg. Message msg => msg
defMessage
, _sessionTracer :: Tracer
_sessionTracer = IO () -> Tracer
forall a b. a -> b -> a
const (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
}
sessionTarget :: Lens' Options ByteString
sessionTarget :: LensLike' f Options ByteString
sessionTarget = (Options -> ByteString)
-> (Options -> ByteString -> Options)
-> Lens Options Options ByteString ByteString
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Options -> ByteString
_sessionTarget (\g :: Options
g x :: ByteString
x -> Options
g { _sessionTarget :: ByteString
_sessionTarget = ByteString
x })
sessionConfig :: Lens' Options ConfigProto
sessionConfig :: LensLike' f Options ConfigProto
sessionConfig = (Options -> ConfigProto)
-> (Options -> ConfigProto -> Options)
-> Lens Options Options ConfigProto ConfigProto
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Options -> ConfigProto
_sessionConfig (\g :: Options
g x :: ConfigProto
x -> Options
g { _sessionConfig :: ConfigProto
_sessionConfig = ConfigProto
x })
sessionTracer :: Lens' Options Tracer
sessionTracer :: LensLike' f Options Tracer
sessionTracer = (Options -> Tracer)
-> (Options -> Tracer -> Options)
-> Lens Options Options Tracer Tracer
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Options -> Tracer
_sessionTracer (\g :: Options
g x :: Tracer
x -> Options
g { _sessionTracer :: Tracer
_sessionTracer = Tracer
x })
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
runSessionWithOptions :: Options -> SessionT m a -> m a
runSessionWithOptions options :: Options
options (Session m :: ReaderT SessionState (BuildT m) a
m) =
(SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Session -> m a) -> m a
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
(SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Session -> m a) -> m a
FFI.withSession SessionOptions -> IO ()
applyOptions (((IO () -> IO ()) -> Session -> m a) -> m a)
-> ((IO () -> IO ()) -> Session -> m a) -> m a
forall a b. (a -> b) -> a -> b
$
\as :: IO () -> IO ()
as rs :: Session
rs ->
let initState :: SessionState
initState = Session -> (IO () -> IO ()) -> Tracer -> SessionState
SessionState Session
rs IO () -> IO ()
as (Options
options Options -> FoldLike Tracer Options Options Tracer Tracer -> Tracer
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Tracer Options Options Tracer Tracer
Lens Options Options Tracer Tracer
sessionTracer)
in BuildT m a -> m a
forall (m :: * -> *) a. Monad m => BuildT m a -> m a
evalBuildT (ReaderT SessionState (BuildT m) a -> SessionState -> BuildT m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT SessionState (BuildT m) a
m SessionState
initState)
where applyOptions :: SessionOptions -> IO ()
applyOptions opt :: SessionOptions
opt = do
ByteString -> SessionOptions -> IO ()
FFI.setSessionTarget (Options
options Options
-> FoldLike ByteString Options Options ByteString ByteString
-> ByteString
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike ByteString Options Options ByteString ByteString
Lens Options Options ByteString ByteString
sessionTarget) SessionOptions
opt
ConfigProto -> SessionOptions -> IO ()
FFI.setSessionConfig (Options
options Options
-> FoldLike ConfigProto Options Options ConfigProto ConfigProto
-> ConfigProto
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike ConfigProto Options Options ConfigProto ConfigProto
Lens Options Options ConfigProto ConfigProto
sessionConfig) SessionOptions
opt
instance Monad m => MonadBuild (SessionT m) where
build :: Build a -> SessionT m a
build = ReaderT SessionState (BuildT m) a -> SessionT m a
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session (ReaderT SessionState (BuildT m) a -> SessionT m a)
-> (Build a -> ReaderT SessionState (BuildT m) a)
-> Build a
-> SessionT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BuildT m a -> ReaderT SessionState (BuildT m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (BuildT m a -> ReaderT SessionState (BuildT m) a)
-> (Build a -> BuildT m a)
-> Build a
-> ReaderT SessionState (BuildT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Build a -> BuildT m a
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build
extend :: MonadIO m => SessionT m ()
extend :: SessionT m ()
extend = do
Session
session <- ReaderT SessionState (BuildT m) Session -> SessionT m Session
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session ((SessionState -> Session)
-> ReaderT SessionState (BuildT m) Session
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
asks SessionState -> Session
rawSession)
Tracer
trace <- ReaderT SessionState (BuildT m) Tracer -> SessionT m Tracer
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session ((SessionState -> Tracer) -> ReaderT SessionState (BuildT m) Tracer
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
asks SessionState -> Tracer
tracer)
[NodeDef]
nodesToExtend <- Build [NodeDef] -> SessionT m [NodeDef]
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build Build [NodeDef]
forall (m :: * -> *). MonadBuild m => m [NodeDef]
flushNodeBuffer
Bool -> SessionT m () -> SessionT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([NodeDef] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [NodeDef]
nodesToExtend) (SessionT m () -> SessionT m ()) -> SessionT m () -> SessionT m ()
forall a b. (a -> b) -> a -> b
$ IO () -> SessionT m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> SessionT m ()) -> IO () -> SessionT m ()
forall a b. (a -> b) -> a -> b
$ do
let graphDef :: GraphDef
graphDef = (GraphDef
forall msg. Message msg => msg
defMessage :: GraphDef) GraphDef -> (GraphDef -> GraphDef) -> GraphDef
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f GraphDef [NodeDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "node" a) =>
LensLike' f s a
node (forall (f :: * -> *).
Identical f =>
LensLike' f GraphDef [NodeDef])
-> [NodeDef] -> GraphDef -> GraphDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [NodeDef]
nodesToExtend
Tracer
trace ("Session.extend " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> String -> Builder
Builder.string8 (GraphDef -> String
forall msg. Message msg => msg -> String
showMessage GraphDef
graphDef))
Session -> GraphDef -> IO ()
FFI.extendGraph Session
session GraphDef
graphDef
[NodeName]
initializers <- Build [NodeName] -> SessionT m [NodeName]
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build Build [NodeName]
forall (m :: * -> *). Monad m => BuildT m [NodeName]
flushInitializers
Bool -> SessionT m () -> SessionT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([NodeName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [NodeName]
initializers) (SessionT m () -> SessionT m ()) -> SessionT m () -> SessionT m ()
forall a b. (a -> b) -> a -> b
$
SessionT m [TensorData] -> SessionT m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (SessionT m [TensorData] -> SessionT m ())
-> SessionT m [TensorData] -> SessionT m ()
forall a b. (a -> b) -> a -> b
$ IO [TensorData] -> SessionT m [TensorData]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [TensorData] -> SessionT m [TensorData])
-> IO [TensorData] -> SessionT m [TensorData]
forall a b. (a -> b) -> a -> b
$ Session
-> [(ByteString, TensorData)]
-> [ByteString]
-> [ByteString]
-> IO [TensorData]
FFI.run Session
session [] [] ([NodeName] -> [ByteString]
toNodeNames [NodeName]
initializers)
run :: (MonadIO m, Fetchable t a) => t -> SessionT m a
run :: t -> SessionT m a
run = [Feed] -> t -> SessionT m a
forall (m :: * -> *) t a.
(MonadIO m, Fetchable t a) =>
[Feed] -> t -> SessionT m a
runWithFeeds []
runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a
runWithFeeds :: [Feed] -> t -> SessionT m a
runWithFeeds feeds :: [Feed]
feeds t :: t
t = do
Set NodeName
ns <- Build (Set NodeName) -> SessionT m (Set NodeName)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build (Set NodeName) -> SessionT m (Set NodeName))
-> Build (Set NodeName) -> SessionT m (Set NodeName)
forall a b. (a -> b) -> a -> b
$ t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t
t
Fetch a
fetch <- Build (Fetch a) -> SessionT m (Fetch a)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build (Fetch a) -> SessionT m (Fetch a))
-> Build (Fetch a) -> SessionT m (Fetch a)
forall a b. (a -> b) -> a -> b
$ t -> Build (Fetch a)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t
t
[Feed] -> Set NodeName -> Fetch a -> SessionT m a
forall (m :: * -> *) a.
MonadIO m =>
[Feed] -> Set NodeName -> Fetch a -> SessionT m a
runFetchWithFeeds [Feed]
feeds Set NodeName
ns Fetch a
fetch
runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> SessionT m a
runFetchWithFeeds feeds :: [Feed]
feeds target :: Set NodeName
target (Fetch fetch :: Set Text
fetch restore :: Map Text TensorData -> a
restore) = do
SessionT m ()
forall (m :: * -> *). MonadIO m => SessionT m ()
extend
let feeds' :: [(ByteString, TensorData)]
feeds' = [Feed] -> [(ByteString, TensorData)]
fixFeeds [Feed]
feeds
let fetchNames :: [ByteString]
fetchNames = Text -> ByteString
encodeUtf8 (Text -> ByteString) -> [Text] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set Text -> [Text]
forall a. Set a -> [a]
Set.toList Set Text
fetch
targetNames :: [ByteString]
targetNames = [NodeName] -> [ByteString]
toNodeNames ([NodeName] -> [ByteString]) -> [NodeName] -> [ByteString]
forall a b. (a -> b) -> a -> b
$ Set NodeName -> [NodeName]
forall a. Set a -> [a]
Set.toList Set NodeName
target
Session
session <- ReaderT SessionState (BuildT m) Session -> SessionT m Session
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session ((SessionState -> Session)
-> ReaderT SessionState (BuildT m) Session
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
asks SessionState -> Session
rawSession)
[TensorData]
runResult <- IO [TensorData] -> SessionT m [TensorData]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [TensorData] -> SessionT m [TensorData])
-> IO [TensorData] -> SessionT m [TensorData]
forall a b. (a -> b) -> a -> b
$ Session
-> [(ByteString, TensorData)]
-> [ByteString]
-> [ByteString]
-> IO [TensorData]
FFI.run Session
session
[(ByteString, TensorData)]
feeds'
[ByteString]
fetchNames
[ByteString]
targetNames
let resultTensorsMap :: Map Text TensorData
resultTensorsMap = [(Text, TensorData)] -> Map Text TensorData
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Text, TensorData)] -> Map Text TensorData)
-> [(Text, TensorData)] -> Map Text TensorData
forall a b. (a -> b) -> a -> b
$ [Text] -> [TensorData] -> [(Text, TensorData)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Set Text -> [Text]
forall a. Set a -> [a]
Set.toList Set Text
fetch) [TensorData]
runResult
a -> SessionT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> SessionT m a) -> a -> SessionT m a
forall a b. (a -> b) -> a -> b
$ Map Text TensorData -> a
restore Map Text TensorData
resultTensorsMap
toNodeNames :: [NodeName] -> [ByteString]
toNodeNames :: [NodeName] -> [ByteString]
toNodeNames = (NodeName -> ByteString) -> [NodeName] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (Text -> ByteString
encodeUtf8 (Text -> ByteString)
-> (NodeName -> Text) -> NodeName -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeName -> Text
unNodeName)
run_ :: (MonadIO m, Nodes t) => t -> SessionT m ()
run_ :: t -> SessionT m ()
run_ = [Feed] -> t -> SessionT m ()
forall (m :: * -> *) t.
(MonadIO m, Nodes t) =>
[Feed] -> t -> SessionT m ()
runWithFeeds_ []
runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m ()
runWithFeeds_ :: [Feed] -> t -> SessionT m ()
runWithFeeds_ feeds :: [Feed]
feeds t :: t
t = do
Set NodeName
ns <- Build (Set NodeName) -> SessionT m (Set NodeName)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build (Set NodeName) -> SessionT m (Set NodeName))
-> Build (Set NodeName) -> SessionT m (Set NodeName)
forall a b. (a -> b) -> a -> b
$ t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t
t
[Feed] -> Set NodeName -> Fetch () -> SessionT m ()
forall (m :: * -> *) a.
MonadIO m =>
[Feed] -> Set NodeName -> Fetch a -> SessionT m a
runFetchWithFeeds [Feed]
feeds Set NodeName
ns (() -> Fetch ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)]
fixFeeds :: [Feed] -> [(ByteString, TensorData)]
fixFeeds = (Feed -> (ByteString, TensorData))
-> [Feed] -> [(ByteString, TensorData)]
forall a b. (a -> b) -> [a] -> [b]
map ((Feed -> (ByteString, TensorData))
-> [Feed] -> [(ByteString, TensorData)])
-> (Feed -> (ByteString, TensorData))
-> [Feed]
-> [(ByteString, TensorData)]
forall a b. (a -> b) -> a -> b
$ \(Feed o :: Output
o d :: TensorData
d) -> (Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Output -> Text
encodeOutput Output
o, TensorData
d)
asyncProdNodes :: (MonadIO m, Nodes t)
=> t
-> SessionT m ()
asyncProdNodes :: t -> SessionT m ()
asyncProdNodes nodes :: t
nodes = do
Set NodeName
target <- Build (Set NodeName) -> SessionT m (Set NodeName)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t
nodes)
SessionT m ()
forall (m :: * -> *). MonadIO m => SessionT m ()
extend
let targetNames :: [ByteString]
targetNames = [NodeName] -> [ByteString]
toNodeNames ([NodeName] -> [ByteString]) -> [NodeName] -> [ByteString]
forall a b. (a -> b) -> a -> b
$ Set NodeName -> [NodeName]
forall a. Set a -> [a]
Set.toList Set NodeName
target
SessionState
state <- ReaderT SessionState (BuildT m) SessionState
-> SessionT m SessionState
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session ReaderT SessionState (BuildT m) SessionState
forall (m :: * -> *) r. Monad m => ReaderT r m r
ask
let loop :: IO b
loop = IO () -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO [TensorData] -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Session
-> [(ByteString, TensorData)]
-> [ByteString]
-> [ByteString]
-> IO [TensorData]
FFI.run (SessionState -> Session
rawSession SessionState
state) [] [] [ByteString]
targetNames))
IO () -> SessionT m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (SessionState -> IO () -> IO ()
asyncCollector SessionState
state IO ()
forall b. IO b
loop)