[arrayfire] 342/408: TEST: for SVD

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:12:26 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.

commit b3c5f0fa2f55a1cb2bc04524b942578754406a74
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Tue Aug 25 06:44:34 2015 -0400

    TEST: for SVD
---
 test/svd_dense.cpp | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 93 insertions(+)

diff --git a/test/svd_dense.cpp b/test/svd_dense.cpp
new file mode 100644
index 0000000..7ba8e23
--- /dev/null
+++ b/test/svd_dense.cpp
@@ -0,0 +1,93 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+
+#include <gtest/gtest.h>
+#include <arrayfire.h>
+#include <af/dim4.hpp>
+#include <af/defines.h>
+#include <af/traits.hpp>
+#include <vector>
+#include <iostream>
+#include <complex>
+#include <string>
+#include <testHelpers.hpp>
+
+using std::vector;
+using std::string;
+using std::cout;
+using std::endl;
+using af::cfloat;
+using af::cdouble;
+
+template<typename T>
+class svd : public ::testing::Test
+{
+};
+
+typedef ::testing::Types<float> TestTypes;
+TYPED_TEST_CASE(svd, TestTypes);
+
+template<typename T>
+double get_val(T val)
+{
+    return val;
+}
+
+template<> double get_val<cfloat>(cfloat val)
+{
+    return abs(val);
+}
+
+template<> double get_val<cdouble>(cdouble val)
+{
+    return abs(val);
+}
+
+template<typename T>
+void svdTest(const int M, const int N)
+{
+    af::dtype ty = (af::dtype)af::dtype_traits<T>::af_type;
+
+    af::array A = af::randu(M, N, ty);
+    af::array U, S, Vt;
+    af::svd(U, S, Vt, A);
+
+    const int MN = std::min(M, N);
+
+    af::array UU = U(af::span, af::seq(MN));
+    af::array SS = af::diag(S, 0, false).as(ty);
+    af::array VV = Vt(af::seq(MN), af::span);
+
+    af::array AA = matmul(UU, SS, VV);
+
+    std::vector<T> hA(M * N);
+    std::vector<T> hAA(M * N);
+
+    A.host(&hA[0]);
+    AA.host(&hAA[0]);
+
+    for (int i = 0; i < M * N; i++) {
+        ASSERT_NEAR(get_val(hA[i]), get_val(hAA[i]), 1E-3);
+    }
+}
+
+TYPED_TEST(svd, Square)
+{
+    svdTest<TypeParam>(500, 500);
+}
+
+TYPED_TEST(svd, Rect0)
+{
+    svdTest<TypeParam>(500, 300);
+}
+
+TYPED_TEST(svd, Rect1)
+{
+    svdTest<TypeParam>(300, 500);
+}

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list