-- Copyright 2020 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}

module TensorFlow.Convolution
    ( Padding(..)
    , DataFormat(..)
    , conv2D
    , conv2D'
    , conv2DBackpropFilter
    , conv2DBackpropFilter'
    , conv2DBackpropInput
    , conv2DBackpropInput'
    , conv3D
    , conv3D'
    , conv3DBackpropFilter
    , conv3DBackpropFilter'
    , conv3DBackpropFilterV2
    , conv3DBackpropFilterV2'
    , conv3DBackpropInput
    , conv3DBackpropInput'
    , conv3DBackpropInputV2
    , conv3DBackpropInputV2'
    , depthwiseConv2dNative
    , depthwiseConv2dNative'
    , depthwiseConv2dNativeBackpropFilter
    , depthwiseConv2dNativeBackpropFilter'
    , depthwiseConv2dNativeBackpropInput
    , depthwiseConv2dNativeBackpropInput'
    ) where

import Data.Word (Word16)
import Data.Int (Int32,Int64)
import Data.ByteString (ByteString)
import Lens.Family2 ((.~))

import qualified TensorFlow.BuildOp as TF
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF

-- TODO: Support other convolution parameters such as stride.

-- | Convolution padding.
data Padding = 
        -- | output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
        PaddingValid
        -- | output_spatial_shape[i] = ceil(
        --      (input_spatial_shape[i] -
        --          (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i])
      | PaddingSame

paddingToByteString :: Padding -> ByteString
paddingToByteString :: Padding -> ByteString
paddingToByteString x :: Padding
x = case Padding
x of
    PaddingValid -> "VALID"
    PaddingSame  -> "SAME"

-- | Matrix format.
data DataFormat = ChannelLast  -- ^ Channel is the last dimension (e.g. NWC, NHWC, NDHWC)
                | ChannelFirst -- ^ Channel is the first dimension after N (e.g. NCW, NCHW, NCDHW)

-- TODO: Address 1D convolution.

dataFormat2D :: DataFormat -> ByteString
dataFormat2D :: DataFormat -> ByteString
dataFormat2D x :: DataFormat
x = case DataFormat
x of
    ChannelLast  -> "NHWC"
    ChannelFirst -> "NCHW"

dataFormat3D :: DataFormat -> ByteString
dataFormat3D :: DataFormat -> ByteString
dataFormat3D x :: DataFormat
x = case DataFormat
x of
    ChannelLast  -> "NDHWC"
    ChannelFirst -> "NCDHW"

-- | 2D Convolution with default parameters.
conv2D :: TF.OneOf '[Word16, Double, Float] t
       => TF.Tensor v1 t -- ^ input
       -> TF.Tensor v2 t -- ^ filter
       -> TF.Tensor TF.Build t -- ^ output
conv2D :: Tensor v1 t -> Tensor v2 t -> Tensor Build t
conv2D = OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor Build t
forall t (v1 :: * -> *) (v2 :: * -> *).
OneOf '[Word16, Double, Float] t =>
OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor Build t
conv2D' OpParams
forall a. a -> a
id Padding
PaddingValid DataFormat
ChannelLast

conv2D' :: TF.OneOf '[Word16, Double, Float] t
        => TF.OpParams
        -> Padding
        -> DataFormat
        -> TF.Tensor v1 t -- ^ input
        -> TF.Tensor v2 t -- ^ filter
        -> TF.Tensor TF.Build t -- ^ output
conv2D' :: OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor Build t
conv2D' params :: OpParams
params padding :: Padding
padding dataformat :: DataFormat
dataformat = OpParams
-> ByteString -> Tensor v1 t -> Tensor v2 t -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Word16, Double, Float] t =>
OpParams
-> ByteString -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
TF.conv2D'
    (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
TF.opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ DataFormat -> ByteString
dataFormat2D DataFormat
dataformat))
    (Padding -> ByteString
paddingToByteString Padding
padding)

-- | 2D convolution backpropagation filter with default parameters.
conv2DBackpropFilter :: TF.OneOf '[Word16, Double, Float] t
                     => TF.Tensor v1 t        -- ^ input
                     -> TF.Tensor v2 Int32    -- ^ filter_sizes
                     -> TF.Tensor v3 t        -- ^ out_backprop
                     -> TF.Tensor TF.Build t  -- ^ output
