-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeApplications #-}

module TensorFlow.Gradient
    ( GradientCompatible
    , gradients
    ) where

import Control.Monad (forM, zipWithM)
import Control.Monad.State.Strict (State, evalState, gets, modify)
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.ProtoLens.Default(def)
import Data.Int (Int32, Int64)
import Data.Foldable (foldlM)
import Data.List (foldl', sortBy)
import Data.Map.Strict (Map)
import qualified Data.IntSet as IntSet
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
import Data.Ord (comparing)
import Data.ProtoLens.TextFormat (showMessage)
import Data.Set (Set)
import Data.Text (Text)
import Data.Tuple (swap)
import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~), under)
import Lens.Family2.State.Strict (uses)
import Lens.Family2.Stock (at, intAt)
import Lens.Family2.Unchecked (lens, adapter)
import Prelude hiding (sum, tanh)
import Text.Printf (printf)
import qualified Data.Graph.Inductive.Basic as FGL
import qualified Data.Graph.Inductive.Graph as FGL
import qualified Data.Graph.Inductive.PatriciaTree as FGL
import qualified Data.Graph.Inductive.Query.DFS as FGL
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Text as Text

import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Build
    ( MonadBuild
    , Build
    , build
    , renderedNodeDefs
    , opDef
    , opAttr
    , opInputs
    )
import TensorFlow.BuildOp
import TensorFlow.Ops
    ( addN
    , broadcastGradientArgs
    , expandDims
    , fill
    , matMul
    , matMul'
    , reducedShape
    , reluGrad
    , tanh
    , tanhGrad
    , reshape
    , scalar
    , shape
    , softmaxCrossEntropyWithLogits
    , sum
    , sigmoid
    , sigmoidGrad
    , scalarize
    , vector
    , zerosLike
    )
import TensorFlow.Output
    ( NodeName(..)
    , Output(..)
    , OutputIx(..)
    , outputIndex
    )
import TensorFlow.Tensor
    ( Tensor(..)
    , Value
    , render
    , expr
    , Rendered
    , tensorNodeName
    , renderedOutput
    , renderValue
    , ToTensor(..)
    )
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef)
import Proto.Tensorflow.Core.Framework.NodeDef_Fields
    ( attr, input, op, name)

type GradientCompatible a =
    -- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.
    (Num a, OneOf '[ Float, Complex Float, Complex Double ] a)

-- TODO(fmayle): Support control flow.
-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.
-- TODO(fmayle): Do we need to consider control inputs? See _PendingCount in
-- tensorflow/python/ops/gradients.py.
-- TODO(fmayle): Maybe store the gradient functions and numOutputs on the OpDef.


-- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 t m . ( MonadBuild m
                               , Rendered t
                               , ToTensor t
                               , GradientCompatible a
                               )
          => Tensor v1 a  -- ^ The output of the graph.
          -> [t a]        -- ^ Tensors for which gradients are computed.
          -> m [Tensor Value a]
gradients :: Tensor v1 a -> [t a] -> m [Tensor Value a]
gradients y :: Tensor v1 a
y xs :: [t a]
xs = Build [Tensor Value a] -> m [Tensor Value a]
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build [Tensor Value a] -> m [Tensor Value a])
-> Build [Tensor Value a] -> m [Tensor Value a]
forall a b. (a -> b) -> a -> b
$ do
    -- The gradients are computed using "reverse accumulation", similarly to
    -- what is described here:
    -- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
    --
    -- The code is summarised as follows:
    --
    -- 1. Create an fgl graph of the relevant nodes (ops) and edges (tensors).
    -- 2. Initialize the gradient of y to 1 (∂y/∂y = 1) and the rest of tensor's
    --    gradients to nothing.
    -- 3. Process the nodes in reverse topological order (i.e. each node comes
    --    after all of its outputs so that the output gradients for a node have
    --    been completely calculated before it is processed):
    --      a. Record the gradient for each of the node's output tensors (∂y/∂w
    --         for each output tensor w).
    --      b. Calculate the gradient of y w.r.t. each of the node's input
    --         tensors using the gradients of the node's output tensors.
    --
    --         Written differently, for each output tensor w and input tensor v:
    --           ∂y/∂w = ...            (calculated in previous steps)
    --           ∂w/∂v = ...            (op specific)
    --           ∂y/∂v = ∂y/∂w * ∂w/∂v  (technically, if tensor v is an input
    --                                   to multiple nodes, then this is only
    --                                   part of ∂y/∂v)
    --
    -- 4. Lookup the recorded gradient for each x in xs.

    Tensor Value a
y' <- Tensor v1 a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) (v :: * -> *) a.
MonadBuild m =>
Tensor v a -> m (Tensor Value a)
renderValue Tensor v1 a
y
    let yName :: NodeName
yName = Tensor Value a -> NodeName
forall (t :: * -> *) a. Rendered t => t a -> NodeName
tensorNodeName Tensor Value a
y'
    Tensor Value a
yOne <- Tensor Build a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> BuildT Identity (Tensor Value a))
-> Tensor Build a -> BuildT Identity (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32 -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t index_type.
(TensorType t, OneOf '[Int32, Int64] index_type) =>
Tensor v'1 index_type -> Tensor v'2 t -> Tensor Build t
fill (Tensor Value a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Value a
y') (a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
scalar 1)
    -- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?
    NodeName -> NodeDef
nodeDefLookup :: (NodeName -> NodeDef) <- FoldLike
  (NodeName -> NodeDef)
  GraphState
  GraphState
  (Map NodeName NodeDef)
  (Map NodeName NodeDef)
-> (Map NodeName NodeDef -> NodeName -> NodeDef)
-> BuildT Identity (NodeName -> NodeDef)
forall s (m :: * -> *) r t a b.
MonadState s m =>
FoldLike r s t a b -> (a -> r) -> m r
uses FoldLike
  (NodeName -> NodeDef)
  GraphState
  GraphState
  (Map NodeName NodeDef)
  (Map NodeName NodeDef)
Lens' GraphState (Map NodeName NodeDef)
renderedNodeDefs ((Map NodeName NodeDef -> NodeName -> NodeDef)
 -> BuildT Identity (NodeName -> NodeDef))
-> (Map NodeName NodeDef -> NodeName -> NodeDef)
-> BuildT Identity (NodeName -> NodeDef)
forall a b. (a -> b) -> a -> b
$
        (\f :: NodeName -> Maybe NodeDef
f x :: NodeName
x -> NodeDef -> Maybe NodeDef -> NodeDef
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> NodeDef
forall a. HasCallStack => [Char] -> a
error ([Char] -> NodeDef) -> [Char] -> NodeDef
forall a b. (a -> b) -> a -> b
$ "no NodeDef found for " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ NodeName -> [Char]
forall a. Show a => a -> [Char]
show NodeName
x) (NodeName -> Maybe NodeDef
f NodeName
x))
        ((NodeName -> Maybe NodeDef) -> NodeName -> NodeDef)
-> (Map NodeName NodeDef -> NodeName -> Maybe NodeDef)
-> Map NodeName NodeDef
-> NodeName
-> NodeDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeName -> Map NodeName NodeDef -> Maybe NodeDef)
-> Map NodeName NodeDef -> NodeName -> Maybe NodeDef
forall a b c. (a -> b -> c) -> b -> a -> c
flip NodeName -> Map NodeName NodeDef -> Maybe NodeDef
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup
    let (gr :: Graph
gr, nodeMap :: Map NodeName Node
nodeMap) = NodeName -> (NodeName -> NodeDef) -> (Graph, Map NodeName Node)
createGraph NodeName
yName NodeName -> NodeDef
nodeDefLookup
        xnodes :: [Node]
xnodes = (t a -> Maybe Node) -> [t a] -> [Node]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\x :: t a
x -> Map NodeName Node
nodeMap Map NodeName Node
-> FoldLike
     (Maybe Node)
     (Map NodeName Node)
     (Map NodeName Node)
     (Maybe Node)
     (Maybe Node)
-> Maybe Node
forall s a t b. s -> FoldLike a s t a b -> a
^. (NodeName
-> FoldLike
     (Maybe Node)
     (Map NodeName Node)
     (Map NodeName Node)
     (Maybe Node)
     (Maybe Node)
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at (NodeName
 -> FoldLike
      (Maybe Node)
      (Map NodeName Node)
      (Map NodeName Node)
      (Maybe Node)
      (Maybe Node))
-> (t a -> NodeName)
-> t a
-> FoldLike
     (Maybe Node)
     (Map NodeName Node)
     (Map NodeName Node)
     (Maybe Node)
     (Maybe Node)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> NodeName
outputNodeName (Output -> NodeName) -> (t a -> Output) -> t a -> NodeName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput (t a
 -> FoldLike
      (Maybe Node)
      (Map NodeName Node)
      (Map NodeName Node)
      (Maybe Node)
      (Maybe Node))
-> t a
-> FoldLike
     (Maybe Node)
     (Map NodeName Node)
     (Map NodeName Node)
     (Maybe Node)
     (Maybe Node)
forall a b. (a -> b) -> a -> b
$ t a
x)) [t a]
xs
        -- make a set of the nodes reachable from the xnodes
        -- The xnodes are not part of this set (unless reachable from another xnode)
        reachableSet :: IntSet
reachableSet = [Node] -> Graph -> IntSet
computeReachableSet [Node]
xnodes Graph
gr

    -- Set gradient of y to one.
    -- TODO: nicer
    let Map Node (PendingGradients a)
initPending :: Map.Map FGL.Node (PendingGradients a)
            = Map Node (PendingGradients a)
