[arrayfire] 342/408: TEST: for SVD
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:12:26 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.
commit b3c5f0fa2f55a1cb2bc04524b942578754406a74
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Tue Aug 25 06:44:34 2015 -0400
TEST: for SVD
---
test/svd_dense.cpp | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 93 insertions(+)
diff --git a/test/svd_dense.cpp b/test/svd_dense.cpp
new file mode 100644
index 0000000..7ba8e23
--- /dev/null
+++ b/test/svd_dense.cpp
@@ -0,0 +1,93 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+
+#include <gtest/gtest.h>
+#include <arrayfire.h>
+#include <af/dim4.hpp>
+#include <af/defines.h>
+#include <af/traits.hpp>
+#include <vector>
+#include <iostream>
+#include <complex>
+#include <string>
+#include <testHelpers.hpp>
+
+using std::vector;
+using std::string;
+using std::cout;
+using std::endl;
+using af::cfloat;
+using af::cdouble;
+
+template<typename T>
+class svd : public ::testing::Test
+{
+};
+
+typedef ::testing::Types<float> TestTypes;
+TYPED_TEST_CASE(svd, TestTypes);
+
+template<typename T>
+double get_val(T val)
+{
+ return val;
+}
+
+template<> double get_val<cfloat>(cfloat val)
+{
+ return abs(val);
+}
+
+template<> double get_val<cdouble>(cdouble val)
+{
+ return abs(val);
+}
+
+template<typename T>
+void svdTest(const int M, const int N)
+{
+ af::dtype ty = (af::dtype)af::dtype_traits<T>::af_type;
+
+ af::array A = af::randu(M, N, ty);
+ af::array U, S, Vt;
+ af::svd(U, S, Vt, A);
+
+ const int MN = std::min(M, N);
+
+ af::array UU = U(af::span, af::seq(MN));
+ af::array SS = af::diag(S, 0, false).as(ty);
+ af::array VV = Vt(af::seq(MN), af::span);
+
+ af::array AA = matmul(UU, SS, VV);
+
+ std::vector<T> hA(M * N);
+ std::vector<T> hAA(M * N);
+
+ A.host(&hA[0]);
+ AA.host(&hAA[0]);
+
+ for (int i = 0; i < M * N; i++) {
+ ASSERT_NEAR(get_val(hA[i]), get_val(hAA[i]), 1E-3);
+ }
+}
+
+TYPED_TEST(svd, Square)
+{
+ svdTest<TypeParam>(500, 500);
+}
+
+TYPED_TEST(svd, Rect0)
+{
+ svdTest<TypeParam>(500, 300);
+}
+
+TYPED_TEST(svd, Rect1)
+{
+ svdTest<TypeParam>(300, 500);
+}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list