-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module TensorFlow.Build
    ( -- * Graph node types
      ControlNode(..)
    , Unique
    -- * Ops
    , explicitName
    , implicitName
    , opDef
    , opDefWithName
    , opName
    , opType
    , opAttr
    , opInputs
    , opControlInputs
    -- * The Build monad
    , GraphState
    , renderedNodeDefs
    , BuildT
    , Build
    , MonadBuild(..)
    , addInitializer
    , hoistBuildT
    , evalBuildT
    , runBuildT
    , asGraphDef
    , addGraphDef
    , flushInitializers
    , flushNodeBuffer
    , summaries
    -- * Creating and looking up Ops
    , getOrAddOp
    , addNewOp
    , encodeOutput
    , lookupNode
    -- * Modifying all nodes in a Build action
    , withStateLens
    , withDevice
    , withNameScope
    , withNodeDependencies
    ) where

import Data.ProtoLens.Message(defMessage)
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.Fix (MonadFix(..))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
import Data.Functor.Identity (Identity(..))
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Data.Set (Set)
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 (Lens', (.~), (^.), (&))
import Lens.Family2.State.Strict (MonadState, use, uses, (.=), (<>=), (%=))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef)
import Proto.Tensorflow.Core.Framework.NodeDef_Fields
    ( attr
    , input
    , device
    , name
    , op
    )

import TensorFlow.Output

newtype Unique = Unique Int
    deriving (Unique -> Unique -> Bool
(Unique -> Unique -> Bool)
-> (Unique -> Unique -> Bool) -> Eq Unique
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Unique -> Unique -> Bool
$c/= :: Unique -> Unique -> Bool
== :: Unique -> Unique -> Bool
$c== :: Unique -> Unique -> Bool
Eq, Eq Unique
Eq Unique =>
(Unique -> Unique -> Ordering)
-> (Unique -> Unique -> Bool)
-> (Unique -> Unique -> Bool)
-> (Unique -> Unique -> Bool)
-> (Unique -> Unique -> Bool)
-> (Unique -> Unique -> Unique)
-> (Unique -> Unique -> Unique)
-> Ord Unique
Unique -> Unique -> Bool
Unique -> Unique -> Ordering
Unique -> Unique -> Unique
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Unique -> Unique -> Unique
$cmin :: Unique -> Unique -> Unique
max :: Unique -> Unique -> Unique
$cmax :: Unique -> Unique -> Unique
>= :: Unique -> Unique -> Bool
$c>= :: Unique -> Unique -> Bool
> :: Unique -> Unique -> Bool
$c> :: Unique -> Unique -> Bool
<= :: Unique -> Unique -> Bool
$c<= :: Unique -> Unique -> Bool
< :: Unique -> Unique -> Bool
$c< :: Unique -> Unique -> Bool
compare :: Unique -> Unique -> Ordering
$ccompare :: Unique -> Unique -> Ordering
$cp1Ord :: Eq Unique
Ord, Int -> Unique
Unique -> Int
Unique -> [Unique]
Unique -> Unique
Unique -> Unique -> [Unique]
Unique -> Unique -> Unique -> [Unique]
(Unique -> Unique)
-> (Unique -> Unique)
-> (Int -> Unique)
-> (Unique -> Int)
-> (Unique -> [Unique])
-> (Unique -> Unique -> [Unique])
-> (Unique -> Unique -> [Unique])
-> (Unique -> Unique -> Unique -> [Unique])
-> Enum Unique
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: Unique -> Unique -> Unique -> [Unique]
$cenumFromThenTo :: Unique -> Unique -> Unique -> [Unique]
enumFromTo :: Unique -> Unique -> [Unique]
$cenumFromTo :: Unique -> Unique -> [Unique]
enumFromThen :: Unique -> Unique -> [Unique]
$cenumFromThen :: Unique -> Unique -> [Unique]
enumFrom :: Unique -> [Unique]
$cenumFrom :: Unique -> [Unique]
fromEnum :: Unique -> Int
$cfromEnum :: Unique -> Int
toEnum :: Int -> Unique
$ctoEnum :: Int -> Unique
pred :: Unique -> Unique
$cpred :: Unique -> Unique
succ :: Unique -> Unique
$csucc :: Unique -> Unique
Enum)

--------------

implicitName :: PendingNodeName
implicitName :: PendingNodeName
implicitName = PendingNodeName
ImplicitName

explicitName :: Text -> PendingNodeName
explicitName :: Text -> PendingNodeName
explicitName = Text -> PendingNodeName
ExplicitName

newtype Scope = Scope {Scope -> Text
unScope :: Text}
    deriving (Scope -> Scope -> Bool
(Scope -> Scope -> Bool) -> (Scope -> Scope -> Bool) -> Eq Scope
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Scope -> Scope -> Bool
$c/= :: Scope -> Scope -> Bool
== :: Scope -> Scope -> Bool
$c== :: Scope -> Scope -> Bool
Eq, Eq Scope
Eq Scope =>
(Scope -> Scope -> Ordering)
-> (Scope -> Scope -> Bool)
-> (Scope -> Scope -> Bool)
-> (Scope -> Scope -> Bool)
-> (Scope -> Scope -> Bool)
-> (Scope -> Scope -> Scope)
-> (Scope -> Scope -> Scope)
-> Ord Scope
Scope -> Scope -> Bool
Scope -> Scope -> Ordering
Scope -> Scope -> Scope
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Scope -> Scope -> Scope
$cmin :: Scope -> Scope -> Scope
max :: Scope -> Scope -> Scope
$cmax :: Scope -> Scope -> Scope
>= :: Scope -> Scope -> Bool
$c>= :: Scope -> Scope -> Bool
> :: Scope -> Scope -> Bool
$c> :: Scope -> Scope -> Bool
<= :: Scope -> Scope -> Bool
$c<= :: Scope -> Scope -> Bool
< :: Scope -> Scope -> Bool
$c< :: Scope -> Scope -> Bool
compare :: Scope -> Scope -> Ordering
$ccompare :: Scope -> Scope -> Ordering
$cp1Ord :: Eq Scope
Ord, String -> Scope
(String -> Scope) -> IsString Scope
forall a. (String -> a) -> IsString a
fromString :: String -> Scope
$cfromString :: String -> Scope
IsString)

instance Show Scope where
    show :: Scope -> String
show = Text -> String
forall a. Show a => a -> String
show (Text -> String) -> (Scope -> Text) -> Scope -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope -> Text
unScope

opDef :: OpType -> OpDef
opDef :: OpType -> OpDef
opDef = PendingNodeName -> OpType -> OpDef
opDefWithName PendingNodeName
ImplicitName

