-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-- We use UndecidableInstances for type families with recursive definitions
-- like "\\".  Those instances will terminate since each equation unwraps one
-- cons cell of a type-level list.
{-# LANGUAGE UndecidableInstances #-}

module TensorFlow.Types
    ( TensorType(..)
    , TensorData(..)
    , TensorDataType(..)
    , Scalar(..)
    , Shape(..)
    , protoShape
    , Attribute(..)
    , DataType(..)
    , ResourceHandle
    , Variant
    -- * Lists
    , ListOf(..)
    , List
    , (/:/)
    , TensorTypeProxy(..)
    , TensorTypes(..)
    , TensorTypeList
    , fromTensorTypeList
    , fromTensorTypes
    -- * Type constraints
    , OneOf
    , type (/=)
    , OneOfs
    -- ** Implementation of constraints
    , TypeError
    , ExcludedCase
    , NoneOf
    , type (\\)
    , Delete
    , AllTensorTypes
    ) where

import Data.ProtoLens.Message(defMessage)
import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Maybe (fromMaybe)
import Data.ProtoLens.TextFormat (showMessageShort)
import Data.Proxy (Proxy(..))
import Data.String (IsString)
import Data.Word (Word8, Word16, Word32, Word64)
import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~), (^..), under)
import Lens.Family2.Unchecked (adapter)
import Text.Printf (printf)
import qualified Data.Attoparsec.ByteString as Atto
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Lazy as L
import qualified Data.Vector as V
import qualified Data.Vector.Storable as S
import Proto.Tensorflow.Core.Framework.AttrValue
    ( AttrValue
    , AttrValue'ListValue
    )
import Proto.Tensorflow.Core.Framework.AttrValue_Fields
    ( b
    , f
    , i
    , s
    , list
    , type'
    , shape
    , tensor
    )

import Proto.Tensorflow.Core.Framework.ResourceHandle
    (ResourceHandleProto)
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
    (TensorProto)
import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
    ( boolVal
    , doubleVal
    , floatVal
    , intVal
    , int64Val
    , resourceHandleVal
    , stringVal
    , uint32Val
    , uint64Val
    )

import Proto.Tensorflow.Core.Framework.TensorShape
    (TensorShapeProto)
import Proto.Tensorflow.Core.Framework.TensorShape_Fields
    ( dim
    , size
    , unknownRank
    )
import Proto.Tensorflow.Core.Framework.Types (DataType(..))

import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
import qualified TensorFlow.Internal.FFI as FFI

type ResourceHandle = ResourceHandleProto

-- | Dynamic type.
-- TensorFlow variants aren't supported yet. This type acts a placeholder to
-- simplify op generation.
data Variant

-- | The class of scalar types supported by tensorflow.
class TensorType a where
    tensorType :: a -> DataType
    tensorRefType :: a -> DataType
    tensorVal :: Lens' TensorProto [a]

instance TensorType Float where
    tensorType :: Float -> DataType
tensorType _ = DataType
DT_FLOAT
    tensorRefType :: Float -> DataType
tensorRefType _ = DataType
DT_FLOAT_REF
    tensorVal :: LensLike' f TensorProto [Float]
tensorVal = LensLike' f TensorProto [Float]
forall (f :: * -> *) s a.
(Functor f, HasField s "floatVal" a) =>
LensLike' f s a
floatVal

instance TensorType Double where
    tensorType :: Double -> DataType
tensorType _ = DataType
DT_DOUBLE
    tensorRefType :: Double -> DataType
tensorRefType _ = DataType
DT_DOUBLE_REF
    tensorVal :: LensLike' f TensorProto [Double]
tensorVal = LensLike' f TensorProto [Double]
forall (f :: * -> *) s a.
(Functor f, HasField s "doubleVal" a) =>
LensLike' f s a
doubleVal

instance TensorType Int32 where
    tensorType :: Int32 -> DataType
tensorType _ = DataType
DT_INT32
    tensorRefType :: Int32 -> DataType
tensorRefType _ = DataType
DT_INT32_REF
    tensorVal :: LensLike' f TensorProto [Int32]
tensorVal = LensLike' f TensorProto [Int32]
forall (f :: * -> *) s a.
(Functor f, HasField s "intVal" a) =>
LensLike' f s a
intVal

instance TensorType Int64 where
    tensorType :: Int64 -> DataType
tensorType _ = DataType
DT_INT64
    tensorRefType :: Int64 -> DataType
tensorRefType _ = DataType
DT_INT64_REF
    tensorVal :: LensLike' f TensorProto [Int64]
tensorVal = LensLike' f TensorProto [Int64]
forall (f :: * -> *) s a.
(Functor f, HasField s "int64Val" a) =>
LensLike' f s a
int64Val

integral :: Integral a => Lens' [Int32] [a]
integral :: Lens' [Int32] [a]
integral = Resetter [Int32] (f [Int32]) [a] (f [a])
-> ([a] -> f [a]) -> [Int32] -> f [Int32]
forall s t a b. Resetter s t a b -> (a -> b) -> s -> t
under (([Int32] -> [a])
-> ([a] -> [Int32]) -> Adapter [Int32] [Int32] [a] [a]
forall s a b t. (s -> a) -> (b -> t) -> Adapter s t a b
adapter ((Int32 -> a) -> [Int32] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int32 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral) ((a -> Int32) -> [a] -> [Int32]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral))

instance TensorType Word8 where
    tensorType :: Word8 -> DataType
tensorType _ = DataType
DT_UINT8
    tensorRefType :: Word8 -> DataType
tensorRefType _ = DataType
DT_UINT8_REF
    tensorVal :: LensLike' f TensorProto [Word8]
tensorVal = LensLike' f TensorProto [Int32]
forall (f :: * -> *) s a.
(Functor f, HasField s "intVal" a) =>
LensLike' f s a
intVal LensLike' f TensorProto [Int32]
-> (([Word8] -> f [Word8]) -> [Int32] -> f [Int32])
-> LensLike' f TensorProto [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Word8] -> f [Word8]) -> [Int32] -> f [Int32]
forall a. Integral a => Lens' [Int32] [a]
integral

instance TensorType Word16 where
    tensorType :: Word16 -> DataType
tensorType _ = DataType
DT_UINT16
    tensorRefType :: Word16 -> DataType
tensorRefType _ = DataType
DT_UINT16_REF
    tensorVal :: LensLike' f TensorProto [Word16]
tensorVal = LensLike' f TensorProto [Int32]
forall (f :: * -> *) s a.
(Functor f, HasField s "intVal" a) =>
LensLike' f s a
intVal LensLike' f TensorProto [Int32]
-> (([Word16] -> f [Word16]) -> [Int32] -> f [Int32])
-> LensLike' f TensorProto [Word16]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Word16] -> f [Word16]) -> [Int32] -> f [Int32]
forall a. Integral a => Lens' [Int32] [a]
integral

instance TensorType Word32 where
    tensorType :: Word32 -> DataType
tensorType _ = DataType
DT_UINT32
    tensorRefType :: Word32 -> DataType
tensorRefType _ = DataType
DT_UINT32_REF
    tensorVal :: LensLike' f TensorProto [Word32]
tensorVal = LensLike' f TensorProto [Word32]
forall (f :: * -> *) s a.
(Functor f, HasField s "uint32Val" a) =>
LensLike' f s a
uint32Val

instance TensorType Word64 where
    tensorType :: Word64 -> DataType
tensorType _ = DataType
DT_UINT64
    tensorRefType :: Word64 -> DataType
tensorRefType _ = DataType
DT_UINT64_REF
    tensorVal :: LensLike' f TensorProto [Word64]
tensorVal = LensLike' f TensorProto [Word64]
forall (f :: * -> *) s a.
(Functor f, HasField s "uint64Val" a) =>
LensLike' f s a
uint64Val

instance TensorType Int16 where
    tensorType :: Int16 -> DataType
tensorType _ = DataType
DT_INT16
    tensorRefType :: Int16 -> DataType
tensorRefType _ = DataType
DT_INT16_REF
    tensorVal :: LensLike' f TensorProto [Int16]
tensorVal = LensLike' f TensorProto [Int32]
forall (f :: * -> *) s a.
(Functor f, HasField s "intVal" a) =>
LensLike' f s a
intVal LensLike' f TensorProto [Int32]
-> (([Int16] -> f [Int16]) -> [Int32] -> f [Int32])
-> LensLike' f TensorProto [Int16]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Int16] -> f [Int16]) -> [Int32] -> f [Int32]
forall a. Integral a => Lens' [Int32] [a]
integral

