[arrayfire] 02/284: Convert CPU blas to use async queues
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Sun Feb 7 18:59:12 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/experimental
in repository arrayfire.
commit b94c3df4e1c5288eb1bb985f372e3f8e5910fe3d
Author: Umar Arshad <umar at arrayfire.com>
Date: Sun Aug 9 01:19:34 2015 -0400
Convert CPU blas to use async queues
---
src/backend/cpu/blas.cpp | 73 ++++++++++++++++++++++++++----------------------
1 file changed, 40 insertions(+), 33 deletions(-)
diff --git a/src/backend/cpu/blas.cpp b/src/backend/cpu/blas.cpp
index 0bbd399..8887202 100644
--- a/src/backend/cpu/blas.cpp
+++ b/src/backend/cpu/blas.cpp
@@ -13,6 +13,8 @@
#include <cassert>
#include <err_cpu.hpp>
#include <err_common.hpp>
+#include <platform.hpp>
+#include <async_queue.hpp>
namespace cpu
{
@@ -131,36 +133,38 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
int N = rDims[bColDim];
int K = lDims[aColDim];
- //FIXME: Leaks on errors.
- Array<T> out = createEmptyArray<T>(af::dim4(M, N, 1, 1));
- auto alpha = getScale<T, 1>();
- auto beta = getScale<T, 0>();
-
- dim4 lStrides = lhs.strides();
- dim4 rStrides = rhs.strides();
using BT = typename blas_base<T>::type;
using CBT = const typename blas_base<T>::type;
- if(rDims[bColDim] == 1) {
- N = lDims[aColDim];
- gemv_func<T>()(
- CblasColMajor, lOpts,
- lDims[0], lDims[1],
- alpha,
- reinterpret_cast<CBT*>(lhs.get()), lStrides[1],
- reinterpret_cast<CBT*>(rhs.get()), rStrides[0],
- beta,
- reinterpret_cast<BT*>(out.get()), 1);
- } else {
- gemm_func<T>()(
- CblasColMajor, lOpts, rOpts,
- M, N, K,
- alpha,
- reinterpret_cast<CBT*>(lhs.get()), lStrides[1],
- reinterpret_cast<CBT*>(rhs.get()), rStrides[1],
- beta,
- reinterpret_cast<BT*>(out.get()), out.dims()[0]);
- }
+ Array<T> out = createEmptyArray<T>(af::dim4(M, N, 1, 1));
+ auto func = [=] (Array<T> output, const Array<T> left, const Array<T> right) {
+ auto alpha = getScale<T, 1>();
+ auto beta = getScale<T, 0>();
+
+ dim4 lStrides = left.strides();
+ dim4 rStrides = right.strides();
+
+ if(rDims[bColDim] == 1) {
+ gemv_func<T>()(
+ CblasColMajor, lOpts,
+ lDims[0], lDims[1],
+ alpha,
+ reinterpret_cast<CBT*>(left.get()), lStrides[1],
+ reinterpret_cast<CBT*>(right.get()), rStrides[0],
+ beta,
+ reinterpret_cast<BT*>(output.get()), 1);
+ } else {
+ gemm_func<T>()(
+ CblasColMajor, lOpts, rOpts,
+ M, N, K,
+ alpha,
+ reinterpret_cast<CBT*>(left.get()), lStrides[1],
+ reinterpret_cast<CBT*>(right.get()), rStrides[1],
+ beta,
+ reinterpret_cast<BT*>(output.get()), output.dims()[0]);
+ }
+ };
+ getQueue().enqueue(func, out, lhs, rhs);
return out;
}
@@ -172,7 +176,7 @@ template<> cfloat conj<cfloat> (cfloat c) { return std::conj(c); }
template<> cdouble conj<cdouble>(cdouble c) { return std::conj(c); }
template<typename T, bool conjugate, bool both_conjugate>
-Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
+void dot_(Array<T> output, const Array<T> &lhs, const Array<T> &rhs,
af_mat_prop optLhs, af_mat_prop optRhs)
{
int N = lhs.dims()[0];
@@ -186,22 +190,25 @@ Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
if(both_conjugate) out = cpu::conj(out);
- return createValueArray(af::dim4(1), out);
+ *output.get() = out;
+
}
template<typename T>
Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
af_mat_prop optLhs, af_mat_prop optRhs)
{
+ Array<T> out = createEmptyArray<T>(af::dim4(1));
if(optLhs == AF_MAT_CONJ && optRhs == AF_MAT_CONJ) {
- return dot_<T, false, true>(lhs, rhs, optLhs, optRhs);
+ getQueue().enqueue(dot_<T, false, true>, out, lhs, rhs, optLhs, optRhs);
} else if (optLhs == AF_MAT_CONJ && optRhs == AF_MAT_NONE) {
- return dot_<T, true, false>(lhs, rhs, optLhs, optRhs);
+ getQueue().enqueue(dot_<T, true, false>,out, lhs, rhs, optLhs, optRhs);
} else if (optLhs == AF_MAT_NONE && optRhs == AF_MAT_CONJ) {
- return dot_<T, true, false>(rhs, lhs, optRhs, optLhs);
+ getQueue().enqueue(dot_<T, true, false>,out, rhs, lhs, optRhs, optLhs);
} else {
- return dot_<T, false, false>(lhs, rhs, optLhs, optRhs);
+ getQueue().enqueue(dot_<T, false, false>,out, lhs, rhs, optLhs, optRhs);
}
+ return out;
}
#undef BT
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list