opDefWithName :: PendingNodeName -> OpType -> OpDef
opDefWithName :: PendingNodeName -> OpType -> OpDef
opDefWithName n :: PendingNodeName
n t :: OpType
t = $WOpDef :: PendingNodeName
-> OpType -> Map Text AttrValue -> [Output] -> [NodeName] -> OpDef
OpDef
    { _opName :: PendingNodeName
_opName = PendingNodeName
n
    , _opType :: OpType
_opType = OpType
t
    , _opAttrs :: Map Text AttrValue
_opAttrs = Map Text AttrValue
forall k a. Map k a
Map.empty
    , _opInputs :: [Output]
_opInputs = []
    , _opControlInputs :: [NodeName]
_opControlInputs = []
    }

data GraphState = GraphState
    { GraphState -> Map PendingNode NodeDef
_renderedNodes :: !(Map.Map PendingNode NodeDef)
        -- ^ Nodes which have been rendered.  Keeps track of the unique ID we
        -- assign each implicitly-named node.  Also prevents us from adding the
        -- same node (implicit or explicit) more than once to the nodeBuffer.
    , GraphState -> Map NodeName NodeDef
_renderedNodeDefs :: !(Map.Map NodeName NodeDef)
        -- ^ The NodeDefs of nodes which have been rendered. Used by the
        -- Gradient module to inspect the node graph.
    , GraphState -> [NodeDef]
_nodeBuffer :: [NodeDef]
        -- ^ A list of nodes that should be passed to TensorFlow during
        -- the next call to Session.extend (TF_ExtendGraph).
    , GraphState -> Unique
_nextUnique :: !Unique
        -- ^ Unique ID for the next node
    -- TODO(judahjacobson): watch for clashes between auto and user names.
    , GraphState -> Maybe Device
_defaultDevice :: !(Maybe Device)
    , GraphState -> [Scope]
_currentScope :: [Scope]
    , GraphState -> Set NodeName
_defaultControlInputs :: !(Set NodeName)
    , GraphState -> [NodeName]
_initializationNodes  :: [NodeName]
      -- ^ The nodes to run next time a TF.run is issued, typically
      -- variable initializers.
    , GraphState -> [Output]
_summaries :: [Output]
      -- ^ The tensors for summary (ByteString type)
    }

-- | A node definition without its final name.  Used as a key in the
-- "renderedNodes" map.
-- The NodeDef contained inside has an empty "name" field.
data PendingNode = PendingNode [Scope] !PendingNodeName !NodeDef
    deriving (PendingNode -> PendingNode -> Bool
(PendingNode -> PendingNode -> Bool)
-> (PendingNode -> PendingNode -> Bool) -> Eq PendingNode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PendingNode -> PendingNode -> Bool
$c/= :: PendingNode -> PendingNode -> Bool
== :: PendingNode -> PendingNode -> Bool
$c== :: PendingNode -> PendingNode -> Bool
Eq, Eq PendingNode
Eq PendingNode =>
(PendingNode -> PendingNode -> Ordering)
-> (PendingNode -> PendingNode -> Bool)
-> (PendingNode -> PendingNode -> Bool)
-> (PendingNode -> PendingNode -> Bool)
-> (PendingNode -> PendingNode -> Bool)
-> (PendingNode -> PendingNode -> PendingNode)
-> (PendingNode -> PendingNode -> PendingNode)
-> Ord PendingNode
PendingNode -> PendingNode -> Bool
PendingNode -> PendingNode -> Ordering
PendingNode -> PendingNode -> PendingNode
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: PendingNode -> PendingNode -> PendingNode
$cmin :: PendingNode -> PendingNode -> PendingNode
max :: PendingNode -> PendingNode -> PendingNode
$cmax :: PendingNode -> PendingNode -> PendingNode
>= :: PendingNode -> PendingNode -> Bool
$c>= :: PendingNode -> PendingNode -> Bool
> :: PendingNode -> PendingNode -> Bool
$c> :: PendingNode -> PendingNode -> Bool
<= :: PendingNode -> PendingNode -> Bool
$c<= :: PendingNode -> PendingNode -> Bool
< :: PendingNode -> PendingNode -> Bool
$c< :: PendingNode -> PendingNode -> Bool
compare :: PendingNode -> PendingNode -> Ordering
$ccompare :: PendingNode -> PendingNode -> Ordering
$cp1Ord :: Eq PendingNode
Ord)

-- Returns an _incomplete_ NodeDef. The name is fixed by addNewOpFromPending.
pendingNodeDef :: PendingNode -> NodeDef
pendingNodeDef :: PendingNode -> NodeDef
pendingNodeDef (PendingNode _ _ n :: NodeDef
n) = NodeDef
n

initGraphState :: GraphState
initGraphState :: GraphState
initGraphState =
    Map PendingNode NodeDef
-> Map NodeName NodeDef
-> [NodeDef]
-> Unique
-> Maybe Device
-> [Scope]
-> Set NodeName
-> [NodeName]
-> [Output]
-> GraphState
GraphState Map PendingNode NodeDef
forall k a. Map k a
Map.empty Map NodeName NodeDef
forall k a. Map k a
Map.empty [] (Int -> Unique
Unique 0) Maybe Device
forall a. Maybe a
Nothing [] Set NodeName
forall a. Set a
Set.empty [] []

renderedNodes :: Lens' GraphState (Map.Map PendingNode NodeDef)
renderedNodes :: LensLike' f GraphState (Map PendingNode NodeDef)
renderedNodes = (GraphState -> Map PendingNode NodeDef)
-> (GraphState -> Map PendingNode NodeDef -> GraphState)
-> Lens
     GraphState
     GraphState
     (Map PendingNode NodeDef)
     (Map PendingNode NodeDef)
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GraphState -> Map PendingNode NodeDef
_renderedNodes (\g :: GraphState
g x :: Map PendingNode NodeDef
x -> GraphState
g { _renderedNodes :: Map PendingNode NodeDef
_renderedNodes = Map PendingNode NodeDef
x })

renderedNodeDefs :: Lens' GraphState (Map.Map NodeName NodeDef)
renderedNodeDefs :: LensLike' f GraphState (Map NodeName NodeDef)
renderedNodeDefs = (GraphState -> Map NodeName NodeDef)
-> (GraphState -> Map NodeName NodeDef -> GraphState)
-> Lens
     GraphState GraphState (Map NodeName NodeDef) (Map NodeName NodeDef)
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GraphState -> Map NodeName NodeDef
_renderedNodeDefs (\g :: GraphState
g x :: Map NodeName NodeDef
x -> GraphState
g { _renderedNodeDefs :: Map NodeName NodeDef
_renderedNodeDefs = Map NodeName NodeDef
x })

nodeBuffer :: Lens' GraphState [NodeDef]
nodeBuffer :: LensLike' f GraphState [NodeDef]
nodeBuffer = (GraphState -> [NodeDef])
-> (GraphState -> [NodeDef] -> GraphState)
-> Lens GraphState GraphState [NodeDef] [NodeDef]
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GraphState -> [NodeDef]
_nodeBuffer (\g :: GraphState
g x :: [NodeDef]
x -> GraphState
g { _nodeBuffer :: [NodeDef]
_nodeBuffer = [NodeDef]
x })

