-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{- | Rendering of TensorFlow operations as Haskell functions.

The basic type signature generated for each op is:

> {constraints} => {mandatory attrs} -> {input tensors} -> {output tensors}

where:

* @{mandatory attrs}@ is of the form @A_1 -> ... -> A_N@, where each @A@ is an
 op attribute that doesn't have a default and can't be inferred from other
 inputs.

* @{constraints}@ restrict the type parameters of the input and output tensors
 (for example: 'TensorType' or 'OneOf').

* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
the form @Tensor Ref a@ or @Tensor v a@ (or a list of one of those types),
and @a@ is either a concrete type or a (constrained) type variable.

* @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and
@Build (T_1,...,T_N)@ for "stateful" ops.  An op is considered "stateful" if
it takes a @Tensor Ref@ or @Tensor v ResourceHandle@ as input, or if it's
explicitly marked \"Stateful\" in its @REGISTER_OP@ definition.  (If there
are no outputs, it is either @ControlNode@ or @Build ControlNode@.)
-}

module TensorFlow.OpGen
  ( OpGenFlags(..)
  , docOpList
  , flagParser)
  where

import Data.Foldable (toList)
import Data.Maybe (fromMaybe)
import Data.ProtoLens.Default(def)
import Data.ProtoLens (showMessage)
import Data.List (sortOn)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NE
import Lens.Family2 ((^.), (.~), (&), view)
import Options.Applicative (Parser, help, long, strOption, value)
import Proto.Tensorflow.Core.Framework.OpDef
  ( OpList
  , OpDef
  )
import Proto.Tensorflow.Core.Framework.OpDef_Fields
  ( attr
  , inputArg
  , name
  , op
  , outputArg
  )
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import System.FilePath (takeBaseName)
import TensorFlow.OpGen.ParsedOp
import Text.PrettyPrint.Mainland
  ( Doc
  , (<+>)
  , (</>)
  , (<+/>)
  , brackets
  , comma
  , commasep
  , dquotes
  , empty
  , enclose
  , flatten
  , folddoc
  , hang
  , indent
  , parens
  , sep
  , stack
  , strictText
  , tuple
  )
import qualified Data.Set as Set
import qualified Data.Text as Text

data OpGenFlags = OpGenFlags
     { OpGenFlags -> String
outputFile :: String
     , OpGenFlags -> String
prefix :: String
     , OpGenFlags -> String
excludeList :: String
     }

flagParser :: Parser OpGenFlags
flagParser :: Parser OpGenFlags
flagParser = String -> String -> String -> OpGenFlags
OpGenFlags
     (String -> String -> String -> OpGenFlags)
-> Parser String -> Parser (String -> String -> OpGenFlags)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mod OptionFields String -> Parser String
forall s. IsString s => Mod OptionFields s -> Parser s
strOption ([Mod OptionFields String] -> Mod OptionFields String
forall a. Monoid a => [a] -> a
mconcat [ String -> Mod OptionFields String
forall (f :: * -> *) a. HasName f => String -> Mod f a
long "output"
                            , String -> Mod OptionFields String
forall (f :: * -> *) a. String -> Mod f a
help "File to write."
                            ])
     Parser (String -> String -> OpGenFlags)
-> Parser String -> Parser (String -> OpGenFlags)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mod OptionFields String -> Parser String
forall s. IsString s => Mod OptionFields s -> Parser s
strOption ([Mod OptionFields String] -> Mod OptionFields String
forall a. Monoid a => [a] -> a
mconcat [ String -> Mod OptionFields String
forall (f :: * -> *) a. HasName f => String -> Mod f a
long "prefix"
                            , String -> Mod OptionFields String
forall (f :: * -> *) a. String -> Mod f a
help "Haskell package prefix to use"
                            ])
     Parser (String -> OpGenFlags) -> Parser String -> Parser OpGenFlags
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mod OptionFields String -> Parser String
forall s. IsString s => Mod OptionFields s -> Parser s
strOption ([Mod OptionFields String] -> Mod OptionFields String
forall a. Monoid a => [a] -> a
mconcat [ String -> Mod OptionFields String
forall (f :: * -> *) a. HasName f => String -> Mod f a
long "exclude_list"
                            , String -> Mod OptionFields String
forall (f :: * -> *) a. HasValue f => a -> Mod f a
value ""
                            , String -> Mod OptionFields String
forall (f :: * -> *) a. String -> Mod f a
help "Comma separated Ops names to ignore"
                            ])


docOpList :: OpGenFlags -> OpList -> Doc
docOpList :: OpGenFlags -> OpList -> Doc
docOpList flags :: OpGenFlags
flags opList :: OpList
opList =
  [Doc] -> Doc
