[arrayfire] 43/408: Added options for dotc and dotu to dot function

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:11:13 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.

commit b62696722ba0983199bc7cb7a711c00e86625e03
Author: Shehzan Mohammed <shehzan at arrayfire.com>
Date:   Thu Jun 25 17:29:44 2015 -0400

    Added options for dotc and dotu to dot function
---
 include/af/blas.h           |  11 +++++
 include/af/defines.h        |   1 +
 src/api/c/blas.cpp          |  16 +++----
 src/backend/cpu/blas.cpp    |  43 ++++++++++++++-----
 src/backend/cuda/blas.cpp   | 100 ++++++++++++++++++++++++++++++--------------
 src/backend/opencl/blas.cpp |  79 ++++++++++++++++++++++++----------
 6 files changed, 179 insertions(+), 71 deletions(-)

diff --git a/include/af/blas.h b/include/af/blas.h
index 8aa59e4..1a89ad2 100644
--- a/include/af/blas.h
+++ b/include/af/blas.h
@@ -138,8 +138,19 @@ namespace af
         af_print(dot(x,y));
         }
 
+        \param[in] lhs The array object on the left hand side
+        \param[in] rhs The array object on the right hand side
+        \param[in] optLhs Options for lhs. Currently only \ref AF_MAT_NONE and
+                   AF_MAT_CONJ are supported.
+        \param[in] optRhs Options for rhs. Currently only \ref AF_MAT_NONE and AF_MAT_CONJ are supported
+        \return The result of the dot product of lhs, rhs
+
+        \note optLhs and optRhs can only be one of \ref AF_MAT_NONE or \ref AF_MAT_CONJ
+        \note optLhs = AF_MAT_CONJ and optRhs = AF_MAT_NONE will run conjugate dot operation.
         \note This function is not supported in GFOR
 