nextUnique :: Lens' GraphState Unique
nextUnique :: LensLike' f GraphState Unique
nextUnique = (GraphState -> Unique)
-> (GraphState -> Unique -> GraphState)
-> Lens GraphState GraphState Unique Unique
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GraphState -> Unique
_nextUnique (\g :: GraphState
g x :: Unique
x -> GraphState
g { _nextUnique :: Unique
_nextUnique = Unique
x })

defaultDevice :: Lens' GraphState (Maybe Device)
defaultDevice :: LensLike' f GraphState (Maybe Device)
defaultDevice = (GraphState -> Maybe Device)
-> (GraphState -> Maybe Device -> GraphState)
-> Lens GraphState GraphState (Maybe Device) (Maybe Device)
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GraphState -> Maybe Device
_defaultDevice (\g :: GraphState
g x :: Maybe Device
x -> GraphState
g { _defaultDevice :: Maybe Device
_defaultDevice = Maybe Device
x })

currentScope :: Lens' GraphState [Scope]
currentScope :: LensLike' f GraphState [Scope]
currentScope = (GraphState -> [Scope])
-> (GraphState -> [Scope] -> GraphState)
-> Lens GraphState GraphState [Scope] [Scope]
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GraphState -> [Scope]
_currentScope (\g :: GraphState
g x :: [Scope]
x -> GraphState
g { _currentScope :: [Scope]
_currentScope = [Scope]
x })

defaultControlInputs :: Lens' GraphState (Set NodeName)
defaultControlInputs :: LensLike' f GraphState (Set NodeName)
defaultControlInputs = (GraphState -> Set NodeName)
-> (GraphState -> Set NodeName -> GraphState)
-> Lens GraphState GraphState (Set NodeName) (Set NodeName)
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GraphState -> Set NodeName
_defaultControlInputs
                          (\g :: GraphState
g x :: Set NodeName
x -> GraphState
g { _defaultControlInputs :: Set NodeName
_defaultControlInputs = Set NodeName
x })

initializationNodes :: Lens' GraphState [NodeName]
initializationNodes :: LensLike' f GraphState [NodeName]
initializationNodes = (GraphState -> [NodeName])
-> (GraphState -> [NodeName] -> GraphState)
-> Lens GraphState GraphState [NodeName] [NodeName]
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GraphState -> [NodeName]
_initializationNodes (\g :: GraphState
g x :: [NodeName]
x -> GraphState
g { _initializationNodes :: [NodeName]
_initializationNodes = [NodeName]
x })

summaries :: Lens' GraphState [Output]
summaries :: LensLike' f GraphState [Output]
summaries = (GraphState -> [Output])
-> (GraphState -> [Output] -> GraphState)
-> Lens GraphState GraphState [Output] [Output]
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GraphState -> [Output]
_summaries (\g :: GraphState
g x :: [Output]
x -> GraphState
g { _summaries :: [Output]
_summaries = [Output]
x })

-- | An action for building nodes in a TensorFlow graph.
-- Used to manage build state internally as part of the @Session@ monad.
newtype BuildT m a = BuildT (StateT GraphState m a)
    deriving (a -> BuildT m b -> BuildT m a
(a -> b) -> BuildT m a -> BuildT m b
(forall a b. (a -> b) -> BuildT m a -> BuildT m b)
-> (forall a b. a -> BuildT m b -> BuildT m a)
-> Functor (BuildT m)
forall a b. a -> BuildT m b -> BuildT m a
forall a b. (a -> b) -> BuildT m a -> BuildT m b
forall (m :: * -> *) a b.
Functor m =>
a -> BuildT m b -> BuildT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> BuildT m a -> BuildT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> BuildT m b -> BuildT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> BuildT m b -> BuildT m a
fmap :: (a -> b) -> BuildT m a -> BuildT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> BuildT m a -> BuildT m b
Functor, Functor (BuildT m)
a -> BuildT m a
Functor (BuildT m) =>
(forall a. a -> BuildT m a)
-> (forall a b. BuildT m (a -> b) -> BuildT m a -> BuildT m b)
-> (forall a b c.
    (a -> b -> c) -> BuildT m a -> BuildT m b -> BuildT m c)
-> (forall a b. BuildT m a -> BuildT m b -> BuildT m b)
-> (forall a b. BuildT m a -> BuildT m b -> BuildT m a)
-> Applicative (BuildT m)
BuildT m a -> BuildT m b -> BuildT m b
BuildT m a -> BuildT m b -> BuildT m a
BuildT m (a -> b) -> BuildT m a -> BuildT m b
(a -> b -> c) -> BuildT m a -> BuildT m b -> BuildT m c
forall a. a -> BuildT m a
forall a b. BuildT m a -> BuildT m b -> BuildT m a
forall a b. BuildT m a -> BuildT m b -> BuildT m b
forall a b. BuildT m (a -> b) -> BuildT m a -> BuildT m b
forall a b c.
(a -> b -> c) -> BuildT m a -> BuildT m b -> BuildT m c
forall (m :: * -> *). Monad m => Functor (BuildT m)
forall (m :: * -> *) a. Monad m => a -> BuildT m a
forall (m :: * -> *) a b.
Monad m =>
BuildT m a -> BuildT m b -> BuildT m a
forall (m :: * -> *) a b.
Monad m =>
BuildT m a -> BuildT m b -> BuildT m b
forall (m :: * -> *) a b.
Monad m =>
BuildT m (a -> b) -> BuildT m a -> BuildT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> BuildT m a -> BuildT m b -> BuildT m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: BuildT m a -> BuildT m b -> BuildT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
BuildT m a -> BuildT m b -> BuildT m a
*> :: BuildT m a -> BuildT m b -> BuildT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
BuildT m a -> BuildT m b -> BuildT m b
liftA2 :: (a -> b -> c) -> BuildT m a -> BuildT m b -> BuildT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> BuildT m a -> BuildT m b -> BuildT m c
<*> :: BuildT m (a -> b) -> BuildT m a -> BuildT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
BuildT m (a -> b) -> BuildT m a -> BuildT m b
pure :: a -> BuildT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> BuildT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (BuildT m)
Applicative, Applicative (BuildT m)
a -> BuildT m a
Applicative (BuildT m) =>
(forall a b. BuildT m a -> (a -> BuildT m b) -> BuildT m b)
-> (forall a b. BuildT m a -> BuildT m b -> BuildT m b)
-> (forall a. a -> BuildT m a)
-> Monad (BuildT m)
BuildT m a -> (a -> BuildT m b) -> BuildT m b
BuildT m a -> BuildT m b -> BuildT m b
forall a. a -> BuildT m a
forall a b. BuildT m a -> BuildT m b -> BuildT m b
forall a b. BuildT m a -> (a -> BuildT m b) -> BuildT m b
forall (m :: * -> *). Monad m => Applicative (BuildT m)
forall (m :: * -> *) a. Monad m => a -> BuildT m a
forall (m :: * -> *) a b.
Monad m =>
BuildT m a -> BuildT m b -> BuildT m b
forall (m :: * -> *) a b.
Monad m =>
BuildT m a -> (a -> BuildT m b) -> BuildT m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> BuildT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> BuildT m a
>> :: BuildT m a -> BuildT m b -> BuildT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
BuildT m a -> BuildT m b -> BuildT m b
>>= :: BuildT m a -> (a -> BuildT m b) -> BuildT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
BuildT m a -> (a -> BuildT m b) -> BuildT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (BuildT m)
Monad, Monad (BuildT m)
Monad (BuildT m) =>
(forall a. IO a -> BuildT m a) -> MonadIO (BuildT m)
IO a -> BuildT m a
forall a. IO a -> BuildT m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (BuildT m)
forall (m :: * -> *) a. MonadIO m => IO a -> BuildT m a
liftIO :: IO a -> BuildT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> BuildT m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (BuildT m)
MonadIO, m a -> BuildT m a
(forall (m :: * -> *) a. Monad m => m a -> BuildT m a)
-> MonadTrans BuildT
forall (m :: * -> *) a. Monad m => m a -> BuildT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> BuildT m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> BuildT m a
MonadTrans,
              MonadState GraphState, Monad (BuildT m)
e -> BuildT m a
Monad (BuildT m) =>
(forall e a. Exception e => e -> BuildT m a)
-> MonadThrow (BuildT m)
forall e a. Exception e => e -> BuildT m a
forall (m :: * -> *).
Monad m =>
(forall e a. Exception e => e -> m a) -> MonadThrow m
forall (m :: * -> *). MonadThrow m => Monad (BuildT m)
forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> BuildT m a
throwM :: e -> BuildT m a
$cthrowM :: forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> BuildT m a
$cp1MonadThrow :: forall (m :: * -> *). MonadThrow m => Monad (BuildT m)
MonadThrow, MonadThrow (BuildT m)
MonadThrow (BuildT m) =>
(forall e a.
 Exception e =>
 BuildT m a -> (e -> BuildT m a) -> BuildT m a)
-> MonadCatch (BuildT m)
BuildT m a -> (e -> BuildT m a) -> BuildT m a
forall e a.
Exception e =>
BuildT m a -> (e -> BuildT m a) -> BuildT m a
forall (m :: * -> *). MonadCatch m => MonadThrow (BuildT m)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
BuildT m a -> (e -> BuildT m a) -> BuildT m a
forall (m :: * -> *).
MonadThrow m =>
(forall e a. Exception e => m a -> (e -> m a) -> m a)
-> MonadCatch m
catch :: BuildT m a -> (e -> BuildT m a) -> BuildT m a
$ccatch :: forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
BuildT m a -> (e -> BuildT m a) -> BuildT m a
$cp1MonadCatch :: forall (m :: * -> *). MonadCatch m => MonadThrow (BuildT m)
MonadCatch, MonadCatch (BuildT m)
MonadCatch (BuildT m) =>
(forall b.
 ((forall a. BuildT m a -> BuildT m a) -> BuildT m b) -> BuildT m b)
-> (forall b.
    ((forall a. BuildT m a -> BuildT m a) -> BuildT m b) -> BuildT m b)
-> (forall a b c.
    BuildT m a
    -> (a -> ExitCase b -> BuildT m c)
    -> (a -> BuildT m b)
    -> BuildT m (b, c))
-> MonadMask (BuildT m)
BuildT m a
-> (a -> ExitCase b -> BuildT m c)
-> (a -> BuildT m b)
-> BuildT m (b, c)
((forall a. BuildT m a -> BuildT m a) -> BuildT m b) -> BuildT m b
((forall a. BuildT m a -> BuildT m a) -> BuildT m b) -> BuildT m b
forall b.
((forall a. BuildT m a -> BuildT m a) -> BuildT m b) -> BuildT m b
forall a b c.
BuildT m a
-> (a -> ExitCase b -> BuildT m c)
-> (a -> BuildT m b)
-> BuildT m (b, c)
forall (m :: * -> *).
MonadCatch m =>
(forall b. ((forall a. m a -> m a) -> m b) -> m b)
-> (forall b. ((forall a. m a -> m a) -> m b) -> m b)
-> (forall a b c.
    m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c))
-> MonadMask m
forall (m :: * -> *). MonadMask m => MonadCatch (BuildT m)
forall (m :: * -> *) b.
MonadMask m =>
((forall a. BuildT m a -> BuildT m a) -> BuildT m b) -> BuildT m b
forall (m :: * -> *) a b c.
MonadMask m =>
BuildT m a
-> (a -> ExitCase b -> BuildT m c)
-> (a -> BuildT m b)
-> BuildT m (b, c)
generalBracket :: BuildT m a
-> (a -> ExitCase b -> BuildT m c)
-> (a -> BuildT m b)
-> BuildT m (b, c)
$cgeneralBracket :: forall (m :: * -> *) a b c.
MonadMask m =>
BuildT m a
-> (a -> ExitCase b -> BuildT m c)
-> (a -> BuildT m b)
-> BuildT m (b, c)
uninterruptibleMask :: ((forall a. BuildT m a -> BuildT m a) -> BuildT m b) -> BuildT m b
$cuninterruptibleMask :: forall (m :: * -> *) b.
MonadMask m =>
((forall a. BuildT m a -> BuildT m a) -> BuildT m b) -> BuildT m b
mask :: ((forall a. BuildT m a -> BuildT m a) -> BuildT m b) -> BuildT m b
$cmask :: forall (m :: * -> *) b.
MonadMask m =>
((forall a. BuildT m a -> BuildT m a) -> BuildT m b) -> BuildT m b
$cp1MonadMask :: forall (m :: * -> *). MonadMask m => MonadCatch (BuildT m)
MonadMask,
              Monad (BuildT m)
Monad (BuildT m) =>
(forall a. (a -> BuildT m a) -> BuildT m a) -> MonadFix (BuildT m)
(a -> BuildT m a) -> BuildT m a
forall a. (a -> BuildT m a) -> BuildT m a
forall (m :: * -> *).
Monad m =>
(forall a. (a -> m a) -> m a) -> MonadFix m
forall (m :: * -> *). MonadFix m => Monad (BuildT m)
forall (m :: * -> *) a.
MonadFix m =>
(a -> BuildT m a) -> BuildT m a
mfix :: (a -> BuildT m a) -> BuildT m a
$cmfix :: forall (m :: * -> *) a.
MonadFix m =>
(a -> BuildT m a) -> BuildT m a
$cp1MonadFix :: forall (m :: * -> *). MonadFix m => Monad (BuildT m)
MonadFix, Monad (BuildT m)
Monad (BuildT m) =>
(forall a. String -> BuildT m a) -> MonadFail (BuildT m)
String -> BuildT m a
forall a. String -> BuildT m a
forall (m :: * -> *).
Monad m =>
(forall a. String -> m a) -> MonadFail m
forall (m :: * -> *). MonadFail m => Monad (BuildT m)
forall (m :: * -> *) a. MonadFail m => String -> BuildT m a
fail :: String -> BuildT m a
$cfail :: forall (m :: * -> *) a. MonadFail m => String -> BuildT m a
$cp1MonadFail :: forall (m :: * -> *). MonadFail m => Monad (BuildT m)
MonadFail)