stack [ "{-# LANGUAGE ConstraintKinds #-}"
        , "{-# LANGUAGE DataKinds #-}"
        , "{-# LANGUAGE FlexibleContexts #-}"
        , "{-# LANGUAGE FlexibleInstances #-}"
        , "{-# LANGUAGE OverloadedStrings #-}"
        , "{-# LANGUAGE ScopedTypeVariables #-}"
          -- Avoids reports about shadowing standard library names.
        , "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}"
          -- eqLengthGuard never returns false and dies instead.
        , "{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}"
        , "module" Doc -> Doc -> Doc
<+> Text -> Doc
strictText Text
moduleName Doc -> Doc -> Doc
<+> "where"
        , Doc
empty
        , Doc
imports
        , Doc
empty
        , (Doc -> Doc -> Doc) -> [Doc] -> Doc
folddoc (\x :: Doc
x y :: Doc
y -> Doc
x Doc -> Doc -> Doc
</> Doc
empty Doc -> Doc -> Doc
</> Doc
y)
                  ((OpDef -> Doc) -> [OpDef] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map OpDef -> Doc
renderOpAndExtras ([OpDef] -> [Doc]) -> [OpDef] -> [Doc]
forall a b. (a -> b) -> a -> b
$
                   (OpDef -> Text) -> [OpDef] -> [OpDef]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (FoldLike Text OpDef OpDef Text Text -> OpDef -> Text
forall a s t b. FoldLike a s t a b -> s -> a
view FoldLike Text OpDef OpDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name) ([OpDef] -> [OpDef]) -> [OpDef] -> [OpDef]
forall a b. (a -> b) -> a -> b
$
                   (OpDef -> Bool) -> [OpDef] -> [OpDef]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (OpDef -> Bool) -> OpDef -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> [Text] -> Bool) -> [Text] -> Text -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip Text -> [Text] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [Text]
exclusions (Text -> Bool) -> (OpDef -> Text) -> OpDef -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FoldLike Text OpDef OpDef Text Text -> OpDef -> Text
forall a s t b. FoldLike a s t a b -> s -> a
view FoldLike Text OpDef OpDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name) ([OpDef] -> [OpDef]) -> [OpDef] -> [OpDef]
forall a b. (a -> b) -> a -> b
$
                   [OpDef] -> [OpDef]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList ([OpDef] -> [OpDef]) -> [OpDef] -> [OpDef]
forall a b. (a -> b) -> a -> b
$ OpList
opList OpList -> FoldLike [OpDef] OpList OpList [OpDef] [OpDef] -> [OpDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [OpDef] OpList OpList [OpDef] [OpDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "op" a) =>
LensLike' f s a
op)
        ]
  where moduleName :: Text
moduleName =
            String -> Text
Text.pack (OpGenFlags -> String
prefix OpGenFlags
flags) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> "." Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
camelCase
             -- Discards the optional trailing _ops_op_lib
            (Text -> Maybe Text -> Text
forall a. a -> Maybe a -> a
fromMaybe Text
shortName (Text -> Text -> Maybe Text
Text.stripSuffix "_ops_op_lib" Text
shortName))
        shortName :: Text