conv2DBackpropFilter :: Tensor v1 t -> Tensor v2 Int32 -> Tensor v3 t -> Tensor Build t
conv2DBackpropFilter = OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
forall t (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *).
OneOf '[Word16, Double, Float] t =>
OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
conv2DBackpropFilter' OpParams
forall a. a -> a
id Padding
PaddingValid DataFormat
ChannelLast

conv2DBackpropFilter' :: TF.OneOf '[Word16, Double, Float] t
                      => TF.OpParams
                      -> Padding
                      -> DataFormat
                      -> TF.Tensor v1 t        -- ^ input
                      -> TF.Tensor v2 Int32    -- ^ filter_sizes
                      -> TF.Tensor v3 t        -- ^ out_backprop
                      -> TF.Tensor TF.Build t  -- ^ output
conv2DBackpropFilter' :: OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
conv2DBackpropFilter' params :: OpParams
params padding :: Padding
padding dataformat :: DataFormat
dataformat = OpParams
-> ByteString
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
OpParams
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 Int32
-> Tensor v'3 t
-> Tensor Build t
TF.conv2DBackpropFilter'
    (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
TF.opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ DataFormat -> ByteString
dataFormat2D DataFormat
dataformat))
    (Padding -> ByteString
paddingToByteString Padding
padding)

-- | 2D convolution backpropagation input with default parameters.
conv2DBackpropInput :: TF.OneOf '[Word16, Double, Float] t
                    => TF.Tensor v1 Int32    -- ^ input_sizes
                    -> TF.Tensor v2 t        -- ^ filter
                    -> TF.Tensor v3 t        -- ^ out_backprop
                    -> TF.Tensor TF.Build t  -- ^ output
conv2DBackpropInput :: Tensor v1 Int32 -> Tensor v2 t -> Tensor v3 t -> Tensor Build t
conv2DBackpropInput = OpParams
-> Padding
-> DataFormat
-> Tensor v1 Int32
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
forall t (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *).
OneOf '[Word16, Double, Float] t =>
OpParams
-> Padding
-> DataFormat
-> Tensor v1 Int32
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
conv2DBackpropInput' OpParams
forall a. a -> a
id Padding
PaddingValid DataFormat
ChannelLast

conv2DBackpropInput' :: TF.OneOf '[Word16, Double, Float] t
                     => TF.OpParams
                     -> Padding
                     -> DataFormat
                     -> TF.Tensor v1 Int32    -- ^ input_sizes
                     -> TF.Tensor v2 t        -- ^ filter
                     -> TF.Tensor v3 t        -- ^ out_backprop
                     -> TF.Tensor TF.Build t  -- ^ output
conv2DBackpropInput' :: OpParams
-> Padding
-> DataFormat
-> Tensor v1 Int32
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
conv2DBackpropInput' params :: OpParams
params padding :: Padding
padding dataformat :: DataFormat
dataformat = OpParams
-> ByteString
-> Tensor v1 Int32
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Int32, Word16, Double, Float] t =>
OpParams
-> ByteString
-> Tensor v'1 Int32
-> Tensor v'2 t
-> Tensor v'3 t
-> Tensor Build t
TF.conv2DBackpropInput'
    (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
TF.opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ DataFormat -> ByteString
dataFormat2D DataFormat
dataformat))
    (Padding -> ByteString
paddingToByteString Padding
padding)

-- | 3D Convolution with default parameters.
conv3D :: TF.OneOf '[Word16, Double, Float] t
       => TF.Tensor v1 t -- ^ input
       -> TF.Tensor v2 t -- ^ filter
       -> TF.Tensor TF.Build t -- ^ output
conv3D :: Tensor v1 t -> Tensor v2 t -> Tensor Build t
conv3D = OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor Build t
forall t (v1 :: * -> *) (v2 :: * -> *).
OneOf '[Word16, Double, Float] t =>
OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor Build t
conv3D' OpParams
forall a. a -> a
id Padding
PaddingValid DataFormat
ChannelLast

conv3D' :: TF.OneOf '[Word16, Double, Float] t
        => TF.OpParams
        -> Padding
        -> DataFormat
        -> TF.Tensor v1 t -- ^ input
        -> TF.Tensor v2 t -- ^ filter
        -> TF.Tensor TF.Build t -- ^ output
conv3D' :: OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor Build t
conv3D' params :: OpParams
params padding :: Padding
padding dataformat :: DataFormat
dataformat = OpParams
-> ByteString -> Tensor v1 t -> Tensor v2 t -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
OpParams
-> ByteString -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
TF.conv3D'
    (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
TF.opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ DataFormat -> ByteString
dataFormat3D DataFormat
dataformat))
    (Padding -> ByteString
paddingToByteString Padding
padding)

