[arrayfire] 340/408: API clean up and adding support for complex numbers for SVD

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:12:26 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.

commit 5fc32dcd06cb8ea50d76d7025f2721bd60818d81
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Tue Aug 25 05:19:28 2015 -0400

    API clean up and adding support for complex numbers for SVD
---
 include/af/lapack.h                       |  16 +--
 src/api/c/svd.cpp                         |  68 +++++++----
 src/api/cpp/lapack.cpp                    |   8 +-
 src/backend/cpu/svd.cpp                   |  59 +++++----
 src/backend/cpu/svd.hpp                   |   8 +-
 src/backend/cuda/svd.cu                   | 194 ++++++++++++++++--------------
 src/backend/cuda/svd.hpp                  |   8 +-
 src/backend/opencl/kernel/sort.hpp        |   6 +
 src/backend/opencl/kernel/sort_by_key.hpp |   6 +
 src/backend/opencl/kernel/sort_index.hpp  |   6 +
 src/backend/opencl/set.cpp                |   6 +
 src/backend/opencl/svd.cpp                |  59 +++++----
 src/backend/opencl/svd.hpp                |   9 +-
 13 files changed, 259 insertions(+), 194 deletions(-)

diff --git a/include/af/lapack.h b/include/af/lapack.h
index 7c8e6fe..c70bf88 100644
--- a/include/af/lapack.h
+++ b/include/af/lapack.h
@@ -17,26 +17,26 @@ namespace af
     /**
        C++ Interface for SVD decomposition
 
-       \param[out] s is the output array containing the diagonal values of sigma, (singular values of the input matrix))
        \param[out] u is the output array containing U
+       \param[out] s is the output array containing the diagonal values of sigma, (singular values of the input matrix))
        \param[out] vt is the output array containing V^H
        \param[in] in is the input matrix
 
        \ingroup lapack_factor_func_svd
     */
-    AFAPI void svd(array &s, array &u, array &vt, const array &in);
+    AFAPI void svd(array &u, array &s, array &vt, const array &in);
 
     /**
        C++ Interface for SVD decomposition
 
-       \param[out] s is the output array containing the diagonal values of sigma, (singular values of the input matrix))
        \param[out] u is the output array containing U
+       \param[out] s is the output array containing the diagonal values of sigma, (singular values of the input matrix))
        \param[out] vt is the output array containing V^H
        \param[inout] in is the input matrix and will contain random data after this operation
 
        \ingroup lapack_factor_func_svd
     */
-    AFAPI void svdInPlace(array &s, array &u, array &vt, array &in);
+    AFAPI void svdInPlace(array &u, array &s, array &vt, array &in);
 
     /**
        C++ Interface for LU decomposition in packed format
@@ -243,26 +243,26 @@ extern "C" {
     /**
        C Interface for SVD decomposition
 
-       \param[out] s is the output array containing the diagonal values of sigma, (singular values of the input matrix))
        \param[out] u is the output array containing U
+       \param[out] s is the output array containing the diagonal values of sigma, (singular values of the input matrix))
        \param[out] vt is the output array containing V^H
        \param[in] in is the input matrix
 
        \ingroup lapack_factor_func_svd
     */
-    AFAPI af_err af_svd(af_array *s, af_array *u, af_array *vt, const af_array in);
+    AFAPI af_err af_svd(af_array *u, af_array *s, af_array *vt, const af_array in);
 
     /**
        C Interface for SVD decomposition
 
-       \param[out] s is the output array containing the diagonal values of sigma, (singular values of the input matrix))
        \param[out] u is the output array containing U
+       \param[out] s is the output array containing the diagonal values of sigma, (singular values of the input matrix))
        \param[out] vt is the output array containing V^H
        \param[inout] in is the input matrix that will contain random data after this operation
 
        \ingroup lapack_factor_func_svd
     */
