-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}  -- For the Render class

module TensorFlow.Tensor where

import Data.ByteString (ByteString)
import Data.String (IsString(..))
import qualified Data.Text as Text
import Lens.Family2 ((^.))
import Lens.Family2.State ((%=), use)

import Proto.Tensorflow.Core.Framework.NodeDef_Fields (device)
import TensorFlow.Build
import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..))
import TensorFlow.Types
    ( TensorType
    , TensorData(..)
    , ListOf(..)
    )
import qualified TensorFlow.Internal.FFI as FFI

-- | A named output of a TensorFlow operation.
--
-- The type parameter @a@ is the type of the elements in the 'Tensor'.  The
-- parameter @v@ is either:
--
--   * 'Build': An unrendered, immutable value.
--   * 'Value': A rendered, immutable value.
--   * 'Ref': A rendered stateful handle (e.g., a variable).
--
-- Note that 'expr', 'value', 'render' and 'renderValue' can help convert between
-- the different types of 'Tensor'.
data Tensor v a where
    Tensor :: TensorKind v => {Tensor v a -> v Output
tensorOutput :: v Output} -> Tensor v a

newtype Value a = Value {Value a -> a
runValue :: a}
    deriving a -> Value b -> Value a
(a -> b) -> Value a -> Value b
(forall a b. (a -> b) -> Value a -> Value b)
-> (forall a b. a -> Value b -> Value a) -> Functor Value
forall a b. a -> Value b -> Value a
forall a b. (a -> b) -> Value a -> Value b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Value b -> Value a
$c<$ :: forall a b. a -> Value b -> Value a
fmap :: (a -> b) -> Value a -> Value b
$cfmap :: forall a b. (a -> b) -> Value a -> Value b
Functor

instance Applicative Value where
    pure :: a -> Value a
pure = a -> Value a
forall a. a -> Value a
Value
    Value f :: a -> b
f <*> :: Value (a -> b) -> Value a -> Value b
<*> Value x :: a
x = b -> Value b
forall a. a -> Value a
Value (b -> Value b) -> b -> Value b
forall a b. (a -> b) -> a -> b
$ a -> b
f a
x

instance Monad Value where
    f :: Value a
f >>= :: Value a -> (a -> Value b) -> Value b
>>= g :: a -> Value b
g = a -> Value b
g (a -> Value b) -> a -> Value b
forall a b. (a -> b) -> a -> b
$ Value a -> a
forall a. Value a -> a
runValue Value a
f

newtype Ref a = Ref {Ref a -> a
runRef :: a}
    deriving a -> Ref b -> Ref a
(a -> b) -> Ref a -> Ref b
(forall a b. (a -> b) -> Ref a -> Ref b)
-> (forall a b. a -> Ref b -> Ref a) -> Functor Ref
forall a b. a -> Ref b -> Ref a
forall a b. (a -> b) -> Ref a -> Ref b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Ref b -> Ref a
$c<$ :: forall a b. a -> Ref b -> Ref a
fmap :: (a -> b) -> Ref a -> Ref b
$cfmap :: forall a b. (a -> b) -> Ref a -> Ref b
Functor

instance Applicative Ref where
    pure :: a -> Ref a
pure = a -> Ref a
forall a. a -> Ref a
Ref
    Ref f :: a -> b
f <*> :: Ref (a -> b) -> Ref a -> Ref b
<*> Ref x :: a
x = b -> Ref b
forall a. a -> Ref a
Ref (b -> Ref b) -> b -> Ref b
forall a b. (a -> b) -> a -> b
$ a -> b
f a
x

instance Monad Ref where
    f :: Ref a
f >>= :: Ref a -> (a -> Ref b) -> Ref b
>>= g :: a -> Ref b
g = a -> Ref b
g (a -> Ref b) -> a -> Ref b
forall a b. (a -> b) -> a -> b
$ Ref a -> a
forall a. Ref a -> a
runRef Ref a
f

-- | Cast a 'Tensor Ref' into a 'Tensor Value'. This behaves like a no-op.
value :: Tensor Ref a -> Tensor Value a
value :: Tensor Ref a -> Tensor Value a
value (Tensor o :: Ref Output
o) = Value Output -> Tensor Value a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (Value Output -> Tensor Value a) -> Value Output -> Tensor Value a
forall a b. (a -> b) -> a -> b
$ Output -> Value Output
forall a. a -> Value a
Value (Output -> Value Output) -> Output -> Value Output
forall a b. (a -> b) -> a -> b
$ Ref Output -> Output
forall a. Ref a -> a
runRef Ref Output
o

renderValue :: MonadBuild m => Tensor v a -> m (Tensor Value a)
renderValue :: Tensor v a -> m (Tensor Value a)
renderValue (Tensor o :: v Output
o) = Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ BuildT Identity Output -> Tensor Build a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (BuildT Identity Output -> Tensor Build a)
-> BuildT Identity Output -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild v Output
o

-- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor'
-- when running the graph.
data Feed = Feed Output FFI.TensorData

-- | A class ensuring that a given tensor is rendered, i.e., has a fixed
-- name, device, etc.
class Rendered t where
    renderedOutput :: t a -> Output

instance Rendered (Tensor Value) where
    renderedOutput :: Tensor Value a -> Output
renderedOutput = Value Output -> Output
forall a. Value a -> a
runValue (Value Output -> Output)
-> (Tensor Value a -> Value Output) -> Tensor Value a -> Output
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Value a -> Value Output
forall (v :: * -> *) a. Tensor v a -> v Output
tensorOutput

instance Rendered (Tensor Ref) where
    renderedOutput :: Tensor Ref a -> Output
renderedOutput = Ref Output -> Output
forall a. Ref a -> a
runRef (Ref Output -> Output)
-> (Tensor Ref a -> Ref Output) -> Tensor Ref a -> Output
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Ref a -> Ref Output
forall (v :: * -> *) a. Tensor v a -> v Output
tensorOutput

tensorNodeName :: Rendered t => t a -> NodeName
tensorNodeName :: t a -> NodeName
tensorNodeName = Output -> NodeName
outputNodeName (Output -> NodeName) -> (t a -> Output) -> t a -> NodeName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput


-- | Create a 'Feed' for feeding the given data into a 'Tensor' when running
-- the graph.
--
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
feed :: Rendered t => t a -> TensorData a -> Feed
feed :: t a -> TensorData a -> Feed
feed t :: t a
t (TensorData td :: TensorData
td) = Output -> TensorData -> Feed
Feed (t a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput t a
t) TensorData
td

-- | Create a 'Tensor' for a given name.  This can be used to reference nodes
-- in a 'GraphDef' that was loaded via 'addGraphDef'.
-- TODO(judahjacobson): add more safety checks here.
tensorFromName :: TensorKind v => Text.Text -> Tensor v a
tensorFromName :: Text -> Tensor v a
tensorFromName = v Output -> Tensor v a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (v Output -> Tensor v a)
-> (Text -> v Output) -> Text -> Tensor v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> v Output
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Output -> v Output) -> (Text -> Output) -> Text -> v Output
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Output
forall a. IsString a => String -> a
fromString (String -> Output) -> (Text -> String) -> Text -> Output
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
Text.unpack

-- | Like 'tensorFromName', but type-restricted to 'Value'.
tensorValueFromName :: Text.Text -> Tensor Value a
tensorValueFromName :: Text -> Tensor Value a
tensorValueFromName = Text -> Tensor Value a
forall (v :: * -> *) a. TensorKind v => Text -> Tensor v a
tensorFromName

-- | Like 'tensorFromName', but type-restricted to 'Ref'.
tensorRefFromName :: Text.Text -> Tensor Ref a
tensorRefFromName :: Text -> Tensor Ref a
tensorRefFromName = Text -> Tensor Ref a
forall (v :: * -> *) a. TensorKind v => Text -> Tensor v a
tensorFromName

type TensorList v = ListOf (Tensor v)

tensorListOutputs :: Rendered (Tensor v) => TensorList v as -> [Output]
tensorListOutputs :: TensorList v as -> [Output]
tensorListOutputs Nil = []
tensorListOutputs (t :: Tensor v a
t :/ ts :: ListOf (Tensor v) as
ts) = Tensor v a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput Tensor v a
t Output -> [Output] -> [Output]
forall a. a -> [a] -> [a]
: ListOf (Tensor v) as -> [Output]
forall (v :: * -> *) (as :: [*]).
Rendered (Tensor v) =>
TensorList v as -> [Output]
tensorListOutputs ListOf (Tensor v) as
ts

-- | Places all nodes rendered in the given 'Build' action on the same
-- device as the given Tensor (see also 'withDevice'). Make sure that
-- the action has side effects of rendering the desired tensors. A pure
-- return would not have the desired effect.
colocateWith :: (MonadBuild m, Rendered t) => t b -> m a -> m a
colocateWith :: t b -> m a -> m a
colocateWith t :: t b
t x :: m a
x = do
    Device
d <- Build Device -> m Device
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build Device -> m Device) -> Build Device -> m Device
forall a b. (a -> b) -> a -> b
$ Text -> Device
Device (Text -> Device) -> (NodeDef -> Text) -> NodeDef -> Device
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "device" a) =>
LensLike' f s a
device)
               (NodeDef -> Device) -> BuildT Identity NodeDef -> Build Device
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NodeName -> BuildT Identity NodeDef
lookupNode (Output -> NodeName
outputNodeName (Output -> NodeName) -> Output -> NodeName
forall a b. (a -> b) -> a -> b
$ t b -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput t b
t)
    Maybe Device -> m a -> m a
forall (m :: * -> *) a. MonadBuild m => Maybe Device -> m a -> m a
withDevice (Device -> Maybe Device
forall a. a -> Maybe a
Just Device
d) m a
x