-- | An action for building nodes in a TensorFlow graph.
type Build = BuildT Identity

-- | This is Control.Monad.Morph.hoist sans the dependency.
hoistBuildT :: (forall a . m a -> n a) -> BuildT m b -> BuildT n b
hoistBuildT :: (forall a. m a -> n a) -> BuildT m b -> BuildT n b
hoistBuildT f :: forall a. m a -> n a
f (BuildT m :: StateT GraphState m b
m) = StateT GraphState n b -> BuildT n b
forall (m :: * -> *) a. StateT GraphState m a -> BuildT m a
BuildT (StateT GraphState n b -> BuildT n b)
-> StateT GraphState n b -> BuildT n b
forall a b. (a -> b) -> a -> b
$ (m (b, GraphState) -> n (b, GraphState))
-> StateT GraphState m b -> StateT GraphState n b
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT m (b, GraphState) -> n (b, GraphState)
forall a. m a -> n a
f StateT GraphState m b
m

runBuildT :: BuildT m a -> m (a, GraphState)
runBuildT :: BuildT m a -> m (a, GraphState)
runBuildT (BuildT f :: StateT GraphState m a
f) = StateT GraphState m a -> GraphState -> m (a, GraphState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT GraphState m a
f GraphState
initGraphState

evalBuildT :: Monad m => BuildT m a -> m a
evalBuildT :: BuildT m a -> m a
evalBuildT (BuildT f :: StateT GraphState m a
f) = StateT GraphState m a -> GraphState -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT GraphState m a
f GraphState
initGraphState

-- | Lift a 'Build' action into a monad, including any explicit op renderings.
class Monad m => MonadBuild m where
    build :: Build a -> m a

instance Monad m => MonadBuild (BuildT m) where
    build :: Build a -> BuildT m a
build = (forall a. Identity a -> m a) -> Build a -> BuildT m a
forall (m :: * -> *) (n :: * -> *) b.
(forall a. m a -> n a) -> BuildT m b -> BuildT n b
hoistBuildT ((forall a. Identity a -> m a) -> Build a -> BuildT m a)
-> (forall a. Identity a -> m a) -> Build a -> BuildT m a
forall a b. (a -> b) -> a -> b
$ a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> m a) -> (Identity a -> a) -> Identity a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity a -> a
forall a. Identity a -> a
runIdentity

-- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
flushNodeBuffer :: MonadBuild m => m [NodeDef]
flushNodeBuffer :: m [NodeDef]
flushNodeBuffer = Build [NodeDef] -> m [NodeDef]
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build [NodeDef] -> m [NodeDef]) -> Build [NodeDef] -> m [NodeDef]
forall a b. (a -> b) -> a -> b
$ do
    [NodeDef]
ns <- FoldLike [NodeDef] GraphState GraphState [NodeDef] [NodeDef]
-> Build [NodeDef]
forall s (m :: * -> *) a t b.
MonadState s m =>
FoldLike a s t a b -> m a
use FoldLike [NodeDef] GraphState GraphState [NodeDef] [NodeDef]
Lens GraphState GraphState [NodeDef] [NodeDef]
nodeBuffer
    Lens GraphState GraphState [NodeDef] [NodeDef]
