-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}

module TensorFlow.BuildOp
    ( BuildResult(..)
    , buildOp
    , PureResult(..)
    , pureOp
    , eqLengthGuard
    , BuildInputs(..)
    , OpParams
    )
  where

import Control.Monad (liftM2, replicateM)
import Control.Monad.Reader (ReaderT, runReaderT, ask)
import Control.Monad.State.Strict (State, evalState, get, put)
import Data.Int (Int64)

import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types

data ResultState = ResultState !OutputIx [Int64] deriving Int -> ResultState -> ShowS
[ResultState] -> ShowS
ResultState -> String
(Int -> ResultState -> ShowS)
-> (ResultState -> String)
-> ([ResultState] -> ShowS)
-> Show ResultState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultState] -> ShowS
$cshowList :: [ResultState] -> ShowS
show :: ResultState -> String
$cshow :: ResultState -> String
showsPrec :: Int -> ResultState -> ShowS
$cshowsPrec :: Int -> ResultState -> ShowS
Show

type Result = ReaderT NodeName (State ResultState)

-- | Class of types that can be used as op outputs.
class BuildResult a where
    buildResult :: Result a

instance (BuildResult a1, BuildResult a2) => BuildResult (a1, a2) where
    buildResult :: Result (a1, a2)
buildResult = (,) (a1 -> a2 -> (a1, a2))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT NodeName (State ResultState) (a2 -> (a1, a2))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult ReaderT NodeName (State ResultState) (a2 -> (a1, a2))
-> ReaderT NodeName (State ResultState) a2 -> Result (a1, a2)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult

instance (BuildResult a1, BuildResult a2, BuildResult a3) => BuildResult (a1, a2, a3) where
    buildResult :: Result (a1, a2, a3)
buildResult = (,,) (a1 -> a2 -> a3 -> (a1, a2, a3))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT NodeName (State ResultState) (a2 -> a3 -> (a1, a2, a3))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult ReaderT NodeName (State ResultState) (a2 -> a3 -> (a1, a2, a3))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT NodeName (State ResultState) (a3 -> (a1, a2, a3))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult ReaderT NodeName (State ResultState) (a3 -> (a1, a2, a3))
-> ReaderT NodeName (State ResultState) a3 -> Result (a1, a2, a3)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult

instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4)
         => BuildResult (a1, a2, a3, a4) where
    buildResult :: Result (a1, a2, a3, a4)