-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
-- the 'MonadBuild' context.  Also renders any dependencies of the 'Tensor' that
-- weren't already rendered.
--
-- This operation is idempotent; calling 'render' on the same input in the same
-- context will produce the same result.  However, rendering the same
-- @Tensor Build@ in two different contexts may result in two different
-- @Tensor Value@s.
render :: MonadBuild m => Tensor Build a -> m (Tensor Value a)
render :: Tensor Build a -> m (Tensor Value a)
render (Tensor t :: BuildT Identity Output
t) = Value Output -> Tensor Value a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (Value Output -> Tensor Value a)
-> (Output -> Value Output) -> Output -> Tensor Value a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> Value Output
forall a. a -> Value a
Value (Output -> Tensor Value a) -> m Output -> m (Tensor Value a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuildT Identity Output -> m Output
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build BuildT Identity Output
t

-- TODO: better name.
expr :: TensorKind v => Tensor v a -> Tensor Build a
expr :: Tensor v a -> Tensor Build a
expr (Tensor o :: v Output
o) = BuildT Identity Output -> Tensor Build a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (BuildT Identity Output -> Tensor Build a)
-> BuildT Identity Output -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild v Output
o

-- | Records the given summary action in Build for retrieval with
-- Summary protocol buffer in string form. For safety, use the
-- pre-composed functions: Logging.scalarSummary and
-- Logging.histogramSummary.
addSummary :: (MonadBuild m, TensorKind v) => Tensor v ByteString -- ^ A 'SummaryTensor'
                        -> m ()
addSummary :: Tensor v ByteString -> m ()
addSummary t :: Tensor v ByteString
t = Build () -> m ()
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build () -> m ()) -> Build () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    -- TODO: more generic way
    Output
o <- v Output -> BuildT Identity Output
forall (v :: * -> *) a. TensorKind v => v a -> Build a
toBuild (v Output -> BuildT Identity Output)
-> v Output -> BuildT Identity Output
forall a b. (a -> b) -> a -> b
$ Tensor v ByteString -> v Output
forall (v :: * -> *) a. Tensor v a -> v Output
tensorOutput Tensor v ByteString
t
    Lens' GraphState [Output]
forall (f :: * -> *).
Identical f =>
LensLike' f GraphState [Output]
summaries (forall (f :: * -> *).
 Identical f =>
 LensLike' f GraphState [Output])
-> ([Output] -> [Output]) -> Build ()
forall s (m :: * -> *) a b.
MonadState s m =>
Setter s s a b -> (a -> b) -> m ()
%= (Output
o Output -> [Output] -> [Output]
forall a. a -> [a] -> [a]
:)

-- | Retrieves the summary ops collected thus far. Typically this only
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
-- repeatedly, the values accumulate.
collectAllSummaries :: MonadBuild m => m [SummaryTensor]
collectAllSummaries :: m [SummaryTensor]
collectAllSummaries = Build [SummaryTensor] -> m [SummaryTensor]
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build [SummaryTensor] -> m [SummaryTensor])
-> Build [SummaryTensor] -> m [SummaryTensor]
forall a b. (a -> b) -> a -> b
$ (Output -> SummaryTensor) -> [Output] -> [SummaryTensor]
forall a b. (a -> b) -> [a] -> [b]
map (Value Output -> SummaryTensor
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (Value Output -> SummaryTensor)
-> (Output -> Value Output) -> Output -> SummaryTensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> Value Output
forall a. a -> Value a
Value) ([Output] -> [SummaryTensor])
-> BuildT Identity [Output] -> Build [SummaryTensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FoldLike [Output] GraphState GraphState [Output] [Output]
-> BuildT Identity [Output]
forall s (m :: * -> *) a t b.
MonadState s m =>
FoldLike a s t a b -> m a
use FoldLike [Output] GraphState GraphState [Output] [Output]
Lens' GraphState [Output]
summaries

-- | Synonym for the tensors that return serialized Summary proto.
type SummaryTensor = Tensor Value ByteString

-- | An internal class for kinds of Tensors.
class Monad v => TensorKind v where
    toBuild :: v a -> Build a

instance TensorKind Value where
    toBuild :: Value a -> Build a
toBuild = a -> Build a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Build a) -> (Value a -> a) -> Value a -> Build a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value a -> a
forall a. Value a -> a
runValue

instance TensorKind Ref where
    toBuild :: Ref a -> Build a
toBuild = a -> Build a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Build a) -> (Ref a -> a) -> Ref a -> Build a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ref a -> a
forall a. Ref a -> a
runRef

instance TensorKind Build where
    toBuild :: Build a -> Build a
toBuild = Build a -> Build a
forall a. a -> a
id


-- | Types which can be converted to `Tensor`.
class ToTensor t where
    toTensor :: TensorType a => t a -> Tensor Build a

instance TensorKind v => ToTensor (Tensor v) where
    toTensor :: Tensor v a -> Tensor Build a
toTensor = Tensor v a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr