-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ViewPatterns #-}

module TensorFlow.Examples.MNIST.Parse where

import Control.Monad (when, liftM)
import Data.Binary.Get (Get, runGet, getWord32be, getLazyByteString)
import Data.ByteString.Lazy (toStrict, readFile)
import Data.List.Split (chunksOf)
import Data.ProtoLens (Message, decodeMessageOrDie)
import Data.Text (Text)
import Data.Word (Word8, Word32)
import Prelude hiding (readFile)
import qualified Codec.Compression.GZip as GZip
import qualified Data.ByteString.Lazy as L
import qualified Data.Text as Text
import qualified Data.Vector as V

-- | Utilities specific to MNIST.
type MNIST = V.Vector Word8

-- | Produces a unicode rendering of the MNIST digit sample.
drawMNIST :: MNIST -> Text
drawMNIST :: MNIST -> Text
drawMNIST = Text -> Text
chunk (Text -> Text) -> (MNIST -> Text) -> MNIST -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MNIST -> Text
block
  where
    block :: V.Vector Word8 -> Text
    block :: MNIST -> Text
block (Int -> MNIST -> (MNIST, MNIST)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt 1 -> ([0], xs :: MNIST
xs)) = " " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> MNIST -> Text
block MNIST
xs
    block (Int -> MNIST -> (MNIST, MNIST)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt 1 -> ([n :: Item MNIST
n], xs :: MNIST
xs)) = Char
c Char -> Text -> Text
`Text.cons` MNIST -> Text
block MNIST
xs
      where c :: Char
c = "\9617\9618\9619\9608" [Char] -> Int -> Char
forall a. [a] -> Int -> a
!! Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
Item MNIST
n Word8 -> Word8 -> Word8
forall a. Integral a => a -> a -> a
`div` 64)
    block (Int -> MNIST -> (MNIST, MNIST)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt 1 -> (MNIST, MNIST)
_)   = ""
    chunk :: Text -> Text
    chunk :: Text -> Text
chunk "" = "\n"
    chunk xs :: Text
xs = Int -> Text -> Text
Text.take 28 Text
xs Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> "\n" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
chunk (Int -> Text -> Text
Text.drop 28 Text
xs)

-- | Check's the file's endianess, throwing an error if it's not as expected.
checkEndian :: Get ()
checkEndian :: Get ()
checkEndian = do
    Word32
magic <- Get Word32
getWord32be
    Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word32
magic Word32 -> [Word32] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` ([2049, 2051] :: [Word32])) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$
        [Char] -> Get ()
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail "Expected big endian, but image file is little endian."

-- | Reads an MNIST file and returns a list of samples.
readMNISTSamples :: FilePath -> IO [MNIST]
readMNISTSamples :: [Char] -> IO [MNIST]
readMNISTSamples path :: [Char]
path = do
    ByteString
raw <- ByteString -> ByteString
GZip.decompress (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> IO ByteString
readFile [Char]
path
    [MNIST] -> IO [MNIST]
forall (m :: * -> *) a. Monad m => a -> m a
return ([MNIST] -> IO [MNIST]) -> [MNIST] -> IO [MNIST]
forall a b. (a -> b) -> a -> b
$ Get [MNIST] -> ByteString -> [MNIST]
forall a. Get a -> ByteString -> a
runGet Get [MNIST]
getMNIST ByteString
raw
  where
    getMNIST :: Get [MNIST]
    getMNIST :: Get [MNIST]
getMNIST = do
        Get ()
checkEndian
        -- Parse header data.
        Int
cnt  <- (Word32 -> Int) -> Get Word32 -> Get Int
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Get Word32
getWord32be
        Int
rows <- (Word32 -> Int) -> Get Word32 -> Get Int
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Get Word32
getWord32be
        Int
cols <- (Word32 -> Int) -> Get Word32 -> Get Int
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Get Word32
getWord32be
        -- Read all of the data, then split into samples.
        ByteString
pixels <- Int64 -> Get ByteString
getLazyByteString (Int64 -> Get ByteString) -> Int64 -> Get ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> Int -> Int64
forall a b. (a -> b) -> a -> b
$ Int
cnt Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cols
        [MNIST] -> Get [MNIST]
forall (m :: * -> *) a. Monad m => a -> m a
return ([MNIST] -> Get [MNIST]) -> [MNIST] -> Get [MNIST]
forall a b. (a -> b) -> a -> b
$ [Word8] -> MNIST
forall a. [a] -> Vector a
V.fromList ([Word8] -> MNIST) -> [[Word8]] -> [MNIST]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [Word8] -> [[Word8]]
forall e. Int -> [e] -> [[e]]
chunksOf (Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cols) (ByteString -> [Word8]
L.unpack ByteString
pixels)

-- | Reads a list of MNIST labels from a file and returns them.
readMNISTLabels :: FilePath -> IO [Word8]
readMNISTLabels :: [Char] -> IO [Word8]
readMNISTLabels path :: [Char]
path = do
    ByteString
raw <- ByteString -> ByteString
GZip.decompress (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> IO ByteString
readFile [Char]
path
    [Word8] -> IO [Word8]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Word8] -> IO [Word8]) -> [Word8] -> IO [Word8]
forall a b. (a -> b) -> a -> b
$ Get [Word8] -> ByteString -> [Word8]
forall a. Get a -> ByteString -> a
runGet Get [Word8]
getLabels ByteString
raw
  where getLabels :: Get [Word8]
        getLabels :: Get [Word8]
getLabels = do
            Get ()
checkEndian
            -- Parse header data.
            Int64
cnt <- (Word32 -> Int64) -> Get Word32 -> Get Int64
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Word32 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Get Word32
getWord32be
            -- Read all of the labels.
            ByteString -> [Word8]
L.unpack (ByteString -> [Word8]) -> Get ByteString -> Get [Word8]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int64 -> Get ByteString
getLazyByteString Int64
cnt

readMessageFromFileOrDie :: Message m => FilePath -> IO m
readMessageFromFileOrDie :: [Char] -> IO m
readMessageFromFileOrDie path :: [Char]
path = do
    ByteString
pb <- [Char] -> IO ByteString
readFile [Char]
path
    m -> IO m
forall (m :: * -> *) a. Monad m => a -> m a
return (m -> IO m) -> m -> IO m
forall a b. (a -> b) -> a -> b
$ ByteString -> m
forall msg. Message msg => ByteString -> msg
decodeMessageOrDie (ByteString -> m) -> ByteString -> m
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
toStrict ByteString
pb

-- TODO: Write a writeMessageFromFileOrDie and read/write non-lethal
--             versions.