buildResult = (,,,) (a1 -> a2 -> a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT
     NodeName (State ResultState) (a2 -> a3 -> a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult ReaderT
  NodeName (State ResultState) (a2 -> a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT
     NodeName (State ResultState) (a3 -> a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult ReaderT NodeName (State ResultState) (a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT NodeName (State ResultState) a3
-> ReaderT NodeName (State ResultState) (a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult ReaderT NodeName (State ResultState) (a4 -> (a1, a2, a3, a4))
-> ReaderT NodeName (State ResultState) a4
-> Result (a1, a2, a3, a4)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a4
forall a. BuildResult a => Result a
buildResult

instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4, BuildResult a5)
         => BuildResult (a1, a2, a3, a4, a5) where
    buildResult :: Result (a1, a2, a3, a4, a5)
buildResult = (,,,,) (a1 -> a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT
     NodeName
     (State ResultState)
     (a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult
                      ReaderT
  NodeName
  (State ResultState)
  (a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT
     NodeName
     (State ResultState)
     (a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult
                      ReaderT
  NodeName
  (State ResultState)
  (a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT NodeName (State ResultState) a3
-> ReaderT
     NodeName (State ResultState) (a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult
                      ReaderT
  NodeName (State ResultState) (a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT NodeName (State ResultState) a4
-> ReaderT
     NodeName (State ResultState) (a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a4
forall a. BuildResult a => Result a
buildResult
                      ReaderT NodeName (State ResultState) (a5 -> (a1, a2, a3, a4, a5))
-> ReaderT NodeName (State ResultState) a5
-> Result (a1, a2, a3, a4, a5)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a5
forall a. BuildResult a => Result a
buildResult

instance ( BuildResult a1
         , BuildResult a2
         , BuildResult a3
         , BuildResult a4
         , BuildResult a5
         , BuildResult a6
         )
         => BuildResult (a1, a2, a3, a4, a5, a6) where
    buildResult :: Result (a1, a2, a3, a4, a5, a6)
buildResult = (,,,,,)
               (a1 -> a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT
     NodeName
     (State ResultState)
     (a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT
     NodeName
     (State ResultState)
     (a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a3
-> ReaderT
     NodeName
     (State ResultState)
     (a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a4
-> ReaderT
     NodeName (State ResultState) (a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a4
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName (State ResultState) (a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a5
-> ReaderT
     NodeName (State ResultState) (a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a5
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName (State ResultState) (a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT NodeName (State ResultState) a6
-> Result (a1, a2, a3, a4, a5, a6)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a6
forall a. BuildResult a => Result a
buildResult

instance ( BuildResult a1
         , BuildResult a2
         , BuildResult a3
         , BuildResult a4
         , BuildResult a5
         , BuildResult a6
         , BuildResult a7
         )
         => BuildResult (a1, a2, a3, a4, a5, a6, a7) where
    buildResult :: Result (a1, a2, a3, a4, a5, a6, a7)
buildResult = (,,,,,,)
               (a1
 -> a2
 -> a3
 -> a4
 -> a5
 -> a6
 -> a7
 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT
     NodeName
     (State ResultState)
     (a2 -> a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a2 -> a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT
     NodeName
     (State ResultState)
     (a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a3
-> ReaderT
     NodeName
     (State ResultState)
     (a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a4
-> ReaderT
     NodeName
     (State ResultState)
     (a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a4
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a5
-> ReaderT
     NodeName
     (State ResultState)
     (a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a5
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a6
-> ReaderT
     NodeName (State ResultState) (a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a6
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName (State ResultState) (a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT NodeName (State ResultState) a7
-> Result (a1, a2, a3, a4, a5, a6, a7)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a7
forall a. BuildResult a => Result a
buildResult

instance ( BuildResult a1
         , BuildResult a2
         , BuildResult a3
         , BuildResult a4
         , BuildResult a5
         , BuildResult a6
         , BuildResult a7
         , BuildResult a8
         )
         => BuildResult (a1, a2, a3, a4, a5, a6, a7, a8) where
    buildResult :: Result (a1, a2, a3, a4, a5, a6, a7, a8)
buildResult = (,,,,,,,)
               (a1
 -> a2
 -> a3
 -> a4
 -> a5
 -> a6
 -> a7
 -> a8
 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a1
-> ReaderT
     NodeName
     (State ResultState)
     (a2
      -> a3
      -> a4
      -> a5
      -> a6
      -> a7
      -> a8
      -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) a1
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a2
   -> a3
   -> a4
   -> a5
   -> a6
   -> a7
   -> a8
   -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a2
-> ReaderT
     NodeName
     (State ResultState)
     (a3
      -> a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a2
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a3
   -> a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a3
-> ReaderT
     NodeName
     (State ResultState)
     (a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a3
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a4
-> ReaderT
     NodeName
     (State ResultState)
     (a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a4
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a5
-> ReaderT
     NodeName
     (State ResultState)
     (a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a5
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a6
-> ReaderT
     NodeName
     (State ResultState)
     (a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a6
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a7
-> ReaderT
     NodeName
     (State ResultState)
     (a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a7
forall a. BuildResult a => Result a
buildResult
               ReaderT
  NodeName
  (State ResultState)
  (a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT NodeName (State ResultState) a8
-> Result (a1, a2, a3, a4, a5, a6, a7, a8)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT NodeName (State ResultState) a8
forall a. BuildResult a => Result a
buildResult

recordResult :: Result Output
recordResult :: Result Output
recordResult = do
    NodeName
o <- ReaderT NodeName (State ResultState) NodeName
forall r (m :: * -> *). MonadReader r m => m r
ask
    ResultState i :: OutputIx
i ns :: [Int64]
ns <- ReaderT NodeName (State ResultState) ResultState
forall s (m :: * -> *). MonadState s m => m s
get
    ResultState -> ReaderT NodeName (State ResultState) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ResultState -> ReaderT NodeName (State ResultState) ())
-> ResultState -> ReaderT NodeName (State ResultState) ()
forall a b. (a -> b) -> a -> b
$! OutputIx -> [Int64] -> ResultState
ResultState (OutputIx
iOutputIx -> OutputIx -> OutputIx
forall a. Num a => a -> a -> a
+1) [Int64]
ns
    Output -> Result Output
forall (m :: * -> *) a. Monad m => a -> m a
return (Output -> Result Output) -> Output -> Result Output
forall a b. (a -> b) -> a -> b
$! OutputIx -> NodeName -> Output
output OutputIx
i NodeName
o

instance (TensorKind v, Rendered (Tensor v)) => BuildResult (Tensor v a) where
    buildResult :: Result (Tensor v a)
buildResult = v Output -> Tensor v a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (v Output -> Tensor v a)
-> (Output -> v Output) -> Output -> Tensor v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> v Output
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Output -> Tensor v a) -> Result Output -> Result (Tensor v a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result Output
recordResult

instance BuildResult ControlNode where
    buildResult :: Result ControlNode
buildResult = NodeName -> ControlNode
ControlNode (NodeName -> ControlNode)
-> ReaderT NodeName (State ResultState) NodeName
-> Result ControlNode
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT NodeName (State ResultState) NodeName
forall r (m :: * -> *). MonadReader r m => m r
ask

instance (TensorKind v, Rendered (Tensor v), TensorTypes as) => BuildResult (TensorList v as) where
  buildResult :: Result (TensorList v as)
buildResult = TensorTypeList as -> Result (TensorList v as)
forall (bs :: [*]). TensorTypeList bs -> Result (TensorList v bs)
loop (TensorTypeList as
forall (ts :: [*]). TensorTypes ts => TensorTypeList ts
tensorTypes :: TensorTypeList as)
    where
        loop :: TensorTypeList bs -> Result (TensorList v bs)
        loop :: TensorTypeList bs -> Result (TensorList v bs)
loop Nil = ListOf (Tensor v) '[]
-> ReaderT NodeName (State ResultState) (ListOf (Tensor v) '[])
forall (m :: * -> *) a. Monad m => a -> m a
return ListOf (Tensor v) '[]
forall (f :: * -> *). ListOf f '[]
Nil
        loop (TensorTypeProxy :/ ls :: ListOf TensorTypeProxy as
ls) = do
            Tensor v a
t <- Result (Tensor v a)
forall a. BuildResult a => Result a
buildResult
            TensorList v as
ts <- ListOf TensorTypeProxy as -> Result (TensorList v as)
forall (bs :: [*]). TensorTypeList bs -> Result (TensorList v bs)
loop ListOf TensorTypeProxy as
ls
            ListOf (Tensor v) (a : as)
-> ReaderT
     NodeName (State ResultState) (ListOf (Tensor v) (a : as))
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor v a
t Tensor v a -> TensorList v as -> ListOf (Tensor v) (a : as)
forall (f :: * -> *) a (as :: [*]).
f a -> ListOf f as -> ListOf f (a : as)
:/ TensorList v as
ts)

instance BuildResult a => BuildResult [a] where
    buildResult :: Result [a]
buildResult = do
        ResultState i :: OutputIx
i ns :: [Int64]
ns <- ReaderT NodeName (State ResultState) ResultState
forall s (m :: * -> *). MonadState s m => m s
get
        case [Int64]
ns of
            [] -> String -> Result [a]
forall a. HasCallStack => String -> a
error (String -> Result [a]) -> String -> Result [a]
forall a b. (a -> b) -> a -> b
$ "Ran out of counts in buildResult. " String -> ShowS
forall a. [a] -> [a] -> [a]
++
                          "Likely misuse of buildOp."
            (n :: Int64
n : rest :: [Int64]
rest) -> do
                ResultState -> ReaderT NodeName (State ResultState) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ResultState -> ReaderT NodeName (State ResultState) ())
-> ResultState -> ReaderT NodeName (State ResultState) ()
forall a b. (a -> b) -> a -> b
$! OutputIx -> [Int64] -> ResultState
ResultState OutputIx
i [Int64]
rest
                Int -> ReaderT NodeName (State ResultState) a -> Result [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n) ReaderT NodeName (State ResultState) a
forall a. BuildResult a => Result a
buildResult

buildOp :: BuildResult a => [Int64] -> OpDef -> Build a
buildOp :: [Int64] -> OpDef -> Build a
buildOp sizes :: [Int64]
sizes o :: OpDef
o = do
    NodeName
n <- OpDef -> Build NodeName
addNewOp OpDef
o
    a -> Build a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Build a) -> a -> Build a
forall a b. (a -> b) -> a -> b
$ (State ResultState a -> ResultState -> a)
-> ResultState -> State ResultState a -> a
forall a b c. (a -> b -> c) -> b -> a -> c
flip State ResultState a -> ResultState -> a
forall s a. State s a -> s -> a
evalState (OutputIx -> [Int64] -> ResultState
ResultState 0 [Int64]
sizes) (ReaderT NodeName (State ResultState) a
-> NodeName -> State ResultState a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT NodeName (State ResultState) a
forall a. BuildResult a => Result a
buildResult NodeName
n)

-- | Returns true if all the integers in each tuple are identical.
-- Throws an error with a descriptive message if not.
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
eqLengthGuard = ((String, [(String, Int)]) -> Bool)
-> [(String, [(String, Int)])] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (String, [(String, Int)]) -> Bool
forall a a. (Eq a, Show a, Show a) => (String, [(a, a)]) -> Bool
eachOk
  where
    eachOk :: (String, [(a, a)]) -> Bool
eachOk (_, []) = Bool
True
    -- The next line has (== 1) . length . nub in disguise
    eachOk (numberAttrName :: String
numberAttrName, pairs :: [(a, a)]
pairs@((_, x :: a
x) : zs :: [(a, a)]
zs)) = ((a, a) -> Bool) -> [(a, a)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\z :: (a, a)
z -> (a, a) -> a
forall a b. (a, b) -> b
snd (a, a)
z a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x) [(a, a)]
zs Bool -> Bool -> Bool
||
        String -> Bool
forall a. HasCallStack => String -> a
error ("number_attr " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
numberAttrName String -> ShowS
forall a. [a] -> [a] -> [a]
++
               " contains tensors with different length " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [(a, a)] -> String
forall a. Show a => a -> String
show [(a, a)]
pairs)

-----------


-- | Class of types that can be used as op outputs.
class PureResult a where
    pureResult :: ReaderT (Build OpDef) (State ResultState) a

instance PureResult (Tensor Build a) where
    pureResult :: ReaderT (Build OpDef) (State ResultState) (Tensor Build a)
pureResult = do
        ResultState i :: OutputIx
i ns :: [Int64]
ns <- ReaderT (Build OpDef) (State ResultState) ResultState
forall s (m :: * -> *). MonadState s m => m s
get
        ResultState -> ReaderT (Build OpDef) (State ResultState) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ResultState -> ReaderT (Build OpDef) (State ResultState) ())
-> ResultState -> ReaderT (Build OpDef) (State ResultState) ()
forall a b. (a -> b) -> a -> b
$! OutputIx -> [Int64] -> ResultState
ResultState (OutputIx
iOutputIx -> OutputIx -> OutputIx
forall a. Num a => a -> a -> a
+1) [Int64]
ns
        Build OpDef
makeOp <- ReaderT (Build OpDef) (State ResultState) (Build OpDef)
forall r (m :: * -> *). MonadReader r m => m r
ask
        Tensor Build a
-> ReaderT (Build OpDef) (State ResultState) (Tensor Build a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor Build a
 -> ReaderT (Build OpDef) (State ResultState) (Tensor Build a))
-> Tensor Build a
-> ReaderT (Build OpDef) (State ResultState) (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ BuildT Identity Output -> Tensor Build a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (BuildT Identity Output -> Tensor Build a)
-> BuildT Identity Output -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ do
            OpDef
o <- Build OpDef
makeOp
            -- TODO: unify with BuildResult (Tensor v)
            OutputIx -> NodeName -> Output
output OutputIx
i (NodeName -> Output) -> Build NodeName -> BuildT Identity Output
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpDef -> Build NodeName
getOrAddOp OpDef
o

instance (PureResult a1, PureResult a2) => PureResult (a1, a2) where
    pureResult :: ReaderT (Build OpDef) (State ResultState) (a1, a2)
pureResult = (,) (a1 -> a2 -> (a1, a2))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT (Build OpDef) (State ResultState) (a2 -> (a1, a2))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT (Build OpDef) (State ResultState) (a2 -> (a1, a2))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT (Build OpDef) (State ResultState) (a1, a2)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult

instance (PureResult a1, PureResult a2, PureResult a3) => PureResult (a1, a2, a3) where
    pureResult :: ReaderT (Build OpDef) (State ResultState) (a1, a2, a3)
pureResult = (,,) (a1 -> a2 -> a3 -> (a1, a2, a3))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
     (Build OpDef) (State ResultState) (a2 -> a3 -> (a1, a2, a3))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT
  (Build OpDef) (State ResultState) (a2 -> a3 -> (a1, a2, a3))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT (Build OpDef) (State ResultState) (a3 -> (a1, a2, a3))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT (Build OpDef) (State ResultState) (a3 -> (a1, a2, a3))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT (Build OpDef) (State ResultState) (a1, a2, a3)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult

instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4)
         => PureResult (a1, a2, a3, a4) where
    pureResult :: ReaderT (Build OpDef) (State ResultState) (a1, a2, a3, a4)
pureResult = (,,,) (a1 -> a2 -> a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a2 -> a3 -> a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT
  (Build OpDef)
  (State ResultState)
  (a2 -> a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT
     (Build OpDef) (State ResultState) (a3 -> a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT
  (Build OpDef) (State ResultState) (a3 -> a4 -> (a1, a2, a3, a4))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT
     (Build OpDef) (State ResultState) (a4 -> (a1, a2, a3, a4))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult ReaderT (Build OpDef) (State ResultState) (a4 -> (a1, a2, a3, a4))
-> ReaderT (Build OpDef) (State ResultState) a4
-> ReaderT (Build OpDef) (State ResultState) (a1, a2, a3, a4)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a4
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult

instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4, PureResult a5)
         => PureResult (a1, a2, a3, a4, a5) where
    pureResult :: ReaderT (Build OpDef) (State ResultState) (a1, a2, a3, a4, a5)
pureResult = (,,,,) (a1 -> a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
                      ReaderT
  (Build OpDef)
  (State ResultState)
  (a2 -> a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
                      ReaderT
  (Build OpDef)
  (State ResultState)
  (a3 -> a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a4 -> a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
                      ReaderT
  (Build OpDef)
  (State ResultState)
  (a4 -> a5 -> (a1, a2, a3, a4, a5))
-> ReaderT (Build OpDef) (State ResultState) a4
-> ReaderT
     (Build OpDef) (State ResultState) (a5 -> (a1, a2, a3, a4, a5))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a4
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
                      ReaderT
  (Build OpDef) (State ResultState) (a5 -> (a1, a2, a3, a4, a5))
-> ReaderT (Build OpDef) (State ResultState) a5
-> ReaderT (Build OpDef) (State ResultState) (a1, a2, a3, a4, a5)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a5
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult

instance ( PureResult a1
         , PureResult a2
         , PureResult a3
         , PureResult a4
         , PureResult a5
         , PureResult a6
         )
         => PureResult (a1, a2, a3, a4, a5, a6) where
    pureResult :: ReaderT (Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6)
pureResult = (,,,,,)
               (a1 -> a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a2 -> a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a3 -> a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a4 -> a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a4
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a4
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a5 -> a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a5
-> ReaderT
     (Build OpDef) (State ResultState) (a6 -> (a1, a2, a3, a4, a5, a6))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a5
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef) (State ResultState) (a6 -> (a1, a2, a3, a4, a5, a6))
-> ReaderT (Build OpDef) (State ResultState) a6
-> ReaderT
     (Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a6
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult

instance ( PureResult a1
         , PureResult a2
         , PureResult a3
         , PureResult a4
         , PureResult a5
         , PureResult a6
         , PureResult a7
         )
         => PureResult (a1, a2, a3, a4, a5, a6, a7) where
    pureResult :: ReaderT
  (Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6, a7)
pureResult = (,,,,,,)
               (a1
 -> a2
 -> a3
 -> a4
 -> a5
 -> a6
 -> a7
 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a2 -> a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a2 -> a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a3 -> a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a4 -> a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a4
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a4
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a5 -> a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a5
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a5
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a6 -> a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a6
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a7 -> (a1, a2, a3, a4, a5, a6, a7))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a6
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a7 -> (a1, a2, a3, a4, a5, a6, a7))
-> ReaderT (Build OpDef) (State ResultState) a7
-> ReaderT
     (Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6, a7)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a7
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult

instance ( PureResult a1
         , PureResult a2
         , PureResult a3
         , PureResult a4
         , PureResult a5
         , PureResult a6
         , PureResult a7
         , PureResult a8
         )
         => PureResult (a1, a2, a3, a4, a5, a6, a7, a8) where
    pureResult :: ReaderT
  (Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6, a7, a8)
pureResult = (,,,,,,,)
               (a1
 -> a2
 -> a3
 -> a4
 -> a5
 -> a6
 -> a7
 -> a8
 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a1
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a2
      -> a3
      -> a4
      -> a5
      -> a6
      -> a7
      -> a8
      -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT (Build OpDef) (State ResultState) a1
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a2
   -> a3
   -> a4
   -> a5
   -> a6
   -> a7
   -> a8
   -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a2
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a3
      -> a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a2
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a3
   -> a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a3
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a3
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a4 -> a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a4
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a4
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a5 -> a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a5
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a5
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a6 -> a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a6
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a6
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a7 -> a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a7
-> ReaderT
     (Build OpDef)
     (State ResultState)
     (a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a7
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
               ReaderT
  (Build OpDef)
  (State ResultState)
  (a8 -> (a1, a2, a3, a4, a5, a6, a7, a8))
-> ReaderT (Build OpDef) (State ResultState) a8
-> ReaderT
     (Build OpDef) (State ResultState) (a1, a2, a3, a4, a5, a6, a7, a8)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Build OpDef) (State ResultState) a8
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult

instance PureResult a => PureResult [a] where
    pureResult :: ReaderT (Build OpDef) (State ResultState) [a]
pureResult = do
        ResultState i :: OutputIx
i ns :: [Int64]
ns <- ReaderT (Build OpDef) (State ResultState) ResultState
forall s (m :: * -> *). MonadState s m => m s
get
        case [Int64]
ns of
            [] -> String -> ReaderT (Build OpDef) (State ResultState) [a]
forall a. HasCallStack => String -> a
error (String -> ReaderT (Build OpDef) (State ResultState) [a])
-> String -> ReaderT (Build OpDef) (State ResultState) [a]
forall a b. (a -> b) -> a -> b
$ "Ran out of counts in pureResult. " String -> ShowS
forall a. [a] -> [a] -> [a]
++
                          "Likely misuse of pureOp with output lists."
            n :: Int64
n : rest :: [Int64]
rest -> do
                ResultState -> ReaderT (Build OpDef) (State ResultState) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ResultState -> ReaderT (Build OpDef) (State ResultState) ())
-> ResultState -> ReaderT (Build OpDef) (State ResultState) ()
forall a b. (a -> b) -> a -> b
$! OutputIx -> [Int64] -> ResultState
ResultState OutputIx
i [Int64]
rest
                Int
-> ReaderT (Build OpDef) (State ResultState) a
-> ReaderT (Build OpDef) (State ResultState) [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n) ReaderT (Build OpDef) (State ResultState) a
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult

instance TensorTypes as => PureResult (TensorList Build as) where
    pureResult :: ReaderT (Build OpDef) (State ResultState) (TensorList Build as)
pureResult = TensorTypeList as
-> ReaderT (Build OpDef) (State ResultState) (TensorList Build as)
forall (bs :: [*]).
TensorTypeList bs
-> ReaderT (Build OpDef) (State ResultState) (TensorList Build bs)
loop (TensorTypeList as
forall (ts :: [*]). TensorTypes ts => TensorTypeList ts
tensorTypes :: TensorTypeList as)
      where
        loop :: TensorTypeList bs -> ReaderT (Build OpDef) (State ResultState)
                                        (TensorList Build bs)
        loop :: TensorTypeList bs
-> ReaderT (Build OpDef) (State ResultState) (TensorList Build bs)
loop Nil = ListOf (Tensor Build) '[]
-> ReaderT
     (Build OpDef) (State ResultState) (ListOf (Tensor Build) '[])
forall (m :: * -> *) a. Monad m => a -> m a
return ListOf (Tensor Build) '[]
forall (f :: * -> *). ListOf f '[]
Nil
        loop (TensorTypeProxy :/ ls :: ListOf TensorTypeProxy as
ls) = do
            Tensor Build a
t <- ReaderT (Build OpDef) (State ResultState) (Tensor Build a)
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult
            TensorList Build as
ts <- ListOf TensorTypeProxy as
-> ReaderT (Build OpDef) (State ResultState) (TensorList Build as)
forall (bs :: [*]).
TensorTypeList bs
-> ReaderT (Build OpDef) (State ResultState) (TensorList Build bs)
loop ListOf TensorTypeProxy as
ls
            ListOf (Tensor Build) (a : as)
-> ReaderT
     (Build OpDef) (State ResultState) (ListOf (Tensor Build) (a : as))
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor Build a
t Tensor Build a
-> TensorList Build as -> ListOf (Tensor Build) (a : as)
forall (f :: * -> *) a (as :: [*]).
f a -> ListOf f as -> ListOf f (a : as)
:/ TensorList Build as
ts)

pureOp :: PureResult a => [Int64] -> Build OpDef -> a
pureOp :: [Int64] -> Build OpDef -> a
pureOp sizes :: [Int64]
sizes o :: Build OpDef
o = (State ResultState a -> ResultState -> a)
-> ResultState -> State ResultState a -> a
forall a b c. (a -> b -> c) -> b -> a -> c
flip State ResultState a -> ResultState -> a
forall s a. State s a -> s -> a
evalState (OutputIx -> [Int64] -> ResultState
ResultState 0 [Int64]
sizes) (ReaderT (Build OpDef) (State ResultState) a
-> Build OpDef -> State ResultState a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Build OpDef) (State ResultState) a
forall a.
PureResult a =>
ReaderT (Build OpDef) (State ResultState) a
pureResult Build OpDef
o)

-----
-- Class of types that can be used as arguments

class BuildInputs a where
    buildInputs :: a -> Build [Output]

instance BuildInputs a => BuildInputs [a] where
    buildInputs :: [a] -> Build [Output]
buildInputs = ([[Output]] -> [Output])
-> BuildT Identity [[Output]] -> Build [Output]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Output]] -> [Output]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (BuildT Identity [[Output]] -> Build [Output])
-> ([a] -> BuildT Identity [[Output]]) -> [a] -> Build [Output]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Build [Output]) -> [a] -> BuildT Identity [[Output]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs

instance BuildInputs (Tensor v a) where
    buildInputs :: Tensor v a -> Build [Output]
buildInputs (Tensor t :: v Output
t) = do
        Output
o <- v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild v Output
t
        [Output] -> Build [Output]
forall (m :: * -> *) a. Monad m => a -> m a
return [Output
o]

instance BuildInputs (ListOf (Tensor v) as) where
    buildInputs :: ListOf (Tensor v) as -> Build [Output]
buildInputs Nil = [Output] -> Build [Output]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    buildInputs (t :: Tensor v a
t :/ ts :: ListOf (Tensor v) as
ts) = ([Output] -> [Output] -> [Output])
-> Build [Output] -> Build [Output] -> Build [Output]
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 [Output] -> [Output] -> [Output]
forall a. [a] -> [a] -> [a]
(++) (Tensor v a -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs Tensor v a
t) (ListOf (Tensor v) as -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs ListOf (Tensor v) as
ts)

----

-- | Parameters to build an op (for example, the node name or optional attributes).
-- TODO: be more type safe.
type OpParams = OpDef -> OpDef