-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

-- | TensorBoard Summary generation. Provides type safe wrappers around raw
-- string emitting CoreOps.
--
-- Example use:
--
-- > -- Call summary functions while constructing the graph.
-- > createModel = do
-- >   loss <- -- ...
-- >   TF.scalarSummary loss
-- >
-- > -- Write summaries to an EventWriter.
-- > train = TF.withEventWriter "/path/to/logs" $ \eventWriter -> do
-- >     summaryTensor <- TF.build TF.allSummaries
-- >     forM_ [1..] $ \step -> do
-- >         if (step % 100 == 0)
-- >             then do
-- >                 ((), summaryBytes) <- TF.run (trainStep, summaryTensor)
-- >                 let summary = decodeMessageOrDie (TF.unScalar summaryBytes)
-- >                 TF.logSummary eventWriter step summary
-- >             else TF.run_ trainStep

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}


module TensorFlow.Logging
    ( EventWriter
    , withEventWriter
    , logEvent
    , logGraph
    , logSummary
    , SummaryTensor
    , histogramSummary
    , imageSummary
    , scalarSummary
    , mergeAllSummaries
    ) where

import Control.Concurrent (forkFinally)
import Control.Concurrent.MVar (MVar, newEmptyMVar, readMVar, putMVar)
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TBMQueue (TBMQueue, newTBMQueueIO, closeTBMQueue, writeTBMQueue)
import Control.Monad.Catch (MonadMask, bracket)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Resource (runResourceT)
import Data.ByteString (ByteString)
import Data.Conduit ((.|))
import Data.Conduit.TQueue (sourceTBMQueue)
import Data.ProtoLens.Default(def)
import Data.Int (Int64)
import Data.Word (Word8, Word16)
import Data.ProtoLens (encodeMessage)
import Data.Time.Clock (getCurrentTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import Lens.Family2 ((.~), (&))
import Network.HostName (getHostName)
import Proto.Tensorflow.Core.Framework.Summary (Summary)
import Proto.Tensorflow.Core.Util.Event (Event)
import Proto.Tensorflow.Core.Util.Event_Fields (fileVersion, graphDef, step, summary, wallTime)
import System.Directory (createDirectoryIfMissing)
import System.FilePath ((</>))
import TensorFlow.Build (MonadBuild, Build, asGraphDef)
import TensorFlow.Ops (scalar)
import TensorFlow.Records.Conduit (sinkTFRecords)
import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries)
import TensorFlow.Types (TensorType, type(/=), OneOf)
import Text.Printf (printf)
import qualified Data.ByteString.Lazy as L
import qualified Data.Conduit as Conduit
import qualified Data.Conduit.List as Conduit
import qualified Data.Text as T
import qualified TensorFlow.GenOps.Core as CoreOps

-- | Handle for logging TensorBoard events safely from multiple threads.
data EventWriter = EventWriter (TBMQueue Event) (MVar ())

-- | Writes Event protocol buffers to event files.
withEventWriter ::
    (MonadIO m, MonadMask m)
    => FilePath
    -- ^ logdir. Local filesystem directory where event file will be written.
    -> (EventWriter -> m a)
    -> m a
withEventWriter :: FilePath -> (EventWriter -> m a) -> m a
withEventWriter logdir :: FilePath
logdir =
    m EventWriter
-> (EventWriter -> m ()) -> (EventWriter -> m a) -> m a
forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket (IO EventWriter -> m EventWriter
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (FilePath -> IO EventWriter
newEventWriter FilePath
logdir)) (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (EventWriter -> IO ()) -> EventWriter -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EventWriter -> IO ()
closeEventWriter)

newEventWriter :: FilePath -> IO EventWriter
newEventWriter :: FilePath -> IO EventWriter
newEventWriter logdir :: FilePath
logdir = do
    Bool -> FilePath -> IO ()
createDirectoryIfMissing Bool
True FilePath
logdir
    Double
t <- IO Double
doubleWallTime
    FilePath
hostname <- IO FilePath
getHostName
    let filename :: FilePath
