-- | This module helps parse the proto OpDef into a Haskell type which is more
-- descriptive of how the attributes and arguments will be used in the
-- generated code.
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module TensorFlow.OpGen.ParsedOp
    ( ParsedOp(..)
    , Name(..)
    , HaskellName(..)
    , TFName(..)
    , Attr(..)
    , AttrType(..)
    , AttrBaseType(..)
    , TypeParam(..)
    , ParsedArg(..)
    , ParsedArgCase(..)
    , ArgType(..)
    , ArgKind(..)
    , parseOp
    , camelCase
    ) where

import Data.Char (toUpper, toLower)
import Data.List (sortBy)
import Data.List.NonEmpty (NonEmpty, nonEmpty)
import Data.Maybe (mapMaybe)
import Data.Ord (comparing)
import qualified Data.Set as Set
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 ((^.))
import Proto.Tensorflow.Core.Framework.AttrValue_Fields (list)
import Proto.Tensorflow.Core.Framework.OpDef
    ( OpDef
    , OpDef'ArgDef
    , OpDef'AttrDef
    )
import Proto.Tensorflow.Core.Framework.OpDef_Fields
    ( allowedValues
    , attr
    , maybe'defaultValue
    , description
    , name
    , inputArg
    , isRef
    , isStateful
    , outputArg
    , summary
    , typeListAttr
    , numberAttr
    , typeAttr
    , type'
    )

import Proto.Tensorflow.Core.Framework.Types (DataType(DT_RESOURCE))

data ParsedOp = ParsedOp
    { ParsedOp -> Name
parsedOpName :: Name
    , ParsedOp -> Text
parsedOpSummary :: Text
    , ParsedOp -> Text
parsedOpDescription :: Text
    , ParsedOp -> [ParsedArg]
parsedInputs :: [ParsedArg]
    , ParsedOp -> [ParsedArg]
parsedOutputs :: [ParsedArg]
    , ParsedOp -> [Attr AttrType]
explicitInputAttrs :: [Attr AttrType]
        -- ^ Attributes that must be set explicitly when creating the op.
        -- Associated with the type of the attribute.
    , ParsedOp -> [Attr TypeParam]
inferredTypeAttrs :: [Attr TypeParam]
        -- ^ Attributes that are type parameters.
    , ParsedOp -> [Attr (NonEmpty Name)]
inferredListSizeAttrs :: [Attr (NonEmpty Name)]
        -- Attributes which are list sizes (ints) that are inferred automatically
        -- from one or more of the input tensors.
        -- Associated with the list of tensors whose size it describes.
    , ParsedOp -> Bool
parsedOpIsMonadic :: Bool
        -- ^ Whether this op is stateful or takes a stateful input.  Such ops
        -- should not be CSE'd and must be monadic in our API (i.e., return a
        -- Build action).
    }

data Name = Name
    { Name -> HaskellName
haskellName :: HaskellName
    , Name -> TFName
tfName :: TFName
    }

-- | A raw name as specified in the OpDef proto.
newtype TFName = TFName { TFName -> Text
unTFName :: Text }
    deriving (TFName -> TFName -> Bool
(TFName -> TFName -> Bool)
-> (TFName -> TFName -> Bool) -> Eq TFName
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TFName -> TFName -> Bool
$c/= :: TFName -> TFName -> Bool
== :: TFName -> TFName -> Bool
$c== :: TFName -> TFName -> Bool
Eq, Eq TFName
Eq TFName =>
(TFName -> TFName -> Ordering)
-> (TFName -> TFName -> Bool)
-> (TFName -> TFName -> Bool)
-> (TFName -> TFName -> Bool)
-> (TFName -> TFName -> Bool)
-> (TFName -> TFName -> TFName)
-> (TFName -> TFName -> TFName)
-> Ord TFName
TFName -> TFName -> Bool
TFName -> TFName -> Ordering
TFName -> TFName -> TFName
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TFName -> TFName -> TFName
$cmin :: TFName -> TFName -> TFName
max :: TFName -> TFName -> TFName
$cmax :: TFName -> TFName -> TFName
>= :: TFName -> TFName -> Bool
$c>= :: TFName -> TFName -> Bool
> :: TFName -> TFName -> Bool
$c> :: TFName -> TFName -> Bool
<= :: TFName -> TFName -> Bool
$c<= :: TFName -> TFName -> Bool
< :: TFName -> TFName -> Bool
$c< :: TFName -> TFName -> Bool
compare :: TFName -> TFName -> Ordering
$ccompare :: TFName -> TFName -> Ordering
$cp1Ord :: Eq TFName
Ord)

