-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Queues in TensorFlow graph. Very limited support for now.
module TensorFlow.Queue (Queue, makeQueue, enqueue, dequeue) where

import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Proxy (Proxy(..))
import Lens.Family2 ((.~), (&))
import TensorFlow.Build (ControlNode, MonadBuild, build, addInitializer, opAttr, opDef)
import TensorFlow.BuildOp (buildOp)
import TensorFlow.ControlFlow (group)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Tensor (Ref, Value, Tensor, TensorList)
import TensorFlow.Types (TensorTypes, fromTensorTypes)

-- | A queue carrying tuples.
data Queue (as :: [*]) = Queue { Queue as -> Handle
handle :: Handle }

type Handle = Tensor Ref ByteString

-- | Adds the given values to the queue.
enqueue :: forall as v m . (MonadBuild m, TensorTypes as)
           => Queue as
           -> TensorList v as
           -> m ControlNode
enqueue :: Queue as -> TensorList v as -> m ControlNode
enqueue = Handle -> TensorList v as -> m ControlNode
forall (v'2 :: * -> *) (tcomponents :: [*]) (m' :: * -> *).
(MonadBuild m', TensorTypes tcomponents) =>
Handle -> TensorList v'2 tcomponents -> m' ControlNode
CoreOps.queueEnqueue (Handle -> TensorList v as -> m ControlNode)
-> (Queue as -> Handle)
-> Queue as
-> TensorList v as
-> m ControlNode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Queue as -> Handle
forall (as :: [*]). Queue as -> Handle
handle

-- | Retrieves the values from the queue.
dequeue :: forall as m . (MonadBuild m, TensorTypes as)
           => Queue as
           -> m (TensorList Value as)
           -- ^ Dequeued tensors. They are coupled in a sense
           -- that values appear together, even if they are
           -- not consumed together.
dequeue :: Queue as -> m (TensorList Value as)
dequeue = Handle -> m (TensorList Value as)
forall (component_types :: [*]) (m' :: * -> *).
(MonadBuild m', TensorTypes component_types) =>
Handle -> m' (TensorList Value component_types)
CoreOps.queueDequeue (Handle -> m (TensorList Value as))
-> (Queue as -> Handle) -> Queue as -> m (TensorList Value as)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Queue as -> Handle
forall (as :: [*]). Queue as -> Handle
handle

-- | Creates a new queue with the given capacity and shared name.
makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
              => Int64  -- ^ The upper bound on the number of elements in
                        --  this queue. Negative numbers mean no limit.
              -> ByteString -- ^ If non-empty, this queue will be shared
                            -- under the given name across multiple sessions.
              -> m (Queue as)
makeQueue :: Int64 -> ByteString -> m (Queue as)
makeQueue capacity :: Int64
capacity sharedName :: ByteString
sharedName = do
    Handle
q <- Build Handle -> m Handle
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build Handle -> m Handle) -> Build Handle -> m Handle
forall a b. (a -> b) -> a -> b
$ [Int64] -> OpDef -> Build Handle
forall a. BuildResult a => [Int64] -> OpDef -> Build a
buildOp [] (OpType -> OpDef
opDef "FIFOQueue"
                     OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef [DataType]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "component_types" (forall (f :: * -> *). Identical f => LensLike' f OpDef [DataType])
-> [DataType] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Proxy as -> [DataType]
forall (as :: [*]). TensorTypes as => Proxy as -> [DataType]
fromTensorTypes (Proxy as
forall k (t :: k). Proxy t
Proxy :: Proxy as)
                     OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "shared_name" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
sharedName
                     OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef Int64
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "capacity" (forall (f :: * -> *). Identical f => LensLike' f OpDef Int64)
-> Int64 -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Int64
capacity
                    )
    Handle -> m ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
group Handle
q m ControlNode -> (ControlNode -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ControlNode -> m ()
forall (m :: * -> *). MonadBuild m => ControlNode -> m ()
addInitializer
    Queue as -> m (Queue as)
forall (m :: * -> *) a. Monad m => a -> m a
return (Handle -> Queue as
forall (as :: [*]). Handle -> Queue as
Queue Handle
q)

-- TODO(gnezdo): Figure out the closing story for queues.