-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

-- | Encoder and decoder for the TensorFlow \"TFRecords\" format.

{-# LANGUAGE Rank2Types #-}
module TensorFlow.Records
  (
  -- * Records
    putTFRecord
  , getTFRecord
  , getTFRecords

  -- * Implementation

  -- | These may be useful for encoding or decoding to types other than
  -- 'ByteString' that have their own Cereal codecs.
  , getTFRecordLength
  , getTFRecordData
  , putTFRecordLength
  , putTFRecordData
  ) where

import Control.Exception (evaluate)
import Control.Monad (when)
import Data.ByteString.Unsafe (unsafePackCStringLen)
import qualified Data.ByteString.Builder as B (Builder)
import Data.ByteString.Builder.Extra (runBuilder, Next(..))
import qualified Data.ByteString.Lazy as BL
import Data.Serialize.Get
  ( Get
  , getBytes
  , getWord32le
  , getWord64le
  , getLazyByteString
  , isEmpty
  , lookAhead
  )
import Data.Serialize
  ( Put
  , execPut
  , putLazyByteString
  , putWord32le
  , putWord64le
  )
import Data.Word (Word8, Word64)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Ptr (Ptr, castPtr)
import System.IO.Unsafe (unsafePerformIO)

import TensorFlow.CRC32C (crc32cLBSMasked, crc32cUpdate, crc32cMask)

-- | Parse one TFRecord.
getTFRecord :: Get BL.ByteString
getTFRecord :: Get ByteString
getTFRecord = Get Word64
getTFRecordLength Get Word64 -> (Word64 -> Get ByteString) -> Get ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Word64 -> Get ByteString
getTFRecordData

-- | Parse many TFRecords as a list.  Note you probably want streaming instead
-- as provided by the tensorflow-records-conduit package.
getTFRecords :: Get [BL.ByteString]
getTFRecords :: Get [ByteString]
getTFRecords = do
  Bool
e <- Get Bool
isEmpty
  if Bool
e then [ByteString] -> Get [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [] else (:) (ByteString -> [ByteString] -> [ByteString])
-> Get ByteString -> Get ([ByteString] -> [ByteString])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
getTFRecord Get ([ByteString] -> [ByteString])
-> Get [ByteString] -> Get [ByteString]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get [ByteString]
getTFRecords

getCheckMaskedCRC32C :: BL.ByteString -> Get ()
getCheckMaskedCRC32C :: ByteString -> Get ()
getCheckMaskedCRC32C bs :: ByteString
bs = do
  Word32
wireCRC <- Get Word32
getWord32le
  let maskedCRC :: Word32
maskedCRC = ByteString -> Word32
crc32cLBSMasked ByteString
bs
  Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word32
maskedCRC Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word32
wireCRC) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get ()) -> String -> Get ()
forall a b. (a -> b) -> a -> b
$
      "getCheckMaskedCRC32C: CRC mismatch, computed: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
maskedCRC String -> String -> String
forall a. [a] -> [a] -> [a]
++
      ", expected: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
wireCRC

-- | Get a length and verify its checksum.
getTFRecordLength :: Get Word64
getTFRecordLength :: Get Word64
getTFRecordLength = do
  ByteString
buf <- Get ByteString -> Get ByteString
forall a. Get a -> Get a
lookAhead (Int -> Get ByteString
getBytes 8)
  Get Word64
getWord64le Get Word64 -> Get () -> Get Word64
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ByteString -> Get ()
getCheckMaskedCRC32C (ByteString -> ByteString
BL.fromStrict ByteString
buf)

-- | Get a record payload and verify its checksum.
getTFRecordData :: Word64 -> Get BL.ByteString
getTFRecordData :: Word64 -> Get ByteString
getTFRecordData len :: Word64
len = if Word64
len Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
> 0x7fffffffffffffff
  then String -> Get ByteString
forall (m :: * -> *) a. MonadFail m => String -> m a
fail "getTFRecordData: Record size overflows Int64"
  else do
    ByteString
bs <- Int64 -> Get ByteString
getLazyByteString (Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
len)
    ByteString -> Get ()
