-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE OverloadedStrings #-}

module TensorFlow.Test
    ( assertAllClose
    ) where

import qualified Data.Vector as V
import Test.HUnit ((@?))
import Test.HUnit.Lang (Assertion)
-- | Compares that the vectors are element-by-element equal within the given
-- tolerance. Raises an assertion and prints some information if not.
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
assertAllClose :: Vector Float -> Vector Float -> Assertion
assertAllClose xs :: Vector Float
xs ys :: Vector Float
ys = (Float -> Bool) -> Vector Float -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
<= Float
tol) ((Float -> Float -> Float)
-> Vector Float -> Vector Float -> Vector Float
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith Float -> Float -> Float
forall a. Num a => a -> a -> a
absDiff Vector Float
xs Vector Float
ys) Bool -> String -> Assertion
forall t.
(HasCallStack, AssertionPredicable t) =>
t -> String -> Assertion
@?
    "Difference > tolerance: \nxs: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Vector Float -> String
forall a. Show a => a -> String
show Vector Float
xs String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\nys: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Vector Float -> String
forall a. Show a => a -> String
show Vector Float
ys
        String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\ntolerance: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Float -> String
forall a. Show a => a -> String
show Float
tol
  where
      absDiff :: a -> a -> a
absDiff x :: a
x y :: a
y = a -> a
forall a. Num a => a -> a
abs (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y)
      tol :: Float
tol = 0.001 :: Float