[arrayfire] 02/408: Implement CUDA complex dot product
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:11:01 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.
commit 3053edd8af2a5b747ee0444f940a8263cd3a27f0
Author: Filipe Maia <filipe.c.maia at gmail.com>
Date: Sat Jun 13 19:03:05 2015 +0200
Implement CUDA complex dot product
---
src/api/c/blas.cpp | 4 ++--
src/backend/cuda/blas.cpp | 12 ++++++++++++
2 files changed, 14 insertions(+), 2 deletions(-)
diff --git a/src/api/c/blas.cpp b/src/api/c/blas.cpp
index 203c1bb..6d5f3aa 100644
--- a/src/api/c/blas.cpp
+++ b/src/api/c/blas.cpp
@@ -116,9 +116,9 @@ af_err af_dot( af_array *out,
switch(lhs_type) {
case f32: output = dot<float >(lhs, rhs, optLhs, optRhs); break;
- //case c32: output = dot<cfloat >(lhs, rhs, optLhs, optRhs); break;
+ case c32: output = dot<cfloat >(lhs, rhs, optLhs, optRhs); break;
case f64: output = dot<double >(lhs, rhs, optLhs, optRhs); break;
- //case c64: output = dot<cdouble>(lhs, rhs, optLhs, optRhs); break;
+ case c64: output = dot<cdouble>(lhs, rhs, optLhs, optRhs); break;
default: TYPE_ERROR(1, lhs_type);
}
std::swap(*out, output);
diff --git a/src/backend/cuda/blas.cpp b/src/backend/cuda/blas.cpp
index 9ed9cdf..91a0b62 100644
--- a/src/backend/cuda/blas.cpp
+++ b/src/backend/cuda/blas.cpp
@@ -91,6 +91,7 @@ FUNC##_func();
#define BLAS_FUNC( FUNC, TYPE, PREFIX ) \
template<> typename FUNC##_func_def_t<TYPE>::FUNC##_func_def FUNC##_func<TYPE>() { return &cublas##PREFIX##FUNC; }
+
BLAS_FUNC_DEF(gemm)
BLAS_FUNC(gemm, float, S)
BLAS_FUNC(gemm, cfloat, C)
@@ -113,6 +114,15 @@ BLAS_FUNC(trsm, cfloat, C)
BLAS_FUNC(trsm, double, D)
BLAS_FUNC(trsm, cdouble,Z)
+
+#undef BLAS_FUNC
+#define BLAS_FUNC( FUNC, TYPE, PREFIX, SUFFIX) \
+template<> typename FUNC##_func_def_t<TYPE>::FUNC##_func_def FUNC##_func<TYPE>() { return &cublas##PREFIX##FUNC##SUFFIX; }
+
+BLAS_FUNC_DEF(dot)
+BLAS_FUNC(dot, cfloat, C, u)
+BLAS_FUNC(dot, cdouble, Z, u)
+
using namespace std;
template<typename T>
@@ -227,6 +237,8 @@ INSTANTIATE_BLAS(cdouble)
INSTANTIATE_DOT(float)
INSTANTIATE_DOT(double)
+INSTANTIATE_DOT(cfloat)
+INSTANTIATE_DOT(cdouble)
#define INSTANTIATE_TRSM(TYPE) \
template void trsm<TYPE>(const Array<TYPE> &lhs, Array<TYPE> &rhs, \
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list