-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module TensorFlow.ControlFlow
    ( -- * Dependencies
      withControlDependencies
    , group
      -- * Operations
    , noOp
    ) where

import TensorFlow.BuildOp
import TensorFlow.Build
import TensorFlow.Nodes

-- | Modify a 'Build' action, such that all new ops rendered in it will depend
-- on the nodes in the first argument.
withControlDependencies :: (MonadBuild m, Nodes t) => t -> m a -> m a
withControlDependencies :: t -> m a -> m a
withControlDependencies deps :: t
deps act :: m a
act = do
    Set NodeName
nodes <- Build (Set NodeName) -> m (Set NodeName)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build (Set NodeName) -> m (Set NodeName))
-> Build (Set NodeName) -> m (Set NodeName)
forall a b. (a -> b) -> a -> b
$ t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t
deps
    Set NodeName -> m a -> m a
forall (m :: * -> *) a. MonadBuild m => Set NodeName -> m a -> m a
withNodeDependencies Set NodeName
nodes m a
act

-- TODO(judahjacobson): Reimplement withDependencies.

-- | Create an op that groups multiple operations.
--
-- When this op finishes, all ops in the input @n@ have finished.  This op has
-- no output.
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
group :: t -> m ControlNode
group deps :: t
deps = t -> m ControlNode -> m ControlNode
forall (m :: * -> *) t a.
(MonadBuild m, Nodes t) =>
t -> m a -> m a
withControlDependencies t
deps m ControlNode
forall (m :: * -> *). MonadBuild m => m ControlNode
noOp

-- | Does nothing.  Only useful as a placeholder for control edges.
noOp :: MonadBuild m => m ControlNode
noOp :: m ControlNode
noOp = Build ControlNode -> m ControlNode
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build ControlNode -> m ControlNode)
-> Build ControlNode -> m ControlNode
forall a b. (a -> b) -> a -> b
$ [Int64] -> OpDef -> Build ControlNode
forall a. BuildResult a => [Int64] -> OpDef -> Build a
buildOp [] (OpDef -> Build ControlNode) -> OpDef -> Build ControlNode
forall a b. (a -> b) -> a -> b
$ OpType -> OpDef
opDef "NoOp"