-    AFAPI af_err af_svd_inplace(af_array *s, af_array *u, af_array *vt, af_array in);
+    AFAPI af_err af_svd_inplace(af_array *u, af_array *s, af_array *vt, af_array in);
 
     /**
        C Interface for LU decomposition
diff --git a/src/api/c/svd.cpp b/src/api/c/svd.cpp
index 84f859c..80b11f7 100644
--- a/src/api/c/svd.cpp
+++ b/src/api/c/svd.cpp
@@ -29,12 +29,14 @@ static inline void svd(af_array *s, af_array *u, af_array *vt, const af_array in
     int M = dims[0];
     int N = dims[1];
 
+    typedef typename af::dtype_traits<T>::base_type Tr;
+
     //Allocate output arrays
-    Array<T> sA = createEmptyArray<T>(af::dim4(min(M, N)));
-    Array<T> uA = createEmptyArray<T>(af::dim4(M, M));
-    Array<T> vtA = createEmptyArray<T>(af::dim4(N, N));
+    Array<Tr> sA  = createEmptyArray<Tr>(af::dim4(min(M, N)));
+    Array<T > uA  = createEmptyArray<T >(af::dim4(M, M));
+    Array<T > vtA = createEmptyArray<T >(af::dim4(N, N));
 
-    svd<T>(sA, uA, vtA, getArray<T>(in));
+    svd<T, Tr>(sA, uA, vtA, getArray<T>(in));
 
     *s = getHandle(sA);
     *u = getHandle(uA);
@@ -49,19 +51,21 @@ static inline void svdInPlace(af_array *s, af_array *u, af_array *vt, af_array i
     int M = dims[0];
     int N = dims[1];
 
+    typedef typename af::dtype_traits<T>::base_type Tr;
+
     //Allocate output arrays
-    Array<T> sA = createEmptyArray<T>(af::dim4(min(M, N)));
-    Array<T> uA = createEmptyArray<T>(af::dim4(M, M));
-    Array<T> vtA = createEmptyArray<T>(af::dim4(N, N));
+    Array<Tr> sA  = createEmptyArray<Tr>(af::dim4(min(M, N)));
+    Array<T > uA  = createEmptyArray<T >(af::dim4(M, M));
+    Array<T > vtA = createEmptyArray<T >(af::dim4(N, N));
 
-    svdInPlace<T>(sA, uA, vtA, getWritableArray<T>(in));
+    svdInPlace<T, Tr>(sA, uA, vtA, getWritableArray<T>(in));
 
     *s = getHandle(sA);
     *u = getHandle(uA);
     *vt = getHandle(vtA);
 }
 
-af_err af_svd(af_array *s, af_array *u, af_array *vt, const af_array in)
+af_err af_svd(af_array *u, af_array *s, af_array *vt, const af_array in)
 {
     try {
         ArrayInfo info = getInfo(in);
@@ -71,21 +75,27 @@ af_err af_svd(af_array *s, af_array *u, af_array *vt, const af_array in)
         af_dtype type = info.getType();
 
         switch (type) {
-            case f64:
-                svd<double>(s, u, vt, in);
-                break;
-            case f32:
-                svd<float>(s, u, vt, in);
-                break;
-            default:
-                TYPE_ERROR(1, type);
+        case f64:
+            svd<double>(s, u, vt, in);
+            break;
+        case f32:
+            svd<float>(s, u, vt, in);
+            break;
+        case c64:
+            svd<cdouble>(s, u, vt, in);
+            break;
+        case c32:
+            svd<cfloat>(s, u, vt, in);
+            break;
+        default:
+            TYPE_ERROR(1, type);
         }
     }
     CATCHALL;
     return AF_SUCCESS;
 }
 
-af_err af_svd_inplace(af_array *s, af_array *u, af_array *vt, af_array in)
+af_err af_svd_inplace(af_array *u, af_array *s, af_array *vt, af_array in)
 {
     try {
         ArrayInfo info = getInfo(in);
@@ -97,14 +107,20 @@ af_err af_svd_inplace(af_array *s, af_array *u, af_array *vt, af_array in)
         af_dtype type = info.getType();
 
         switch (type) {
-            case f64:
-                svdInPlace<double>(s, u, vt, in);
-                break;
-            case f32:
-                svdInPlace<float>(s, u, vt, in);
-                break;
-            default:
-                TYPE_ERROR(1, type);
+        case f64:
+            svdInPlace<double>(s, u, vt, in);
+            break;
+        case f32:
+            svdInPlace<float>(s, u, vt, in);
+            break;
+        case c64:
+            svdInPlace<cdouble>(s, u, vt, in);
+            break;
+        case c32:
+            svdInPlace<cfloat>(s, u, vt, in);
+            break;
+        default:
+            TYPE_ERROR(1, type);
         }
     }
     CATCHALL;
diff --git a/src/api/cpp/lapack.cpp b/src/api/cpp/lapack.cpp
index 2b03dba..cf9b3ec 100644
--- a/src/api/cpp/lapack.cpp
+++ b/src/api/cpp/lapack.cpp
@@ -13,19 +13,19 @@
 
 namespace af
 {
-    void svd(array &s, array &u, array &vt, const array &in)
+    void svd(array &u, array &s, array &vt, const array &in)
     {
         af_array sl = 0, ul = 0, vtl = 0;
-        AF_THROW(af_svd(&sl, &ul, &vtl, in.get()));
+        AF_THROW(af_svd(&ul, &sl, &vtl, in.get()));
         s = array(sl);
         u = array(ul);
         vt = array(vtl);
     }
 
-    void svdInPlace(array &s, array &u, array &vt, array &in)
+    void svdInPlace(array &u, array &s, array &vt, array &in)
     {
         af_array sl = 0, ul = 0, vtl = 0;
-        AF_THROW(af_svd_inplace(&sl, &ul, &vtl, in.get()));
+        AF_THROW(af_svd_inplace(&ul, &sl, &vtl, in.get()));
         s = array(sl);
         u = array(ul);
         vt = array(vtl);
diff --git a/src/backend/cpu/svd.cpp b/src/backend/cpu/svd.cpp
index 6c590d6..4b74c7f 100644
--- a/src/backend/cpu/svd.cpp
+++ b/src/backend/cpu/svd.cpp
@@ -13,10 +13,6 @@
 
 #include <err_cpu.hpp>
 
-#define INSTANTIATE_SVD(T)                                                               \
-    template void svd<T>(Array<T> & s, Array<T> & u, Array<T> & vt, const Array<T> &in); \
-    template void svdInPlace<T>(Array<T> & s, Array<T> & u, Array<T> & vt, Array<T> &in);
-
 #if defined(WITH_CPU_LINEAR_ALGEBRA)
 #include <lapack_helper.hpp>
 #include <copy.hpp>
@@ -25,34 +21,40 @@ namespace cpu
 {
 
 #define SVD_FUNC_DEF( FUNC )                                \
-    template<typename T> FUNC##_func_def<T> FUNC##_func();
+    template<typename T,typename Tr> FUNC##_func_def<T, Tr> FUNC##_func();
 
-#define SVD_FUNC( FUNC, TYPE, PREFIX )                          \
-    template<> FUNC##_func_def<TYPE>     FUNC##_func<TYPE>()    \
+#define SVD_FUNC( FUNC, T, Tr, PREFIX )                         \
+    template<> FUNC##_func_def<T, Tr>     FUNC##_func<T, Tr>()  \
     { return & LAPACK_NAME(PREFIX##FUNC); }
 
-    template<typename T>
-    using gesdd_func_def = int (*)(ORDER_TYPE, char jobz, int m, int n, T* in,
-                                     int ldin, T* s, T* u, int ldu,
-                                     T* vt, int ldvt);
+    template<typename T, typename Tr>
+    using gesdd_func_def = int (*)(ORDER_TYPE,
+                                   char jobz,
+                                   int m, int n,
+                                   T* in, int ldin,
+                                   Tr* s,
+                                   T* u, int ldu,
+                                   T* vt, int ldvt);
 
     SVD_FUNC_DEF( gesdd )
-    SVD_FUNC(gesdd, float, s)
-    SVD_FUNC(gesdd, double, d)
+    SVD_FUNC(gesdd, float  , float , s)
+    SVD_FUNC(gesdd, double , double, d)
+    SVD_FUNC(gesdd, cfloat , float , c)
+    SVD_FUNC(gesdd, cdouble, double, z)
 
-    template <typename T>
-    void svdInPlace(Array<T> &s, Array<T> &u, Array<T> &vt, Array<T> &in)
+    template <typename T, typename Tr>
+    void svdInPlace(Array<Tr> &s, Array<T> &u, Array<T> &vt, Array<T> &in)
     {
         dim4 iDims = in.dims();
         int M = iDims[0];
         int N = iDims[1];
 
-        gesdd_func<T>()(AF_LAPACK_COL_MAJOR, 'A', M, N, in.get(), in.strides()[1],
-                        s.get(), u.get(), u.strides()[1], vt.get(), vt.strides()[1]);
+        gesdd_func<T, Tr>()(AF_LAPACK_COL_MAJOR, 'A', M, N, in.get(), in.strides()[1],
+                            s.get(), u.get(), u.strides()[1], vt.get(), vt.strides()[1]);
     }
 
-    template <typename T>
-    void svd(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
+    template <typename T, typename Tr>
+    void svd(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
     {
         Array<T> in_copy = copyArray<T>(in);
         svdInPlace(s, u, vt, in_copy);
@@ -63,14 +65,14 @@ namespace cpu
 
 namespace cpu
 {
-    template <typename T>
-    void svd(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
+    template <typename T, typename Tr>
+    void svd(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
     {
         AF_ERROR("Linear Algebra is disabled on CPU", AF_ERR_NOT_CONFIGURED);
     }
 
-    template <typename T>
-    void svdInPlace(Array<T> &s, Array<T> &u, Array<T> &vt, Array<T> &in)
+    template <typename T, typename Tr>
+    void svdInPlace(Array<Tr> &s, Array<T> &u, Array<T> &vt, Array<T> &in)
     {
         AF_ERROR("Linear Algebra is disabled on CPU", AF_ERR_NOT_CONFIGURED);
     }
@@ -79,6 +81,13 @@ namespace cpu
 #endif
 
 namespace cpu {
-    INSTANTIATE_SVD(float)
-    INSTANTIATE_SVD(double)
+
+#define INSTANTIATE_SVD(T, Tr)                                          \
+    template void svd<T, Tr>(Array<Tr> & s, Array<T> & u, Array<T> & vt, const Array<T> &in); \
+    template void svdInPlace<T, Tr>(Array<Tr> & s, Array<T> & u, Array<T> & vt, Array<T> &in);
+
+    INSTANTIATE_SVD(float  , float )
+    INSTANTIATE_SVD(double , double)
+    INSTANTIATE_SVD(cfloat , float )
+    INSTANTIATE_SVD(cdouble, double)
 }
diff --git a/src/backend/cpu/svd.hpp b/src/backend/cpu/svd.hpp
index d465e8c..e9934ce 100644
--- a/src/backend/cpu/svd.hpp
+++ b/src/backend/cpu/svd.hpp
@@ -12,9 +12,9 @@
 
 namespace cpu
 {
-    template <typename T>
-    void svd(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in);
+    template<typename T, typename Tr>
+    void svd(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in);
 
-    template <typename T>
-    void svdInPlace(Array<T> &s, Array<T> &u, Array<T> &vt, Array<T> &in);
+    template<typename T, typename Tr>
+    void svdInPlace(Array<Tr> &s, Array<T> &u, Array<T> &vt, Array<T> &in);
 }
diff --git a/src/backend/cuda/svd.cu b/src/backend/cuda/svd.cu
index 8e7b3ac..fed35cf 100644
--- a/src/backend/cuda/svd.cu
+++ b/src/backend/cuda/svd.cu
@@ -10,8 +10,6 @@
 #include <svd.hpp>
 #include <err_common.hpp>
 
-#if defined(WITH_CUDA_LINEAR_ALGEBRA)
-
 #include <cusolverDnManager.hpp>
 #include "transpose.hpp"
 #include <memory.hpp>
@@ -21,123 +19,135 @@
 
 namespace cuda
 {
-    using cusolver::getDnHandle;
 
-    template <typename T>
-    struct gesvd_func_def_t {
-        typedef cusolverStatus_t (*gesvd_func_def)(cusolverDnHandle_t, char, char, int,
-                                                   int, T *, int, T *, T *, int, T *, int,
-                                                   T *, int, T *, int *);
-    };
+#if defined(WITH_CUDA_LINEAR_ALGEBRA)
+
+#include <cusolverDnManager.hpp>
+
+    using cusolver::getDnHandle;
 
     template<typename T>
-    struct gesvd_buf_func_def_t {
-        typedef cusolverStatus_t (*gesvd_buf_func_def)(cusolverDnHandle_t, int, int,
-                                                       int *);
-    };
-
-#define SVD_FUNC_DEF(FUNC)                                                              \
-    template <typename T>                                                               \
-    typename FUNC##_func_def_t<T>::FUNC##_func_def FUNC##_func();                       \
-                                                                                        \
-    template<typename T>                                                                \
-    typename FUNC##_buf_func_def_t<T>::FUNC##_buf_func_def                              \
-    FUNC##_buf_func();
-
-#define SVD_FUNC(FUNC, TYPE, PREFIX)                                                    \
-    template <>                                                                         \
-    typename FUNC##_func_def_t<TYPE>::FUNC##_func_def FUNC##_func<TYPE>()               \
-    {                                                                                   \
-        return (FUNC##_func_def_t<TYPE>::FUNC##_func_def) & cusolverDn##PREFIX##FUNC;   \
-    }                                                                                   \
-                                                                                        \
-    template<> typename FUNC##_buf_func_def_t<TYPE>::FUNC##_buf_func_def                \
-    FUNC##_buf_func<TYPE>()                                                             \
-    {                                                                                   \
-        return (FUNC##_buf_func_def_t<TYPE>::FUNC##_buf_func_def) &                     \
-               cusolverDn##PREFIX##FUNC##_bufferSize;                                   \
+    cusolverStatus_t gesvd_buf_func(cusolverDnHandle_t handle, int m, int n, int *Lwork)
+    {
+        return CUSOLVER_STATUS_ARCH_MISMATCH;
     }
 
-    SVD_FUNC_DEF(gesvd)
-    SVD_FUNC(gesvd, float, S)
-    SVD_FUNC(gesvd, double, D)
-//SVD_FUNC(gesvd , cfloat , C)
-//SVD_FUNC(gesvd , cdouble, Z)
+    template<typename T, typename Tr>
+    cusolverStatus_t gesvd_func(cusolverDnHandle_t handle, char jobu, char jobvt,
+                                int m, int n,
+                                T *A, int lda,
+                                Tr *S,
+                                T *U, int ldu,
+                                T *VT, int ldvt,
+                                T *Work, int Lwork,
+                                Tr *rwork, int *devInfo)
+    {
+        return CUSOLVER_STATUS_ARCH_MISMATCH;
+    }
 
-    template <typename T>
-    void svdInPlace(Array<T> &s, Array<T> &u, Array<T> &vt, Array<T> &in)
+#define SVD_SPECIALIZE(T, Tr, X)                                        \
+    template<> cusolverStatus_t                                         \
+    gesvd_buf_func<T>(cusolverDnHandle_t handle,                        \
+                      int m, int n, int *Lwork)                         \
+    {                                                                   \
+        return cusolverDn##X##gesvd_bufferSize(handle, m, n, Lwork);    \
+    }                                                                   \
+
+SVD_SPECIALIZE(float  , float , S);
+SVD_SPECIALIZE(double , double, D);
+SVD_SPECIALIZE(cfloat , float , C);
+SVD_SPECIALIZE(cdouble, double, Z);
+
+#undef SVD_SPECIALIZE
+
+#define SVD_SPECIALIZE(T, Tr, X)                                        \
+    template<> cusolverStatus_t                                         \
+    gesvd_func<T, Tr>(cusolverDnHandle_t handle,                        \
+                      char jobu, char jobvt,                            \
+                      int m, int n,                                     \
+                      T *A, int lda,                                    \
+                      Tr *S,                                            \
+                      T *U, int ldu,                                    \
+                      T *VT, int ldvt,                                  \
+                      T *Work, int Lwork,                               \
+                      Tr *rwork, int *devInfo)                          \
+    {                                                                   \
+        return cusolverDn##X##gesvd(handle, jobu, jobvt,                \
+                                    m, n, A, lda, S, U, ldu, VT, ldvt,  \
+                                    Work, Lwork, rwork, devInfo);       \
+    }                                                                   \
+
+SVD_SPECIALIZE(float  , float , S);
+SVD_SPECIALIZE(double , double, D);
+SVD_SPECIALIZE(cfloat , float , C);
+SVD_SPECIALIZE(cdouble, double, Z);
+
+    template <typename T, typename Tr>
+    void svdInPlace(Array<Tr> &s, Array<T> &u, Array<T> &vt, Array<T> &in)
     {
         dim4 iDims = in.dims();
         int M = iDims[0];
         int N = iDims[1];
 
-        // cuSolver(cuda 7.0) doesn't have support for M<N
-        bool flip_and_transpose = M < N;
+        int lwork = 0;
 
-        if (flip_and_transpose) {
-            std::swap(M, N);
-            std::swap(vt, u);
-        }
+        CUSOLVER_CHECK(gesvd_buf_func<T>(getDnHandle(), M, N, &lwork));
 
-        int lwork = 0;
-        CUSOLVER_CHECK(gesvd_buf_func<T>()(getDnHandle(), M, N, &lwork));
-        T *lWorkspace = memAlloc<T>(lwork);
-        //complex numbers would need rWorkspace
-        //T *rWorkspace = memAlloc<T>(lwork);
+        T  *lWorkspace = memAlloc<T >(lwork);
+        Tr *rWorkspace = memAlloc<Tr>(5 * std::min(M, N));
 
         int *info = memAlloc<int>(1);
 
-        if (flip_and_transpose) {
-            transpose_inplace(in, true);
-            CUSOLVER_CHECK(gesvd_func<T>()(getDnHandle(), 'A', 'A', M, N, in.get(),
-                                           M, s.get(), u.get(), M, vt.get(), N,
-                                           lWorkspace, lwork, NULL, info));
-            std::swap(u, vt);
-            transpose_inplace(vt, true);
-        } else {
-            Array<T> inCopy = copyArray<T>(in);
-            CUSOLVER_CHECK(gesvd_func<T>()(getDnHandle(), 'A', 'A', M, N, in.get(),
-                                           M, s.get(), u.get(), M, vt.get(), N,
-                                           lWorkspace, lwork, NULL, info));
-        }
+        gesvd_func<T, Tr>(getDnHandle(), 'A', 'A', M, N, in.get(),
+                          M, s.get(), u.get(), M, vt.get(), N,
+                          lWorkspace, lwork, rWorkspace, info);
+
         memFree(info);
         memFree(lWorkspace);
-        //memFree(rWorkspace);
+        memFree(rWorkspace);
     }
 
-    template <typename T>
-    void svd(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
+    template <typename T, typename Tr>
+    void svd(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
     {
-        Array<T> inCopy = copyArray<T>(in);
-        svdInPlace(s, u, vt, inCopy);
+        dim4 iDims = in.dims();
+        int M = iDims[0];
+        int N = iDims[1];
+
+        if (M <= N) {
+            Array<T> in_copy = copyArray(in);
+            return svdInPlace(s, u, vt, in_copy);
+        } else {
+            Array<T> in_trans = transpose(in, true);
+            return svdInPlace(s, vt, u, in_trans);
+        }
     }
 
-#define INSTANTIATE_SVD(T)                                                              \
-    template void svd<T>(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in);   \
-    template void svdInPlace<T>(Array<T> &s, Array<T> &u, Array<T> &vt, Array<T> &in);   \
+#else
 
-    INSTANTIATE_SVD(float)
-    //INSTANTIATE_SVD(cfloat)
-    INSTANTIATE_SVD(double)
-    //INSTANTIATE_SVD(cdouble)
+template<typename T, typename Tr>
+void svd(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
+{
+    AF_ERROR("CUDA cusolver not available. Linear Algebra is disabled",
+             AF_ERR_NOT_CONFIGURED);
 }
 
-#else
-namespace cuda
+template<typename T, typename Tr>
+void svdInPlace(Array<Tr> &s, Array<T> &u, Array<T> &vt, Array<T> &in)
 {
-    template <typename T>
-    void svd(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
-    {
-        AF_ERROR("CUDA cusolver not available. Linear Algebra is disabled",
-                 AF_ERR_NOT_CONFIGURED);
-    }
+    AF_ERROR("CUDA cusolver not available. Linear Algebra is disabled",
+             AF_ERR_NOT_CONFIGURED);
+}
 
-#define INSTANTIATE_SVD(T)                                                              \
-    template void svd<T>(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in);   \
+#endif
+
+#define INSTANTIATE(T, Tr)                                              \
+    template void svd<T, Tr>(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in); \
+    template void svdInPlace<T, Tr>(Array<Tr> &s, Array<T> &u, Array<T> &vt, Array<T> &in);
+
+INSTANTIATE(float, float)
+INSTANTIATE(double, double)
+INSTANTIATE(cfloat, float)
+INSTANTIATE(cdouble, double)
 
-    INSTANTIATE_SVD(float)
-    //INSTANTIATE_SVD(cfloat)
-    INSTANTIATE_SVD(double)
-    //INSTANTIATE_SVD(cdouble)
 }
-#endif
diff --git a/src/backend/cuda/svd.hpp b/src/backend/cuda/svd.hpp
index 44de4c2..5a833ce 100644
--- a/src/backend/cuda/svd.hpp
+++ b/src/backend/cuda/svd.hpp
@@ -12,9 +12,9 @@
 
 namespace cuda
 {
-    template <typename T>
-    void svd(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in);
+    template<typename T, typename Tr>
+    void svd(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in);
 
-    template <typename T>
-    void svdInPlace(Array<T> &s, Array<T> &u, Array<T> &vt, Array<T> &in);
+    template<typename T, typename Tr>
+    void svdInPlace(Array<Tr> &s, Array<T> &u, Array<T> &vt, Array<T> &in);
 }
diff --git a/src/backend/opencl/kernel/sort.hpp b/src/backend/opencl/kernel/sort.hpp
index c3976a0..58345b9 100644
--- a/src/backend/opencl/kernel/sort.hpp
+++ b/src/backend/opencl/kernel/sort.hpp
@@ -15,6 +15,10 @@
 #include <dispatch.hpp>
 #include <Param.hpp>
 #include <debug_opencl.hpp>
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+
 #include <boost/compute/core.hpp>
 #include <boost/compute/algorithm/stable_sort.hpp>
 #include <boost/compute/functional/operator.hpp>
@@ -77,3 +81,5 @@ namespace opencl
         }
     }
 }
+
+#pragma GCC diagnostic pop
diff --git a/src/backend/opencl/kernel/sort_by_key.hpp b/src/backend/opencl/kernel/sort_by_key.hpp
index 07086ca..1ea2a48 100644
--- a/src/backend/opencl/kernel/sort_by_key.hpp
+++ b/src/backend/opencl/kernel/sort_by_key.hpp
@@ -15,6 +15,10 @@
 #include <dispatch.hpp>
 #include <Param.hpp>
 #include <debug_opencl.hpp>
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+
 #include <boost/compute/core.hpp>
 #include <boost/compute/algorithm/sort_by_key.hpp>
 #include <boost/compute/functional/operator.hpp>
@@ -79,3 +83,5 @@ namespace opencl
         }
     }
 }
+
+#pragma GCC diagnostic pop
diff --git a/src/backend/opencl/kernel/sort_index.hpp b/src/backend/opencl/kernel/sort_index.hpp
index 8504b78..5595b8c 100644
--- a/src/backend/opencl/kernel/sort_index.hpp
+++ b/src/backend/opencl/kernel/sort_index.hpp
@@ -15,6 +15,10 @@
 #include <dispatch.hpp>
 #include <Param.hpp>
 #include <debug_opencl.hpp>
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+
 #include <boost/compute/core.hpp>
 #include <boost/compute/algorithm/iota.hpp>
 #include <boost/compute/algorithm/sort_by_key.hpp>
@@ -81,3 +85,5 @@ namespace opencl
         }
     }
 }
+
+#pragma GCC diagnostic pop
diff --git a/src/backend/opencl/set.cpp b/src/backend/opencl/set.cpp
index f725a23..665ffdf 100644
--- a/src/backend/opencl/set.cpp
+++ b/src/backend/opencl/set.cpp
@@ -15,6 +15,10 @@
 #include <copy.hpp>
 #include <sort.hpp>
 #include <err_opencl.hpp>
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+
 #include <boost/compute/algorithm/set_intersection.hpp>
 #include <boost/compute/algorithm/set_union.hpp>
 #include <boost/compute/algorithm/sort.hpp>
@@ -148,3 +152,5 @@ namespace opencl
     INSTANTIATE(char)
     INSTANTIATE(uchar)
 }
+
+#pragma GCC diagnostic pop
diff --git a/src/backend/opencl/svd.cpp b/src/backend/opencl/svd.cpp
index 1467d9e..3d90444 100644
--- a/src/backend/opencl/svd.cpp
+++ b/src/backend/opencl/svd.cpp
@@ -13,6 +13,7 @@
 #include <reduce.hpp>
 #include <copy.hpp>
 #include <blas.hpp>
+#include <transpose.hpp>
 
 #include <magma/magma.h>
 #include <magma/magma_cpu_lapack.h>
@@ -113,8 +114,10 @@ void svd(Array<T > &arrU,
     int ncvt = 0;
 
     std::vector<T> A(m * n);
+    std::vector<T> tauq(min_mn), taup(min_mn);
+    std::vector<T> work(lwork);
     std::vector<Tr> s0(min_mn), s1(min_mn - 1);
-    std::vector<T> tauq(min_mn), taup(min_mn), work(lwork);
+    std::vector<Tr> rwork(5 * min_mn);
 
     int info = 0;
 
@@ -167,7 +170,7 @@ void svd(Array<T > &arrU,
     // (RWorkspace: need BDSPAC)
     LAPACKE_CHECK(cpu_lapack_bdsqr_work('U', n, ncvt, nru, izero,
                                         &s0[0], &s1[0], &VT[0], ldvt, &U[0], ldu,
-                                        &cdummy[0], ione, &work[0]));
+                                        &cdummy[0], ione, &rwork[0]));
 
 
     if (want_vectors) {
@@ -189,47 +192,51 @@ void svd(Array<T > &arrU,
 }
 
 
-template<typename T>
-void svdInPlace(Array<T> &s, Array<T> &u, Array<T> &vt, Array<T> &in)
+template<typename T, typename Tr>
+void svdInPlace(Array<Tr> &s, Array<T> &u, Array<T> &vt, Array<T> &in)
 {
     initBlas();
-    dim4 iDims = in.dims();
-    int M = iDims[0];
-    int N = iDims[1];
-
-    if (M < N) OPENCL_NOT_SUPPORTED();
-
-    typedef typename af::dtype_traits<T>::base_type Tr;
     svd<T, Tr>(u, s, vt, in, true);
 }
 
-template<typename T>
-void svd(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
+template<typename T, typename Tr>
+void svd(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
 {
-    Array<T> in_copy = copyArray(in);
-    return svdInPlace(s, u, vt, in_copy);
+    dim4 iDims = in.dims();
+    int M = iDims[0];
+    int N = iDims[1];
+
+    if (M <= N) {
+        Array<T> in_copy = copyArray(in);
+        return svdInPlace(s, u, vt, in_copy);
+    } else {
+        Array<T> in_trans = transpose(in, true);
+        return svdInPlace(s, vt, u, in_trans);
+    }
 }
 
 #else
 
-template<typename T>
-void svd(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
+template<typename T, typename Tr>
+void svd(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in)
 {
-    OPENCL_NOT_SUPPORTED();
+    AF_ERROR("Linear Algebra is disabled on OpenCL", AF_ERR_NOT_CONFIGURED);
 }
 
-template<typename T>
-void svdInPlace(Array<T> &s, Array<T> &u, Array<T> &vt, Array<T> &in)
+template<typename T, typename Tr>
+void svdInPlace(Array<Tr> &s, Array<T> &u, Array<T> &vt, Array<T> &in)
 {
-    OPENCL_NOT_SUPPORTED();
+    AF_ERROR("Linear Algebra is disabled on OpenCL", AF_ERR_NOT_CONFIGURED);
 }
 #endif
 
-#define INSTANTIATE(T)                                                  \
-    template void svd(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in); \
-    template void svdInPlace(Array<T> &s, Array<T> &u, Array<T> &vt, Array<T> &in);
+#define INSTANTIATE(T, Tr)                                              \
+    template void svd<T, Tr>(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in); \
+    template void svdInPlace<T, Tr>(Array<Tr> &s, Array<T> &u, Array<T> &vt, Array<T> &in);
 
-INSTANTIATE(float)
-INSTANTIATE(double)
+INSTANTIATE(float, float)
+INSTANTIATE(double, double)
+INSTANTIATE(cfloat, float)
+INSTANTIATE(cdouble, double)
 
 }
diff --git a/src/backend/opencl/svd.hpp b/src/backend/opencl/svd.hpp
index 0bf4ded..06f6901 100644
--- a/src/backend/opencl/svd.hpp
+++ b/src/backend/opencl/svd.hpp
@@ -11,10 +11,9 @@
 
 namespace opencl
 {
-    template<typename T>
-    void svd(Array<T> &s, Array<T> &u, Array<T> &vt, const Array<T> &in);
+    template<typename T, typename Tr>
+    void svd(Array<Tr> &s, Array<T> &u, Array<T> &vt, const Array<T> &in);
 
-    template<typename T>
-    void svdInPlace(Array<T> &s, Array<T> &u, Array<T> &vt, Array<T> &in);
+    template<typename T, typename Tr>
+    void svdInPlace(Array<Tr> &s, Array<T> &u, Array<T> &vt, Array<T> &in);
 }
-

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list