+        \returns out = dot(lhs, rhs)
+
         \ingroup blas_func_dot
     */
     AFAPI array dot   (const array &lhs, const array &rhs,
diff --git a/include/af/defines.h b/include/af/defines.h
index db22b5f..a2aa7a9 100644
--- a/include/af/defines.h
+++ b/include/af/defines.h
@@ -244,6 +244,7 @@ typedef enum {
     AF_MAT_NONE       = 0,    ///< Default
     AF_MAT_TRANS      = 1,    ///< Data needs to be transposed
     AF_MAT_CTRANS     = 2,    ///< Data needs to be conjugate tansposed
+    AF_MAT_CONJ       = 4,    ///< Data needs to be conjugate
     AF_MAT_UPPER      = 32,   ///< Matrix is upper triangular
     AF_MAT_LOWER      = 64,   ///< Matrix is lower triangular
     AF_MAT_DIAG_UNIT  = 128,  ///< Matrix diagonal contains unitary values
diff --git a/src/api/c/blas.cpp b/src/api/c/blas.cpp
index 6d5f3aa..21fb44f 100644
--- a/src/api/c/blas.cpp
+++ b/src/api/c/blas.cpp
@@ -93,11 +93,11 @@ af_err af_dot(      af_array *out,
         ArrayInfo lhsInfo = getInfo(lhs);
         ArrayInfo rhsInfo = getInfo(rhs);
 
-        if (optLhs != AF_MAT_NONE) {
+        if (optLhs != AF_MAT_NONE && optLhs != AF_MAT_CONJ) {
             AF_ERROR("Using this property is not yet supported in dot", AF_ERR_NOT_SUPPORTED);
         }
 
-        if (optRhs != AF_MAT_NONE) {
+        if (optRhs != AF_MAT_NONE && optRhs != AF_MAT_CONJ) {
             AF_ERROR("Using this property is not yet supported in dot", AF_ERR_NOT_SUPPORTED);
         }
 
@@ -105,8 +105,8 @@ af_err af_dot(      af_array *out,
         af_dtype lhs_type = lhsInfo.getType();
         af_dtype rhs_type = rhsInfo.getType();
 
-        if (lhsInfo.ndims() > 2 ||
-            rhsInfo.ndims() > 2) {
+        if (lhsInfo.ndims() > 1 ||
+            rhsInfo.ndims() > 1) {
             AF_ERROR("dot can not be used in batch mode", AF_ERR_BATCH);
         }
 
@@ -115,10 +115,10 @@ af_err af_dot(      af_array *out,
         af_array output = 0;
 
         switch(lhs_type) {
-        case f32: output = dot<float  >(lhs, rhs, optLhs, optRhs);   break;
-	case c32: output = dot<cfloat >(lhs, rhs, optLhs, optRhs);   break;
-        case f64: output = dot<double >(lhs, rhs, optLhs, optRhs);   break;
-	case c64: output = dot<cdouble>(lhs, rhs, optLhs, optRhs);   break;
+        case f32: output = dot<float  >(lhs, rhs, optLhs, optRhs);    break;
+        case c32: output = dot<cfloat >(lhs, rhs, optLhs, optRhs);    break;
+        case f64: output = dot<double >(lhs, rhs, optLhs, optRhs);    break;
+        case c64: output = dot<cdouble>(lhs, rhs, optLhs, optRhs);    break;
         default:  TYPE_ERROR(1, lhs_type);
         }
         std::swap(*out, output);
diff --git a/src/backend/cpu/blas.cpp b/src/backend/cpu/blas.cpp
index 65fc19f..44094e3 100644
--- a/src/backend/cpu/blas.cpp
+++ b/src/backend/cpu/blas.cpp
@@ -102,9 +102,9 @@ toCblasTranspose(af_mat_prop opt)
     CBLAS_TRANSPOSE out = CblasNoTrans;
     switch(opt) {
         case AF_MAT_NONE        : out = CblasNoTrans;   break;
-        case AF_MAT_TRANS           : out = CblasTrans;     break;
-        case AF_MAT_CTRANS : out = CblasConjTrans; break;
-        default                     : AF_ERROR("INVALID af_mat_prop", AF_ERR_ARG);
+        case AF_MAT_TRANS       : out = CblasTrans;     break;
+        case AF_MAT_CTRANS      : out = CblasConjTrans; break;
+        default                 : AF_ERROR("INVALID af_mat_prop", AF_ERR_ARG);
     }
     return out;
 }
@@ -185,9 +185,15 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
     return out;
 }
 
-template<typename T>
-Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
-             af_mat_prop optLhs, af_mat_prop optRhs)
+template<typename T> T
+conj(T  x) { return x; }
+
+template<> cfloat  conj<cfloat> (cfloat  c) { return std::conj(c); }
+template<> cdouble conj<cdouble>(cdouble c) { return std::conj(c); }
+
+template<typename T, bool conjugate, bool both_conjugate>
+Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
+              af_mat_prop optLhs, af_mat_prop optRhs)
 {
     int N = lhs.dims()[0];
 
@@ -196,16 +202,33 @@ Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
     const T *pR = rhs.get();
 
     for(int i = 0; i < N; i++)
-        out += pL[i] * pR[i];
+        out += (conjugate ? cpu::conj(pL[i]) : pL[i]) * pR[i];
+
+    if(both_conjugate) out = cpu::conj(out);
 
     return createValueArray(af::dim4(1), out);
 }
 
+template<typename T>
+Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
+             af_mat_prop optLhs, af_mat_prop optRhs)
+{
+    if(optLhs == AF_MAT_CONJ && optRhs == AF_MAT_CONJ) {
+        return dot_<T, false, true>(lhs, rhs, optLhs, optRhs);
+    } else if (optLhs == AF_MAT_CONJ && optRhs == AF_MAT_NONE) {
+        return dot_<T, true, false>(lhs, rhs, optLhs, optRhs);
+    } else if (optLhs == AF_MAT_NONE && optRhs == AF_MAT_CONJ) {
+        return dot_<T, true, false>(rhs, lhs, optRhs, optLhs);
+    } else {
+        return dot_<T, false, false>(lhs, rhs, optLhs, optRhs);
+    }
+}
+
 #undef BT
 #undef REINTEPRET_CAST
 
 #define INSTANTIATE_BLAS(TYPE)                                                          \
-    template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs,  \
+    template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs,   \
                                       af_mat_prop optLhs, af_mat_prop optRhs);
 
 INSTANTIATE_BLAS(float)
@@ -213,8 +236,8 @@ INSTANTIATE_BLAS(cfloat)
 INSTANTIATE_BLAS(double)
 INSTANTIATE_BLAS(cdouble)
 
-#define INSTANTIATE_DOT(TYPE)                                                       \
-    template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
+#define INSTANTIATE_DOT(TYPE)                                                               \
+    template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs,          \
                                    af_mat_prop optLhs, af_mat_prop optRhs);
 
 INSTANTIATE_DOT(float)
diff --git a/src/backend/cuda/blas.cpp b/src/backend/cuda/blas.cpp
index 91a0b62..aaccc65 100644
--- a/src/backend/cuda/blas.cpp
+++ b/src/backend/cuda/blas.cpp
@@ -60,16 +60,6 @@ struct gemv_func_def_t
 };
 
 template<typename T>