instance TensorType Int8 where
    tensorType :: Int8 -> DataType
tensorType _ = DataType
DT_INT8
    tensorRefType :: Int8 -> DataType
tensorRefType _ = DataType
DT_INT8_REF
    tensorVal :: LensLike' f TensorProto [Int8]
tensorVal = LensLike' f TensorProto [Int32]
forall (f :: * -> *) s a.
(Functor f, HasField s "intVal" a) =>
LensLike' f s a
intVal LensLike' f TensorProto [Int32]
-> (([Int8] -> f [Int8]) -> [Int32] -> f [Int32])
-> LensLike' f TensorProto [Int8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Int8] -> f [Int8]) -> [Int32] -> f [Int32]
forall a. Integral a => Lens' [Int32] [a]
integral

instance TensorType ByteString where
    tensorType :: ByteString -> DataType
tensorType _ = DataType
DT_STRING
    tensorRefType :: ByteString -> DataType
tensorRefType _ = DataType
DT_STRING_REF
    tensorVal :: LensLike' f TensorProto [ByteString]
tensorVal = LensLike' f TensorProto [ByteString]
forall (f :: * -> *) s a.
(Functor f, HasField s "stringVal" a) =>
LensLike' f s a
stringVal

instance TensorType Bool where
    tensorType :: Bool -> DataType
tensorType _ = DataType
DT_BOOL
    tensorRefType :: Bool -> DataType
tensorRefType _ = DataType
DT_BOOL_REF
    tensorVal :: LensLike' f TensorProto [Bool]
tensorVal = LensLike' f TensorProto [Bool]
forall (f :: * -> *) s a.
(Functor f, HasField s "boolVal" a) =>
LensLike' f s a
boolVal

instance TensorType (Complex Float) where
    tensorType :: Complex Float -> DataType
tensorType _ = DataType
DT_COMPLEX64
    tensorRefType :: Complex Float -> DataType
tensorRefType _ = DataType
DT_COMPLEX64
    tensorVal :: LensLike' f TensorProto [Complex Float]
tensorVal = [Char] -> LensLike' f TensorProto [Complex Float]
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Float)"

instance TensorType (Complex Double) where
    tensorType :: Complex Double -> DataType
tensorType _ = DataType
DT_COMPLEX128
    tensorRefType :: Complex Double -> DataType
tensorRefType _ = DataType
DT_COMPLEX128
    tensorVal :: LensLike' f TensorProto [Complex Double]
tensorVal = [Char] -> LensLike' f TensorProto [Complex Double]
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Double)"

instance TensorType ResourceHandle where
    tensorType :: ResourceHandle -> DataType
tensorType _ = DataType
DT_RESOURCE
    tensorRefType :: ResourceHandle -> DataType
tensorRefType _ = DataType
DT_RESOURCE_REF
    tensorVal :: LensLike' f TensorProto [ResourceHandle]
tensorVal = LensLike' f TensorProto [ResourceHandle]
forall (f :: * -> *) s a.
(Functor f, HasField s "resourceHandleVal" a) =>
LensLike' f s a
resourceHandleVal

instance TensorType Variant where
    tensorType :: Variant -> DataType
tensorType _ = DataType
DT_VARIANT
    tensorRefType :: Variant -> DataType
tensorRefType _ = DataType
DT_VARIANT_REF
    tensorVal :: LensLike' f TensorProto [Variant]
tensorVal = [Char] -> LensLike' f TensorProto [Variant]
forall a. HasCallStack => [Char] -> a
error "TODO Variant"

-- | Tensor data with the correct memory layout for tensorflow.
newtype TensorData a = TensorData { TensorData a -> TensorData
unTensorData :: FFI.TensorData }

-- | Types that can be converted to and from 'TensorData'.
--
-- 'S.Vector' is the most efficient to encode/decode for most element types.
class TensorType a => TensorDataType s a where
    -- | Decode the bytes of a 'TensorData' into an 's'.
    decodeTensorData :: TensorData a -> s a
    -- | Encode an 's' into a 'TensorData'.
    --
    -- The values should be in row major order, e.g.,
    --
    --   element 0:   index (0, ..., 0)
    --   element 1:   index (0, ..., 1)
    --   ...
    encodeTensorData :: Shape -> s a -> TensorData a

-- All types, besides ByteString and Bool, are encoded as simple arrays and we
-- can use Vector.Storable to encode/decode by type casting pointers.

-- TODO(fmayle): Assert that the data type matches the return type.
simpleDecode :: Storable a => TensorData a -> S.Vector a
simpleDecode :: TensorData a -> Vector a
simpleDecode = Vector Word8 -> Vector a
forall a b. (Storable a, Storable b) => Vector a -> Vector b
S.unsafeCast (Vector Word8 -> Vector a)
-> (TensorData a -> Vector Word8) -> TensorData a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData -> Vector Word8
FFI.tensorDataBytes (TensorData -> Vector Word8)
-> (TensorData a -> TensorData) -> TensorData a -> Vector Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData a -> TensorData
forall a. TensorData a -> TensorData
unTensorData

simpleEncode :: forall a . (TensorType a, Storable a)
             => Shape -> S.Vector a -> TensorData a
simpleEncode :: Shape -> Vector a -> TensorData a
simpleEncode (Shape xs :: [Int64]
xs) v :: Vector a
v =
    if [Int64] -> Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
xs Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector a -> Int
forall a. Storable a => Vector a -> Int
S.length Vector a
v)
        then [Char] -> TensorData a
forall a. HasCallStack => [Char] -> a
error ([Char] -> TensorData a) -> [Char] -> TensorData a
forall a b. (a -> b) -> a -> b
$ [Char] -> [Char] -> Int64 -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf
            "simpleEncode: bad vector length for shape %v: expected=%d got=%d"
            ([Int64] -> [Char]
forall a. Show a => a -> [Char]
show [Int64]
xs) ([Int64] -> Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
xs) (Vector a -> Int
forall a. Storable a => Vector a -> Int
S.length Vector a
v)
        else TensorData -> TensorData a
forall a. TensorData -> TensorData a
TensorData ([Int64] -> DataType -> Vector Word8 -> TensorData
FFI.TensorData [Int64]
xs DataType
dt (Vector a -> Vector Word8
forall a b. (Storable a, Storable b) => Vector a -> Vector b
S.unsafeCast Vector a
v))
  where
    dt :: DataType
dt = a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)

instance TensorDataType S.Vector Float where
    decodeTensorData :: TensorData Float -> Vector Float
decodeTensorData = TensorData Float -> Vector Float
forall a. Storable a => TensorData a -> Vector a
simpleDecode
    encodeTensorData :: Shape -> Vector Float -> TensorData Float
encodeTensorData = Shape -> Vector Float -> TensorData Float
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode

instance TensorDataType S.Vector Double where
    decodeTensorData :: TensorData Double -> Vector Double
decodeTensorData = TensorData Double -> Vector Double
forall a. Storable a => TensorData a -> Vector a
simpleDecode
    encodeTensorData :: Shape -> Vector Double -> TensorData Double
encodeTensorData = Shape -> Vector Double -> TensorData Double
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode

instance TensorDataType S.Vector Int8 where
    decodeTensorData :: TensorData Int8 -> Vector Int8
decodeTensorData = TensorData Int8 -> Vector Int8
forall a. Storable a => TensorData a -> Vector a
simpleDecode
    encodeTensorData :: Shape -> Vector Int8 -> TensorData Int8
encodeTensorData = Shape -> Vector Int8 -> TensorData Int8
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode

instance TensorDataType S.Vector Int16 where
    decodeTensorData :: TensorData Int16 -> Vector Int16
decodeTensorData = TensorData Int16 -> Vector Int16
forall a. Storable a => TensorData a -> Vector a
simpleDecode
    encodeTensorData :: Shape -> Vector Int16 -> TensorData Int16
encodeTensorData = Shape -> Vector Int16 -> TensorData Int16
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode

instance TensorDataType S.Vector Int32 where
    decodeTensorData :: TensorData Int32 -> Vector Int32
decodeTensorData = TensorData Int32 -> Vector Int32
forall a. Storable a => TensorData a -> Vector a
simpleDecode
    encodeTensorData :: Shape -> Vector Int32 -> TensorData Int32
encodeTensorData = Shape -> Vector Int32 -> TensorData Int32
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode

instance TensorDataType S.Vector Int64 where
    decodeTensorData :: TensorData Int64 -> Vector Int64