-- | 3D convolution backpropagation filter with default parameters.
conv3DBackpropFilter :: TF.OneOf '[Word16, Double, Float] t
                     => TF.Tensor v1 t        -- ^ input
                     -> TF.Tensor v2 t        -- ^ filter
                     -> TF.Tensor v3 t        -- ^ out_backprop
                     -> TF.Tensor TF.Build t  -- ^ output
conv3DBackpropFilter :: Tensor v1 t -> Tensor v2 t -> Tensor v3 t -> Tensor Build t
conv3DBackpropFilter = OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
forall t (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *).
OneOf '[Word16, Double, Float] t =>
OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
conv3DBackpropFilter' OpParams
forall a. a -> a
id Padding
PaddingValid DataFormat
ChannelLast

conv3DBackpropFilter' :: TF.OneOf '[Word16, Double, Float] t
                      => TF.OpParams
                      -> Padding
                      -> DataFormat
                      -> TF.Tensor v1 t        -- ^ input
                      -> TF.Tensor v2 t        -- ^ filter
                      -> TF.Tensor v3 t        -- ^ out_backprop
                      -> TF.Tensor TF.Build t  -- ^ output
conv3DBackpropFilter' :: OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
conv3DBackpropFilter' params :: OpParams
params padding :: Padding
padding dataformat :: DataFormat
dataformat = OpParams
-> ByteString
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
OpParams
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 t
-> Tensor v'3 t
-> Tensor Build t
TF.conv3DBackpropFilter'
    (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
TF.opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ DataFormat -> ByteString
dataFormat3D DataFormat
dataformat))
    (Padding -> ByteString
paddingToByteString Padding
padding)

-- | 3D convolution backpropagation filter with default parameters.
conv3DBackpropFilterV2 :: TF.OneOf '[Word16, Double, Float] t
                     => TF.Tensor v1 t        -- ^ input
                     -> TF.Tensor v2 Int32    -- ^ filter_sizes
                     -> TF.Tensor v3 t        -- ^ out_backprop
                     -> TF.Tensor TF.Build t  -- ^ output
conv3DBackpropFilterV2 :: Tensor v1 t -> Tensor v2 Int32 -> Tensor v3 t -> Tensor Build t
conv3DBackpropFilterV2 = OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
forall t (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *).
OneOf '[Word16, Double, Float] t =>
OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
conv3DBackpropFilterV2' OpParams
forall a. a -> a
id Padding
PaddingValid DataFormat
ChannelLast

conv3DBackpropFilterV2' :: TF.OneOf '[Word16, Double, Float] t
                      => TF.OpParams
                      -> Padding
                      -> DataFormat
                      -> TF.Tensor v1 t        -- ^ input
                      -> TF.Tensor v2 Int32    -- ^ filter_sizes
                      -> TF.Tensor v3 t        -- ^ out_backprop
                      -> TF.Tensor TF.Build t  -- ^ output
conv3DBackpropFilterV2' :: OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
conv3DBackpropFilterV2' params :: OpParams
params padding :: Padding
padding dataformat :: DataFormat
dataformat = OpParams
-> ByteString
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
OpParams
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 Int32
-> Tensor v'3 t
-> Tensor Build t
TF.conv3DBackpropFilterV2'
    (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
TF.opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ DataFormat -> ByteString
dataFormat3D DataFormat
dataformat))
    (Padding -> ByteString
paddingToByteString Padding
padding)

-- | 3D convolution backpropagation input with default parameters.
conv3DBackpropInput :: TF.OneOf '[Word16, Double, Float] t
                    => TF.Tensor v1 t        -- ^ input
                    -> TF.Tensor v2 t        -- ^ filter
                    -> TF.Tensor v3 t        -- ^ out_backprop
                    -> TF.Tensor TF.Build t  -- ^ output
conv3DBackpropInput :: Tensor v1 t -> Tensor v2 t -> Tensor v3 t -> Tensor Build t
conv3DBackpropInput = OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
forall t (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *).
OneOf '[Word16, Double, Float] t =>
OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
conv3DBackpropInput' OpParams
forall a. a -> a
id Padding
PaddingValid DataFormat
ChannelLast

