[arrayfire] 153/408: BUGFIX: Fixed issues with mixed real and complex types
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:11:42 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.
commit 5df0554651c6b70b6fa831f1e315dd2822fd3507
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Wed Jul 15 18:30:49 2015 -0400
BUGFIX: Fixed issues with mixed real and complex types
- Making sure complex takes precedence to floats when doing binary operations
---
src/api/cpp/array.cpp | 25 +++++++++++++++++++++----
src/api/cpp/complex.cpp | 4 ++--
test/binary.cpp | 43 ++++++++++++++++++++++++++++---------------
3 files changed, 51 insertions(+), 21 deletions(-)
diff --git a/src/api/cpp/array.cpp b/src/api/cpp/array.cpp
index ba46f63..15fdfd2 100644
--- a/src/api/cpp/array.cpp
+++ b/src/api/cpp/array.cpp
@@ -842,12 +842,30 @@ namespace af
#undef ASSIGN_OP
#undef ASSIGN_TYPE
+af::dtype implicit_dtype(af::dtype scalar_type, af::dtype array_type)
+{
+ // If same, do not do anything
+ if (scalar_type == array_type) return scalar_type;
+
+ // If complex, return appropriate complex type
+ if (scalar_type == c32 || scalar_type == c64) {
+ if (array_type == f64 || array_type == c64) return c64;
+ return c32;
+ }
+
+ // If 64 bit precision, do not lose precision
+ if (array_type == f64 || array_type == c64 ||
+ array_type == f32 || array_type == c32 ) return array_type;
+
+ // Punt to C api for everything else
+ return scalar_type;
+}
+
#define BINARY_TYPE(TY, OP, func, dty) \
array operator OP(const array& plhs, const TY &value) \
{ \
af_array out; \
- af::dtype ty = plhs.type(); \
- af::dtype cty = plhs.isfloating() ? ty : dty; \
+ af::dtype cty = implicit_dtype(dty, plhs.type()); \
array cst = constant(value, plhs.dims(), cty); \
AF_THROW(func(&out, plhs.get(), cst.get(), gforGet())); \
return array(out); \
@@ -856,8 +874,7 @@ namespace af
{ \
const af_array rhs = other.get(); \
af_array out; \
- af::dtype ty = other.type(); \
- af::dtype cty = other.isfloating() ? ty : dty; \
+ af::dtype cty = implicit_dtype(dty, other.type()); \
array cst = constant(value, other.dims(), cty); \
AF_THROW(func(&out, cst.get(), rhs, gforGet())); \
return array(out); \
diff --git a/src/api/cpp/complex.cpp b/src/api/cpp/complex.cpp
index 7ec2c98..4e9b5d4 100644
--- a/src/api/cpp/complex.cpp
+++ b/src/api/cpp/complex.cpp
@@ -83,9 +83,9 @@ cdouble operator/(const cdouble &lhs, const cdouble &rhs)
{ return lhs OP cfloat (rhs); } \
cdouble operator OP(const cdouble &lhs, const double &rhs) \
{ return lhs OP cdouble(rhs); } \
- cfloat operator OP(const double &rhs, const cfloat &lhs) \
+ cfloat operator OP(const double &lhs, const cfloat &rhs) \
{ return cfloat (lhs) OP rhs; } \
- cdouble operator OP(const double &rhs, const cdouble &lhs) \
+ cdouble operator OP(const double &lhs, const cdouble &rhs) \
{ return cdouble(lhs) OP rhs; } \
cdouble operator OP(const cfloat &lhs, const cdouble &rhs) \
{ return cdouble(real(lhs), imag(lhs)) OP rhs; } \
diff --git a/test/binary.cpp b/test/binary.cpp
index 73dbb3f..4777487 100644
--- a/test/binary.cpp
+++ b/test/binary.cpp
@@ -70,7 +70,7 @@ af::array randgen(const int num, af::dtype ty)
Tb h_b = 3.0; \
af::array c = func(a, h_b); \
Ta *h_a = a.host<Ta>(); \
- Tc *h_c = c.host<Tc>(); \
+ Ta *h_c = c.host<Ta>(); \
for (int i = 0; i < num; i++) \
ASSERT_EQ(h_c[i], func(h_a[i], h_b)) << \
"for values: " << h_a[i] << "," << h_b << std::endl; \
@@ -97,8 +97,8 @@ af::array randgen(const int num, af::dtype ty)
} \
-#define BINARY_TESTS_NEAR(Ta, Tb, Tc, func, err) \
- TEST(BinaryTests, Test_##func##_##Ta##_##Tb) \
+#define BINARY_TESTS_NEAR_GENERAL(Ta, Tb, Tc, Td, Te,func, err) \
+ TEST(BinaryTestsFloating, Test_##func##_##Ta##_##Tb) \
{ \
if (noDoubleTests<Ta>()) return; \
if (noDoubleTests<Tb>()) return; \
@@ -120,7 +120,7 @@ af::array randgen(const int num, af::dtype ty)
delete[] h_c; \
} \
\
- TEST(BinaryTests, Test_##func##_##Ta##_##Tb##_left) \
+ TEST(BinaryTestsFloating, Test_##func##_##Ta##_##Tb##_left) \
{ \
if (noDoubleTests<Ta>()) return; \
if (noDoubleTests<Tb>()) return; \
@@ -130,15 +130,15 @@ af::array randgen(const int num, af::dtype ty)
Tb h_b = 0.3; \
af::array c = func(a, h_b); \
Ta *h_a = a.host<Ta>(); \
- Ta *h_c = c.host<Ta>(); \
+ Td *h_d = c.host<Td>(); \
for (int i = 0; i < num; i++) \
- MY_ASSERT_NEAR(h_c[i], func(h_a[i], h_b), err) << \
+ MY_ASSERT_NEAR(h_d[i], func(h_a[i], h_b), err) << \
"for values: " << h_a[i] << "," << h_b << std::endl; \
delete[] h_a; \
- delete[] h_c; \
+ delete[] h_d; \
} \
\
- TEST(BinaryTests, Test_##func##_##Ta##_##Tb##_right) \
+ TEST(BinaryTestsFloating, Test_##func##_##Ta##_##Tb##_right) \
{ \
if (noDoubleTests<Ta>()) return; \
if (noDoubleTests<Tb>()) return; \
@@ -149,14 +149,16 @@ af::array randgen(const int num, af::dtype ty)
af::array b = randgen(num, tb); \
af::array c = func(h_a, b); \
Tb *h_b = b.host<Tb>(); \
- Tb *h_c = c.host<Tb>(); \
+ Te *h_e = c.host<Te>(); \
for (int i = 0; i < num; i++) \
- MY_ASSERT_NEAR(h_c[i], func(h_a, h_b[i]), err) << \
+ MY_ASSERT_NEAR(h_e[i], func(h_a, h_b[i]), err) << \
"for values: " << h_a << "," << h_b[i] << std::endl; \
delete[] h_b; \
- delete[] h_c; \
+ delete[] h_e; \
} \
+#define BINARY_TESTS_NEAR(Ta, Tb, Tc, func, err) BINARY_TESTS_NEAR_GENERAL(Ta, Tb, Tc, Ta, Tc, func, err)
+
#define BINARY_TESTS_FLOAT(func) BINARY_TESTS(float, float, float, func)
#define BINARY_TESTS_DOUBLE(func) BINARY_TESTS(double, double, double, func)
#define BINARY_TESTS_CFLOAT(func) BINARY_TESTS(cfloat, cfloat, cfloat, func)
@@ -226,10 +228,21 @@ BINARY_TESTS_NEAR(cfloat, cdouble, cdouble, sub, 1e-5)
BINARY_TESTS_NEAR(cfloat, cdouble, cdouble, mul, 1e-5)
BINARY_TESTS_NEAR(cfloat, cdouble, cdouble, div, 1e-5)
-BINARY_TESTS_NEAR(cfloat, double, cdouble, add, 1e-5)
-BINARY_TESTS_NEAR(cfloat, double, cdouble, sub, 1e-5)
-BINARY_TESTS_NEAR(cfloat, double, cdouble, mul, 1e-5)
-BINARY_TESTS_NEAR(cfloat, double, cdouble, div, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(float, cfloat, cfloat, cfloat, cfloat, add, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(float, cfloat, cfloat, cfloat, cfloat, sub, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(float, cfloat, cfloat, cfloat, cfloat, mul, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(float, cfloat, cfloat, cfloat, cfloat, div, 1e-5)
+
+BINARY_TESTS_NEAR_GENERAL(double, cfloat, cdouble, cdouble, cfloat, add, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(double, cfloat, cdouble, cdouble, cfloat, sub, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(double, cfloat, cdouble, cdouble, cfloat, mul, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(double, cfloat, cdouble, cdouble, cfloat, div, 1e-5)
+
+BINARY_TESTS_NEAR_GENERAL(cfloat, double, cdouble, cfloat, cdouble, add, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(cfloat, double, cdouble, cfloat, cdouble, sub, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(cfloat, double, cdouble, cfloat, cdouble, mul, 1e-5)
+BINARY_TESTS_NEAR_GENERAL(cfloat, double, cdouble, cfloat, cdouble, div, 1e-5)
+
#define BITOP(func, T, op) \
TEST(BinaryTests, Test_##func##_##T) \
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list