decodeTensorData = TensorData Int64 -> Vector Int64
forall a. Storable a => TensorData a -> Vector a
simpleDecode
    encodeTensorData :: Shape -> Vector Int64 -> TensorData Int64
encodeTensorData = Shape -> Vector Int64 -> TensorData Int64
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode

instance TensorDataType S.Vector Word8 where
    decodeTensorData :: TensorData Word8 -> Vector Word8
decodeTensorData = TensorData Word8 -> Vector Word8
forall a. Storable a => TensorData a -> Vector a
simpleDecode
    encodeTensorData :: Shape -> Vector Word8 -> TensorData Word8
encodeTensorData = Shape -> Vector Word8 -> TensorData Word8
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode

instance TensorDataType S.Vector Word16 where
    decodeTensorData :: TensorData Word16 -> Vector Word16
decodeTensorData = TensorData Word16 -> Vector Word16
forall a. Storable a => TensorData a -> Vector a
simpleDecode
    encodeTensorData :: Shape -> Vector Word16 -> TensorData Word16
encodeTensorData = Shape -> Vector Word16 -> TensorData Word16
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode

-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
-- encoding more expensive. It may make sense to define a custom boolean type.
instance TensorDataType S.Vector Bool where
    decodeTensorData :: TensorData Bool -> Vector Bool
decodeTensorData =
        Vector Bool -> Vector Bool
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
S.convert (Vector Bool -> Vector Bool)
-> (TensorData Bool -> Vector Bool)
-> TensorData Bool
-> Vector Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Bool) -> Vector Word8 -> Vector Bool
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
S.map (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= 0) (Vector Word8 -> Vector Bool)
-> (TensorData Bool -> Vector Word8)
-> TensorData Bool
-> Vector Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData -> Vector Word8
FFI.tensorDataBytes (TensorData -> Vector Word8)
-> (TensorData Bool -> TensorData)
-> TensorData Bool
-> Vector Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData Bool -> TensorData
forall a. TensorData a -> TensorData
unTensorData
    encodeTensorData :: Shape -> Vector Bool -> TensorData Bool
encodeTensorData (Shape xs :: [Int64]
xs) =
        TensorData -> TensorData Bool
forall a. TensorData -> TensorData a
TensorData (TensorData -> TensorData Bool)
-> (Vector Bool -> TensorData) -> Vector Bool -> TensorData Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int64] -> DataType -> Vector Word8 -> TensorData
FFI.TensorData [Int64]
xs DataType
DT_BOOL (Vector Word8 -> TensorData)
-> (Vector Bool -> Vector Word8) -> Vector Bool -> TensorData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> Word8) -> Vector Bool -> Vector Word8
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
S.map Bool -> Word8
fromBool (Vector Bool -> Vector Word8)
-> (Vector Bool -> Vector Bool) -> Vector Bool -> Vector Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Bool -> Vector Bool
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
S.convert
      where
        fromBool :: Bool -> Word8
fromBool x :: Bool
x = if Bool
x then 1 else 0 :: Word8

instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a, TensorType a)
    => TensorDataType V.Vector a where
    decodeTensorData :: TensorData a -> Vector a
decodeTensorData = (Vector a -> Vector a
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
S.convert :: S.Vector a -> V.Vector a) (Vector a -> Vector a)
-> (TensorData a -> Vector a) -> TensorData a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData a -> Vector a
forall (s :: * -> *) a. TensorDataType s a => TensorData a -> s a
decodeTensorData
    encodeTensorData :: Shape -> Vector a -> TensorData a
encodeTensorData x :: Shape
x = Shape -> Vector a -> TensorData a
forall (s :: * -> *) a.
TensorDataType s a =>
Shape -> s a -> TensorData a
encodeTensorData Shape
x (Vector a -> TensorData a)
-> (Vector a -> Vector a) -> Vector a -> TensorData a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector a -> Vector a
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
S.convert :: V.Vector a -> S.Vector a)

instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where
    decodeTensorData :: TensorData (Complex Float) -> Vector (Complex Float)
decodeTensorData = [Char] -> TensorData (Complex Float) -> Vector (Complex Float)
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Float)"
    encodeTensorData :: Shape -> Vector (Complex Float) -> TensorData (Complex Float)
encodeTensorData = [Char]
-> Shape -> Vector (Complex Float) -> TensorData (Complex Float)
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Float)"

instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
    decodeTensorData :: TensorData (Complex Double) -> Vector (Complex Double)
decodeTensorData = [Char] -> TensorData (Complex Double) -> Vector (Complex Double)
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Double)"
    encodeTensorData :: Shape -> Vector (Complex Double) -> TensorData (Complex Double)
encodeTensorData = [Char]
-> Shape -> Vector (Complex Double) -> TensorData (Complex Double)
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Double)"

instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
    -- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
    --   table offsets for each element :: [Word64]
    --   at each element offset:
    --     string length :: VarInt64
    --     string data   :: [Word8]
    decodeTensorData :: TensorData ByteString -> Vector ByteString
decodeTensorData tensorData :: TensorData ByteString
tensorData =
        ([Char] -> Vector ByteString)
-> (Vector ByteString -> Vector ByteString)
-> Either [Char] (Vector ByteString)
-> Vector ByteString
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (\err :: [Char]
err -> [Char] -> Vector ByteString
forall a. HasCallStack => [Char] -> a
error ([Char] -> Vector ByteString) -> [Char] -> Vector ByteString
forall a b. (a -> b) -> a -> b
$ "Malformed TF_STRING tensor; " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
err) Vector ByteString -> Vector ByteString
forall a. a -> a
id (Either [Char] (Vector ByteString) -> Vector ByteString)
-> Either [Char] (Vector ByteString) -> Vector ByteString
forall a b. (a -> b) -> a -> b
$
            if Int
expected Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
count
                then [Char] -> Either [Char] (Vector ByteString)
forall a b. a -> Either a b
Left ([Char] -> Either [Char] (Vector ByteString))
-> [Char] -> Either [Char] (Vector ByteString)
forall a b. (a -> b) -> a -> b
$ "decodeTensorData for ByteString count mismatch " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
                            (Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (Int
expected, Int
count)
                else (Word64 -> Either [Char] ByteString)
-> Vector Word64 -> Either [Char] (Vector ByteString)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM Word64 -> Either [Char] ByteString
decodeString (Vector Word64 -> Vector Word64
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
S.convert Vector Word64
offsets)
      where
        expected :: Int
expected = Vector Word64 -> Int
forall a. Storable a => Vector a -> Int
S.length Vector Word64
offsets
        count :: Int
count = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ [Int64] -> Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Int64] -> Int64) -> [Int64] -> Int64
forall a b. (a -> b) -> a -> b
$ TensorData -> [Int64]
FFI.tensorDataDimensions
                    (TensorData -> [Int64]) -> TensorData -> [Int64]
forall a b. (a -> b) -> a -> b
$ TensorData ByteString -> TensorData
forall a. TensorData a -> TensorData
unTensorData TensorData ByteString
tensorData
        bytes :: Vector Word8
bytes = TensorData -> Vector Word8
FFI.tensorDataBytes (TensorData -> Vector Word8) -> TensorData -> Vector Word8
forall a b. (a -> b) -> a -> b
$ TensorData ByteString -> TensorData
forall a. TensorData a -> TensorData
unTensorData TensorData ByteString
tensorData
        offsets :: Vector Word64
offsets = Int -> Vector Word64 -> Vector Word64
forall a. Storable a => Int -> Vector a -> Vector a
S.take Int
count (Vector Word64 -> Vector Word64) -> Vector Word64 -> Vector Word64
forall a b. (a -> b) -> a -> b
$ Vector Word8 -> Vector Word64
forall a b. (Storable a, Storable b) => Vector a -> Vector b
S.unsafeCast Vector Word8
bytes :: S.Vector Word64
        dataBytes :: ByteString
dataBytes = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ Vector Word8 -> [Word8]
forall a. Storable a => Vector a -> [a]
S.toList (Vector Word8 -> [Word8]) -> Vector Word8 -> [Word8]
forall a b. (a -> b) -> a -> b
$ Int -> Vector Word8 -> Vector Word8
forall a. Storable a => Int -> Vector a -> Vector a
S.drop (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
* 8) Vector Word8
bytes
        decodeString :: Word64 -> Either String ByteString
        decodeString :: Word64 -> Either [Char] ByteString
decodeString offset :: Word64
offset =
            let stringDataStart :: ByteString
