[arrayfire] 70/408: STYLE: Remove macros; Simplify templates;
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:11:18 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.
commit 1b315f95535d6a55ad2e53c52fba16600db09303
Author: Umar Arshad <umar at arrayfire.com>
Date: Wed Jul 1 20:38:10 2015 -0400
STYLE: Remove macros; Simplify templates;
Works on my machine(TM); Not tested on Windows
---
src/backend/cpu/blas.cpp | 155 ++++++++++++++++++++---------------------------
1 file changed, 67 insertions(+), 88 deletions(-)
diff --git a/src/backend/cpu/blas.cpp b/src/backend/cpu/blas.cpp
index 39fa32b..0bbd399 100644
--- a/src/backend/cpu/blas.cpp
+++ b/src/backend/cpu/blas.cpp
@@ -24,77 +24,81 @@ namespace cpu
using std::remove_const;
using std::conditional;
-template<typename T, typename BT>
+ // Some implementations of BLAS require void* for complex pointers while
+ // others use float*/double*
+#if (defined(OS_WIN) && !defined(USE_MKL)) || defined(IS_OPENBLAS)
+ static const bool cplx_void_ptr = false;
+#else
+ static const bool cplx_void_ptr = true;
+#endif
+
+template<typename T, class Enable = void>
+struct blas_base {
+ using type = typename dtype_traits<T>::base_type;
+};
+
+template<typename T>
+struct blas_base <T, typename enable_if<is_complex<T>::value && cplx_void_ptr>::type> {
+ using type = void;
+};
+
+
+template<typename T>
using cptr_type = typename conditional< is_complex<T>::value,
- const BT *,
+ const typename blas_base<T>::type *,
const T*>::type;
-template<typename T, typename BT>
+template<typename T>
using ptr_type = typename conditional< is_complex<T>::value,
- BT *,
+ typename blas_base<T>::type *,
T*>::type;
-template<typename T, typename BT>
+template<typename T>
using scale_type = typename conditional< is_complex<T>::value,
- const BT *,
+ const typename blas_base<T>::type *,
const T>::type;
-template<typename T, typename BT>
+template<typename T>
using gemm_func_def = void (*)( const CBLAS_ORDER, const CBLAS_TRANSPOSE, const CBLAS_TRANSPOSE,
const blasint, const blasint, const blasint,
- scale_type<T, BT>, cptr_type<T, BT>, const blasint,
- cptr_type<T, BT>, const blasint,
- scale_type<T, BT>, ptr_type<T, BT>, const blasint);
+ scale_type<T>, cptr_type<T>, const blasint,
+ cptr_type<T>, const blasint,
+ scale_type<T>, ptr_type<T>, const blasint);
-template<typename T, typename BT>
+template<typename T>
using gemv_func_def = void (*)( const CBLAS_ORDER, const CBLAS_TRANSPOSE,
const blasint, const blasint,
- scale_type<T, BT>, cptr_type<T, BT>, const blasint,
- cptr_type<T, BT>, const blasint,
- scale_type<T, BT>, ptr_type<T, BT>, const blasint);
+ scale_type<T>, cptr_type<T>, const blasint,
+ cptr_type<T>, const blasint,
+ scale_type<T>, ptr_type<T>, const blasint);
-#define BLAS_FUNC_DEF( FUNC ) \
-template<typename T, typename BT> FUNC##_func_def<T, BT> FUNC##_func();
+#define BLAS_FUNC_DEF( FUNC ) \
+template<typename T> FUNC##_func_def<T> FUNC##_func();
-
-#define BLAS_FUNC( FUNC, TYPE, BASE_TYPE, PREFIX ) \
-template<> FUNC##_func_def<TYPE, BASE_TYPE> FUNC##_func<TYPE, BASE_TYPE>() \
+#define BLAS_FUNC( FUNC, TYPE, PREFIX ) \
+ template<> FUNC##_func_def<TYPE> FUNC##_func<TYPE>() \
{ return &cblas_##PREFIX##FUNC; }
BLAS_FUNC_DEF( gemm )
-#if (defined(OS_WIN) && !defined(USE_MKL)) || defined(IS_OPENBLAS)
-BLAS_FUNC(gemm , float , float , s)
-BLAS_FUNC(gemm , double , double , d)
-BLAS_FUNC(gemm , cfloat , float , c)
-BLAS_FUNC(gemm , cdouble , double , z)
-#else
-BLAS_FUNC(gemm , float , float , s)
-BLAS_FUNC(gemm , double , double, d)
-BLAS_FUNC(gemm , cfloat , void, c)
-BLAS_FUNC(gemm , cdouble , void, z)
-#endif
+BLAS_FUNC(gemm , float , s)
+BLAS_FUNC(gemm , double , d)
+BLAS_FUNC(gemm , cfloat , c)
+BLAS_FUNC(gemm , cdouble , z)
BLAS_FUNC_DEF(gemv)
-#if (defined(OS_WIN) && !defined(USE_MKL)) || defined(IS_OPENBLAS)
-BLAS_FUNC(gemv , float , float , s)
-BLAS_FUNC(gemv , double , double, d)
-BLAS_FUNC(gemv , cfloat , float , c)
-BLAS_FUNC(gemv , cdouble , double, z)
-#else
-BLAS_FUNC(gemv , float , float, s)
-BLAS_FUNC(gemv , double , double, d)
-BLAS_FUNC(gemv , cfloat , void, c)
-BLAS_FUNC(gemv , cdouble , void, z)
-#endif
+BLAS_FUNC(gemv , float , s)
+BLAS_FUNC(gemv , double , d)
+BLAS_FUNC(gemv , cfloat , c)
+BLAS_FUNC(gemv , cdouble , z)
-template<typename T, typename BT, int value>
-typename enable_if<is_floating_point<T>::value, scale_type<T,BT>>::type
+template<typename T, int value>
+typename enable_if<is_floating_point<T>::value, scale_type<T>>::type
getScale() { return T(value); }
-template<typename T, typename BT, int value>
-typename enable_if<is_complex<T>::value, scale_type<T,BT>>::type
+template<typename T, int value>
+typename enable_if<is_complex<T>::value, scale_type<T>>::type
getScale()
{
static T val(value);
- return (const BT *)&val;
+ return (const typename blas_base<T>::type *)&val;
}
CBLAS_TRANSPOSE
@@ -110,38 +114,6 @@ toCblasTranspose(af_mat_prop opt)
return out;
}
-using namespace std;
-
-
-#if (defined(OS_WIN) && !defined(USE_MKL)) || defined(IS_OPENBLAS)
-#define BT typename af::dtype_traits<T>::base_type
-#define REINTERPRET_CAST(PTR_TYPE, X) reinterpret_cast<PTR_TYPE>((X))
-#else
-template<typename T> struct cblas_types;
-
-template<>
-struct cblas_types<float> {
- typedef float base_type;
-};
-
-template<>
-struct cblas_types<cfloat> {
- typedef void base_type;
-};
-
-template<>
-struct cblas_types<double> {
- typedef double base_type;
-};
-
-template<>
-struct cblas_types<cdouble> {
- typedef void base_type;
-};
-#define BT typename cblas_types<T>::base_type
-#define REINTERPRET_CAST(PTR_TYPE, X) (X)
-#endif
-
template<typename T>
Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
af_mat_prop optLhs, af_mat_prop optRhs)
@@ -161,26 +133,33 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
//FIXME: Leaks on errors.
Array<T> out = createEmptyArray<T>(af::dim4(M, N, 1, 1));
- auto alpha = getScale<T, BT, 1>();
- auto beta = getScale<T, BT, 0>();
+ auto alpha = getScale<T, 1>();
+ auto beta = getScale<T, 0>();
dim4 lStrides = lhs.strides();
dim4 rStrides = rhs.strides();
+ using BT = typename blas_base<T>::type;
+ using CBT = const typename blas_base<T>::type;
+
if(rDims[bColDim] == 1) {
N = lDims[aColDim];
- gemv_func<T, BT>()(
+ gemv_func<T>()(
CblasColMajor, lOpts,
lDims[0], lDims[1],
- alpha, REINTERPRET_CAST(const BT*, lhs.get()), lStrides[1],
- REINTERPRET_CAST(const BT*, rhs.get()), rStrides[0],
- beta, REINTERPRET_CAST(BT*, out.get()), 1);
+ alpha,
+ reinterpret_cast<CBT*>(lhs.get()), lStrides[1],
+ reinterpret_cast<CBT*>(rhs.get()), rStrides[0],
+ beta,
+ reinterpret_cast<BT*>(out.get()), 1);
} else {
- gemm_func<T, BT>()(
+ gemm_func<T>()(
CblasColMajor, lOpts, rOpts,
M, N, K,
- alpha, REINTERPRET_CAST(const BT*, lhs.get()), lStrides[1],
- REINTERPRET_CAST(const BT*, rhs.get()), rStrides[1],
- beta, REINTERPRET_CAST(BT*, out.get()), out.dims()[0]);
+ alpha,
+ reinterpret_cast<CBT*>(lhs.get()), lStrides[1],
+ reinterpret_cast<CBT*>(rhs.get()), rStrides[1],
+ beta,
+ reinterpret_cast<BT*>(out.get()), out.dims()[0]);
}
return out;
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list