forall (f :: * -> *).
Identical f =>
LensLike' f GraphState [NodeDef]
nodeBuffer (forall (f :: * -> *).
 Identical f =>
 LensLike' f GraphState [NodeDef])
-> [NodeDef] -> BuildT Identity ()
forall s (m :: * -> *) a b.
MonadState s m =>
Setter s s a b -> b -> m ()
.= []
    [NodeDef] -> Build [NodeDef]
forall (m :: * -> *) a. Monad m => a -> m a
return [NodeDef]
ns

-- | Get all the initializers that have accumulated so far, and clear
-- that buffer.
flushInitializers :: Monad m => BuildT m [NodeName]
flushInitializers :: BuildT m [NodeName]
flushInitializers = do
    [NodeName]
ns <- FoldLike [NodeName] GraphState GraphState [NodeName] [NodeName]
-> BuildT m [NodeName]
forall s (m :: * -> *) a t b.
MonadState s m =>
FoldLike a s t a b -> m a
use FoldLike [NodeName] GraphState GraphState [NodeName] [NodeName]
Lens GraphState GraphState [NodeName] [NodeName]
initializationNodes
    Lens GraphState GraphState [NodeName] [NodeName]
forall (f :: * -> *).
Identical f =>
LensLike' f GraphState [NodeName]
initializationNodes (forall (f :: * -> *).
 Identical f =>
 LensLike' f GraphState [NodeName])
-> [NodeName] -> BuildT m ()
forall s (m :: * -> *) a b.
MonadState s m =>
Setter s s a b -> b -> m ()
.= []
    [NodeName] -> BuildT m [NodeName]
forall (m :: * -> *) a. Monad m => a -> m a
return [NodeName]
ns

-- | Registers the given node to be executed before the next
-- 'TensorFlow.Session.run'.
addInitializer :: MonadBuild m => ControlNode -> m ()
addInitializer :: ControlNode -> m ()
addInitializer (ControlNode i :: NodeName
i) = BuildT Identity () -> m ()
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (BuildT Identity () -> m ()) -> BuildT Identity () -> m ()
forall a b. (a -> b) -> a -> b
$ Lens GraphState GraphState [NodeName] [NodeName]
forall (f :: * -> *).
Identical f =>
LensLike' f GraphState [NodeName]
initializationNodes (forall (f :: * -> *).
 Identical f =>
 LensLike' f GraphState [NodeName])
-> ([NodeName] -> [NodeName]) -> BuildT Identity ()
forall s (m :: * -> *) a b.
MonadState s m =>
Setter s s a b -> (a -> b) -> m ()
%= (NodeName
iNodeName -> [NodeName] -> [NodeName]
forall a. a -> [a] -> [a]
:)

-- | Produce a GraphDef proto representation of the nodes that are rendered in
-- the given 'Build' action.
asGraphDef :: Build a -> GraphDef
asGraphDef :: Build a -> GraphDef
asGraphDef b :: Build a
b = GraphDef
forall msg. Message msg => msg
defMessage GraphDef -> (GraphDef -> GraphDef) -> GraphDef
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f GraphDef [NodeDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "node" a) =>
LensLike' f s a
node (forall (f :: * -> *).
 Identical f =>
 LensLike' f GraphDef [NodeDef])
-> [NodeDef] -> GraphDef -> GraphDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ GraphState
gs GraphState
-> FoldLike [NodeDef] GraphState GraphState [NodeDef] [NodeDef]
-> [NodeDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [NodeDef] GraphState GraphState [NodeDef] [NodeDef]
Lens GraphState GraphState [NodeDef] [NodeDef]
nodeBuffer
  where
    gs :: GraphState
gs = (a, GraphState) -> GraphState
forall a b. (a, b) -> b
snd ((a, GraphState) -> GraphState) -> (a, GraphState) -> GraphState
forall a b. (a -> b) -> a -> b
$ Identity (a, GraphState) -> (a, GraphState)
forall a. Identity a -> a
runIdentity (Identity (a, GraphState) -> (a, GraphState))
-> Identity (a, GraphState) -> (a, GraphState)
forall a b. (a -> b) -> a -> b
$ Build a -> Identity (a, GraphState)
forall (m :: * -> *) a. BuildT m a -> m (a, GraphState)
runBuildT Build a
b

-- TODO: check against existing nodes for conflicts?
addGraphDef :: MonadBuild m => GraphDef -> m ()
addGraphDef :: GraphDef -> m ()
addGraphDef g :: GraphDef
g = BuildT Identity () -> m ()
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (BuildT Identity () -> m ()) -> BuildT Identity () -> m ()
forall a b. (a -> b) -> a -> b
$ Lens GraphState GraphState [NodeDef] [NodeDef]
forall (f :: * -> *).
Identical f =>
LensLike' f GraphState [NodeDef]
nodeBuffer (forall (f :: * -> *).
 Identical f =>
 LensLike' f GraphState [NodeDef])
-> [NodeDef] -> BuildT Identity ()
forall s (m :: * -> *) a.
(MonadState s m, Monoid a) =>
Setter' s a -> a -> m ()
<>= GraphDef
g GraphDef
-> FoldLike [NodeDef] GraphDef GraphDef [NodeDef] [NodeDef]
-> [NodeDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [NodeDef] GraphDef GraphDef [NodeDef] [NodeDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "node" a) =>
LensLike' f s a
node

-- | Render the given op if it hasn't been rendered already, and return its
-- name.
getOrAddOp :: OpDef -> Build NodeName
getOrAddOp :: OpDef -> Build NodeName
getOrAddOp o :: OpDef
o = do
    PendingNode
pending <- OpDef -> Build PendingNode
getPendingNode OpDef
o
    FoldLike
  (Maybe NodeDef)
  GraphState
  GraphState
  (Map PendingNode NodeDef)
  (Map PendingNode NodeDef)
-> (Map PendingNode NodeDef -> Maybe NodeDef)
-> BuildT Identity (Maybe NodeDef)
forall s (m :: * -> *) r t a b.
MonadState s m =>
FoldLike r s t a b -> (a -> r) -> m r
uses FoldLike
  (Maybe NodeDef)
  GraphState
  GraphState
  (Map PendingNode NodeDef)
  (Map PendingNode NodeDef)
Lens
  GraphState
  GraphState
  (Map PendingNode NodeDef)
  (Map PendingNode NodeDef)
renderedNodes (PendingNode -> Map PendingNode NodeDef -> Maybe NodeDef
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup PendingNode
pending) BuildT Identity (Maybe NodeDef)
-> (Maybe NodeDef -> Build NodeName) -> Build NodeName
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just n :: NodeDef
n -> NodeName -> Build NodeName
forall (m :: * -> *) a. Monad m => a -> m a
return (NodeName -> Build NodeName) -> NodeName -> Build NodeName
forall a b. (a -> b) -> a -> b
$ Text -> NodeName
NodeName (Text -> NodeName) -> Text -> NodeName
forall a b. (a -> b) -> a -> b
$ NodeDef
n NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name
        Nothing -> PendingNode -> Build NodeName
addNewOpFromPending PendingNode
pending

lookupNode :: NodeName -> Build NodeDef
lookupNode :: NodeName -> Build NodeDef
lookupNode n :: NodeName
n = FoldLike
  (Maybe NodeDef)
  GraphState
  GraphState
  (Map NodeName NodeDef)
  (Map NodeName NodeDef)
-> (Map NodeName NodeDef -> Maybe NodeDef)
-> BuildT Identity (Maybe NodeDef)
forall s (m :: * -> *) r t a b.
MonadState s m =>
FoldLike r s t a b -> (a -> r) -> m r
uses FoldLike
  (Maybe NodeDef)
  GraphState
  GraphState
  (Map NodeName NodeDef)
  (Map NodeName NodeDef)
Lens
  GraphState GraphState (Map NodeName NodeDef) (Map NodeName NodeDef)
renderedNodeDefs (NodeName -> Map NodeName NodeDef -> Maybe NodeDef
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup NodeName
n) BuildT Identity (Maybe NodeDef)
-> (Maybe NodeDef -> Build NodeDef) -> Build NodeDef
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just n' :: NodeDef
n' -> NodeDef -> Build NodeDef
forall (m :: * -> *) a. Monad m => a -> m a
return NodeDef
n'
    Nothing -> String -> Build NodeDef
forall a. HasCallStack => String -> a
error (String -> Build NodeDef) -> String -> Build NodeDef
forall a b. (a -> b) -> a -> b
$ "lookupNode: unknown node name " String -> ShowS
forall a. [a] -> [a] -> [a]
++ NodeName -> String
forall a. Show a => a -> String
show NodeName
n

-- | Add a new node for a given 'OpDef'.  This is used for making "stateful" ops
-- which are not safe to dedup (e.g, "variable" and "assign").
addNewOp :: OpDef -> Build NodeName
addNewOp :: OpDef -> Build NodeName
addNewOp o :: OpDef
o = OpDef -> Build PendingNode
getPendingNode OpDef
o Build PendingNode
-> (PendingNode -> Build NodeName) -> Build NodeName
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= PendingNode -> Build NodeName
addNewOpFromPending

addNewOpFromPending :: PendingNode -> Build NodeName
addNewOpFromPending :: PendingNode -> Build NodeName
addNewOpFromPending pending :: PendingNode
pending = do
    NodeName
nodeName <- PendingNode -> Build NodeName
renderPendingNode PendingNode
pending
    let nodeDef :: NodeDef
nodeDef = PendingNode -> NodeDef
pendingNodeDef PendingNode
pending NodeDef -> (NodeDef -> NodeDef) -> NodeDef
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f NodeDef Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name (forall (f :: * -> *). Identical f => LensLike' f NodeDef Text)
-> Text -> NodeDef -> NodeDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ NodeName -> Text
unNodeName NodeName
nodeName
    Lens GraphState GraphState [NodeDef] [NodeDef]
forall (f :: * -> *).
Identical f =>
LensLike' f GraphState [NodeDef]
nodeBuffer (forall (f :: * -> *).
 Identical f =>
 LensLike' f GraphState [NodeDef])
-> ([NodeDef] -> [NodeDef]) -> BuildT Identity ()
forall s (m :: * -> *) a b.
MonadState s m =>
Setter s s a b -> (a -> b) -> m ()
%= (NodeDef
nodeDef NodeDef -> [NodeDef] -> [NodeDef]
forall a. a -> [a] -> [a]
:)
    Lens
  GraphState
  GraphState
  (Map PendingNode NodeDef)
  (Map PendingNode NodeDef)
forall (f :: * -> *).
Identical f =>
LensLike' f GraphState (Map PendingNode NodeDef)
renderedNodes (forall (f :: * -> *).
 Identical f =>
 LensLike' f GraphState (Map PendingNode NodeDef))
-> (Map PendingNode NodeDef -> Map PendingNode NodeDef)
-> BuildT Identity ()
forall s (m :: * -> *) a b.
MonadState s m =>
Setter s s a b -> (a -> b) -> m ()
%= PendingNode
-> NodeDef -> Map PendingNode NodeDef -> Map PendingNode NodeDef
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert PendingNode
pending NodeDef
nodeDef
    Lens
  GraphState GraphState (Map NodeName NodeDef) (Map NodeName NodeDef)
forall (f :: * -> *).
Identical f =>
LensLike' f GraphState (Map NodeName NodeDef)
renderedNodeDefs (forall (f :: * -> *).
 Identical f =>
 LensLike' f GraphState (Map NodeName NodeDef))
-> (Map NodeName NodeDef -> Map NodeName NodeDef)
-> BuildT Identity ()
forall s (m :: * -> *) a b.
MonadState s m =>
Setter s s a b -> (a -> b) -> m ()
%= NodeName -> NodeDef -> Map NodeName NodeDef -> Map NodeName NodeDef
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert NodeName
nodeName NodeDef
nodeDef
    NodeName -> Build NodeName
forall (m :: * -> *) a. Monad m => a -> m a
return NodeName
nodeName

-- | Get the pending node corresponding to an OpDef, which may or may not have
-- been rendered before.  Implicitly renders all of this node's inputs.
getPendingNode :: OpDef -> Build PendingNode
getPendingNode :: OpDef -> Build PendingNode
getPendingNode o :: OpDef
o = do
    -- An empty string in the proto field means that no specific
    -- device is specified.
    Text
dev <- Text -> (Device -> Text) -> Maybe Device -> Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe "" Device -> Text
deviceName (Maybe Device -> Text)
-> BuildT Identity (Maybe Device) -> BuildT Identity Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FoldLike
  (Maybe Device) GraphState GraphState (Maybe Device) (Maybe Device)
-> BuildT Identity (Maybe Device)
forall s (m :: * -> *) a t b.
MonadState s m =>
FoldLike a s t a b -> m a
use FoldLike
  (Maybe Device) GraphState GraphState (Maybe Device) (Maybe Device)
Lens GraphState GraphState (Maybe Device) (Maybe Device)
defaultDevice
    [Scope]
scope <- FoldLike [Scope] GraphState GraphState [Scope] [Scope]
-> BuildT Identity [Scope]
forall s (m :: * -> *) a t b.
MonadState s m =>
FoldLike a s t a b -> m a
use FoldLike [Scope] GraphState GraphState [Scope] [Scope]
Lens GraphState GraphState [Scope] [Scope]
currentScope
    Set NodeName
controls <- FoldLike
  (Set NodeName) GraphState GraphState (Set NodeName) (Set NodeName)
-> BuildT Identity (Set NodeName)
forall s (m :: * -> *) a t b.
MonadState s m =>
FoldLike a s t a b -> m a
use FoldLike
  (Set NodeName) GraphState GraphState (Set NodeName) (Set NodeName)
Lens GraphState GraphState (Set NodeName) (Set NodeName)
defaultControlInputs
    let inputs :: [Text]
inputs = (Output -> Text) -> [Output] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Output -> Text
encodeOutput (OpDef
o OpDef
-> FoldLike [Output] OpDef OpDef [Output] [Output] -> [Output]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [Output] OpDef OpDef [Output] [Output]
Lens' OpDef [Output]
opInputs)
    let controlInputs :: [Text]
controlInputs
            = (NodeName -> Text) -> [NodeName] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map NodeName -> Text
makeDep (OpDef
o OpDef
-> FoldLike [NodeName] OpDef OpDef [NodeName] [NodeName]
-> [NodeName]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [NodeName] OpDef OpDef [NodeName] [NodeName]
Lens' OpDef [NodeName]
opControlInputs [NodeName] -> [NodeName] -> [NodeName]
forall a. [a] -> [a] -> [a]
++ Set NodeName -> [NodeName]
forall a. Set a -> [a]
Set.toList Set NodeName
controls)
    PendingNode -> Build PendingNode
forall (m :: * -> *) a. Monad m => a -> m a
return (PendingNode -> Build PendingNode)
-> PendingNode -> Build PendingNode
forall a b. (a -> b) -> a -> b
$ [Scope] -> PendingNodeName -> NodeDef -> PendingNode
PendingNode [Scope]
scope (OpDef
o OpDef
-> FoldLike
     PendingNodeName OpDef OpDef PendingNodeName PendingNodeName
-> PendingNodeName
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike
  PendingNodeName OpDef OpDef PendingNodeName PendingNodeName
Lens' OpDef PendingNodeName
opName)
            (NodeDef -> PendingNode) -> NodeDef -> PendingNode
forall a b. (a -> b) -> a -> b
$ NodeDef
forall msg. Message msg => msg
defMessage NodeDef -> (NodeDef -> NodeDef) -> NodeDef
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f NodeDef Text
forall (f :: * -> *) s a.
(Functor f, HasField s "op" a) =>
LensLike' f s a
op (forall (f :: * -> *). Identical f => LensLike' f NodeDef Text)
-> Text -> NodeDef -> NodeDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ (OpType -> Text
unOpType (OpDef
o OpDef -> FoldLike OpType OpDef OpDef OpType OpType -> OpType
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike OpType OpDef OpDef OpType OpType
Lens' OpDef OpType
opType) :: Text)
                  NodeDef -> (NodeDef -> NodeDef) -> NodeDef
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f NodeDef (Map Text AttrValue)
forall (f :: * -> *) s a.
(Functor f, HasField s "attr" a) =>
LensLike' f s a
attr (forall (f :: * -> *).
 Identical f =>
 LensLike' f NodeDef (Map Text AttrValue))
-> Map Text AttrValue -> NodeDef -> NodeDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ OpDef -> Map Text AttrValue
_opAttrs OpDef
o
                  NodeDef -> (NodeDef -> NodeDef) -> NodeDef
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f NodeDef [Text]
forall (f :: * -> *) s a.
(Functor f, HasField s "input" a) =>
LensLike' f s a
input (forall (f :: * -> *). Identical f => LensLike' f NodeDef [Text])
-> [Text] -> NodeDef -> NodeDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ([Text]
inputs [Text] -> [Text] -> [Text]
forall a. [a] -> [a] -> [a]
++ [Text]
controlInputs)
                  NodeDef -> (NodeDef -> NodeDef) -> NodeDef
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f NodeDef Text
forall (f :: * -> *) s a.
(Functor f, HasField s "device" a) =>
LensLike' f s a
device (forall (f :: * -> *). Identical f => LensLike' f NodeDef Text)
-> Text -> NodeDef -> NodeDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Text
dev
  where
    makeDep :: NodeName -> Text
makeDep = ("^" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>) (Text -> Text) -> (NodeName -> Text) -> NodeName -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeName -> Text
unNodeName

-- | Pick a name for a pending node.  If it has an explicit name, just use that;
-- if the name is implicit, assign a new unique name based on the op type.
renderPendingNode :: PendingNode -> Build NodeName
renderPendingNode :: PendingNode -> Build NodeName
renderPendingNode (PendingNode scope :: [Scope]
scope pendingName :: PendingNodeName
pendingName nodeDef :: NodeDef
nodeDef)
    = Text -> NodeName
NodeName (Text -> NodeName) -> (Text -> Text) -> Text -> NodeName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text
scopePrefix Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>) (Text -> NodeName) -> BuildT Identity Text -> Build NodeName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuildT Identity Text
getName
  where
    scopePrefix :: Text
scopePrefix = [Text] -> Text
Text.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (Scope -> Text) -> [Scope] -> [Text]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> "/") (Text -> Text) -> (Scope -> Text) -> Scope -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope -> Text
unScope) [Scope]
scope
    getName :: BuildT Identity Text
getName = case PendingNodeName
pendingName of
        ExplicitName n :: Text
n -> Text -> BuildT Identity Text
forall (m :: * -> *) a. Monad m => a -> m a
return Text
n
        ImplicitName -> do
            u :: Unique
u@(Unique k :: Int
k) <- FoldLike Unique GraphState GraphState Unique Unique
-> BuildT Identity Unique
forall s (m :: * -> *) a t b.
MonadState s m =>
FoldLike a s t a b -> m a
use FoldLike Unique GraphState GraphState Unique Unique
Lens GraphState GraphState Unique Unique
nextUnique
            Lens GraphState GraphState Unique Unique
forall (f :: * -> *). Identical f => LensLike' f GraphState Unique
nextUnique (forall (f :: * -> *).
 Identical f =>
 LensLike' f GraphState Unique)
-> Unique -> BuildT Identity ()
forall s (m :: * -> *) a b.
MonadState s m =>
Setter s s a b -> b -> m ()
.= Unique -> Unique
forall a. Enum a => a -> a
succ Unique
u
            Text -> BuildT Identity Text
forall (m :: * -> *) a. Monad m => a -> m a
return (Text -> BuildT Identity Text) -> Text -> BuildT Identity Text
forall a b. (a -> b) -> a -> b
$ NodeDef
nodeDef NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "op" a) =>
LensLike' f s a
op Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> "_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Text.pack (Int -> String
forall a. Show a => a -> String
show Int
k)

-- | Turn an 'Output' into a string representation for the TensorFlow
-- foreign APIs.
encodeOutput :: Output -> Text
encodeOutput :: Output -> Text
encodeOutput (Output (OutputIx 0) n :: NodeName
n) = NodeName -> Text
unNodeName NodeName
n
encodeOutput (Output (OutputIx i :: Int
i) n :: NodeName
n) = NodeName -> Text
unNodeName NodeName
n Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Text.pack (':' Char -> ShowS
forall a. a -> [a] -> [a]
: Int -> String
forall a. Show a => a -> String
show Int
i)

-- | Modify some part of the state, run an action, and restore the state
-- after that action is done.
withStateLens :: MonadBuild m => Lens' GraphState a -> (a -> a) -> m b -> m b
withStateLens :: Lens' GraphState a -> (a -> a) -> m b -> m b
withStateLens accessor :: Lens' GraphState a
accessor f :: a -> a
f act :: m b
act = do
    a
old <- Build a -> m a
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build a -> m a) -> Build a -> m a
forall a b. (a -> b) -> a -> b
$ FoldLike a GraphState GraphState a a -> Build a
forall s (m :: * -> *) a t b.
MonadState s m =>
FoldLike a s t a b -> m a
use FoldLike a GraphState GraphState a a
Lens' GraphState a
accessor
    BuildT Identity () -> m ()
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (BuildT Identity () -> m ()) -> BuildT Identity () -> m ()
forall a b. (a -> b) -> a -> b
$ Lens' GraphState a
forall (f :: * -> *). Identical f => LensLike' f GraphState a
accessor (forall (f :: * -> *). Identical f => LensLike' f GraphState a)
-> (a -> a) -> BuildT Identity ()
forall s (m :: * -> *) a b.
MonadState s m =>
Setter s s a b -> (a -> b) -> m ()
%= a -> a
f
    b
result <- m b
act
    BuildT Identity () -> m ()
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (BuildT Identity () -> m ()) -> BuildT Identity () -> m ()
forall a b. (a -> b) -> a -> b
$ Lens' GraphState a
forall (f :: * -> *). Identical f => LensLike' f GraphState a
accessor (forall (f :: * -> *). Identical f => LensLike' f GraphState a)
-> a -> BuildT Identity ()
forall s (m :: * -> *) a b.
MonadState s m =>
Setter s s a b -> b -> m ()
.= a
old
    b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
result

-- | Set a device for all nodes rendered in the given 'Build' action
-- (unless further overridden by another use of withDevice).
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
withDevice :: Maybe Device -> m a -> m a
withDevice d :: Maybe Device
d = Lens GraphState GraphState (Maybe Device) (Maybe Device)
-> (Maybe Device -> Maybe Device) -> m a -> m a
forall (m :: * -> *) a b.
MonadBuild m =>
Lens' GraphState a -> (a -> a) -> m b -> m b
withStateLens Lens GraphState GraphState (Maybe Device) (Maybe Device)
defaultDevice (Maybe Device -> Maybe Device -> Maybe Device
forall a b. a -> b -> a
const Maybe Device
d)

-- | Prepend a scope to all nodes rendered in the given 'Build' action.
withNameScope :: MonadBuild m => Text -> m a -> m a
withNameScope :: Text -> m a -> m a
withNameScope s :: Text
s = Lens GraphState GraphState [Scope] [Scope]
-> ([Scope] -> [Scope]) -> m a -> m a
forall (m :: * -> *) a b.
MonadBuild m =>
Lens' GraphState a -> (a -> a) -> m b -> m b
withStateLens Lens GraphState GraphState [Scope] [Scope]
currentScope (Text -> Scope
Scope Text
s Scope -> [Scope] -> [Scope]
forall a. a -> [a] -> [a]
:)

-- | Add control inputs to all nodes rendered in the given 'Build' action.
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
withNodeDependencies :: Set NodeName -> m a -> m a
withNodeDependencies nodes :: Set NodeName
nodes = Lens GraphState GraphState (Set NodeName) (Set NodeName)
-> (Set NodeName -> Set NodeName) -> m a -> m a
forall (m :: * -> *) a b.
MonadBuild m =>
Lens' GraphState a -> (a -> a) -> m b -> m b
withStateLens Lens GraphState GraphState (Set NodeName) (Set NodeName)
defaultControlInputs (Set NodeName -> Set NodeName -> Set NodeName
forall a. Semigroup a => a -> a -> a
<> Set NodeName
nodes)