{-# LANGUAGE Rank2Types #-}
module TensorFlow.Records
(
putTFRecord
, getTFRecord
, getTFRecords
, getTFRecordLength
, getTFRecordData
, putTFRecordLength
, putTFRecordData
) where
import Control.Exception (evaluate)
import Control.Monad (when)
import Data.ByteString.Unsafe (unsafePackCStringLen)
import qualified Data.ByteString.Builder as B (Builder)
import Data.ByteString.Builder.Extra (runBuilder, Next(..))
import qualified Data.ByteString.Lazy as BL
import Data.Serialize.Get
( Get
, getBytes
, getWord32le
, getWord64le
, getLazyByteString
, isEmpty
, lookAhead
)
import Data.Serialize
( Put
, execPut
, putLazyByteString
, putWord32le
, putWord64le
)
import Data.Word (Word8, Word64)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Ptr (Ptr, castPtr)
import System.IO.Unsafe (unsafePerformIO)
import TensorFlow.CRC32C (crc32cLBSMasked, crc32cUpdate, crc32cMask)
getTFRecord :: Get BL.ByteString
getTFRecord :: Get ByteString
getTFRecord = Get Word64
getTFRecordLength Get Word64 -> (Word64 -> Get ByteString) -> Get ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Word64 -> Get ByteString
getTFRecordData
getTFRecords :: Get [BL.ByteString]
getTFRecords :: Get [ByteString]
getTFRecords = do
Bool
e <- Get Bool
isEmpty
if Bool
e then [ByteString] -> Get [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [] else (:) (ByteString -> [ByteString] -> [ByteString])
-> Get ByteString -> Get ([ByteString] -> [ByteString])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
getTFRecord Get ([ByteString] -> [ByteString])
-> Get [ByteString] -> Get [ByteString]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get [ByteString]
getTFRecords
getCheckMaskedCRC32C :: BL.ByteString -> Get ()
getCheckMaskedCRC32C :: ByteString -> Get ()
getCheckMaskedCRC32C bs :: ByteString
bs = do
Word32
wireCRC <- Get Word32
getWord32le
let maskedCRC :: Word32
maskedCRC = ByteString -> Word32
crc32cLBSMasked ByteString
bs
Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word32
maskedCRC Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word32
wireCRC) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get ()) -> String -> Get ()
forall a b. (a -> b) -> a -> b
$
"getCheckMaskedCRC32C: CRC mismatch, computed: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
maskedCRC String -> String -> String
forall a. [a] -> [a] -> [a]
++
", expected: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
wireCRC
getTFRecordLength :: Get Word64
getTFRecordLength :: Get Word64
getTFRecordLength = do
ByteString
buf <- Get ByteString -> Get ByteString
forall a. Get a -> Get a
lookAhead (Int -> Get ByteString
getBytes 8)
Get Word64
getWord64le Get Word64 -> Get () -> Get Word64
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ByteString -> Get ()
getCheckMaskedCRC32C (ByteString -> ByteString
BL.fromStrict ByteString
buf)
getTFRecordData :: Word64 -> Get BL.ByteString
getTFRecordData :: Word64 -> Get ByteString
getTFRecordData len :: Word64
len = if Word64
len Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
> 0x7fffffffffffffff
then String -> Get ByteString
forall (m :: * -> *) a. MonadFail m => String -> m a
fail "getTFRecordData: Record size overflows Int64"
else do
ByteString
bs <- Int64 -> Get ByteString
getLazyByteString (Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
len)
ByteString -> Get ()
getCheckMaskedCRC32C ByteString
bs
ByteString -> Get ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
putMaskedCRC32C :: BL.ByteString -> Put
putMaskedCRC32C :: ByteString -> Put
putMaskedCRC32C = Putter Word32
putWord32le Putter Word32 -> (ByteString -> Word32) -> ByteString -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Word32
crc32cLBSMasked
unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r
unsafeWithFixedWidthBuilder :: Int -> Builder -> (Ptr Word8 -> IO r) -> IO r
unsafeWithFixedWidthBuilder n :: Int
n b :: Builder
b act :: Ptr Word8 -> IO r
act = Int -> (Ptr Word8 -> IO r) -> IO r
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
n ((Ptr Word8 -> IO r) -> IO r) -> (Ptr Word8 -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr Word8
ptr -> do
(_, signal :: Next
signal) <- Builder -> BufferWriter
runBuilder Builder
b Ptr Word8
ptr Int
n
case Next
signal of
Done -> Ptr Word8 -> IO r
act Ptr Word8
ptr
More _ _ -> String -> IO r
forall a. HasCallStack => String -> a
error "unsafeWithFixedWidthBuilder: Builder returned More."
Chunk _ _ -> String -> IO r
forall a. HasCallStack => String -> a
error "unsafeWithFixedWidthBuilder: Builder returned Chunk."
putTFRecordLength :: Word64 -> Put
putTFRecordLength :: Word64 -> Put
putTFRecordLength x :: Word64
x =
let put :: Put
put = Word64 -> Put
putWord64le Word64
x
len :: Int
len = 8
crc :: Word32
crc = Word32 -> Word32
crc32cMask (Word32 -> Word32) -> Word32 -> Word32
forall a b. (a -> b) -> a -> b
$ IO Word32 -> Word32
forall a. IO a -> a
unsafePerformIO (IO Word32 -> Word32) -> IO Word32 -> Word32
forall a b. (a -> b) -> a -> b
$
Int -> Builder -> (Ptr Word8 -> IO Word32) -> IO Word32
forall r. Int -> Builder -> (Ptr Word8 -> IO r) -> IO r
unsafeWithFixedWidthBuilder Int
len (Put -> Builder
forall a. PutM a -> Builder
execPut Put
put) ((Ptr Word8 -> IO Word32) -> IO Word32)
-> (Ptr Word8 -> IO Word32) -> IO Word32
forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr Word8
ptr -> do
ByteString
str <- CStringLen -> IO ByteString
unsafePackCStringLen (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr, Int
len)
Word32 -> IO Word32
forall a. a -> IO a
evaluate (Word32 -> IO Word32) -> Word32 -> IO Word32
forall a b. (a -> b) -> a -> b
$ Word32 -> ByteString -> Word32
crc32cUpdate 0 ByteString
str
in Put
put Put -> Put -> Put
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Putter Word32
putWord32le Word32
crc
putTFRecordData :: BL.ByteString -> Put
putTFRecordData :: ByteString -> Put
putTFRecordData bs :: ByteString
bs = ByteString -> Put
putLazyByteString ByteString
bs Put -> Put -> Put
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Put
putMaskedCRC32C ByteString
bs
putTFRecord :: BL.ByteString -> Put
putTFRecord :: ByteString -> Put
putTFRecord bs :: ByteString
bs =
Word64 -> Put
putTFRecordLength (Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Word64) -> Int64 -> Word64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int64
BL.length ByteString
bs) Put -> Put -> Put
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Put
putTFRecordData ByteString
bs