-struct dot_func_def_t
-{
-    typedef cublasStatus_t (*dot_func_def)(    cublasHandle_t,
-                                                int,
-                                                const T *,  int,
-                                                const T *,  int,
-                                                T *);
-};
-
-template<typename T>
 struct trsm_func_def_t
 {
     typedef cublasStatus_t (*trsm_func_def)(    cublasHandle_t,
@@ -91,7 +81,6 @@ FUNC##_func();
 #define BLAS_FUNC( FUNC, TYPE, PREFIX )         \
 template<> typename FUNC##_func_def_t<TYPE>::FUNC##_func_def       FUNC##_func<TYPE>()  { return &cublas##PREFIX##FUNC; }
 
-
 BLAS_FUNC_DEF(gemm)
 BLAS_FUNC(gemm, float,  S)
 BLAS_FUNC(gemm, cfloat, C)
@@ -104,24 +93,54 @@ BLAS_FUNC(gemv, cfloat, C)
 BLAS_FUNC(gemv, double, D)
 BLAS_FUNC(gemv, cdouble,Z)
 
-BLAS_FUNC_DEF(dot)
-BLAS_FUNC(dot, float,  S)
-BLAS_FUNC(dot, double, D)
-
 BLAS_FUNC_DEF(trsm)
 BLAS_FUNC(trsm, float,  S)
 BLAS_FUNC(trsm, cfloat, C)
 BLAS_FUNC(trsm, double, D)
 BLAS_FUNC(trsm, cdouble,Z)
 
+#undef BLAS_FUNC
+#undef BLAS_FUNC_DEF
+
+template<typename T, bool conjugate>
+struct dot_func_def_t
+{
+    typedef cublasStatus_t (*dot_func_def)(    cublasHandle_t,
+                                                int,
+                                                const T *,  int,
+                                                const T *,  int,
+                                                T *);
+};
+
+#define BLAS_FUNC_DEF( FUNC )                                   \
+template<typename T, bool conjugate>                            \
+typename FUNC##_func_def_t<T, conjugate>::FUNC##_func_def       \
+FUNC##_func();
+
+#define BLAS_FUNC( FUNC, TYPE, CONJUGATE, PREFIX )                           \
+template<> typename FUNC##_func_def_t<TYPE, CONJUGATE>::FUNC##_func_def      \
+FUNC##_func<TYPE, CONJUGATE>()  { return &cublas##PREFIX##FUNC; }
+
+BLAS_FUNC_DEF(dot)
+BLAS_FUNC(dot, float,  true,  S)
+BLAS_FUNC(dot, double, true,  D)
+BLAS_FUNC(dot, float,  false, S)
+BLAS_FUNC(dot, double, false, D)
 
 #undef BLAS_FUNC
-#define BLAS_FUNC( FUNC, TYPE, PREFIX, SUFFIX)         \
-template<> typename FUNC##_func_def_t<TYPE>::FUNC##_func_def       FUNC##_func<TYPE>()  { return &cublas##PREFIX##FUNC##SUFFIX; }
+
+#define BLAS_FUNC( FUNC, TYPE, CONJUGATE, PREFIX, SUFFIX)                \
+template<> typename FUNC##_func_def_t<TYPE, CONJUGATE>::FUNC##_func_def  \
+FUNC##_func<TYPE, CONJUGATE>()  { return &cublas##PREFIX##FUNC##SUFFIX; }
 
 BLAS_FUNC_DEF(dot)
-BLAS_FUNC(dot, cfloat,  C, u)
-BLAS_FUNC(dot, cdouble, Z, u)
+BLAS_FUNC(dot, cfloat,  true , C, c)
+BLAS_FUNC(dot, cdouble, true , Z, c)
+BLAS_FUNC(dot, cfloat,  false, C, u)
+BLAS_FUNC(dot, cdouble, false, Z, u)
+
+#undef BLAS_FUNC
+#undef BLAS_FUNC_DEF
 
 using namespace std;
 
@@ -178,21 +197,40 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
 
 }
 
-template<typename T>
-Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
-             af_mat_prop optLhs, af_mat_prop optRhs)
+template<typename T, bool conjugate, bool both_conjugate>
+Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
+              af_mat_prop optLhs, af_mat_prop optRhs)
 {
     int N = lhs.dims()[0];
 
     T out;
 
-    CUBLAS_CHECK(dot_func<T>()(getHandle(),
-                               N,
-                               lhs.get(), lhs.strides()[0],
-                               rhs.get(), rhs.strides()[0],
-                               &out));
+    CUBLAS_CHECK((dot_func<T, conjugate>()(
+                 getHandle(),
+                 N,
+                 lhs.get(), lhs.strides()[0],
+                 rhs.get(), rhs.strides()[0],
+                 &out)));
+
+    if(both_conjugate)
+        return createValueArray(af::dim4(1), conj(out));
+    else
+        return createValueArray(af::dim4(1), out);
+}
 