stringDataStart = Int -> ByteString -> ByteString
B.drop (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
offset) ByteString
dataBytes
            in Result ByteString -> Either [Char] ByteString
forall r. Result r -> Either [Char] r
Atto.eitherResult (Result ByteString -> Either [Char] ByteString)
-> Result ByteString -> Either [Char] ByteString
forall a b. (a -> b) -> a -> b
$ Parser ByteString -> ByteString -> Result ByteString
forall a. Parser a -> ByteString -> Result a
Atto.parse Parser ByteString
stringParser ByteString
stringDataStart
        stringParser :: Atto.Parser ByteString
        stringParser :: Parser ByteString
stringParser = Parser Word64
getVarInt Parser Word64 -> (Word64 -> Parser ByteString) -> Parser ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> Parser ByteString
Atto.take (Int -> Parser ByteString)
-> (Word64 -> Int) -> Word64 -> Parser ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
    encodeTensorData :: Shape -> Vector ByteString -> TensorData ByteString
encodeTensorData (Shape xs :: [Int64]
xs) vec :: Vector ByteString
vec =
        TensorData -> TensorData ByteString
forall a. TensorData -> TensorData a
TensorData (TensorData -> TensorData ByteString)
-> TensorData -> TensorData ByteString
forall a b. (a -> b) -> a -> b
$ [Int64] -> DataType -> Vector Word8 -> TensorData
FFI.TensorData [Int64]
xs DataType
dt Vector Word8
byteVector
      where
        dt :: DataType
dt = ByteString -> DataType
forall a. TensorType a => a -> DataType
tensorType (ByteString
forall a. HasCallStack => a
undefined :: ByteString)
        -- Add a string to an offset table and data blob.
        addString :: (Builder, Builder, Word64)
                  -> ByteString
                  -> (Builder, Builder, Word64)
        addString :: (Builder, Builder, Word64)
-> ByteString -> (Builder, Builder, Word64)
addString (table :: Builder
table, strings :: Builder
strings, offset :: Word64
offset) str :: ByteString
str =
            ( Builder
table Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word64 -> Builder
Builder.word64LE Word64
offset
            , Builder
strings Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
lengthBytes Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
Builder.byteString ByteString
str
            , Word64
offset Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
lengthBytesLen Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
strLen
            )
          where
            strLen :: Word64
strLen = Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> Int -> Word64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
str
            lengthBytes :: Builder
lengthBytes = Word64 -> Builder
putVarInt (Word64 -> Builder) -> Word64 -> Builder
forall a b. (a -> b) -> a -> b
$ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> Int -> Word64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
str
            lengthBytesLen :: Word64
lengthBytesLen =
                Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Word64) -> Int64 -> Word64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int64
L.length (ByteString -> Int64) -> ByteString -> Int64
forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
Builder.toLazyByteString Builder
lengthBytes
        -- Encode all strings.
        (table' :: Builder
table', strings' :: Builder
strings', _) = ((Builder, Builder, Word64)
 -> ByteString -> (Builder, Builder, Word64))
-> (Builder, Builder, Word64)
-> Vector ByteString
-> (Builder, Builder, Word64)
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl' (Builder, Builder, Word64)
-> ByteString -> (Builder, Builder, Word64)
addString (Builder
forall a. Monoid a => a
mempty, Builder
forall a. Monoid a => a
mempty, 0) Vector ByteString
vec
        -- Concat offset table with data.
        bytes :: Builder
bytes = Builder
table' Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
strings'
        -- Convert to Vector Word8.
        byteVector :: Vector Word8
byteVector = [Word8] -> Vector Word8
forall a. Storable a => [a] -> Vector a
S.fromList ([Word8] -> Vector Word8) -> [Word8] -> Vector Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> [Word8]
L.unpack (ByteString -> [Word8]) -> ByteString -> [Word8]
forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
Builder.toLazyByteString Builder
bytes

newtype Scalar a = Scalar {Scalar a -> a
unScalar :: a}
    deriving (Int -> Scalar a -> [Char] -> [Char]
[Scalar a] -> [Char] -> [Char]
Scalar a -> [Char]
(Int -> Scalar a -> [Char] -> [Char])
-> (Scalar a -> [Char])
-> ([Scalar a] -> [Char] -> [Char])
-> Show (Scalar a)
forall a. Show a => Int -> Scalar a -> [Char] -> [Char]
forall a. Show a => [Scalar a] -> [Char] -> [Char]
forall a. Show a => Scalar a -> [Char]
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [Scalar a] -> [Char] -> [Char]
$cshowList :: forall a. Show a => [Scalar a] -> [Char] -> [Char]
show :: Scalar a -> [Char]
$cshow :: forall a. Show a => Scalar a -> [Char]
showsPrec :: Int -> Scalar a -> [Char] -> [Char]
$cshowsPrec :: forall a. Show a => Int -> Scalar a -> [Char] -> [Char]
Show, Scalar a -> Scalar a -> Bool
(Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool) -> Eq (Scalar a)
forall a. Eq a => Scalar a -> Scalar a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Scalar a -> Scalar a -> Bool
$c/= :: forall a. Eq a => Scalar a -> Scalar a -> Bool
== :: Scalar a -> Scalar a -> Bool
$c== :: forall a. Eq a => Scalar a -> Scalar a -> Bool
Eq, Eq (Scalar a)
Eq (Scalar a) =>
(Scalar a -> Scalar a -> Ordering)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a -> Scalar a)
-> Ord (Scalar a)
Scalar a -> Scalar a -> Bool
Scalar a -> Scalar a -> Ordering
Scalar a -> Scalar a -> Scalar a
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Scalar a)
forall a. Ord a => Scalar a -> Scalar a -> Bool
forall a. Ord a => Scalar a -> Scalar a -> Ordering
forall a. Ord a => Scalar a -> Scalar a -> Scalar a
min :: Scalar a -> Scalar a -> Scalar a
$cmin :: forall a. Ord a => Scalar a -> Scalar a -> Scalar a
max :: Scalar a -> Scalar a -> Scalar a
$cmax :: forall a. Ord a => Scalar a -> Scalar a -> Scalar a
>= :: Scalar a -> Scalar a -> Bool
$c>= :: forall a. Ord a => Scalar a -> Scalar a -> Bool
> :: Scalar a -> Scalar a -> Bool
$c> :: forall a. Ord a => Scalar a -> Scalar a -> Bool
<= :: Scalar a -> Scalar a -> Bool
$c<= :: forall a. Ord a => Scalar a -> Scalar a -> Bool
< :: Scalar a -> Scalar a -> Bool
$c< :: forall a. Ord a => Scalar a -> Scalar a -> Bool
compare :: Scalar a -> Scalar a -> Ordering
$ccompare :: forall a. Ord a => Scalar a -> Scalar a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Scalar a)
Ord, Integer -> Scalar a
Scalar a -> Scalar a
Scalar a -> Scalar a -> Scalar a
(Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Integer -> Scalar a)
-> Num (Scalar a)
forall a. Num a => Integer -> Scalar a
forall a. Num a => Scalar a -> Scalar a
forall a. Num a => Scalar a -> Scalar a -> Scalar a
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> Scalar a
$cfromInteger :: forall a. Num a => Integer -> Scalar a
signum :: Scalar a -> Scalar a
$csignum :: forall a. Num a => Scalar a -> Scalar a
abs :: Scalar a -> Scalar a
$cabs :: forall a. Num a => Scalar a -> Scalar a
negate :: Scalar a -> Scalar a
$cnegate :: forall a. Num a => Scalar a -> Scalar a
* :: Scalar a -> Scalar a -> Scalar a
$c* :: forall a. Num a => Scalar a -> Scalar a -> Scalar a
- :: Scalar a -> Scalar a -> Scalar a
$c- :: forall a. Num a => Scalar a -> Scalar a -> Scalar a
+ :: Scalar a -> Scalar a -> Scalar a
$c+ :: forall a. Num a => Scalar a -> Scalar a -> Scalar a
Num, Num (Scalar a)
Num (Scalar a) =>
(Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Rational -> Scalar a)
-> Fractional (Scalar a)
Rational -> Scalar a
Scalar a -> Scalar a
Scalar a -> Scalar a -> Scalar a
forall a. Fractional a => Num (Scalar a)
forall a. Fractional a => Rational -> Scalar a
forall a. Fractional a => Scalar a -> Scalar a
forall a. Fractional a => Scalar a -> Scalar a -> Scalar a
forall a.
Num a =>
(a -> a -> a) -> (a -> a) -> (Rational -> a) -> Fractional a
fromRational :: Rational -> Scalar a
$cfromRational :: forall a. Fractional a => Rational -> Scalar a
recip :: Scalar a -> Scalar a
$crecip :: forall a. Fractional a => Scalar a -> Scalar a
/ :: Scalar a -> Scalar a -> Scalar a
$c/ :: forall a. Fractional a => Scalar a -> Scalar a -> Scalar a
$cp1Fractional :: forall a. Fractional a => Num (Scalar a)
Fractional, Fractional (Scalar a)
Scalar a
Fractional (Scalar a) =>
Scalar a
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> Floating (Scalar a)
Scalar a -> Scalar a
Scalar a -> Scalar a -> Scalar a
forall a. Floating a => Fractional (Scalar a)
forall a. Floating a => Scalar a
forall a. Floating a => Scalar a -> Scalar a
forall a. Floating a => Scalar a -> Scalar a -> Scalar a
forall a.
Fractional a =>
a
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> Floating a
log1mexp :: Scalar a -> Scalar a
$clog1mexp :: forall a. Floating a => Scalar a -> Scalar a
log1pexp :: Scalar a -> Scalar a
$clog1pexp :: forall a. Floating a => Scalar a -> Scalar a
expm1 :: Scalar a -> Scalar a
$cexpm1 :: forall a. Floating a => Scalar a -> Scalar a
log1p :: Scalar a -> Scalar a
$clog1p :: forall a. Floating a => Scalar a -> Scalar a
atanh :: Scalar a -> Scalar a
$catanh :: forall a. Floating a => Scalar a -> Scalar a
acosh :: Scalar a -> Scalar a
$cacosh :: forall a. Floating a => Scalar a -> Scalar a
asinh :: Scalar a -> Scalar a
$casinh :: forall a. Floating a => Scalar a -> Scalar a
tanh :: Scalar a -> Scalar a
$ctanh :: forall a. Floating a => Scalar a -> Scalar a
cosh :: Scalar a -> Scalar a
$ccosh :: forall a. Floating a => Scalar a -> Scalar a
sinh :: Scalar a -> Scalar a
$csinh :: forall a. Floating a => Scalar a -> Scalar a
atan :: Scalar a -> Scalar a
$catan :: forall a. Floating a => Scalar a -> Scalar a
acos :: Scalar a -> Scalar a
$cacos :: forall a. Floating a => Scalar a -> Scalar a
asin :: Scalar a -> Scalar a
$casin :: forall a. Floating a => Scalar a -> Scalar a
tan :: Scalar a -> Scalar a
$ctan :: forall a. Floating a => Scalar a -> Scalar a
cos :: Scalar a -> Scalar a
$ccos :: forall a. Floating a => Scalar a -> Scalar a
sin :: Scalar a -> Scalar a
$csin :: forall a. Floating a => Scalar a -> Scalar a
logBase :: Scalar a -> Scalar a -> Scalar a
$clogBase :: forall a. Floating a => Scalar a -> Scalar a -> Scalar a
** :: Scalar a -> Scalar a -> Scalar a
$c** :: forall a. Floating a => Scalar a -> Scalar a -> Scalar a
sqrt :: Scalar a -> Scalar a
$csqrt :: forall a. Floating a => Scalar a -> Scalar a
log :: Scalar a -> Scalar a
$clog :: forall a. Floating a => Scalar a -> Scalar a
exp :: Scalar a -> Scalar a
$cexp :: forall a. Floating a => Scalar a -> Scalar a
pi :: Scalar a
$cpi :: forall a. Floating a => Scalar a
$cp1Floating :: forall a. Floating a => Fractional (Scalar a)
Floating, Num (Scalar a)
Ord (Scalar a)
(Num (Scalar a), Ord (Scalar a)) =>
(Scalar a -> Rational) -> Real (Scalar a)
Scalar a -> Rational
forall a. (Num a, Ord a) => (a -> Rational) -> Real a
forall a. Real a => Num (Scalar a)
forall a. Real a => Ord (Scalar a)
forall a. Real a => Scalar a -> Rational
toRational :: Scalar a -> Rational
$ctoRational :: forall a. Real a => Scalar a -> Rational
$cp2Real :: forall a. Real a => Ord (Scalar a)
$cp1Real :: forall a. Real a => Num (Scalar a)
Real, Floating (Scalar a)
RealFrac (Scalar a)
(RealFrac (Scalar a), Floating (Scalar a)) =>
(Scalar a -> Integer)
-> (Scalar a -> Int)
-> (Scalar a -> (Int, Int))
-> (Scalar a -> (Integer, Int))
-> (Integer -> Int -> Scalar a)
-> (Scalar a -> Int)
-> (Scalar a -> Scalar a)
-> (Int -> Scalar a -> Scalar a)
-> (Scalar a -> Bool)
-> (Scalar a -> Bool)
-> (Scalar a -> Bool)
-> (Scalar a -> Bool)
-> (Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Scalar a)
-> RealFloat (Scalar a)
Int -> Scalar a -> Scalar a
Integer -> Int -> Scalar a
Scalar a -> Bool
Scalar a -> Int
Scalar a -> Integer
Scalar a -> (Int, Int)
Scalar a -> (Integer, Int)
Scalar a -> Scalar a
Scalar a -> Scalar a -> Scalar a
forall a. RealFloat a => Floating (Scalar a)
forall a. RealFloat a => RealFrac (Scalar a)
forall a. RealFloat a => Int -> Scalar a -> Scalar a
forall a. RealFloat a => Integer -> Int -> Scalar a
forall a. RealFloat a => Scalar a -> Bool
forall a. RealFloat a => Scalar a -> Int
forall a. RealFloat a => Scalar a -> Integer
forall a. RealFloat a => Scalar a -> (Int, Int)
forall a. RealFloat a => Scalar a -> (Integer, Int)
forall a. RealFloat a => Scalar a -> Scalar a
forall a. RealFloat a => Scalar a -> Scalar a -> Scalar a
forall a.
(RealFrac a, Floating a) =>
(a -> Integer)
-> (a -> Int)
-> (a -> (Int, Int))
-> (a -> (Integer, Int))
-> (Integer -> Int -> a)
-> (a -> Int)
-> (a -> a)
-> (Int -> a -> a)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> a -> a)
-> RealFloat a
atan2 :: Scalar a -> Scalar a -> Scalar a
$catan2 :: forall a. RealFloat a => Scalar a -> Scalar a -> Scalar a
isIEEE :: Scalar a -> Bool
$cisIEEE :: forall a. RealFloat a => Scalar a -> Bool
isNegativeZero :: Scalar a -> Bool
$cisNegativeZero :: forall a. RealFloat a => Scalar a -> Bool
isDenormalized :: Scalar a -> Bool
$cisDenormalized :: forall a. RealFloat a => Scalar a -> Bool
isInfinite :: Scalar a -> Bool
$cisInfinite :: forall a. RealFloat a => Scalar a -> Bool
isNaN :: Scalar a -> Bool
$cisNaN :: forall a. RealFloat a => Scalar a -> Bool
scaleFloat :: Int -> Scalar a -> Scalar a
$cscaleFloat :: forall a. RealFloat a => Int -> Scalar a -> Scalar a
significand :: Scalar a -> Scalar a
$csignificand :: forall a. RealFloat a => Scalar a -> Scalar a
exponent :: Scalar a -> Int
$cexponent :: forall a. RealFloat a => Scalar a -> Int
encodeFloat :: Integer -> Int -> Scalar a
$cencodeFloat :: forall a. RealFloat a => Integer -> Int -> Scalar a
decodeFloat :: Scalar a -> (Integer, Int)
$cdecodeFloat :: forall a. RealFloat a => Scalar a -> (Integer, Int)
floatRange :: Scalar a -> (Int, Int)
$cfloatRange :: forall a. RealFloat a => Scalar a -> (Int, Int)
floatDigits :: Scalar a -> Int
$cfloatDigits :: forall a. RealFloat a => Scalar a -> Int
floatRadix :: Scalar a -> Integer
$cfloatRadix :: forall a. RealFloat a => Scalar a -> Integer
$cp2RealFloat :: forall a. RealFloat a => Floating (Scalar a)
$cp1RealFloat :: forall a. RealFloat a => RealFrac (Scalar a)
RealFloat,
              Fractional (Scalar a)
