-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

-- | This module contains definitions for some built-in TensorFlow operations.
--
-- Note that certain, "stateful" ops like 'variable' and 'assign' return a
-- 'Build' action (e.g., @Build (Tensor Ref a)@ instead of a pure value; the
-- returned 'Tensor's are always rendered in the current 'Build' context.  This
-- approach helps us avoid problems with inlining or common subexpression
-- elimination, by writing
--
-- > do
-- >     v <- variable []
-- >     w <- assign v 3
-- >     render $ w * w
--
-- instead of
--
-- > let
-- >    v = variable []
-- >    w = assign v 3
-- > in w * w
--
-- since the latter could be reasonably transformed by the compiler into (or
-- vice versa)
--
-- > let
-- >    v = variable []
-- >    w = assign v 3
-- >    w' = assign v 3
-- > in w * w'
--
-- Ops should return a 'Build' action if their original 'OpDef' marks them as
-- stateful, or if they take any Refs as input.  (This mirrors the rules that
-- TensorFlow uses to avoid common subexpression elimination.)
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module TensorFlow.Ops
    ( CoreOps.add
    , CoreOps.add'
    , CoreOps.abs
    , CoreOps.abs'
    , CoreOps.addN
    , CoreOps.addN'
    , CoreOps.argMax
    , CoreOps.argMax'
    , CoreOps.assign
    , CoreOps.assign'
    , CoreOps.broadcastGradientArgs
    , CoreOps.broadcastGradientArgs'
    , CoreOps.cast
    , CoreOps.cast'
    , CoreOps.concat
    , CoreOps.concat'
    , constant
    , constant'
    , CoreOps.equal
    , CoreOps.equal'
    , expandDims
    , expandDims'
    , initializedVariable
    , initializedVariable'
    , zeroInitializedVariable
    , zeroInitializedVariable'
    , CoreOps.fill
    , CoreOps.fill'
    , CoreOps.identity
    , CoreOps.identity'
    , CoreOps.matMul
    , CoreOps.matMul'
    , CoreOps.einsum
    , CoreOps.einsum'
    , matTranspose
    , matTranspose'
    , CoreOps.mean
    , CoreOps.mean'
    , CoreOps.mul
    , CoreOps.mul'
    , CoreOps.neg
    , CoreOps.neg'
    , CoreOps.oneHot
    , CoreOps.oneHot'
    , CoreOps.pack
    , CoreOps.pack'
    , placeholder
    , placeholder'
    , CoreOps.range
    , CoreOps.range'
    , reducedShape
    , reduceMean
    , reduceMean'
    , CoreOps.relu
    , CoreOps.relu'
    , CoreOps.reluGrad
    , CoreOps.reluGrad'
    , CoreOps.tanh
    , CoreOps.tanhGrad
    , CoreOps.reshape
    , CoreOps.reshape'
    , restore
    , restoreFromName
    , save
    , scalar
    , scalar'
    , shape
    , shape'
    , CoreOps.sigmoid
    , CoreOps.sigmoidGrad
    , CoreOps.sign
    , CoreOps.sign'
    , CoreOps.size
    , CoreOps.size'
    , CoreOps.softmax
    , CoreOps.softmax'
    , CoreOps.softmaxCrossEntropyWithLogits
    , CoreOps.softmaxCrossEntropyWithLogits'
    , CoreOps.sparseToDense
    , CoreOps.sparseToDense'
    , CoreOps.sub
    , CoreOps.sub'
    , CoreOps.sum
    , CoreOps.sum'
    , reduceSum
    , reduceSum'
    , CoreOps.transpose
    , CoreOps.transpose'
    , truncatedNormal
    , truncatedNormal'
    , CoreOps.variable
    , CoreOps.variable'
    , vector
    , vector'
    , zeros
    , CoreOps.zerosLike
    , CoreOps.zerosLike'
    , scalarize
    ) where

import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Int (Int32, Int64)
import Data.Word (Word16)
import Prelude hiding (abs, sum, concat)
import Data.ProtoLens.Default(def)
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import Text.Printf (printf)
import Proto.Tensorflow.Core.Framework.Tensor  (TensorProto)
import Proto.Tensorflow.Core.Framework.Tensor_Fields
    ( dtype
    , tensorShape
    )
import qualified Proto.Tensorflow.Core.Framework.TensorShape_Fields
  as TensorShape

import TensorFlow.Build
import TensorFlow.BuildOp
import TensorFlow.ControlFlow (group)
import TensorFlow.Tensor
import TensorFlow.Types

import qualified TensorFlow.GenOps.Core as CoreOps

import qualified Prelude (abs)

-- TODO: Look into hs-boot refactoring to allow mutually recursive imports.
-- | Must be defined as an orphan because of the dependency order between Ops
-- and Tensor.
--
-- The indirect constraint "v ~ Value" helps disambiguate types, for example in
-- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression
-- "1".
instance ( TensorType a
         , Num a
         , v ~ Build
         , OneOf '[ Double, Float, Int32, Int64
                  , Complex Float, Complex Double] a) => Num (Tensor v a) where
    + :: Tensor v a -> Tensor v a -> Tensor v a
