[arrayfire] 43/408: Added options for dotc and dotu to dot function
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:11:13 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.
commit b62696722ba0983199bc7cb7a711c00e86625e03
Author: Shehzan Mohammed <shehzan at arrayfire.com>
Date: Thu Jun 25 17:29:44 2015 -0400
Added options for dotc and dotu to dot function
---
include/af/blas.h | 11 +++++
include/af/defines.h | 1 +
src/api/c/blas.cpp | 16 +++----
src/backend/cpu/blas.cpp | 43 ++++++++++++++-----
src/backend/cuda/blas.cpp | 100 ++++++++++++++++++++++++++++++--------------
src/backend/opencl/blas.cpp | 79 ++++++++++++++++++++++++----------
6 files changed, 179 insertions(+), 71 deletions(-)
diff --git a/include/af/blas.h b/include/af/blas.h
index 8aa59e4..1a89ad2 100644
--- a/include/af/blas.h
+++ b/include/af/blas.h
@@ -138,8 +138,19 @@ namespace af
af_print(dot(x,y));
}
+ \param[in] lhs The array object on the left hand side
+ \param[in] rhs The array object on the right hand side
+ \param[in] optLhs Options for lhs. Currently only \ref AF_MAT_NONE and
+ AF_MAT_CONJ are supported.
+ \param[in] optRhs Options for rhs. Currently only \ref AF_MAT_NONE and AF_MAT_CONJ are supported
+ \return The result of the dot product of lhs, rhs
+
+ \note optLhs and optRhs can only be one of \ref AF_MAT_NONE or \ref AF_MAT_CONJ
+ \note optLhs = AF_MAT_CONJ and optRhs = AF_MAT_NONE will run conjugate dot operation.
\note This function is not supported in GFOR
+ \returns out = dot(lhs, rhs)
+
\ingroup blas_func_dot
*/
AFAPI array dot (const array &lhs, const array &rhs,
diff --git a/include/af/defines.h b/include/af/defines.h
index db22b5f..a2aa7a9 100644
--- a/include/af/defines.h
+++ b/include/af/defines.h
@@ -244,6 +244,7 @@ typedef enum {
AF_MAT_NONE = 0, ///< Default
AF_MAT_TRANS = 1, ///< Data needs to be transposed
AF_MAT_CTRANS = 2, ///< Data needs to be conjugate tansposed
+ AF_MAT_CONJ = 4, ///< Data needs to be conjugate
AF_MAT_UPPER = 32, ///< Matrix is upper triangular
AF_MAT_LOWER = 64, ///< Matrix is lower triangular
AF_MAT_DIAG_UNIT = 128, ///< Matrix diagonal contains unitary values
diff --git a/src/api/c/blas.cpp b/src/api/c/blas.cpp
index 6d5f3aa..21fb44f 100644
--- a/src/api/c/blas.cpp
+++ b/src/api/c/blas.cpp
@@ -93,11 +93,11 @@ af_err af_dot( af_array *out,
ArrayInfo lhsInfo = getInfo(lhs);
ArrayInfo rhsInfo = getInfo(rhs);
- if (optLhs != AF_MAT_NONE) {
+ if (optLhs != AF_MAT_NONE && optLhs != AF_MAT_CONJ) {
AF_ERROR("Using this property is not yet supported in dot", AF_ERR_NOT_SUPPORTED);
}
- if (optRhs != AF_MAT_NONE) {
+ if (optRhs != AF_MAT_NONE && optRhs != AF_MAT_CONJ) {
AF_ERROR("Using this property is not yet supported in dot", AF_ERR_NOT_SUPPORTED);
}
@@ -105,8 +105,8 @@ af_err af_dot( af_array *out,
af_dtype lhs_type = lhsInfo.getType();
af_dtype rhs_type = rhsInfo.getType();
- if (lhsInfo.ndims() > 2 ||
- rhsInfo.ndims() > 2) {
+ if (lhsInfo.ndims() > 1 ||
+ rhsInfo.ndims() > 1) {
AF_ERROR("dot can not be used in batch mode", AF_ERR_BATCH);
}
@@ -115,10 +115,10 @@ af_err af_dot( af_array *out,
af_array output = 0;
switch(lhs_type) {
- case f32: output = dot<float >(lhs, rhs, optLhs, optRhs); break;
- case c32: output = dot<cfloat >(lhs, rhs, optLhs, optRhs); break;
- case f64: output = dot<double >(lhs, rhs, optLhs, optRhs); break;
- case c64: output = dot<cdouble>(lhs, rhs, optLhs, optRhs); break;
+ case f32: output = dot<float >(lhs, rhs, optLhs, optRhs); break;
+ case c32: output = dot<cfloat >(lhs, rhs, optLhs, optRhs); break;
+ case f64: output = dot<double >(lhs, rhs, optLhs, optRhs); break;
+ case c64: output = dot<cdouble>(lhs, rhs, optLhs, optRhs); break;
default: TYPE_ERROR(1, lhs_type);
}
std::swap(*out, output);
diff --git a/src/backend/cpu/blas.cpp b/src/backend/cpu/blas.cpp
index 65fc19f..44094e3 100644
--- a/src/backend/cpu/blas.cpp
+++ b/src/backend/cpu/blas.cpp
@@ -102,9 +102,9 @@ toCblasTranspose(af_mat_prop opt)
CBLAS_TRANSPOSE out = CblasNoTrans;
switch(opt) {
case AF_MAT_NONE : out = CblasNoTrans; break;
- case AF_MAT_TRANS : out = CblasTrans; break;
- case AF_MAT_CTRANS : out = CblasConjTrans; break;
- default : AF_ERROR("INVALID af_mat_prop", AF_ERR_ARG);
+ case AF_MAT_TRANS : out = CblasTrans; break;
+ case AF_MAT_CTRANS : out = CblasConjTrans; break;
+ default : AF_ERROR("INVALID af_mat_prop", AF_ERR_ARG);
}
return out;
}
@@ -185,9 +185,15 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
return out;
}
-template<typename T>
-Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
- af_mat_prop optLhs, af_mat_prop optRhs)
+template<typename T> T
+conj(T x) { return x; }
+
+template<> cfloat conj<cfloat> (cfloat c) { return std::conj(c); }
+template<> cdouble conj<cdouble>(cdouble c) { return std::conj(c); }
+
+template<typename T, bool conjugate, bool both_conjugate>
+Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
+ af_mat_prop optLhs, af_mat_prop optRhs)
{
int N = lhs.dims()[0];
@@ -196,16 +202,33 @@ Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
const T *pR = rhs.get();
for(int i = 0; i < N; i++)
- out += pL[i] * pR[i];
+ out += (conjugate ? cpu::conj(pL[i]) : pL[i]) * pR[i];
+
+ if(both_conjugate) out = cpu::conj(out);
return createValueArray(af::dim4(1), out);
}
+template<typename T>
+Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
+ af_mat_prop optLhs, af_mat_prop optRhs)
+{
+ if(optLhs == AF_MAT_CONJ && optRhs == AF_MAT_CONJ) {
+ return dot_<T, false, true>(lhs, rhs, optLhs, optRhs);
+ } else if (optLhs == AF_MAT_CONJ && optRhs == AF_MAT_NONE) {
+ return dot_<T, true, false>(lhs, rhs, optLhs, optRhs);
+ } else if (optLhs == AF_MAT_NONE && optRhs == AF_MAT_CONJ) {
+ return dot_<T, true, false>(rhs, lhs, optRhs, optLhs);
+ } else {
+ return dot_<T, false, false>(lhs, rhs, optLhs, optRhs);
+ }
+}
+
#undef BT
#undef REINTEPRET_CAST
#define INSTANTIATE_BLAS(TYPE) \
- template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
+ template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
af_mat_prop optLhs, af_mat_prop optRhs);
INSTANTIATE_BLAS(float)
@@ -213,8 +236,8 @@ INSTANTIATE_BLAS(cfloat)
INSTANTIATE_BLAS(double)
INSTANTIATE_BLAS(cdouble)
-#define INSTANTIATE_DOT(TYPE) \
- template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
+#define INSTANTIATE_DOT(TYPE) \
+ template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
af_mat_prop optLhs, af_mat_prop optRhs);
INSTANTIATE_DOT(float)
diff --git a/src/backend/cuda/blas.cpp b/src/backend/cuda/blas.cpp
index 91a0b62..aaccc65 100644
--- a/src/backend/cuda/blas.cpp
+++ b/src/backend/cuda/blas.cpp
@@ -60,16 +60,6 @@ struct gemv_func_def_t
};
template<typename T>
-struct dot_func_def_t
-{
- typedef cublasStatus_t (*dot_func_def)( cublasHandle_t,
- int,
- const T *, int,
- const T *, int,
- T *);
-};
-
-template<typename T>
struct trsm_func_def_t
{
typedef cublasStatus_t (*trsm_func_def)( cublasHandle_t,
@@ -91,7 +81,6 @@ FUNC##_func();
#define BLAS_FUNC( FUNC, TYPE, PREFIX ) \
template<> typename FUNC##_func_def_t<TYPE>::FUNC##_func_def FUNC##_func<TYPE>() { return &cublas##PREFIX##FUNC; }
-
BLAS_FUNC_DEF(gemm)
BLAS_FUNC(gemm, float, S)
BLAS_FUNC(gemm, cfloat, C)
@@ -104,24 +93,54 @@ BLAS_FUNC(gemv, cfloat, C)
BLAS_FUNC(gemv, double, D)
BLAS_FUNC(gemv, cdouble,Z)
-BLAS_FUNC_DEF(dot)
-BLAS_FUNC(dot, float, S)
-BLAS_FUNC(dot, double, D)
-
BLAS_FUNC_DEF(trsm)
BLAS_FUNC(trsm, float, S)
BLAS_FUNC(trsm, cfloat, C)
BLAS_FUNC(trsm, double, D)
BLAS_FUNC(trsm, cdouble,Z)
+#undef BLAS_FUNC
+#undef BLAS_FUNC_DEF
+
+template<typename T, bool conjugate>
+struct dot_func_def_t
+{
+ typedef cublasStatus_t (*dot_func_def)( cublasHandle_t,
+ int,
+ const T *, int,
+ const T *, int,
+ T *);
+};
+
+#define BLAS_FUNC_DEF( FUNC ) \
+template<typename T, bool conjugate> \
+typename FUNC##_func_def_t<T, conjugate>::FUNC##_func_def \
+FUNC##_func();
+
+#define BLAS_FUNC( FUNC, TYPE, CONJUGATE, PREFIX ) \
+template<> typename FUNC##_func_def_t<TYPE, CONJUGATE>::FUNC##_func_def \
+FUNC##_func<TYPE, CONJUGATE>() { return &cublas##PREFIX##FUNC; }
+
+BLAS_FUNC_DEF(dot)
+BLAS_FUNC(dot, float, true, S)
+BLAS_FUNC(dot, double, true, D)
+BLAS_FUNC(dot, float, false, S)
+BLAS_FUNC(dot, double, false, D)
#undef BLAS_FUNC
-#define BLAS_FUNC( FUNC, TYPE, PREFIX, SUFFIX) \
-template<> typename FUNC##_func_def_t<TYPE>::FUNC##_func_def FUNC##_func<TYPE>() { return &cublas##PREFIX##FUNC##SUFFIX; }
+
+#define BLAS_FUNC( FUNC, TYPE, CONJUGATE, PREFIX, SUFFIX) \
+template<> typename FUNC##_func_def_t<TYPE, CONJUGATE>::FUNC##_func_def \
+FUNC##_func<TYPE, CONJUGATE>() { return &cublas##PREFIX##FUNC##SUFFIX; }
BLAS_FUNC_DEF(dot)
-BLAS_FUNC(dot, cfloat, C, u)
-BLAS_FUNC(dot, cdouble, Z, u)
+BLAS_FUNC(dot, cfloat, true , C, c)
+BLAS_FUNC(dot, cdouble, true , Z, c)
+BLAS_FUNC(dot, cfloat, false, C, u)
+BLAS_FUNC(dot, cdouble, false, Z, u)
+
+#undef BLAS_FUNC
+#undef BLAS_FUNC_DEF
using namespace std;
@@ -178,21 +197,40 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
}
-template<typename T>
-Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
- af_mat_prop optLhs, af_mat_prop optRhs)
+template<typename T, bool conjugate, bool both_conjugate>
+Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
+ af_mat_prop optLhs, af_mat_prop optRhs)
{
int N = lhs.dims()[0];
T out;
- CUBLAS_CHECK(dot_func<T>()(getHandle(),
- N,
- lhs.get(), lhs.strides()[0],
- rhs.get(), rhs.strides()[0],
- &out));
+ CUBLAS_CHECK((dot_func<T, conjugate>()(
+ getHandle(),
+ N,
+ lhs.get(), lhs.strides()[0],
+ rhs.get(), rhs.strides()[0],
+ &out)));
+
+ if(both_conjugate)
+ return createValueArray(af::dim4(1), conj(out));
+ else
+ return createValueArray(af::dim4(1), out);
+}
- return createValueArray(af::dim4(1), out);
+template<typename T>
+Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
+ af_mat_prop optLhs, af_mat_prop optRhs)
+{
+ if(optLhs == AF_MAT_CONJ && optRhs == AF_MAT_CONJ) {
+ return dot_<T, false, true>(lhs, rhs, optLhs, optRhs);
+ } else if (optLhs == AF_MAT_CONJ && optRhs == AF_MAT_NONE) {
+ return dot_<T, true, false>(lhs, rhs, optLhs, optRhs);
+ } else if (optLhs == AF_MAT_NONE && optRhs == AF_MAT_CONJ) {
+ return dot_<T, true, false>(rhs, lhs, optRhs, optLhs);
+ } else {
+ return dot_<T, false, false>(lhs, rhs, optLhs, optRhs);
+ }
}
template<typename T>
@@ -223,7 +261,7 @@ void trsm(const Array<T> &lhs, Array<T> &rhs, af_mat_prop trans,
#define INSTANTIATE_BLAS(TYPE) \
- template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
+ template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
af_mat_prop optLhs, af_mat_prop optRhs);
INSTANTIATE_BLAS(float)
@@ -231,8 +269,8 @@ INSTANTIATE_BLAS(cfloat)
INSTANTIATE_BLAS(double)
INSTANTIATE_BLAS(cdouble)
-#define INSTANTIATE_DOT(TYPE) \
- template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
+#define INSTANTIATE_DOT(TYPE) \
+ template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
af_mat_prop optLhs, af_mat_prop optRhs);
INSTANTIATE_DOT(float)
diff --git a/src/backend/opencl/blas.cpp b/src/backend/opencl/blas.cpp
index bd811db..6173a68 100644
--- a/src/backend/opencl/blas.cpp
+++ b/src/backend/opencl/blas.cpp
@@ -18,6 +18,7 @@
#include <err_common.hpp>
#include <err_clblas.hpp>
#include <math.hpp>
+#include <transpose.hpp>
namespace opencl
{
@@ -34,10 +35,10 @@ toClblasTranspose(af_mat_prop opt)
{
clblasTranspose out = clblasNoTrans;
switch(opt) {
- case AF_MAT_NONE : out = clblasNoTrans; break;
- case AF_MAT_TRANS : out = clblasTrans; break;
- case AF_MAT_CTRANS : out = clblasConjTrans; break;
- default : AF_ERROR("INVALID af_mat_prop", AF_ERR_ARG);
+ case AF_MAT_NONE : out = clblasNoTrans; break;
+ case AF_MAT_TRANS : out = clblasTrans; break;
+ case AF_MAT_CTRANS : out = clblasConjTrans; break;
+ default : AF_ERROR("INVALID af_mat_prop", AF_ERR_ARG);
}
return out;
}
@@ -67,24 +68,43 @@ BLAS_FUNC(gemv, double, D)
BLAS_FUNC(gemv, cfloat, C)
BLAS_FUNC(gemv, cdouble, Z)
+#undef BLAS_FUNC_DEF
+#undef BLAS_FUNC
+
+#define BLAS_FUNC_DEF(NAME) \
+template<typename T, bool conjugate> \
+struct NAME##_func;
+
+#define BLAS_FUNC(NAME, TYPE, CONJUGATE, PREFIX) \
+template<> \
+struct NAME##_func<TYPE, CONJUGATE> \
+{ \
+ template<typename... Args> \
+ clblasStatus \
+ operator() (Args... args) { return clblas##PREFIX##NAME(args...); } \
+};
+
BLAS_FUNC_DEF( dot )
-BLAS_FUNC(dot, float, S)
-BLAS_FUNC(dot, double, D)
+BLAS_FUNC(dot, float, false, S)
+BLAS_FUNC(dot, double, false, D)
+BLAS_FUNC(dot, float, true , S)
+BLAS_FUNC(dot, double, true , D)
-#undef BLAS_FUNC_DEF
#undef BLAS_FUNC
-#define BLAS_FUNC(NAME, TYPE, PREFIX, SUFFIX) \
+#define BLAS_FUNC(NAME, TYPE, CONJUGATE, PREFIX, SUFFIX) \
template<> \
-struct NAME##_func<TYPE> \
+struct NAME##_func<TYPE, CONJUGATE> \
{ \
template<typename... Args> \
clblasStatus \
operator() (Args... args) { return clblas##PREFIX##NAME##SUFFIX(args...); } \
};
-BLAS_FUNC(dot, cfloat, C, u)
-BLAS_FUNC(dot, cdouble, Z, u)
+BLAS_FUNC(dot, cfloat, true , C, c)
+BLAS_FUNC(dot, cdouble, true , Z, c)
+BLAS_FUNC(dot, cfloat, false, C, u)
+BLAS_FUNC(dot, cdouble, false, Z, u)
#undef BLAS_FUNC_DEF
#undef BLAS_FUNC
@@ -148,16 +168,16 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
return out;
}
-template<typename T>
-Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
- af_mat_prop optLhs, af_mat_prop optRhs)
+template<typename T, bool conjugate, bool both_conjugate>
+Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
+ af_mat_prop optLhs, af_mat_prop optRhs)
{
initBlas();
int N = lhs.dims()[0];
- dot_func<T> dot;
+ dot_func<T, conjugate> dot;
cl::Event event;
- auto out = createEmptyArray<T>(af::dim4(1));
+ Array<T> out = createEmptyArray<T>(af::dim4(1));
cl::Buffer scratch(getContext(), CL_MEM_READ_WRITE, sizeof(T) * N);
CLBLAS_CHECK(
dot(N,
@@ -167,11 +187,30 @@ Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
scratch(),
1, &getQueue()(), 0, nullptr, &event())
);
+
+ if(both_conjugate)
+ transpose_inplace<T>(out, true);
+
return out;
}
+template<typename T>
+Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
+ af_mat_prop optLhs, af_mat_prop optRhs)
+{
+ if(optLhs == AF_MAT_CONJ && optRhs == AF_MAT_CONJ) {
+ return dot_<T, false, true>(lhs, rhs, optLhs, optRhs);
+ } else if (optLhs == AF_MAT_CONJ && optRhs == AF_MAT_NONE) {
+ return dot_<T, true, false>(lhs, rhs, optLhs, optRhs);
+ } else if (optLhs == AF_MAT_NONE && optRhs == AF_MAT_CONJ) {
+ return dot_<T, true, false>(rhs, lhs, optRhs, optLhs);
+ } else {
+ return dot_<T, false, false>(lhs, rhs, optLhs, optRhs);
+ }
+}
+
#define INSTANTIATE_BLAS(TYPE) \
- template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
+ template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
af_mat_prop optLhs, af_mat_prop optRhs);
INSTANTIATE_BLAS(float)
@@ -180,13 +219,9 @@ INSTANTIATE_BLAS(double)
INSTANTIATE_BLAS(cdouble)
#define INSTANTIATE_DOT(TYPE) \
- template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
+ template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
af_mat_prop optLhs, af_mat_prop optRhs);
-template<typename T>
-Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
- af_mat_prop optLhs, af_mat_prop optRhs);
-
INSTANTIATE_DOT(float)
INSTANTIATE_DOT(double)
INSTANTIATE_DOT(cfloat)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list