[arrayfire] 136/408: BUGFIX/TEST: Fixing bug in rank. Added appropriate tests

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:11:39 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.

commit 8253205dea182f5674d01872fda5d737a614bfe8
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Mon Jul 13 13:29:15 2015 -0400

    BUGFIX/TEST: Fixing bug in rank. Added appropriate tests
---
 src/api/c/rank.cpp | 14 +++++----
 test/rank.cpp      | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 97 insertions(+), 5 deletions(-)

diff --git a/src/api/c/rank.cpp b/src/api/c/rank.cpp
index c1eeb73..197d2c8 100644
--- a/src/api/c/rank.cpp
+++ b/src/api/c/rank.cpp
@@ -17,6 +17,7 @@
 #include <qr.hpp>
 #include <reduce.hpp>
 #include <logic.hpp>
+#include <complex.hpp>
 
 using af::dim4;
 using namespace detail;
@@ -24,21 +25,24 @@ using namespace detail;
 template<typename T>
 static inline uint rank(const af_array in, double tol)
 {
+    typedef typename af::dtype_traits<T>::base_type BT;
     Array<T> In = getArray<T>(in);
 
-    Array<T> r = createEmptyArray<T>(dim4());
+    Array<BT> R = createEmptyArray<BT>(dim4());
 
-    // Scoping to get rid of q and t as they are not necessary
+    // Scoping to get rid of q, r and t as they are not necessary
     {
         Array<T> q = createEmptyArray<T>(dim4());
+        Array<T> r = createEmptyArray<T>(dim4());
         Array<T> t = createEmptyArray<T>(dim4());
         qr(q, r, t, In);
+
+        R = abs<BT, T>(r);
     }
 
-    Array<T> val = createValueArray<T>(r.dims(), scalar<T>(tol));
-    Array<char> gt = logicOp<T, af_gt_t>(r, val, val.dims());
+    Array<BT> val = createValueArray<BT>(R.dims(), scalar<BT>(tol));
+    Array<char> gt = logicOp<BT, af_gt_t>(R, val, val.dims());
     Array<char> at = reduce<af_or_t, char, char>(gt, 1);
-
     return reduce_all<af_notzero_t, char, uint>(at);
 }
 
diff --git a/test/rank.cpp b/test/rank.cpp
new file mode 100644
index 0000000..3ecf497
--- /dev/null
+++ b/test/rank.cpp
@@ -0,0 +1,88 @@
+/*******************************************************
+ * Copyright (c) 2014, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+
+#include <gtest/gtest.h>
+#include <arrayfire.h>
+#include <af/dim4.hpp>
+#include <af/defines.h>
+#include <af/traits.hpp>
+#include <vector>
+#include <iostream>
+#include <complex>
+#include <string>
+#include <testHelpers.hpp>
+
+using std::vector;
+using std::string;
+using std::cout;
+using std::endl;
+using af::cfloat;
+using af::cdouble;
+
+template<typename T>
+class Rank : public ::testing::Test
+{
+};
+
+typedef ::testing::Types<float, double, af::cfloat, af::cdouble> TestTypes;
+TYPED_TEST_CASE(Rank, TestTypes);
+
+template<typename T>
+void rankSmall()
+{
+    if (noDoubleTests<T>()) return;
+
+    T ha[] = {1, 4, 7, 2, 5, 8, 3, 6, 20};
+    af::array a(3, 3, ha);
+
+    ASSERT_EQ(3, (int)af::rank(a));
+}
+
+template<typename T>
+void rankBig(const int num)
+{
+    if (noDoubleTests<T>()) return;
+    af::dtype dt = (af::dtype)af::dtype_traits<T>::af_type;
+    af::array a = af::randu(num, num, dt);
+    ASSERT_EQ(num, (int)af::rank(a));
+
+    af::array b = af::randu(num, num/2, dt);
+    ASSERT_EQ(num/2, (int)af::rank(b));
+    ASSERT_EQ(num/2, (int)af::rank(transpose(b)));
+}
+
+template<typename T>
+void rankLow(const int num)
+{
+    if (noDoubleTests<T>()) return;
+    af::dtype dt = (af::dtype)af::dtype_traits<T>::af_type;
+
+    af::array a = af::randu(3 * num, num, dt);
+    af::array b = af::randu(3 * num, num, dt);
+    af::array c = a + 0.2 * b;
+    af::array in = join(1, a, b, c);
+
+    // The last third is just a linear combination of first and second thirds
+    ASSERT_EQ(2 * num, (int)af::rank(in));
+}
+
+TYPED_TEST(Rank, small)
+{
+    rankSmall<TypeParam>();
+}
+
+TYPED_TEST(Rank, big)
+{
+    rankBig<TypeParam>(1024);
+}
+
+TYPED_TEST(Rank, low)
+{
+    rankBig<TypeParam>(512);
+}

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list