filename = FilePath -> Integer -> FilePath -> FilePath
forall r. PrintfType r => FilePath -> r
printf (FilePath
logdir FilePath -> FilePath -> FilePath
</> "events.out.tfevents.%010d.%s")
                          (Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
truncate Double
t :: Integer) FilePath
hostname
    -- Asynchronously consume events from a queue.
    -- We use a bounded queue to ensure the producer doesn't get too far ahead
    -- of the consumer. The buffer size was picked arbitrarily.
    TBMQueue Event
q <- Int -> IO (TBMQueue Event)
forall a. Int -> IO (TBMQueue a)
newTBMQueueIO 1024
    -- Use an MVar to signal that the worker thread has completed.
    MVar ()
done <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
    let writer :: EventWriter
writer = TBMQueue Event -> MVar () -> EventWriter
EventWriter TBMQueue Event
q MVar ()
done
        consumeQueue :: IO ()
consumeQueue = ResourceT IO () -> IO ()
forall (m :: * -> *) a. MonadUnliftIO m => ResourceT m a -> m a
runResourceT (ResourceT IO () -> IO ()) -> ResourceT IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ConduitT () Void (ResourceT IO) () -> ResourceT IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
Conduit.runConduit (ConduitT () Void (ResourceT IO) () -> ResourceT IO ())
-> ConduitT () Void (ResourceT IO) () -> ResourceT IO ()
forall a b. (a -> b) -> a -> b
$
            TBMQueue Event -> ConduitT () Event (ResourceT IO) ()
forall (m :: * -> *) a z.
MonadIO m =>
TBMQueue a -> ConduitT z a m ()
sourceTBMQueue TBMQueue Event
q
            ConduitT () Event (ResourceT IO) ()
-> ConduitM Event Void (ResourceT IO) ()
-> ConduitT () Void (ResourceT IO) ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| (Event -> ByteString)
-> ConduitT Event ByteString (ResourceT IO) ()
forall (m :: * -> *) a b. Monad m => (a -> b) -> ConduitT a b m ()
Conduit.map (ByteString -> ByteString
L.fromStrict (ByteString -> ByteString)
-> (Event -> ByteString) -> Event -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Event -> ByteString
forall msg. Message msg => msg -> ByteString
encodeMessage)
            ConduitT Event ByteString (ResourceT IO) ()
-> ConduitM ByteString Void (ResourceT IO) ()
-> ConduitM Event Void (ResourceT IO) ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| FilePath -> ConduitM ByteString Void (ResourceT IO) ()
forall (m :: * -> *) o.
MonadResource m =>
FilePath -> ConduitT ByteString o m ()
sinkTFRecords FilePath
filename
    ThreadId
_ <- IO () -> (Either SomeException () -> IO ()) -> IO ThreadId
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally IO ()
consumeQueue (\_ -> MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
done ())
    EventWriter -> Event -> IO ()