getCheckMaskedCRC32C ByteString
bs
    ByteString -> Get ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

putMaskedCRC32C :: BL.ByteString -> Put
putMaskedCRC32C :: ByteString -> Put
putMaskedCRC32C = Putter Word32
putWord32le Putter Word32 -> (ByteString -> Word32) -> ByteString -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Word32
crc32cLBSMasked

-- Runs a Builder that's known to write a fixed number of bytes on an 'alloca'
-- buffer, and runs the given IO action on the result.  Raises exceptions if
-- the Builder yields ByteString chunks or attempts to write more bytes than
-- expected.
unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r
unsafeWithFixedWidthBuilder :: Int -> Builder -> (Ptr Word8 -> IO r) -> IO r
unsafeWithFixedWidthBuilder n :: Int
n b :: Builder
b act :: Ptr Word8 -> IO r
act = Int -> (Ptr Word8 -> IO r) -> IO r
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
n ((Ptr Word8 -> IO r) -> IO r) -> (Ptr Word8 -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr Word8
ptr -> do
  (_, signal :: Next
signal) <- Builder -> BufferWriter
runBuilder Builder
b Ptr Word8
ptr Int
n
  case Next
signal of
    Done -> Ptr Word8 -> IO r
act Ptr Word8
ptr
    More _ _ -> String -> IO r
forall a. HasCallStack => String -> a
error "unsafeWithFixedWidthBuilder: Builder returned More."
    Chunk _ _ -> String -> IO r
forall a. HasCallStack => String -> a
error "unsafeWithFixedWidthBuilder: Builder returned Chunk."

-- | Put a record length and its checksum.
putTFRecordLength :: Word64 -> Put
putTFRecordLength :: Word64 -> Put
putTFRecordLength x :: Word64
x =
  let put :: Put
put = Word64 -> Put
putWord64le Word64
x
      len :: Int
len = 8
      crc :: Word32
crc = Word32 -> Word32
crc32cMask (Word32 -> Word32) -> Word32 -> Word32
forall a b. (a -> b) -> a -> b
$ IO Word32 -> Word32
forall a. IO a -> a
unsafePerformIO (IO Word32 -> Word32) -> IO Word32 -> Word32
forall a b. (a -> b) -> a -> b
$
          -- Serialized Word64 is always 8 bytes, so we can go fast by using
          -- alloca.
          Int -> Builder -> (Ptr Word8 -> IO Word32) -> IO Word32
forall r. Int -> Builder -> (Ptr Word8 -> IO r) -> IO r
unsafeWithFixedWidthBuilder Int
len (Put -> Builder
forall a. PutM a -> Builder
execPut Put
put) ((Ptr Word8 -> IO Word32) -> IO Word32)
-> (Ptr Word8 -> IO Word32) -> IO Word32
forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr Word8
ptr -> do
              ByteString
str <- CStringLen -> IO ByteString
unsafePackCStringLen (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr, Int
len)
              -- Force the result to ensure it's evaluated before freeing ptr.
              Word32 -> IO Word32
forall a. a -> IO a
evaluate (Word32 -> IO Word32) -> Word32 -> IO Word32
forall a b. (a -> b) -> a -> b
$ Word32 -> ByteString -> Word32
crc32cUpdate 0 ByteString
str
  in  Put
put Put -> Put -> Put
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Putter Word32
putWord32le Word32
crc

-- | Put a record payload and its checksum.
putTFRecordData :: BL.ByteString -> Put
putTFRecordData :: ByteString -> Put
putTFRecordData bs :: ByteString
bs = ByteString -> Put
putLazyByteString ByteString
bs Put -> Put -> Put
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Put
putMaskedCRC32C ByteString
bs

-- | Put one TFRecord with the given contents.
putTFRecord :: BL.ByteString -> Put
putTFRecord :: ByteString -> Put
putTFRecord bs :: ByteString
bs =
  Word64 -> Put
putTFRecordLength (Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Word64) -> Int64 -> Word64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int64
BL.length ByteString
bs) Put -> Put -> Put
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Put
putTFRecordData ByteString
bs