-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}

module TensorFlow.NN
    ( sigmoidCrossEntropyWithLogits
    ) where

import Prelude hiding           ( log
                                , exp
                                )
import TensorFlow.Build         ( MonadBuild
                                , withNameScope
                                )
import TensorFlow.GenOps.Core   ( greaterEqual
                                , select
                                , log
                                , exp
                                )
import TensorFlow.Tensor        ( Tensor(..)
                                , render
                                , Value
                                )
import TensorFlow.Types         ( TensorType(..)
                                , OneOf
                                )
import TensorFlow.Ops           ( zerosLike
                                , add
                                , mul
                                , neg
                                )

-- | Computes sigmoid cross entropy given `logits`.
--
-- Measures the probability error in discrete classification tasks in which each
-- class is independent and not mutually exclusive.  For instance, one could
-- perform multilabel classification where a picture can contain both an elephant
-- and a dog at the same time.
--
-- For brevity, let `x = logits`, `z = targets`.  The logistic loss is
--
--        z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
--      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
--      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
--      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
--      = (1 - z) * x + log(1 + exp(-x))
--      = x - x * z + log(1 + exp(-x))
--
--  For x < 0, to avoid overflow in exp(-x), we reformulate the above
--
--        x - x * z + log(1 + exp(-x))
--      = log(exp(x)) - x * z + log(1 + exp(-x))
--      = - x * z + log(1 + exp(x))
--
--  Hence, to ensure stability and avoid overflow, the implementation uses this
--  equivalent formulation
--
--      max(x, 0) - x * z + log(1 + exp(-abs(x)))
--
--  `logits` and `targets` must have the same type and shape.
sigmoidCrossEntropyWithLogits
  :: (MonadBuild m, OneOf '[Float, Double] a, TensorType a, Num a)
     => Tensor Value a          -- ^ __logits__
     -> Tensor Value a          -- ^ __targets__
     -> m (Tensor Value a)
sigmoidCrossEntropyWithLogits :: Tensor Value a -> Tensor Value a -> m (Tensor Value a)
sigmoidCrossEntropyWithLogits logits :: Tensor Value a
logits targets :: Tensor Value a
targets = do
    let zeros :: Tensor Build a
zeros = Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
zerosLike Tensor Value a
logits
        cond :: Tensor Build Bool
cond = Tensor Value a
logits Tensor Value a -> Tensor Build a -> Tensor Build Bool
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
    Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build Bool
`greaterEqual` Tensor Build a
zeros
        relu_logits :: Tensor Build a
relu_logits = Tensor Build Bool
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
select Tensor Build Bool
cond Tensor Value a
logits Tensor Build a
zeros
        neg_abs_logits :: Tensor Build a
neg_abs_logits = Tensor Build Bool
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
select Tensor Build Bool
cond (Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Double, Float]
  t =>
Tensor v'1 t -> Tensor Build t
neg Tensor Value a
logits) Tensor Value a
logits
    Text -> m (Tensor Value a) -> m (Tensor Value a)
forall (m :: * -> *) a. MonadBuild m => Text -> m a -> m a
withNameScope "logistic_loss" (m (Tensor Value a) -> m (Tensor Value a))
-> m (Tensor Value a) -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ do
        Tensor Value a
left <- Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
relu_logits Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
- Tensor Value a
logits Tensor Value a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`mul` Tensor Value a
targets
        Tensor Value a
right <- Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
log (1 Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
+ Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
exp Tensor Build a
neg_abs_logits)
        Text -> m (Tensor Value a) -> m (Tensor Value a)
forall (m :: * -> *) a. MonadBuild m => Text -> m a -> m a
withNameScope "sigmoid_add" (m (Tensor Value a) -> m (Tensor Value a))
-> m (Tensor Value a) -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
left Tensor Value a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, ByteString, Int16, Int32, Int64,
    Int8, Word16, Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`add` Tensor Value a
right