-- | An implementation of ResourceHandle-based variables.
--
-- The main difference between this and 'Ref'-based variables is
-- that reads are explicit, via the 'readValue' op.
--
-- TODO: given that distinction, figure out a good story around
-- gradients and save/restore.  Then, merge this module into
-- TensorFlow.Ops.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Variable
    ( Variable
    , variable
    , variable'
    , readValue
    , initializedValue
    , initializedVariable
    , initializedVariable'
    , zeroInitializedVariable
    , zeroInitializedVariable'
    , assign
    , assign'
    , assignAdd
    , assignAdd'
    , resourceApplyAdam
    , resourceApplyAdam'
    ) where

import qualified Data.Complex
import qualified Data.Int
import qualified Data.Word
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import TensorFlow.Core
import TensorFlow.Build (opDef)
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
import TensorFlow.Output (opInputs, unNodeName)
import TensorFlow.Tensor (Rendered(..), ToTensor(..), renderValue, tensorNodeName)
import TensorFlow.Types (tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Ops (zeros)

data Variable a = Variable
    { Variable a -> Tensor Value ResourceHandle
variableHandle   :: Tensor Value ResourceHandle
    , Variable a -> Maybe (Tensor Value a)
initializedValue :: Maybe (Tensor Value a)
      -- ^ The initial value of a 'Variable' created with 'initializedVariable'.
    }

instance Rendered Variable where
    renderedOutput :: Variable a -> Output
renderedOutput = Tensor Value ResourceHandle -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput (Tensor Value ResourceHandle -> Output)
-> (Variable a -> Tensor Value ResourceHandle)
-> Variable a
-> Output
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable a -> Tensor Value ResourceHandle
forall a. Variable a -> Tensor Value ResourceHandle
variableHandle

instance ToTensor Variable where
    toTensor :: Variable a -> Tensor Build a
toTensor = Variable a -> Tensor Build a
forall a. TensorType a => Variable a -> Tensor Build a
readValue

-- | Creates a new, uninitialized variable.
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
variable :: Shape -> m (Variable a)
variable = OpParams -> Shape -> m (Variable a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a) =>
OpParams -> Shape -> m (Variable a)
variable' OpParams
forall a. a -> a
id

variable' :: forall m a . (MonadBuild m, TensorType a)
                    => OpParams -> Shape -> m (Variable a)
variable' :: OpParams -> Shape -> m (Variable a)
variable' params :: OpParams
params s :: Shape
s = OpParams -> Maybe Shape -> m (Variable a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a) =>
OpParams -> Maybe Shape -> m (Variable a)
variableInternal OpParams
params (Shape -> Maybe Shape
forall a. a -> Maybe a
Just Shape
s)

variableInternal :: forall m a . (MonadBuild m, TensorType a)
                 => OpParams -> Maybe Shape -> m (Variable a)
variableInternal :: OpParams -> Maybe Shape -> m (Variable a)
variableInternal params :: OpParams
params s :: Maybe Shape
s = Build (Variable a) -> m (Variable a)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build (Variable a) -> m (Variable a))
-> Build (Variable a) -> m (Variable a)
forall a b. (a -> b) -> a -> b
$ do
    -- Each variable needs a unique "shared_name".  Use MonadFix to
    -- set the attribute to the same name as the variable itself, without
    -- exposing more internals of the Build module.
    rec let attrs :: OpParams
attrs = OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "shared_name" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
n) OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef (Maybe Shape)
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "shape" (forall (f :: * -> *).
 Identical f =>
 LensLike' f OpDef (Maybe Shape))
-> Maybe Shape -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ Maybe Shape
s)
            dtype :: DataType
dtype = a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)
            -- Generated ops don't support unknown shapes. As a workaround, we
            -- pass in a rank zero shape and then override it using OpParams.
            -- TODO: Consider supporting this better in op generation.
            shape :: Shape
shape = [Int64] -> Shape
Shape []
        Tensor Value ResourceHandle