-    return createValueArray(af::dim4(1), out);
+template<typename T>
+Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
+             af_mat_prop optLhs, af_mat_prop optRhs)
+{
+    if(optLhs == AF_MAT_CONJ && optRhs == AF_MAT_CONJ) {
+        return dot_<T, false, true>(lhs, rhs, optLhs, optRhs);
+    } else if (optLhs == AF_MAT_CONJ && optRhs == AF_MAT_NONE) {
+        return dot_<T, true, false>(lhs, rhs, optLhs, optRhs);
+    } else if (optLhs == AF_MAT_NONE && optRhs == AF_MAT_CONJ) {
+        return dot_<T, true, false>(rhs, lhs, optRhs, optLhs);
+    } else {
+        return dot_<T, false, false>(lhs, rhs, optLhs, optRhs);
+    }
 }
 
 template<typename T>
@@ -223,7 +261,7 @@ void trsm(const Array<T> &lhs, Array<T> &rhs, af_mat_prop trans,
 
 
 #define INSTANTIATE_BLAS(TYPE)                                                          \
-    template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs,  \
+    template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs,   \
                                       af_mat_prop optLhs, af_mat_prop optRhs);
 
 INSTANTIATE_BLAS(float)
@@ -231,8 +269,8 @@ INSTANTIATE_BLAS(cfloat)
 INSTANTIATE_BLAS(double)
 INSTANTIATE_BLAS(cdouble)
 
-#define INSTANTIATE_DOT(TYPE)                                                       \
-    template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
+#define INSTANTIATE_DOT(TYPE)                                                           \
+    template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs,      \
                                    af_mat_prop optLhs, af_mat_prop optRhs);
 
 INSTANTIATE_DOT(float)
diff --git a/src/backend/opencl/blas.cpp b/src/backend/opencl/blas.cpp
index bd811db..6173a68 100644
--- a/src/backend/opencl/blas.cpp
+++ b/src/backend/opencl/blas.cpp
@@ -18,6 +18,7 @@
 #include <err_common.hpp>
 #include <err_clblas.hpp>
 #include <math.hpp>
