[arrayfire] 02/408: Implement CUDA complex dot product

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:11:01 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.

commit 3053edd8af2a5b747ee0444f940a8263cd3a27f0
Author: Filipe Maia <filipe.c.maia at gmail.com>
Date:   Sat Jun 13 19:03:05 2015 +0200

    Implement CUDA complex dot product
---
 src/api/c/blas.cpp        |  4 ++--
 src/backend/cuda/blas.cpp | 12 ++++++++++++
 2 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/src/api/c/blas.cpp b/src/api/c/blas.cpp
index 203c1bb..6d5f3aa 100644
--- a/src/api/c/blas.cpp
+++ b/src/api/c/blas.cpp
@@ -116,9 +116,9 @@ af_err af_dot(      af_array *out,
 
         switch(lhs_type) {
         case f32: output = dot<float  >(lhs, rhs, optLhs, optRhs);   break;
-            //case c32: output = dot<cfloat >(lhs, rhs, optLhs, optRhs);   break;
+	case c32: output = dot<cfloat >(lhs, rhs, optLhs, optRhs);   break;
         case f64: output = dot<double >(lhs, rhs, optLhs, optRhs);   break;
-            //case c64: output = dot<cdouble>(lhs, rhs, optLhs, optRhs);   break;
+	case c64: output = dot<cdouble>(lhs, rhs, optLhs, optRhs);   break;
         default:  TYPE_ERROR(1, lhs_type);
         }
         std::swap(*out, output);
diff --git a/src/backend/cuda/blas.cpp b/src/backend/cuda/blas.cpp
index 9ed9cdf..91a0b62 100644
--- a/src/backend/cuda/blas.cpp
+++ b/src/backend/cuda/blas.cpp
@@ -91,6 +91,7 @@ FUNC##_func();
 #define BLAS_FUNC( FUNC, TYPE, PREFIX )         \
 template<> typename FUNC##_func_def_t<TYPE>::FUNC##_func_def       FUNC##_func<TYPE>()  { return &cublas##PREFIX##FUNC; }
 
+
 BLAS_FUNC_DEF(gemm)
 BLAS_FUNC(gemm, float,  S)
 BLAS_FUNC(gemm, cfloat, C)
@@ -113,6 +114,15 @@ BLAS_FUNC(trsm, cfloat, C)
 BLAS_FUNC(trsm, double, D)
 BLAS_FUNC(trsm, cdouble,Z)
 
+
+#undef BLAS_FUNC
+#define BLAS_FUNC( FUNC, TYPE, PREFIX, SUFFIX)         \
+template<> typename FUNC##_func_def_t<TYPE>::FUNC##_func_def       FUNC##_func<TYPE>()  { return &cublas##PREFIX##FUNC##SUFFIX; }
+
+BLAS_FUNC_DEF(dot)
+BLAS_FUNC(dot, cfloat,  C, u)
+BLAS_FUNC(dot, cdouble, Z, u)
+
 using namespace std;
 
 template<typename T>
@@ -227,6 +237,8 @@ INSTANTIATE_BLAS(cdouble)
 
 INSTANTIATE_DOT(float)
 INSTANTIATE_DOT(double)
+INSTANTIATE_DOT(cfloat)
+INSTANTIATE_DOT(cdouble)
 
 #define INSTANTIATE_TRSM(TYPE)                                                          \
     template void trsm<TYPE>(const Array<TYPE> &lhs, Array<TYPE> &rhs,                  \

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list