t <- OpParams
-> DataType
-> Shape
-> BuildT Identity (Tensor Value ResourceHandle)
forall (m' :: * -> *).
MonadBuild m' =>
OpParams -> DataType -> Shape -> m' (Tensor Value ResourceHandle)
CoreOps.varHandleOp' OpParams
attrs DataType
dtype Shape
shape
        let n :: ByteString
n = Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ NodeName -> Text
unNodeName (NodeName -> Text) -> NodeName -> Text
forall a b. (a -> b) -> a -> b
$ Tensor Value ResourceHandle -> NodeName
forall (t :: * -> *) a. Rendered t => t a -> NodeName
tensorNodeName Tensor Value ResourceHandle
t
    Variable a -> Build (Variable a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Variable a -> Build (Variable a))
-> Variable a -> Build (Variable a)
forall a b. (a -> b) -> a -> b
$ Tensor Value ResourceHandle -> Maybe (Tensor Value a) -> Variable a
forall a.
Tensor Value ResourceHandle -> Maybe (Tensor Value a) -> Variable a
Variable Tensor Value ResourceHandle
t Maybe (Tensor Value a)
forall a. Maybe a
Nothing

-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: (MonadBuild m, TensorType a)
                    => Tensor v a -> m (Variable a)
initializedVariable :: Tensor v a -> m (Variable a)
initializedVariable = OpParams -> Tensor v a -> m (Variable a)
forall a (m :: * -> *) (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Tensor v a -> m (Variable a)
initializedVariable' OpParams
forall a. a -> a
id

initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
                    => OpParams -> Tensor v a -> m (Variable a)
initializedVariable' :: OpParams -> Tensor v a -> m (Variable a)
initializedVariable' params :: OpParams
params initializer :: Tensor v a
initializer = do
    -- The shape is not known initially.
    Variable a
variables <- OpParams -> Maybe Shape -> m (Variable a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a) =>
OpParams -> Maybe Shape -> m (Variable a)
variableInternal OpParams
params Maybe Shape
forall a. Maybe a
Nothing
    Tensor Value ResourceHandle
h <- Tensor Value ResourceHandle -> m (Tensor Value ResourceHandle)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor Value ResourceHandle -> m (Tensor Value ResourceHandle))
-> Tensor Value ResourceHandle -> m (Tensor Value ResourceHandle)
forall a b. (a -> b) -> a -> b
$ case Variable a
variables of
                  (Variable h :: Tensor Value ResourceHandle
h Nothing :: Variable a) -> Tensor Value ResourceHandle
h
                  _ -> [Char] -> Tensor Value ResourceHandle
forall a. HasCallStack => [Char] -> a
error "variableInternal is empty"
    Tensor Value a
initializer' <- Tensor v a -> m (Tensor Value a)
forall (m :: * -> *) (v :: * -> *) a.
MonadBuild m =>
Tensor v a -> m (Tensor Value a)
renderValue Tensor v a
initializer
    ControlNode
i <- Tensor Value ResourceHandle -> Tensor Value a -> m ControlNode
forall (v'1 :: * -> *) (v'2 :: * -> *) dtype (m' :: * -> *).
(MonadBuild m', TensorType dtype) =>
Tensor v'1 ResourceHandle -> Tensor v'2 dtype -> m' ControlNode
CoreOps.assignVariableOp Tensor Value ResourceHandle
h Tensor Value a
initializer'
    ControlNode -> m ()
forall (m :: * -> *). MonadBuild m => ControlNode -> m ()
addInitializer (ControlNode -> m ()) -> m ControlNode -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ControlNode -> m ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
group ControlNode
i
    Variable a -> m (Variable a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor Value ResourceHandle -> Maybe (Tensor Value a) -> Variable a
forall a.
Tensor Value ResourceHandle -> Maybe (Tensor Value a) -> Variable a
Variable Tensor Value ResourceHandle
h (Tensor Value a -> Maybe (Tensor Value a)
forall a. a -> Maybe a
Just Tensor Value a
initializer'))

-- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable
  :: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a)
zeroInitializedVariable :: Shape -> m (Variable a)
zeroInitializedVariable = OpParams -> Shape -> m (Variable a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a, Num a) =>
OpParams -> Shape -> m (Variable a)
zeroInitializedVariable' OpParams
forall a. a -> a
id

zeroInitializedVariable'
  :: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a)
zeroInitializedVariable' :: OpParams -> Shape -> m (Variable a)
zeroInitializedVariable' params :: OpParams
params = OpParams -> Tensor Build a -> m (Variable a)
forall a (m :: * -> *) (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Tensor v a -> m (Variable a)
initializedVariable' OpParams
params (Tensor Build a -> m (Variable a))
-> (Shape -> Tensor Build a) -> Shape -> m (Variable a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> Tensor Build a
forall a. (Num a, TensorType a) => Shape -> Tensor Build a
zeros

-- | Gets the value stored in a variable.
--
-- Note that this op is stateful since it depends on the value of the variable;
-- however, it may be CSE'd with other reads in the same context.  The context can
-- be fixed by using 'render' along with (for example) 'withControlDependencies'.
-- For example:
--
-- >   runSession $ do
-- >     v <- variable []
-- >     a <- assign v 24
-- >     r <- withControlDependencies a $ render $ readValue v + 18
-- >     result <- run r
-- >     liftIO $ (42 :: Float) @=? unScalar result
--
--
readValue :: TensorType a => Variable a -> Tensor Build a
readValue :: Variable a -> Tensor Build a
readValue = OpParams -> Variable a -> Tensor Build a
forall a. TensorType a => OpParams -> Variable a -> Tensor Build a
readValue' OpParams
forall a. a -> a
id

readValue' :: forall a . TensorType a
    => OpParams -> Variable a -> Tensor Build a
readValue' :: OpParams -> Variable a -> Tensor Build a
readValue' params :: OpParams
params (Variable h :: Tensor Value ResourceHandle
h _)
    = [Int64] -> Build OpDef -> Tensor Build a
forall a. PureResult a => [Int64] -> Build OpDef -> a
pureOp [] (Build OpDef -> Tensor Build a) -> Build OpDef -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ do
        [Output]
os <- Tensor Value ResourceHandle -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs Tensor Value ResourceHandle
h
        OpDef -> Build OpDef
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpDef -> Build OpDef) -> OpDef -> Build OpDef
forall a b. (a -> b) -> a -> b
$ OpType -> OpDef
opDef "ReadVariableOp"
                OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& (OpParams
params
                    OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef DataType
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "dtype" (forall (f :: * -> *). Identical f => LensLike' f OpDef DataType)
-> DataType -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a))
                    OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Lens' OpDef [Output]
forall (f :: * -> *). Identical f => LensLike' f OpDef [Output]
opInputs (forall (f :: * -> *). Identical f => LensLike' f OpDef [Output])
-> [Output] -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Output]
os))

-- | Sets the value of a variable.
assign :: (MonadBuild m, TensorType a)
    => Variable a -> Tensor v a -> m ControlNode
assign :: Variable a -> Tensor v a -> m ControlNode
assign = OpParams -> Variable a -> Tensor v a -> m ControlNode
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Variable a -> Tensor v a -> m ControlNode
assign' OpParams
forall a. a -> a
id

assign' :: (MonadBuild m, TensorType a)
    => OpParams -> Variable a -> Tensor v a -> m ControlNode
assign' :: OpParams -> Variable a -> Tensor v a -> m ControlNode
assign' params :: OpParams
params (Variable h :: Tensor Value ResourceHandle
h _) v :: Tensor v a
v = OpParams
-> Tensor Value ResourceHandle -> Tensor v a -> m ControlNode
forall (v'1 :: * -> *) (v'2 :: * -> *) dtype (m' :: * -> *).
(MonadBuild m', TensorType dtype) =>
OpParams
-> Tensor v'1 ResourceHandle -> Tensor v'2 dtype -> m' ControlNode
CoreOps.assignVariableOp' OpParams
params Tensor Value ResourceHandle
h Tensor v a
v

-- | Increments the value of a variable.
assignAdd :: (MonadBuild m, TensorType a)
    => Variable a -> Tensor v a -> m ControlNode
assignAdd :: Variable a -> Tensor v a -> m ControlNode
assignAdd = OpParams -> Variable a -> Tensor v a -> m ControlNode
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Variable a -> Tensor v a -> m ControlNode
assignAdd' OpParams
forall a. a -> a
id

assignAdd' :: (MonadBuild m, TensorType a)
    => OpParams -> Variable a -> Tensor v a -> m ControlNode
assignAdd' :: OpParams -> Variable a -> Tensor v a -> m ControlNode
assignAdd' params :: OpParams
params (Variable h :: Tensor Value ResourceHandle
h _) v :: Tensor v a
v = OpParams
-> Tensor Value ResourceHandle -> Tensor v a -> m ControlNode
forall (v'1 :: * -> *) (v'2 :: * -> *) dtype (m' :: * -> *).
(MonadBuild m', TensorType dtype) =>
OpParams
-> Tensor v'1 ResourceHandle -> Tensor v'2 dtype -> m' ControlNode
CoreOps.assignAddVariableOp' OpParams
params Tensor Value ResourceHandle
h Tensor v a
v

-- | Update '*var' according to the Adam algorithm.
--
-- lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
-- m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t
-- v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t
-- variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)
resourceApplyAdam ::
    (MonadBuild m,
     OneOf '[(Data.Complex.Complex Double),
             (Data.Complex.Complex Float),
             Data.Int.Int16,
             Data.Int.Int32,
             Data.Int.Int64, Data.Int.Int8,
             Data.Word.Word16,
             Data.Word.Word8, Double,
             Float] t)
    => Variable t -- ^ __var__: Should be from a Variable().
    -> Variable t -- ^ __m__: Should be from a Variable().
    -> Variable t -- ^ __v__: Should be from a Variable().
    -> Tensor v1 t -- ^ __beta1_power__: Must be a scalar.
    -> Tensor v2 t -- ^ __beta2_power__: Must be a scalar.
    -> Tensor v3 t -- ^ __lr__: Scaling factor. Must be a scalar.
    -> Tensor v4 t -- ^ __beta1__: Momentum factor. Must be a scalar.
    -> Tensor v5 t -- ^ __beta2__: Momentum factor. Must be a scalar.
    -> Tensor v6 t -- ^ __epsilon__: Ridge term. Must be a scalar.
    -> Tensor v7 t -- ^ __grad__: The gradient.
    -> m (ControlNode)
resourceApplyAdam :: Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
resourceApplyAdam = OpParams
-> Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
forall (m :: * -> *) t (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *)
       (v4 :: * -> *) (v5 :: * -> *) (v6 :: * -> *) (v7 :: * -> *).
(MonadBuild m,
 OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word8, Double, Float]
   t) =>
OpParams
-> Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
resourceApplyAdam' OpParams
forall a. a -> a
id

resourceApplyAdam' ::
    (MonadBuild m,
     OneOf '[(Data.Complex.Complex Double),
             (Data.Complex.Complex Float),
             Data.Int.Int16, Data.Int.Int32,
             Data.Int.Int64, Data.Int.Int8,
             Data.Word.Word16, Data.Word.Word8, Double,
             Float] t)
    => OpParams
    -> Variable t -- ^ __var__: Should be from a Variable().
    -> Variable t -- ^ __m__: Should be from a Variable().
    -> Variable t -- ^ __v__: Should be from a Variable().
    -> Tensor v1 t -- ^ __beta1_power__: Must be a scalar.
    -> Tensor v2 t -- ^ __beta2_power__: Must be a scalar.
    -> Tensor v3 t -- ^ __lr__: Scaling factor. Must be a scalar.
    -> Tensor v4 t -- ^ __beta1__: Momentum factor. Must be a scalar.
    -> Tensor v5 t -- ^ __beta2__: Momentum factor. Must be a scalar.
    -> Tensor v6 t -- ^ __epsilon__: Ridge term. Must be a scalar.
    -> Tensor v7 t -- ^ __grad__: The gradient.
    -> m (ControlNode)
resourceApplyAdam' :: OpParams
-> Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
resourceApplyAdam' params :: OpParams
params (Variable var :: Tensor Value ResourceHandle
var _) (Variable m :: Tensor Value ResourceHandle
m _) (Variable v :: Tensor Value ResourceHandle
v _) =
    OpParams
-> Tensor Value ResourceHandle
-> Tensor Value ResourceHandle
-> Tensor Value ResourceHandle
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *)
       (v'4 :: * -> *) (v'5 :: * -> *) (v'6 :: * -> *) (v'7 :: * -> *)
       (v'8 :: * -> *) (v'9 :: * -> *) (v'10 :: * -> *) t (m' :: * -> *).
(MonadBuild m',
 OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word32, Word64, Word8, Double, Float]
   t) =>
OpParams
-> Tensor v'1 ResourceHandle
-> Tensor v'2 ResourceHandle
-> Tensor v'3 ResourceHandle
-> Tensor v'4 t
-> Tensor v'5 t
-> Tensor v'6 t
-> Tensor v'7 t
-> Tensor v'8 t
-> Tensor v'9 t
-> Tensor v'10 t
-> m' ControlNode
CoreOps.resourceApplyAdam' OpParams
params Tensor Value ResourceHandle
var Tensor Value ResourceHandle
m Tensor Value ResourceHandle
v