[arrayfire] 153/408: BUGFIX: Fixed issues with mixed real and complex types

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:11:42 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.

commit 5df0554651c6b70b6fa831f1e315dd2822fd3507
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Wed Jul 15 18:30:49 2015 -0400

    BUGFIX: Fixed issues with mixed real and complex types
    
    - Making sure complex takes precedence to floats when doing binary operations
---
 src/api/cpp/array.cpp   | 25 +++++++++++++++++++++----
 src/api/cpp/complex.cpp |  4 ++--
 test/binary.cpp         | 43 ++++++++++++++++++++++++++++---------------
 3 files changed, 51 insertions(+), 21 deletions(-)

diff --git a/src/api/cpp/array.cpp b/src/api/cpp/array.cpp
index ba46f63..15fdfd2 100644
--- a/src/api/cpp/array.cpp
+++ b/src/api/cpp/array.cpp
@@ -842,12 +842,30 @@ namespace af
 #undef ASSIGN_OP
 #undef ASSIGN_TYPE
 
+af::dtype implicit_dtype(af::dtype scalar_type, af::dtype array_type)
+{
+    // If same, do not do anything
+    if (scalar_type == array_type) return scalar_type;
+
+    // If complex, return appropriate complex type
+    if (scalar_type == c32 || scalar_type == c64) {
+        if (array_type == f64 || array_type == c64) return c64;
+        return c32;
+    }
+
+    // If 64 bit precision, do not lose precision
+    if (array_type == f64 || array_type == c64 ||
+        array_type == f32 || array_type == c32 ) return array_type;
+
+    // Punt to C api for everything else
+    return scalar_type;
+}
+
 #define BINARY_TYPE(TY, OP, func, dty)                          \
     array operator OP(const array& plhs, const TY &value)       \
     {                                                           \
         af_array out;                                           \
-        af::dtype ty = plhs.type();                             \
-        af::dtype cty = plhs.isfloating() ? ty : dty;           \
+        af::dtype cty = implicit_dtype(dty, plhs.type());       \
         array cst = constant(value, plhs.dims(), cty);          \
         AF_THROW(func(&out, plhs.get(), cst.get(), gforGet())); \
         return array(out);                                      \
@@ -856,8 +874,7 @@ namespace af
     {                                                           \
         const af_array rhs = other.get();                       \
         af_array out;                                           \
-        af::dtype ty = other.type();                            \
-        af::dtype cty = other.isfloating() ? ty : dty;          \
+        af::dtype cty = implicit_dtype(dty, other.type());      \
         array cst = constant(value, other.dims(), cty);         \
         AF_THROW(func(&out, cst.get(), rhs, gforGet()));        \
         return array(out);                                      \
diff --git a/src/api/cpp/complex.cpp b/src/api/cpp/complex.cpp
index 7ec2c98..4e9b5d4 100644
--- a/src/api/cpp/complex.cpp
+++ b/src/api/cpp/complex.cpp
@@ -83,9 +83,9 @@ cdouble operator/(const cdouble &lhs, const cdouble &rhs)
     { return lhs OP cfloat (rhs); }                             \
     cdouble operator OP(const cdouble &lhs, const double  &rhs) \
     { return lhs OP cdouble(rhs); }                             \
-    cfloat  operator OP(const double  &rhs, const cfloat  &lhs) \
+    cfloat  operator OP(const double  &lhs, const cfloat  &rhs) \
     { return cfloat (lhs) OP rhs; }                             \
-    cdouble operator OP(const double  &rhs, const cdouble &lhs) \
+    cdouble operator OP(const double  &lhs, const cdouble &rhs) \
     { return cdouble(lhs) OP rhs; }                             \
     cdouble operator OP(const cfloat  &lhs, const cdouble &rhs) \
     { return cdouble(real(lhs), imag(lhs)) OP rhs; }            \
diff --git a/test/binary.cpp b/test/binary.cpp
index 73dbb3f..4777487 100644
--- a/test/binary.cpp
+++ b/test/binary.cpp
@@ -70,7 +70,7 @@ af::array randgen(const int num, af::dtype ty)
         Tb h_b = 3.0;                                                   \
         af::array c = func(a, h_b);                                     \
         Ta *h_a = a.host<Ta>();                                         \
-        Tc *h_c = c.host<Tc>();                                         \
+        Ta *h_c = c.host<Ta>();                                         \
         for (int i = 0; i < num; i++)                                   \
             ASSERT_EQ(h_c[i], func(h_a[i], h_b)) <<                     \
                 "for values: " << h_a[i]  << "," << h_b << std::endl;   \
@@ -97,8 +97,8 @@ af::array randgen(const int num, af::dtype ty)
     }                                                                   \
 
 
-#define BINARY_TESTS_NEAR(Ta, Tb, Tc, func, err)                        \
-    TEST(BinaryTests, Test_##func##_##Ta##_##Tb)                        \
+#define BINARY_TESTS_NEAR_GENERAL(Ta, Tb, Tc, Td, Te,func, err)         \
+    TEST(BinaryTestsFloating, Test_##func##_##Ta##_##Tb)                \
     {                                                                   \
         if (noDoubleTests<Ta>()) return;                                \
         if (noDoubleTests<Tb>()) return;                                \
@@ -120,7 +120,7 @@ af::array randgen(const int num, af::dtype ty)
         delete[] h_c;                                                   \
     }                                                                   \
                                                                         \
-    TEST(BinaryTests, Test_##func##_##Ta##_##Tb##_left)                 \
+    TEST(BinaryTestsFloating, Test_##func##_##Ta##_##Tb##_left)         \
     {                                                                   \
         if (noDoubleTests<Ta>()) return;                                \
         if (noDoubleTests<Tb>()) return;                                \
@@ -130,15 +130,15 @@ af::array randgen(const int num, af::dtype ty)
         Tb h_b = 0.3;                                                   \
         af::array c = func(a, h_b);                                     \
         Ta *h_a = a.host<Ta>();                                         \
-        Ta *h_c = c.host<Ta>();                                         \
+        Td *h_d = c.host<Td>();                                         \
         for (int i = 0; i < num; i++)                                   \
-            MY_ASSERT_NEAR(h_c[i], func(h_a[i], h_b), err) <<           \
+            MY_ASSERT_NEAR(h_d[i], func(h_a[i], h_b), err) <<           \
                 "for values: " << h_a[i]  << "," << h_b << std::endl;   \
         delete[] h_a;                                                   \
-        delete[] h_c;                                                   \
+        delete[] h_d;                                                   \
     }                                                                   \
                                                                         \
-    TEST(BinaryTests, Test_##func##_##Ta##_##Tb##_right)                \
+    TEST(BinaryTestsFloating, Test_##func##_##Ta##_##Tb##_right)        \
     {                                                                   \
         if (noDoubleTests<Ta>()) return;                                \
         if (noDoubleTests<Tb>()) return;                                \
@@ -149,14 +149,16 @@ af::array randgen(const int num, af::dtype ty)
         af::array b = randgen(num, tb);                                 \
         af::array c = func(h_a, b);                                     \
         Tb *h_b = b.host<Tb>();                                         \
-        Tb *h_c = c.host<Tb>();                                         \
+        Te *h_e = c.host<Te>();                                         \
         for (int i = 0; i < num; i++)                                   \
-            MY_ASSERT_NEAR(h_c[i], func(h_a, h_b[i]), err) <<           \
+            MY_ASSERT_NEAR(h_e[i], func(h_a, h_b[i]), err) <<           \
                 "for values: " << h_a  << "," << h_b[i] << std::endl;   \
         delete[] h_b;                                                   \
-        delete[] h_c;                                                   \
+        delete[] h_e;                                                   \
     }                                                                   \
 
+#define BINARY_TESTS_NEAR(Ta, Tb, Tc, func, err) BINARY_TESTS_NEAR_GENERAL(Ta, Tb, Tc, Ta, Tc, func, err)
+
 #define BINARY_TESTS_FLOAT(func) BINARY_TESTS(float, float, float, func)
 #define BINARY_TESTS_DOUBLE(func) BINARY_TESTS(double, double, double, func)
 #define BINARY_TESTS_CFLOAT(func) BINARY_TESTS(cfloat, cfloat, cfloat, func)
@@ -226,10 +228,21 @@ BINARY_TESTS_NEAR(cfloat, cdouble, cdouble, sub, 1e-5)
 BINARY_TESTS_NEAR(cfloat, cdouble, cdouble, mul, 1e-5)
 BINARY_TESTS_NEAR(cfloat, cdouble, cdouble, div, 1e-5)
 
-BINARY_TESTS_NEAR(cfloat, double, cdouble, add, 1e-5)
-BINARY_TESTS_NEAR(cfloat, double, cdouble, sub, 1e-5)
-BINARY_TESTS_NEAR(cfloat, double, cdouble, mul, 1e-5)
-BINARY_TESTS_NEAR(cfloat, double, cdouble, div, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(float, cfloat, cfloat, cfloat, cfloat, add, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(float, cfloat, cfloat, cfloat, cfloat, sub, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(float, cfloat, cfloat, cfloat, cfloat, mul, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(float, cfloat, cfloat, cfloat, cfloat, div, 1e-5)
+
+BINARY_TESTS_NEAR_GENERAL(double, cfloat, cdouble, cdouble, cfloat, add, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(double, cfloat, cdouble, cdouble, cfloat, sub, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(double, cfloat, cdouble, cdouble, cfloat, mul, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(double, cfloat, cdouble, cdouble, cfloat, div, 1e-5)
+
+BINARY_TESTS_NEAR_GENERAL(cfloat, double, cdouble, cfloat, cdouble, add, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(cfloat, double, cdouble, cfloat, cdouble, sub, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(cfloat, double, cdouble, cfloat, cdouble, mul, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(cfloat, double, cdouble, cfloat, cdouble, div, 1e-5)
+
 
 #define BITOP(func, T, op)                                  \
     TEST(BinaryTests, Test_##func##_##T)                    \

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list