-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module TensorFlow.Internal.FFI
    ( TensorFlowException(..)
    , Raw.Session
    , withSession
    , extendGraph
    , run
    , TensorData(..)
    , setSessionConfig
    , setSessionTarget
    , getAllOpList
      -- * Internal helper.
    , useProtoAsVoidPtrLen
    )
    where

import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Monad (when)
import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask_)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Bits (Bits, toIntegralSized)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Mutable as M

import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)

import qualified TensorFlow.Internal.Raw as Raw

data TensorFlowException = TensorFlowException Raw.Code T.Text
    deriving (Int -> TensorFlowException -> ShowS
[TensorFlowException] -> ShowS
TensorFlowException -> String
(Int -> TensorFlowException -> ShowS)
-> (TensorFlowException -> String)
-> ([TensorFlowException] -> ShowS)
-> Show TensorFlowException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TensorFlowException] -> ShowS
$cshowList :: [TensorFlowException] -> ShowS
show :: TensorFlowException -> String
$cshow :: TensorFlowException -> String
showsPrec :: Int -> TensorFlowException -> ShowS
$cshowsPrec :: Int -> TensorFlowException -> ShowS
Show, TensorFlowException -> TensorFlowException -> Bool
(TensorFlowException -> TensorFlowException -> Bool)
-> (TensorFlowException -> TensorFlowException -> Bool)
-> Eq TensorFlowException
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TensorFlowException -> TensorFlowException -> Bool
$c/= :: TensorFlowException -> TensorFlowException -> Bool
== :: TensorFlowException -> TensorFlowException -> Bool
$c== :: TensorFlowException -> TensorFlowException -> Bool
Eq, Typeable)

instance Exception TensorFlowException

-- | All of the data needed to represent a tensor.
data TensorData = TensorData
    { TensorData -> [Int64]
tensorDataDimensions :: [Int64]
    , TensorData -> DataType
tensorDataType       :: !DataType
    , TensorData -> Vector Word8
tensorDataBytes      :: !(S.Vector Word8)
    }
  deriving (Int -> TensorData -> ShowS
[TensorData] -> ShowS
TensorData -> String
(Int -> TensorData -> ShowS)
-> (TensorData -> String)
-> ([TensorData] -> ShowS)
-> Show TensorData
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TensorData] -> ShowS
$cshowList :: [TensorData] -> ShowS
show :: TensorData -> String
$cshow :: TensorData -> String
showsPrec :: Int -> TensorData -> ShowS
$cshowsPrec :: Int -> TensorData -> ShowS
Show, TensorData -> TensorData -> Bool
(TensorData -> TensorData -> Bool)
-> (TensorData -> TensorData -> Bool) -> Eq TensorData
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TensorData -> TensorData -> Bool
$c/= :: TensorData -> TensorData -> Bool
== :: TensorData -> TensorData -> Bool
$c== :: TensorData -> TensorData -> Bool
Eq)

-- | Runs the given action after creating a session with options
-- populated by the given optionSetter.
withSession :: (MonadIO m, MonadMask m)
            => (Raw.SessionOptions -> IO ())
            -> ((IO () -> IO ()) -> Raw.Session -> m a)
            -- ^ The action can spawn concurrent tasks which will
            -- be canceled before withSession returns.
            -> m a
withSession :: (SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Session -> m a) -> m a
withSession optionSetter :: SessionOptions -> IO ()
optionSetter action :: (IO () -> IO ()) -> Session -> m a
action = do
    MVar [Async ()]
drain <- IO (MVar [Async ()]) -> m (MVar [Async ()])
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (MVar [Async ()]) -> m (MVar [Async ()]))
-> IO (MVar [Async ()]) -> m (MVar [Async ()])
forall a b. (a -> b) -> a -> b
$ [Async ()] -> IO (MVar [Async ()])
forall a. a -> IO (MVar a)
newMVar []
    let cleanup :: Session -> IO ()
cleanup s :: Session
s =
        -- Closes the session to nudge the pending run calls to fail and exit.
            IO () -> IO () -> IO ()