Real (Scalar a)
(Real (Scalar a), Fractional (Scalar a)) =>
(forall b. Integral b => Scalar a -> (b, Scalar a))
-> (forall b. Integral b => Scalar a -> b)
-> (forall b. Integral b => Scalar a -> b)
-> (forall b. Integral b => Scalar a -> b)
-> (forall b. Integral b => Scalar a -> b)
-> RealFrac (Scalar a)
Scalar a -> b
Scalar a -> b
Scalar a -> b
Scalar a -> b
Scalar a -> (b, Scalar a)
forall b. Integral b => Scalar a -> b
forall b. Integral b => Scalar a -> (b, Scalar a)
forall a.
(Real a, Fractional a) =>
(forall b. Integral b => a -> (b, a))
-> (forall b. Integral b => a -> b)
-> (forall b. Integral b => a -> b)
-> (forall b. Integral b => a -> b)
-> (forall b. Integral b => a -> b)
-> RealFrac a
forall a. RealFrac a => Fractional (Scalar a)
forall a. RealFrac a => Real (Scalar a)
forall a b. (RealFrac a, Integral b) => Scalar a -> b
forall a b. (RealFrac a, Integral b) => Scalar a -> (b, Scalar a)
floor :: Scalar a -> b
$cfloor :: forall a b. (RealFrac a, Integral b) => Scalar a -> b
ceiling :: Scalar a -> b
$cceiling :: forall a b. (RealFrac a, Integral b) => Scalar a -> b
round :: Scalar a -> b
$cround :: forall a b. (RealFrac a, Integral b) => Scalar a -> b
truncate :: Scalar a -> b
$ctruncate :: forall a b. (RealFrac a, Integral b) => Scalar a -> b
properFraction :: Scalar a -> (b, Scalar a)
$cproperFraction :: forall a b. (RealFrac a, Integral b) => Scalar a -> (b, Scalar a)
$cp2RealFrac :: forall a. RealFrac a => Fractional (Scalar a)
$cp1RealFrac :: forall a. RealFrac a => Real (Scalar a)
RealFrac, [Char] -> Scalar a
([Char] -> Scalar a) -> IsString (Scalar a)
forall a. IsString a => [Char] -> Scalar a
forall a. ([Char] -> a) -> IsString a
fromString :: [Char] -> Scalar a
$cfromString :: forall a. IsString a => [Char] -> Scalar a
IsString)

instance (TensorDataType V.Vector a, TensorType a) => TensorDataType Scalar a where
    decodeTensorData :: TensorData a -> Scalar a
decodeTensorData = a -> Scalar a
forall a. a -> Scalar a
Scalar (a -> Scalar a) -> (TensorData a -> a) -> TensorData a -> Scalar a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> a
forall a. Vector a -> a
headFromSingleton (Vector a -> a) -> (TensorData a -> Vector a) -> TensorData a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData a -> Vector a
forall (s :: * -> *) a. TensorDataType s a => TensorData a -> s a
decodeTensorData
    encodeTensorData :: Shape -> Scalar a -> TensorData a
encodeTensorData x :: Shape
x (Scalar y :: a
y) = Shape -> Vector a -> TensorData a
forall (s :: * -> *) a.
TensorDataType s a =>
Shape -> s a -> TensorData a
encodeTensorData Shape
x ([a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a
y])

headFromSingleton :: V.Vector a -> a
headFromSingleton :: Vector a -> a
headFromSingleton x :: Vector a
x
    | Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 1 = Vector a -> a
forall a. Vector a -> a
V.head Vector a
x
    | Bool
otherwise = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$
                  "Unable to extract singleton from tensor of length "
                  [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
x)


-- | Shape (dimensions) of a tensor.
--
-- TensorFlow supports shapes of unknown rank, which are represented as
-- @Nothing :: Maybe Shape@ in Haskell.
newtype Shape = Shape [Int64] deriving Int -> Shape -> [Char] -> [Char]
[Shape] -> [Char] -> [Char]
Shape -> [Char]
(Int -> Shape -> [Char] -> [Char])
-> (Shape -> [Char]) -> ([Shape] -> [Char] -> [Char]) -> Show Shape
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [Shape] -> [Char] -> [Char]
$cshowList :: [Shape] -> [Char] -> [Char]
show :: Shape -> [Char]
$cshow :: Shape -> [Char]
showsPrec :: Int -> Shape -> [Char] -> [Char]
$cshowsPrec :: Int -> Shape -> [Char] -> [Char]
Show

instance IsList Shape where
    type Item Shape = Int64
    fromList :: [Item Shape] -> Shape
fromList = [Int64] -> Shape
Shape ([Int64] -> Shape) -> ([Int64] -> [Int64]) -> [Int64] -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int64] -> [Int64]
forall l. IsList l => [Item l] -> l
fromList
    toList :: Shape -> [Item Shape]
toList (Shape ss :: [Int64]
ss) = [Int64] -> [Item [Int64]]
forall l. IsList l => l -> [Item l]
toList [Int64]
ss

protoShape :: Lens' TensorShapeProto Shape
protoShape :: LensLike' f TensorShapeProto Shape
protoShape = Resetter TensorShapeProto (f TensorShapeProto) Shape (f Shape)
-> LensLike' f TensorShapeProto Shape
forall s t a b. Resetter s t a b -> (a -> b) -> s -> t
under ((TensorShapeProto -> Shape)
-> (Shape -> TensorShapeProto)
-> Adapter TensorShapeProto TensorShapeProto Shape Shape
forall s a b t. (s -> a) -> (b -> t) -> Adapter s t a b
adapter TensorShapeProto -> Shape
protoToShape Shape -> TensorShapeProto
shapeToProto)
  where
    protoToShape :: TensorShapeProto -> Shape
protoToShape p :: TensorShapeProto
p = Shape -> Maybe Shape -> Shape
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> Shape
forall a. HasCallStack => [Char] -> a
error [Char]
msg) (FoldLike
  (Maybe Shape)
  TensorShapeProto
  TensorShapeProto
  (Maybe Shape)
  (Maybe Shape)
-> TensorShapeProto -> Maybe Shape
forall a s t b. FoldLike a s t a b -> s -> a
view FoldLike
  (Maybe Shape)
  TensorShapeProto
  TensorShapeProto
  (Maybe Shape)
  (Maybe Shape)
Lens' TensorShapeProto (Maybe Shape)
protoMaybeShape TensorShapeProto
p)
      where msg :: [Char]
msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
                  [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TensorShapeProto -> [Char]
forall msg. Message msg => msg -> [Char]
showMessageShort TensorShapeProto
p
    shapeToProto :: Shape -> TensorShapeProto
shapeToProto s' :: Shape
s' = TensorShapeProto
forall msg. Message msg => msg
defMessage TensorShapeProto
-> (TensorShapeProto -> TensorShapeProto) -> TensorShapeProto
forall s t. s -> (s -> t) -> t
& Lens' TensorShapeProto (Maybe Shape)
forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto (Maybe Shape)
protoMaybeShape (forall (f :: * -> *).
 Identical f =>
 LensLike' f TensorShapeProto (Maybe Shape))
-> Maybe Shape -> TensorShapeProto -> TensorShapeProto
forall s t a b. Setter s t a b -> b -> s -> t
.~ Shape -> Maybe Shape
forall a. a -> Maybe a
Just Shape
s'

protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
protoMaybeShape :: LensLike' f TensorShapeProto (Maybe Shape)
protoMaybeShape = Resetter
  TensorShapeProto
  (f TensorShapeProto)
  (Maybe Shape)
  (f (Maybe Shape))
-> LensLike' f TensorShapeProto (Maybe Shape)
forall s t a b. Resetter s t a b -> (a -> b) -> s -> t
under ((TensorShapeProto -> Maybe Shape)
-> (Maybe Shape -> TensorShapeProto)
-> Adapter
     TensorShapeProto TensorShapeProto (Maybe Shape) (Maybe Shape)
forall s a b t. (s -> a) -> (b -> t) -> Adapter s t a b
adapter TensorShapeProto -> Maybe Shape
protoToShape Maybe Shape -> TensorShapeProto
shapeToProto)
  where
    protoToShape :: TensorShapeProto -> Maybe Shape
    protoToShape :: TensorShapeProto -> Maybe Shape
protoToShape p :: TensorShapeProto
p =
        if FoldLike Bool TensorShapeProto TensorShapeProto Bool Bool
-> TensorShapeProto -> Bool
forall a s t b. FoldLike a s t a b -> s -> a
view FoldLike Bool TensorShapeProto TensorShapeProto Bool Bool
forall (f :: * -> *) s a.
(Functor f, HasField s "unknownRank" a) =>
LensLike' f s a
unknownRank TensorShapeProto
p
            then Maybe Shape
forall a. Maybe a
Nothing
            else Shape -> Maybe Shape
forall a. a -> Maybe a
Just ([Int64] -> Shape
Shape (TensorShapeProto
p TensorShapeProto
-> Fold TensorShapeProto TensorShapeProto Int64 Int64 -> [Int64]
forall s t a b. s -> Fold s t a b -> [a]
^.. LensLike' f TensorShapeProto [TensorShapeProto'Dim]
forall (f :: * -> *) s a.
(Functor f, HasField s "dim" a) =>
LensLike' f s a
dim LensLike' f TensorShapeProto [TensorShapeProto'Dim]
-> ((Int64 -> f Int64)
    -> [TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
-> (Int64 -> f Int64)
-> TensorShapeProto
-> f TensorShapeProto
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TensorShapeProto'Dim -> f TensorShapeProto'Dim)
-> [TensorShapeProto'Dim] -> f [TensorShapeProto'Dim]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TensorShapeProto'Dim -> f TensorShapeProto'Dim)
 -> [TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
-> ((Int64 -> f Int64)
    -> TensorShapeProto'Dim -> f TensorShapeProto'Dim)
-> (Int64 -> f Int64)
-> [TensorShapeProto'Dim]
-> f [TensorShapeProto'Dim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int64 -> f Int64)
-> TensorShapeProto'Dim -> f TensorShapeProto'Dim
forall (f :: * -> *) s a.
(Functor f, HasField s "size" a) =>
LensLike' f s a
size))
    shapeToProto :: Maybe Shape -> TensorShapeProto
    shapeToProto :: Maybe Shape -> TensorShapeProto
shapeToProto Nothing =
        TensorShapeProto
forall msg. Message msg => msg
defMessage TensorShapeProto
-> (TensorShapeProto -> TensorShapeProto) -> TensorShapeProto
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto Bool
forall (f :: * -> *) s a.
(Functor f, HasField s "unknownRank" a) =>
LensLike' f s a
unknownRank (forall (f :: * -> *).
 Identical f =>
 LensLike' f TensorShapeProto Bool)
-> Bool -> TensorShapeProto -> TensorShapeProto
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
True
    shapeToProto (Just (Shape ds :: [Int64]
ds)) =
        TensorShapeProto
forall msg. Message msg => msg
defMessage TensorShapeProto
-> (TensorShapeProto -> TensorShapeProto) -> TensorShapeProto
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto [TensorShapeProto'Dim]
forall (f :: * -> *) s a.
(Functor f, HasField s "dim" a) =>
LensLike' f s a
dim (forall (f :: * -> *).
 Identical f =>
 LensLike' f TensorShapeProto [TensorShapeProto'Dim])
-> [TensorShapeProto'Dim] -> TensorShapeProto -> TensorShapeProto
forall s t a b. Setter s t a b -> b -> s -> t
.~ (Int64 -> TensorShapeProto'Dim)
-> [Int64] -> [TensorShapeProto'Dim]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\d :: Int64
d -> TensorShapeProto'Dim
forall msg. Message msg => msg
defMessage TensorShapeProto'Dim
-> (TensorShapeProto'Dim -> TensorShapeProto'Dim)
-> TensorShapeProto'Dim
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto'Dim Int64
forall (f :: * -> *) s a.
(Functor f, HasField s "size" a) =>
LensLike' f s a
size (forall (f :: * -> *).
 Identical f =>
 LensLike' f TensorShapeProto'Dim Int64)
-> Int64 -> TensorShapeProto'Dim -> TensorShapeProto'Dim
forall s t a b. Setter s t a b -> b -> s -> t
.~ Int64
d) [Int64]
ds


class Attribute a where
    attrLens :: Lens' AttrValue a

instance Attribute Float where
    attrLens :: LensLike' f AttrValue Float
attrLens = LensLike' f AttrValue Float
forall (f :: * -> *) s a.
(Functor f, HasField s "f" a) =>
LensLike' f s a
f

instance Attribute ByteString where
    attrLens :: LensLike' f AttrValue ByteString
attrLens = LensLike' f AttrValue ByteString
forall (f :: * -> *) s a.
(Functor f, HasField s "s" a) =>
LensLike' f s a
s

instance Attribute Int64 where
    attrLens :: LensLike' f AttrValue Int64
attrLens = LensLike' f AttrValue Int64
forall (f :: * -> *) s a.
(Functor f, HasField s "i" a) =>
LensLike' f s a
i

instance Attribute DataType where
    attrLens :: LensLike' f AttrValue DataType
attrLens = LensLike' f AttrValue DataType
forall (f :: * -> *) s a.
(Functor f, HasField s "type'" a) =>
LensLike' f s a
type'

instance Attribute TensorProto where
    attrLens :: LensLike' f AttrValue TensorProto
attrLens = LensLike' f AttrValue TensorProto
forall (f :: * -> *) s a.
(Functor f, HasField s "tensor" a) =>
LensLike' f s a
tensor

instance Attribute Bool where
    attrLens :: LensLike' f AttrValue Bool
attrLens = LensLike' f AttrValue Bool
forall (f :: * -> *) s a.
(Functor f, HasField s "b" a) =>
LensLike' f s a
b

instance Attribute Shape where
    attrLens :: LensLike' f AttrValue Shape
attrLens = LensLike' f AttrValue TensorShapeProto
forall (f :: * -> *) s a.
(Functor f, HasField s "shape" a) =>
LensLike' f s a
shape LensLike' f AttrValue TensorShapeProto
-> ((Shape -> f Shape) -> TensorShapeProto -> f TensorShapeProto)
-> LensLike' f AttrValue Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Shape -> f Shape) -> TensorShapeProto -> f TensorShapeProto
Lens' TensorShapeProto Shape
protoShape

instance Attribute (Maybe Shape) where
    attrLens :: LensLike' f AttrValue (Maybe Shape)
attrLens = LensLike' f AttrValue TensorShapeProto
forall (f :: * -> *) s a.
(Functor f, HasField s "shape" a) =>
LensLike' f s a
shape LensLike' f AttrValue TensorShapeProto
-> ((Maybe Shape -> f (Maybe Shape))
    -> TensorShapeProto -> f TensorShapeProto)
-> LensLike' f AttrValue (Maybe Shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe Shape -> f (Maybe Shape))
-> TensorShapeProto -> f TensorShapeProto
Lens' TensorShapeProto (Maybe Shape)
protoMaybeShape

-- TODO(gnezdo): support generating list(Foo) from [Foo].
instance Attribute AttrValue'ListValue where
    attrLens :: LensLike' f AttrValue AttrValue'ListValue
attrLens = LensLike' f AttrValue AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "list" a) =>
LensLike' f s a
list

instance Attribute [DataType] where
    attrLens :: LensLike' f AttrValue [DataType]
attrLens = LensLike' f AttrValue AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "list" a) =>
LensLike' f s a
list LensLike' f AttrValue AttrValue'ListValue
-> (([DataType] -> f [DataType])
    -> AttrValue'ListValue -> f AttrValue'ListValue)
-> LensLike' f AttrValue [DataType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([DataType] -> f [DataType])
-> AttrValue'ListValue -> f AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "type'" a) =>
LensLike' f s a
type'

instance Attribute [Int64] where
    attrLens :: LensLike' f AttrValue [Int64]
attrLens = LensLike' f AttrValue AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "list" a) =>
LensLike' f s a
list LensLike' f AttrValue AttrValue'ListValue
-> (([Int64] -> f [Int64])
    -> AttrValue'ListValue -> f AttrValue'ListValue)
-> LensLike' f AttrValue [Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Int64] -> f [Int64])
-> AttrValue'ListValue -> f AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "i" a) =>
LensLike' f s a
i

-- | A heterogeneous list type.
data ListOf f as where
    Nil :: ListOf f '[]
    (:/) :: f a -> ListOf f as -> ListOf f (a ': as)

infixr 5 :/

type family All f as :: Constraint where
    All f '[] = ()
    All f (a ': as) = (f a, All f as)

type family Map f as where
    Map f '[] = '[]
    Map f (a ': as) = f a ': Map f as

instance All Eq (Map f as) => Eq (ListOf f as) where
    Nil == :: ListOf f as -> ListOf f as -> Bool
== Nil = Bool
True
    (x :: f a
x :/ xs :: ListOf f as
xs) == (y :: f a
y :/ ys :: ListOf f as
ys) = f a
x f a -> f a -> Bool
forall a. Eq a => a -> a -> Bool
== f a
f a
y Bool -> Bool -> Bool
&& ListOf f as
xs ListOf f as -> ListOf f as -> Bool
forall a. Eq a => a -> a -> Bool
== ListOf f as
ListOf f as
ys
    -- Newer versions of GHC use the GADT to tell that the previous cases are
    -- exhaustive.
#if __GLASGOW_HASKELL__ < 800
    _ == _ = False
#endif

instance All Show (Map f as) => Show (ListOf f as) where
    showsPrec :: Int -> ListOf f as -> [Char] -> [Char]
showsPrec _ Nil = [Char] -> [Char] -> [Char]
showString "Nil"
    showsPrec d :: Int
d (x :: f a
x :/ xs :: ListOf f as
xs) = Bool -> ([Char] -> [Char]) -> [Char] -> [Char]
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 10)
                                (([Char] -> [Char]) -> [Char] -> [Char])
-> ([Char] -> [Char]) -> [Char] -> [Char]
forall a b. (a -> b) -> a -> b
$ Int -> f a -> [Char] -> [Char]
forall a. Show a => Int -> a -> [Char] -> [Char]
showsPrec 6 f a
x ([Char] -> [Char]) -> ([Char] -> [Char]) -> [Char] -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> [Char] -> [Char]
showString " :/ "
                                    ([Char] -> [Char]) -> ([Char] -> [Char]) -> [Char] -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ListOf f as -> [Char] -> [Char]
forall a. Show a => Int -> a -> [Char] -> [Char]
showsPrec 6 ListOf f as
xs

type List = ListOf Identity

-- | Equivalent of ':/' for lists.
(/:/) :: a -> List as -> List (a ': as)
/:/ :: a -> List as -> List (a : as)
(/:/) = Identity a -> List as -> List (a : as)
forall (f :: * -> *) a (as :: [*]).
f a -> ListOf f as -> ListOf f (a : as)
(:/) (Identity a -> List as -> List (a : as))
-> (a -> Identity a) -> a -> List as -> List (a : as)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Identity a
forall a. a -> Identity a
Identity

infixr 5 /:/

-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
--
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
-- natural representation as a conjunction, i.e.,
--
-- @
--    a == Double || a == Float
-- @
--
-- into a disjunction like
--
-- @
--     a \/= Int32 && a \/= Int64 && a \/= ByteString && ...
-- @
--
-- using an enumeration of all the possible 'TensorType's.
type OneOf ts a
    -- Assert `TensorTypes' ts` to make error messages a little better.
    = (TensorType a, TensorTypes' ts, NoneOf (AllTensorTypes \\ ts) a)

type OneOfs ts as = (TensorTypes as, TensorTypes' ts,
                        NoneOfs (AllTensorTypes \\ ts) as)

type family NoneOfs ts as :: Constraint where
    NoneOfs ts '[] = ()
    NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as)

data TensorTypeProxy a where
    TensorTypeProxy :: TensorType a => TensorTypeProxy a

type TensorTypeList = ListOf TensorTypeProxy

fromTensorTypeList :: TensorTypeList ts -> [DataType]
fromTensorTypeList :: TensorTypeList ts -> [DataType]
fromTensorTypeList Nil = []
fromTensorTypeList ((TensorTypeProxy a
TensorTypeProxy :: TensorTypeProxy t) :/ ts :: ListOf TensorTypeProxy as
ts)
    = a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: t) DataType -> [DataType] -> [DataType]
forall a. a -> [a] -> [a]
: ListOf TensorTypeProxy as -> [DataType]
forall (ts :: [*]). TensorTypeList ts -> [DataType]
fromTensorTypeList ListOf TensorTypeProxy as
ts

fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType]
fromTensorTypes :: Proxy as -> [DataType]
fromTensorTypes _ = TensorTypeList as -> [DataType]
forall (ts :: [*]). TensorTypeList ts -> [DataType]
fromTensorTypeList (TensorTypeList as
forall (ts :: [*]). TensorTypes ts => TensorTypeList ts
tensorTypes :: TensorTypeList as)

class TensorTypes (ts :: [*]) where
    tensorTypes :: TensorTypeList ts

instance TensorTypes '[] where
    tensorTypes :: TensorTypeList '[]
tensorTypes = TensorTypeList '[]
forall (f :: * -> *). ListOf f '[]
Nil

-- | A constraint that the input is a list of 'TensorTypes'.
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where
    tensorTypes :: TensorTypeList (t : ts)
tensorTypes = TensorTypeProxy t
forall a. TensorType a => TensorTypeProxy a
TensorTypeProxy TensorTypeProxy t
-> ListOf TensorTypeProxy ts -> TensorTypeList (t : ts)
forall (f :: * -> *) a (as :: [*]).
f a -> ListOf f as -> ListOf f (a : as)
:/ ListOf TensorTypeProxy ts
forall (ts :: [*]). TensorTypes ts => TensorTypeList ts
tensorTypes

-- | A simpler version of the 'TensorTypes' class, that doesn't run
-- afoul of @-Wsimplifiable-class-constraints@.
--
-- In more detail: the constraint @OneOf '[Double, Float] a@ leads
-- to the constraint @TensorTypes' '[Double, Float]@, as a safety-check
-- to give better error messages.  However, if @TensorTypes'@ were a class,
-- then GHC 8.2.1 would complain with the above warning unless @NoMonoBinds@
-- were enabled.  So instead, we use a separate type family for this purpose.
-- For more details: https://ghc.haskell.org/trac/ghc/ticket/11948
type family TensorTypes' (ts :: [*]) :: Constraint where
    -- Specialize this type family when `ts` is a long list, to avoid deeply
    -- nested tuples of constraints.  Works around a bug in ghc-8.0:
    -- https://ghc.haskell.org/trac/ghc/ticket/12175
    TensorTypes' (t1 ': t2 ': t3 ': t4 ': ts)
        = (TensorType t1, TensorType t2, TensorType t3, TensorType t4
              , TensorTypes' ts)
    TensorTypes' (t1 ': t2 ': t3 ': ts)
        = (TensorType t1, TensorType t2, TensorType t3, TensorTypes' ts)
    TensorTypes' (t1 ': t2 ': ts)
        = (TensorType t1, TensorType t2, TensorTypes' ts)
    TensorTypes' (t ': ts) = (TensorType t, TensorTypes' ts)
    TensorTypes' '[] = ()

-- | A constraint checking that two types are different.
type family a /= b :: Constraint where
    a /= a = TypeError a ~ ExcludedCase
    a /= b = ()

-- | Helper types to produce a reasonable type error message when the Constraint
-- "a /= a" fails.
-- TODO(judahjacobson): Use ghc-8's CustomTypeErrors for this.
data TypeError a
data ExcludedCase

-- | An enumeration of all valid 'TensorType's.
type AllTensorTypes =
    -- NOTE: This list should be kept in sync with
    -- TensorFlow.OpGen.dtTypeToHaskell.
    -- TODO: Add support for Complex Float/Double.
    '[ Float
     , Double
     , Int8
     , Int16
     , Int32
     , Int64
     , Word8
     , Word16
     , ByteString
     , Bool
     ]

-- | Removes a type from the given list of types.
type family Delete a as where
    Delete a '[] = '[]
    Delete a (a ': as) = Delete a as
    Delete a (b ': as) = b ': Delete a as

-- | Takes the difference of two lists of types.
type family as \\ bs where
    as \\ '[] = as
    as \\ (b ': bs) = Delete b as \\ bs

-- | A constraint that the type @a@ doesn't appear in the type list @ts@.
-- Assumes that @a@ and each of the elements of @ts@ are 'TensorType's.
type family NoneOf ts a :: Constraint where
    -- Specialize this type family when `ts` is a long list, to avoid deeply
    -- nested tuples of constraints.  Works around a bug in ghc-8.0:
    -- https://ghc.haskell.org/trac/ghc/ticket/12175
    NoneOf (t1 ': t2 ': t3 ': t4 ': ts) a
        = (a /= t1, a /= t2, a /= t3, a /= t4, NoneOf ts a)
    NoneOf (t1 ': t2 ': t3 ': ts) a = (a /= t1, a /= t2, a /= t3, NoneOf ts a)
    NoneOf (t1 ': t2 ': ts) a = (a /= t1, a /= t2, NoneOf ts a)
    NoneOf (t1 ': ts) a = (a /= t1, NoneOf ts a)
    NoneOf '[] a = ()