conv3DBackpropInput' :: TF.OneOf '[Word16, Double, Float] t
                     => TF.OpParams
                     -> Padding
                     -> DataFormat
                     -> TF.Tensor v1 t        -- ^ input
                     -> TF.Tensor v2 t        -- ^ filter
                     -> TF.Tensor v3 t        -- ^ out_backprop
                     -> TF.Tensor TF.Build t  -- ^ output
conv3DBackpropInput' :: OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
conv3DBackpropInput' params :: OpParams
params padding :: Padding
padding dataformat :: DataFormat
dataformat = OpParams
-> ByteString
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
OpParams
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 t
-> Tensor v'3 t
-> Tensor Build t
TF.conv3DBackpropInput'
    (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
TF.opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ DataFormat -> ByteString
dataFormat3D DataFormat
dataformat))
    (Padding -> ByteString
paddingToByteString Padding
padding)

-- | 3D convolution backpropagation input with default parameters.
conv3DBackpropInputV2 :: (TF.OneOf '[Word16, Double, Float] t, TF.OneOf '[Int32, Int64] tshape)
                    => TF.Tensor v1 tshape   -- ^ input_sizes
                    -> TF.Tensor v2 t        -- ^ filter
                    -> TF.Tensor v3 t        -- ^ out_backprop
                    -> TF.Tensor TF.Build t  -- ^ output
conv3DBackpropInputV2 :: Tensor v1 tshape -> Tensor v2 t -> Tensor v3 t -> Tensor Build t
conv3DBackpropInputV2 = OpParams
-> Padding
-> DataFormat
-> Tensor v1 tshape
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
forall t tshape (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *).
(OneOf '[Word16, Double, Float] t, OneOf '[Int32, Int64] tshape) =>
OpParams
-> Padding
-> DataFormat
-> Tensor v1 tshape
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
conv3DBackpropInputV2' OpParams
forall a. a -> a
id Padding
PaddingValid DataFormat
ChannelLast

conv3DBackpropInputV2' :: (TF.OneOf '[Word16, Double, Float] t, TF.OneOf '[Int32, Int64] tshape)
                     => TF.OpParams
                     -> Padding
                     -> DataFormat
                     -> TF.Tensor v1 tshape   -- ^ input_sizes
                     -> TF.Tensor v2 t        -- ^ filter
                     -> TF.Tensor v3 t        -- ^ out_backprop
                     -> TF.Tensor TF.Build t  -- ^ output
conv3DBackpropInputV2' :: OpParams
-> Padding
-> DataFormat
-> Tensor v1 tshape
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
conv3DBackpropInputV2' params :: OpParams
params padding :: Padding
padding dataformat :: DataFormat
dataformat = OpParams
-> ByteString
-> Tensor v1 tshape
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t tshape.
(OneOf '[Word16, Double, Float] t, OneOf '[Int32, Int64] tshape) =>
OpParams
-> ByteString
-> Tensor v'1 tshape
-> Tensor v'2 t
-> Tensor v'3 t
-> Tensor Build t
TF.conv3DBackpropInputV2'
    (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
TF.opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ DataFormat -> ByteString
dataFormat3D DataFormat
dataformat))
    (Padding -> ByteString
paddingToByteString Padding
padding)

-- | Depth-wise 2D convolution native with default parameters.
depthwiseConv2dNative :: TF.OneOf '[Word16, Double, Float] t
                      => TF.Tensor v1 t -- ^ input
                      -> TF.Tensor v2 t -- ^ filter
                      -> TF.Tensor TF.Build t -- ^ output
depthwiseConv2dNative :: Tensor v1 t -> Tensor v2 t -> Tensor Build t
depthwiseConv2dNative = OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor Build t
forall t (v1 :: * -> *) (v2 :: * -> *).
OneOf '[Word16, Double, Float] t =>
OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor Build t
depthwiseConv2dNative' OpParams
forall a. a -> a
id Padding
PaddingValid DataFormat
ChannelLast

depthwiseConv2dNative' :: TF.OneOf '[Word16, Double, Float] t
                       => TF.OpParams
                       -> Padding
                       -> DataFormat
                       -> TF.Tensor v1 t -- ^ input
                       -> TF.Tensor v2 t -- ^ filter
                       -> TF.Tensor TF.Build t -- ^ output
depthwiseConv2dNative' :: OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 t
-> Tensor Build t
depthwiseConv2dNative' params :: OpParams
params padding :: Padding
padding dataformat :: DataFormat
dataformat = OpParams
-> ByteString -> Tensor v1 t -> Tensor v2 t -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
OpParams
-> ByteString -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
TF.depthwiseConv2dNative'
    (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
TF.opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ DataFormat -> ByteString
dataFormat2D DataFormat
dataformat))
    (Padding -> ByteString
paddingToByteString Padding
padding)