forall (m :: * -> *) a b. MonadMask m => m a -> m b -> m a
finally ((Status -> IO ()) -> IO ()
forall a. (Status -> IO a) -> IO a
checkStatus (Session -> Status -> IO ()
Raw.closeSession Session
s)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                [Async ()]
runners <- MVar [Async ()] -> IO [Async ()]
forall a. MVar a -> IO a
takeMVar MVar [Async ()]
drain
                -- Collects all runners before deleting the session.
                (Async () -> IO ()) -> [Async ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Async () -> IO ()
shutDownRunner [Async ()]
runners
                (Status -> IO ()) -> IO ()
forall a. (Status -> IO a) -> IO a
checkStatus (Session -> Status -> IO ()
Raw.deleteSession Session
s)
    let bracketIO :: IO a -> (a -> IO c) -> (a -> m b) -> m b
bracketIO x :: IO a
x y :: a -> IO c
y = m a -> (a -> m c) -> (a -> m b) -> m b
forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket (IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO a
x) (IO c -> m c
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO c -> m c) -> (a -> IO c) -> a -> m c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> IO c
y)
    IO SessionOptions
-> (SessionOptions -> IO ()) -> (SessionOptions -> m a) -> m a
forall (m :: * -> *) a c b.
(MonadMask m, MonadIO m) =>
IO a -> (a -> IO c) -> (a -> m b) -> m b
bracketIO IO SessionOptions
Raw.newSessionOptions SessionOptions -> IO ()
Raw.deleteSessionOptions ((SessionOptions -> m a) -> m a) -> (SessionOptions -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \options :: SessionOptions
options -> do
        IO Session -> (Session -> IO ()) -> (Session -> m a) -> m a
forall (m :: * -> *) a c b.
(MonadMask m, MonadIO m) =>
IO a -> (a -> IO c) -> (a -> m b) -> m b
bracketIO
            (SessionOptions -> IO ()
optionSetter SessionOptions
options IO () -> IO Session -> IO Session
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Status -> IO Session) -> IO Session
forall a. (Status -> IO a) -> IO a
checkStatus (SessionOptions -> Status -> IO Session
Raw.newSession SessionOptions
options))
            Session -> IO ()
cleanup
            ((IO () -> IO ()) -> Session -> m a
action (MVar [Async ()] -> IO () -> IO ()
asyncCollector MVar [Async ()]
drain))

asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain :: MVar [Async ()]
drain runner :: IO ()
runner = MVar [Async ()] -> ([Async ()] -> IO [Async ()]) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVarMasked_ MVar [Async ()]
drain [Async ()] -> IO [Async ()]
launchAndRecord
    where
      launchAndRecord :: [Async ()] -> IO [Async ()]
launchAndRecord restRunners :: [Async ()]
restRunners = (Async () -> [Async ()] -> [Async ()]
forall a. a -> [a] -> [a]
: [Async ()]
restRunners) (Async () -> [Async ()]) -> IO (Async ()) -> IO [Async ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async IO ()
runner

shutDownRunner :: Async () -> IO ()
shutDownRunner :: Async () -> IO ()
shutDownRunner r :: Async ()
r = do
    Async () -> IO ()
forall a. Async a -> IO ()
cancel Async ()
r
    -- TODO(gnezdo): manage exceptions better than print.
    (SomeException -> IO ())
-> (() -> IO ()) -> Either SomeException () -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> IO ()
forall a. Show a => a -> IO ()
print (IO () -> () -> IO ()
forall a b. a -> b -> a
const (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())) (Either SomeException () -> IO ())
-> IO (Either SomeException ()) -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Async () -> IO (Either SomeException ())
forall a. Async a -> IO (Either SomeException a)
waitCatch Async ()
r

extendGraph :: Raw.Session -> GraphDef -> IO ()
extendGraph :: Session -> GraphDef -> IO ()
extendGraph session :: Session
session pb :: GraphDef
pb =
    GraphDef -> (Ptr () -> CULong -> IO ()) -> IO ()
forall msg c b a.
(Message msg, Integral c, Show c, Bits c) =>
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen GraphDef
pb ((Ptr () -> CULong -> IO ()) -> IO ())
-> (Ptr () -> CULong -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr ()
ptr len :: CULong
len ->
        (Status -> IO ()) -> IO ()
forall a. (Status -> IO a) -> IO a
checkStatus ((Status -> IO ()) -> IO ()) -> (Status -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Session -> Ptr () -> CULong -> Status -> IO ()
Raw.extendGraph Session
session Ptr ()
ptr CULong
len


run :: Raw.Session
    -> [(B.ByteString, TensorData)] -- ^ Feeds.
    -> [B.ByteString]               -- ^ Fetches.
    -> [B.ByteString]               -- ^ Targets.
    -> IO [TensorData]
run :: Session
-> [(ByteString, TensorData)]
-> [ByteString]
-> [ByteString]
-> IO [TensorData]
run session :: Session
session feeds :: [(ByteString, TensorData)]
feeds fetches :: [ByteString]
fetches targets :: [ByteString]
targets = do
    let nullTensor :: Tensor
nullTensor = Ptr Tensor -> Tensor
Raw.Tensor Ptr Tensor
forall a. Ptr a
nullPtr
    -- Use mask to avoid leaking input tensors before they are passed to 'run'
    -- and output tensors before they are passed to 'createTensorData'.
    IO [TensorData] -> IO [TensorData]
forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ (IO [TensorData] -> IO [TensorData])
-> IO [TensorData] -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$
        -- Feeds
        [ByteString]
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a. [ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen ((ByteString, TensorData) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, TensorData) -> ByteString)
-> [(ByteString, TensorData)] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(ByteString, TensorData)]
feeds) ((Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData])
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$ \feedsLen :: Int
feedsLen feedNames :: Ptr CString
feedNames ->
        ((ByteString, TensorData) -> IO Tensor)
-> [(ByteString, TensorData)] -> IO [Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TensorData -> IO Tensor
createRawTensor (TensorData -> IO Tensor)
-> ((ByteString, TensorData) -> TensorData)
-> (ByteString, TensorData)
-> IO Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, TensorData) -> TensorData
forall a b. (a, b) -> b
snd) [(ByteString, TensorData)]
feeds IO [Tensor] -> ([Tensor] -> IO [TensorData]) -> IO [TensorData]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \feedTensors :: [Tensor]
feedTensors ->
        [Tensor]
-> (Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData]
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen [Tensor]
feedTensors ((Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData])
-> (Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$ \_ cFeedTensors :: Ptr Tensor
cFeedTensors ->
        -- Fetches.
        [ByteString]
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a. [ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen [ByteString]
fetches ((Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData])
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$ \fetchesLen :: Int
fetchesLen fetchNames :: Ptr CString
fetchNames ->
        -- tensorOuts is an array of null Tensor pointers that will be filled
        -- by the call to Raw.run.
        [Tensor]
-> (Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData]
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen (Int -> Tensor -> [Tensor]
forall a. Int -> a -> [a]
replicate Int
fetchesLen Tensor
nullTensor) ((Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData])
-> (Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$ \_ tensorOuts :: Ptr Tensor
tensorOuts ->
        -- Targets.
        [ByteString]
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a. [ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen [ByteString]
targets ((Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData])
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$ \targetsLen :: Int
targetsLen ctargets :: Ptr CString
ctargets -> do
            (Status -> IO ()) -> IO ()
forall a. (Status -> IO a) -> IO a
checkStatus ((Status -> IO ()) -> IO ()) -> (Status -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Session
-> BufferPtr
-> Ptr CString
-> Ptr Tensor
-> CInt
-> Ptr CString
-> Ptr Tensor
-> CInt
-> Ptr CString
-> CInt
-> BufferPtr
-> Status
-> IO ()
Raw.run
                Session
session
                BufferPtr
forall a. Ptr a
nullPtr
                Ptr CString
feedNames Ptr Tensor
cFeedTensors (Int -> CInt
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
feedsLen)
                Ptr CString
fetchNames Ptr Tensor
tensorOuts (Int -> CInt
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
fetchesLen)
                Ptr CString
ctargets (Int -> CInt
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
targetsLen)
                BufferPtr
forall a. Ptr a
nullPtr
            (Tensor -> IO ()) -> [Tensor] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Tensor -> IO ()
Raw.deleteTensor [Tensor]
feedTensors
            [Tensor]
outTensors <- Int -> Ptr Tensor -> IO [Tensor]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
fetchesLen Ptr Tensor
tensorOuts
            (Tensor -> IO TensorData) -> [Tensor] -> IO [TensorData]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tensor -> IO TensorData
createTensorData [Tensor]
outTensors


-- Internal.


-- | Same as 'fromIntegral', but throws an error if conversion is "lossy".
safeConvert ::
    forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b)
    => a -> b
safeConvert :: a -> b
safeConvert x :: a
x =
    b -> Maybe b -> b
forall a. a -> Maybe a -> a
fromMaybe
    (String -> b
forall a. HasCallStack => String -> a
error ("Failed to convert " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ ", got " String -> ShowS
forall a. [a] -> [a] -> [a]
++
            b -> String
forall a. Show a => a -> String
show (a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
x :: b)))
    (a -> Maybe b
forall a b.
(Integral a, Integral b, Bits a, Bits b) =>
a -> Maybe b
toIntegralSized a
x)


-- | Use a list of ByteString as a list of CString.
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList :: [ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings :: [ByteString]
strings fn :: [CString] -> IO a
fn = [ByteString] -> [CString] -> IO a
go [ByteString]
strings []
  where
    go :: [ByteString] -> [CString] -> IO a
go [] cs :: [CString]
cs = [CString] -> IO a
fn ([CString] -> [CString]
forall a. [a] -> [a]
reverse [CString]
cs)
    -- TODO(fmayle): Is it worth using unsafeAsCString here?
    go (x :: ByteString
x:xs :: [ByteString]
xs) cs :: [CString]
cs = ByteString -> (CString -> IO a) -> IO a
forall a. ByteString -> (CString -> IO a) -> IO a
B.useAsCString ByteString
x ((CString -> IO a) -> IO a) -> (CString -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \c :: CString
c -> [ByteString] -> [CString] -> IO a
go [ByteString]
xs (CString
cCString -> [CString] -> [CString]
forall a. a -> [a] -> [a]
:[CString]
cs)


-- | Use a list of ByteString as an array of CString.
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen :: [ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs :: [ByteString]
xs fn :: Int -> Ptr CString -> IO a
fn = [ByteString] -> ([CString] -> IO a) -> IO a
forall a. [ByteString] -> ([CString] -> IO a) -> IO a
withStringList [ByteString]
xs ([CString] -> (Int -> Ptr CString -> IO a) -> IO a
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
`withArrayLen` Int -> Ptr CString -> IO a
fn)


-- | Create a Raw.Tensor from a TensorData.
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor :: TensorData -> IO Tensor
createRawTensor (TensorData dims :: [Int64]
dims dt :: DataType
dt byteVec :: Vector Word8
byteVec) =
    [CInt64] -> (Int -> Ptr CInt64 -> IO Tensor) -> IO Tensor
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen ((Int64 -> CInt64) -> [Int64] -> [CInt64]
forall a b. (a -> b) -> [a] -> [b]
map Int64 -> CInt64
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert [Int64]
dims) ((Int -> Ptr CInt64 -> IO Tensor) -> IO Tensor)
-> (Int -> Ptr CInt64 -> IO Tensor) -> IO Tensor
forall a b. (a -> b) -> a -> b
$ \cdimsLen :: Int
cdimsLen cdims :: Ptr CInt64
cdims -> do
        let len :: Int
len = Vector Word8 -> Int
forall a. Storable a => Vector a -> Int
S.length Vector Word8
byteVec
        Ptr Word8
dest <- Int -> IO (Ptr Word8)
forall a. Storable a => Int -> IO (Ptr a)
mallocArray Int
len
        Vector Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector Word8
byteVec ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \x :: Ptr Word8
x -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Word8
dest Ptr Word8
x Int
len
        DataType
-> Ptr CInt64
-> CInt
-> Ptr ()
-> CULong
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ())
-> Ptr ()
-> IO Tensor
Raw.newTensor (Int -> DataType
forall a. Enum a => Int -> a
toEnum (Int -> DataType) -> Int -> DataType
forall a b. (a -> b) -> a -> b
$ DataType -> Int
forall a. Enum a => a -> Int
fromEnum DataType
dt)
                      Ptr CInt64
cdims (Int -> CInt
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
cdimsLen)
                      (Ptr Word8 -> Ptr ()
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dest) (Int -> CULong
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
len)
                      FunPtr (Ptr () -> CULong -> Ptr () -> IO ())
tensorDeallocFunPtr Ptr ()
forall a. Ptr a
nullPtr

{-# NOINLINE tensorDeallocFunPtr #-}
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
tensorDeallocFunPtr :: FunPtr (Ptr () -> CULong -> Ptr () -> IO ())
tensorDeallocFunPtr = IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ())
forall a. IO a -> a
unsafePerformIO (IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
 -> FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
-> IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ())
forall a b. (a -> b) -> a -> b
$ (Ptr () -> CULong -> Ptr () -> IO ())
-> IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
Raw.wrapTensorDealloc ((Ptr () -> CULong -> Ptr () -> IO ())
 -> IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ())))
-> (Ptr () -> CULong -> Ptr () -> IO ())
-> IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
forall a b. (a -> b) -> a -> b
$ \x :: Ptr ()
x _ _ -> Ptr () -> IO ()
forall a. Ptr a -> IO ()
free Ptr ()
x

-- | Create a TensorData from a Raw.Tensor.
--
-- Takes ownership of the Raw.Tensor.
-- TODO: Currently, it just makes a copy of the Tensor (and then deletes it),
-- since the raw pointer may refer to storage inside a mutable TensorFlow
-- variable.  We should avoid that copy when it's not needed; for example,
-- by making TensorData wrap an IOVector, and changing the code that uses it.
createTensorData :: Raw.Tensor -> IO TensorData
createTensorData :: Tensor -> IO TensorData
createTensorData t :: Tensor
t = do
    -- Read dimensions.
    CInt
numDims <- Tensor -> IO CInt
Raw.numDims Tensor
t
    [CInt64]
dims <- (CInt -> IO CInt64) -> [CInt] -> IO [CInt64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Tensor -> CInt -> IO CInt64
Raw.dim Tensor
t) [0..CInt
numDimsCInt -> CInt -> CInt
forall a. Num a => a -> a -> a
-1]
    -- Read type.
    DataType
dtype <- Int -> DataType
forall a. Enum a => Int -> a
toEnum (Int -> DataType) -> (DataType -> Int) -> DataType -> DataType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DataType -> Int
forall a. Enum a => a -> Int
fromEnum (DataType -> DataType) -> IO DataType -> IO DataType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> IO DataType
Raw.tensorType Tensor
t
    -- Read data.
    Int
len <- CULong -> Int
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert (CULong -> Int) -> IO CULong -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> IO CULong
Raw.tensorByteSize Tensor
t
    Ptr Word8
bytes <- Ptr () -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr (Ptr () -> Ptr Word8) -> IO (Ptr ()) -> IO (Ptr Word8)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> IO (Ptr ())
Raw.tensorData Tensor
t :: IO (Ptr Word8)
    ForeignPtr Word8
fp <- Ptr Word8 -> IO (ForeignPtr Word8)
forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ Ptr Word8
bytes
    -- Make an explicit copy of the raw data, since it might point
    -- to a mutable variable's memory.
    Vector Word8
v <- MVector (PrimState IO) Word8 -> IO (Vector Word8)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
S.freeze (ForeignPtr Word8 -> Int -> MVector RealWorld Word8
forall a s. Storable a => ForeignPtr a -> Int -> MVector s a
M.unsafeFromForeignPtr0 ForeignPtr Word8
fp Int
len)
    Tensor -> IO ()
Raw.deleteTensor Tensor
t
    TensorData -> IO TensorData
forall (m :: * -> *) a. Monad m => a -> m a
return (TensorData -> IO TensorData) -> TensorData -> IO TensorData
forall a b. (a -> b) -> a -> b
$ [Int64] -> DataType -> Vector Word8 -> TensorData
TensorData ((CInt64 -> Int64) -> [CInt64] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map CInt64 -> Int64
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert [CInt64]
dims) DataType
dtype Vector Word8
v

-- | Runs the given action which does FFI calls updating a provided
-- status object. If the status is not OK it is thrown as
-- TensorFlowException.
checkStatus :: (Raw.Status -> IO a) -> IO a
checkStatus :: (Status -> IO a) -> IO a
checkStatus fn :: Status -> IO a
fn =
    IO Status -> (Status -> IO ()) -> (Status -> IO a) -> IO a
forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket IO Status
Raw.newStatus Status -> IO ()
Raw.deleteStatus ((Status -> IO a) -> IO a) -> (Status -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \status :: Status
status -> do
        a
result <- Status -> IO a
fn Status
status
        Code
code <- Status -> IO Code
Raw.getCode Status
status
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Code
code Code -> Code -> Bool
forall a. Eq a => a -> a -> Bool
/= Code
Raw.TF_OK) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Text
msg <- OnDecodeError -> ByteString -> Text
T.decodeUtf8With OnDecodeError
T.lenientDecode (ByteString -> Text) -> IO ByteString -> IO Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                   (Status -> IO CString
Raw.message Status
status IO CString -> (CString -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CString -> IO ByteString
B.packCString)
            TensorFlowException -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (TensorFlowException -> IO ()) -> TensorFlowException -> IO ()
forall a b. (a -> b) -> a -> b
$ Code -> Text -> TensorFlowException
TensorFlowException Code
code Text
msg
        a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
result

setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
setSessionConfig :: ConfigProto -> SessionOptions -> IO ()
setSessionConfig pb :: ConfigProto
pb opt :: SessionOptions
opt =
    ConfigProto -> (Ptr () -> CULong -> IO ()) -> IO ()
forall msg c b a.
(Message msg, Integral c, Show c, Bits c) =>
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen ConfigProto
pb ((Ptr () -> CULong -> IO ()) -> IO ())
-> (Ptr () -> CULong -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr ()
ptr len :: CULong
len ->
        (Status -> IO ()) -> IO ()
forall a. (Status -> IO a) -> IO a
checkStatus (SessionOptions -> Ptr () -> CULong -> Status -> IO ()
Raw.setConfig SessionOptions
opt Ptr ()
ptr CULong
len)

setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
setSessionTarget :: ByteString -> SessionOptions -> IO ()
setSessionTarget target :: ByteString
target = ByteString -> (CString -> IO ()) -> IO ()
forall a. ByteString -> (CString -> IO a) -> IO a
B.useAsCString ByteString
target ((CString -> IO ()) -> IO ())
-> (SessionOptions -> CString -> IO ()) -> SessionOptions -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionOptions -> CString -> IO ()
Raw.setTarget

-- | Serializes the given msg and provides it as (ptr,len) argument
-- to the given action.
useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
                        msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen :: msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen msg :: msg
msg f :: Ptr b -> c -> IO a
f = ByteString -> (CStringLen -> IO a) -> IO a
forall a. ByteString -> (CStringLen -> IO a) -> IO a
B.useAsCStringLen (msg -> ByteString
forall msg. Message msg => msg -> ByteString
encodeMessage msg
msg) ((CStringLen -> IO a) -> IO a) -> (CStringLen -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$
        \(bytes :: CString
bytes, len :: Int
len) -> Ptr b -> c -> IO a
f (CString -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr CString
bytes) (Int -> c
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
len)

-- | Returns the serialized OpList of all OpDefs defined in this
-- address space.
getAllOpList :: IO B.ByteString
getAllOpList :: IO ByteString
getAllOpList = do
    ForeignPtr Buffer
foreignPtr <-
        IO (ForeignPtr Buffer) -> IO (ForeignPtr Buffer)
forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ (FinalizerPtr Buffer -> BufferPtr -> IO (ForeignPtr Buffer)
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr Buffer
Raw.deleteBuffer (BufferPtr -> IO (ForeignPtr Buffer))
-> IO BufferPtr -> IO (ForeignPtr Buffer)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO BufferPtr
checkCall)
    -- Makes a copy because it is more reliable than eviscerating
    -- Buffer to steal its memory (including custom deallocator).
    ForeignPtr Buffer -> (BufferPtr -> IO ByteString) -> IO ByteString
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Buffer
foreignPtr ((BufferPtr -> IO ByteString) -> IO ByteString)
-> (BufferPtr -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$
        \ptr :: BufferPtr
ptr -> CStringLen -> IO ByteString
B.packCStringLen (CStringLen -> IO ByteString) -> IO CStringLen -> IO ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (,)
                (CString -> Int -> CStringLen)
-> IO CString -> IO (Int -> CStringLen)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ptr () -> CString
forall a b. Ptr a -> Ptr b
castPtr (Ptr () -> CString) -> IO (Ptr ()) -> IO CString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BufferPtr -> IO (Ptr ())
Raw.getBufferData BufferPtr
ptr)
                IO (Int -> CStringLen) -> IO Int -> IO CStringLen
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (CULong -> Int
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert (CULong -> Int) -> IO CULong -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BufferPtr -> IO CULong
Raw.getBufferLength BufferPtr
ptr)
    where
      checkCall :: IO BufferPtr
checkCall = do
          BufferPtr
p <- IO BufferPtr
Raw.getAllOpList
          Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (BufferPtr
p BufferPtr -> BufferPtr -> Bool
forall a. Eq a => a -> a -> Bool
== BufferPtr
forall a. Ptr a
nullPtr) (TensorFlowException -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM TensorFlowException
exception)
          BufferPtr -> IO BufferPtr
forall (m :: * -> *) a. Monad m => a -> m a
return BufferPtr
p
      exception :: TensorFlowException
exception = Code -> Text -> TensorFlowException
TensorFlowException
                Code
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"