{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.BuildOp
( BuildResult(..)
, buildOp
, PureResult(..)
, pureOp
, eqLengthGuard
, BuildInputs(..)
, OpParams
)
where
import Control.Monad (liftM2, replicateM)
import Control.Monad.Reader (ReaderT, runReaderT, ask)
import Control.Monad.State.Strict (State, evalState, get, put)
import Data.Int (Int64)
import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types
data ResultState = ResultState !OutputIx [Int64] deriving Int -> ResultState -> ShowS
[ResultState] -> ShowS
ResultState -> String
(Int -> ResultState -> ShowS)
-> (ResultState -> String)
-> ([ResultState] -> ShowS)
-> Show ResultState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultState] -> ShowS
$cshowList :: [ResultState] -> ShowS
show :: ResultState -> String
$cshow :: ResultState -> String
showsPrec :: Int -> ResultState -> ShowS
$cshowsPrec :: Int -> ResultState -> ShowS
Show
type Result = ReaderT NodeName (State ResultState)
class BuildResult a where
buildResult :: Result a
instance (BuildResult a1, BuildResult a2) => BuildResult (a1, a2) where
buildResult :: Result (a1, a2)
buildResult = (,) (a1 -> a2 -> (a1, a2))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT NodeName (State ResultState) (a2 -> (a1, a2))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult ReaderT NodeName (State ResultState) (a2 -> (a1, a2))
-> ReaderT NodeName (State ResultState) a2 -> Result (a1, a2)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult
instance (BuildResult a1, BuildResult a2, BuildResult a3) => BuildResult (a1, a2, a3) where
buildResult :: Result (a1, a2, a3)
buildResult = (,,) (a1 -> a2 -> a3 -> (a1, a2, a3))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT NodeName (State ResultState) (a2 -> a3 -> (a1, a2, a3))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult ReaderT NodeName (State ResultState) (a2 -> a3 -> (a1, a2, a3))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT NodeName (State ResultState) (a3 -> (a1, a2, a3))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult ReaderT NodeName (State ResultState) (a3 -> (a1, a2, a3))
-> ReaderT NodeName (State ResultState) a3 -> Result (a1, a2, a3)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult
instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4)
=> BuildResult (a1, a2, a3, a4) where
buildResult :: Result (a1, a2, a3, a4)
buildResult = (,,,) (a1 -> a2 -> a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT
NodeName (State ResultState) (a2 -> a3 -> a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult ReaderT
NodeName (State ResultState) (a2 -> a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT
NodeName (State ResultState) (a3 -> a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult ReaderT NodeName (State ResultState) (a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT NodeName (State ResultState) a3
-> ReaderT NodeName (State ResultState) (a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult ReaderT NodeName (State ResultState) (a4 -> (a1, a2, a3, a4))
-> ReaderT NodeName (State ResultState) a4
-> Result (a1, a2, a3, a4)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a4
forall a. BuildResult a => Result a
buildResult
instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4, BuildResult a5)
=> BuildResult (a1, a2, a3, a4, a5) where
buildResult :: Result (a1, a2, a3, a4, a5)
buildResult = (,,,,) (a1 -> a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT
NodeName
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT
NodeName
(State ResultState)
(a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT NodeName (State ResultState) a3
-> ReaderT
NodeName (State ResultState) (a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName (State ResultState) (a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT NodeName (State ResultState) a4
-> ReaderT
NodeName (State ResultState) (a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a4
forall a. BuildResult a => Result a
buildResult
ReaderT NodeName (State ResultState) (a5 -> (a1, a2, a3, a4, a5))
-> ReaderT NodeName (State ResultState) a5
-> Result (a1, a2, a3, a4, a5)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a5
forall a. BuildResult a => Result a
buildResult
instance ( BuildResult a1
, BuildResult a2
, BuildResult a3
, BuildResult a4
, BuildResult a5
, BuildResult a6
)
=> BuildResult (a1, a2, a3, a4, a5, a6) where
buildResult :: Result (a1, a2, a3, a4, a5, a6)
buildResult = (,,,,,)
(a1 -> a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT
NodeName
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT
NodeName
(State ResultState)
(a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a3
-> ReaderT
NodeName
(State ResultState)
(a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a4
-> ReaderT
NodeName (State ResultState) (a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a4
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName (State ResultState) (a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a5
-> ReaderT
NodeName (State ResultState) (a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a5
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName (State ResultState) (a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a6
-> Result (a1, a2, a3, a4, a5, a6)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a6
forall a. BuildResult a => Result a
buildResult
instance ( BuildResult a1
, BuildResult a2
, BuildResult a3
, BuildResult a4
, BuildResult a5
, BuildResult a6
, BuildResult a7
)
=> BuildResult (a1, a2, a3, a4, a5, a6, a7) where
buildResult :: Result (a1, a2, a3, a4, a5, a6, a7)
buildResult = (,,,,,,)
(a1
-> a2
-> a3
-> a4
-> a5
-> a6
-> a7
-> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT
NodeName
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT
NodeName
(State ResultState)
(a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a3
-> ReaderT
NodeName
(State ResultState)
(a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a4
-> ReaderT
NodeName
(State ResultState)
(a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a4
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a5
-> ReaderT
NodeName
(State ResultState)
(a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a5
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a6
-> ReaderT
NodeName (State ResultState) (a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a6
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName (State ResultState) (a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a7
-> Result (a1, a2, a3, a4, a5, a6, a7)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a7
forall a. BuildResult a => Result a
buildResult
instance ( BuildResult a1
, BuildResult a2
, BuildResult a3
, BuildResult a4
, BuildResult a5
, BuildResult a6
, BuildResult a7
, BuildResult a8
)
=> BuildResult (a1, a2, a3, a4, a5, a6, a7, a8) where
buildResult :: Result (a1, a2, a3, a4, a5, a6, a7, a8)
buildResult = (,,,,,,,)
(a1
-> a2
-> a3
-> a4
-> a5
-> a6
-> a7
-> a8
-> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT
NodeName
(State ResultState)
(a2
-> a3
-> a4
-> a5
-> a6
-> a7
-> a8
-> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a2
-> a3
-> a4
-> a5
-> a6
-> a7
-> a8
-> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT
NodeName
(State ResultState)
(a3
-> a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a3
-> a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a3
-> ReaderT
NodeName
(State ResultState)
(a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a4
-> ReaderT
NodeName
(State ResultState)
(a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a4
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a5
-> ReaderT
NodeName
(State ResultState)
(a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a5
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a6
-> ReaderT
NodeName
(State ResultState)
(a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a6
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a7
-> ReaderT
NodeName
(State ResultState)
(a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a7
forall a. BuildResult a => Result a
buildResult
ReaderT
NodeName
(State ResultState)
(a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a8
-> Result (a1, a2, a3, a4, a5, a6, a7, a8)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a8
forall a. BuildResult a => Result a
buildResult
recordResult :: Result Output
recordResult :: Result Output
recordResult = do
NodeName
o <- ReaderT NodeName (State ResultState) NodeName
forall r (m :: * -> *). MonadReader r m => m r
ask
ResultState i :: OutputIx
i ns :: [Int64]
ns <- ReaderT NodeName (State ResultState) ResultState
forall s (m :: * -> *). MonadState s m => m s
get
ResultState -> ReaderT NodeName (State ResultState) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ResultState -> ReaderT NodeName (State ResultState) ())
-> ResultState -> ReaderT NodeName (State ResultState) ()
forall a b. (a -> b) -> a -> b
$! OutputIx -> [Int64] -> ResultState
ResultState (OutputIx
iOutputIx -> OutputIx -> OutputIx
forall a. Num a => a -> a -> a
+1) [Int64]
ns
Output -> Result Output
forall (m :: * -> *) a. Monad m => a -> m a
return (Output -> Result Output) -> Output -> Result Output
forall a b. (a -> b) -> a -> b
$! OutputIx -> NodeName -> Output
output OutputIx
i NodeName
o
instance (TensorKind v, Rendered (Tensor v)) => BuildResult (Tensor v a) where
buildResult :: Result (Tensor v a)
buildResult = v Output -> Tensor v a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (v Output -> Tensor v a)
-> (Output -> v Output) -> Output -> Tensor v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> v Output
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Output -> Tensor v a) -> Result Output -> Result (Tensor v a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result Output
recordResult
instance BuildResult ControlNode where
buildResult :: Result ControlNode
buildResult = NodeName -> ControlNode
ControlNode (NodeName -> ControlNode)
-> ReaderT NodeName (State ResultState) NodeName
-> Result ControlNode
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) NodeName
forall r (m :: * -> *). MonadReader r m => m r
ask
instance (TensorKind v, Rendered (Tensor v), TensorTypes as) => BuildResult (TensorList v as) where
buildResult :: Result (TensorList v as)
buildResult = TensorTypeList as -> Result (TensorList v as)
forall (bs :: [*]). TensorTypeList bs -> Result (TensorList v bs)
loop (TensorTypeList as
forall (ts :: [*]). TensorTypes ts => TensorTypeList ts
tensorTypes :: TensorTypeList as)
where
loop :: TensorTypeList bs -> Result (TensorList v bs)
loop :: TensorTypeList bs -> Result (TensorList v bs)
loop Nil = ListOf (Tensor v) '[]
-> ReaderT NodeName (State ResultState) (ListOf (Tensor v) '[])
forall (m :: * -> *) a. Monad m => a -> m a
return ListOf (Tensor v) '[]
forall (f :: * -> *). ListOf f '[]
Nil
loop (TensorTypeProxy :/ ls :: ListOf TensorTypeProxy as
ls) = do
Tensor v a
t <- Result (Tensor v a)
forall a. BuildResult a => Result a
buildResult
TensorList v as
ts <- ListOf TensorTypeProxy as -> Result (TensorList v as)
forall (bs :: [*]). TensorTypeList bs -> Result (TensorList v bs)
loop ListOf TensorTypeProxy as
ls
ListOf (Tensor v) (a : as)
-> ReaderT
NodeName (State ResultState) (ListOf (Tensor v) (a : as))
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor v a
t Tensor v a -> TensorList v as -> ListOf (Tensor v) (a : as)
forall (f :: * -> *) a (as :: [*]).
f a -> ListOf f as -> ListOf f (a : as)
:/ TensorList v as
ts)
instance BuildResult a => BuildResult [a] where
buildResult :: Result [a]
buildResult = do
ResultState i :: OutputIx
i ns :: [Int64]
ns <- ReaderT NodeName (State ResultState) ResultState
forall s (m :: * -> *). MonadState s m => m s
get
case [Int64]
ns of
[] -> String -> Result [a]
forall a. HasCallStack => String -> a
error (String -> Result [a]) -> String -> Result [a]
forall a b. (a -> b) -> a -> b
$ "Ran out of counts in buildResult. " String -> ShowS
forall a. [a] -> [a] -> [a]
++
"Likely misuse of buildOp."
(n :: Int64
n : rest :: [Int64]
rest) -> do
ResultState -> ReaderT NodeName (State ResultState) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ResultState -> ReaderT NodeName (State ResultState) ())
-> ResultState -> ReaderT NodeName (State ResultState) ()
forall a b. (a -> b) -> a -> b
$! OutputIx -> [Int64] -> ResultState
ResultState OutputIx
i [Int64]
rest
Int -> ReaderT NodeName (State ResultState) a -> Result [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n) ReaderT NodeName (State ResultState) a
forall a. BuildResult a => Result a
buildResult
buildOp :: BuildResult a => [Int64] -> OpDef -> Build a
buildOp :: [Int64] -> OpDef -> Build a
buildOp sizes :: [Int64]
sizes o :: OpDef
o = do
NodeName
n <- OpDef -> Build NodeName
addNewOp OpDef
o
a -> Build a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Build a) -> a -> Build a
forall a b. (a -> b) -> a -> b
$ (State ResultState a -> ResultState -> a)
-> ResultState -> State ResultState a -> a
forall a b c. (a -> b -> c) -> b -> a -> c
flip State ResultState a -> ResultState -> a
forall s a. State s a -> s -> a
evalState (OutputIx -> [Int64] -> ResultState
ResultState 0 [Int64]
sizes) (ReaderT NodeName (State ResultState) a
-> NodeName -> State ResultState a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT NodeName (State ResultState) a
forall a. BuildResult a => Result a
buildResult NodeName
n)
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
eqLengthGuard = ((String, [(String, Int)]) -> Bool)
-> [(String, [(String, Int)])] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (String, [(String, Int)]) -> Bool
forall a a. (Eq a, Show a, Show a) => (String, [(a, a)]) -> Bool
eachOk
where
eachOk :: (String, [(a, a)]) -> Bool
eachOk (_, []) = Bool
True
eachOk (numberAttrName :: String
numberAttrName, pairs :: [(a, a)]
pairs@((_, x :: a
x) : zs :: [(a, a)]
zs)) = ((a, a) -> Bool) -> [(a, a)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\z :: (a, a)
z -> (a, a) -> a
forall a b. (a, b) -> b
snd (a, a)
z a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x) [(a, a)]
zs Bool -> Bool -> Bool
||
String -> Bool
forall a. HasCallStack => String -> a
error ("number_attr " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
numberAttrName String -> ShowS
forall a. [a] -> [a] -> [a]
++
" contains tensors with different length " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [(a, a)] -> String
forall a. Show a => a -> String
show [(a, a)]
pairs)
class PureResult a where
pureResult :: ReaderT (Build OpDef) (State ResultState) a
instance PureResult (Tensor Build a) where
pureResult :: ReaderT (Build OpDef) (State ResultState) (Tensor Build a)
pureResult = do
ResultState i :: OutputIx
i ns :: [Int64]
ns <- ReaderT (Build OpDef) (State ResultState) ResultState
forall s (m :: * -> *). MonadState s m => m s
get
ResultState -> ReaderT (Build OpDef) (State ResultState) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ResultState -> ReaderT (Build OpDef) (State ResultState) ())
-> ResultState -> ReaderT (Build OpDef) (State ResultState) ()
forall a b. (a -> b) -> a -> b
$! OutputIx -> [Int64] -> ResultState
ResultState (OutputIx
iOutputIx -> OutputIx -> OutputIx
forall a. Num a => a -> a -> a
+1) [Int64]
ns
Build OpDef
makeOp <- ReaderT (Build OpDef) (State ResultState) (Build OpDef)
forall r (m :: * -> *). MonadReader r m => m r
ask
Tensor Build a
-> ReaderT (Build OpDef) (State ResultState) (Tensor Build a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor Build a
-> ReaderT (Build OpDef) (State ResultState) (Tensor Build a))
-> Tensor Build a
-> ReaderT (Build OpDef) (State ResultState) (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ BuildT Identity Output -> Tensor Build a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (BuildT Identity Output -> Tensor Build a)
-> BuildT Identity Output -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ do
OpDef
o <- Build OpDef
makeOp
OutputIx -> NodeName -> Output
output OutputIx
i (NodeName -> Output) -> Build NodeName -> BuildT Identity Output
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpDef -> Build NodeName
getOrAddOp OpDef
o
instance (PureResult a1, PureResult a2) => PureResult (a1, a2) where
pureResult :: ReaderT (Build OpDef) (State ResultState) (a1, a2)
pureResult = (,) (a1 -> a2 -> (a1, a2))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT (Build OpDef) (State ResultState) (a2 -> (a1, a2))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT (Build OpDef) (State ResultState) (a2 -> (a1, a2))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT (Build OpDef) (State ResultState) (a1, a2)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
instance (PureResult a1, PureResult a2, PureResult a3) => PureResult (a1, a2, a3) where
pureResult :: ReaderT (Build OpDef) (State ResultState) (a1, a2, a3)
pureResult = (,,) (a1 -> a2 -> a3 -> (a1, a2, a3))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
(Build OpDef) (State ResultState) (a2 -> a3 -> (a1, a2, a3))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT
(Build OpDef) (State ResultState) (a2 -> a3 -> (a1, a2, a3))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT (Build OpDef) (State ResultState) (a3 -> (a1, a2, a3))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT (Build OpDef) (State ResultState) (a3 -> (a1, a2, a3))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT (Build OpDef) (State ResultState) (a1, a2, a3)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4)
=> PureResult (a1, a2, a3, a4) where
pureResult :: ReaderT (Build OpDef) (State ResultState) (a1, a2, a3, a4)
pureResult = (,,,) (a1 -> a2 -> a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
(Build OpDef)
(State ResultState)
(a2 -> a3 -> a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT
(Build OpDef)
(State ResultState)
(a2 -> a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT
(Build OpDef) (State ResultState) (a3 -> a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT
(Build OpDef) (State ResultState) (a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT
(Build OpDef) (State ResultState) (a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT (Build OpDef) (State ResultState) (a4 -> (a1, a2, a3, a4))
-> ReaderT (Build OpDef) (State ResultState) a4
-> ReaderT (Build OpDef) (State ResultState) (a1, a2, a3, a4)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a4
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4, PureResult a5)
=> PureResult (a1, a2, a3, a4, a5) where
pureResult :: ReaderT (Build OpDef) (State ResultState) (a1, a2, a3, a4, a5)
pureResult = (,,,,) (a1 -> a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
(Build OpDef)
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT
(Build OpDef)
(State ResultState)
(a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT
(Build OpDef)
(State ResultState)
(a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT (Build OpDef) (State ResultState) a4
-> ReaderT
(Build OpDef) (State ResultState) (a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a4
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef) (State ResultState) (a5 -> (a1, a2, a3, a4, a5))
-> ReaderT (Build OpDef) (State ResultState) a5
-> ReaderT (Build OpDef) (State ResultState) (a1, a2, a3, a4, a5)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a5
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
instance ( PureResult a1
, PureResult a2
, PureResult a3
, PureResult a4
, PureResult a5
, PureResult a6
)
=> PureResult (a1, a2, a3, a4, a5, a6) where
pureResult :: ReaderT (Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6)
pureResult = (,,,,,)
(a1 -> a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
(Build OpDef)
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT
(Build OpDef)
(State ResultState)
(a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT
(Build OpDef)
(State ResultState)
(a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a4
-> ReaderT
(Build OpDef)
(State ResultState)
(a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a4
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a5
-> ReaderT
(Build OpDef) (State ResultState) (a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a5
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef) (State ResultState) (a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a6
-> ReaderT
(Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a6
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
instance ( PureResult a1
, PureResult a2
, PureResult a3
, PureResult a4
, PureResult a5
, PureResult a6
, PureResult a7
)
=> PureResult (a1, a2, a3, a4, a5, a6, a7) where
pureResult :: ReaderT
(Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6, a7)
pureResult = (,,,,,,)
(a1
-> a2
-> a3
-> a4
-> a5
-> a6
-> a7
-> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
(Build OpDef)
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a2 -> a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT
(Build OpDef)
(State ResultState)
(a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT
(Build OpDef)
(State ResultState)
(a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a4
-> ReaderT
(Build OpDef)
(State ResultState)
(a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a4
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a5
-> ReaderT
(Build OpDef)
(State ResultState)
(a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a5
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a6
-> ReaderT
(Build OpDef)
(State ResultState)
(a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a6
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a7
-> ReaderT
(Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6, a7)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a7
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
instance ( PureResult a1
, PureResult a2
, PureResult a3
, PureResult a4
, PureResult a5
, PureResult a6
, PureResult a7
, PureResult a8
)
=> PureResult (a1, a2, a3, a4, a5, a6, a7, a8) where
pureResult :: ReaderT
(Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6, a7, a8)
pureResult = (,,,,,,,)
(a1
-> a2
-> a3
-> a4
-> a5
-> a6
-> a7
-> a8
-> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
(Build OpDef)
(State ResultState)
(a2
-> a3
-> a4
-> a5
-> a6
-> a7
-> a8
-> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a2
-> a3
-> a4
-> a5
-> a6
-> a7
-> a8
-> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT
(Build OpDef)
(State ResultState)
(a3
-> a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a3
-> a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT
(Build OpDef)
(State ResultState)
(a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a4
-> ReaderT
(Build OpDef)
(State ResultState)
(a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a4
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a5
-> ReaderT
(Build OpDef)
(State ResultState)
(a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a5
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a6
-> ReaderT
(Build OpDef)
(State ResultState)
(a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a6
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a7
-> ReaderT
(Build OpDef)
(State ResultState)
(a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a7
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
ReaderT
(Build OpDef)
(State ResultState)
(a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a8
-> ReaderT
(Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6, a7, a8)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a8
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
instance PureResult a => PureResult [a] where
pureResult :: ReaderT (Build OpDef) (State ResultState) [a]
pureResult = do
ResultState i :: OutputIx
i ns :: [Int64]
ns <- ReaderT (Build OpDef) (State ResultState) ResultState
forall s (m :: * -> *). MonadState s m => m s
get
case [Int64]
ns of
[] -> String -> ReaderT (Build OpDef) (State ResultState) [a]
forall a. HasCallStack => String -> a
error (String -> ReaderT (Build OpDef) (State ResultState) [a])
-> String -> ReaderT (Build OpDef) (State ResultState) [a]
forall a b. (a -> b) -> a -> b
$ "Ran out of counts in pureResult. " String -> ShowS
forall a. [a] -> [a] -> [a]
++
"Likely misuse of pureOp with output lists."
n :: Int64
n : rest :: [Int64]
rest -> do
ResultState -> ReaderT (Build OpDef) (State ResultState) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ResultState -> ReaderT (Build OpDef) (State ResultState) ())
-> ResultState -> ReaderT (Build OpDef) (State ResultState) ()
forall a b. (a -> b) -> a -> b
$! OutputIx -> [Int64] -> ResultState
ResultState OutputIx
i [Int64]
rest
Int
-> ReaderT (Build OpDef) (State ResultState) a
-> ReaderT (Build OpDef) (State ResultState) [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n) ReaderT (Build OpDef) (State ResultState) a
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
instance TensorTypes as => PureResult (TensorList Build as) where
pureResult :: ReaderT (Build OpDef) (State ResultState) (TensorList Build as)
pureResult = TensorTypeList as
-> ReaderT (Build OpDef) (State ResultState) (TensorList Build as)
forall (bs :: [*]).
TensorTypeList bs
-> ReaderT (Build OpDef) (State ResultState) (TensorList Build bs)
loop (TensorTypeList as
forall (ts :: [*]). TensorTypes ts => TensorTypeList ts
tensorTypes :: TensorTypeList as)
where
loop :: TensorTypeList bs -> ReaderT (Build OpDef) (State ResultState)
(TensorList Build bs)
loop :: TensorTypeList bs
-> ReaderT (Build OpDef) (State ResultState) (TensorList Build bs)
loop Nil = ListOf (Tensor Build) '[]
-> ReaderT
(Build OpDef) (State ResultState) (ListOf (Tensor Build) '[])
forall (m :: * -> *) a. Monad m => a -> m a
return ListOf (Tensor Build) '[]
forall (f :: * -> *). ListOf f '[]
Nil
loop (TensorTypeProxy :/ ls :: ListOf TensorTypeProxy as
ls) = do
Tensor Build a
t <- ReaderT (Build OpDef) (State ResultState) (Tensor Build a)
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
TensorList Build as
ts <- ListOf TensorTypeProxy as
-> ReaderT (Build OpDef) (State ResultState) (TensorList Build as)
forall (bs :: [*]).
TensorTypeList bs
-> ReaderT (Build OpDef) (State ResultState) (TensorList Build bs)
loop ListOf TensorTypeProxy as
ls
ListOf (Tensor Build) (a : as)
-> ReaderT
(Build OpDef) (State ResultState) (ListOf (Tensor Build) (a : as))
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor Build a
t Tensor Build a
-> TensorList Build as -> ListOf (Tensor Build) (a : as)
forall (f :: * -> *) a (as :: [*]).
f a -> ListOf f as -> ListOf f (a : as)
:/ TensorList Build as
ts)
pureOp :: PureResult a => [Int64] -> Build OpDef -> a
pureOp :: [Int64] -> Build OpDef -> a
pureOp sizes :: [Int64]
sizes o :: Build OpDef
o = (State ResultState a -> ResultState -> a)
-> ResultState -> State ResultState a -> a
forall a b c. (a -> b -> c) -> b -> a -> c
flip State ResultState a -> ResultState -> a
forall s a. State s a -> s -> a
evalState (OutputIx -> [Int64] -> ResultState
ResultState 0 [Int64]
sizes) (ReaderT (Build OpDef) (State ResultState) a
-> Build OpDef -> State ResultState a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Build OpDef) (State ResultState) a
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult Build OpDef
o)
class BuildInputs a where
buildInputs :: a -> Build [Output]
instance BuildInputs a => BuildInputs [a] where
buildInputs :: [a] -> Build [Output]
buildInputs = ([[Output]] -> [Output])
-> BuildT Identity [[Output]] -> Build [Output]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Output]] -> [Output]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (BuildT Identity [[Output]] -> Build [Output])
-> ([a] -> BuildT Identity [[Output]]) -> [a] -> Build [Output]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Build [Output]) -> [a] -> BuildT Identity [[Output]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs
instance BuildInputs (Tensor v a) where
buildInputs :: Tensor v a -> Build [Output]
buildInputs (Tensor t :: v Output
t) = do
Output
o <- v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild v Output
t
[Output] -> Build [Output]
forall (m :: * -> *) a. Monad m => a -> m a
return [Output
o]
instance BuildInputs (ListOf (Tensor v) as) where
buildInputs :: ListOf (Tensor v) as -> Build [Output]
buildInputs Nil = [Output] -> Build [Output]
forall (m :: * -> *) a. Monad m => a -> m a
return []
buildInputs (t :: Tensor v a
t :/ ts :: ListOf (Tensor v) as
ts) = ([Output] -> [Output] -> [Output])
-> Build [Output] -> Build [Output] -> Build [Output]
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 [Output] -> [Output] -> [Output]
forall a. [a] -> [a] -> [a]
(++) (Tensor v a -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs Tensor v a
t) (ListOf (Tensor v) as -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs ListOf (Tensor v) as
ts)
type OpParams = OpDef -> OpDef