(+) = Tensor v a -> Tensor v a -> Tensor v a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, ByteString, Int16, Int32, Int64,
    Int8, Word16, Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.add
    * :: Tensor v a -> Tensor v a -> Tensor v a
(*) = Tensor v a -> Tensor v a -> Tensor v a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.mul
    (-) = Tensor v a -> Tensor v a -> Tensor v a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word32, Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.sub
    abs :: Tensor v a -> Tensor v a
abs = Tensor v a -> Tensor v a
forall (v'1 :: * -> *) t.
OneOf '[Int16, Int32, Int64, Int8, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
CoreOps.abs
    fromInteger :: Integer -> Tensor v a
fromInteger = a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
scalar (a -> Tensor Build a)
-> (Integer -> a) -> Integer -> Tensor Build a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
    signum :: Tensor v a -> Tensor v a
signum = Tensor v a -> Tensor v a
forall (v'1 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
Tensor v'1 t -> Tensor Build t
CoreOps.sign
    negate :: Tensor v a -> Tensor v a
negate = Tensor v a -> Tensor v a
forall (v'1 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Double, Float]
  t =>
Tensor v'1 t -> Tensor Build t
CoreOps.neg

matTranspose :: TensorType a => Tensor e a -> Tensor Build a
matTranspose :: Tensor e a -> Tensor Build a
matTranspose = OpParams -> Tensor e a -> Tensor Build a
forall a (v :: * -> *).
TensorType a =>
OpParams -> Tensor v a -> Tensor Build a
matTranspose' OpParams
forall a. a -> a
id

matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Build a
matTranspose' :: OpParams -> Tensor v a -> Tensor Build a
matTranspose' params :: OpParams
params = (Tensor v a -> Tensor Build Int32 -> Tensor Build a)
-> Tensor Build Int32 -> Tensor v a -> Tensor Build a
forall a b c. (a -> b -> c) -> b -> a -> c
flip (OpParams -> Tensor v a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tperm.
(TensorType t, OneOf '[Int32, Int64] tperm) =>
OpParams -> Tensor v'1 t -> Tensor v'2 tperm -> Tensor Build t
CoreOps.transpose' OpParams
params) ([Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1, 0 :: Int32])

placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
placeholder :: Shape -> m (Tensor Value a)
placeholder = OpParams -> Shape -> m (Tensor Value a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a) =>
OpParams -> Shape -> m (Tensor Value a)
placeholder' OpParams
forall a. a -> a
id

placeholder' :: forall m a . (MonadBuild m, TensorType a)
             => OpParams -> Shape -> m (Tensor Value a)
placeholder' :: OpParams -> Shape -> m (Tensor Value a)
placeholder' params :: OpParams
params pShape :: Shape
pShape
    -- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
    -- and thus would be CSE'd.
    = Build (Tensor Value a) -> m (Tensor Value a)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build (Tensor Value a) -> m (Tensor Value a))
-> Build (Tensor Value a) -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ [Int64] -> OpDef -> Build (Tensor Value a)
forall a. BuildResult a => [Int64] -> OpDef -> Build a
buildOp [] (OpDef -> Build (Tensor Value a))
-> OpDef -> Build (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ OpType -> OpDef
opDef "Placeholder"
                OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef DataType
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "dtype" (forall (f :: * -> *). Identical f => LensLike' f OpDef DataType)
-> DataType -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)
                OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef Shape
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "shape" (forall (f :: * -> *). Identical f => LensLike' f OpDef Shape)
-> Shape -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ Shape
pShape
                OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& OpParams
params

-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: (MonadBuild m, TensorType a)
                    => Tensor v a -> m (Tensor Ref a)
initializedVariable :: Tensor v a -> m (Tensor Ref a)
initializedVariable = OpParams -> Tensor v a -> m (Tensor Ref a)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Tensor v a -> m (Tensor Ref a)
initializedVariable' OpParams
forall a. a -> a
id

initializedVariable' :: (MonadBuild m, TensorType a)
                    => OpParams -> Tensor v a -> m (Tensor Ref a)
initializedVariable' :: OpParams -> Tensor v a -> m (Tensor Ref a)
initializedVariable' params :: OpParams
params initializer :: Tensor v a
initializer = do
    Tensor Ref a
v <- OpParams -> Shape -> m (Tensor Ref a)
forall dtype (m' :: * -> *).
(MonadBuild m', TensorType dtype) =>
OpParams -> Shape -> m' (Tensor Ref dtype)
CoreOps.variable' OpParams
params []  -- The shape is not known initially.
    Tensor Ref a
i <- OpParams -> Tensor Ref a -> Tensor v a -> m (Tensor Ref a)
forall (v'2 :: * -> *) t (m' :: * -> *).
(MonadBuild m', TensorType t) =>
OpParams -> Tensor Ref t -> Tensor v'2 t -> m' (Tensor Ref t)
CoreOps.assign' (Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "validate_shape" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
False) Tensor Ref a
v
                            Tensor v a
initializer
    ControlNode -> m ()
forall (m :: * -> *). MonadBuild m => ControlNode -> m ()
addInitializer (ControlNode -> m ()) -> m ControlNode -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor Ref a -> m ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
group Tensor Ref a
i
    Tensor Ref a -> m (Tensor Ref a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor Ref a
v

-- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable
  :: (MonadBuild m, TensorType a, Num a) =>
     TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable :: Shape -> m (Tensor Ref a)
zeroInitializedVariable = OpParams -> Shape -> m (Tensor Ref a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a, Num a) =>
OpParams -> Shape -> m (Tensor Ref a)
zeroInitializedVariable' OpParams
forall a. a -> a
id

zeroInitializedVariable'
  :: (MonadBuild m, TensorType a, Num a) =>
     OpParams -> TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable' :: OpParams -> Shape -> m (Tensor Ref a)
zeroInitializedVariable' params :: OpParams
params = OpParams -> Tensor Build a -> m (Tensor Ref a)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Tensor v a -> m (Tensor Ref a)
initializedVariable' OpParams
params (Tensor Build a -> m (Tensor Ref a))
-> (Shape -> Tensor Build a) -> Shape -> m (Tensor Ref a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> Tensor Build a
forall a. (Num a, TensorType a) => Shape -> Tensor Build a
zeros

-- TODO: Support heterogeneous list of tensors.
save :: forall a m v . (Rendered (Tensor v), MonadBuild m, TensorType a)
        => ByteString    -- ^ File path.
        -> [Tensor v a]  -- ^ Tensors to save.
        -> m ControlNode
save :: ByteString -> [Tensor v a] -> m ControlNode
save path :: ByteString
path xs :: [Tensor v a]
xs = Build ControlNode -> m ControlNode
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build ControlNode -> m ControlNode)
-> Build ControlNode -> m ControlNode
forall a b. (a -> b) -> a -> b
$ do
    let toByteStringTensor :: Tensor v a -> Tensor Build ByteString
toByteStringTensor = ByteString -> Tensor Build ByteString
forall a. TensorType a => a -> Tensor Build a
scalar (ByteString -> Tensor Build ByteString)
-> (Tensor v a -> ByteString)
-> Tensor v a
-> Tensor Build ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
encodeUtf8 (Text -> ByteString)
-> (Tensor v a -> Text) -> Tensor v a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> Text
encodeOutput (Output -> Text) -> (Tensor v a -> Output) -> Tensor v a -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor v a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput
    let names :: [Tensor Build ByteString]
names = (Tensor v a -> Tensor Build ByteString)
-> [Tensor v a] -> [Tensor Build ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Tensor v a -> Tensor Build ByteString
forall a. Tensor v a -> Tensor Build ByteString
toByteStringTensor [Tensor v a]
xs
    let types :: [DataType]
types = Int -> DataType -> [DataType]
forall a. Int -> a -> [a]
replicate ([Tensor v a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor v a]
xs) (a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a))
    [Output]
names' <- Tensor Build ByteString -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs (Tensor Build ByteString -> Build [Output])
-> Tensor Build ByteString -> Build [Output]
forall a b. (a -> b) -> a -> b
$ [Tensor Build ByteString] -> Tensor Build ByteString
forall (v'1 :: * -> *) t.
TensorType t =>
[Tensor v'1 t] -> Tensor Build t
CoreOps.pack [Tensor Build ByteString]
names
    [Output]
xs' <- [Tensor v a] -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs [Tensor v a]
xs
    [Output]
path' <- Tensor Build ByteString -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs (Tensor Build ByteString -> Build [Output])
-> Tensor Build ByteString -> Build [Output]
forall a b. (a -> b) -> a -> b
$ ByteString -> Tensor Build ByteString
forall a. TensorType a => a -> Tensor Build a
scalar ByteString
path
    [Int64] -> OpDef -> Build ControlNode
forall a. BuildResult a => [Int64] -> OpDef -> Build a
buildOp [] (OpDef -> Build ControlNode) -> OpDef -> Build ControlNode
forall a b. (a -> b) -> a -> b
$ OpType -> OpDef
opDef "Save"
                    OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef [DataType]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "T" (forall (f :: * -> *). Identical f => LensLike' f OpDef [DataType])
-> [DataType] -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ [DataType]
types
                    OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Lens' OpDef [Output]
forall (f :: * -> *). Identical f => LensLike' f OpDef [Output]
opInputs (forall (f :: * -> *). Identical f => LensLike' f OpDef [Output])
-> [Output] -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ ([Output]
path' [Output] -> [Output] -> [Output]
forall a. [a] -> [a] -> [a]
++ [Output]
names' [Output] -> [Output] -> [Output]
forall a. [a] -> [a] -> [a]
++ [Output]
xs')

-- | Restore a tensor's value from a checkpoint file.
--
-- This version allows restoring from a checkpoint file that uses a different
-- tensor name than the variable.
restoreFromName :: forall a m . (MonadBuild m, TensorType a)
                => ByteString    -- ^ File path.
                -> ByteString    -- ^ Tensor name override.
                -> Tensor Ref a  -- ^ Tensor to restore.
                -> m ControlNode
restoreFromName :: ByteString -> ByteString -> Tensor Ref a -> m ControlNode
restoreFromName path :: ByteString
path name :: ByteString
name x :: Tensor Ref a
x = Build ControlNode -> m ControlNode
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build ControlNode -> m ControlNode)
-> Build ControlNode -> m ControlNode
forall a b. (a -> b) -> a -> b
$ do
    [Output]
path' <- Tensor Build ByteString -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs (Tensor Build ByteString -> Build [Output])
-> Tensor Build ByteString -> Build [Output]
forall a b. (a -> b) -> a -> b
$ ByteString -> Tensor Build ByteString
forall a. TensorType a => a -> Tensor Build a
scalar ByteString
path
    [Output]
name' <- Tensor Build ByteString -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs (Tensor Build ByteString -> Build [Output])
-> Tensor Build ByteString -> Build [Output]
forall a b. (a -> b) -> a -> b
$ ByteString -> Tensor Build ByteString
forall a. TensorType a => a -> Tensor Build a
scalar ByteString
name
    Tensor Value a
restoreOp <- [Int64] -> OpDef -> Build (Tensor Value a)
forall a. BuildResult a => [Int64] -> OpDef -> Build a
buildOp [] (OpDef -> Build (Tensor Value a))
-> OpDef -> Build (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ OpType -> OpDef
opDef "Restore"
                               OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef DataType
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "dt" (forall (f :: * -> *). Identical f => LensLike' f OpDef DataType)
-> DataType -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)
                               OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Lens' OpDef [Output]
forall (f :: * -> *). Identical f => LensLike' f OpDef [Output]
opInputs (forall (f :: * -> *). Identical f => LensLike' f OpDef [Output])
-> [Output] -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ ([Output]
path' [Output] -> [Output] -> [Output]
forall a. [a] -> [a] -> [a]
++ [Output]
name')
    Tensor Ref a -> Build ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
group (Tensor Ref a -> Build ControlNode)
-> BuildT Identity (Tensor Ref a) -> Build ControlNode
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor Ref a -> Tensor Value a -> BuildT Identity (Tensor Ref a)
forall (v'2 :: * -> *) t (m' :: * -> *).
(MonadBuild m', TensorType t) =>
Tensor Ref t -> Tensor v'2 t -> m' (Tensor Ref t)
CoreOps.assign Tensor Ref a
x (Tensor Value a
restoreOp :: Tensor Value a)

-- | Restore a tensor's value from a checkpoint file.
restore :: forall a m . (MonadBuild m, TensorType a)
        => ByteString    -- ^ File path.
        -> Tensor Ref a  -- ^ Tensor to restore.
        -> m ControlNode
restore :: ByteString -> Tensor Ref a -> m ControlNode
restore path :: ByteString
path x :: Tensor Ref a
x = ByteString -> ByteString -> Tensor Ref a -> m ControlNode
forall a (m :: * -> *).
(MonadBuild m, TensorType a) =>
ByteString -> ByteString -> Tensor Ref a -> m ControlNode
restoreFromName ByteString
path ByteString
name Tensor Ref a
x
  where
    name :: ByteString
name = Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Output -> Text
encodeOutput (Output -> Text) -> Output -> Text
forall a b. (a -> b) -> a -> b
$ Tensor Ref a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput Tensor Ref a
x

-- | Create a constant tensor.
--
-- The values should be in row major order, e.g.,
--
--   element 0:   index (0, ..., 0)
--   element 1:   index (0, ..., 1)
--   ...
constant :: TensorType a => Shape -> [a] -> Tensor Build a
constant :: Shape -> [a] -> Tensor Build a
constant = OpParams -> Shape -> [a] -> Tensor Build a
forall a.
TensorType a =>
OpParams -> Shape -> [a] -> Tensor Build a
constant' OpParams
forall a. a -> a
id

constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Build a
constant' :: OpParams -> Shape -> [a] -> Tensor Build a
constant' params :: OpParams
params (Shape cShape :: [Int64]
cShape) values :: [a]
values
    | Bool
invalidLength = [Char] -> Tensor Build a
forall a. HasCallStack => [Char] -> a
error [Char]
invalidLengthMsg
    | Bool
otherwise = OpParams -> Tensor Build a
forall dtype. TensorType dtype => OpParams -> Tensor Build dtype
CoreOps.const' (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef TensorProto
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "value" (forall (f :: * -> *).
 Identical f =>
 LensLike' f OpDef TensorProto)
-> TensorProto -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ TensorProto
typedNode))
  where
    invalidLength :: Bool
invalidLength = [Int64] -> Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
cShape Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
values)
    invalidLengthMsg :: [Char]
invalidLengthMsg = [Char] -> Int64 -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "invalid tensor length: expected %d got %d"
                              ([Int64] -> Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
cShape)
                              ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
values)
    typedNode :: TensorProto
    typedNode :: TensorProto
typedNode = TensorProto
forall a. Message a => a
def
                TensorProto -> (TensorProto -> TensorProto) -> TensorProto
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f TensorProto DataType
forall (f :: * -> *) s a.
(Functor f, HasField s "dtype" a) =>
LensLike' f s a
dtype (forall (f :: * -> *).
 Identical f =>
 LensLike' f TensorProto DataType)
-> DataType -> TensorProto -> TensorProto
forall s t a b. Setter s t a b -> b -> s -> t
.~ a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)
                TensorProto -> (TensorProto -> TensorProto) -> TensorProto
forall s t. s -> (s -> t) -> t
& LensLike' f TensorProto TensorShapeProto
forall (f :: * -> *) s a.
(Functor f, HasField s "tensorShape" a) =>
LensLike' f s a
tensorShapeLensLike' f TensorProto TensorShapeProto
-> (([TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
    -> TensorShapeProto -> f TensorShapeProto)
-> ([TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
-> TensorProto
-> f TensorProto
forall b c a. (b -> c) -> (a -> b) -> a -> c
.([TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
-> TensorShapeProto -> f TensorShapeProto
forall (f :: * -> *) s a.
(Functor f, HasField s "dim" a) =>
LensLike' f s a
TensorShape.dim (forall (f :: * -> *).
 Identical f =>
 ([TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
 -> TensorProto -> f TensorProto)
-> [TensorShapeProto'Dim] -> TensorProto -> TensorProto
forall s t a b. Setter s t a b -> b -> s -> t
.~
                      [TensorShapeProto'Dim
forall a. Message a => a
def TensorShapeProto'Dim
-> (TensorShapeProto'Dim -> TensorShapeProto'Dim)
-> TensorShapeProto'Dim
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto'Dim Int64
forall (f :: * -> *) s a.
(Functor f, HasField s "size" a) =>
LensLike' f s a
TensorShape.size (forall (f :: * -> *).
 Identical f =>
 LensLike' f TensorShapeProto'Dim Int64)
-> Int64 -> TensorShapeProto'Dim -> TensorShapeProto'Dim
forall s t a b. Setter s t a b -> b -> s -> t
.~ Int64
x | Int64
x <- [Int64]
cShape]
                TensorProto -> (TensorProto -> TensorProto) -> TensorProto
forall s t. s -> (s -> t) -> t
& forall a. TensorType a => Lens' TensorProto [a]
forall (f :: * -> *). Identical f => LensLike' f TensorProto [a]
tensorVal (forall (f :: * -> *). Identical f => LensLike' f TensorProto [a])
-> [a] -> TensorProto -> TensorProto
forall s t a b. Setter s t a b -> b -> s -> t
.~ [a]
values

-- | Reshape a N-D tensor down to a scalar.
--
-- See `TensorFlow.GenOps.Core.reshape`.
scalarize :: TensorType a => Tensor v a -> Tensor Build a
scalarize :: Tensor v a -> Tensor Build a
scalarize t :: Tensor v a
t = Tensor v a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor v a
t ([Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [Int32]
scalarShape)
    where
        scalarShape :: [Int32]
scalarShape = [] :: [Int32]

-- | Sum a tensor down to a scalar
-- Seee `TensorFlow.GenOps.Core.sum`
reduceSum :: (OneOf '[ Double, Float, Int32, Int64
                     , Complex Float, Complex Double] a) =>
             Tensor v a -> Tensor Build a
reduceSum :: Tensor v a -> Tensor Build a
reduceSum x :: Tensor v a
x = Tensor v a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.sum Tensor v a
x Tensor Build Int32
allAxes
  where allAxes :: Tensor Build Int32
allAxes = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor v a -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank Tensor v a
x :: Tensor Build Int32) 1

reduceSum' :: (OneOf '[ Double, Float, Int32, Int64
                      , Complex Float, Complex Double] a) =>
              OpParams -> Tensor v a -> Tensor Build a
reduceSum' :: OpParams -> Tensor v a -> Tensor Build a
reduceSum' params :: OpParams
params x :: Tensor v a
x = OpParams -> Tensor v a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
OpParams -> Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.sum' OpParams
params Tensor v a
x Tensor Build Int32
allAxes
  where allAxes :: Tensor Build Int32
allAxes = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor v a -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank Tensor v a
x :: Tensor Build Int32) 1

-- | Computes the mean of elements across dimensions of a tensor.
-- See `TensorFlow.GenOps.Core.mean`
reduceMean
  :: ( TensorType a
     , OneOf '[ Double, Float, Complex Float, Complex Double] a
     )
  => Tensor v a -> Tensor Build a
reduceMean :: Tensor v a -> Tensor Build a
reduceMean = OpParams -> Tensor v a -> Tensor Build a
forall a (v :: * -> *).
(TensorType a,
 OneOf '[Double, Float, Complex Float, Complex Double] a) =>
OpParams -> Tensor v a -> Tensor Build a
reduceMean' OpParams
forall a. a -> a
id

reduceMean'
  :: ( TensorType a
     , OneOf '[ Double, Float, Complex Float, Complex Double] a
     )
  => OpParams -> Tensor v a -> Tensor Build a
reduceMean' :: OpParams -> Tensor v a -> Tensor Build a
reduceMean' params :: OpParams
params x :: Tensor v a
x = OpParams -> Tensor v a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
OpParams -> Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.mean' OpParams
params Tensor v a
x Tensor Build Int32
allAxes
  where allAxes :: Tensor Build Int32
allAxes = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor v a -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank Tensor v a
x :: Tensor Build Int32) 1

-- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Build a
vector :: [a] -> Tensor Build a
vector = OpParams -> [a] -> Tensor Build a
forall a. TensorType a => OpParams -> [a] -> Tensor Build a
vector' OpParams
forall a. a -> a
id

vector' :: TensorType a => OpParams -> [a] -> Tensor Build a
vector' :: OpParams -> [a] -> Tensor Build a
vector' params :: OpParams
params xs :: [a]
xs = OpParams -> Shape -> [a] -> Tensor Build a
forall a.
TensorType a =>
OpParams -> Shape -> [a] -> Tensor Build a
constant' OpParams
params [Int -> Item Shape
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Item Shape) -> Int -> Item Shape
forall a b. (a -> b) -> a -> b
$ [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs] [a]
xs

-- | Create a constant scalar.
scalar :: TensorType a => a -> Tensor Build a
scalar :: a -> Tensor Build a
scalar = OpParams -> a -> Tensor Build a
forall a. TensorType a => OpParams -> a -> Tensor Build a
scalar' OpParams
forall a. a -> a
id

scalar' :: TensorType a => OpParams -> a -> Tensor Build a
scalar' :: OpParams -> a -> Tensor Build a
scalar' params :: OpParams
params x :: a
x = OpParams -> Shape -> [a] -> Tensor Build a
forall a.
TensorType a =>
OpParams -> Shape -> [a] -> Tensor Build a
constant' OpParams
params [] [a
Item [a]
x]

-- | Random tensor from the unit normal distribution with bounded values.
--
-- This is a type-restricted version of 'TensorFlow.GenOps.Core.truncatedNormal'.
truncatedNormal :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
                => Tensor v Int64  -- ^ Shape.
                -> m (Tensor Value a)
truncatedNormal :: Tensor v Int64 -> m (Tensor Value a)
truncatedNormal = Tensor v Int64 -> m (Tensor Value a)
forall (v'1 :: * -> *) dtype t (m' :: * -> *).
(MonadBuild m', OneOf '[Word16, Double, Float] dtype,
 OneOf '[Int32, Int64] t) =>
Tensor v'1 t -> m' (Tensor Value dtype)
CoreOps.truncatedNormal

truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
                => OpParams -> Tensor v Int64  -- ^ Shape.
                -> m (Tensor Value a)
truncatedNormal' :: OpParams -> Tensor v Int64 -> m (Tensor Value a)
truncatedNormal' = OpParams -> Tensor v Int64 -> m (Tensor Value a)
forall (v'1 :: * -> *) dtype t (m' :: * -> *).
(MonadBuild m', OneOf '[Word16, Double, Float] dtype,
 OneOf '[Int32, Int64] t) =>
OpParams -> Tensor v'1 t -> m' (Tensor Value dtype)
CoreOps.truncatedNormal'

zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a
zeros :: Shape -> Tensor Build a
zeros (Shape s :: [Int64]
s) = Tensor Build Int64 -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t index_type.
(TensorType t, OneOf '[Int32, Int64] index_type) =>
Tensor v'1 index_type -> Tensor v'2 t -> Tensor Build t
CoreOps.fill ([Int64] -> Tensor Build Int64
forall a. TensorType a => [a] -> Tensor Build a
vector [Int64]
s) (a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
scalar 0)

shape :: TensorType t => Tensor v t -> Tensor Build Int32
shape :: Tensor v t -> Tensor Build Int32
shape = Tensor v t -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.shape

shape' :: TensorType t => OpParams -> Tensor v t -> Tensor Build Int32
shape' :: OpParams -> Tensor v t -> Tensor Build Int32
shape' = OpParams -> Tensor v t -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
OpParams -> Tensor v'1 t -> Tensor Build out_type
CoreOps.shape'

expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims :: Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims = Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.expandDims

expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims' :: OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims' = OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tperm.
(TensorType t, OneOf '[Int32, Int64] tperm) =>
OpParams -> Tensor v'1 t -> Tensor v'2 tperm -> Tensor Build t
CoreOps.expandDims'

-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
                Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
reducedShape :: Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
reducedShape inputShape :: Tensor v1 t1
inputShape axes :: Tensor v2 t2
axes =
    let inputShape32 :: Tensor Build Int32
inputShape32 = Tensor v1 t1 -> Tensor Build Int32
forall srcT (v'1 :: * -> *).
TensorType srcT =>
Tensor v'1 srcT -> Tensor Build Int32
toInt32 Tensor v1 t1
inputShape         -- [2, 3, 5, 7]
        axes32 :: Tensor Build Int32
axes32 = Tensor v2 t2 -> Tensor Build Int32
forall srcT (v'1 :: * -> *).
TensorType srcT =>
Tensor v'1 srcT -> Tensor Build Int32
toInt32 Tensor v2 t2
axes                     -- [1, 2]
        toInt32 :: Tensor v'1 srcT -> Tensor Build Int32
toInt32 x :: Tensor v'1 srcT
x = Tensor v'1 srcT -> Tensor Build Int32
forall (v'1 :: * -> *) srcT dstT.
(TensorType srcT, TensorType dstT) =>
Tensor v'1 srcT -> Tensor Build dstT
CoreOps.cast Tensor v'1 srcT
x :: Tensor Build Int32
        inputRank :: Tensor Build Int32
inputRank = Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.size Tensor Build Int32
inputShape32     -- 4
        axesMod :: Tensor Build Int32
axesMod = (Tensor Build Int32
axes32 Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall a. Num a => a -> a -> a
+ Tensor Build Int32
inputRank) Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mod` Tensor Build Int32
inputRank
        axesShape :: Tensor Build Int32
axesShape = Tensor Build Int32 -> Tensor Build Int32
forall srcT (v'1 :: * -> *).
TensorType srcT =>
Tensor v'1 srcT -> Tensor Build Int32
shape Tensor Build Int32
axesMod                 -- [2]
    in [Tensor Build Int32] -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
[Tensor v'1 Int32] -> [Tensor v'2 t] -> Tensor Build t
CoreOps.dynamicStitch                      -- [2, 1, 1, 7]
         [Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 Tensor Build Int32
inputRank 1,            -- [0, 1, 2, 3]
           Item [Tensor Build Int32]
Tensor Build Int32
axesMod]                               -- [1, 2]
         [Item [Tensor Build Int32]
Tensor Build Int32
inputShape32,                           -- [2, 3, 5, 7]
           Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t index_type.
(TensorType t, OneOf '[Int32, Int64] index_type) =>
Tensor v'1 index_type -> Tensor v'2 t -> Tensor Build t
CoreOps.fill Tensor Build Int32
axesShape 1]              -- [1, 1]