-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE BangPatterns #-}

{-|
Module      : TensorFlow.Internal.VarInt
Description : Encoders and decoders for varint types.

Originally taken from internal proto-lens code.
-}
module TensorFlow.Internal.VarInt
    ( getVarInt
    , putVarInt
    ) where

import Data.Attoparsec.ByteString as Parse
import Data.Bits
import Data.ByteString.Lazy.Builder as Builder
import Data.Word (Word64)

-- | Decode an unsigned varint.
getVarInt :: Parser Word64
getVarInt :: Parser Word64
getVarInt = Word64 -> Word64 -> Parser Word64
forall t. Num t => t -> t -> Parser ByteString t
loop 1 0
  where
    loop :: t -> t -> Parser ByteString t
loop !t
s !t
n = do
        Word8
b <- Parser Word8
anyWord8
        let n' :: t
n' = t
n t -> t -> t
forall a. Num a => a -> a -> a
+ t
s t -> t -> t
forall a. Num a => a -> a -> a
* Word8 -> t
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
b Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. 127)
        if (Word8
b Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. 128) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== 0
            then t -> Parser ByteString t
forall (m :: * -> *) a. Monad m => a -> m a
return t
n'
            else t -> t -> Parser ByteString t
loop (128t -> t -> t
forall a. Num a => a -> a -> a
*t
s) t
n'

-- | Encode a Word64.
putVarInt :: Word64 -> Builder
putVarInt :: Word64 -> Builder
putVarInt n :: Word64
n
    | Word64
n Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< 128 = Word8 -> Builder
Builder.word8 (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
n)
    | Bool
otherwise = Word8 -> Builder
Builder.word8 (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Word8) -> Word64 -> Word8
forall a b. (a -> b) -> a -> b
$ Word64
n Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. 127 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. 128)
                      Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word64 -> Builder
putVarInt (Word64
n Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` 7)