-- | A name that's appropriate for a variable in a Haskell source file.
newtype HaskellName = HaskellName { HaskellName -> Text
unHaskellName :: Text }

-- | A named attribute, associated with some information about it.
data Attr a = Attr
    { Attr a -> Name
attrName :: Name
    , Attr a -> Text
attrDescription :: Text
    , Attr a -> a
attrInfo :: a
    }

-- | The type of an attribute.
data AttrType = AttrSingle AttrBaseType
                | AttrList AttrBaseType
                deriving AttrType -> AttrType -> Bool
(AttrType -> AttrType -> Bool)
-> (AttrType -> AttrType -> Bool) -> Eq AttrType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AttrType -> AttrType -> Bool
$c/= :: AttrType -> AttrType -> Bool
== :: AttrType -> AttrType -> Bool
$c== :: AttrType -> AttrType -> Bool
Eq

data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
                | AttrType | AttrShape | AttrTensor
                deriving AttrBaseType -> AttrBaseType -> Bool
(AttrBaseType -> AttrBaseType -> Bool)
-> (AttrBaseType -> AttrBaseType -> Bool) -> Eq AttrBaseType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AttrBaseType -> AttrBaseType -> Bool
$c/= :: AttrBaseType -> AttrBaseType -> Bool
== :: AttrBaseType -> AttrBaseType -> Bool
$c== :: AttrBaseType -> AttrBaseType -> Bool
Eq

data TypeParam = TypeParam
    { TypeParam -> Bool
typeParamIsList :: Bool
    , TypeParam -> Maybe (NonEmpty DataType)
typeParamRestrictions :: Maybe (NonEmpty DataType)
        -- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
        -- If 'Nothing', then any type is acceptable.
    }

-- | An input or output argument (Tensor) for an op.
data ParsedArg = ParsedArg
    { ParsedArg -> Name
parsedArgName :: Name
    , ParsedArg -> Text
parsedArgDescription :: Text
    , ParsedArg -> ParsedArgCase
parsedArgCase :: ParsedArgCase
    }

data ParsedArgCase
    = SimpleArg { ParsedArgCase -> ArgType
argType :: ArgType, ParsedArgCase -> ArgKind
argKind :: ArgKind }
    | ListArg
        { ParsedArgCase -> Name
argLength :: Name  -- ^ The attribute that specifies this list's length.
        , argType :: ArgType
        , argKind :: ArgKind
        }
    | MixedListArg { ParsedArgCase -> Name
argTypeAttr :: Name, argKind :: ArgKind }
        -- ^ A heterogeneous list.

maybeArgType :: ParsedArgCase -> Maybe ArgType
maybeArgType :: ParsedArgCase -> Maybe ArgType
maybeArgType MixedListArg{} = Maybe ArgType
forall a. Maybe a
Nothing
maybeArgType a :: ParsedArgCase
a = ArgType -> Maybe ArgType
forall a. a -> Maybe a
Just (ArgType -> Maybe ArgType) -> ArgType -> Maybe ArgType
forall a b. (a -> b) -> a -> b
$ ParsedArgCase -> ArgType
argType ParsedArgCase
a

-- | The type of an argument.
data ArgType
    = ArgTypeFixed DataType -- ^ A fixed type.
    | ArgTypeAttr Name  -- ^ A type that depends on an attribute.

-- The kind of an op input or output (not including the argument type `a`).
data ArgKind
    = ArgTensorRef -- Tensor Ref a
    | ArgTensorValue -- Tensor Value a
    | ArgTensorBuild -- Tensor Build a
    | ArgSomeTensor Text -- Tensor v a; the Text is the variable 'v'.
    deriving (ArgKind -> ArgKind -> Bool
(ArgKind -> ArgKind -> Bool)
-> (ArgKind -> ArgKind -> Bool) -> Eq ArgKind
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArgKind -> ArgKind -> Bool
$c/= :: ArgKind -> ArgKind -> Bool
== :: ArgKind -> ArgKind -> Bool
$c== :: ArgKind -> ArgKind -> Bool
Eq)

isRefCase :: ParsedArgCase -> Bool
isRefCase :: ParsedArgCase -> Bool
isRefCase a :: ParsedArgCase
a
    | ArgKind
ArgTensorRef <- ParsedArgCase -> ArgKind
argKind ParsedArgCase
a = Bool
True
    | Just (ArgTypeFixed DT_RESOURCE) <- ParsedArgCase -> Maybe ArgType
maybeArgType ParsedArgCase
a = Bool
True
    | Bool
otherwise = Bool
False