shortName = String -> Text
Text.pack (String -> String
takeBaseName (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ OpGenFlags -> String
outputFile OpGenFlags
flags)
        exclusions :: [Text]
exclusions = Text -> Text -> [Text]
Text.splitOn "," (Text -> [Text]) -> Text -> [Text]
forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ OpGenFlags -> String
excludeList OpGenFlags
flags
        renderOpAndExtras :: OpDef -> Doc
renderOpAndExtras o :: OpDef
o = ParsedOp -> Doc
renderOp (OpDef -> ParsedOp
parseOp OpDef
o) Doc -> Doc -> Doc
</> OpDef -> Doc
extras OpDef
o

imports :: Doc
imports :: Doc
imports = [Doc] -> Doc
stack [
      "import Data.ByteString (ByteString)"
    , "import Data.Complex (Complex)"
    , "import Data.Int (Int8, Int16, Int32, Int64)"
    , "import Data.Proxy (Proxy(Proxy))"
    , "import Data.Word (Word8, Word16, Word32, Word64)"
    , "import Lens.Family2 ((.~), (&))"
    , "import TensorFlow.Build"
    , "import TensorFlow.BuildOp"
    , "import TensorFlow.Tensor"
    , "import TensorFlow.Types"
    ]

renderHaskellName, renderTFName, renderQuotedTFName :: Name -> Doc
renderHaskellName :: Name -> Doc
renderHaskellName = Text -> Doc
strictText (Text -> Doc) -> (Name -> Text) -> Name -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HaskellName -> Text
unHaskellName (HaskellName -> Text) -> (Name -> HaskellName) -> Name -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> HaskellName
haskellName
renderTFName :: Name -> Doc
renderTFName = Text -> Doc
strictText (Text -> Doc) -> (Name -> Text) -> Name -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TFName -> Text
unTFName (TFName -> Text) -> (Name -> TFName) -> Name -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> TFName
tfName
renderQuotedTFName :: Name -> Doc
renderQuotedTFName = Doc -> Doc
dquotes (Doc -> Doc) -> (Name -> Doc) -> Name -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Doc
renderTFName


-- | Generate the source code for a single op.
-- For example:
--
-- -- | {haddock comment}
-- foo :: {type sig}
-- foo attr1 attr2 input1 input2 | eqLengthGuard [...] = {function body}
renderOp :: ParsedOp -> Doc
renderOp :: ParsedOp -> Doc
renderOp pOp :: ParsedOp
pOp = [Doc] -> Doc
stack ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$
    [ Doc
haddocks
    -- Prevent unreasonably long compilation times on ghc-7.10, due
    -- to stack calling "-dump-hi" which (unnecessarily) includes the
    -- inlining information, and is large for ops with many arguments.
#if __GLASGOW_HASKELL__ < 800
    , "{-# NOINLINE" <+> n <+> "#-}"
#endif
    , Doc
n Doc -> Doc -> Doc
<+> "::" Doc -> Doc -> Doc
<+> Int -> Doc -> Doc
hang 0 (Doc -> ParsedOp -> Doc
typeSig Doc
empty ParsedOp
pOp)
    , Doc
n Doc -> Doc -> Doc
<+> "=" Doc -> Doc -> Doc
<+> Doc
n Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> "' id"
    , Doc
n' Doc -> Doc -> Doc
<+> "::" Doc -> Doc -> Doc
<+> Int -> Doc -> Doc
hang 0 (Doc -> ParsedOp -> Doc
typeSig "OpParams ->" ParsedOp
pOp)
    , Doc
n' Doc -> Doc -> Doc
<+> Int -> Doc -> Doc
hang 0 Doc
args Doc -> Doc -> Doc
<+> "|" Doc -> Doc -> Doc
<+> [Attr (NonEmpty Name)] -> Doc
funcGuard [Attr (NonEmpty Name)]
listSizeAttrs
                Doc -> Doc -> Doc
<+> "=" Doc -> Doc -> Doc
</>  -- args are indented
                    -- the body needs to be indented wrt the name
                    Int -> Doc -> Doc
indent Int
indentation (ParsedOp -> Doc
functionBody ParsedOp
pOp)
    ] [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Attr (NonEmpty Name)] -> [Doc]
whereClause [Attr (NonEmpty Name)]
listSizeAttrs
  where
    n :: Doc
n = Name -> Doc
renderHaskellName (Name -> Doc) -> Name -> Doc
forall a b. (a -> b) -> a -> b
$ ParsedOp -> Name
parsedOpName ParsedOp
pOp
    n' :: Doc
n' = Doc
n Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> "'"
    listSizeAttrs :: [Attr (NonEmpty Name)]
listSizeAttrs = ParsedOp -> [Attr (NonEmpty Name)]
inferredListSizeAttrs ParsedOp
pOp
    args :: Doc
args = [Doc] -> Doc
sep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ "op'options"
               Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
: ((Name -> Doc) -> [Name] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Doc
renderHaskellName
                    ([Name] -> [Doc]) -> [Name] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Attr AttrType -> Name) -> [Attr AttrType] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map Attr AttrType -> Name
forall a. Attr a -> Name
attrName (ParsedOp -> [Attr AttrType]
explicitInputAttrs ParsedOp
pOp)
                    [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ (ParsedArg -> Name) -> [ParsedArg] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map ParsedArg -> Name
parsedArgName (ParsedOp -> [ParsedArg]
parsedInputs ParsedOp
pOp))
    haddocks :: Doc
haddocks = "-- |" Doc -> Doc -> Doc
<+> Text -> Text -> Doc
multilineComment (ParsedOp -> Text
parsedOpSummary ParsedOp
pOp) (ParsedOp -> Text
parsedOpDescription ParsedOp
pOp)

-- | A check that all lists of the given size have the given length.
-- For example:
--   eqLengthGuard [("N", [("input1", length input1), ("input2", length input2)])]
funcGuard :: [Attr (NonEmpty Name)] -> Doc
funcGuard :: [Attr (NonEmpty Name)] -> Doc
funcGuard attrs :: [Attr (NonEmpty Name)]
attrs = "eqLengthGuard" Doc -> Doc -> Doc
<+> Doc -> Doc
brackets ([Doc] -> Doc
commasep [Doc]
entries)
      where
        entries :: [Doc]
entries =
            [ Doc -> Doc
parens (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ Doc
nAttr Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+>
              Doc -> Doc
brackets ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ [Doc] -> [Doc]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$
                            (Name -> Doc) -> [Name] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Doc
renderTensorName (NonEmpty Name -> [Name]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (NonEmpty Name -> [Name]) -> NonEmpty Name -> [Name]
forall a b. (a -> b) -> a -> b
$ Attr (NonEmpty Name) -> NonEmpty Name
forall a. Attr a -> a
attrInfo Attr (NonEmpty Name)
a))
            | Attr (NonEmpty Name)
a <- [Attr (NonEmpty Name)]
attrs
            , let nAttr :: Doc
nAttr = Name -> Doc
renderQuotedTFName (Attr (NonEmpty Name) -> Name
forall a. Attr a -> Name
attrName Attr (NonEmpty Name)
a)
            ]
        renderTensorName :: Name -> Doc
renderTensorName x :: Name
x = Doc -> Doc
parens (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ Name -> Doc
renderQuotedTFName Name
x Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+>
                        "length" Doc -> Doc -> Doc
<+> Name -> Doc
renderHaskellName Name
x

-- | Define the implicit list length attributes.
-- For example:
--   where
--     n1 = fromIntegral (length input1) :: Int64
--     n2 = fromIntegral (length input2) :: Int64
whereClause :: [Attr (NonEmpty Name)] -> [Doc]
whereClause :: [Attr (NonEmpty Name)] -> [Doc]
whereClause [] = []
whereClause as :: [Attr (NonEmpty Name)]
as = [Int -> Doc -> Doc
indent 2 (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ "where" Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent 2 ([Doc] -> Doc
stack ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (Attr (NonEmpty Name) -> Doc) -> [Attr (NonEmpty Name)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Attr (NonEmpty Name) -> Doc
defineLengthAttr [Attr (NonEmpty Name)]
as)]
  where
    defineLengthAttr :: Attr (NonEmpty Name) -> Doc
defineLengthAttr a :: Attr (NonEmpty Name)
a = Attr (NonEmpty Name) -> Doc
forall a. Attr a -> Doc
renderHaskellAttrName Attr (NonEmpty Name)
a Doc -> Doc -> Doc
<+> "="
                            Doc -> Doc -> Doc
<+> "fromIntegral (length"
                            Doc -> Doc -> Doc
<+> Name -> Doc
renderHaskellName (NonEmpty Name -> Name
forall a. NonEmpty a -> a
NE.head (NonEmpty Name -> Name) -> NonEmpty Name -> Name
forall a b. (a -> b) -> a -> b
$ Attr (NonEmpty Name) -> NonEmpty Name
forall a. Attr a -> a
attrInfo Attr (NonEmpty Name)
a)
                            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> ") :: Int64"

renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName = Name -> Doc
renderHaskellName (Name -> Doc) -> (Attr a -> Name) -> Attr a -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Attr a -> Name
forall a. Attr a -> Name
attrName

functionBody :: ParsedOp -> Doc
functionBody :: ParsedOp -> Doc
functionBody pOp :: ParsedOp
pOp
    | ParsedOp -> Bool
parsedOpIsMonadic ParsedOp
pOp
        = "build $ do"
            Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
indentation (Doc
bindOpInputsVar
                        Doc -> Doc -> Doc
</> "buildOp" Doc -> Doc -> Doc
<+> Doc
outputListsSizes Doc -> Doc -> Doc
<+> Doc
opDef)
    | Bool
otherwise
        = "pureOp" Doc -> Doc -> Doc
<+> Doc
outputListsSizes Doc -> Doc -> Doc
<+> "$ do"
            Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
indentation (Doc
bindOpInputsVar Doc -> Doc -> Doc
</> "return" Doc -> Doc -> Doc
<+> Doc
opDef)
  where
    outputListsSizes :: Doc
outputListsSizes = Doc -> Doc
brackets (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
commasep
        [ Name -> Doc
renderHaskellName Name
a
        | ParsedArg { parsedArgCase :: ParsedArg -> ParsedArgCase
parsedArgCase = ListArg { argLength :: ParsedArgCase -> Name
argLength = Name
a } }
            <- ParsedOp -> [ParsedArg]
parsedOutputs ParsedOp
pOp
        ]
    opInputsVar :: Doc
opInputsVar = "op'inputs"
    bindOpInputsVar :: Doc
bindOpInputsVar = Doc
opInputsVar Doc -> Doc -> Doc
<+> "<- fmap Prelude.concat $ Prelude.sequence"
                            Doc -> Doc -> Doc
<+> Doc -> Doc
brackets ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (Doc -> Doc) -> [Doc] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (\a :: Doc
a -> "buildInputs" Doc -> Doc -> Doc
<+> Doc
a) [Doc]
tensorArgs)
    opDef :: Doc
opDef = Doc -> Doc
parens (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ Int -> Doc -> Doc
hang 0 (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
stack ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$
        "opDef" Doc -> Doc -> Doc
<+> Name -> Doc
renderQuotedTFName (ParsedOp -> Name
parsedOpName ParsedOp
pOp) Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
:
        -- Renders type parameter arguments.
        [ "& opAttr" Doc -> Doc -> Doc
<+> Name -> Doc
renderQuotedTFName Name
n Doc -> Doc -> Doc
<+> ".~" Doc -> Doc -> Doc
<+> Attr TypeParam -> Doc
inferredTypeExpr Attr TypeParam
a
        | Attr TypeParam
a <- ParsedOp -> [Attr TypeParam]
inferredTypeAttrs ParsedOp
pOp, let n :: Name
n = Attr TypeParam -> Name
forall a. Attr a -> Name
attrName Attr TypeParam
a
        ] [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++
        -- Renders mandatory attributes as function parameters.
        [ "& opAttr" Doc -> Doc -> Doc
<+> Name -> Doc
renderQuotedTFName Name
n Doc -> Doc -> Doc
<+> ".~" Doc -> Doc -> Doc
<+> Name -> Doc
renderHaskellName Name
n
        | Attr AttrType
a <- ParsedOp -> [Attr AttrType]
explicitInputAttrs ParsedOp
pOp, let n :: Name
n = Attr AttrType -> Name
forall a. Attr a -> Name
attrName Attr AttrType
a
        ] [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++
        -- Renders sizes of tensor list types having number_attr.
        [ "& opAttr" Doc -> Doc -> Doc
<+> Name -> Doc
renderQuotedTFName Name
n Doc -> Doc -> Doc
<+> ".~" Doc -> Doc -> Doc
<+> Name -> Doc
renderHaskellName Name
n
        | Attr (NonEmpty Name)
a <- ParsedOp -> [Attr (NonEmpty Name)]
inferredListSizeAttrs ParsedOp
pOp, let n :: Name
n = Attr (NonEmpty Name) -> Name
forall a. Attr a -> Name
attrName Attr (NonEmpty Name)
a
        ] [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++
        ["& op'options & opInputs .~" Doc -> Doc -> Doc
<+> Doc
opInputsVar]
    tensorArgs :: [Doc]
tensorArgs = ParsedArg -> Doc
renderTensorArg (ParsedArg -> Doc) -> [ParsedArg] -> [Doc]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ParsedOp -> [ParsedArg]
parsedInputs ParsedOp
pOp
    renderTensorArg :: ParsedArg -> Doc
renderTensorArg = Name -> Doc
renderHaskellName (Name -> Doc) -> (ParsedArg -> Name) -> ParsedArg -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParsedArg -> Name
parsedArgName
    inferredTypeExpr :: Attr TypeParam -> Doc
inferredTypeExpr a :: Attr TypeParam
a
        | TypeParam -> Bool
typeParamIsList (TypeParam -> Bool) -> TypeParam -> Bool
forall a b. (a -> b) -> a -> b
$ Attr TypeParam -> TypeParam
forall a. Attr a -> a
attrInfo Attr TypeParam
a
            = "fromTensorTypes (Proxy :: Proxy" Doc -> Doc -> Doc
<+> Attr TypeParam -> Doc
forall a. Attr a -> Doc
renderHaskellAttrName Attr TypeParam
a
                    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> ")"
        | Bool
otherwise = "tensorType (undefined ::" Doc -> Doc -> Doc
<+> Attr TypeParam -> Doc
forall a. Attr a -> Doc
renderHaskellAttrName Attr TypeParam
a
                            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> ")"

-- | Write a comment with the inputs/outputs/attributes in proto format, for
-- debugging.
extras :: OpDef -> Doc
extras :: OpDef -> Doc
extras d :: OpDef
d = Doc -> Doc -> Doc -> Doc
enclose "{-\n" "\n-}" (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
            Text -> Doc
strictText (Text -> Doc) -> Text -> Doc
forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$
            OpDef -> String
forall msg. Message msg => msg -> String
showMessage ((OpDef
forall a. Message a => a
def :: OpDef)
                        OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f OpDef [OpDef'ArgDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "inputArg" a) =>
LensLike' f s a
inputArg (forall (f :: * -> *).
 Identical f =>
 LensLike' f OpDef [OpDef'ArgDef])
-> [OpDef'ArgDef] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ (OpDef
d OpDef
-> FoldLike
     [OpDef'ArgDef] OpDef OpDef [OpDef'ArgDef] [OpDef'ArgDef]
-> [OpDef'ArgDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [OpDef'ArgDef] OpDef OpDef [OpDef'ArgDef] [OpDef'ArgDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "inputArg" a) =>
LensLike' f s a
inputArg)
                        OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f OpDef [OpDef'ArgDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "outputArg" a) =>
LensLike' f s a
outputArg (forall (f :: * -> *).
 Identical f =>
 LensLike' f OpDef [OpDef'ArgDef])
-> [OpDef'ArgDef] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ (OpDef
d OpDef
-> FoldLike
     [OpDef'ArgDef] OpDef OpDef [OpDef'ArgDef] [OpDef'ArgDef]
-> [OpDef'ArgDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [OpDef'ArgDef] OpDef OpDef [OpDef'ArgDef] [OpDef'ArgDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "outputArg" a) =>
LensLike' f s a
outputArg)
                        OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f OpDef [OpDef'AttrDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "attr" a) =>
LensLike' f s a
attr (forall (f :: * -> *).
 Identical f =>
 LensLike' f OpDef [OpDef'AttrDef])
-> [OpDef'AttrDef] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ (OpDef
d OpDef
-> FoldLike
     [OpDef'AttrDef] OpDef OpDef [OpDef'AttrDef] [OpDef'AttrDef]
-> [OpDef'AttrDef]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike
  [OpDef'AttrDef] OpDef OpDef [OpDef'AttrDef] [OpDef'AttrDef]
forall (f :: * -> *) s a.
(Functor f, HasField s "attr" a) =>
LensLike' f s a
attr))

-- | The type signature for an op.
-- Of the form:
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
--      => {pre} Float -> Tensor t1 v1 -> Tensor t2 v2
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
-- "Tensor t2 v2" is an output.
typeSig :: Doc -> ParsedOp -> Doc
typeSig :: Doc -> ParsedOp -> Doc
typeSig pre :: Doc
pre pOp :: ParsedOp
pOp = Doc
constraints
            Doc -> Doc -> Doc
<+/> Doc
pre Doc -> Doc -> Doc
</> [Doc] -> Doc
signatureFold ((Attr AttrType -> Doc) -> [Attr AttrType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Attr AttrType -> Doc
attrInput (ParsedOp -> [Attr AttrType]
explicitInputAttrs ParsedOp
pOp)
                                [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ (ParsedArg -> Doc) -> [ParsedArg] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map ParsedArg -> Doc
tensorArgAndComment (ParsedOp -> [ParsedArg]
parsedInputs ParsedOp
pOp)
                                [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Doc
outputs])
  where
    constraints :: Doc
constraints
        | [Doc] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Doc]
classConstraints = Doc
empty
        | Bool
otherwise = "forall" Doc -> Doc -> Doc
<+> [Doc] -> Doc
sep [Doc]
typeParams Doc -> Doc -> Doc
<+> "." Doc -> Doc -> Doc
<+> [Doc] -> Doc
tuple [Doc]
classConstraints Doc -> Doc -> Doc
<+> "=>"
    typeParams :: [Doc]
typeParams = [Text -> Doc
strictText Text
v | ParsedArg
k <- ParsedOp -> [ParsedArg]
parsedInputs ParsedOp
pOp [ParsedArg] -> [ParsedArg] -> [ParsedArg]
forall a. [a] -> [a] -> [a]
++ ParsedOp -> [ParsedArg]
parsedOutputs ParsedOp
pOp,
                  ArgSomeTensor v :: Text
v <- [ParsedArgCase -> ArgKind
argKind (ParsedArgCase -> ArgKind) -> ParsedArgCase -> ArgKind
forall a b. (a -> b) -> a -> b
$ ParsedArg -> ParsedArgCase
parsedArgCase ParsedArg
k]]
                [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Attr TypeParam -> Doc
forall a. Attr a -> Doc
renderHaskellAttrName Attr TypeParam
n | Attr TypeParam
n <- ParsedOp -> [Attr TypeParam]
inferredTypeAttrs ParsedOp
pOp]
                [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ if ParsedOp -> Bool
parsedOpIsMonadic ParsedOp
pOp then ["m'"] else []
    -- Use m' as the type parameter to avoid clashing with an attribute name.
    monadConstraint :: [Doc]
monadConstraint
        | ParsedOp -> Bool
parsedOpIsMonadic ParsedOp
pOp = ["MonadBuild m'"]
        | Bool
otherwise = []
    classConstraints :: [Doc]
classConstraints = [Doc]
monadConstraint [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ (Attr TypeParam -> Doc) -> [Attr TypeParam] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Attr TypeParam -> Doc
tensorArgConstraint
                                                    (ParsedOp -> [Attr TypeParam]
inferredTypeAttrs ParsedOp
pOp)
    signatureFold :: [Doc] -> Doc
signatureFold = (Doc -> Doc -> Doc) -> [Doc] -> Doc
folddoc (\x :: Doc
x y :: Doc
y -> Doc
x Doc -> Doc -> Doc
</> "->" Doc -> Doc -> Doc
<+> Doc
y)
    attrInput :: Attr AttrType -> Doc
attrInput a :: Attr AttrType
a = AttrType -> Doc
renderAttrType (Attr AttrType -> AttrType
forall a. Attr a -> a
attrInfo Attr AttrType
a) Doc -> Doc -> Doc
<+> Int -> Doc -> Doc
hang 0 ("-- ^" Doc -> Doc -> Doc
<+> Attr AttrType -> Doc
forall a. Attr a -> Doc
attrComment Attr AttrType
a)
    renderAttrType :: AttrType -> Doc
renderAttrType (AttrSingle a :: AttrBaseType
a) = AttrBaseType -> Doc
renderAttrBaseType AttrBaseType
a
    renderAttrType (AttrList a :: AttrBaseType
a) = Doc -> Doc
brackets (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ AttrBaseType -> Doc
renderAttrBaseType AttrBaseType
a
    renderAttrBaseType :: AttrBaseType -> Doc
renderAttrBaseType = \case
        AttrBytes -> "ByteString"
        AttrInt64 -> "Data.Int.Int64"
        AttrFloat -> "Float"
        AttrBool -> "Bool"
        AttrType -> "DataType"
        AttrShape -> "Shape"
        AttrTensor -> "TensorProto"

    tensorArgAndComment :: ParsedArg -> Doc
tensorArgAndComment t :: ParsedArg
t = ParsedArg -> Doc
tensorArg ParsedArg
t Doc -> Doc -> Doc
<+> Int -> Doc -> Doc
hang 0 ("-- ^" Doc -> Doc -> Doc
<+> ParsedArg -> Doc
argComment ParsedArg
t)
    outputs :: Doc
outputs = case ParsedOp -> [ParsedArg]
parsedOutputs ParsedOp
pOp of
        [] -> Doc -> Doc
wrapOutput "ControlNode"
        -- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
        [a :: ParsedArg
a] -> Doc -> Doc
wrapOutput (ParsedArg -> Doc
tensorArg ParsedArg
a) Doc -> Doc -> Doc
<+> "-- ^" Doc -> Doc -> Doc
<+> ParsedArg -> Doc
argComment ParsedArg
a
        as :: [ParsedArg]
as -> Doc -> Doc
wrapOutput ([Doc] -> Doc
tuple ((ParsedArg -> Doc) -> [ParsedArg] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map ParsedArg -> Doc
tensorArg [ParsedArg]
as)) Doc -> Doc -> Doc
<+/> [ParsedArg] -> Doc
resultComment [ParsedArg]
as
    wrapOutput :: Doc -> Doc
wrapOutput o :: Doc
o
        | ParsedOp -> Bool
parsedOpIsMonadic ParsedOp
pOp = "m'" Doc -> Doc -> Doc
<+> Doc -> Doc
parens Doc
o
        | Bool
otherwise = Doc
o

-- | Render an op input or output.
-- For example: "Tensor Ref Int64", "Tensor v t"
tensorArg :: ParsedArg -> Doc
tensorArg :: ParsedArg -> Doc
tensorArg p :: ParsedArg
p = case ParsedArg -> ParsedArgCase
parsedArgCase ParsedArg
p of
    SimpleArg { argType :: ParsedArgCase -> ArgType
argType = ArgType
t, argKind :: ParsedArgCase -> ArgKind
argKind = ArgKind
k } -> ArgType -> ArgKind -> Doc
tensorType ArgType
t ArgKind
k
    ListArg { argType :: ParsedArgCase -> ArgType
argType = ArgType
t, argKind :: ParsedArgCase -> ArgKind
argKind = ArgKind
k } -> Doc -> Doc
brackets (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ ArgType -> ArgKind -> Doc
tensorType ArgType
t ArgKind
k
    MixedListArg {argTypeAttr :: ParsedArgCase -> Name
argTypeAttr = Name
t, argKind :: ParsedArgCase -> ArgKind
argKind = ArgKind
k}
        -> "TensorList" Doc -> Doc -> Doc
<+> Doc -> Doc
parens (ArgKind -> Doc
kind ArgKind
k) Doc -> Doc -> Doc
<+> Name -> Doc
renderHaskellName Name
t
  where
    kind :: ArgKind -> Doc
kind k :: ArgKind
k = case ArgKind
k of
                ArgTensorRef -> "Ref"
                ArgTensorValue -> "Value"
                ArgTensorBuild -> "Build"
                ArgSomeTensor v :: Text
v -> Text -> Doc
strictText Text
v
    tensorType :: ArgType -> ArgKind -> Doc
tensorType t :: ArgType
t k :: ArgKind
k = let
        a :: Doc
a = case ArgType
t of
                ArgTypeFixed dt :: DataType
dt -> Text -> Doc
strictText (Text -> Doc) -> Text -> Doc
forall a b. (a -> b) -> a -> b
$ DataType -> Text
dtTypeToHaskell DataType
dt
                ArgTypeAttr n :: Name
n -> Name -> Doc
renderHaskellName Name
n
        in "Tensor" Doc -> Doc -> Doc
<+> ArgKind -> Doc
kind ArgKind
k Doc -> Doc -> Doc
<+> Doc
a

attrComment :: Attr a -> Doc
attrComment :: Attr a -> Doc
attrComment a :: Attr a
a = Name -> Text -> Doc
argComment' (Attr a -> Name
forall a. Attr a -> Name
attrName Attr a
a) (Attr a -> Text
forall a. Attr a -> Text
attrDescription Attr a
a)

argComment :: ParsedArg -> Doc
argComment :: ParsedArg -> Doc
argComment a :: ParsedArg
a = Name -> Text -> Doc
argComment' (ParsedArg -> Name
parsedArgName ParsedArg
a) (ParsedArg -> Text
parsedArgDescription ParsedArg
a)

argComment' :: Name -> Text.Text -> Doc
argComment' :: Name -> Text -> Doc
argComment' argName :: Name
argName argDesc :: Text
argDesc =
    Doc -> Doc
bold (Name -> Doc
renderTFName Name
argName) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (Doc -> Doc) -> Text -> Doc
splitMultilineText (":" Doc -> Doc -> Doc
<+>) Text
argDesc

bold :: Doc -> Doc
bold :: Doc -> Doc
bold n :: Doc
n = "__" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
n Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> "__"

-- | Comment for the outputs of an op.
-- For example:
--   -- ^ (__output1__, __output2__)
--   --
--   -- * __output1__: description1
--   --
--   -- * __output2__: description2
resultComment :: [ParsedArg] -> Doc
resultComment :: [ParsedArg] -> Doc
resultComment os :: [ParsedArg]
os = [Doc] -> Doc
stack ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> Doc
flatten Doc
commentSummary Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
: (ParsedArg -> Doc) -> [ParsedArg] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map ParsedArg -> Doc
commentDetails [ParsedArg]
os
  where
    commentSummary :: Doc
commentSummary = "-- ^" Doc -> Doc -> Doc
<+> [Doc] -> Doc
tuple [Doc -> Doc
bold (Name -> Doc
renderTFName (Name -> Doc) -> Name -> Doc
forall a b. (a -> b) -> a -> b
$ ParsedArg -> Name
parsedArgName ParsedArg
o) | ParsedArg
o <- [ParsedArg]
os]
    commentDetails :: ParsedArg -> Doc
commentDetails o :: ParsedArg
o =
        [Doc] -> Doc
stack [ "--"
              , "-- *" Doc -> Doc -> Doc
<+> ParsedArg -> Doc
argComment ParsedArg
o
              ]

-- | Constraints for a given type parameter.
-- E.g.: "TensorType t" or "OneOf [Int64, Float] t"
-- or "TensorTypes ts" or "OneOfs [..] ts".
tensorArgConstraint :: Attr TypeParam -> Doc
tensorArgConstraint :: Attr TypeParam -> Doc
tensorArgConstraint a :: Attr TypeParam
a = case Attr TypeParam -> TypeParam
forall a. Attr a -> a
attrInfo Attr TypeParam
a of
    TypeParam False Nothing -> "TensorType" Doc -> Doc -> Doc
<+> Doc
n
    TypeParam False (Just as :: NonEmpty DataType
as) -> "OneOf" Doc -> Doc -> Doc
<+> NonEmpty DataType -> Doc
typeList NonEmpty DataType
as Doc -> Doc -> Doc
<+> Doc
n
    TypeParam True Nothing -> "TensorTypes" Doc -> Doc -> Doc
<+> Doc
n
    TypeParam True (Just as :: NonEmpty DataType
as) -> "OneOfs" Doc -> Doc -> Doc
<+> NonEmpty DataType -> Doc
typeList NonEmpty DataType
as Doc -> Doc -> Doc
<+> Doc
n
  where
    n :: Doc
n = Attr TypeParam -> Doc
forall a. Attr a -> Doc
renderHaskellAttrName Attr TypeParam
a
    -- Produces a type-level list, e.g.: '[Int32,Int64,Float]
    typeList :: NonEmpty DataType -> Doc
typeList = ("'" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>) (Doc -> Doc)
-> (NonEmpty DataType -> Doc) -> NonEmpty DataType -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Doc -> Doc
brackets (Doc -> Doc)
-> (NonEmpty DataType -> Doc) -> NonEmpty DataType -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Doc] -> Doc
commasep ([Doc] -> Doc)
-> (NonEmpty DataType -> [Doc]) -> NonEmpty DataType -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Doc) -> [Text] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Text -> Doc
strictText ([Text] -> [Doc])
-> (NonEmpty DataType -> [Text]) -> NonEmpty DataType -> [Doc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                    Set Text -> [Text]
forall a. Set a -> [a]
Set.toList (Set Text -> [Text])
-> (NonEmpty DataType -> Set Text) -> NonEmpty DataType -> [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> Set Text
forall a. Ord a => [a] -> Set a
Set.fromList ([Text] -> Set Text)
-> (NonEmpty DataType -> [Text]) -> NonEmpty DataType -> Set Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                    (DataType -> Text) -> [DataType] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map DataType -> Text
dtTypeToHaskell ([DataType] -> [Text])
-> (NonEmpty DataType -> [DataType]) -> NonEmpty DataType -> [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty DataType -> [DataType]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList

-- NOTE: The cases of this function should be kept in sync with
-- TensorFlow.Types.AllTensorTypes.
dtTypeToHaskell :: DataType -> Text.Text
dtTypeToHaskell :: DataType -> Text
dtTypeToHaskell DT_BOOL = "Bool"
dtTypeToHaskell DT_BFLOAT16 = "Data.Word.Word16"
dtTypeToHaskell DT_COMPLEX128 = "(Data.Complex.Complex Double)"
dtTypeToHaskell DT_COMPLEX64 = "(Data.Complex.Complex Float)"
dtTypeToHaskell DT_DOUBLE = "Double"
dtTypeToHaskell DT_FLOAT = "Float"
dtTypeToHaskell DT_INT16 = "Data.Int.Int16"
dtTypeToHaskell DT_INT32 = "Data.Int.Int32"
dtTypeToHaskell DT_INT64 = "Data.Int.Int64"
dtTypeToHaskell DT_INT8 = "Data.Int.Int8"
dtTypeToHaskell DT_QINT32 = "Data.Int.Int32"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QINT8 = "Data.Word.Word8"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QINT16 = "Data.Int.Int16"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QUINT16 = "Data.Word.Word16"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QUINT8 = "Data.Word.Word8"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_UINT32 = "Data.Word.Word32"
dtTypeToHaskell DT_UINT64 = "Data.Word.Word64"
dtTypeToHaskell DT_HALF = "Data.Word.Word16"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
dtTypeToHaskell DT_RESOURCE = "ResourceHandle"
dtTypeToHaskell DT_VARIANT = "Variant"
dtTypeToHaskell x :: DataType
x =
    String -> Text
Text.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ "Unsupported type in dtTypeToHaskell: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ DataType -> String
forall a. Show a => a -> String
show DataType
x

-- | haddockComment escapes TensorFlow doc strings into haddock.
-- TODO(gnezdo): deal with the markup.
haddockComment :: Text.Text -> Doc
haddockComment :: Text -> Doc
haddockComment = Text -> Doc
strictText

-- | Generate a multiline comment.  For example:
--   summary'
--   --
--   -- detail_line1
--   -- detail_line2
--   -- ...
multilineComment :: Text.Text -> Text.Text -> Doc
multilineComment :: Text -> Text -> Doc
multilineComment summary' :: Text
summary' detail :: Text
detail =
    Text -> Doc
haddockComment Text
summary' Doc -> Doc -> Doc
</>
    (Doc -> Doc) -> Text -> Doc
splitMultilineText Doc -> Doc
insertParagraphAndComment Text
detail
  where insertParagraphAndComment :: Doc -> Doc
insertParagraphAndComment x :: Doc
x = "--" Doc -> Doc -> Doc
</> "--" Doc -> Doc -> Doc
<+> Doc
x

-- | Converts the given multi-line detail string into
-- a multi-line haddock. Applies the given lead to the
-- first line. Returns an empty document for empty detail.
splitMultilineText :: (Doc -> Doc) -> Text.Text -> Doc
splitMultilineText :: (Doc -> Doc) -> Text -> Doc
splitMultilineText lead :: Doc -> Doc
lead detail :: Text
detail =
  case Text -> [Text]
Text.lines Text
detail of
    [] -> Doc
empty
    (l :: Text
l : ls :: [Text]
ls) -> [Doc] -> Doc
stack ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> Doc
lead (Text -> Doc
haddockComment Text
l)
                      Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
: (Text -> Doc) -> [Text] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (("--" Doc -> Doc -> Doc
<+>) (Doc -> Doc) -> (Text -> Doc) -> Text -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Doc
haddockComment) [Text]
ls

indentation :: Int
indentation :: Int
indentation = 4