forall (m :: * -> *). MonadIO m => EventWriter -> Event -> m ()
logEvent EventWriter
writer (Event -> IO ()) -> Event -> IO ()
forall a b. (a -> b) -> a -> b
$ Event
forall a. Message a => a
def Event -> (Event -> Event) -> Event
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f Event Double
forall (f :: * -> *) s a.
(Functor f, HasField s "wallTime" a) =>
LensLike' f s a
wallTime (forall (f :: * -> *). Identical f => LensLike' f Event Double)
-> Double -> Event -> Event
forall s t a b. Setter s t a b -> b -> s -> t
.~ Double
t
                          Event -> (Event -> Event) -> Event
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f Event Text
forall (f :: * -> *) s a.
(Functor f, HasField s "fileVersion" a) =>
LensLike' f s a
fileVersion (forall (f :: * -> *). Identical f => LensLike' f Event Text)
-> Text -> Event -> Event
forall s t a b. Setter s t a b -> b -> s -> t
.~ FilePath -> Text
T.pack "brain.Event:2"
    EventWriter -> IO EventWriter
forall (m :: * -> *) a. Monad m => a -> m a
return EventWriter
writer

closeEventWriter :: EventWriter -> IO ()
closeEventWriter :: EventWriter -> IO ()
closeEventWriter (EventWriter q :: TBMQueue Event
q done :: MVar ()
done) =
    STM () -> IO ()
forall a. STM a -> IO a
atomically (TBMQueue Event -> STM ()
forall a. TBMQueue a -> STM ()
closeTBMQueue TBMQueue Event
q) IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar () -> IO ()
forall a. MVar a -> IO a
readMVar MVar ()
done

-- | Logs the given Event protocol buffer.
logEvent :: MonadIO m => EventWriter -> Event -> m ()
logEvent :: EventWriter -> Event -> m ()
logEvent (EventWriter q :: TBMQueue Event
q _) pb :: Event
pb = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (STM () -> IO ()
forall a. STM a -> IO a
atomically (TBMQueue Event -> Event -> STM ()
forall a. TBMQueue a -> a -> STM ()
writeTBMQueue TBMQueue Event
q Event
pb))

-- | Logs the graph for the given 'Build' action.
logGraph :: MonadIO m => EventWriter -> Build a -> m ()
logGraph :: EventWriter -> Build a -> m ()
logGraph writer :: EventWriter
writer build :: Build a
build = do
  let graph :: GraphDef
graph = Build a -> GraphDef
forall a. Build a -> GraphDef
asGraphDef Build a
build
      graphBytes :: ByteString
graphBytes = GraphDef -> ByteString
forall msg. Message msg => msg -> ByteString
encodeMessage GraphDef
graph
      graphEvent :: Event
graphEvent = (Event
forall a. Message a => a
def :: Event) Event -> (Event -> Event) -> Event
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f Event ByteString
forall (f :: * -> *) s a.
(Functor f, HasField s "graphDef" a) =>
LensLike' f s a
graphDef (forall (f :: * -> *). Identical f => LensLike' f Event ByteString)
-> ByteString -> Event -> Event
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
graphBytes
  EventWriter -> Event -> m ()
forall (m :: * -> *). MonadIO m => EventWriter -> Event -> m ()
logEvent EventWriter
writer Event
graphEvent

-- | Logs the given Summary event with an optional global step (use 0 if not
-- applicable).
logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m ()
logSummary :: EventWriter -> Int64 -> Summary -> m ()
logSummary writer :: EventWriter
writer step' :: Int64
step' summaryProto :: Summary
summaryProto = do
    Double
t <- IO Double -> m Double
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Double
doubleWallTime
    EventWriter -> Event -> m ()
forall (m :: * -> *). MonadIO m => EventWriter -> Event -> m ()
logEvent EventWriter
writer (Event
forall a. Message a => a
def Event -> (Event -> Event) -> Event
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f Event Double
forall (f :: * -> *) s a.
(Functor f, HasField s "wallTime" a) =>
LensLike' f s a
wallTime (forall (f :: * -> *). Identical f => LensLike' f Event Double)
-> Double -> Event -> Event
forall s t a b. Setter s t a b -> b -> s -> t
.~ Double
t
                         Event -> (Event -> Event) -> Event
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f Event Int64
forall (f :: * -> *) s a.
(Functor f, HasField s "step" a) =>
LensLike' f s a
step (forall (f :: * -> *). Identical f => LensLike' f Event Int64)
-> Int64 -> Event -> Event
forall s t a b. Setter s t a b -> b -> s -> t
.~ Int64
step'
                         Event -> (Event -> Event) -> Event
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *). Identical f => LensLike' f Event Summary
forall (f :: * -> *) s a.
(Functor f, HasField s "summary" a) =>
LensLike' f s a
summary (forall (f :: * -> *). Identical f => LensLike' f Event Summary)
-> Summary -> Event -> Event
forall s t a b. Setter s t a b -> b -> s -> t
.~ Summary
summaryProto
                    )

-- Number of seconds since epoch.
doubleWallTime :: IO Double
doubleWallTime :: IO Double
doubleWallTime = UTCTime -> Double
forall a. Fractional a => UTCTime -> a
asDouble (UTCTime -> Double) -> IO UTCTime -> IO Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UTCTime
getCurrentTime
    where asDouble :: UTCTime -> a
asDouble t :: UTCTime
t = Rational -> a
forall a. Fractional a => Rational -> a
fromRational (POSIXTime -> Rational
forall a. Real a => a -> Rational
toRational (UTCTime -> POSIXTime
utcTimeToPOSIXSeconds UTCTime
t))

-- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally
-- limited to a single value for simplicity.
histogramSummary ::
    (MonadBuild m, TensorType t, t /= ByteString, t /= Bool)
     -- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
    => ByteString -> Tensor v t -> m ()
histogramSummary :: ByteString -> Tensor v t -> m ()
histogramSummary tag :: ByteString
tag = Tensor Build ByteString -> m ()
forall (m :: * -> *) (v :: * -> *).
(MonadBuild m, TensorKind v) =>
Tensor v ByteString -> m ()
addSummary (Tensor Build ByteString -> m ())
-> (Tensor v t -> Tensor Build ByteString) -> Tensor v t -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Build ByteString -> Tensor v t -> Tensor Build ByteString
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
    Float]
  t =>
Tensor v'1 ByteString -> Tensor v'2 t -> Tensor Build ByteString
CoreOps.histogramSummary (ByteString -> Tensor Build ByteString
forall a. TensorType a => a -> Tensor Build a
scalar ByteString
tag)

-- | Adds a 'CoreOps.imageSummary' node. The tag argument is intentionally
-- limited to a single value for simplicity.
imageSummary ::
    (OneOf '[Word8, Word16, Float] t, MonadBuild m)
    => ByteString
    -> Tensor v t
    -> m ()

imageSummary :: ByteString -> Tensor v t -> m ()
imageSummary tag :: ByteString
tag = Tensor Build ByteString -> m ()
forall (m :: * -> *) (v :: * -> *).
(MonadBuild m, TensorKind v) =>
Tensor v ByteString -> m ()
addSummary (Tensor Build ByteString -> m ())
-> (Tensor v t -> Tensor Build ByteString) -> Tensor v t -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Build ByteString -> Tensor v t -> Tensor Build ByteString
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Word16, Word8, Double, Float] t =>
Tensor v'1 ByteString -> Tensor v'2 t -> Tensor Build ByteString
CoreOps.imageSummary (ByteString -> Tensor Build ByteString
forall a. TensorType a => a -> Tensor Build a
scalar ByteString
tag)

-- | Adds a 'CoreOps.scalarSummary' node.
scalarSummary ::
    (TensorType t, t /= ByteString, t /= Bool, MonadBuild m)
    -- (TensorType t,
    --  OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
    => ByteString -> Tensor v t -> m ()
scalarSummary :: ByteString -> Tensor v t -> m ()
scalarSummary tag :: ByteString
tag = Tensor Build ByteString -> m ()
forall (m :: * -> *) (v :: * -> *).
(MonadBuild m, TensorKind v) =>
Tensor v ByteString -> m ()
addSummary (Tensor Build ByteString -> m ())
-> (Tensor v t -> Tensor Build ByteString) -> Tensor v t -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Build ByteString -> Tensor v t -> Tensor Build ByteString
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
    Float]
  t =>
Tensor v'1 ByteString -> Tensor v'2 t -> Tensor Build ByteString
CoreOps.scalarSummary (ByteString -> Tensor Build ByteString
forall a. TensorType a => a -> Tensor Build a
scalar ByteString
tag)

-- | Merge all summaries accumulated in the 'Build' into one summary.
mergeAllSummaries :: MonadBuild m => m SummaryTensor
mergeAllSummaries :: m SummaryTensor
mergeAllSummaries = m [SummaryTensor]
forall (m :: * -> *). MonadBuild m => m [SummaryTensor]
collectAllSummaries m [SummaryTensor]
-> ([SummaryTensor] -> m SummaryTensor) -> m SummaryTensor
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor Build ByteString -> m SummaryTensor
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build ByteString -> m SummaryTensor)
-> ([SummaryTensor] -> Tensor Build ByteString)
-> [SummaryTensor]
-> m SummaryTensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SummaryTensor] -> Tensor Build ByteString
forall (v'1 :: * -> *).
[Tensor v'1 ByteString] -> Tensor Build ByteString
CoreOps.mergeSummary