makeName :: Text -> Name
makeName :: Text -> Name
makeName n :: Text
n = Name :: HaskellName -> TFName -> Name
Name
    { haskellName :: HaskellName
haskellName = Text -> HaskellName
HaskellName (Text -> HaskellName) -> Text -> HaskellName
forall a b. (a -> b) -> a -> b
$ Text -> Text
fixReservedName (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Text -> Text
lowCase Text
n
    , tfName :: TFName
tfName = Text -> TFName
TFName Text
n
    }

-- | Change a name so it doesn't conflict with any Haskell keywords.
fixReservedName :: Text -> Text
fixReservedName :: Text -> Text
fixReservedName n :: Text
n
    | Text
n Text -> Set Text -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Text
reservedKeywords = Text
n Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> "'"
    | Bool
otherwise = Text
n

reservedKeywords :: Set.Set Text
reservedKeywords :: Set Text
reservedKeywords = [Text] -> Set Text
forall a. Ord a => [a] -> Set a
Set.fromList ([Text] -> Set Text) -> [Text] -> Set Text
forall a b. (a -> b) -> a -> b
$
    -- Haskell2010 keywords:
    -- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
    -- We don't include keywords that are allowed to be variable names,
    -- in particular: "as", "forall", and "hiding".
    [ "case"
    , "class"
    , "data"
    , "default"
    , "deriving"
    , "do"
    , "else"
    , "foreign"
    , "if"
    , "import"
    , "in"
    , "infix"
    , "infixl"
    , "infixr"
    , "instance"
    , "let"
    , "module"
    , "newtype"
    , "of"
    , "then"
    , "type"
    , "where"
    ]
    [Text] -> [Text] -> [Text]
forall a. [a] -> [a] -> [a]
++  -- Nonstandard extensions
    [ "mdo"   -- RecursiveDo
    , "rec"   -- Arrows, RecursiveDo
    , "proc"  -- Arrows
    ]

-- | Lower-case the given text.
lowCase :: Text -> Text
lowCase :: Text -> Text
lowCase = (Char -> Char) -> Text -> Text
forceCase Char -> Char
toLower

forceCase :: (Char -> Char) -> Text -> Text
forceCase :: (Char -> Char) -> Text -> Text
forceCase convert :: Char -> Char
convert s :: Text
s = Text -> ((Char, Text) -> Text) -> Maybe (Char, Text) -> Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe "" (\(c :: Char
c, cs :: Text
cs) -> Char -> Text -> Text
Text.cons (Char -> Char
convert Char
c) Text
cs)
                      (Text -> Maybe (Char, Text)
Text.uncons Text
s)

camelCase :: Text -> Text
camelCase :: Text -> Text
camelCase s :: Text
s = [Text] -> Text
Text.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (Text -> Text) -> [Text] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Text -> Text
upCase
                          ([Text] -> [Text]) -> [Text] -> [Text]
forall a b. (a -> b) -> a -> b
$ Text -> Text -> [Text]
Text.splitOn "_" Text
s

-- | Upper-case the given text.
upCase :: Text -> Text
upCase :: Text -> Text
upCase = (Char -> Char) -> Text -> Text
forceCase Char -> Char
toUpper


parseOp :: OpDef -> ParsedOp
parseOp :: OpDef -> ParsedOp
parseOp o :: OpDef
o = ParsedOp :: Name
-> Text
-> Text
-> [ParsedArg]
-> [ParsedArg]
-> [Attr AttrType]
-> [Attr TypeParam]
-> [Attr (NonEmpty Name)]
-> Bool
-> ParsedOp
ParsedOp
    { parsedOpName :: Name
parsedOpName = Text -> Name
makeName (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ OpDef
o OpDef -> FoldLike Text OpDef OpDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef OpDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name
    , parsedOpSummary :: Text
parsedOpSummary = OpDef
o OpDef -> FoldLike Text OpDef OpDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef OpDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "summary" a) =>
LensLike' f s a
summary
    , parsedOpDescription :: Text
parsedOpDescription = OpDef
o OpDef -> FoldLike Text OpDef OpDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef OpDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "description" a) =>
LensLike' f s a
description
    , ..
    }
  where
    parsedOpIsMonadic :: Bool
parsedOpIsMonadic = OpDef
o OpDef -> FoldLike Bool OpDef OpDef Bool Bool -> Bool
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Bool OpDef OpDef Bool Bool
forall (f :: * -> *) s a.
(Functor f, HasField s "isStateful" a) =>
LensLike' f s a
isStateful
                    Bool -> Bool -> Bool
|| (ParsedArg -> Bool) -> [ParsedArg] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (ParsedArgCase -> Bool
isRefCase (ParsedArgCase -> Bool)
-> (ParsedArg -> ParsedArgCase) -> ParsedArg -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParsedArg -> ParsedArgCase
parsedArgCase) [ParsedArg]
parsedInputs
                    Bool -> Bool -> Bool