-- | Depth-wise 2D convolution native backpropagation filter with default parameters.
depthwiseConv2dNativeBackpropFilter :: TF.OneOf '[Word16, Double, Float] t
                                    => TF.Tensor v1 t     -- ^ input
                                    -> TF.Tensor v2 Int32 -- ^ filter_sizes
                                    -> TF.Tensor v3 t     -- ^ out_backprop
                                    -> TF.Tensor TF.Build t  -- ^ output
depthwiseConv2dNativeBackpropFilter :: Tensor v1 t -> Tensor v2 Int32 -> Tensor v3 t -> Tensor Build t
depthwiseConv2dNativeBackpropFilter = OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
forall t (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *).
OneOf '[Word16, Double, Float] t =>
OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
depthwiseConv2dNativeBackpropFilter' OpParams
forall a. a -> a
id Padding
PaddingValid DataFormat
ChannelLast

depthwiseConv2dNativeBackpropFilter' :: TF.OneOf '[Word16, Double, Float] t
                                     => TF.OpParams
                                     -> Padding
                                     -> DataFormat
                                     -> TF.Tensor v1 t        -- ^ input
                                     -> TF.Tensor v2 Int32    -- ^ filter_sizes
                                     -> TF.Tensor v3 t        -- ^ out_backprop
                                     -> TF.Tensor TF.Build t  -- ^ output
depthwiseConv2dNativeBackpropFilter' :: OpParams
-> Padding
-> DataFormat
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
depthwiseConv2dNativeBackpropFilter' params :: OpParams
params padding :: Padding
padding dataformat :: DataFormat
dataformat = OpParams
-> ByteString
-> Tensor v1 t
-> Tensor v2 Int32
-> Tensor v3 t
-> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
OpParams
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 Int32
-> Tensor v'3 t
-> Tensor Build t
TF.depthwiseConv2dNativeBackpropFilter'
    (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
TF.opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ DataFormat -> ByteString
dataFormat2D DataFormat
dataformat))
    (Padding -> ByteString
paddingToByteString Padding
padding)

-- | Depth-wise 2D convolution native backpropagation input with default parameters.
depthwiseConv2dNativeBackpropInput :: TF.OneOf '[Word16, Double, Float] t
                                   => TF.Tensor v1 Int32 -- ^ input_sizes
                                   -> TF.Tensor v2 t     -- ^ input
                                   -> TF.Tensor v3 t     -- ^ out_backprop
                                   -> TF.Tensor TF.Build t  -- ^ output
depthwiseConv2dNativeBackpropInput :: Tensor v1 Int32 -> Tensor v2 t -> Tensor v3 t -> Tensor Build t
depthwiseConv2dNativeBackpropInput = OpParams
-> Padding
-> DataFormat
-> Tensor v1 Int32
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
forall t (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *).
OneOf '[Word16, Double, Float] t =>
OpParams
-> Padding
-> DataFormat
-> Tensor v1 Int32
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
depthwiseConv2dNativeBackpropInput' OpParams
forall a. a -> a
id Padding
PaddingValid DataFormat
ChannelLast


depthwiseConv2dNativeBackpropInput' :: TF.OneOf '[Word16, Double, Float] t
                                    => TF.OpParams
                                    -> Padding
                                    -> DataFormat
                                    -> TF.Tensor v1 Int32 -- ^ input_sizes
                                    -> TF.Tensor v2 t     -- ^ input
                                    -> TF.Tensor v3 t     -- ^ out_backprop
                                    -> TF.Tensor TF.Build t  -- ^ output
depthwiseConv2dNativeBackpropInput' :: OpParams
-> Padding
-> DataFormat
-> Tensor v1 Int32
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
depthwiseConv2dNativeBackpropInput' params :: OpParams
params padding :: Padding
padding dataformat :: DataFormat
dataformat = OpParams
-> ByteString
-> Tensor v1 Int32
-> Tensor v2 t
-> Tensor v3 t
-> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
OpParams
-> ByteString
-> Tensor v'1 Int32
-> Tensor v'2 t
-> Tensor v'3 t
-> Tensor Build t
TF.depthwiseConv2dNativeBackpropInput'
    (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
TF.opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ DataFormat -> ByteString
dataFormat2D DataFormat
dataformat))
    (Padding -> ByteString
paddingToByteString Padding
padding)