[arrayfire] 136/408: BUGFIX/TEST: Fixing bug in rank. Added appropriate tests
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:11:39 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.
commit 8253205dea182f5674d01872fda5d737a614bfe8
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Mon Jul 13 13:29:15 2015 -0400
BUGFIX/TEST: Fixing bug in rank. Added appropriate tests
---
src/api/c/rank.cpp | 14 +++++----
test/rank.cpp | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 97 insertions(+), 5 deletions(-)
diff --git a/src/api/c/rank.cpp b/src/api/c/rank.cpp
index c1eeb73..197d2c8 100644
--- a/src/api/c/rank.cpp
+++ b/src/api/c/rank.cpp
@@ -17,6 +17,7 @@
#include <qr.hpp>
#include <reduce.hpp>
#include <logic.hpp>
+#include <complex.hpp>
using af::dim4;
using namespace detail;
@@ -24,21 +25,24 @@ using namespace detail;
template<typename T>
static inline uint rank(const af_array in, double tol)
{
+ typedef typename af::dtype_traits<T>::base_type BT;
Array<T> In = getArray<T>(in);
- Array<T> r = createEmptyArray<T>(dim4());
+ Array<BT> R = createEmptyArray<BT>(dim4());
- // Scoping to get rid of q and t as they are not necessary
+ // Scoping to get rid of q, r and t as they are not necessary
{
Array<T> q = createEmptyArray<T>(dim4());
+ Array<T> r = createEmptyArray<T>(dim4());
Array<T> t = createEmptyArray<T>(dim4());
qr(q, r, t, In);
+
+ R = abs<BT, T>(r);
}
- Array<T> val = createValueArray<T>(r.dims(), scalar<T>(tol));
- Array<char> gt = logicOp<T, af_gt_t>(r, val, val.dims());
+ Array<BT> val = createValueArray<BT>(R.dims(), scalar<BT>(tol));
+ Array<char> gt = logicOp<BT, af_gt_t>(R, val, val.dims());
Array<char> at = reduce<af_or_t, char, char>(gt, 1);
-
return reduce_all<af_notzero_t, char, uint>(at);
}
diff --git a/test/rank.cpp b/test/rank.cpp
new file mode 100644
index 0000000..3ecf497
--- /dev/null
+++ b/test/rank.cpp
@@ -0,0 +1,88 @@
+/*******************************************************
+ * Copyright (c) 2014, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+
+#include <gtest/gtest.h>
+#include <arrayfire.h>
+#include <af/dim4.hpp>
+#include <af/defines.h>
+#include <af/traits.hpp>
+#include <vector>
+#include <iostream>
+#include <complex>
+#include <string>
+#include <testHelpers.hpp>
+
+using std::vector;
+using std::string;
+using std::cout;
+using std::endl;
+using af::cfloat;
+using af::cdouble;
+
+template<typename T>
+class Rank : public ::testing::Test
+{
+};
+
+typedef ::testing::Types<float, double, af::cfloat, af::cdouble> TestTypes;
+TYPED_TEST_CASE(Rank, TestTypes);
+
+template<typename T>
+void rankSmall()
+{
+ if (noDoubleTests<T>()) return;
+
+ T ha[] = {1, 4, 7, 2, 5, 8, 3, 6, 20};
+ af::array a(3, 3, ha);
+
+ ASSERT_EQ(3, (int)af::rank(a));
+}
+
+template<typename T>
+void rankBig(const int num)
+{
+ if (noDoubleTests<T>()) return;
+ af::dtype dt = (af::dtype)af::dtype_traits<T>::af_type;
+ af::array a = af::randu(num, num, dt);
+ ASSERT_EQ(num, (int)af::rank(a));
+
+ af::array b = af::randu(num, num/2, dt);
+ ASSERT_EQ(num/2, (int)af::rank(b));
+ ASSERT_EQ(num/2, (int)af::rank(transpose(b)));
+}
+
+template<typename T>
+void rankLow(const int num)
+{
+ if (noDoubleTests<T>()) return;
+ af::dtype dt = (af::dtype)af::dtype_traits<T>::af_type;
+
+ af::array a = af::randu(3 * num, num, dt);
+ af::array b = af::randu(3 * num, num, dt);
+ af::array c = a + 0.2 * b;
+ af::array in = join(1, a, b, c);
+
+ // The last third is just a linear combination of first and second thirds
+ ASSERT_EQ(2 * num, (int)af::rank(in));
+}
+
+TYPED_TEST(Rank, small)
+{
+ rankSmall<TypeParam>();
+}
+
+TYPED_TEST(Rank, big)
+{
+ rankBig<TypeParam>(1024);
+}
+
+TYPED_TEST(Rank, low)
+{
+ rankBig<TypeParam>(512);
+}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list