|| [OpDef'ArgDef] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (OpDef
o OpDef
-> FoldLike
     [OpDef'ArgDef] OpDef OpDef [OpDef'ArgDef] [OpDef'ArgDef]
-> [OpDef'ArgDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [OpDef'ArgDef] OpDef OpDef [OpDef'ArgDef] [OpDef'ArgDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "outputArg" a) =>
LensLike' f s a
outputArg)
    parsedInputs :: [ParsedArg]
parsedInputs = (Text -> OpDef'ArgDef -> ParsedArg)
-> [Text] -> [OpDef'ArgDef] -> [ParsedArg]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\t :: Text
t a :: OpDef'ArgDef
a -> OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg OpDef'ArgDef
a (Text -> OpDef'ArgDef -> ArgKind
inputTensorKind Text
t OpDef'ArgDef
a))
                                        [Text]
tensorKindParams (OpDef
o OpDef
-> FoldLike
     [OpDef'ArgDef] OpDef OpDef [OpDef'ArgDef] [OpDef'ArgDef]
-> [OpDef'ArgDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [OpDef'ArgDef] OpDef OpDef [OpDef'ArgDef] [OpDef'ArgDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "inputArg" a) =>
LensLike' f s a
inputArg) 
    tensorKindParams :: [Text]
tensorKindParams = ["v'" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Text.pack (Integer -> String
forall a. Show a => a -> String
show Integer
x) | Integer
x <- [1::Integer ..]]
    parsedOutputs :: [ParsedArg]
parsedOutputs = (OpDef'ArgDef -> ParsedArg) -> [OpDef'ArgDef] -> [ParsedArg]
forall a b. (a -> b) -> [a] -> [b]
map (\a :: OpDef'ArgDef
a -> OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg OpDef'ArgDef
a (Bool -> OpDef'ArgDef -> ArgKind
outputTensorKind Bool
parsedOpIsMonadic OpDef'ArgDef
a))
                        (OpDef
o OpDef
-> FoldLike
     [OpDef'ArgDef] OpDef OpDef [OpDef'ArgDef] [OpDef'ArgDef]
-> [OpDef'ArgDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [OpDef'ArgDef] OpDef OpDef [OpDef'ArgDef] [OpDef'ArgDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "outputArg" a) =>
LensLike' f s a
outputArg)
    -- Integer attributes that can be inferred from the size of at least one
    -- input list.
    inferredListSizeAttrs :: [Attr (NonEmpty Name)]
inferredListSizeAttrs = (OpDef'AttrDef -> Maybe (NonEmpty Name))
-> [OpDef'AttrDef] -> [Attr (NonEmpty Name)]
forall a. (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs ([ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
getInferredListSizeAttr [ParsedArg]
parsedInputs)
                                ([OpDef'AttrDef] -> [Attr (NonEmpty Name)])
-> [OpDef'AttrDef] -> [Attr (NonEmpty Name)]
forall a b. (a -> b) -> a -> b
$ OpDef
o OpDef
-> FoldLike
     [OpDef'AttrDef] OpDef OpDef [OpDef'AttrDef] [OpDef'AttrDef]
-> [OpDef'AttrDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike
  [OpDef'AttrDef] OpDef OpDef [OpDef'AttrDef] [OpDef'AttrDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "attr" a) =>
LensLike' f s a
attr
    implicitAttrs :: Set TFName
implicitAttrs = [TFName] -> Set TFName
forall a. Ord a => [a] -> Set a
Set.fromList ([TFName] -> Set TFName) -> [TFName] -> Set TFName
forall a b. (a -> b) -> a -> b
$ (Name -> TFName) -> [Name] -> [TFName]
forall a b. (a -> b) -> [a] -> [b]
map Name -> TFName
tfName ([Name] -> [TFName]) -> [Name] -> [TFName]
forall a b. (a -> b) -> a -> b
$
                        (Attr TypeParam -> Name) -> [Attr TypeParam] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map Attr TypeParam -> Name
forall a. Attr a -> Name
attrName [Attr TypeParam]
inferredTypeAttrs
                            [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ (Attr (NonEmpty Name) -> Name) -> [Attr (NonEmpty Name)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map Attr (NonEmpty Name) -> Name
forall a. Attr a -> Name
attrName [Attr (NonEmpty Name)]
inferredListSizeAttrs
    inferredTypeAttrs :: [Attr TypeParam]
inferredTypeAttrs = (OpDef'AttrDef -> Maybe TypeParam)
-> [OpDef'AttrDef] -> [Attr TypeParam]
forall a. (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs (Set TFName -> OpDef'AttrDef -> Maybe TypeParam
getInferredTypeAttr Set TFName
argTypeParams) ([OpDef'AttrDef] -> [Attr TypeParam])
-> [OpDef'AttrDef] -> [Attr TypeParam]
forall a b. (a -> b) -> a -> b
$ OpDef
o OpDef
-> FoldLike
     [OpDef'AttrDef] OpDef OpDef [OpDef'AttrDef] [OpDef'AttrDef]
-> [OpDef'AttrDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike
  [OpDef'AttrDef] OpDef OpDef [OpDef'AttrDef] [OpDef'AttrDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "attr" a) =>
LensLike' f s a
attr
    argTypeParams :: Set TFName
argTypeParams = [TFName] -> Set TFName
forall a. Ord a => [a] -> Set a
Set.fromList ([TFName] -> Set TFName) -> [TFName] -> Set TFName
forall a b. (a -> b) -> a -> b
$ (Name -> TFName) -> [Name] -> [TFName]
forall a b. (a -> b) -> [a] -> [b]
map Name -> TFName
tfName ([Name] -> [TFName]) -> [Name] -> [TFName]
forall a b. (a -> b) -> a -> b
$
                        (ParsedArg -> Maybe Name) -> [ParsedArg] -> [Name]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (ParsedArgCase -> Maybe Name
getArgTypeParam (ParsedArgCase -> Maybe Name)
-> (ParsedArg -> ParsedArgCase) -> ParsedArg -> Maybe Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParsedArg -> ParsedArgCase
parsedArgCase) ([ParsedArg] -> [Name]) -> [ParsedArg] -> [Name]
forall a b. (a -> b) -> a -> b
$
                            [ParsedArg]
parsedInputs [ParsedArg] -> [ParsedArg] -> [ParsedArg]
forall a. [a] -> [a] -> [a]
++ [ParsedArg]
parsedOutputs
    -- Attributes that can't be inferred and don't have defaults, so must be
    -- passed as separate arguments to the op.
    explicitInputAttrs :: [Attr AttrType]
explicitInputAttrs = (Attr AttrType -> Attr AttrType -> Ordering)
-> [Attr AttrType] -> [Attr AttrType]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ((Attr AttrType -> TFName)
-> Attr AttrType -> Attr AttrType -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (Name -> TFName
tfName (Name -> TFName)
-> (Attr AttrType -> Name) -> Attr AttrType -> TFName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Attr AttrType -> Name
forall a. Attr a -> Name
attrName))
                        ([Attr AttrType] -> [Attr AttrType])
-> [Attr AttrType] -> [Attr AttrType]
forall a b. (a -> b) -> a -> b
$ (OpDef'AttrDef -> Maybe AttrType)
-> [OpDef'AttrDef] -> [Attr AttrType]
forall a. (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs (OpDef -> Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr OpDef
o Set TFName
implicitAttrs)
                        ([OpDef'AttrDef] -> [Attr AttrType])
-> [OpDef'AttrDef] -> [Attr AttrType]
forall a b. (a -> b) -> a -> b
$ OpDef
o OpDef
-> FoldLike
     [OpDef'AttrDef] OpDef OpDef [OpDef'AttrDef] [OpDef'AttrDef]
-> [OpDef'AttrDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike
  [OpDef'AttrDef] OpDef OpDef [OpDef'AttrDef] [OpDef'AttrDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "attr" a) =>
LensLike' f s a
attr

-- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: Text -> OpDef'ArgDef -> ArgKind
inputTensorKind :: Text -> OpDef'ArgDef -> ArgKind
inputTensorKind v :: Text
v a :: OpDef'ArgDef
a
    | OpDef'ArgDef
a OpDef'ArgDef
-> FoldLike Bool OpDef'ArgDef OpDef'ArgDef Bool Bool -> Bool
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Bool OpDef'ArgDef OpDef'ArgDef Bool Bool
forall (f :: * -> *) s a.
(Functor f, HasField s "isRef" a) =>
LensLike' f s a
isRef = ArgKind
ArgTensorRef
    | Bool
otherwise = Text -> ArgKind
ArgSomeTensor Text
v

outputTensorKind :: Bool -> OpDef'ArgDef -> ArgKind
outputTensorKind :: Bool -> OpDef'ArgDef -> ArgKind
outputTensorKind isMonadic :: Bool
isMonadic a :: OpDef'ArgDef
a
    | OpDef'ArgDef
a OpDef'ArgDef
-> FoldLike Bool OpDef'ArgDef OpDef'ArgDef Bool Bool -> Bool
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Bool OpDef'ArgDef OpDef'ArgDef Bool Bool
forall (f :: * -> *) s a.
(Functor f, HasField s "isRef" a) =>
LensLike' f s a
isRef = ArgKind
ArgTensorRef
    | Bool
isMonadic = ArgKind
ArgTensorValue
    | Bool
otherwise = ArgKind
ArgTensorBuild

getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr :: OpDef -> Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr o :: OpDef
o implicitAttrs :: Set TFName
implicitAttrs a :: OpDef'AttrDef
a
    | Text -> TFName
TFName (OpDef'AttrDef
a OpDef'AttrDef
-> FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name) TFName -> Set TFName -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.notMember` Set TFName
implicitAttrs
    , OpDef'AttrDef
a OpDef'AttrDef
-> FoldLike
     (Maybe AttrValue)
     OpDef'AttrDef
     OpDef'AttrDef
     (Maybe AttrValue)
     (Maybe AttrValue)
-> Maybe AttrValue
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike
  (Maybe AttrValue)
  OpDef'AttrDef
  OpDef'AttrDef
  (Maybe AttrValue)
  (Maybe AttrValue)
forall (f :: * -> *) s a.
(Functor f, HasField s "maybe'defaultValue" a) =>
LensLike' f s a
maybe'defaultValue Maybe AttrValue -> Maybe AttrValue -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe AttrValue
forall a. Maybe a
Nothing
    , AttrType
t <- OpDef -> Text -> AttrType
parseAttrType OpDef
o (OpDef'AttrDef
a OpDef'AttrDef
-> FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "type'" a) =>
LensLike' f s a
type')
    , AttrType
t AttrType -> [AttrType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (AttrBaseType -> AttrType) -> [AttrBaseType] -> [AttrType]
forall a b. (a -> b) -> [a] -> [b]
map AttrBaseType -> AttrType
AttrSingle
                    [AttrBaseType
AttrBool, AttrBaseType
AttrInt64, AttrBaseType
AttrFloat, AttrBaseType
AttrType, AttrBaseType
AttrShape, AttrBaseType
AttrBytes]
                [AttrType] -> [AttrType] -> [AttrType]
forall a. [a] -> [a] -> [a]
++ [AttrBaseType -> AttrType
AttrList AttrBaseType
AttrType] = AttrType -> Maybe AttrType
forall a. a -> Maybe a
Just AttrType
t
    | Bool
otherwise = Maybe AttrType
forall a. Maybe a
Nothing

getInferredTypeAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe TypeParam
getInferredTypeAttr :: Set TFName -> OpDef'AttrDef -> Maybe TypeParam
getInferredTypeAttr argTypeParams :: Set TFName
argTypeParams a :: OpDef'AttrDef
a
    | Text -> TFName
TFName (OpDef'AttrDef
a OpDef'AttrDef
-> FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name) TFName -> Set TFName -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Set TFName
argTypeParams = Maybe TypeParam
forall a. Maybe a
Nothing
    | OpDef'AttrDef
a OpDef'AttrDef
-> FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "type'" a) =>
LensLike' f s a
type' Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "type" = TypeParam -> Maybe TypeParam
forall a. a -> Maybe a
Just (TypeParam -> Maybe TypeParam) -> TypeParam -> Maybe TypeParam
forall a b. (a -> b) -> a -> b
$ Bool -> Maybe (NonEmpty DataType) -> TypeParam
TypeParam Bool
False Maybe (NonEmpty DataType)
allowed
    | OpDef'AttrDef
a OpDef'AttrDef
-> FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "type'" a) =>
LensLike' f s a
type' Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "list(type)" = TypeParam -> Maybe TypeParam
forall a. a -> Maybe a
Just (TypeParam -> Maybe TypeParam) -> TypeParam -> Maybe TypeParam
forall a b. (a -> b) -> a -> b
$ Bool -> Maybe (NonEmpty DataType) -> TypeParam
TypeParam Bool
True Maybe (NonEmpty DataType)
allowed
    | Bool
otherwise = Maybe TypeParam
forall a. Maybe a
Nothing
  where
    allowed :: Maybe (NonEmpty DataType)
allowed = [DataType] -> Maybe (NonEmpty DataType)
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty (OpDef'AttrDef
a OpDef'AttrDef
-> FoldLike
     [DataType] OpDef'AttrDef OpDef'AttrDef [DataType] [DataType]
-> [DataType]
forall s a t b. s -> FoldLike a s t a b -> a
^. LensLike' (Constant [DataType]) OpDef'AttrDef AttrValue
forall (f :: * -> *) s a.
(Functor f, HasField s "allowedValues" a) =>
LensLike' f s a
allowedValues LensLike' (Constant [DataType]) OpDef'AttrDef AttrValue
-> (([DataType] -> Constant [DataType] [DataType])
    -> AttrValue -> Constant [DataType] AttrValue)
-> FoldLike
     [DataType] OpDef'AttrDef OpDef'AttrDef [DataType] [DataType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LensLike' (Constant [DataType]) AttrValue AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "list" a) =>
LensLike' f s a
list LensLike' (Constant [DataType]) AttrValue AttrValue'ListValue
-> (([DataType] -> Constant [DataType] [DataType])
    -> AttrValue'ListValue -> Constant [DataType] AttrValue'ListValue)
-> ([DataType] -> Constant [DataType] [DataType])
-> AttrValue
-> Constant [DataType] AttrValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([DataType] -> Constant [DataType] [DataType])
-> AttrValue'ListValue -> Constant [DataType] AttrValue'ListValue
forall (f :: * -> *) s a.
(Functor f, HasField s "type'" a) =>
LensLike' f s a
type')

getArgTypeParam :: ParsedArgCase -> Maybe Name
getArgTypeParam :: ParsedArgCase -> Maybe Name
getArgTypeParam SimpleArg { argType :: ParsedArgCase -> ArgType
argType = ArgTypeAttr n :: Name
n} = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
n
getArgTypeParam ListArg { argType :: ParsedArgCase -> ArgType
argType = ArgTypeAttr n :: Name
n} = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
n
getArgTypeParam MixedListArg { argTypeAttr :: ParsedArgCase -> Name
argTypeAttr = Name
n } = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
n
getArgTypeParam _ = Maybe Name
forall a. Maybe a
Nothing

getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
getInferredListSizeAttr inputs :: [ParsedArg]
inputs a :: OpDef'AttrDef
a
    | OpDef'AttrDef
a OpDef'AttrDef
-> FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "type'" a) =>
LensLike' f s a
type' Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "int"
        = [Name] -> Maybe (NonEmpty Name)
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty [Name
t | ParsedArg { parsedArgName :: ParsedArg -> Name
parsedArgName = Name
t
                                  , parsedArgCase :: ParsedArg -> ParsedArgCase
parsedArgCase
                                        = ListArg { argLength :: ParsedArgCase -> Name
argLength = Name
n }
                                  } <- [ParsedArg]
inputs
                      , Text -> TFName
TFName (OpDef'AttrDef
a OpDef'AttrDef
-> FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name) TFName -> TFName -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> TFName
tfName Name
n]
    | Bool
otherwise = Maybe (NonEmpty Name)
forall a. Maybe a
Nothing

-- | Like mapMaybe, but associates the attribute name/description with the given info.
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs f :: OpDef'AttrDef -> Maybe a
f = (OpDef'AttrDef -> Maybe (Attr a)) -> [OpDef'AttrDef] -> [Attr a]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((OpDef'AttrDef -> Maybe (Attr a)) -> [OpDef'AttrDef] -> [Attr a])
-> (OpDef'AttrDef -> Maybe (Attr a)) -> [OpDef'AttrDef] -> [Attr a]
forall a b. (a -> b) -> a -> b
$ \a :: OpDef'AttrDef
a -> do
                            a
x <- OpDef'AttrDef -> Maybe a
f OpDef'AttrDef
a
                            Attr a -> Maybe (Attr a)
forall a. a -> Maybe a
Just Attr :: forall a. Name -> Text -> a -> Attr a
Attr
                                { attrName :: Name
attrName = Text -> Name
makeName (OpDef'AttrDef
a OpDef'AttrDef
-> FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name)
                                , attrDescription :: Text
attrDescription = OpDef'AttrDef
a OpDef'AttrDef
-> FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'AttrDef OpDef'AttrDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "description" a) =>
LensLike' f s a
description
                                , attrInfo :: a
attrInfo = a
x
                                }

parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg a :: OpDef'ArgDef
a tKind :: ArgKind
tKind = ParsedArg :: Name -> Text -> ParsedArgCase -> ParsedArg
ParsedArg
    { parsedArgName :: Name
parsedArgName = Text -> Name
makeName (OpDef'ArgDef
a OpDef'ArgDef
-> FoldLike Text OpDef'ArgDef OpDef'ArgDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'ArgDef OpDef'ArgDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name)
    , parsedArgDescription :: Text
parsedArgDescription = OpDef'ArgDef
a OpDef'ArgDef
-> FoldLike Text OpDef'ArgDef OpDef'ArgDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'ArgDef OpDef'ArgDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "description" a) =>
LensLike' f s a
description
    , parsedArgCase :: ParsedArgCase
parsedArgCase = OpDef'ArgDef -> ArgKind -> ParsedArgCase
parseArgCase OpDef'ArgDef
a ArgKind
tKind
    }

parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
parseArgCase a :: OpDef'ArgDef
a tKind :: ArgKind
tKind
    | Just n :: Name
n <- Text -> Maybe Name
maybeAttr (OpDef'ArgDef
a OpDef'ArgDef
-> FoldLike Text OpDef'ArgDef OpDef'ArgDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'ArgDef OpDef'ArgDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "typeListAttr" a) =>
LensLike' f s a
typeListAttr) = Name -> ArgKind -> ParsedArgCase
MixedListArg Name
n ArgKind
tKind
    | Just n :: Name
n <- Text -> Maybe Name
maybeAttr (OpDef'ArgDef
a OpDef'ArgDef
-> FoldLike Text OpDef'ArgDef OpDef'ArgDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'ArgDef OpDef'ArgDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "numberAttr" a) =>
LensLike' f s a
numberAttr) = Name -> ArgType -> ArgKind -> ParsedArgCase
ListArg Name
n ArgType
thisArgType ArgKind
tKind
    | Bool
otherwise = ArgType -> ArgKind -> ParsedArgCase
SimpleArg ArgType
thisArgType ArgKind
tKind
  where
    thisArgType :: ArgType
thisArgType
        | Just n :: Name
n <- Text -> Maybe Name
maybeAttr (OpDef'ArgDef
a OpDef'ArgDef
-> FoldLike Text OpDef'ArgDef OpDef'ArgDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef'ArgDef OpDef'ArgDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "typeAttr" a) =>
LensLike' f s a
typeAttr) = Name -> ArgType
ArgTypeAttr Name
n
        | Bool
otherwise = DataType -> ArgType
ArgTypeFixed (OpDef'ArgDef
a OpDef'ArgDef
-> FoldLike DataType OpDef'ArgDef OpDef'ArgDef DataType DataType
-> DataType
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike DataType OpDef'ArgDef OpDef'ArgDef DataType DataType
forall (f :: * -> *) s a.
(Functor f, HasField s "type'" a) =>
LensLike' f s a
type')
    maybeAttr :: Text -> Maybe Name
    maybeAttr :: Text -> Maybe Name
maybeAttr "" = Maybe Name
forall a. Maybe a
Nothing
    maybeAttr t :: Text
t = Name -> Maybe Name
forall a. a -> Maybe a
Just (Name -> Maybe Name) -> Name -> Maybe Name
forall a b. (a -> b) -> a -> b
$ Text -> Name
makeName Text
t

parseAttrType :: OpDef -> Text -> AttrType
parseAttrType :: OpDef -> Text -> AttrType
parseAttrType o :: OpDef
o = \case
    "string" -> AttrBaseType -> AttrType
AttrSingle AttrBaseType
AttrBytes
    "int" -> AttrBaseType -> AttrType
AttrSingle AttrBaseType
AttrInt64
    "float" -> AttrBaseType -> AttrType
AttrSingle AttrBaseType
AttrFloat
    "bool" -> AttrBaseType -> AttrType
AttrSingle AttrBaseType
AttrBool
    "type" -> AttrBaseType -> AttrType
AttrSingle AttrBaseType
AttrType
    "shape" -> AttrBaseType -> AttrType
AttrSingle AttrBaseType
AttrShape
    "tensor" -> AttrBaseType -> AttrType
AttrSingle AttrBaseType
AttrTensor
    "list(string)" -> AttrBaseType -> AttrType
AttrList AttrBaseType
AttrBytes
    "list(int)" -> AttrBaseType -> AttrType
AttrList AttrBaseType
AttrInt64
    "list(float)" -> AttrBaseType -> AttrType
AttrList AttrBaseType
AttrFloat
    "list(bool)" -> AttrBaseType -> AttrType
AttrList AttrBaseType
AttrBool
    "list(type)" -> AttrBaseType -> AttrType
AttrList AttrBaseType
AttrType
    "list(shape)" -> AttrBaseType -> AttrType
AttrList AttrBaseType
AttrShape
    "list(tensor)" -> AttrBaseType -> AttrType
AttrList AttrBaseType
AttrTensor
    t :: Text
t -> String -> AttrType
forall a. HasCallStack => String -> a
error (String -> AttrType) -> String -> AttrType
forall a b. (a -> b) -> a -> b
$ "parseAttrType: unrecognized type " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Text -> String
forall a. Show a => a -> String
show Text
t
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ " for op " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Text -> String
forall a. Show a => a -> String
show (OpDef
o OpDef -> FoldLike Text OpDef OpDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text OpDef OpDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name)