+#include <transpose.hpp>
 
 namespace opencl
 {
@@ -34,10 +35,10 @@ toClblasTranspose(af_mat_prop opt)
 {
     clblasTranspose out = clblasNoTrans;
     switch(opt) {
-        case AF_MAT_NONE        : out = clblasNoTrans;   break;
-        case AF_MAT_TRANS           : out = clblasTrans;     break;
-        case AF_MAT_CTRANS : out = clblasConjTrans; break;
-        default                     : AF_ERROR("INVALID af_mat_prop", AF_ERR_ARG);
+        case AF_MAT_NONE    : out = clblasNoTrans;   break;
+        case AF_MAT_TRANS   : out = clblasTrans;     break;
+        case AF_MAT_CTRANS  : out = clblasConjTrans; break;
+        default             : AF_ERROR("INVALID af_mat_prop", AF_ERR_ARG);
     }
     return out;
 }
@@ -67,24 +68,43 @@ BLAS_FUNC(gemv, double,     D)
 BLAS_FUNC(gemv, cfloat,     C)
 BLAS_FUNC(gemv, cdouble,    Z)
 
+#undef BLAS_FUNC_DEF
+#undef BLAS_FUNC
+
+#define BLAS_FUNC_DEF(NAME)                                             \
+template<typename T, bool conjugate>                                    \
+struct NAME##_func;
+
+#define BLAS_FUNC(NAME, TYPE, CONJUGATE, PREFIX)                        \
+template<>                                                              \
+struct NAME##_func<TYPE, CONJUGATE>                                     \
+{                                                                       \
+    template<typename... Args>                                          \
+    clblasStatus                                                        \
+    operator() (Args... args) { return clblas##PREFIX##NAME(args...); } \
+};
+
 BLAS_FUNC_DEF( dot )
-BLAS_FUNC(dot, float,       S)
-BLAS_FUNC(dot, double,      D)
+BLAS_FUNC(dot, float,  false, S)
+BLAS_FUNC(dot, double, false, D)
+BLAS_FUNC(dot, float,  true , S)
+BLAS_FUNC(dot, double, true , D)
 
-#undef BLAS_FUNC_DEF
 #undef BLAS_FUNC
 
-#define BLAS_FUNC(NAME, TYPE, PREFIX, SUFFIX)					\
+#define BLAS_FUNC(NAME, TYPE, CONJUGATE, PREFIX, SUFFIX)                \
 template<>                                                              \
-struct NAME##_func<TYPE>                                                \
+struct NAME##_func<TYPE, CONJUGATE>                                     \
 {                                                                       \
     template<typename... Args>                                          \
     clblasStatus                                                        \
     operator() (Args... args) { return clblas##PREFIX##NAME##SUFFIX(args...); } \
 };
 
-BLAS_FUNC(dot, cfloat,       C, u)
-BLAS_FUNC(dot, cdouble,      Z, u)
+BLAS_FUNC(dot, cfloat,  true , C, c)
+BLAS_FUNC(dot, cdouble, true , Z, c)
+BLAS_FUNC(dot, cfloat,  false, C, u)
+BLAS_FUNC(dot, cdouble, false, Z, u)
 
 #undef BLAS_FUNC_DEF
 #undef BLAS_FUNC
@@ -148,16 +168,16 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
     return out;
 }
 
-template<typename T>
-Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
-             af_mat_prop optLhs, af_mat_prop optRhs)
+template<typename T, bool conjugate, bool both_conjugate>
+Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
+              af_mat_prop optLhs, af_mat_prop optRhs)
 {
     initBlas();
 
     int N = lhs.dims()[0];
-    dot_func<T> dot;
+    dot_func<T, conjugate> dot;
     cl::Event event;
-    auto out = createEmptyArray<T>(af::dim4(1));
+    Array<T> out = createEmptyArray<T>(af::dim4(1));
     cl::Buffer scratch(getContext(), CL_MEM_READ_WRITE, sizeof(T) * N);
     CLBLAS_CHECK(
         dot(N,
@@ -167,11 +187,30 @@ Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
             scratch(),
             1, &getQueue()(), 0, nullptr, &event())
         );
+
+    if(both_conjugate)
+        transpose_inplace<T>(out, true);
+
     return out;
 }
 
+template<typename T>
+Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
+             af_mat_prop optLhs, af_mat_prop optRhs)
+{
+    if(optLhs == AF_MAT_CONJ && optRhs == AF_MAT_CONJ) {
+        return dot_<T, false, true>(lhs, rhs, optLhs, optRhs);
+    } else if (optLhs == AF_MAT_CONJ && optRhs == AF_MAT_NONE) {
+        return dot_<T, true, false>(lhs, rhs, optLhs, optRhs);
+    } else if (optLhs == AF_MAT_NONE && optRhs == AF_MAT_CONJ) {
+        return dot_<T, true, false>(rhs, lhs, optRhs, optLhs);
+    } else {
+        return dot_<T, false, false>(lhs, rhs, optLhs, optRhs);
+    }
+}
+
 #define INSTANTIATE_BLAS(TYPE)                                                          \
-    template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs,  \
+    template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs,   \
                     af_mat_prop optLhs, af_mat_prop optRhs);
 
 INSTANTIATE_BLAS(float)
@@ -180,13 +219,9 @@ INSTANTIATE_BLAS(double)
 INSTANTIATE_BLAS(cdouble)
 
 #define INSTANTIATE_DOT(TYPE)                                                       \
-    template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
+    template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs,  \
                                    af_mat_prop optLhs, af_mat_prop optRhs);
 
-template<typename T>
-Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
-              af_mat_prop optLhs, af_mat_prop optRhs);
-
 INSTANTIATE_DOT(float)
 INSTANTIATE_DOT(double)
 INSTANTIATE_DOT(cfloat)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list