forall k a. Map k a
Map.empty Map Node (PendingGradients a)
-> (Map Node (PendingGradients a) -> Map Node (PendingGradients a))
-> Map Node (PendingGradients a)
forall s t. s -> (s -> t) -> t
& (Node
-> Lens'
     (Map Node (PendingGradients a)) (Maybe (PendingGradients a))
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at (Map NodeName Node
nodeMap Map NodeName Node -> NodeName -> Node
forall k a. Ord k => Map k a -> k -> a
Map.! NodeName
yName)
                                LensLike'
  f (Map Node (PendingGradients a)) (Maybe (PendingGradients a))
-> (([Tensor Value a] -> f [Tensor Value a])
    -> Maybe (PendingGradients a) -> f (Maybe (PendingGradients a)))
-> ([Tensor Value a] -> f [Tensor Value a])
-> Map Node (PendingGradients a)
-> f (Map Node (PendingGradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LensLike' f (Maybe (PendingGradients a)) (PendingGradients a)
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty
                                LensLike' f (Maybe (PendingGradients a)) (PendingGradients a)
-> (([Tensor Value a] -> f [Tensor Value a])
    -> PendingGradients a -> f (PendingGradients a))
-> ([Tensor Value a] -> f [Tensor Value a])
-> Maybe (PendingGradients a)
-> f (Maybe (PendingGradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OutputIx -> Lens' (PendingGradients a) (Maybe [Tensor Value a])
forall v. OutputIx -> Lens' (IntMap v) (Maybe v)
outputIxAt (Output -> OutputIx
outputIndex (Output -> OutputIx) -> Output -> OutputIx
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput Tensor Value a
y')
                                LensLike' f (PendingGradients a) (Maybe [Tensor Value a])
-> (([Tensor Value a] -> f [Tensor Value a])
    -> Maybe [Tensor Value a] -> f (Maybe [Tensor Value a]))
-> ([Tensor Value a] -> f [Tensor Value a])
-> PendingGradients a
-> f (PendingGradients a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Tensor Value a] -> f [Tensor Value a])
-> Maybe [Tensor Value a] -> f (Maybe [Tensor Value a])
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty
                                (forall (f :: * -> *).
 Identical f =>
 ([Tensor Value a] -> f [Tensor Value a])
 -> Map Node (PendingGradients a)
 -> f (Map Node (PendingGradients a)))
-> [Tensor Value a]
-> Map Node (PendingGradients a)
-> Map Node (PendingGradients a)
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Tensor Value a
yOne]
                                )
    -- Calculate the gradients of y w.r.t. each node in the graph.
    Map Node (Gradients a)
gradientMap <- Graph
-> IntSet
-> Map Node (PendingGradients a)
-> Build (Map Node (Gradients a))
forall a.
GradientCompatible a =>
Graph
-> IntSet
-> Map Node (PendingGradients a)
-> Build (Map Node (Gradients a))
graphGrads Graph
gr IntSet
reachableSet Map Node (PendingGradients a)
initPending
    -- Lookup the gradients for each x.
    [t a]
-> (t a -> BuildT Identity (Tensor Value a))
-> Build [Tensor Value a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [t a]
xs ((t a -> BuildT Identity (Tensor Value a))
 -> Build [Tensor Value a])
-> (t a -> BuildT Identity (Tensor Value a))
-> Build [Tensor Value a]
forall a b. (a -> b) -> a -> b
$ \x :: t a
x ->
        let Output i :: OutputIx
i xName :: NodeName
xName = t a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput t a
x
        in BuildT Identity (Tensor Value a)
-> (Tensor Value a -> BuildT Identity (Tensor Value a))
-> Maybe (Tensor Value a)
-> BuildT Identity (Tensor Value a)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Tensor Build a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> BuildT Identity (Tensor Value a))
-> Tensor Build a -> BuildT Identity (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
zerosLike (Tensor Build a -> Tensor Build a)
-> Tensor Build a -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ t a -> Tensor Build a
forall (t :: * -> *) a.
(ToTensor t, TensorType a) =>
t a -> Tensor Build a
toTensor t a
x) Tensor Value a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Tensor Value a) -> BuildT Identity (Tensor Value a))
-> Maybe (Tensor Value a) -> BuildT Identity (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ do
            Node
n <- Map NodeName Node
nodeMap Map NodeName Node
-> FoldLike
     (Maybe Node)
     (Map NodeName Node)
     (Map NodeName Node)
     (Maybe Node)
     (Maybe Node)
-> Maybe Node
forall s a t b. s -> FoldLike a s t a b -> a
^. NodeName -> Lens' (Map NodeName Node) (Maybe Node)
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at NodeName
xName
            Map Node (Gradients a)
gradientMap Map Node (Gradients a)
-> FoldLike
     (Maybe (Tensor Value a))
     (Map Node (Gradients a))
     (Map Node (Gradients a))
     (Maybe (Tensor Value a))
     (Maybe (Tensor Value a))
-> Maybe (Tensor Value a)
forall s a t b. s -> FoldLike a s t a b -> a
^. Node -> Lens' (Map Node (Gradients a)) (Maybe (Gradients a))
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at Node
n LensLike'
  (Constant (Maybe (Tensor Value a)))
  (Map Node (Gradients a))
  (Maybe (Gradients a))
-> ((Maybe (Tensor Value a)
     -> Constant (Maybe (Tensor Value a)) (Maybe (Tensor Value a)))
    -> Maybe (Gradients a)
    -> Constant (Maybe (Tensor Value a)) (Maybe (Gradients a)))
-> FoldLike
     (Maybe (Tensor Value a))
     (Map Node (Gradients a))
     (Map Node (Gradients a))
     (Maybe (Tensor Value a))
     (Maybe (Tensor Value a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LensLike'
  (Constant (Maybe (Tensor Value a)))
  (Maybe (Gradients a))
  (Gradients a)
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty LensLike'
  (Constant (Maybe (Tensor Value a)))
  (Maybe (Gradients a))
  (Gradients a)
-> ((Maybe (Tensor Value a)
     -> Constant (Maybe (Tensor Value a)) (Maybe (Tensor Value a)))
    -> Gradients a -> Constant (Maybe (Tensor Value a)) (Gradients a))
-> (Maybe (Tensor Value a)
    -> Constant (Maybe (Tensor Value a)) (Maybe (Tensor Value a)))
-> Maybe (Gradients a)
-> Constant (Maybe (Tensor Value a)) (Maybe (Gradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OutputIx -> Lens' (Gradients a) (Maybe (Tensor Value a))
forall v. OutputIx -> Lens' (IntMap v) (Maybe v)
outputIxAt OutputIx
i

-- | Compute a set of nodes reachable from the start nodes
--
-- the start nodes are excluded, unless reachable from another start node
computeReachableSet :: [FGL.Node] -> Graph -> IntSet.IntSet
computeReachableSet :: [Node] -> Graph -> IntSet
computeReachableSet vs :: [Node]
vs g :: Graph
g =
  [Node] -> IntSet
IntSet.fromList ([Node] -> IntSet) -> [Node] -> IntSet
forall a b. (a -> b) -> a -> b
$ (Tree Node -> [Node]) -> [Tree Node] -> [Node]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Node -> [Node] -> [Node]
forall a. Node -> [a] -> [a]
drop 1 ([Node] -> [Node]) -> (Tree Node -> [Node]) -> Tree Node -> [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree Node -> [Node]
forall a. Tree a -> [a]
FGL.preorder) ([Node] -> Graph -> [Tree Node]
forall (gr :: * -> * -> *) a b.
Graph gr =>
[Node] -> gr a b -> [Tree Node]
FGL.dff [Node]
vs Graph
g)

outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
outputIxAt :: OutputIx -> Lens' (IntMap v) (Maybe v)
outputIxAt = Node -> LensLike' f (IntMap v) (Maybe v)
forall v. Node -> Lens' (IntMap v) (Maybe v)
intAt (Node -> LensLike' f (IntMap v) (Maybe v))
-> (OutputIx -> Node)
-> OutputIx
-> LensLike' f (IntMap v) (Maybe v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OutputIx -> Node
unOutputIx

-- | Incomplete gradients of a node's outputs.
--
-- The lists represent partial sums. The key is an OutputIx sans newtype.
type PendingGradients a = IntMap.IntMap [Tensor Value a]

-- | Gradients of a node's outputs. The key is an OutputIx sans newtype.
-- TODO: precache the rendering?
type Gradients a = IntMap.IntMap (Tensor Value a)

-- | Graph of TensorFlow operations.
type Graph = FGL.Gr NodeDef EdgeLabel

-- | Data associated with an edge.
--
-- Pair of
--   1. Output index of a tensor from the source node.
--   2. Input index that the tensor connects to on the destination node.
type EdgeLabel = (OutputIx, OutputIx)


-- | State used for calculating gradients.
data GradientsState a = GradientsState
                      { GradientsState a -> Map Node (PendingGradients a)
_gradientsPending :: !(Map FGL.Node (PendingGradients a))
                      , GradientsState a -> Map Node (Gradients a)
_gradientsResult  :: !(Map FGL.Node (Gradients a))
                      }

gradientsPending :: Lens' (GradientsState a) (Map FGL.Node (PendingGradients a))
gradientsPending :: LensLike' f (GradientsState a) (Map Node (PendingGradients a))
gradientsPending = (GradientsState a -> Map Node (PendingGradients a))
-> (GradientsState a
    -> Map Node (PendingGradients a) -> GradientsState a)
-> Lens
     (GradientsState a)
     (GradientsState a)
     (Map Node (PendingGradients a))
     (Map Node (PendingGradients a))
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GradientsState a -> Map Node (PendingGradients a)
forall a. GradientsState a -> Map Node (PendingGradients a)
_gradientsPending (\x :: GradientsState a
x y :: Map Node (PendingGradients a)
y -> GradientsState a
x { _gradientsPending :: Map Node (PendingGradients a)
_gradientsPending = Map Node (PendingGradients a)
y })

gradientsResult :: Lens' (GradientsState a) (Map FGL.Node (Gradients a))
gradientsResult :: LensLike' f (GradientsState a) (Map Node (Gradients a))
gradientsResult = (GradientsState a -> Map Node (Gradients a))
-> (GradientsState a -> Map Node (Gradients a) -> GradientsState a)
-> Lens
     (GradientsState a)
     (GradientsState a)
     (Map Node (Gradients a))
     (Map Node (Gradients a))
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GradientsState a -> Map Node (Gradients a)
forall a. GradientsState a -> Map Node (Gradients a)
_gradientsResult (\x :: GradientsState a
x y :: Map Node (Gradients a)
y -> GradientsState a
x { _gradientsResult :: Map Node (Gradients a)
_gradientsResult = Map Node (Gradients a)
y })


-- TODO(fmayle): Use something like Data.List.Safe.
-- | Safe version of (!!).
safeIndex :: [a] -> Int -> Maybe a
_      safeIndex :: [a] -> Node -> Maybe a
`safeIndex` n :: Node
n | Node
n Node -> Node -> Bool
forall a. Ord a => a -> a -> Bool
< 0 = Maybe a
forall a. Maybe a
Nothing
[]     `safeIndex` _         = Maybe a
forall a. Maybe a
Nothing
(x :: a
x:_)  `safeIndex` 0         = a -> Maybe a
forall a. a -> Maybe a
Just a
x
(_:xs :: [a]
xs) `safeIndex` n :: Node
n         = [a]
xs [a] -> Node -> Maybe a
forall a. [a] -> Node -> Maybe a
`safeIndex` (Node
nNode -> Node -> Node
forall a. Num a => a -> a -> a
-1)

-- Copy of http://hackage.haskell.org/package/lens-3.9.0.2/docs/Control-Lens-Iso.html#v%3anon
anon :: a -> (a -> Bool) -> Lens' (Maybe a) a
anon :: a -> (a -> Bool) -> Lens' (Maybe a) a
anon a :: a
a p :: a -> Bool
p = Resetter (Maybe a) (f (Maybe a)) a (f a)
-> (a -> f a) -> Maybe a -> f (Maybe a)
forall s t a b. Resetter s t a b -> (a -> b) -> s -> t
under ((Maybe a -> a) -> (a -> Maybe a) -> Adapter (Maybe a) (Maybe a) a a
forall s a b t. (s -> a) -> (b -> t) -> Adapter s t a b
adapter (a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe a
a) a -> Maybe a
go) where
  go :: a -> Maybe a
go b :: a
b | a -> Bool
p a
b       = Maybe a
forall a. Maybe a
Nothing
       | Bool
otherwise = a -> Maybe a
forall a. a -> Maybe a
Just a
b

non :: Eq a => a -> Lens' (Maybe a) a
non :: a -> Lens' (Maybe a) a
non a :: a
a = a -> (a -> Bool) -> Lens' (Maybe a) a
forall a. a -> (a -> Bool) -> Lens' (Maybe a) a
anon a
a (a
aa -> a -> Bool
forall a. Eq a => a -> a -> Bool
==)

-- | Lens that defaults Nothing to mempty.
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
nonEmpty :: Lens' (Maybe (t v)) (t v)
nonEmpty = t v -> (t v -> Bool) -> Lens' (Maybe (t v)) (t v)
forall a. a -> (a -> Bool) -> Lens' (Maybe a) a
anon t v
forall a. Monoid a => a
mempty t v -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null

-- TODO: strictness (e.g., foldlM')

-- | Calculate the gradients for every node in a graph.
graphGrads :: forall a. GradientCompatible a
           => Graph
           -> IntSet.IntSet
           -> Map FGL.Node (PendingGradients a)
           -- ^ Initial gradients (usually just 1 for the node of interest).
           -> Build (Map FGL.Node (Gradients a))
graphGrads :: Graph
-> IntSet
-> Map Node (PendingGradients a)
-> Build (Map Node (Gradients a))
graphGrads gr :: Graph
gr reachableSet :: IntSet
reachableSet initPending :: Map Node (PendingGradients a)
initPending = FoldLike
  (Map Node (Gradients a))
  (GradientsState a)
  (GradientsState a)
  (Map Node (Gradients a))
  (Map Node (Gradients a))
-> GradientsState a -> Map Node (Gradients a)
forall a s t b. FoldLike a s t a b -> s -> a
view FoldLike
  (Map Node (Gradients a))
  (GradientsState a)
  (GradientsState a)
  (Map Node (Gradients a))
  (Map Node (Gradients a))
forall a. Lens' (GradientsState a) (Map Node (Gradients a))
gradientsResult (GradientsState a -> Map Node (Gradients a))
-> BuildT Identity (GradientsState a)
-> Build (Map Node (Gradients a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (GradientsState a -> Node -> BuildT Identity (GradientsState a))
-> GradientsState a -> [Node] -> BuildT Identity (GradientsState a)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM GradientsState a -> Node -> BuildT Identity (GradientsState a)
go GradientsState a
initState [Node]
nodeOrder
  where
    initState :: GradientsState a
initState = Map Node (PendingGradients a)
-> Map Node (Gradients a) -> GradientsState a
forall a.
Map Node (PendingGradients a)
-> Map Node (Gradients a) -> GradientsState a
GradientsState Map Node (PendingGradients a)
initPending Map Node (Gradients a)
forall k a. Map k a
Map.empty
    -- Reverse topological sort.
    nodeOrder :: [Node]
nodeOrder = Graph -> [Node]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [Node]
FGL.topsort (Graph -> [Node]) -> (Graph -> Graph) -> Graph -> [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Graph -> Graph
forall (gr :: * -> * -> *) a b. DynGraph gr => gr a b -> gr a b
FGL.grev (Graph -> [Node]) -> Graph -> [Node]
forall a b. (a -> b) -> a -> b
$ Graph
gr
    go :: GradientsState a -> Int -> Build (GradientsState a)
    go :: GradientsState a -> Node -> BuildT Identity (GradientsState a)
go state :: GradientsState a
state node :: Node
node = do
        -- Aggregate the accumulated gradients for this node.
        Gradients a
outputGrads <-
                PendingGradients a -> Build (Gradients a)
forall a.
GradientCompatible a =>
PendingGradients a -> Build (Gradients a)
sumPendingGradient (GradientsState a
state GradientsState a
-> FoldLike
     (PendingGradients a)
     (GradientsState a)
     (GradientsState a)
     (PendingGradients a)
     (PendingGradients a)
-> PendingGradients a
forall s a t b. s -> FoldLike a s t a b -> a
^. LensLike'
  (Constant (PendingGradients a))
  (GradientsState a)
  (Map Node (PendingGradients a))
forall a. Lens' (GradientsState a) (Map Node (PendingGradients a))
gradientsPending LensLike'
  (Constant (PendingGradients a))
  (GradientsState a)
  (Map Node (PendingGradients a))
-> ((PendingGradients a
     -> Constant (PendingGradients a) (PendingGradients a))
    -> Map Node (PendingGradients a)
    -> Constant (PendingGradients a) (Map Node (PendingGradients a)))
-> FoldLike
     (PendingGradients a)
     (GradientsState a)
     (GradientsState a)
     (PendingGradients a)
     (PendingGradients a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Node
-> Lens'
     (Map Node (PendingGradients a)) (Maybe (PendingGradients a))
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at Node
node LensLike'
  (Constant (PendingGradients a))
  (Map Node (PendingGradients a))
  (Maybe (PendingGradients a))
-> ((PendingGradients a
     -> Constant (PendingGradients a) (PendingGradients a))
    -> Maybe (PendingGradients a)
    -> Constant (PendingGradients a) (Maybe (PendingGradients a)))
-> (PendingGradients a
    -> Constant (PendingGradients a) (PendingGradients a))
-> Map Node (PendingGradients a)
-> Constant (PendingGradients a) (Map Node (PendingGradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PendingGradients a
 -> Constant (PendingGradients a) (PendingGradients a))
-> Maybe (PendingGradients a)
-> Constant (PendingGradients a) (Maybe (PendingGradients a))
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty)
        if Gradients a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Gradients a
outputGrads
           then GradientsState a -> BuildT Identity (GradientsState a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure GradientsState a
state
           else do
              let nextState :: GradientsState a
nextState = GradientsState a
state GradientsState a
-> (GradientsState a -> GradientsState a) -> GradientsState a
forall s t. s -> (s -> t) -> t
& forall a. Lens' (GradientsState a) (Map Node (Gradients a))
forall (f :: * -> *).
Identical f =>
LensLike' f (GradientsState a) (Map Node (Gradients a))
gradientsResult (forall (f :: * -> *).
 Identical f =>
 LensLike' f (GradientsState a) (Map Node (Gradients a)))
-> (Map Node (Gradients a) -> Map Node (Gradients a))
-> GradientsState a
-> GradientsState a
forall s t a b. Setter s t a b -> (a -> b) -> s -> t
%~ Node
-> Gradients a -> Map Node (Gradients a) -> Map Node (Gradients a)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Node
node Gradients a
outputGrads
              -- Only consider nodes that are reachable from the inputs to
              -- avoid calculating gradients that won't be used.
              if Node
node Node -> IntSet -> Bool
`IntSet.member` IntSet
reachableSet
                then do
                  let ctx :: Context NodeDef EdgeLabel
ctx = Graph -> Node -> Context NodeDef EdgeLabel
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Node -> Context a b
FGL.context Graph
gr Node
node
                  [Maybe (Tensor Value a)]
inputGrads <- Context NodeDef EdgeLabel
-> Gradients a -> Graph -> Build [Maybe (Tensor Value a)]
forall a.
GradientCompatible a =>
Context NodeDef EdgeLabel
-> Gradients a -> Graph -> Build [Maybe (Tensor Value a)]
calculateInputGrads Context NodeDef EdgeLabel
ctx Gradients a
outputGrads Graph
gr
                  -- Calculate the gradients for each of the node's inputs.
                  GradientsState a -> BuildT Identity (GradientsState a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (GradientsState a -> BuildT Identity (GradientsState a))
-> GradientsState a -> BuildT Identity (GradientsState a)
forall a b. (a -> b) -> a -> b
$ Context NodeDef EdgeLabel
-> [Maybe (Tensor Value a)] -> GradientsState a -> GradientsState a
forall a.
(TensorType a, Num a) =>
Context NodeDef EdgeLabel
-> [Maybe (Tensor Value a)] -> GradientsState a -> GradientsState a
updatePendingGradients Context NodeDef EdgeLabel
ctx [Maybe (Tensor Value a)]
inputGrads GradientsState a
nextState
                else
                  GradientsState a -> BuildT Identity (GradientsState a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure GradientsState a
nextState

-- | Reduce accumulated gradients for each output to one Tensor.
sumPendingGradient :: GradientCompatible a
                   => PendingGradients a -> Build (Gradients a)
sumPendingGradient :: PendingGradients a -> Build (Gradients a)
sumPendingGradient = IntMap (BuildT Identity (Tensor Value a)) -> Build (Gradients a)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence (IntMap (BuildT Identity (Tensor Value a)) -> Build (Gradients a))
-> (PendingGradients a
    -> IntMap (BuildT Identity (Tensor Value a)))
-> PendingGradients a
-> Build (Gradients a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Tensor Value a] -> Maybe (BuildT Identity (Tensor Value a)))
-> PendingGradients a -> IntMap (BuildT Identity (Tensor Value a))
forall a b. (a -> Maybe b) -> IntMap a -> IntMap b
IntMap.mapMaybe [Tensor Value a] -> Maybe (BuildT Identity (Tensor Value a))
forall a (f :: * -> *).
(a /= ByteString, a /= Bool, MonadBuild f, TensorType a) =>
[Tensor Value a] -> Maybe (f (Tensor Value a))
f
  where
    f :: [Tensor Value a] -> Maybe (f (Tensor Value a))
f [] = Maybe (f (Tensor Value a))
forall a. Maybe a
Nothing
    f [x :: Tensor Value a
x] = f (Tensor Value a) -> Maybe (f (Tensor Value a))
forall a. a -> Maybe a
Just (Tensor Value a -> f (Tensor Value a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor Value a
x)
    f xs :: [Tensor Value a]
xs = f (Tensor Value a) -> Maybe (f (Tensor Value a))
forall a. a -> Maybe a
Just (Tensor Build a -> f (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> f (Tensor Value a))
-> Tensor Build a -> f (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ [Tensor Value a] -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word32, Word64, Word8, Double, Float, Variant]
  t =>
[Tensor v'1 t] -> Tensor Build t
addN [Tensor Value a]
xs)


-- | Calculate the gradients of a node's input tensors.
--
-- This is mostly just a wrapper around opGrad.
calculateInputGrads :: forall a. GradientCompatible a
                    => FGL.Context NodeDef EdgeLabel
                    -> Gradients a  -- ^ Output gradients of the node.
                    -> Graph
                    -> Build [Maybe (Tensor Value a)]
calculateInputGrads :: Context NodeDef EdgeLabel
-> Gradients a -> Graph -> Build [Maybe (Tensor Value a)]
calculateInputGrads (inputEdges :: Adj EdgeLabel
inputEdges, _, nodeDef :: NodeDef
nodeDef, _) outputGrads :: Gradients a
outputGrads gr :: Graph
gr = do
    [Tensor Value a]
fullOutGrads <- OutputIx -> NodeName -> Gradients a -> Build [Tensor Value a]
forall a.
(TensorType a, Num a) =>
OutputIx -> NodeName -> Gradients a -> Build [Tensor Value a]
fullOutputGrads (NodeDef -> OutputIx
numOutputs NodeDef
nodeDef) (NodeDef -> NodeName
nodeDefName NodeDef
nodeDef)
                        Gradients a
outputGrads
    (Maybe (Tensor Build a)
 -> BuildT Identity (Maybe (Tensor Value a)))
-> [Maybe (Tensor Build a)] -> Build [Maybe (Tensor Value a)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Tensor Build a -> BuildT Identity (Tensor Value a))
-> Maybe (Tensor Build a)
-> BuildT Identity (Maybe (Tensor Value a))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Tensor Build a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render) ([Maybe (Tensor Build a)] -> Build [Maybe (Tensor Value a)])
-> [Maybe (Tensor Build a)] -> Build [Maybe (Tensor Value a)]
forall a b. (a -> b) -> a -> b
$ Text -> GradientFunc a
forall a. GradientCompatible a => Text -> GradientFunc a
opGrad (NodeDef
nodeDef NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "op" a) =>
LensLike' f s a
op) NodeDef
nodeDef [Output]
inputTensors [Tensor Value a]
fullOutGrads
  where
    -- Create a tensor from an edge (technically an Output, but it seems less
    -- confusing to refer to it as a tensor here).
    edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
    edgeToTensor :: (EdgeLabel, Node) -> Output
edgeToTensor ((i :: OutputIx
i, _), n :: Node
n) =
        case Graph -> Node -> Maybe NodeDef
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Node -> Maybe a
FGL.lab Graph
gr Node
n of
            Just edgeNodeDef :: NodeDef
edgeNodeDef -> OutputIx -> NodeName -> Output
Output OutputIx
i (Text -> NodeName
NodeName (Text -> NodeName) -> Text -> NodeName
forall a b. (a -> b) -> a -> b
$ NodeDef
edgeNodeDef NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name)
            Nothing -> [Char] -> Output
forall a. HasCallStack => [Char] -> a
error ([Char] -> Output) -> [Char] -> Output
forall a b. (a -> b) -> a -> b
$ "calculateInputGrads: missing input node for "
                               [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Text -> [Char]
Text.unpack (NodeDef
nodeDef NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name)
    -- Input tensors, sorted by input index.
    inputTensors :: [Output]
inputTensors = ((EdgeLabel, Node) -> Output) -> Adj EdgeLabel -> [Output]
forall a b. (a -> b) -> [a] -> [b]
map (EdgeLabel, Node) -> Output
edgeToTensor (Adj EdgeLabel -> [Output]) -> Adj EdgeLabel -> [Output]
forall a b. (a -> b) -> a -> b
$ ((EdgeLabel, Node) -> (EdgeLabel, Node) -> Ordering)
-> Adj EdgeLabel -> Adj EdgeLabel
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((EdgeLabel, Node) -> OutputIx)
-> (EdgeLabel, Node) -> (EdgeLabel, Node) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (EdgeLabel -> OutputIx
forall a b. (a, b) -> b
snd (EdgeLabel -> OutputIx)
-> ((EdgeLabel, Node) -> EdgeLabel)
-> (EdgeLabel, Node)
-> OutputIx
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeLabel, Node) -> EdgeLabel
forall a b. (a, b) -> a
fst)) Adj EdgeLabel
inputEdges

-- | Convert a Map of gradients to a list, with zeros for missing outputs.
fullOutputGrads :: (TensorType a, Num a)
                => OutputIx  -- ^ Number of outputs.
                -> NodeName
                -> Gradients a
                -> Build [Tensor Value a]
fullOutputGrads :: OutputIx -> NodeName -> Gradients a -> Build [Tensor Value a]
fullOutputGrads n :: OutputIx
n o :: NodeName
o gs :: Gradients a
gs =
    (OutputIx -> BuildT Identity (Tensor Value a))
-> [OutputIx] -> Build [Tensor Value a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\i :: OutputIx
i -> BuildT Identity (Tensor Value a)
-> (Tensor Value a -> BuildT Identity (Tensor Value a))
-> Maybe (Tensor Value a)
-> BuildT Identity (Tensor Value a)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Tensor Build a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> BuildT Identity (Tensor Value a))
-> Tensor Build a -> BuildT Identity (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ OutputIx -> Tensor Build a
zero OutputIx
i) Tensor Value a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Gradients a
gs Gradients a
-> FoldLike
     (Maybe (Tensor Value a))
     (Gradients a)
     (Gradients a)
     (Maybe (Tensor Value a))
     (Maybe (Tensor Value a))
-> Maybe (Tensor Value a)
forall s a t b. s -> FoldLike a s t a b -> a
^. OutputIx -> Lens' (Gradients a) (Maybe (Tensor Value a))
forall v. OutputIx -> Lens' (IntMap v) (Maybe v)
outputIxAt OutputIx
i)) [0..OutputIx
nOutputIx -> OutputIx -> OutputIx
forall a. Num a => a -> a -> a
-1]
  where
    -- A tensor of zeros with the same shape as the i'th output.
    zero :: OutputIx -> Tensor Build a
zero i :: OutputIx
i = Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
zerosLike (Tensor Build a -> Tensor Build a)
-> Tensor Build a -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT (OutputIx -> NodeName -> Output
Output OutputIx
i NodeName
o)


-- | Update the pending gradients of a node's inputs.
updatePendingGradients :: forall a. (TensorType a, Num a)
                       => FGL.Context NodeDef EdgeLabel
                       -> [Maybe (Tensor Value a)]
                       -- ^ Gradient of each input tensor.
                       -> GradientsState a
                       -> GradientsState a
updatePendingGradients :: Context NodeDef EdgeLabel
-> [Maybe (Tensor Value a)] -> GradientsState a -> GradientsState a
updatePendingGradients (inputEdges :: Adj EdgeLabel
inputEdges, _, nodeDef :: NodeDef
nodeDef, _) inputGrads :: [Maybe (Tensor Value a)]
inputGrads initState :: GradientsState a
initState =
    (GradientsState a -> (EdgeLabel, Node) -> GradientsState a)
-> GradientsState a -> Adj EdgeLabel -> GradientsState a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' GradientsState a -> (EdgeLabel, Node) -> GradientsState a
go GradientsState a
initState Adj EdgeLabel
inputEdges
  where
    go :: GradientsState a -> (EdgeLabel, FGL.Node) -> GradientsState a
    go :: GradientsState a -> (EdgeLabel, Node) -> GradientsState a
go state :: GradientsState a
state ((outIndex :: OutputIx
outIndex, OutputIx inIndex :: Node
inIndex), node :: Node
node) =
        case Maybe (Tensor Value a)
maybeGradient of
            Nothing -> GradientsState a
state
            Just g :: Tensor Value a
g ->
                -- Add to the list of pending gradients for this tensor.
                GradientsState a
state GradientsState a
-> (GradientsState a -> GradientsState a) -> GradientsState a
forall s t. s -> (s -> t) -> t
& LensLike' f (GradientsState a) (Map Node (PendingGradients a))
forall a. Lens' (GradientsState a) (Map Node (PendingGradients a))
gradientsPending
                      LensLike' f (GradientsState a) (Map Node (PendingGradients a))
-> (([Tensor Value a] -> f [Tensor Value a])
    -> Map Node (PendingGradients a)
    -> f (Map Node (PendingGradients a)))
-> ([Tensor Value a] -> f [Tensor Value a])
-> GradientsState a
-> f (GradientsState a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Node
-> Lens'
     (Map Node (PendingGradients a)) (Maybe (PendingGradients a))
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at Node
node
                      LensLike'
  f (Map Node (PendingGradients a)) (Maybe (PendingGradients a))
-> (([Tensor Value a] -> f [Tensor Value a])
    -> Maybe (PendingGradients a) -> f (Maybe (PendingGradients a)))
-> ([Tensor Value a] -> f [Tensor Value a])
-> Map Node (PendingGradients a)
-> f (Map Node (PendingGradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LensLike' f (Maybe (PendingGradients a)) (PendingGradients a)
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty
                      LensLike' f (Maybe (PendingGradients a)) (PendingGradients a)
-> (([Tensor Value a] -> f [Tensor Value a])
    -> PendingGradients a -> f (PendingGradients a))
-> ([Tensor Value a] -> f [Tensor Value a])
-> Maybe (PendingGradients a)
-> f (Maybe (PendingGradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OutputIx -> Lens' (PendingGradients a) (Maybe [Tensor Value a])
forall v. OutputIx -> Lens' (IntMap v) (Maybe v)
outputIxAt OutputIx
outIndex
                      LensLike' f (PendingGradients a) (Maybe [Tensor Value a])
-> (([Tensor Value a] -> f [Tensor Value a])
    -> Maybe [Tensor Value a] -> f (Maybe [Tensor Value a]))
-> ([Tensor Value a] -> f [Tensor Value a])
-> PendingGradients a
-> f (PendingGradients a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Tensor Value a] -> f [Tensor Value a])
-> Maybe [Tensor Value a] -> f (Maybe [Tensor Value a])
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty
                      (forall (f :: * -> *).
 Identical f =>
 ([Tensor Value a] -> f [Tensor Value a])
 -> GradientsState a -> f (GradientsState a))
-> ([Tensor Value a] -> [Tensor Value a])
-> GradientsState a
-> GradientsState a
forall s t a b. Setter s t a b -> (a -> b) -> s -> t
%~ (Tensor Value a
gTensor Value a -> [Tensor Value a] -> [Tensor Value a]
forall a. a -> [a] -> [a]
:)
      where
        badSizeErr :: Maybe (Tensor Value a)
badSizeErr = [Char] -> Maybe (Tensor Value a)
forall a. HasCallStack => [Char] -> a
error ([Char] -> Maybe (Tensor Value a))
-> [Char] -> Maybe (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ [Char] -> Node -> Node -> [Char] -> [Char]
forall r. PrintfType r => [Char] -> r
printf "updatePendingGradients: bad input index \
                                    \%d for inputGrads of length %d in %s"
                                    Node
inIndex ([Maybe (Tensor Value a)] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [Maybe (Tensor Value a)]
inputGrads)
                                    (Text -> [Char]
forall a. Show a => a -> [Char]
show (NodeDef
nodeDef NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name))
        maybeGradient :: Maybe (Tensor Value a)
maybeGradient = Maybe (Tensor Value a)
-> Maybe (Maybe (Tensor Value a)) -> Maybe (Tensor Value a)
forall a. a -> Maybe a -> a
fromMaybe Maybe (Tensor Value a)
badSizeErr ([Maybe (Tensor Value a)] -> Node -> Maybe (Maybe (Tensor Value a))
forall a. [a] -> Node -> Maybe a
safeIndex [Maybe (Tensor Value a)]
inputGrads Node
inIndex)


-- | Create a graph that includes a node and its transitive dependencies.
createGraph :: NodeName -> (NodeName -> NodeDef)
            -> (Graph, Map NodeName FGL.Node)
createGraph :: NodeName -> (NodeName -> NodeDef) -> (Graph, Map NodeName Node)
createGraph nodeName :: NodeName
nodeName nodeDefLookup :: NodeName -> NodeDef
nodeDefLookup = ((NodeName -> NodeDef) -> Gr NodeName EdgeLabel -> Graph
forall (gr :: * -> * -> *) a c b.
DynGraph gr =>
(a -> c) -> gr a b -> gr c b
FGL.nmap NodeName -> NodeDef
nodeDefLookup Gr NodeName EdgeLabel
graph, Map NodeName Node
nodeMap)
  where
    -- Parse a tensor name.
    parseTensorName :: Text -> Maybe (NodeName, OutputIx)
    parseTensorName :: Text -> Maybe (NodeName, OutputIx)
parseTensorName n :: Text
n
        | Text -> Bool
Text.null Text
n        = [Char] -> Maybe (NodeName, OutputIx)
forall a. HasCallStack => [Char] -> a
error "parseTensorName: empty name"
        | Text -> Char
Text.head Text
n Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== '^' = Maybe (NodeName, OutputIx)
forall a. Maybe a
Nothing  -- Control edge
        | Bool
otherwise          =
            let (nm :: Text
nm, indexStr :: Text
indexStr) = Text -> Text -> (Text, Text)
Text.breakOn ":" Text
n
                index :: Node
index | Text -> Bool
Text.null Text
indexStr = 0
                      | Bool
otherwise = [Char] -> Node
forall a. Read a => [Char] -> a
read ([Char] -> Node) -> [Char] -> Node
forall a b. (a -> b) -> a -> b
$ Text -> [Char]
Text.unpack (Text -> [Char]) -> Text -> [Char]
forall a b. (a -> b) -> a -> b
$ Text -> Text
Text.tail Text
indexStr
            in (NodeName, OutputIx) -> Maybe (NodeName, OutputIx)
forall a. a -> Maybe a
Just (Text -> NodeName
NodeName Text
nm, Node -> OutputIx
OutputIx Node
index)

    -- Build a map from node name to outward edges.
    --
    -- The state is the set of visited nodes.
    collect :: Maybe (NodeName, OutputIx, OutputIx)
            -> NodeName
            -> State (Set NodeName)
                     (Map NodeName [(NodeName, OutputIx, OutputIx)])
    collect :: Maybe (NodeName, OutputIx, OutputIx)
-> NodeName
-> State
     (Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
collect outgoingEdge :: Maybe (NodeName, OutputIx, OutputIx)
outgoingEdge nm :: NodeName
nm = do
        let nextLookup :: Map NodeName [(NodeName, OutputIx, OutputIx)]
nextLookup = NodeName
-> [(NodeName, OutputIx, OutputIx)]
-> Map NodeName [(NodeName, OutputIx, OutputIx)]
forall k a. k -> a -> Map k a
Map.singleton NodeName
nm (Maybe (NodeName, OutputIx, OutputIx)
-> [(NodeName, OutputIx, OutputIx)]
forall a. Maybe a -> [a]
maybeToList Maybe (NodeName, OutputIx, OutputIx)
outgoingEdge)
        Bool
seen <- (Set NodeName -> Bool) -> StateT (Set NodeName) Identity Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (NodeName -> Set NodeName -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member NodeName
nm)
        (Set NodeName -> Set NodeName) -> StateT (Set NodeName) Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (NodeName -> Set NodeName -> Set NodeName
forall a. Ord a => a -> Set a -> Set a
Set.insert NodeName
nm)
        if Bool
seen
            then Map NodeName [(NodeName, OutputIx, OutputIx)]
-> State
     (Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map NodeName [(NodeName, OutputIx, OutputIx)]
nextLookup
            else do
                let inputs :: [Text]
inputs = NodeName -> NodeDef
nodeDefLookup NodeName
nm NodeDef -> FoldLike [Text] NodeDef NodeDef [Text] [Text] -> [Text]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [Text] NodeDef NodeDef [Text] [Text]
forall (f :: * -> *) s a.
(Functor f, HasField s "input" a) =>
LensLike' f s a
input
                    recurse :: OutputIx
-> (NodeName, OutputIx)
-> State
     (Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
recurse inIndex :: OutputIx
inIndex (parentName :: NodeName
parentName, outIndex :: OutputIx
outIndex) =
                        Maybe (NodeName, OutputIx, OutputIx)
-> NodeName
-> State
     (Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
collect ((NodeName, OutputIx, OutputIx)
-> Maybe (NodeName, OutputIx, OutputIx)
forall a. a -> Maybe a
Just (NodeName
nm, OutputIx
outIndex, OutputIx
inIndex)) NodeName
parentName
                [Map NodeName [(NodeName, OutputIx, OutputIx)]]
subEdgeLookups <-
                    (OutputIx
 -> (NodeName, OutputIx)
 -> State
      (Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)]))
-> [OutputIx]
-> [(NodeName, OutputIx)]
-> StateT
     (Set NodeName)
     Identity
     [Map NodeName [(NodeName, OutputIx, OutputIx)]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM OutputIx
-> (NodeName, OutputIx)
-> State
     (Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
recurse [0..] ([(NodeName, OutputIx)]
 -> StateT
      (Set NodeName)
      Identity
      [Map NodeName [(NodeName, OutputIx, OutputIx)]])
-> [(NodeName, OutputIx)]
-> StateT
     (Set NodeName)
     Identity
     [Map NodeName [(NodeName, OutputIx, OutputIx)]]
forall a b. (a -> b) -> a -> b
$ (Text -> Maybe (NodeName, OutputIx))
-> [Text] -> [(NodeName, OutputIx)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Text -> Maybe (NodeName, OutputIx)
parseTensorName [Text]
inputs
                Map NodeName [(NodeName, OutputIx, OutputIx)]
-> State
     (Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map NodeName [(NodeName, OutputIx, OutputIx)]
 -> State
      (Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)]))
-> Map NodeName [(NodeName, OutputIx, OutputIx)]
-> State
     (Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
forall a b. (a -> b) -> a -> b
$ ([(NodeName, OutputIx, OutputIx)]
 -> [(NodeName, OutputIx, OutputIx)]
 -> [(NodeName, OutputIx, OutputIx)])
-> [Map NodeName [(NodeName, OutputIx, OutputIx)]]
-> Map NodeName [(NodeName, OutputIx, OutputIx)]
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(a -> a -> a) -> f (Map k a) -> Map k a
Map.unionsWith [(NodeName, OutputIx, OutputIx)]
-> [(NodeName, OutputIx, OutputIx)]
-> [(NodeName, OutputIx, OutputIx)]
forall a. [a] -> [a] -> [a]
(++) (Map NodeName [(NodeName, OutputIx, OutputIx)]
nextLookupMap NodeName [(NodeName, OutputIx, OutputIx)]
-> [Map NodeName [(NodeName, OutputIx, OutputIx)]]
-> [Map NodeName [(NodeName, OutputIx, OutputIx)]]
forall a. a -> [a] -> [a]
:[Map NodeName [(NodeName, OutputIx, OutputIx)]]
subEdgeLookups)

    edgeLookup :: Map NodeName [(NodeName, OutputIx, OutputIx)]
edgeLookup = State
  (Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
-> Set NodeName -> Map NodeName [(NodeName, OutputIx, OutputIx)]
forall s a. State s a -> s -> a
evalState (Maybe (NodeName, OutputIx, OutputIx)
-> NodeName
-> State
     (Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
collect Maybe (NodeName, OutputIx, OutputIx)
forall a. Maybe a
Nothing NodeName
nodeName) Set NodeName
forall a. Set a
Set.empty
    -- Associate an ID with each node name.
    nodeMap :: Map NodeName Node
nodeMap = [(NodeName, Node)] -> Map NodeName Node
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(NodeName, Node)] -> Map NodeName Node)
-> [(NodeName, Node)] -> Map NodeName Node
forall a b. (a -> b) -> a -> b
$ [NodeName] -> [Node] -> [(NodeName, Node)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Map NodeName [(NodeName, OutputIx, OutputIx)] -> [NodeName]
forall k a. Map k a -> [k]
Map.keys Map NodeName [(NodeName, OutputIx, OutputIx)]
edgeLookup) [0..]
    -- Create the graph.
    graph :: Gr NodeName EdgeLabel
graph = [LNode NodeName] -> [LEdge EdgeLabel] -> Gr NodeName EdgeLabel
forall (gr :: * -> * -> *) a b.
Graph gr =>
[LNode a] -> [LEdge b] -> gr a b
FGL.mkGraph ((NodeName, Node) -> LNode NodeName
forall a b. (a, b) -> (b, a)
swap ((NodeName, Node) -> LNode NodeName)
-> [(NodeName, Node)] -> [LNode NodeName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map NodeName Node -> [(NodeName, Node)]
forall k a. Map k a -> [(k, a)]
Map.toList Map NodeName Node
nodeMap)
                        [ (Map NodeName Node
nodeMap Map NodeName Node -> NodeName -> Node
forall k a. Ord k => Map k a -> k -> a
Map.! NodeName
n, Map NodeName Node
nodeMap Map NodeName Node -> NodeName -> Node
forall k a. Ord k => Map k a -> k -> a
Map.! NodeName
m, (OutputIx
i, OutputIx
j))
                        | (n :: NodeName
n, edges :: [(NodeName, OutputIx, OutputIx)]
edges) <- Map NodeName [(NodeName, OutputIx, OutputIx)]
-> [(NodeName, [(NodeName, OutputIx, OutputIx)])]
forall k a. Map k a -> [(k, a)]
Map.toList Map NodeName [(NodeName, OutputIx, OutputIx)]
edgeLookup
                        , (m :: NodeName
m, i :: OutputIx
i, j :: OutputIx
j) <- [(NodeName, OutputIx, OutputIx)]
edges
                        ]

-- | Function to compute the gradient of y w.r.t. each input.
--
-- Let y be an arbitrary tensor
-- and [w_0, ..., w_n] be the output tensors of a node
-- and [v_0, ..., v_n] be the input tensors of the same node.
--
-- Given [∂y/∂w_0, ..., ∂y/∂w_n] and [v_0, ..., v_n], a GradientFunc computes
-- [∂y/∂v_0, ..., ∂y/∂v_n] for a particular op type.
--
-- A Nothing gradient is equivalent to zero (but allows for short circuiting
-- computation when all the gradients for something are Nothing).
type GradientFunc a = NodeDef
                    -> [Output]
                    -- ^ Input tensors.
                    -> [Tensor Value a]
                    -- ^ Gradient of y w.r.t. each output tensor.
                    -> [Maybe (Tensor Build a)]
                    -- ^ Gradient of y w.r.t. each input tensor.


-- TODO(fmayle): Assert the type is correct.
-- | Create a Tensor from an Output.
toT :: Output -> Tensor Build a
toT :: Output -> Tensor Build a
toT = Build Output -> Tensor Build a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (Build Output -> Tensor Build a)
-> (Output -> Build Output) -> Output -> Tensor Build a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> Build Output
forall (f :: * -> *) a. Applicative f => a -> f a
pure


-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
-- simple slicing operations.
flatSlice :: forall v1 t . TensorType t
         => Tensor v1 t    -- ^ __input__
         -> Int32          -- ^ __begin__: specifies the offset into the first dimension of
                           -- 'input' to slice from.
         -> Int32          -- ^ __size__: specifies the number of elements of the first dimension
                           -- of 'input' to slice. If size is -1, all remaining elements in the dimension
                           -- are included in the slice (i.e. this is equivalent to setting
                           -- size = input.dim_size(0) - begin).
         -> Tensor Build t -- ^ __output__
flatSlice :: Tensor v1 t -> Int32 -> Int32 -> Tensor Build t
flatSlice t :: Tensor v1 t
t begin :: Int32
begin size :: Int32
size = Tensor v1 t
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t index.
(TensorType t, OneOf '[Int32, Int64] index) =>
Tensor v'1 t
-> Tensor v'2 index -> Tensor v'3 index -> Tensor Build t
CoreOps.slice Tensor v1 t
t ([Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [Int32
begin]) ([Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [Int32
size])

nodeDefName :: NodeDef -> NodeName
nodeDefName :: NodeDef -> NodeName
nodeDefName = Text -> NodeName
NodeName (Text -> NodeName) -> (NodeDef -> Text) -> NodeDef -> NodeName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FoldLike Text NodeDef NodeDef Text Text -> NodeDef -> Text
forall a s t b. FoldLike a s t a b -> s -> a
view FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name

-- | Gradient helper for binary component wise operations
-- See https://github.com/tensorflow/tensorflow/blob/e9de087fa7f59c39bbe12ac2c83c5547c83f746c/tensorflow/core/ops/math_grad.cc#L329
gradForBinaryCwise :: ( OneOf '[ Int32, Int64, Float, Double, Complex Float, Complex Double ] t
                      )
                   => (Tensor v1 t, Tensor v1 t)
                   -> (Tensor v1 t, Tensor v1 t)
                   -> [ Maybe (Tensor Build t) ]
gradForBinaryCwise :: (Tensor v1 t, Tensor v1 t)
-> (Tensor v1 t, Tensor v1 t) -> [Maybe (Tensor Build t)]
gradForBinaryCwise (x :: Tensor v1 t
x, gx :: Tensor v1 t
gx) (y :: Tensor v1 t
y, gy :: Tensor v1 t
gy) =
    [ Tensor Build t -> Maybe (Tensor Build t)
forall a. a -> Maybe a
Just Tensor Build t
dx
    , Tensor Build t -> Maybe (Tensor Build t)
forall a. a -> Maybe a
Just Tensor Build t
dy ]
  where
    dx :: Tensor Build t
dx = Tensor Build t -> Tensor Build Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor v1 t -> Tensor Build Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor v1 t
gx Tensor Build Int32
rx) Tensor Build Int32
sx
    dy :: Tensor Build t
dy = Tensor Build t -> Tensor Build Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor v1 t -> Tensor Build Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor v1 t
gy Tensor Build Int32
ry) Tensor Build Int32
sy
    sx :: Tensor Build Int32
sx = Tensor v1 t -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor v1 t
x
    sy :: Tensor Build Int32
sy = Tensor v1 t -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor v1 t
y
    (rx :: Tensor Build Int32
rx, ry :: Tensor Build Int32
ry) = Tensor Build Int32
-> Tensor Build Int32 -> (Tensor Build Int32, Tensor Build Int32)
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64] t =>
Tensor v'1 t -> Tensor v'2 t -> (Tensor Build t, Tensor Build t)
broadcastGradientArgs Tensor Build Int32
sx Tensor Build Int32
sy

-- | The gradient function for an op type.
--
-- These implementations should match their python counterparts in:
-- third_party/tensorflow/python/ops/*_grad.py
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a

opGrad :: Text -> GradientFunc a
opGrad "Abs" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr Tensor Value a
dz Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
* Tensor Build a -> Tensor Build a
forall a. Num a => a -> a
signum Tensor Build a
x]
opGrad "Neg" _ [_] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build a
forall a. Num a => a -> a
negate (Tensor Build a -> Tensor Build a)
-> Tensor Build a -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr Tensor Value a
dz]
opGrad "Relu" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
    Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
reluGrad Tensor Value a
dz Tensor Build a
x]
opGrad "ReluGrad" _ [_, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x ] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
    Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
reluGrad Tensor Value a
dz Tensor Build a
x, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
CoreOps.zerosLike Tensor Build a
x]
opGrad "Tanh" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
tanhGrad (Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
tanh Tensor Build a
x) Tensor Value a
dz]
opGrad "Sigmoid" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
sigmoidGrad (Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
sigmoid Tensor Build a
x) Tensor Value a
dz]

opGrad "Concat" _ _ix :: [Output]
_ix [dy :: Tensor Value a
dy]
    -- Concat concatenates input tensors
    --   x1 of shape s1 = [k1, ..., ki_1, ..., kn]
    --   x2 of shape s2 = [k1, ..., ki_2, ..., kn]
    --    .           .     .          .        .
    --    .           .     .          .        .
    --    .           .     .          .        .
    --   xm of shape sm = [k1, ..., ki_m, ..., kn]
    --  along dimension i to an output tensor
    --   y  of shape sy = [k1, ..., k, ..., kn]
    --  where k = sum ki = sum [ki_1,...,ki_m]
    --
    --  The incoming gradient dy from backpropagation is
    --   simply forwarded split across input tensors yielding dx.
    --   Forwarded gradients have shapes s = [s1, ..., sm].
    | Node
m Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== 1    = Maybe (Tensor Build a)
forall a. Maybe a
Nothing Maybe (Tensor Build a)
-> [Maybe (Tensor Build a)] -> [Maybe (Tensor Build a)]
forall a. a -> [a] -> [a]
: [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr Tensor Value a
dy]
    | Bool
otherwise = Maybe (Tensor Build a)
forall a. Maybe a
Nothing Maybe (Tensor Build a)
-> [Maybe (Tensor Build a)] -> [Maybe (Tensor Build a)]
forall a. a -> [a] -> [a]
: (Tensor Build a -> Maybe (Tensor Build a))
-> [Tensor Build a] -> [Maybe (Tensor Build a)]
forall a b. (a -> b) -> [a] -> [b]
map Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just ([Tensor Build a]
dx [Tensor Build a] -> [Tensor Build Int32] -> [Tensor Build a]
forall (v'1 :: * -> *) (v'2 :: * -> *).
[Tensor v'1 a] -> [Tensor v'2 Int32] -> [Tensor Build a]
`reshapeZip` [Tensor Build Int32]
s)
  where
    reshapeZip :: [Tensor v'1 a] -> [Tensor v'2 Int32] -> [Tensor Build a]
reshapeZip = (Tensor v'1 a -> Tensor v'2 Int32 -> Tensor Build a)
-> [Tensor v'1 a] -> [Tensor v'2 Int32] -> [Tensor Build a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor v'1 a -> Tensor v'2 Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape
    dx :: [Tensor Build a]
dx = Int64
-> Tensor Value a
-> Tensor Build Int32
-> Tensor Build Int32
-> [Tensor Build a]
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t tlen.
(TensorType t, OneOf '[Int32, Int64] tlen) =>
Int64
-> Tensor v'1 t
-> Tensor v'2 tlen
-> Tensor v'3 Int32
-> [Tensor Build t]
CoreOps.splitV (Node -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Node
m) Tensor Value a
dy Tensor Build Int32
ki Tensor Build Int32
_i
    s  :: [Tensor Build Int32]
    s :: [Tensor Build Int32]
s  = (Tensor Build a -> Tensor Build Int32)
-> [Tensor Build a] -> [Tensor Build Int32]
forall a b. (a -> b) -> [a] -> [b]
map Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape [Tensor Build a]
x
    x  :: [Tensor Build a]
    x :: [Tensor Build a]
x  = (Output -> Tensor Build a) -> [Output] -> [Tensor Build a]
forall a b. (a -> b) -> [a] -> [b]
map Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT ([Output] -> [Tensor Build a]) -> [Output] -> [Tensor Build a]
forall a b. (a -> b) -> a -> b
$ [Output] -> [Output]
forall a. [a] -> [a]
tail [Output]
_ix
    -- i: concat dimension. Adjusted modulo n to handle negative indices.
    _i :: Tensor Build Int32
_i = Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT ([Output] -> Output
forall a. [a] -> a
head [Output]
_ix) Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64, Word16, Word64, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.floorMod` Tensor Build Int32
n
    i :: Tensor Build Int32
i  = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Build Int32
_i (Tensor Build Int32 -> Tensor Build Int32)
-> Tensor Build Int32 -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1 :: Int32]
    -- sizes along concatenated dimension
    ki :: Tensor Build Int32
    ki :: Tensor Build Int32
ki = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 0 ([Tensor Build Int32] -> Tensor Build Int32)
-> [Tensor Build Int32] -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ (Tensor Build Int32 -> Tensor Build Int32)
-> [Tensor Build Int32] -> [Tensor Build Int32]
forall a b. (a -> b) -> [a] -> [b]
map (\t :: Tensor Build Int32
t -> Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t index.
(TensorType t, OneOf '[Int32, Int64] index) =>
Tensor v'1 t
-> Tensor v'2 index -> Tensor v'3 index -> Tensor Build t
CoreOps.slice Tensor Build Int32
t Tensor Build Int32
i (Tensor Build Int32 -> Tensor Build Int32)
-> Tensor Build Int32 -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1 :: Int32]) [Tensor Build Int32]
s
    m :: Node
m  = [Tensor Build a] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [Tensor Build a]
x
    n :: Tensor Build Int32
n  = Tensor Build a -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank ([Tensor Build a] -> Tensor Build a
forall a. [a] -> a
head [Tensor Build a]
x)

opGrad "Square" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] =
    -- TODO(fmayle): Handle complex numbers.
    -- TODO(fmayle): The python code makes dz a control dependency of the 2*x
    -- (for performance reasons?). Will need to put these functions in the Build
    -- monad to replicate that.
    [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` (2 Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
* Tensor Build a
x)]

opGrad "Gather" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
indices] [dz :: Tensor Value a
dz] =
    -- TODO(fmayle): The python version uses a better performance implementation
    -- when the shape is known without having to run the graph.
    -- TODO(fmayle): We shouldn't convert the result to a dense tensor. Sparse
    -- tensor support will require some thinking.
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t tindices
       tnumsegments.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tindices,
 OneOf '[Int32, Int64] tnumsegments) =>
Tensor v'1 t
-> Tensor v'2 tindices -> Tensor v'3 tnumsegments -> Tensor Build t
CoreOps.unsortedSegmentSum Tensor Build a
values Tensor Build Int32
indices' Tensor Build Int32
numRows
    , Maybe (Tensor Build a)
forall a. Maybe a
Nothing
    ]
  where
    -- TODO(gnezdo): Use colocateWith but it requires Build monad.
    denseShape :: Tensor Build Int32
denseShape = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
    numRows :: Tensor Build Int32
numRows = Tensor Build Int32 -> Tensor Build Int32
forall a (v :: * -> *).
TensorType a =>
Tensor v a -> Tensor Build a
scalarize (Tensor Build Int32 -> Tensor Build Int32)
-> Tensor Build Int32 -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32 -> Int32 -> Int32 -> Tensor Build Int32
forall (v1 :: * -> *) t.
TensorType t =>
Tensor v1 t -> Int32 -> Int32 -> Tensor Build t
flatSlice Tensor Build Int32
denseShape 0 1
    valuesShape :: Tensor Build Int32
valuesShape = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 0 [ Tensor Build Int32
allDimensions
                                   , Tensor Build Int32 -> Int32 -> Int32 -> Tensor Build Int32
forall (v1 :: * -> *) t.
TensorType t =>
Tensor v1 t -> Int32 -> Int32 -> Tensor Build t
flatSlice Tensor Build Int32
denseShape 1 (-1)
                                   ]
    values :: Tensor Build a
values = Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Value a
dz Tensor Build Int32
valuesShape
    -- TODO(fmayle): This could be either Int32 or Int64.
    indices' :: Tensor Build Int32
indices' = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Build Int32
indices Tensor Build Int32
allDimensions :: Tensor Build Int32

opGrad "Max" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
indices] [dz :: Tensor Value a
dz] =
    [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
indicators Tensor Build a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` Tensor Build a
numSelected Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
* Tensor Build a
dz', Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
  where
    sx :: Tensor Build Int32
sx = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
    outputShapeKeptDims :: Tensor Build Int32
outputShapeKeptDims = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall t1 t2 (v1 :: * -> *) (v2 :: * -> *).
(OneOf '[Int32, Int64] t1, OneOf '[Int32, Int64] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
reducedShape Tensor Build Int32
sx (Tensor Build Int32
indices :: Tensor Build Int32)
    y :: Tensor Build a
y = Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.max Tensor Build a
x Tensor Build Int32
indices
    y' :: Tensor Build a
y' = Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Build a
y Tensor Build Int32
outputShapeKeptDims
    dz' :: Tensor Build a
dz' = Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Value a
dz Tensor Build Int32
outputShapeKeptDims
    indicators :: Tensor Build a
indicators = Tensor Build Bool -> Tensor Build a
forall (v'1 :: * -> *) srcT dstT.
(TensorType srcT, TensorType dstT) =>
Tensor v'1 srcT -> Tensor Build dstT
CoreOps.cast (Tensor Build Bool -> Tensor Build a)
-> Tensor Build Bool -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build a -> Tensor Build Bool
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Bool, ByteString, Int16, Int32,
    Int64, Int8, Word16, Word32, Word64, Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build Bool
CoreOps.equal Tensor Build a
y' Tensor Build a
x
    numSelected :: Tensor Build a
numSelected = Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor Build a
indicators Tensor Build Int32
indices) Tensor Build Int32
outputShapeKeptDims

-- Min and Max have identical gradient implementations.
opGrad "Min" u :: NodeDef
u v :: [Output]
v w :: [Tensor Value a]
w = Text -> GradientFunc a
forall a. GradientCompatible a => Text -> GradientFunc a
opGrad "Max" NodeDef
u [Output]
v [Tensor Value a]
w

-- Element wise maximum gradient
-- See https://github.com/tensorflow/tensorflow/blob/e9de087fa7f59c39bbe12ac2c83c5547c83f746c/tensorflow/core/ops/math_grad.cc#L473
opGrad "Maximum" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
    (Tensor Build a, Tensor Build a)
-> (Tensor Build a, Tensor Build a) -> [Maybe (Tensor Build a)]
forall t (v1 :: * -> *).
OneOf
  '[Int32, Int64, Float, Double, Complex Float, Complex Double] t =>
(Tensor v1 t, Tensor v1 t)
-> (Tensor v1 t, Tensor v1 t) -> [Maybe (Tensor Build t)]
gradForBinaryCwise (Tensor Build a
x, Tensor Build a
gx) (Tensor Build a
y, Tensor Build a
gy)
  where
    xmask :: Tensor Build Bool
xmask = Tensor Build a -> Tensor Build a -> Tensor Build Bool
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
    Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build Bool
CoreOps.greaterEqual Tensor Build a
x Tensor Build a
y
    gx :: Tensor Build a
gx = Tensor Build Bool
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
CoreOps.select Tensor Build Bool
xmask Tensor Value a
dz (Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
CoreOps.zerosLike Tensor Value a
dz)
    gy :: Tensor Build a
gy = Tensor Build Bool
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
CoreOps.select (Tensor Build Bool -> Tensor Build Bool
forall (v'1 :: * -> *). Tensor v'1 Bool -> Tensor Build Bool
CoreOps.logicalNot Tensor Build Bool
xmask) Tensor Value a
dz (Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
CoreOps.zerosLike Tensor Value a
dz)

opGrad "Sum" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
indices] [dz :: Tensor Value a
dz] =
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.tile Tensor Build a
grad Tensor Build Int32
tileScaling, Maybe (Tensor Build a)
forall a. Maybe a
Nothing ]
  where
    -- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.
    sx :: Tensor Build Int32
sx = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
    outputShapeKeptDims :: Tensor Build Int32
outputShapeKeptDims = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall t1 t2 (v1 :: * -> *) (v2 :: * -> *).
(OneOf '[Int32, Int64] t1, OneOf '[Int32, Int64] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
reducedShape Tensor Build Int32
sx (Tensor Build Int32
indices :: Tensor Build Int32)
    tileScaling :: Tensor Build Int32
tileScaling = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v1 :: * -> *) (v2 :: * -> *).
Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
safeShapeDiv Tensor Build Int32
sx Tensor Build Int32
outputShapeKeptDims
    grad :: Tensor Build a
grad = Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Value a
dz Tensor Build Int32
outputShapeKeptDims

opGrad "Mean" u :: NodeDef
u v :: [Output]
v@[Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, _] w :: [Tensor Value a]
w =
    [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
dz Tensor Build a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` (Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
CoreOps.stopGradient (Tensor Build a -> Tensor Build a)
-> Tensor Build a -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) srcT dstT.
(TensorType srcT, TensorType dstT) =>
Tensor v'1 srcT -> Tensor Build dstT
CoreOps.cast (Tensor Build Int32 -> Tensor Build a)
-> Tensor Build Int32 -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32
factor), Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
  where
    [Just dz :: Tensor Build a
dz, Nothing] = Text -> GradientFunc a
forall a. GradientCompatible a => Text -> GradientFunc a
opGrad "Sum" NodeDef
u [Output]
v [Tensor Value a]
w
    inputShape :: Tensor Build Int32
inputShape = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
    outputShape :: Tensor Build Int32
outputShape = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
dz :: Tensor Build a)
    -- TODO(fmayle): Add fast path when shape is known.
    inputSize :: Tensor Build Int32
inputSize = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.prod Tensor Build Int32
inputShape (Tensor Build Int32 -> Tensor Build Int32)
-> Tensor Build Int32 -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
rangeOfRank Tensor Build Int32
inputShape
    outputSize :: Tensor Build Int32
outputSize = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.prod Tensor Build Int32
outputShape (Tensor Build Int32 -> Tensor Build Int32)
-> Tensor Build Int32 -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
rangeOfRank Tensor Build Int32
outputShape
    factor :: Tensor Build Int32
factor = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v1 :: * -> *) (v2 :: * -> *).
Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
safeShapeDiv Tensor Build Int32
inputSize Tensor Build Int32
outputSize

opGrad "Add" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor Value a
dz Tensor Build Int32
rx) Tensor Build Int32
sx
    , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor Value a
dz Tensor Build Int32
ry) Tensor Build Int32
sy ]
  where
    sx :: Tensor Build Int32
sx = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
    sy :: Tensor Build Int32
sy = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
y :: Tensor Build a)
    (rx :: Tensor Build Int32
rx, ry :: Tensor Build Int32
ry) = Tensor Build Int32
-> Tensor Build Int32 -> (Tensor Build Int32, Tensor Build Int32)
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64] t =>
Tensor v'1 t -> Tensor v'2 t -> (Tensor Build t, Tensor Build t)
broadcastGradientArgs Tensor Build Int32
sx Tensor Build Int32
sy

-- Copies the gradients to all inputs
-- Not broadcasting
opGrad "AddN" _ inputs :: [Output]
inputs [dz :: Tensor Value a
dz] =
    (Output -> Maybe (Tensor Build a))
-> [Output] -> [Maybe (Tensor Build a)]
forall a b. (a -> b) -> [a] -> [b]
map ((Maybe (Tensor Build a) -> Output -> Maybe (Tensor Build a)
forall a b. a -> b -> a
const (Maybe (Tensor Build a) -> Output -> Maybe (Tensor Build a))
-> (Tensor Value a -> Maybe (Tensor Build a))
-> Tensor Value a
-> Output
-> Maybe (Tensor Build a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> (Tensor Value a -> Tensor Build a)
-> Tensor Value a
-> Maybe (Tensor Build a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr) Tensor Value a
dz) [Output]
inputs

opGrad "Sub" u :: NodeDef
u v :: [Output]
v w :: [Tensor Value a]
w =
    [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just Tensor Build a
x, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (-Tensor Build a
y)]
  where
    [Just x :: Tensor Build a
x, Just y :: Tensor Build a
y] = Text -> GradientFunc a
forall a. GradientCompatible a => Text -> GradientFunc a
opGrad "Add" NodeDef
u [Output]
v [Tensor Value a]
w

opGrad "SoftmaxCrossEntropyWithLogits" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz, _] =
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall t (v1 :: * -> *) (v2 :: * -> *).
TensorType t =>
Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims Tensor Value a
dz (-1) Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
* (Tensor Build a, Tensor Build a) -> Tensor Build a
forall a b. (a, b) -> b
snd (Tensor Build a
-> Tensor Build a -> (Tensor Build a, Tensor Build a)
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> (Tensor Build t, Tensor Build t)
softmaxCrossEntropyWithLogits Tensor Build a
x Tensor Build a
y)
    , Maybe (Tensor Build a)
forall a. Maybe a
Nothing ]

opGrad "Mul" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
    -- TODO(fmayle): Handle complex numbers.
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum (Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Build a
y) Tensor Build Int32
rx) Tensor Build Int32
sx
    , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum (Tensor Build a
x Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Value a
dz) Tensor Build Int32
ry) Tensor Build Int32
sy ]
  where
    sx :: Tensor Build Int32
sx = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
    sy :: Tensor Build Int32
sy = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
y :: Tensor Build a)
    (rx :: Tensor Build Int32
rx, ry :: Tensor Build Int32
ry) = Tensor Build Int32
-> Tensor Build Int32 -> (Tensor Build Int32, Tensor Build Int32)
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64] t =>
Tensor v'1 t -> Tensor v'2 t -> (Tensor Build t, Tensor Build t)
broadcastGradientArgs Tensor Build Int32
sx Tensor Build Int32
sy

opGrad "Div" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
    -- TODO(fmayle): Handle complex numbers.
    -- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum (Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` Tensor Build a
y) Tensor Build Int32
rx) Tensor Build Int32
sx
    , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum (Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` (Tensor Build a -> Tensor Build a
forall a. Num a => a -> a
negate Tensor Build a
x Tensor Build a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` (Tensor Build a
y Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
* Tensor Build a
y)))
                         Tensor Build Int32
ry)
                Tensor Build Int32
sy
    ]
  where
    sx :: Tensor Build Int32
sx = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
    sy :: Tensor Build Int32
sy = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
y :: Tensor Build a)
    (rx :: Tensor Build Int32
rx, ry :: Tensor Build Int32
ry) = Tensor Build Int32
-> Tensor Build Int32 -> (Tensor Build Int32, Tensor Build Int32)
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64] t =>
Tensor v'1 t -> Tensor v'2 t -> (Tensor Build t, Tensor Build t)
broadcastGradientArgs Tensor Build Int32
sx Tensor Build Int32
sy

opGrad "MatMul" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
    let transposeA :: Bool
transposeA = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "transpose_a"
        transposeB :: Bool
transposeB = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "transpose_b"
        transAttrs :: a -> a -> OpDef -> OpDef
transAttrs a :: a
a b :: a
b =
            (Text -> Lens' OpDef a
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "transpose_a" (forall (f :: * -> *). Identical f => LensLike' f OpDef a)
-> a -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ a
a) (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef a
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "transpose_b" (forall (f :: * -> *). Identical f => LensLike' f OpDef a)
-> a -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ a
b)
    in case (Bool
transposeA, Bool
transposeB) of
       (False, False) ->
           [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
False Bool
True) Tensor Value a
dz Tensor Build a
y
           , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
True Bool
False) Tensor Build a
x Tensor Value a
dz]
       (False, True) ->
           [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul Tensor Value a
dz Tensor Build a
y
           , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
True Bool
False) Tensor Value a
dz Tensor Build a
x]
       (True, False) ->
           [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
False Bool
True) Tensor Build a
y Tensor Value a
dz
           , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul Tensor Build a
x Tensor Value a
dz]
       (True, True) ->
           [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
True Bool
True) Tensor Build a
y Tensor Value a
dz
           , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
True Bool
True) Tensor Value a
dz Tensor Build a
x]

opGrad "BatchMatMul" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
    let adjX :: Bool
adjX = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "adj_x"
        adjY :: Bool
adjY = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "adj_y"
        adjAttrs :: a -> a -> OpDef -> OpDef
adjAttrs a :: a
a b :: a
b =
            (Text -> Lens' OpDef a
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "adj_x" (forall (f :: * -> *). Identical f => LensLike' f OpDef a)
-> a -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ a
a) (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef a
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "adj_y" (forall (f :: * -> *). Identical f => LensLike' f OpDef a)
-> a -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ a
b)
    in case (Bool
adjX, Bool
adjY) of
        (False, False) ->
            [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
False Bool
True) Tensor Value a
dz Tensor Build a
y
            , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
True Bool
False) Tensor Build a
x Tensor Value a
dz]
        (False, True) ->
            [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul Tensor Value a
dz Tensor Build a
y
            , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
True Bool
False) Tensor Value a
dz Tensor Build a
x]
        (True, False) ->
            [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
False Bool
True) Tensor Build a
y Tensor Value a
dz
            , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul Tensor Build a
x Tensor Value a
dz]
        (True, True) ->
            [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
True Bool
True) Tensor Build a
y Tensor Value a
dz
            , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int32, Int64, Word16, Double,
    Float]
  t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
True Bool
True) Tensor Value a
dz Tensor Build a
x]

opGrad "Transpose" _ [_, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
p] [dz :: Tensor Value a
dz] =
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.transpose Tensor Value a
dz
            (Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t.
OneOf '[Int32, Int64] t =>
Tensor v'1 t -> Tensor Build t
CoreOps.invertPermutation Tensor Build Int32
p :: Tensor Build Int32)
    , Maybe (Tensor Build a)
forall a. Maybe a
Nothing
    ]

opGrad "Conv2D" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Build Int32
-> Tensor Build a
-> Tensor Value a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Int32, Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 Int32
-> Tensor v'2 t
-> Tensor v'3 t
-> Tensor Build t
CoreOps.conv2DBackpropInput'
                ((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "use_cudnn_on_gpu" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
useCudnnOnGpu)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
                ByteString
padding (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
x) Tensor Build a
y Tensor Value a
dz
    , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Build a
-> Tensor Build Int32
-> Tensor Value a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 Int32
-> Tensor v'3 t
-> Tensor Build t
CoreOps.conv2DBackpropFilter'
                ((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "use_cudnn_on_gpu" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
useCudnnOnGpu)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
                ByteString
padding Tensor Build a
x (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
y) Tensor Value a
dz
    ]
  where
    strides :: [Int64]
strides = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "strides" :: [Int64]
    padding :: ByteString
padding = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "padding" :: ByteString
    useCudnnOnGpu :: Bool
useCudnnOnGpu = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "use_cudnn_on_gpu" :: Bool
    dataFormat :: ByteString
dataFormat = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "data_format" :: ByteString

opGrad "Conv2DBackpropInput" nodeDef :: NodeDef
nodeDef [_, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
    [ Maybe (Tensor Build a)
forall a. Maybe a
Nothing
    , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Value a
-> Tensor Build Int32
-> Tensor Build a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 Int32
-> Tensor v'3 t
-> Tensor Build t
CoreOps.conv2DBackpropFilter'
                ((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "use_cudnn_on_gpu" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
useCudnnOnGpu)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
                ByteString
padding Tensor Value a
dz (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
x) Tensor Build a
y
    , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString -> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.conv2D'
                ((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "use_cudnn_on_gpu" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
useCudnnOnGpu)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
                ByteString
padding Tensor Value a
dz Tensor Build a
x
    ]
  where
    strides :: [Int64]
strides = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "strides" :: [Int64]
    padding :: ByteString
padding = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "padding" :: ByteString
    useCudnnOnGpu :: Bool
useCudnnOnGpu = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "use_cudnn_on_gpu" :: Bool
    dataFormat :: ByteString
dataFormat = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "data_format" :: ByteString

opGrad "DepthwiseConv2dNative" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Build Int32
-> Tensor Build a
-> Tensor Value a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 Int32
-> Tensor v'2 t
-> Tensor v'3 t
-> Tensor Build t
CoreOps.depthwiseConv2dNativeBackpropInput'
                ((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
                ByteString
padding (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
x) Tensor Build a
y Tensor Value a
dz
    , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Build a
-> Tensor Build Int32
-> Tensor Value a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 Int32
-> Tensor v'3 t
-> Tensor Build t
CoreOps.depthwiseConv2dNativeBackpropFilter'
                ((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
                ByteString
padding Tensor Build a
x (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
y) Tensor Value a
dz
    ]
  where
    strides :: [Int64]
strides = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "strides" :: [Int64]
    padding :: ByteString
padding = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "padding" :: ByteString
    dataFormat :: ByteString
dataFormat = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "data_format" :: ByteString

opGrad "DepthwiseConv2dNativeBackpropInput" nodeDef :: NodeDef
nodeDef [_, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
    [ Maybe (Tensor Build a)
forall a. Maybe a
Nothing
    , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Value a
-> Tensor Build Int32
-> Tensor Build a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 Int32
-> Tensor v'3 t
-> Tensor Build t
CoreOps.depthwiseConv2dNativeBackpropFilter'
                ((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
                ByteString
padding Tensor Value a
dz (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
x) Tensor Build a
y
    , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString -> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.depthwiseConv2dNative'
                ((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
                ByteString
padding Tensor Value a
dz Tensor Build a
x
    ]
  where
    strides :: [Int64]
strides = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "strides" :: [Int64]
    padding :: ByteString
padding = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "padding" :: ByteString
    dataFormat :: ByteString
dataFormat = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "data_format" :: ByteString

opGrad "MaxPool" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] =
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Build a
-> Tensor Build a
-> Tensor Value a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf
  '[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
    Float]
  t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 t
-> Tensor v'3 t
-> Tensor Build t
CoreOps.maxPoolGrad'
                ((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "ksize" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
ksize)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
                    (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
                ByteString
padding Tensor Build a
x Tensor Build a
output Tensor Value a
dz
    ]
  where
    output :: Tensor Build a
    output :: Tensor Build a
output = Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT (Output -> Tensor Build a) -> Output -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ OutputIx -> NodeName -> Output
Output 0 (NodeDef -> NodeName
nodeDefName NodeDef
nodeDef)
    ksize :: [Int64]
ksize = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "ksize" :: [Int64]
    strides :: [Int64]
strides = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "strides" :: [Int64]
    padding :: ByteString
padding = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "padding" :: ByteString
    dataFormat :: ByteString
dataFormat = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "data_format" :: ByteString

opGrad "Reshape" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, _] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Value a
dz (Tensor Build Int32 -> Tensor Build a)
-> Tensor Build Int32 -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a), Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "ExpandDims" n :: NodeDef
n xs :: [Output]
xs@[Output -> Tensor Build Any
forall a. Output -> Tensor Build a
toT -> Tensor Build Any
_, _] dzs :: [Tensor Value a]
dzs@[_] = Text -> GradientFunc a
forall a. GradientCompatible a => Text -> GradientFunc a
opGrad "Reshape" NodeDef
n [Output]
xs [Tensor Value a]
dzs
opGrad "Squeeze" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Value a
dz (Tensor Build Int32 -> Tensor Build a)
-> Tensor Build Int32 -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)]
opGrad "Pad" _ [Output -> Tensor Build Float
forall a. Output -> Tensor Build a
toT -> Tensor Build Float
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
padPattern] [dz :: Tensor Value a
dz] =
  [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t index.
(TensorType t, OneOf '[Int32, Int64] index) =>
Tensor v'1 t
-> Tensor v'2 index -> Tensor v'3 index -> Tensor Build t
CoreOps.slice Tensor Value a
dz Tensor Build Int32
gradientSliceBegin Tensor Build Int32
gradientSliceSize, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
  where
    v1 :: Tensor Build Int32
v1 = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1]
     -- For some reason rankx' has an empty shape
    rankx' :: Tensor Build Int32
rankx' = Tensor Build Float -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank (Tensor Build Float
x :: Tensor Build Float)
    rankx :: Tensor Build Int32
rankx = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Build Int32
rankx' Tensor Build Int32
v1
    -- Size of column that is sliced from pad pattern
    padPatternSliceSize :: Tensor Build Int32
padPatternSliceSize = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 0 [Tensor Build Int32
rankx, Tensor Build Int32
v1]
    padPatternSliceBegin :: Tensor Build Int32
padPatternSliceBegin = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [0, 0]
    Tensor Build Int32
padPatternSliced :: Tensor Build Int32 = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t index.
(TensorType t, OneOf '[Int32, Int64] index) =>
Tensor v'1 t
-> Tensor v'2 index -> Tensor v'3 index -> Tensor Build t
CoreOps.slice Tensor Build Int32
padPattern Tensor Build Int32
padPatternSliceBegin Tensor Build Int32
padPatternSliceSize
    -- The slice of the pad pattern has the same rank as the pad pattern itself
    gradientSliceBegin :: Tensor Build Int32
gradientSliceBegin = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Build Int32
padPatternSliced Tensor Build Int32
rankx
    gradientSliceSize :: Tensor Build Int32
gradientSliceSize = Tensor Build Float -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build Float
x :: Tensor Build Float)

-- Gradient for Slice
-- Create an Nx2 padding where N is the rank of (grad of) Slice and the first
-- column represents how many zeros are to be prepended for each dimension, and the second
-- column indicates how many zeros are appended.
-- The number of zeros to prepend is the shape of the beginvec.
-- The number of zeros to append is the shape of the inputvec
-- elementwise-subtracted by both the beginvec and sizevec.
-- Some more reshaping is needed to assemble this tensor with the
-- right dimensions.
opGrad "Slice" _ [Output -> Tensor Build Float
forall a. Output -> Tensor Build a
toT -> Tensor Build Float
inputvec, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
beginvec, _] [dz :: Tensor Value a
dz] =
   [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.pad Tensor Value a
dz Tensor Build Int32
paddings, Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
  where
    v1 :: Tensor Build Int32
v1 = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1 :: Int32]
    inputRank' :: Tensor Build Int32
inputRank' = Tensor Build Float -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank (Tensor Build Float
inputvec :: Tensor Build Float)
    -- For some reason inputRank' has an empty shape
    inputRank :: Tensor Build Int32
inputRank = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Build Int32
inputRank' Tensor Build Int32
v1
    padShape :: Tensor Build Int32
padShape = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 0 [Tensor Build Int32
inputRank, Tensor Build Int32
v1]
    beforePad :: Tensor Build Int32
beforePad = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Build Int32
beginvec Tensor Build Int32
padShape
    afterPad :: Tensor Build Int32
afterPad = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape (Tensor Build Float -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build Float
inputvec Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall a. Num a => a -> a -> a
- Tensor Value a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Value a
dz Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall a. Num a => a -> a -> a
- Tensor Build Int32
beginvec) Tensor Build Int32
padShape
    paddings :: Tensor Build Int32
paddings = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 1 [Tensor Build Int32
beforePad, Tensor Build Int32
afterPad]

-- TODO: This could be either Int32 or Int64.
opGrad "BatchToSpaceND" _ [_, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT @Int32 -> Tensor Build Int32
blockShape, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT @Int32 -> Tensor Build Int32
crops] [dz :: Tensor Value a
dz] =
  [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t
       tblock_shape tpaddings.
(TensorType t, OneOf '[Int32, Int64] tblock_shape,
 OneOf '[Int32, Int64] tpaddings) =>
Tensor v'1 t
-> Tensor v'2 tblock_shape
-> Tensor v'3 tpaddings
-> Tensor Build t
CoreOps.spaceToBatchND Tensor Value a
dz Tensor Build Int32
blockShape Tensor Build Int32
crops, Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]

-- TODO: This could be either Int32 or Int64.
opGrad "SpaceToBatchND" _ [_, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT @Int32 -> Tensor Build Int32
blockShape, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT @Int32 -> Tensor Build Int32
paddings] [dz :: Tensor Value a
dz] =
  [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t
       tblock_shape tpaddings.
(TensorType t, OneOf '[Int32, Int64] tblock_shape,
 OneOf '[Int32, Int64] tpaddings) =>
Tensor v'1 t
-> Tensor v'2 tblock_shape
-> Tensor v'3 tpaddings
-> Tensor Build t
CoreOps.batchToSpaceND Tensor Value a
dz Tensor Build Int32
blockShape Tensor Build Int32
paddings, Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]

opGrad "OneHot" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "TruncatedNormal" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing]

opGrad "RefIdentity" _ _ [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr Tensor Value a
dz]
opGrad "Cast" nodeDef :: NodeDef
nodeDef _ [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just Tensor Build a
reverseCast]
  where
    -- TODO(gnezdo): too permissive, python only allows float types as src_type.
    reverseCast :: Tensor Build a
reverseCast =
        [Int64] -> Build OpDef -> Tensor Build a
forall a. PureResult a => [Int64] -> Build OpDef -> a
pureOp [] (Build OpDef -> Tensor Build a) -> Build OpDef -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ OpDef -> Build OpDef
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpType -> OpDef
opDef "Cast"
                 OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "DstT" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ (NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "SrcT" :: ByteString)
                 OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "SrcT" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ (NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "DstT" :: ByteString)
                 OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Lens' OpDef [Output]
forall (f :: * -> *). Identical f => LensLike' f OpDef [Output]
opInputs (forall (f :: * -> *). Identical f => LensLike' f OpDef [Output])
-> [Output] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Tensor Value a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput Tensor Value a
dz])

opGrad "DynamicStitch" nodeDef :: NodeDef
nodeDef inputs :: [Output]
inputs [dz :: Tensor Value a
dz] =
    Node -> Maybe (Tensor Build a) -> [Maybe (Tensor Build a)]
forall a. Node -> a -> [a]
replicate Node
halfLen Maybe (Tensor Build a)
forall a. Maybe a
Nothing [Maybe (Tensor Build a)]
-> [Maybe (Tensor Build a)] -> [Maybe (Tensor Build a)]
forall a. [a] -> [a] -> [a]
++ [Maybe (Tensor Build a)]
valuesGrads
  where
    halfLen :: Node
halfLen =
        let len :: Node
len = [Output] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [Output]
inputs
            half :: Node
half = Node
len Node -> Node -> Node
forall a. Integral a => a -> a -> a
`div` 2
        in if 2 Node -> Node -> Node
forall a. Num a => a -> a -> a
* Node
half Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== Node
len
           then Node
half
           else [Char] -> Node
forall a. HasCallStack => [Char] -> a
error ("Uneven input size " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (Node, [Char]) -> [Char]
forall a. Show a => a -> [Char]
show (Node
len, NodeDef -> [Char]
forall msg. Message msg => msg -> [Char]
showMessage NodeDef
nodeDef))
    valuesGrads :: [Maybe (Tensor Build a)]
valuesGrads = [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.gather Tensor Value a
dz (Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT Output
idx :: Tensor Build Int32)
                  | Output
idx <- Node -> [Output] -> [Output]
forall a. Node -> [a] -> [a]
take Node
halfLen [Output]
inputs
                  ]

opGrad "DynamicPartition" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
xs, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
indices] dz :: [Tensor Value a]
dz =
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just Tensor Build a
reconstructed, Maybe (Tensor Build a)
forall a. Maybe a
Nothing ]
  where
    reconstructed :: Tensor Build a
reconstructed = Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Build a
stitched
                    (Tensor Build a -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.shape (Tensor Build a
xs :: Tensor Build a) :: Tensor Build Int32)
    stitched :: Tensor Build a
stitched = [Tensor Build Int32] -> [Tensor Value a] -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
[Tensor v'1 Int32] -> [Tensor v'2 t] -> Tensor Build t
CoreOps.dynamicStitch [Tensor Build Int32]
partitionedIndices [Tensor Value a]
dz
    partitionedIndices :: [Tensor Build Int32]
partitionedIndices = Int64
-> Tensor Build Int32 -> Tensor Build Int32 -> [Tensor Build Int32]
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Int64 -> Tensor v'1 t -> Tensor v'2 Int32 -> [Tensor Build t]
CoreOps.dynamicPartition Int64
np Tensor Build Int32
originalIndices Tensor Build Int32
indices
    np :: Int64
np = NodeDef -> Text -> Int64
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "num_partitions" :: Int64
    originalIndices :: Tensor Build Int32
originalIndices =
        Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape (Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.size Tensor Build Int32
indices) 1) Tensor Build Int32
prefixShape
    prefixShape :: Tensor Build Int32
prefixShape = Tensor Build Int32 -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shapeInt32 Tensor Build Int32
indices
    shapeInt32 :: Tensor v'1 t -> Tensor Build Int32
shapeInt32 t :: Tensor v'1 t
t = Tensor v'1 t -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.shape Tensor v'1 t
t :: Tensor Build Int32

opGrad "Select" _ [Output -> Tensor Build Bool
forall a. Output -> Tensor Build a
toT -> Tensor Build Bool
c, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, _] [dz :: Tensor Value a
dz] =
    [ Maybe (Tensor Build a)
forall a. Maybe a
Nothing
    , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build Bool
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
CoreOps.select Tensor Build Bool
c Tensor Value a
dz Tensor Build a
zeros
    , Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build Bool
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
CoreOps.select Tensor Build Bool
c Tensor Build a
zeros Tensor Value a
dz
    ]
  where zeros :: Tensor Build a
zeros = Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
CoreOps.zerosLike Tensor Build a
x

-- TODO(gnezdo): Unlike Python, no control dependency on dz.
opGrad "Log" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Double, Float]
  t =>
Tensor v'1 t -> Tensor Build t
CoreOps.inv Tensor Build a
x ]
-- TODO(gnezdo): Reuse the output instead of doing another exp,
-- though, it is probably CSE'd away anyway.
opGrad "Exp" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
CoreOps.exp Tensor Build a
x ]
opGrad "SparseSegmentSum" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
y, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
t] [dz :: Tensor Value a
dz] =
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t tindices
       tnumsegments.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tindices,
 OneOf '[Int32, Int64] tnumsegments) =>
Tensor v'1 t
-> Tensor v'2 tindices -> Tensor v'3 tnumsegments -> Tensor Build t
CoreOps.unsortedSegmentSum
             (Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.gather Tensor Value a
dz (Tensor Build Int32
t :: Tensor Build Int32))
             (Tensor Build Int32
y :: Tensor Build Int32) Tensor Build Int32
inputRows
    , Maybe (Tensor Build a)
forall a. Maybe a
Nothing
    , Maybe (Tensor Build a)
forall a. Maybe a
Nothing
    ]
  where inputRows :: Tensor Build Int32
inputRows = Tensor Build Int32 -> Int32 -> Int32 -> Tensor Build Int32
forall (v1 :: * -> *) t.
TensorType t =>
Tensor v1 t -> Int32 -> Int32 -> Tensor Build t
flatSlice (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)) 0 1

opGrad "LabelClasses" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "LabelWeights" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "Size" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing]

-- TODO (jcberentsen): Python implementation uses set_shape for
-- static shape inference, which is unsupported.
-- TODO: implement support for static shape inference
opGrad "Tile" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
multiples] [dz :: Tensor Value a
dz] =
    [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just Tensor Build a
inputGrad, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
  where
    inputGrad :: Tensor Build a
inputGrad = Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor Build a
reshapedDz Tensor Build Int32
axes
    inputShape :: Tensor Build Int32
inputShape = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
    packed :: Tensor Build Int32
packed = [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
[Tensor v'1 t] -> Tensor Build t
CoreOps.pack [Tensor Build Int32
multiples, Tensor Build Int32
inputShape]
    perm :: Tensor Build Int32
perm = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1, 0 :: Int32]
    splitShape :: Tensor Build Int32
splitShape = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape (Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.transpose Tensor Build Int32
packed Tensor Build Int32
perm) Tensor Build Int32
allDimensions
    axes :: Tensor Build Int32
axes = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.size Tensor Build Int32
splitShape) (2 :: Tensor Build Int32)
    reshapedDz :: Tensor Build a
reshapedDz = Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Value a
dz Tensor Build Int32
splitShape

opGrad "ResizeBilinear" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, _] [dz :: Tensor Value a
dz] =
    [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build Float -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> Tensor v'1 Float -> Tensor v'2 t -> Tensor Build t
CoreOps.resizeBilinearGrad'
               (Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "align_corners" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
align)
               (Tensor Value a -> Tensor Build Float
forall (v'1 :: * -> *) srcT dstT.
(TensorType srcT, TensorType dstT) =>
Tensor v'1 srcT -> Tensor Build dstT
CoreOps.cast Tensor Value a
dz)
               Tensor Build a
x

    , Maybe (Tensor Build a)
forall a. Maybe a
Nothing
    ]
  where
    align :: Bool
align = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "align_corners" :: Bool

opGrad "ZerosLike" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "Fill" _ _ [dz :: Tensor Value a
dz] = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t,
 OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor Value a
dz Tensor Build Int32
rx]
  where
    rx :: Tensor Build Int32
rx = Tensor Value a -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
rangeOfRank Tensor Value a
dz

-- Treat read ops as an identity function on the variable. This allows us to
-- take gradients w.r.t. to the variable handle instead of the result of a read
-- op. If a variable is read multiple times, the gradients will propagate back
-- through each read.
opGrad "ReadVariableOp" _ _ [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr Tensor Value a
dz]

opGrad "Const" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "StopGradient" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "VarHandleOp" _ _ _ = []

opGrad "Sqrt" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
sq' Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Value a
dz]
  where
    sq' :: Tensor Build a
sq' = a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
scalar 1 Tensor Build a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` (a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
scalar 2 Tensor Build a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
CoreOps.sqrt Tensor Build a
x)

opGrad n :: Text
n nodeDef :: NodeDef
nodeDef ins :: [Output]
ins grads :: [Tensor Value a]
grads =
    [Char] -> [Maybe (Tensor Build a)]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [Maybe (Tensor Build a)])
-> [Char] -> [Maybe (Tensor Build a)]
forall a b. (a -> b) -> a -> b
$ "no gradient implemented for " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
            (Text, Node, Node, [Char], [Output]) -> [Char]
forall a. Show a => a -> [Char]
show (Text
n, [Output] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [Output]
ins, [Tensor Value a] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [Tensor Value a]
grads, NodeDef -> [Char]
forall msg. Message msg => msg -> [Char]
showMessage NodeDef
nodeDef, [Output]
ins)

-- | The number of outputs for an op type.
numOutputs :: NodeDef -> OutputIx
numOutputs :: NodeDef -> OutputIx
numOutputs o :: NodeDef
o =
    case NodeDef
o NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "op" a) =>
LensLike' f s a
op of
        "Abs" -> 1
        "Add" -> 1
        "AddN" -> 1
        "BatchToSpaceND" -> 1
        "BatchMatMul" -> 1
        "Cast" -> 1
        "Const" -> 1
        "Concat" -> 1
        "Conv2D" -> 1
        "Conv2DBackpropInput" -> 1
        "DepthwiseConv2dNative" -> 1
        "DepthwiseConv2dNativeBackpropInput" -> 1
        "Div" -> 1
        "DynamicStitch" -> 1
        "DynamicPartition" ->
            Int64 -> OutputIx
forall a b. (Integral a, Num b) => a -> b
fromIntegral (NodeDef -> Text -> Int64
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
o "num_partitions" :: Int64)
        "Exp" -> 1
        "ExpandDims" -> 1
        "Gather" -> 1
        "LabelClasses" -> 1
        "LabelWeights" -> 1
        "Log" -> 1
        "MatMul" -> 1
        "Max" -> 1
        "Maximum" -> 1
        "MaxPool" -> 1
        "Mean" -> 1
        "Min" -> 1
        "Mul" -> 1
        "Neg" -> 1
        "Pad" -> 1
        "Placeholder" -> 1
        "StopGradient" -> 1
        "OneHot" -> 1
        "ReadVariableOp" -> 1
        "RefIdentity" -> 1
        "Relu" -> 1
        "ReluGrad" -> 1
        "Reshape" -> 1
        "Select" -> 1
        "Sigmoid" -> 1
        "Size" -> 1
        "Slice" -> 1
        "SoftmaxCrossEntropyWithLogits" -> 2
        "SpaceToBatchND" -> 1
        "SparseSegmentSum" -> 1
        "Square" -> 1
        "Squeeze" -> 1
        "Sqrt" -> 1
        "Sub" -> 1
        "Sum" -> 1
        "Tanh" -> 1
        "Tile" -> 1
        "ResizeBilinear" -> 1
        "Transpose" -> 1
        "TruncatedNormal" -> 1
        "VarHandleOp" -> 1
        "Variable" -> 1
        "ZerosLike" -> 1
        "Fill" -> 1
        _ -> [Char] -> OutputIx
forall a. HasCallStack => [Char] -> a
error ([Char] -> OutputIx) -> [Char] -> OutputIx
forall a b. (a -> b) -> a -> b
$ "numOutputs not implemented for " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Text -> [Char]
forall a. Show a => a -> [Char]
show (NodeDef
o NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "op" a) =>
LensLike' f s a
op)

-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
safeShapeDiv x :: Tensor v1 Int32
x y :: Tensor v2 Int32
y = Tensor v1 Int32
x Tensor v1 Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` (Tensor v2 Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int16, Int32, Int64, Word16, Word8, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.maximum Tensor v2 Int32
y 1)

allDimensions :: Tensor Build Int32
allDimensions :: Tensor Build Int32
allDimensions = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [-1 :: Int32]

rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Build Int32
rangeOfRank :: Tensor v1 t -> Tensor Build Int32
rangeOfRank x :: Tensor v1 t
x = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor v1 t -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank Tensor v1 t
x) 1

lookupAttr ::  Attribute a1 => NodeDef -> Text -> a1
lookupAttr :: NodeDef -> Text -> a1
lookupAttr nodeDef :: NodeDef
nodeDef attrName :: Text
attrName = NodeDef
nodeDef NodeDef -> FoldLike a1 NodeDef NodeDef a1 a1 -> a1
forall s a t b. s -> FoldLike a s t a b -> a
^. LensLike' (Constant a1) NodeDef (Map Text AttrValue)
forall (f :: * -> *) s a.
(Functor f, HasField s "attr" a) =>
LensLike' f s a
attr LensLike' (Constant a1) NodeDef (Map Text AttrValue)
-> ((a1 -> Constant a1 a1)
    -> Map Text AttrValue -> Constant a1 (Map Text AttrValue))
-> FoldLike a1 NodeDef NodeDef a1 a1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Lens' (Map Text AttrValue) (Maybe AttrValue)
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at Text
attrName LensLike' (Constant a1) (Map Text AttrValue) (Maybe AttrValue)
-> ((a1 -> Constant a1 a1)
    -> Maybe AttrValue -> Constant a1 (Maybe AttrValue))
-> (a1 -> Constant a1 a1)
-> Map Text AttrValue
-> Constant a1 (Map Text AttrValue)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AttrValue -> Lens' (Maybe AttrValue) AttrValue
forall a. Eq a => a -> Lens' (Maybe a) a
non AttrValue
forall a. Message a => a
def LensLike' (Constant a1) (Maybe AttrValue) AttrValue
-> ((a1 -> Constant a1 a1) -> AttrValue -> Constant a1 AttrValue)
-> (a1 -> Constant a1 a1)
-> Maybe AttrValue
-> Constant a1 (Maybe AttrValue)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a1 -> Constant a1 a1) -> AttrValue -> Constant a1 AttrValue
forall a. Attribute a => Lens' AttrValue a
attrLens