-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}

module TensorFlow.Session (
    Session,
    SessionT,
    Options,
    sessionConfig,
    sessionTarget,
    sessionTracer,
    runSession,
    runSessionWithOptions,
    MonadBuild(..),
    extend,
    addGraphDef,
    run,
    runWithFeeds,
    run_,
    runWithFeeds_,
    asyncProdNodes,
    ) where

import Data.ProtoLens.Message(defMessage)
import Control.Monad (forever, unless, void)
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (MonadTrans, lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString)
import Data.Default (Default, def)
import Data.ProtoLens (showMessage)
import Data.Set (Set)
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 (Lens', (^.), (&), (.~))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build
import TensorFlow.Nodes
import TensorFlow.Output (NodeName, unNodeName)
import TensorFlow.Tensor

import qualified Data.ByteString.Builder as Builder
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified TensorFlow.Internal.FFI as FFI

-- | An action for logging.
type Tracer = Builder.Builder -> IO ()

-- Common state threaded through the session.
data SessionState
    = SessionState {
          SessionState -> Session
rawSession :: FFI.Session
        , SessionState -> IO () -> IO ()
asyncCollector :: IO () -> IO ()
          -- ^ Starts the given action concurrently.
        , SessionState -> Tracer
tracer :: Tracer
        }

newtype SessionT m a
    = Session (ReaderT SessionState (BuildT m) a)
    deriving (a -> SessionT m b -> SessionT m a
(a -> b) -> SessionT m a -> SessionT m b
(forall a b. (a -> b) -> SessionT m a -> SessionT m b)
-> (forall a b. a -> SessionT m b -> SessionT m a)
-> Functor (SessionT m)
forall a b. a -> SessionT m b -> SessionT m a
forall a b. (a -> b) -> SessionT m a -> SessionT m b
forall (m :: * -> *) a b.
Functor m =>
a -> SessionT m b -> SessionT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SessionT m a -> SessionT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> SessionT m b -> SessionT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> SessionT m b -> SessionT m a
fmap :: (a -> b) -> SessionT m a -> SessionT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SessionT m a -> SessionT m b
Functor, Functor (SessionT m)
a -> SessionT m a
Functor (SessionT m) =>
(forall a. a -> SessionT m a)
-> (forall a b.
    SessionT m (a -> b) -> SessionT m a -> SessionT m b)
-> (forall a b c.
    (a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c)
-> (forall a b. SessionT m a -> SessionT m b -> SessionT m b)
-> (forall a b. SessionT m a -> SessionT m b -> SessionT m a)
-> Applicative (SessionT m)
SessionT m a -> SessionT m b -> SessionT m b
SessionT m a -> SessionT m b -> SessionT m a
SessionT m (a -> b) -> SessionT m a -> SessionT m b
(a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c
forall a. a -> SessionT m a
forall a b. SessionT m a -> SessionT m b -> SessionT m a
forall a b. SessionT m a -> SessionT m b -> SessionT m b
forall a b. SessionT m (a -> b) -> SessionT m a -> SessionT m b
forall a b c.
(a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c
forall (m :: * -> *). Monad m => Functor (SessionT m)
forall (m :: * -> *) a. Monad m => a -> SessionT m a
forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m a
forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m b
forall (m :: * -> *) a b.
Monad m =>
SessionT m (a -> b) -> SessionT m a -> SessionT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: SessionT m a -> SessionT m b -> SessionT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m a
*> :: SessionT m a -> SessionT m b -> SessionT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m b
liftA2 :: (a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c
<*> :: SessionT m (a -> b) -> SessionT m a -> SessionT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SessionT m (a -> b) -> SessionT m a -> SessionT m b
pure :: a -> SessionT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> SessionT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (SessionT m)
Applicative, Applicative (SessionT m)
a -> SessionT m a
Applicative (SessionT m) =>
(forall a b. SessionT m a -> (a -> SessionT m b) -> SessionT m b)
-> (forall a b. SessionT m a -> SessionT m b -> SessionT m b)
-> (forall a. a -> SessionT m a)
-> Monad (SessionT m)
SessionT m a -> (a -> SessionT m b) -> SessionT m b
SessionT m a -> SessionT m b -> SessionT m b
forall a. a -> SessionT m a
forall a b. SessionT m a -> SessionT m b -> SessionT m b
forall a b. SessionT m a -> (a -> SessionT m b) -> SessionT m b
forall (m :: * -> *). Monad m => Applicative (SessionT m)
forall (m :: * -> *) a. Monad m => a -> SessionT m a
forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m b
forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> (a -> SessionT m b) -> SessionT m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> SessionT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> SessionT m a
>> :: SessionT m a -> SessionT m b -> SessionT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> SessionT m b -> SessionT m b
>>= :: SessionT m a -> (a -> SessionT m b) -> SessionT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SessionT m a -> (a -> SessionT m b) -> SessionT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (SessionT m)
Monad, Monad (SessionT m)
Monad (SessionT m) =>
(forall a. IO a -> SessionT m a) -> MonadIO (SessionT m)
IO a -> SessionT m a
forall a. IO a -> SessionT m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (SessionT m)
forall (m :: * -> *) a. MonadIO m => IO a -> SessionT m a
liftIO :: IO a -> SessionT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> SessionT m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (SessionT m)
MonadIO, Monad (SessionT m)
e -> SessionT m a
Monad (SessionT m) =>
(forall e a. Exception e => e -> SessionT m a)
-> MonadThrow (SessionT m)
forall e a. Exception e => e -> SessionT m a
forall (m :: * -> *).
Monad m =>
(forall e a. Exception e => e -> m a) -> MonadThrow m
forall (m :: * -> *). MonadThrow m => Monad (SessionT m)
forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> SessionT m a
throwM :: e -> SessionT m a
$cthrowM :: forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> SessionT m a
$cp1MonadThrow :: forall (m :: * -> *). MonadThrow m => Monad (SessionT m)
MonadThrow, MonadThrow (SessionT m)
MonadThrow (SessionT m) =>
(forall e a.
 Exception e =>
 SessionT m a -> (e -> SessionT m a) -> SessionT m a)
-> MonadCatch (SessionT m)
SessionT m a -> (e -> SessionT m a) -> SessionT m a
forall e a.
Exception e =>
SessionT m a -> (e -> SessionT m a) -> SessionT m a
forall (m :: * -> *). MonadCatch m => MonadThrow (SessionT m)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
SessionT m a -> (e -> SessionT m a) -> SessionT m a
forall (m :: * -> *).
MonadThrow m =>
(forall e a. Exception e => m a -> (e -> m a) -> m a)
-> MonadCatch m
catch :: SessionT m a -> (e -> SessionT m a) -> SessionT m a
$ccatch :: forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
SessionT m a -> (e -> SessionT m a) -> SessionT m a
$cp1MonadCatch :: forall (m :: * -> *). MonadCatch m => MonadThrow (SessionT m)
MonadCatch,
              MonadCatch (SessionT m)
MonadCatch (SessionT m) =>
(forall b.
 ((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
 -> SessionT m b)
-> (forall b.
    ((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
    -> SessionT m b)
-> (forall a b c.
    SessionT m a
    -> (a -> ExitCase b -> SessionT m c)
    -> (a -> SessionT m b)
    -> SessionT m (b, c))
-> MonadMask (SessionT m)
SessionT m a
-> (a -> ExitCase b -> SessionT m c)
-> (a -> SessionT m b)
-> SessionT m (b, c)
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
forall b.
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
forall a b c.
SessionT m a
-> (a -> ExitCase b -> SessionT m c)
-> (a -> SessionT m b)
-> SessionT m (b, c)
forall (m :: * -> *).
MonadCatch m =>
(forall b. ((forall a. m a -> m a) -> m b) -> m b)
-> (forall b. ((forall a. m a -> m a) -> m b) -> m b)
-> (forall a b c.
    m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c))
-> MonadMask m
forall (m :: * -> *). MonadMask m => MonadCatch (SessionT m)
forall (m :: * -> *) b.
MonadMask m =>
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
forall (m :: * -> *) a b c.
MonadMask m =>
SessionT m a
-> (a -> ExitCase b -> SessionT m c)
-> (a -> SessionT m b)
-> SessionT m (b, c)
generalBracket :: SessionT m a
-> (a -> ExitCase b -> SessionT m c)
-> (a -> SessionT m b)
-> SessionT m (b, c)
$cgeneralBracket :: forall (m :: * -> *) a b c.
MonadMask m =>
SessionT m a
-> (a -> ExitCase b -> SessionT m c)
-> (a -> SessionT m b)
-> SessionT m (b, c)
uninterruptibleMask :: ((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
$cuninterruptibleMask :: forall (m :: * -> *) b.
MonadMask m =>
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
mask :: ((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
$cmask :: forall (m :: * -> *) b.
MonadMask m =>
((forall a. SessionT m a -> SessionT m a) -> SessionT m b)
-> SessionT m b
$cp1MonadMask :: forall (m :: * -> *). MonadMask m => MonadCatch (SessionT m)
MonadMask, Monad (SessionT m)
Monad (SessionT m) =>
(forall a. String -> SessionT m a) -> MonadFail (SessionT m)
String -> SessionT m a
forall a. String -> SessionT m a
forall (m :: * -> *).
Monad m =>
(forall a. String -> m a) -> MonadFail m
forall (m :: * -> *). MonadFail m => Monad (SessionT m)
forall (m :: * -> *) a. MonadFail m => String -> SessionT m a
fail :: String -> SessionT m a
$cfail :: forall (m :: * -> *) a. MonadFail m => String -> SessionT m a
$cp1MonadFail :: forall (m :: * -> *). MonadFail m => Monad (SessionT m)
MonadFail)

instance MonadTrans SessionT where
  lift :: m a -> SessionT m a
lift = ReaderT SessionState (BuildT m) a -> SessionT m a
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session (ReaderT SessionState (BuildT m) a -> SessionT m a)
-> (m a -> ReaderT SessionState (BuildT m) a)
-> m a
-> SessionT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BuildT m a -> ReaderT SessionState (BuildT m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (BuildT m a -> ReaderT SessionState (BuildT m) a)
-> (m a -> BuildT m a) -> m a -> ReaderT SessionState (BuildT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> BuildT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

type Session = SessionT IO

-- | Run 'Session' actions in a new TensorFlow session.
runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a
runSession :: SessionT m a -> m a
runSession = Options -> SessionT m a -> m a
forall (m :: * -> *) a.
(MonadMask m, MonadIO m) =>
Options -> SessionT m a -> m a
runSessionWithOptions Options
forall a. Default a => a
def

-- | Customization for session. Use the lenses to update:
-- 'sessionTarget', 'sessionTracer', 'sessionConfig'.
data Options = Options
    { Options -> ByteString
_sessionTarget :: ByteString
    , Options -> ConfigProto
_sessionConfig :: ConfigProto
    , Options -> Tracer
_sessionTracer :: Tracer
    }

instance Default Options where
    def :: Options
def = Options :: ByteString -> ConfigProto -> Tracer -> Options
Options
          { _sessionTarget :: ByteString
_sessionTarget = ""
          , _sessionConfig :: ConfigProto
_sessionConfig = ConfigProto
forall msg. Message msg => msg
defMessage
          , _sessionTracer :: Tracer
_sessionTracer = IO () -> Tracer
forall a b. a -> b -> a
const (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
          }

-- | Target can be: "local", ip:port, host:port.
-- The set of supported factories depends on the linked in libraries.
sessionTarget :: Lens' Options ByteString
sessionTarget :: LensLike' f Options ByteString
sessionTarget = (Options -> ByteString)
-> (Options -> ByteString -> Options)
-> Lens Options Options ByteString ByteString
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Options -> ByteString
_sessionTarget (\g :: Options
g x :: ByteString
x -> Options
g { _sessionTarget :: ByteString
_sessionTarget = ByteString
x })

-- | Uses the specified config for the created session.
sessionConfig :: Lens' Options ConfigProto
sessionConfig :: LensLike' f Options ConfigProto
sessionConfig = (Options -> ConfigProto)
-> (Options -> ConfigProto -> Options)
-> Lens Options Options ConfigProto ConfigProto
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Options -> ConfigProto
_sessionConfig (\g :: Options
g x :: ConfigProto
x -> Options
g { _sessionConfig :: ConfigProto
_sessionConfig = ConfigProto
x })

-- | Uses the given logger to monitor session progress.
sessionTracer :: Lens' Options Tracer
sessionTracer :: LensLike' f Options Tracer
sessionTracer = (Options -> Tracer)
-> (Options -> Tracer -> Options)
-> Lens Options Options Tracer Tracer
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Options -> Tracer
_sessionTracer (\g :: Options
g x :: Tracer
x -> Options
g { _sessionTracer :: Tracer
_sessionTracer = Tracer
x })

-- | Run 'Session' actions in a new TensorFlow session created with
-- the given option setter actions ('sessionTarget', 'sessionConfig').
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
runSessionWithOptions :: Options -> SessionT m a -> m a
runSessionWithOptions options :: Options
options (Session m :: ReaderT SessionState (BuildT m) a
m) =
    (SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Session -> m a) -> m a
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
(SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Session -> m a) -> m a
FFI.withSession SessionOptions -> IO ()
applyOptions (((IO () -> IO ()) -> Session -> m a) -> m a)
-> ((IO () -> IO ()) -> Session -> m a) -> m a
forall a b. (a -> b) -> a -> b
$
        \as :: IO () -> IO ()
as rs :: Session
rs ->
            let initState :: SessionState
initState = Session -> (IO () -> IO ()) -> Tracer -> SessionState
SessionState Session
rs IO () -> IO ()
as (Options
options Options -> FoldLike Tracer Options Options Tracer Tracer -> Tracer
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Tracer Options Options Tracer Tracer
Lens Options Options Tracer Tracer
sessionTracer)
            in BuildT m a -> m a
forall (m :: * -> *) a. Monad m => BuildT m a -> m a
evalBuildT (ReaderT SessionState (BuildT m) a -> SessionState -> BuildT m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT SessionState (BuildT m) a
m SessionState
initState)
  where applyOptions :: SessionOptions -> IO ()
applyOptions opt :: SessionOptions
opt = do
            ByteString -> SessionOptions -> IO ()
FFI.setSessionTarget (Options
options Options
-> FoldLike ByteString Options Options ByteString ByteString
-> ByteString
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike ByteString Options Options ByteString ByteString
Lens Options Options ByteString ByteString
sessionTarget) SessionOptions
opt
            ConfigProto -> SessionOptions -> IO ()
FFI.setSessionConfig (Options
options Options
-> FoldLike ConfigProto Options Options ConfigProto ConfigProto
-> ConfigProto
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike ConfigProto Options Options ConfigProto ConfigProto
Lens Options Options ConfigProto ConfigProto
sessionConfig) SessionOptions
opt

instance Monad m => MonadBuild (SessionT m) where
    build :: Build a -> SessionT m a
build = ReaderT SessionState (BuildT m) a -> SessionT m a
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session (ReaderT SessionState (BuildT m) a -> SessionT m a)
-> (Build a -> ReaderT SessionState (BuildT m) a)
-> Build a
-> SessionT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BuildT m a -> ReaderT SessionState (BuildT m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (BuildT m a -> ReaderT SessionState (BuildT m) a)
-> (Build a -> BuildT m a)
-> Build a
-> ReaderT SessionState (BuildT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Build a -> BuildT m a
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build

-- | Add all pending rendered nodes to the TensorFlow graph and runs
-- any pending initializers.
--
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
extend :: MonadIO m => SessionT m ()
extend :: SessionT m ()
extend = do
    Session
session <- ReaderT SessionState (BuildT m) Session -> SessionT m Session
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session ((SessionState -> Session)
-> ReaderT SessionState (BuildT m) Session
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
asks SessionState -> Session
rawSession)
    Tracer
trace <- ReaderT SessionState (BuildT m) Tracer -> SessionT m Tracer
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session ((SessionState -> Tracer) -> ReaderT SessionState (BuildT m) Tracer
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
asks SessionState -> Tracer
tracer)
    [NodeDef]
nodesToExtend <- Build [NodeDef] -> SessionT m [NodeDef]
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build Build [NodeDef]
forall (m :: * -> *). MonadBuild m => m [NodeDef]
flushNodeBuffer
    Bool -> SessionT m () -> SessionT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([NodeDef] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [NodeDef]
nodesToExtend) (SessionT m () -> SessionT m ()) -> SessionT m () -> SessionT m ()
forall a b. (a -> b) -> a -> b
$ IO () -> SessionT m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> SessionT m ()) -> IO () -> SessionT m ()
forall a b. (a -> b) -> a -> b
$ do
        let graphDef :: GraphDef
graphDef = (GraphDef
forall msg. Message msg => msg
defMessage :: GraphDef) GraphDef -> (GraphDef -> GraphDef) -> GraphDef
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f GraphDef [NodeDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "node" a) =>
LensLike' f s a
node (forall (f :: * -> *).
 Identical f =>
 LensLike' f GraphDef [NodeDef])
-> [NodeDef] -> GraphDef -> GraphDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [NodeDef]
nodesToExtend
        Tracer
trace ("Session.extend " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> String -> Builder
Builder.string8 (GraphDef -> String
forall msg. Message msg => msg -> String
showMessage GraphDef
graphDef))
        Session -> GraphDef -> IO ()
FFI.extendGraph Session
session GraphDef
graphDef
    -- Now that all the nodes are created, run the initializers.
    [NodeName]
initializers <- Build [NodeName] -> SessionT m [NodeName]
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build Build [NodeName]
forall (m :: * -> *). Monad m => BuildT m [NodeName]
flushInitializers
    Bool -> SessionT m () -> SessionT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([NodeName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [NodeName]
initializers) (SessionT m () -> SessionT m ()) -> SessionT m () -> SessionT m ()
forall a b. (a -> b) -> a -> b
$
        SessionT m [TensorData] -> SessionT m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (SessionT m [TensorData] -> SessionT m ())
-> SessionT m [TensorData] -> SessionT m ()
forall a b. (a -> b) -> a -> b
$ IO [TensorData] -> SessionT m [TensorData]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [TensorData] -> SessionT m [TensorData])
-> IO [TensorData] -> SessionT m [TensorData]
forall a b. (a -> b) -> a -> b
$ Session
-> [(ByteString, TensorData)]
-> [ByteString]
-> [ByteString]
-> IO [TensorData]
FFI.run Session
session [] [] ([NodeName] -> [ByteString]
toNodeNames [NodeName]
initializers)

-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, and fetch the corresponding values for 'a'.
run :: (MonadIO m, Fetchable t a) => t -> SessionT m a
run :: t -> SessionT m a
run = [Feed] -> t -> SessionT m a
forall (m :: * -> *) t a.
(MonadIO m, Fetchable t a) =>
[Feed] -> t -> SessionT m a
runWithFeeds []

-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, feed the given input values, and fetch the corresponding result
-- values for 'a'.
runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a
runWithFeeds :: [Feed] -> t -> SessionT m a
runWithFeeds feeds :: [Feed]
feeds t :: t
t = do
    Set NodeName
ns <- Build (Set NodeName) -> SessionT m (Set NodeName)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build (Set NodeName) -> SessionT m (Set NodeName))
-> Build (Set NodeName) -> SessionT m (Set NodeName)
forall a b. (a -> b) -> a -> b
$ t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t
t
    -- Note that this call to "fetch" shouldn't affect the following "extend"
    -- call, since all nodes in t and its inputs/deps will be rendered by the
    -- above call to getNodes.
    Fetch a
fetch <- Build (Fetch a) -> SessionT m (Fetch a)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build (Fetch a) -> SessionT m (Fetch a))
-> Build (Fetch a) -> SessionT m (Fetch a)
forall a b. (a -> b) -> a -> b
$ t -> Build (Fetch a)
forall t a. Fetchable t a => t -> Build (Fetch a)
getFetch t
t
    [Feed] -> Set NodeName -> Fetch a -> SessionT m a
forall (m :: * -> *) a.
MonadIO m =>
[Feed] -> Set NodeName -> Fetch a -> SessionT m a
runFetchWithFeeds [Feed]
feeds Set NodeName
ns Fetch a
fetch

runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> SessionT m a
runFetchWithFeeds feeds :: [Feed]
feeds target :: Set NodeName
target (Fetch fetch :: Set Text
fetch restore :: Map Text TensorData -> a
restore) = do
    SessionT m ()
forall (m :: * -> *). MonadIO m => SessionT m ()
extend
    let feeds' :: [(ByteString, TensorData)]
feeds' = [Feed] -> [(ByteString, TensorData)]
fixFeeds [Feed]
feeds
    let fetchNames :: [ByteString]
fetchNames = Text -> ByteString
encodeUtf8 (Text -> ByteString) -> [Text] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set Text -> [Text]
forall a. Set a -> [a]
Set.toList Set Text
fetch
        targetNames :: [ByteString]
targetNames = [NodeName] -> [ByteString]
toNodeNames ([NodeName] -> [ByteString]) -> [NodeName] -> [ByteString]
forall a b. (a -> b) -> a -> b
$ Set NodeName -> [NodeName]
forall a. Set a -> [a]
Set.toList Set NodeName
target
    Session
session <- ReaderT SessionState (BuildT m) Session -> SessionT m Session
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session ((SessionState -> Session)
-> ReaderT SessionState (BuildT m) Session
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
asks SessionState -> Session
rawSession)
    [TensorData]
runResult <- IO [TensorData] -> SessionT m [TensorData]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [TensorData] -> SessionT m [TensorData])
-> IO [TensorData] -> SessionT m [TensorData]
forall a b. (a -> b) -> a -> b
$ Session
-> [(ByteString, TensorData)]
-> [ByteString]
-> [ByteString]
-> IO [TensorData]
FFI.run Session
session
                                  [(ByteString, TensorData)]
feeds'
                                  [ByteString]
fetchNames
                                  [ByteString]
targetNames
    let resultTensorsMap :: Map Text TensorData
resultTensorsMap = [(Text, TensorData)] -> Map Text TensorData
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Text, TensorData)] -> Map Text TensorData)
-> [(Text, TensorData)] -> Map Text TensorData
forall a b. (a -> b) -> a -> b
$ [Text] -> [TensorData] -> [(Text, TensorData)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Set Text -> [Text]
forall a. Set a -> [a]
Set.toList Set Text
fetch) [TensorData]
runResult
    a -> SessionT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> SessionT m a) -> a -> SessionT m a
forall a b. (a -> b) -> a -> b
$ Map Text TensorData -> a
restore Map Text TensorData
resultTensorsMap

toNodeNames :: [NodeName] -> [ByteString]
toNodeNames :: [NodeName] -> [ByteString]
toNodeNames = (NodeName -> ByteString) -> [NodeName] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (Text -> ByteString
encodeUtf8 (Text -> ByteString)
-> (NodeName -> Text) -> NodeName -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeName -> Text
unNodeName)

-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
-- already rendered.  This behaves like 'run' except that it doesn't do any
-- fetches.
run_ :: (MonadIO m, Nodes t) => t -> SessionT m ()
run_ :: t -> SessionT m ()
run_ = [Feed] -> t -> SessionT m ()
forall (m :: * -> *) t.
(MonadIO m, Nodes t) =>
[Feed] -> t -> SessionT m ()
runWithFeeds_ []

-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, feed the given input values, and fetch the corresponding result
-- values for 'a'.  This behaves like 'runWithFeeds' except that it doesn't do
-- any fetches.
runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m ()
runWithFeeds_ :: [Feed] -> t -> SessionT m ()
runWithFeeds_ feeds :: [Feed]
feeds t :: t
t = do
    Set NodeName
ns <- Build (Set NodeName) -> SessionT m (Set NodeName)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build (Set NodeName) -> SessionT m (Set NodeName))
-> Build (Set NodeName) -> SessionT m (Set NodeName)
forall a b. (a -> b) -> a -> b
$ t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t
t
    [Feed] -> Set NodeName -> Fetch () -> SessionT m ()
forall (m :: * -> *) a.
MonadIO m =>
[Feed] -> Set NodeName -> Fetch a -> SessionT m a
runFetchWithFeeds [Feed]
feeds Set NodeName
ns (() -> Fetch ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)]
fixFeeds :: [Feed] -> [(ByteString, TensorData)]
fixFeeds = (Feed -> (ByteString, TensorData))
-> [Feed] -> [(ByteString, TensorData)]
forall a b. (a -> b) -> [a] -> [b]
map ((Feed -> (ByteString, TensorData))
 -> [Feed] -> [(ByteString, TensorData)])
-> (Feed -> (ByteString, TensorData))
-> [Feed]
-> [(ByteString, TensorData)]
forall a b. (a -> b) -> a -> b
$ \(Feed o :: Output
o d :: TensorData
d) -> (Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Output -> Text
encodeOutput Output
o, TensorData
d)

-- | Starts a concurrent thread which evaluates the given Nodes
-- forever until runSession exits or an exception occurs. Graph
-- extension happens synchronously, but the resultant run proceeds as
-- a separate thread.
asyncProdNodes :: (MonadIO m, Nodes t)
                  => t  -- ^ Node to evaluate concurrently.
                  -> SessionT m ()
asyncProdNodes :: t -> SessionT m ()
asyncProdNodes nodes :: t
nodes = do
    Set NodeName
target <- Build (Set NodeName) -> SessionT m (Set NodeName)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t
nodes)
    SessionT m ()
forall (m :: * -> *). MonadIO m => SessionT m ()
extend
    let targetNames :: [ByteString]
targetNames = [NodeName] -> [ByteString]
toNodeNames ([NodeName] -> [ByteString]) -> [NodeName] -> [ByteString]
forall a b. (a -> b) -> a -> b
$ Set NodeName -> [NodeName]
forall a. Set a -> [a]
Set.toList Set NodeName
target
    SessionState
state <- ReaderT SessionState (BuildT m) SessionState
-> SessionT m SessionState
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session ReaderT SessionState (BuildT m) SessionState
forall (m :: * -> *) r. Monad m => ReaderT r m r
ask
    let loop :: IO b
loop = IO () -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO [TensorData] -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Session
-> [(ByteString, TensorData)]
-> [ByteString]
-> [ByteString]
-> IO [TensorData]
FFI.run (SessionState -> Session
rawSession SessionState
state) [] [] [ByteString]
targetNames))
    IO () -> SessionT m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (SessionState -> IO () -> IO ()
asyncCollector SessionState
state IO ()
forall b. IO b
loop)