{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module TensorFlow.Types
( TensorType(..)
, TensorData(..)
, TensorDataType(..)
, Scalar(..)
, Shape(..)
, protoShape
, Attribute(..)
, DataType(..)
, ResourceHandle
, Variant
, ListOf(..)
, List
, (/:/)
, TensorTypeProxy(..)
, TensorTypes(..)
, TensorTypeList
, fromTensorTypeList
, fromTensorTypes
, OneOf
, type (/=)
, OneOfs
, TypeError
, ExcludedCase
, NoneOf
, type (\\)
, Delete
, AllTensorTypes
) where
import Data.ProtoLens.Message(defMessage)
import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Maybe (fromMaybe)
import Data.ProtoLens.TextFormat (showMessageShort)
import Data.Proxy (Proxy(..))
import Data.String (IsString)
import Data.Word (Word8, Word16, Word32, Word64)
import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~), (^..), under)
import Lens.Family2.Unchecked (adapter)
import Text.Printf (printf)
import qualified Data.Attoparsec.ByteString as Atto
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Lazy as L
import qualified Data.Vector as V
import qualified Data.Vector.Storable as S
import Proto.Tensorflow.Core.Framework.AttrValue
( AttrValue
, AttrValue'ListValue
)
import Proto.Tensorflow.Core.Framework.AttrValue_Fields
( b
, f
, i
, s
, list
, type'
, shape
, tensor
)
import Proto.Tensorflow.Core.Framework.ResourceHandle
(ResourceHandleProto)
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
(TensorProto)
import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
( boolVal
, doubleVal
, floatVal
, intVal
, int64Val
, resourceHandleVal
, stringVal
, uint32Val
, uint64Val
)
import Proto.Tensorflow.Core.Framework.TensorShape
(TensorShapeProto)
import Proto.Tensorflow.Core.Framework.TensorShape_Fields
( dim
, size
, unknownRank
)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
import qualified TensorFlow.Internal.FFI as FFI
type ResourceHandle = ResourceHandleProto
data Variant
class TensorType a where
tensorType :: a -> DataType
tensorRefType :: a -> DataType
tensorVal :: Lens' TensorProto [a]
instance TensorType Float where
tensorType :: Float -> DataType
tensorType _ = DataType
DT_FLOAT
tensorRefType :: Float -> DataType
tensorRefType _ = DataType
DT_FLOAT_REF
tensorVal :: LensLike' f TensorProto [Float]
tensorVal = LensLike' f TensorProto [Float]
forall (f :: * -> *) s a.
(Functor f, HasField s "floatVal" a) =>
LensLike' f s a
floatVal
instance TensorType Double where
tensorType :: Double -> DataType
tensorType _ = DataType
DT_DOUBLE
tensorRefType :: Double -> DataType
tensorRefType _ = DataType
DT_DOUBLE_REF
tensorVal :: LensLike' f TensorProto [Double]
tensorVal = LensLike' f TensorProto [Double]
forall (f :: * -> *) s a.
(Functor f, HasField s "doubleVal" a) =>
LensLike' f s a
doubleVal
instance TensorType Int32 where
tensorType :: Int32 -> DataType
tensorType _ = DataType
DT_INT32
tensorRefType :: Int32 -> DataType
tensorRefType _ = DataType
DT_INT32_REF
tensorVal :: LensLike' f TensorProto [Int32]
tensorVal = LensLike' f TensorProto [Int32]
forall (f :: * -> *) s a.
(Functor f, HasField s "intVal" a) =>
LensLike' f s a
intVal
instance TensorType Int64 where
tensorType :: Int64 -> DataType
tensorType _ = DataType
DT_INT64
tensorRefType :: Int64 -> DataType
tensorRefType _ = DataType
DT_INT64_REF
tensorVal :: LensLike' f TensorProto [Int64]
tensorVal = LensLike' f TensorProto [Int64]
forall (f :: * -> *) s a.
(Functor f, HasField s "int64Val" a) =>
LensLike' f s a
int64Val
integral :: Integral a => Lens' [Int32] [a]
integral :: Lens' [Int32] [a]
integral = Resetter [Int32] (f [Int32]) [a] (f [a])
-> ([a] -> f [a]) -> [Int32] -> f [Int32]
forall s t a b. Resetter s t a b -> (a -> b) -> s -> t
under (([Int32] -> [a])
-> ([a] -> [Int32]) -> Adapter [Int32] [Int32] [a] [a]
forall s a b t. (s -> a) -> (b -> t) -> Adapter s t a b
adapter ((Int32 -> a) -> [Int32] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int32 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral) ((a -> Int32) -> [a] -> [Int32]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral))
instance TensorType Word8 where
tensorType :: Word8 -> DataType
tensorType _ = DataType
DT_UINT8
tensorRefType :: Word8 -> DataType
tensorRefType _ = DataType
DT_UINT8_REF
tensorVal :: LensLike' f TensorProto [Word8]
tensorVal = LensLike' f TensorProto [Int32]
forall (f :: * -> *) s a.
(Functor f, HasField s "intVal" a) =>
LensLike' f s a
intVal LensLike' f TensorProto [Int32]
-> (([Word8] -> f [Word8]) -> [Int32] -> f [Int32])
-> LensLike' f TensorProto [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Word8] -> f [Word8]) -> [Int32] -> f [Int32]
forall a. Integral a => Lens' [Int32] [a]
integral
instance TensorType Word16 where
tensorType :: Word16 -> DataType
tensorType _ = DataType
DT_UINT16
tensorRefType :: Word16 -> DataType
tensorRefType _ = DataType
DT_UINT16_REF
tensorVal :: LensLike' f TensorProto [Word16]
tensorVal = LensLike' f TensorProto [Int32]
forall (f :: * -> *) s a.
(Functor f, HasField s "intVal" a) =>
LensLike' f s a
intVal LensLike' f TensorProto [Int32]
-> (([Word16] -> f [Word16]) -> [Int32] -> f [Int32])
-> LensLike' f TensorProto [Word16]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Word16] -> f [Word16]) -> [Int32] -> f [Int32]
forall a. Integral a => Lens' [Int32] [a]
integral
instance TensorType Word32 where
tensorType :: Word32 -> DataType
tensorType _ = DataType
DT_UINT32
tensorRefType :: Word32 -> DataType
tensorRefType _ = DataType
DT_UINT32_REF
tensorVal :: LensLike' f TensorProto [Word32]
tensorVal = LensLike' f TensorProto [Word32]
forall (f :: * -> *) s a.
(Functor f, HasField s "uint32Val" a) =>
LensLike' f s a
uint32Val
instance TensorType Word64 where
tensorType :: Word64 -> DataType
tensorType _ = DataType
DT_UINT64
tensorRefType :: Word64 -> DataType
tensorRefType _ = DataType
DT_UINT64_REF
tensorVal :: LensLike' f TensorProto [Word64]
tensorVal = LensLike' f TensorProto [Word64]
forall (f :: * -> *) s a.
(Functor f, HasField s "uint64Val" a) =>
LensLike' f s a
uint64Val
instance TensorType Int16 where
tensorType :: Int16 -> DataType
tensorType _ = DataType
DT_INT16
tensorRefType :: Int16 -> DataType
tensorRefType _ = DataType
DT_INT16_REF
tensorVal :: LensLike' f TensorProto [Int16]
tensorVal = LensLike' f TensorProto [Int32]
forall (f :: * -> *) s a.
(Functor f, HasField s "intVal" a) =>
LensLike' f s a
intVal LensLike' f TensorProto [Int32]
-> (([Int16] -> f [Int16]) -> [Int32] -> f [Int32])
-> LensLike' f TensorProto [Int16]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Int16] -> f [Int16]) -> [Int32] -> f [Int32]
forall a. Integral a => Lens' [Int32] [a]
integral
instance TensorType Int8 where
tensorType :: Int8 -> DataType
tensorType _ = DataType
DT_INT8
tensorRefType :: Int8 -> DataType
tensorRefType _ = DataType
DT_INT8_REF
tensorVal :: LensLike' f TensorProto [Int8]
tensorVal = LensLike' f TensorProto [Int32]
forall (f :: * -> *) s a.
(Functor f, HasField s "intVal" a) =>
LensLike' f s a
intVal LensLike' f TensorProto [Int32]
-> (([Int8] -> f [Int8]) -> [Int32] -> f [Int32])
-> LensLike' f TensorProto [Int8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Int8] -> f [Int8]) -> [Int32] -> f [Int32]
forall a. Integral a => Lens' [Int32] [a]
integral
instance TensorType ByteString where
tensorType :: ByteString -> DataType
tensorType _ = DataType
DT_STRING
tensorRefType :: ByteString -> DataType
tensorRefType _ = DataType
DT_STRING_REF
tensorVal :: LensLike' f TensorProto [ByteString]
tensorVal = LensLike' f TensorProto [ByteString]
forall (f :: * -> *) s a.
(Functor f, HasField s "stringVal" a) =>
LensLike' f s a
stringVal
instance TensorType Bool where
tensorType :: Bool -> DataType
tensorType _ = DataType
DT_BOOL
tensorRefType :: Bool -> DataType
tensorRefType _ = DataType
DT_BOOL_REF
tensorVal :: LensLike' f TensorProto [Bool]
tensorVal = LensLike' f TensorProto [Bool]
forall (f :: * -> *) s a.
(Functor f, HasField s "boolVal" a) =>
LensLike' f s a
boolVal
instance TensorType (Complex Float) where
tensorType :: Complex Float -> DataType
tensorType _ = DataType
DT_COMPLEX64
tensorRefType :: Complex Float -> DataType
tensorRefType _ = DataType
DT_COMPLEX64
tensorVal :: LensLike' f TensorProto [Complex Float]
tensorVal = [Char] -> LensLike' f TensorProto [Complex Float]
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Float)"
instance TensorType (Complex Double) where
tensorType :: Complex Double -> DataType
tensorType _ = DataType
DT_COMPLEX128
tensorRefType :: Complex Double -> DataType
tensorRefType _ = DataType
DT_COMPLEX128
tensorVal :: LensLike' f TensorProto [Complex Double]
tensorVal = [Char] -> LensLike' f TensorProto [Complex Double]
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Double)"
instance TensorType ResourceHandle where
tensorType :: ResourceHandle -> DataType
tensorType _ = DataType
DT_RESOURCE
tensorRefType :: ResourceHandle -> DataType
tensorRefType _ = DataType
DT_RESOURCE_REF
tensorVal :: LensLike' f TensorProto [ResourceHandle]
tensorVal = LensLike' f TensorProto [ResourceHandle]
forall (f :: * -> *) s a.
(Functor f, HasField s "resourceHandleVal" a) =>
LensLike' f s a
resourceHandleVal
instance TensorType Variant where
tensorType :: Variant -> DataType
tensorType _ = DataType
DT_VARIANT
tensorRefType :: Variant -> DataType
tensorRefType _ = DataType
DT_VARIANT_REF
tensorVal :: LensLike' f TensorProto [Variant]
tensorVal = [Char] -> LensLike' f TensorProto [Variant]
forall a. HasCallStack => [Char] -> a
error "TODO Variant"
newtype TensorData a = TensorData { TensorData a -> TensorData
unTensorData :: FFI.TensorData }
class TensorType a => TensorDataType s a where
decodeTensorData :: TensorData a -> s a
encodeTensorData :: Shape -> s a -> TensorData a
simpleDecode :: Storable a => TensorData a -> S.Vector a
simpleDecode :: TensorData a -> Vector a
simpleDecode = Vector Word8 -> Vector a
forall a b. (Storable a, Storable b) => Vector a -> Vector b
S.unsafeCast (Vector Word8 -> Vector a)
-> (TensorData a -> Vector Word8) -> TensorData a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData -> Vector Word8
FFI.tensorDataBytes (TensorData -> Vector Word8)
-> (TensorData a -> TensorData) -> TensorData a -> Vector Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData a -> TensorData
forall a. TensorData a -> TensorData
unTensorData
simpleEncode :: forall a . (TensorType a, Storable a)
=> Shape -> S.Vector a -> TensorData a
simpleEncode :: Shape -> Vector a -> TensorData a
simpleEncode (Shape xs :: [Int64]
xs) v :: Vector a
v =
if [Int64] -> Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
xs Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector a -> Int
forall a. Storable a => Vector a -> Int
S.length Vector a
v)
then [Char] -> TensorData a
forall a. HasCallStack => [Char] -> a
error ([Char] -> TensorData a) -> [Char] -> TensorData a
forall a b. (a -> b) -> a -> b
$ [Char] -> [Char] -> Int64 -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf
"simpleEncode: bad vector length for shape %v: expected=%d got=%d"
([Int64] -> [Char]
forall a. Show a => a -> [Char]
show [Int64]
xs) ([Int64] -> Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
xs) (Vector a -> Int
forall a. Storable a => Vector a -> Int
S.length Vector a
v)
else TensorData -> TensorData a
forall a. TensorData -> TensorData a
TensorData ([Int64] -> DataType -> Vector Word8 -> TensorData
FFI.TensorData [Int64]
xs DataType
dt (Vector a -> Vector Word8
forall a b. (Storable a, Storable b) => Vector a -> Vector b
S.unsafeCast Vector a
v))
where
dt :: DataType
dt = a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)
instance TensorDataType S.Vector Float where
decodeTensorData :: TensorData Float -> Vector Float
decodeTensorData = TensorData Float -> Vector Float
forall a. Storable a => TensorData a -> Vector a
simpleDecode
encodeTensorData :: Shape -> Vector Float -> TensorData Float
encodeTensorData = Shape -> Vector Float -> TensorData Float
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode
instance TensorDataType S.Vector Double where
decodeTensorData :: TensorData Double -> Vector Double
decodeTensorData = TensorData Double -> Vector Double
forall a. Storable a => TensorData a -> Vector a
simpleDecode
encodeTensorData :: Shape -> Vector Double -> TensorData Double
encodeTensorData = Shape -> Vector Double -> TensorData Double
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode
instance TensorDataType S.Vector Int8 where
decodeTensorData :: TensorData Int8 -> Vector Int8
decodeTensorData = TensorData Int8 -> Vector Int8
forall a. Storable a => TensorData a -> Vector a
simpleDecode
encodeTensorData :: Shape -> Vector Int8 -> TensorData Int8
encodeTensorData = Shape -> Vector Int8 -> TensorData Int8
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode
instance TensorDataType S.Vector Int16 where
decodeTensorData :: TensorData Int16 -> Vector Int16
decodeTensorData = TensorData Int16 -> Vector Int16
forall a. Storable a => TensorData a -> Vector a
simpleDecode
encodeTensorData :: Shape -> Vector Int16 -> TensorData Int16
encodeTensorData = Shape -> Vector Int16 -> TensorData Int16
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode
instance TensorDataType S.Vector Int32 where
decodeTensorData :: TensorData Int32 -> Vector Int32
decodeTensorData = TensorData Int32 -> Vector Int32
forall a. Storable a => TensorData a -> Vector a
simpleDecode
encodeTensorData :: Shape -> Vector Int32 -> TensorData Int32
encodeTensorData = Shape -> Vector Int32 -> TensorData Int32
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode
instance TensorDataType S.Vector Int64 where
decodeTensorData :: TensorData Int64 -> Vector Int64
decodeTensorData = TensorData Int64 -> Vector Int64
forall a. Storable a => TensorData a -> Vector a
simpleDecode
encodeTensorData :: Shape -> Vector Int64 -> TensorData Int64
encodeTensorData = Shape -> Vector Int64 -> TensorData Int64
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode
instance TensorDataType S.Vector Word8 where
decodeTensorData :: TensorData Word8 -> Vector Word8
decodeTensorData = TensorData Word8 -> Vector Word8
forall a. Storable a => TensorData a -> Vector a
simpleDecode
encodeTensorData :: Shape -> Vector Word8 -> TensorData Word8
encodeTensorData = Shape -> Vector Word8 -> TensorData Word8
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode
instance TensorDataType S.Vector Word16 where
decodeTensorData :: TensorData Word16 -> Vector Word16
decodeTensorData = TensorData Word16 -> Vector Word16
forall a. Storable a => TensorData a -> Vector a
simpleDecode
encodeTensorData :: Shape -> Vector Word16 -> TensorData Word16
encodeTensorData = Shape -> Vector Word16 -> TensorData Word16
forall a.
(TensorType a, Storable a) =>
Shape -> Vector a -> TensorData a
simpleEncode
instance TensorDataType S.Vector Bool where
decodeTensorData :: TensorData Bool -> Vector Bool
decodeTensorData =
Vector Bool -> Vector Bool
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
S.convert (Vector Bool -> Vector Bool)
-> (TensorData Bool -> Vector Bool)
-> TensorData Bool
-> Vector Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Bool) -> Vector Word8 -> Vector Bool
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
S.map (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= 0) (Vector Word8 -> Vector Bool)
-> (TensorData Bool -> Vector Word8)
-> TensorData Bool
-> Vector Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData -> Vector Word8
FFI.tensorDataBytes (TensorData -> Vector Word8)
-> (TensorData Bool -> TensorData)
-> TensorData Bool
-> Vector Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData Bool -> TensorData
forall a. TensorData a -> TensorData
unTensorData
encodeTensorData :: Shape -> Vector Bool -> TensorData Bool
encodeTensorData (Shape xs :: [Int64]
xs) =
TensorData -> TensorData Bool
forall a. TensorData -> TensorData a
TensorData (TensorData -> TensorData Bool)
-> (Vector Bool -> TensorData) -> Vector Bool -> TensorData Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int64] -> DataType -> Vector Word8 -> TensorData
FFI.TensorData [Int64]
xs DataType
DT_BOOL (Vector Word8 -> TensorData)
-> (Vector Bool -> Vector Word8) -> Vector Bool -> TensorData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> Word8) -> Vector Bool -> Vector Word8
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
S.map Bool -> Word8
fromBool (Vector Bool -> Vector Word8)
-> (Vector Bool -> Vector Bool) -> Vector Bool -> Vector Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Bool -> Vector Bool
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
S.convert
where
fromBool :: Bool -> Word8
fromBool x :: Bool
x = if Bool
x then 1 else 0 :: Word8
instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a, TensorType a)
=> TensorDataType V.Vector a where
decodeTensorData :: TensorData a -> Vector a
decodeTensorData = (Vector a -> Vector a
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
S.convert :: S.Vector a -> V.Vector a) (Vector a -> Vector a)
-> (TensorData a -> Vector a) -> TensorData a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData a -> Vector a
forall (s :: * -> *) a. TensorDataType s a => TensorData a -> s a
decodeTensorData
encodeTensorData :: Shape -> Vector a -> TensorData a
encodeTensorData x :: Shape
x = Shape -> Vector a -> TensorData a
forall (s :: * -> *) a.
TensorDataType s a =>
Shape -> s a -> TensorData a
encodeTensorData Shape
x (Vector a -> TensorData a)
-> (Vector a -> Vector a) -> Vector a -> TensorData a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector a -> Vector a
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
S.convert :: V.Vector a -> S.Vector a)
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where
decodeTensorData :: TensorData (Complex Float) -> Vector (Complex Float)
decodeTensorData = [Char] -> TensorData (Complex Float) -> Vector (Complex Float)
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Float)"
encodeTensorData :: Shape -> Vector (Complex Float) -> TensorData (Complex Float)
encodeTensorData = [Char]
-> Shape -> Vector (Complex Float) -> TensorData (Complex Float)
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Float)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
decodeTensorData :: TensorData (Complex Double) -> Vector (Complex Double)
decodeTensorData = [Char] -> TensorData (Complex Double) -> Vector (Complex Double)
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Double)"
encodeTensorData :: Shape -> Vector (Complex Double) -> TensorData (Complex Double)
encodeTensorData = [Char]
-> Shape -> Vector (Complex Double) -> TensorData (Complex Double)
forall a. HasCallStack => [Char] -> a
error "TODO (Complex Double)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
decodeTensorData :: TensorData ByteString -> Vector ByteString
decodeTensorData tensorData :: TensorData ByteString
tensorData =
([Char] -> Vector ByteString)
-> (Vector ByteString -> Vector ByteString)
-> Either [Char] (Vector ByteString)
-> Vector ByteString
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (\err :: [Char]
err -> [Char] -> Vector ByteString
forall a. HasCallStack => [Char] -> a
error ([Char] -> Vector ByteString) -> [Char] -> Vector ByteString
forall a b. (a -> b) -> a -> b
$ "Malformed TF_STRING tensor; " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
err) Vector ByteString -> Vector ByteString
forall a. a -> a
id (Either [Char] (Vector ByteString) -> Vector ByteString)
-> Either [Char] (Vector ByteString) -> Vector ByteString
forall a b. (a -> b) -> a -> b
$
if Int
expected Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
count
then [Char] -> Either [Char] (Vector ByteString)
forall a b. a -> Either a b
Left ([Char] -> Either [Char] (Vector ByteString))
-> [Char] -> Either [Char] (Vector ByteString)
forall a b. (a -> b) -> a -> b
$ "decodeTensorData for ByteString count mismatch " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
(Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (Int
expected, Int
count)
else (Word64 -> Either [Char] ByteString)
-> Vector Word64 -> Either [Char] (Vector ByteString)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM Word64 -> Either [Char] ByteString
decodeString (Vector Word64 -> Vector Word64
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
S.convert Vector Word64
offsets)
where
expected :: Int
expected = Vector Word64 -> Int
forall a. Storable a => Vector a -> Int
S.length Vector Word64
offsets
count :: Int
count = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ [Int64] -> Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Int64] -> Int64) -> [Int64] -> Int64
forall a b. (a -> b) -> a -> b
$ TensorData -> [Int64]
FFI.tensorDataDimensions
(TensorData -> [Int64]) -> TensorData -> [Int64]
forall a b. (a -> b) -> a -> b
$ TensorData ByteString -> TensorData
forall a. TensorData a -> TensorData
unTensorData TensorData ByteString
tensorData
bytes :: Vector Word8
bytes = TensorData -> Vector Word8
FFI.tensorDataBytes (TensorData -> Vector Word8) -> TensorData -> Vector Word8
forall a b. (a -> b) -> a -> b
$ TensorData ByteString -> TensorData
forall a. TensorData a -> TensorData
unTensorData TensorData ByteString
tensorData
offsets :: Vector Word64
offsets = Int -> Vector Word64 -> Vector Word64
forall a. Storable a => Int -> Vector a -> Vector a
S.take Int
count (Vector Word64 -> Vector Word64) -> Vector Word64 -> Vector Word64
forall a b. (a -> b) -> a -> b
$ Vector Word8 -> Vector Word64
forall a b. (Storable a, Storable b) => Vector a -> Vector b
S.unsafeCast Vector Word8
bytes :: S.Vector Word64
dataBytes :: ByteString
dataBytes = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ Vector Word8 -> [Word8]
forall a. Storable a => Vector a -> [a]
S.toList (Vector Word8 -> [Word8]) -> Vector Word8 -> [Word8]
forall a b. (a -> b) -> a -> b
$ Int -> Vector Word8 -> Vector Word8
forall a. Storable a => Int -> Vector a -> Vector a
S.drop (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
* 8) Vector Word8
bytes
decodeString :: Word64 -> Either String ByteString
decodeString :: Word64 -> Either [Char] ByteString
decodeString offset :: Word64
offset =
let stringDataStart :: ByteString
stringDataStart = Int -> ByteString -> ByteString
B.drop (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
offset) ByteString
dataBytes
in Result ByteString -> Either [Char] ByteString
forall r. Result r -> Either [Char] r
Atto.eitherResult (Result ByteString -> Either [Char] ByteString)
-> Result ByteString -> Either [Char] ByteString
forall a b. (a -> b) -> a -> b
$ Parser ByteString -> ByteString -> Result ByteString
forall a. Parser a -> ByteString -> Result a
Atto.parse Parser ByteString
stringParser ByteString
stringDataStart
stringParser :: Atto.Parser ByteString
stringParser :: Parser ByteString
stringParser = Parser Word64
getVarInt Parser Word64 -> (Word64 -> Parser ByteString) -> Parser ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> Parser ByteString
Atto.take (Int -> Parser ByteString)
-> (Word64 -> Int) -> Word64 -> Parser ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
encodeTensorData :: Shape -> Vector ByteString -> TensorData ByteString
encodeTensorData (Shape xs :: [Int64]
xs) vec :: Vector ByteString
vec =
TensorData -> TensorData ByteString
forall a. TensorData -> TensorData a
TensorData (TensorData -> TensorData ByteString)
-> TensorData -> TensorData ByteString
forall a b. (a -> b) -> a -> b
$ [Int64] -> DataType -> Vector Word8 -> TensorData
FFI.TensorData [Int64]
xs DataType
dt Vector Word8
byteVector
where
dt :: DataType
dt = ByteString -> DataType
forall a. TensorType a => a -> DataType
tensorType (ByteString
forall a. HasCallStack => a
undefined :: ByteString)
addString :: (Builder, Builder, Word64)
-> ByteString
-> (Builder, Builder, Word64)
addString :: (Builder, Builder, Word64)
-> ByteString -> (Builder, Builder, Word64)
addString (table :: Builder
table, strings :: Builder
strings, offset :: Word64
offset) str :: ByteString
str =
( Builder
table Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word64 -> Builder
Builder.word64LE Word64
offset
, Builder
strings Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
lengthBytes Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
Builder.byteString ByteString
str
, Word64
offset Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
lengthBytesLen Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
strLen
)
where
strLen :: Word64
strLen = Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> Int -> Word64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
str
lengthBytes :: Builder
lengthBytes = Word64 -> Builder
putVarInt (Word64 -> Builder) -> Word64 -> Builder
forall a b. (a -> b) -> a -> b
$ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> Int -> Word64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
str
lengthBytesLen :: Word64
lengthBytesLen =
Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Word64) -> Int64 -> Word64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int64
L.length (ByteString -> Int64) -> ByteString -> Int64
forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
Builder.toLazyByteString Builder
lengthBytes
(table' :: Builder
table', strings' :: Builder
strings', _) = ((Builder, Builder, Word64)
-> ByteString -> (Builder, Builder, Word64))
-> (Builder, Builder, Word64)
-> Vector ByteString
-> (Builder, Builder, Word64)
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl' (Builder, Builder, Word64)
-> ByteString -> (Builder, Builder, Word64)
addString (Builder
forall a. Monoid a => a
mempty, Builder
forall a. Monoid a => a
mempty, 0) Vector ByteString
vec
bytes :: Builder
bytes = Builder
table' Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
strings'
byteVector :: Vector Word8
byteVector = [Word8] -> Vector Word8
forall a. Storable a => [a] -> Vector a
S.fromList ([Word8] -> Vector Word8) -> [Word8] -> Vector Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> [Word8]
L.unpack (ByteString -> [Word8]) -> ByteString -> [Word8]
forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
Builder.toLazyByteString Builder
bytes
newtype Scalar a = Scalar {Scalar a -> a
unScalar :: a}
deriving (Int -> Scalar a -> [Char] -> [Char]
[Scalar a] -> [Char] -> [Char]
Scalar a -> [Char]
(Int -> Scalar a -> [Char] -> [Char])
-> (Scalar a -> [Char])
-> ([Scalar a] -> [Char] -> [Char])
-> Show (Scalar a)
forall a. Show a => Int -> Scalar a -> [Char] -> [Char]
forall a. Show a => [Scalar a] -> [Char] -> [Char]
forall a. Show a => Scalar a -> [Char]
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [Scalar a] -> [Char] -> [Char]
$cshowList :: forall a. Show a => [Scalar a] -> [Char] -> [Char]
show :: Scalar a -> [Char]
$cshow :: forall a. Show a => Scalar a -> [Char]
showsPrec :: Int -> Scalar a -> [Char] -> [Char]
$cshowsPrec :: forall a. Show a => Int -> Scalar a -> [Char] -> [Char]
Show, Scalar a -> Scalar a -> Bool
(Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool) -> Eq (Scalar a)
forall a. Eq a => Scalar a -> Scalar a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Scalar a -> Scalar a -> Bool
$c/= :: forall a. Eq a => Scalar a -> Scalar a -> Bool
== :: Scalar a -> Scalar a -> Bool
$c== :: forall a. Eq a => Scalar a -> Scalar a -> Bool
Eq, Eq (Scalar a)
Eq (Scalar a) =>
(Scalar a -> Scalar a -> Ordering)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a -> Scalar a)
-> Ord (Scalar a)
Scalar a -> Scalar a -> Bool
Scalar a -> Scalar a -> Ordering
Scalar a -> Scalar a -> Scalar a
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Scalar a)
forall a. Ord a => Scalar a -> Scalar a -> Bool
forall a. Ord a => Scalar a -> Scalar a -> Ordering
forall a. Ord a => Scalar a -> Scalar a -> Scalar a
min :: Scalar a -> Scalar a -> Scalar a
$cmin :: forall a. Ord a => Scalar a -> Scalar a -> Scalar a
max :: Scalar a -> Scalar a -> Scalar a
$cmax :: forall a. Ord a => Scalar a -> Scalar a -> Scalar a
>= :: Scalar a -> Scalar a -> Bool
$c>= :: forall a. Ord a => Scalar a -> Scalar a -> Bool
> :: Scalar a -> Scalar a -> Bool
$c> :: forall a. Ord a => Scalar a -> Scalar a -> Bool
<= :: Scalar a -> Scalar a -> Bool
$c<= :: forall a. Ord a => Scalar a -> Scalar a -> Bool
< :: Scalar a -> Scalar a -> Bool
$c< :: forall a. Ord a => Scalar a -> Scalar a -> Bool
compare :: Scalar a -> Scalar a -> Ordering
$ccompare :: forall a. Ord a => Scalar a -> Scalar a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Scalar a)
Ord, Integer -> Scalar a
Scalar a -> Scalar a
Scalar a -> Scalar a -> Scalar a
(Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Integer -> Scalar a)
-> Num (Scalar a)
forall a. Num a => Integer -> Scalar a
forall a. Num a => Scalar a -> Scalar a
forall a. Num a => Scalar a -> Scalar a -> Scalar a
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> Scalar a
$cfromInteger :: forall a. Num a => Integer -> Scalar a
signum :: Scalar a -> Scalar a
$csignum :: forall a. Num a => Scalar a -> Scalar a
abs :: Scalar a -> Scalar a
$cabs :: forall a. Num a => Scalar a -> Scalar a
negate :: Scalar a -> Scalar a
$cnegate :: forall a. Num a => Scalar a -> Scalar a
* :: Scalar a -> Scalar a -> Scalar a
$c* :: forall a. Num a => Scalar a -> Scalar a -> Scalar a
- :: Scalar a -> Scalar a -> Scalar a
$c- :: forall a. Num a => Scalar a -> Scalar a -> Scalar a
+ :: Scalar a -> Scalar a -> Scalar a
$c+ :: forall a. Num a => Scalar a -> Scalar a -> Scalar a
Num, Num (Scalar a)
Num (Scalar a) =>
(Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Rational -> Scalar a)
-> Fractional (Scalar a)
Rational -> Scalar a
Scalar a -> Scalar a
Scalar a -> Scalar a -> Scalar a
forall a. Fractional a => Num (Scalar a)
forall a. Fractional a => Rational -> Scalar a
forall a. Fractional a => Scalar a -> Scalar a
forall a. Fractional a => Scalar a -> Scalar a -> Scalar a
forall a.
Num a =>
(a -> a -> a) -> (a -> a) -> (Rational -> a) -> Fractional a
fromRational :: Rational -> Scalar a
$cfromRational :: forall a. Fractional a => Rational -> Scalar a
recip :: Scalar a -> Scalar a
$crecip :: forall a. Fractional a => Scalar a -> Scalar a
/ :: Scalar a -> Scalar a -> Scalar a
$c/ :: forall a. Fractional a => Scalar a -> Scalar a -> Scalar a
$cp1Fractional :: forall a. Fractional a => Num (Scalar a)
Fractional, Fractional (Scalar a)
Scalar a
Fractional (Scalar a) =>
Scalar a
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> (Scalar a -> Scalar a)
-> Floating (Scalar a)
Scalar a -> Scalar a
Scalar a -> Scalar a -> Scalar a
forall a. Floating a => Fractional (Scalar a)
forall a. Floating a => Scalar a
forall a. Floating a => Scalar a -> Scalar a
forall a. Floating a => Scalar a -> Scalar a -> Scalar a
forall a.
Fractional a =>
a
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> Floating a
log1mexp :: Scalar a -> Scalar a
$clog1mexp :: forall a. Floating a => Scalar a -> Scalar a
log1pexp :: Scalar a -> Scalar a
$clog1pexp :: forall a. Floating a => Scalar a -> Scalar a
expm1 :: Scalar a -> Scalar a
$cexpm1 :: forall a. Floating a => Scalar a -> Scalar a
log1p :: Scalar a -> Scalar a
$clog1p :: forall a. Floating a => Scalar a -> Scalar a
atanh :: Scalar a -> Scalar a
$catanh :: forall a. Floating a => Scalar a -> Scalar a
acosh :: Scalar a -> Scalar a
$cacosh :: forall a. Floating a => Scalar a -> Scalar a
asinh :: Scalar a -> Scalar a
$casinh :: forall a. Floating a => Scalar a -> Scalar a
tanh :: Scalar a -> Scalar a
$ctanh :: forall a. Floating a => Scalar a -> Scalar a
cosh :: Scalar a -> Scalar a
$ccosh :: forall a. Floating a => Scalar a -> Scalar a
sinh :: Scalar a -> Scalar a
$csinh :: forall a. Floating a => Scalar a -> Scalar a
atan :: Scalar a -> Scalar a
$catan :: forall a. Floating a => Scalar a -> Scalar a
acos :: Scalar a -> Scalar a
$cacos :: forall a. Floating a => Scalar a -> Scalar a
asin :: Scalar a -> Scalar a
$casin :: forall a. Floating a => Scalar a -> Scalar a
tan :: Scalar a -> Scalar a
$ctan :: forall a. Floating a => Scalar a -> Scalar a
cos :: Scalar a -> Scalar a
$ccos :: forall a. Floating a => Scalar a -> Scalar a
sin :: Scalar a -> Scalar a
$csin :: forall a. Floating a => Scalar a -> Scalar a
logBase :: Scalar a -> Scalar a -> Scalar a
$clogBase :: forall a. Floating a => Scalar a -> Scalar a -> Scalar a
** :: Scalar a -> Scalar a -> Scalar a
$c** :: forall a. Floating a => Scalar a -> Scalar a -> Scalar a
sqrt :: Scalar a -> Scalar a
$csqrt :: forall a. Floating a => Scalar a -> Scalar a
log :: Scalar a -> Scalar a
$clog :: forall a. Floating a => Scalar a -> Scalar a
exp :: Scalar a -> Scalar a
$cexp :: forall a. Floating a => Scalar a -> Scalar a
pi :: Scalar a
$cpi :: forall a. Floating a => Scalar a
$cp1Floating :: forall a. Floating a => Fractional (Scalar a)
Floating, Num (Scalar a)
Ord (Scalar a)
(Num (Scalar a), Ord (Scalar a)) =>
(Scalar a -> Rational) -> Real (Scalar a)
Scalar a -> Rational
forall a. (Num a, Ord a) => (a -> Rational) -> Real a
forall a. Real a => Num (Scalar a)
forall a. Real a => Ord (Scalar a)
forall a. Real a => Scalar a -> Rational
toRational :: Scalar a -> Rational
$ctoRational :: forall a. Real a => Scalar a -> Rational
$cp2Real :: forall a. Real a => Ord (Scalar a)
$cp1Real :: forall a. Real a => Num (Scalar a)
Real, Floating (Scalar a)
RealFrac (Scalar a)
(RealFrac (Scalar a), Floating (Scalar a)) =>
(Scalar a -> Integer)
-> (Scalar a -> Int)
-> (Scalar a -> (Int, Int))
-> (Scalar a -> (Integer, Int))
-> (Integer -> Int -> Scalar a)
-> (Scalar a -> Int)
-> (Scalar a -> Scalar a)
-> (Int -> Scalar a -> Scalar a)
-> (Scalar a -> Bool)
-> (Scalar a -> Bool)
-> (Scalar a -> Bool)
-> (Scalar a -> Bool)
-> (Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Scalar a)
-> RealFloat (Scalar a)
Int -> Scalar a -> Scalar a
Integer -> Int -> Scalar a
Scalar a -> Bool
Scalar a -> Int
Scalar a -> Integer
Scalar a -> (Int, Int)
Scalar a -> (Integer, Int)
Scalar a -> Scalar a
Scalar a -> Scalar a -> Scalar a
forall a. RealFloat a => Floating (Scalar a)
forall a. RealFloat a => RealFrac (Scalar a)
forall a. RealFloat a => Int -> Scalar a -> Scalar a
forall a. RealFloat a => Integer -> Int -> Scalar a
forall a. RealFloat a => Scalar a -> Bool
forall a. RealFloat a => Scalar a -> Int
forall a. RealFloat a => Scalar a -> Integer
forall a. RealFloat a => Scalar a -> (Int, Int)
forall a. RealFloat a => Scalar a -> (Integer, Int)
forall a. RealFloat a => Scalar a -> Scalar a
forall a. RealFloat a => Scalar a -> Scalar a -> Scalar a
forall a.
(RealFrac a, Floating a) =>
(a -> Integer)
-> (a -> Int)
-> (a -> (Int, Int))
-> (a -> (Integer, Int))
-> (Integer -> Int -> a)
-> (a -> Int)
-> (a -> a)
-> (Int -> a -> a)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> a -> a)
-> RealFloat a
atan2 :: Scalar a -> Scalar a -> Scalar a
$catan2 :: forall a. RealFloat a => Scalar a -> Scalar a -> Scalar a
isIEEE :: Scalar a -> Bool
$cisIEEE :: forall a. RealFloat a => Scalar a -> Bool
isNegativeZero :: Scalar a -> Bool
$cisNegativeZero :: forall a. RealFloat a => Scalar a -> Bool
isDenormalized :: Scalar a -> Bool
$cisDenormalized :: forall a. RealFloat a => Scalar a -> Bool
isInfinite :: Scalar a -> Bool
$cisInfinite :: forall a. RealFloat a => Scalar a -> Bool
isNaN :: Scalar a -> Bool
$cisNaN :: forall a. RealFloat a => Scalar a -> Bool
scaleFloat :: Int -> Scalar a -> Scalar a
$cscaleFloat :: forall a. RealFloat a => Int -> Scalar a -> Scalar a
significand :: Scalar a -> Scalar a
$csignificand :: forall a. RealFloat a => Scalar a -> Scalar a
exponent :: Scalar a -> Int
$cexponent :: forall a. RealFloat a => Scalar a -> Int
encodeFloat :: Integer -> Int -> Scalar a
$cencodeFloat :: forall a. RealFloat a => Integer -> Int -> Scalar a
decodeFloat :: Scalar a -> (Integer, Int)
$cdecodeFloat :: forall a. RealFloat a => Scalar a -> (Integer, Int)
floatRange :: Scalar a -> (Int, Int)
$cfloatRange :: forall a. RealFloat a => Scalar a -> (Int, Int)
floatDigits :: Scalar a -> Int
$cfloatDigits :: forall a. RealFloat a => Scalar a -> Int
floatRadix :: Scalar a -> Integer
$cfloatRadix :: forall a. RealFloat a => Scalar a -> Integer
$cp2RealFloat :: forall a. RealFloat a => Floating (Scalar a)
$cp1RealFloat :: forall a. RealFloat a => RealFrac (Scalar a)
RealFloat,
Fractional (Scalar a)
Real (Scalar a)
(Real (Scalar a), Fractional (Scalar a)) =>
(forall b. Integral b => Scalar a -> (b, Scalar a))
-> (forall b. Integral b => Scalar a -> b)
-> (forall b. Integral b => Scalar a -> b)
-> (forall b. Integral b => Scalar a -> b)
-> (forall b. Integral b => Scalar a -> b)
-> RealFrac (Scalar a)
Scalar a -> b
Scalar a -> b
Scalar a -> b
Scalar a -> b
Scalar a -> (b, Scalar a)
forall b. Integral b => Scalar a -> b
forall b. Integral b => Scalar a -> (b, Scalar a)
forall a.
(Real a, Fractional a) =>
(forall b. Integral b => a -> (b, a))
-> (forall b. Integral b => a -> b)
-> (forall b. Integral b => a -> b)
-> (forall b. Integral b => a -> b)
-> (forall b. Integral b => a -> b)
-> RealFrac a
forall a. RealFrac a => Fractional (Scalar a)
forall a. RealFrac a => Real (Scalar a)
forall a b. (RealFrac a, Integral b) => Scalar a -> b
forall a b. (RealFrac a, Integral b) => Scalar a -> (b, Scalar a)
floor :: Scalar a -> b
$cfloor :: forall a b. (RealFrac a, Integral b) => Scalar a -> b
ceiling :: Scalar a -> b
$cceiling :: forall a b. (RealFrac a, Integral b) => Scalar a -> b
round :: Scalar a -> b
$cround :: forall a b. (RealFrac a, Integral b) => Scalar a -> b
truncate :: Scalar a -> b
$ctruncate :: forall a b. (RealFrac a, Integral b) => Scalar a -> b
properFraction :: Scalar a -> (b, Scalar a)
$cproperFraction :: forall a b. (RealFrac a, Integral b) => Scalar a -> (b, Scalar a)
$cp2RealFrac :: forall a. RealFrac a => Fractional (Scalar a)
$cp1RealFrac :: forall a. RealFrac a => Real (Scalar a)
RealFrac, [Char] -> Scalar a
([Char] -> Scalar a) -> IsString (Scalar a)
forall a. IsString a => [Char] -> Scalar a
forall a. ([Char] -> a) -> IsString a
fromString :: [Char] -> Scalar a
$cfromString :: forall a. IsString a => [Char] -> Scalar a
IsString)
instance (TensorDataType V.Vector a, TensorType a) => TensorDataType Scalar a where
decodeTensorData :: TensorData a -> Scalar a
decodeTensorData = a -> Scalar a
forall a. a -> Scalar a
Scalar (a -> Scalar a) -> (TensorData a -> a) -> TensorData a -> Scalar a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> a
forall a. Vector a -> a
headFromSingleton (Vector a -> a) -> (TensorData a -> Vector a) -> TensorData a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TensorData a -> Vector a
forall (s :: * -> *) a. TensorDataType s a => TensorData a -> s a
decodeTensorData
encodeTensorData :: Shape -> Scalar a -> TensorData a
encodeTensorData x :: Shape
x (Scalar y :: a
y) = Shape -> Vector a -> TensorData a
forall (s :: * -> *) a.
TensorDataType s a =>
Shape -> s a -> TensorData a
encodeTensorData Shape
x ([a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a
y])
headFromSingleton :: V.Vector a -> a
headFromSingleton :: Vector a -> a
headFromSingleton x :: Vector a
x
| Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 1 = Vector a -> a
forall a. Vector a -> a
V.head Vector a
x
| Bool
otherwise = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$
"Unable to extract singleton from tensor of length "
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
x)
newtype Shape = Shape [Int64] deriving Int -> Shape -> [Char] -> [Char]
[Shape] -> [Char] -> [Char]
Shape -> [Char]
(Int -> Shape -> [Char] -> [Char])
-> (Shape -> [Char]) -> ([Shape] -> [Char] -> [Char]) -> Show Shape
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [Shape] -> [Char] -> [Char]
$cshowList :: [Shape] -> [Char] -> [Char]
show :: Shape -> [Char]
$cshow :: Shape -> [Char]
showsPrec :: Int -> Shape -> [Char] -> [Char]
$cshowsPrec :: Int -> Shape -> [Char] -> [Char]
Show
instance IsList Shape where
type Item Shape = Int64
fromList :: [Item Shape] -> Shape
fromList = [Int64] -> Shape
Shape ([Int64] -> Shape) -> ([Int64] -> [Int64]) -> [Int64] -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int64] -> [Int64]
forall l. IsList l => [Item l] -> l
fromList
toList :: Shape -> [Item Shape]
toList (Shape ss :: [Int64]
ss) = [Int64] -> [Item [Int64]]
forall l. IsList l => l -> [Item l]
toList [Int64]
ss
protoShape :: Lens' TensorShapeProto Shape
protoShape :: LensLike' f TensorShapeProto Shape
protoShape = Resetter TensorShapeProto (f TensorShapeProto) Shape (f Shape)
-> LensLike' f TensorShapeProto Shape
forall s t a b. Resetter s t a b -> (a -> b) -> s -> t
under ((TensorShapeProto -> Shape)
-> (Shape -> TensorShapeProto)
-> Adapter TensorShapeProto TensorShapeProto Shape Shape
forall s a b t. (s -> a) -> (b -> t) -> Adapter s t a b
adapter TensorShapeProto -> Shape
protoToShape Shape -> TensorShapeProto
shapeToProto)
where
protoToShape :: TensorShapeProto -> Shape
protoToShape p :: TensorShapeProto
p = Shape -> Maybe Shape -> Shape
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> Shape
forall a. HasCallStack => [Char] -> a
error [Char]
msg) (FoldLike
(Maybe Shape)
TensorShapeProto
TensorShapeProto
(Maybe Shape)
(Maybe Shape)
-> TensorShapeProto -> Maybe Shape
forall a s t b. FoldLike a s t a b -> s -> a
view FoldLike
(Maybe Shape)
TensorShapeProto
TensorShapeProto
(Maybe Shape)
(Maybe Shape)
Lens' TensorShapeProto (Maybe Shape)
protoMaybeShape TensorShapeProto
p)
where msg :: [Char]
msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TensorShapeProto -> [Char]
forall msg. Message msg => msg -> [Char]
showMessageShort TensorShapeProto
p
shapeToProto :: Shape -> TensorShapeProto
shapeToProto s' :: Shape
s' = TensorShapeProto
forall msg. Message msg => msg
defMessage TensorShapeProto
-> (TensorShapeProto -> TensorShapeProto) -> TensorShapeProto
forall s t. s -> (s -> t) -> t
& Lens' TensorShapeProto (Maybe Shape)
forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto (Maybe Shape)
protoMaybeShape (forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto (Maybe Shape))
-> Maybe Shape -> TensorShapeProto -> TensorShapeProto
forall s t a b. Setter s t a b -> b -> s -> t
.~ Shape -> Maybe Shape
forall a. a -> Maybe a
Just Shape
s'
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
protoMaybeShape :: LensLike' f TensorShapeProto (Maybe Shape)
protoMaybeShape = Resetter
TensorShapeProto
(f TensorShapeProto)
(Maybe Shape)
(f (Maybe Shape))
-> LensLike' f TensorShapeProto (Maybe Shape)
forall s t a b. Resetter s t a b -> (a -> b) -> s -> t
under ((TensorShapeProto -> Maybe Shape)
-> (Maybe Shape -> TensorShapeProto)
-> Adapter
TensorShapeProto TensorShapeProto (Maybe Shape) (Maybe Shape)
forall s a b t. (s -> a) -> (b -> t) -> Adapter s t a b
adapter TensorShapeProto -> Maybe Shape
protoToShape Maybe Shape -> TensorShapeProto
shapeToProto)
where
protoToShape :: TensorShapeProto -> Maybe Shape
protoToShape :: TensorShapeProto -> Maybe Shape
protoToShape p :: TensorShapeProto
p =
if FoldLike Bool TensorShapeProto TensorShapeProto Bool Bool
-> TensorShapeProto -> Bool
forall a s t b. FoldLike a s t a b -> s -> a
view FoldLike Bool TensorShapeProto TensorShapeProto Bool Bool
forall (f :: * -> *) s a.
(Functor f, HasField s "unknownRank" a) =>
LensLike' f s a
unknownRank TensorShapeProto
p
then Maybe Shape
forall a. Maybe a
Nothing
else Shape -> Maybe Shape
forall a. a -> Maybe a
Just ([Int64] -> Shape
Shape (TensorShapeProto
p TensorShapeProto
-> Fold TensorShapeProto TensorShapeProto Int64 Int64 -> [Int64]
forall s t a b. s -> Fold s t a b -> [a]
^.. LensLike' f TensorShapeProto [TensorShapeProto'Dim]
forall (f :: * -> *) s a.
(Functor f, HasField s "dim" a) =>
LensLike' f s a
dim LensLike' f TensorShapeProto [TensorShapeProto'Dim]
-> ((Int64 -> f Int64)
-> [TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
-> (Int64 -> f Int64)
-> TensorShapeProto
-> f TensorShapeProto
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TensorShapeProto'Dim -> f TensorShapeProto'Dim)
-> [TensorShapeProto'Dim] -> f [TensorShapeProto'Dim]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TensorShapeProto'Dim -> f TensorShapeProto'Dim)
-> [TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
-> ((Int64 -> f Int64)
-> TensorShapeProto'Dim -> f TensorShapeProto'Dim)
-> (Int64 -> f Int64)
-> [TensorShapeProto'Dim]
-> f [TensorShapeProto'Dim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int64 -> f Int64)
-> TensorShapeProto'Dim -> f TensorShapeProto'Dim
forall (f :: * -> *) s a.
(Functor f, HasField s "size" a) =>
LensLike' f s a
size))
shapeToProto :: Maybe Shape -> TensorShapeProto
shapeToProto :: Maybe Shape -> TensorShapeProto
shapeToProto Nothing =
TensorShapeProto
forall msg. Message msg => msg
defMessage TensorShapeProto
-> (TensorShapeProto -> TensorShapeProto) -> TensorShapeProto
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto Bool
forall (f :: * -> *) s a.
(Functor f, HasField s "unknownRank" a) =>
LensLike' f s a
unknownRank (forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto Bool)
-> Bool -> TensorShapeProto -> TensorShapeProto
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
True
shapeToProto (Just (Shape ds :: [Int64]
ds)) =
TensorShapeProto
forall msg. Message msg => msg
defMessage TensorShapeProto
-> (TensorShapeProto -> TensorShapeProto) -> TensorShapeProto
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto [TensorShapeProto'Dim]
forall (f :: * -> *) s a.
(Functor f, HasField s "dim" a) =>
LensLike' f s a
dim (forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto [TensorShapeProto'Dim])
-> [TensorShapeProto'Dim] -> TensorShapeProto -> TensorShapeProto
forall s t a b. Setter s t a b -> b -> s -> t
.~ (Int64 -> TensorShapeProto'Dim)
-> [Int64] -> [TensorShapeProto'Dim]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\d :: Int64
d -> TensorShapeProto'Dim
forall msg. Message msg => msg
defMessage TensorShapeProto'Dim
-> (TensorShapeProto'Dim -> TensorShapeProto'Dim)
-> TensorShapeProto'Dim
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto'Dim Int64
forall (f :: * -> *) s a.
(Functor f, HasField s "size" a) =>
LensLike' f s a
size (forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto'Dim Int64)
-> Int64 -> TensorShapeProto'Dim -> TensorShapeProto'Dim
forall s t a b. Setter s t a b -> b -> s -> t
.~ Int64
d) [Int64]
ds
class Attribute a where
attrLens :: Lens' AttrValue a
instance Attribute Float where
attrLens :: LensLike' f AttrValue Float
attrLens = LensLike' f AttrValue Float
forall (f :: * -> *) s a.
(Functor f, HasField s "f" a) =>
LensLike' f s a
f
instance Attribute ByteString where
attrLens :: LensLike' f AttrValue ByteString
attrLens = LensLike' f AttrValue ByteString
forall (f :: * -> *) s a.
(Functor f, HasField s "s" a) =>
LensLike' f s a
s
instance Attribute Int64 where
attrLens :: LensLike' f AttrValue Int64
attrLens = LensLike' f AttrValue Int64
forall (f :: * -> *) s a.
(Functor f, HasField s "i" a) =>
LensLike' f s a
i
instance Attribute DataType where
attrLens :: LensLike' f AttrValue DataType
attrLens = LensLike' f AttrValue DataType
forall (f :: * -> *) s a.
(Functor f, HasField s "type'" a) =>
LensLike' f s a
type'
instance Attribute TensorProto where
attrLens :: LensLike' f AttrValue TensorProto
attrLens = LensLike' f AttrValue TensorProto
forall (f :: * -> *) s a.
(Functor f, HasField s "tensor" a) =>
LensLike' f s a
tensor
instance Attribute Bool where
attrLens :: LensLike' f AttrValue Bool
attrLens = LensLike' f AttrValue Bool
forall (f :: * -> *) s a.
(Functor f, HasField s "b" a) =>
LensLike' f s a
b
instance Attribute Shape where
attrLens :: LensLike' f AttrValue Shape
attrLens = LensLike' f AttrValue TensorShapeProto
forall (f :: * -> *) s a.
(Functor f, HasField s "shape" a) =>
LensLike' f s a
shape LensLike' f AttrValue TensorShapeProto
-> ((Shape -> f Shape) -> TensorShapeProto -> f TensorShapeProto)
-> LensLike' f AttrValue Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Shape -> f Shape) -> TensorShapeProto -> f TensorShapeProto
Lens' TensorShapeProto Shape
protoShape
instance Attribute (Maybe Shape) where
attrLens :: LensLike' f AttrValue (Maybe Shape)
attrLens = LensLike' f AttrValue TensorShapeProto
forall (f :: * -> *) s a.
(Functor f, HasField s "shape" a) =>
LensLike' f s a
shape LensLike' f AttrValue TensorShapeProto
-> ((Maybe Shape -> f (Maybe Shape))
-> TensorShapeProto -> f TensorShapeProto)
-> LensLike' f AttrValue (Maybe Shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe Shape -> f (Maybe Shape))
-> TensorShapeProto -> f TensorShapeProto
Lens' TensorShapeProto (Maybe Shape)
protoMaybeShape
instance Attribute AttrValue'ListValue where
attrLens :: LensLike' f AttrValue AttrValue'ListValue
attrLens = LensLike' f AttrValue AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "list" a) =>
LensLike' f s a
list
instance Attribute [DataType] where
attrLens :: LensLike' f AttrValue [DataType]
attrLens = LensLike' f AttrValue AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "list" a) =>
LensLike' f s a
list LensLike' f AttrValue AttrValue'ListValue
-> (([DataType] -> f [DataType])
-> AttrValue'ListValue -> f AttrValue'ListValue)
-> LensLike' f AttrValue [DataType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([DataType] -> f [DataType])
-> AttrValue'ListValue -> f AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "type'" a) =>
LensLike' f s a
type'
instance Attribute [Int64] where
attrLens :: LensLike' f AttrValue [Int64]
attrLens = LensLike' f AttrValue AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "list" a) =>
LensLike' f s a
list LensLike' f AttrValue AttrValue'ListValue
-> (([Int64] -> f [Int64])
-> AttrValue'ListValue -> f AttrValue'ListValue)
-> LensLike' f AttrValue [Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Int64] -> f [Int64])
-> AttrValue'ListValue -> f AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "i" a) =>
LensLike' f s a
i
data ListOf f as where
Nil :: ListOf f '[]
(:/) :: f a -> ListOf f as -> ListOf f (a ': as)
infixr 5 :/
type family All f as :: Constraint where
All f '[] = ()
All f (a ': as) = (f a, All f as)
type family Map f as where
Map f '[] = '[]
Map f (a ': as) = f a ': Map f as
instance All Eq (Map f as) => Eq (ListOf f as) where
Nil == :: ListOf f as -> ListOf f as -> Bool
== Nil = Bool
True
(x :: f a
x :/ xs :: ListOf f as
xs) == (y :: f a
y :/ ys :: ListOf f as
ys) = f a
x f a -> f a -> Bool
forall a. Eq a => a -> a -> Bool
== f a
f a
y Bool -> Bool -> Bool
&& ListOf f as
xs ListOf f as -> ListOf f as -> Bool
forall a. Eq a => a -> a -> Bool
== ListOf f as
ListOf f as
ys
#if __GLASGOW_HASKELL__ < 800
_ == _ = False
#endif
instance All Show (Map f as) => Show (ListOf f as) where
showsPrec :: Int -> ListOf f as -> [Char] -> [Char]
showsPrec _ Nil = [Char] -> [Char] -> [Char]
showString "Nil"
showsPrec d :: Int
d (x :: f a
x :/ xs :: ListOf f as
xs) = Bool -> ([Char] -> [Char]) -> [Char] -> [Char]
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 10)
(([Char] -> [Char]) -> [Char] -> [Char])
-> ([Char] -> [Char]) -> [Char] -> [Char]
forall a b. (a -> b) -> a -> b
$ Int -> f a -> [Char] -> [Char]
forall a. Show a => Int -> a -> [Char] -> [Char]
showsPrec 6 f a
x ([Char] -> [Char]) -> ([Char] -> [Char]) -> [Char] -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> [Char] -> [Char]
showString " :/ "
([Char] -> [Char]) -> ([Char] -> [Char]) -> [Char] -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ListOf f as -> [Char] -> [Char]
forall a. Show a => Int -> a -> [Char] -> [Char]
showsPrec 6 ListOf f as
xs
type List = ListOf Identity
(/:/) :: a -> List as -> List (a ': as)
/:/ :: a -> List as -> List (a : as)
(/:/) = Identity a -> List as -> List (a : as)
forall (f :: * -> *) a (as :: [*]).
f a -> ListOf f as -> ListOf f (a : as)
(:/) (Identity a -> List as -> List (a : as))
-> (a -> Identity a) -> a -> List as -> List (a : as)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Identity a
forall a. a -> Identity a
Identity
infixr 5 /:/
type OneOf ts a
= (TensorType a, TensorTypes' ts, NoneOf (AllTensorTypes \\ ts) a)
type OneOfs ts as = (TensorTypes as, TensorTypes' ts,
NoneOfs (AllTensorTypes \\ ts) as)
type family NoneOfs ts as :: Constraint where
NoneOfs ts '[] = ()
NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as)
data TensorTypeProxy a where
TensorTypeProxy :: TensorType a => TensorTypeProxy a
type TensorTypeList = ListOf TensorTypeProxy
fromTensorTypeList :: TensorTypeList ts -> [DataType]
fromTensorTypeList :: TensorTypeList ts -> [DataType]
fromTensorTypeList Nil = []
fromTensorTypeList ((TensorTypeProxy a
TensorTypeProxy :: TensorTypeProxy t) :/ ts :: ListOf TensorTypeProxy as
ts)
= a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: t) DataType -> [DataType] -> [DataType]
forall a. a -> [a] -> [a]
: ListOf TensorTypeProxy as -> [DataType]
forall (ts :: [*]). TensorTypeList ts -> [DataType]
fromTensorTypeList ListOf TensorTypeProxy as
ts
fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType]
fromTensorTypes :: Proxy as -> [DataType]
fromTensorTypes _ = TensorTypeList as -> [DataType]
forall (ts :: [*]). TensorTypeList ts -> [DataType]
fromTensorTypeList (TensorTypeList as
forall (ts :: [*]). TensorTypes ts => TensorTypeList ts
tensorTypes :: TensorTypeList as)
class TensorTypes (ts :: [*]) where
tensorTypes :: TensorTypeList ts
instance TensorTypes '[] where
tensorTypes :: TensorTypeList '[]
tensorTypes = TensorTypeList '[]
forall (f :: * -> *). ListOf f '[]
Nil
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where
tensorTypes :: TensorTypeList (t : ts)
tensorTypes = TensorTypeProxy t
forall a. TensorType a => TensorTypeProxy a
TensorTypeProxy TensorTypeProxy t
-> ListOf TensorTypeProxy ts -> TensorTypeList (t : ts)
forall (f :: * -> *) a (as :: [*]).
f a -> ListOf f as -> ListOf f (a : as)
:/ ListOf TensorTypeProxy ts
forall (ts :: [*]). TensorTypes ts => TensorTypeList ts
tensorTypes
type family TensorTypes' (ts :: [*]) :: Constraint where
TensorTypes' (t1 ': t2 ': t3 ': t4 ': ts)
= (TensorType t1, TensorType t2, TensorType t3, TensorType t4
, TensorTypes' ts)
TensorTypes' (t1 ': t2 ': t3 ': ts)
= (TensorType t1, TensorType t2, TensorType t3, TensorTypes' ts)
TensorTypes' (t1 ': t2 ': ts)
= (TensorType t1, TensorType t2, TensorTypes' ts)
TensorTypes' (t ': ts) = (TensorType t, TensorTypes' ts)
TensorTypes' '[] = ()
type family a /= b :: Constraint where
a /= a = TypeError a ~ ExcludedCase
a /= b = ()
data TypeError a
data ExcludedCase
type AllTensorTypes =
'[ Float
, Double
, Int8
, Int16
, Int32
, Int64
, Word8
, Word16
, ByteString
, Bool
]
type family Delete a as where
Delete a '[] = '[]
Delete a (a ': as) = Delete a as
Delete a (b ': as) = b ': Delete a as
type family as \\ bs where
as \\ '[] = as
as \\ (b ': bs) = Delete b as \\ bs
type family NoneOf ts a :: Constraint where
NoneOf (t1 ': t2 ': t3 ': t4 ': ts) a
= (a /= t1, a /= t2, a /= t3, a /= t4, NoneOf ts a)
NoneOf (t1 ': t2 ': t3 ': ts) a = (a /= t1, a /= t2, a /= t3, NoneOf ts a)
NoneOf (t1 ': t2 ': ts) a = (a /= t1, a /= t2, NoneOf ts a)
NoneOf (t1 ': ts) a = (a /= t1, NoneOf ts a)
NoneOf '[] a = ()