[arrayfire] 70/408: STYLE: Remove macros; Simplify templates;

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:11:18 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.

commit 1b315f95535d6a55ad2e53c52fba16600db09303
Author: Umar Arshad <umar at arrayfire.com>
Date:   Wed Jul 1 20:38:10 2015 -0400

    STYLE: Remove macros; Simplify templates;
    
    Works on my machine(TM); Not tested on Windows
---
 src/backend/cpu/blas.cpp | 155 ++++++++++++++++++++---------------------------
 1 file changed, 67 insertions(+), 88 deletions(-)

diff --git a/src/backend/cpu/blas.cpp b/src/backend/cpu/blas.cpp
index 39fa32b..0bbd399 100644
--- a/src/backend/cpu/blas.cpp
+++ b/src/backend/cpu/blas.cpp
@@ -24,77 +24,81 @@ namespace cpu
     using std::remove_const;
     using std::conditional;
 
-template<typename T, typename BT>
+  // Some implementations of BLAS require void* for complex pointers while
+  // others use float*/double*
+#if (defined(OS_WIN) && !defined(USE_MKL)) || defined(IS_OPENBLAS)
+  static const bool cplx_void_ptr = false;
+#else
+  static const bool cplx_void_ptr = true;
+#endif
+
+template<typename T, class Enable = void>
+struct blas_base {
+    using type = typename dtype_traits<T>::base_type;
+};
+
+template<typename T>
+struct blas_base <T, typename enable_if<is_complex<T>::value && cplx_void_ptr>::type> {
+    using type = void;
+};
+
+
+template<typename T>
 using cptr_type     =   typename conditional<   is_complex<T>::value,
-                                                const BT *,
+                                                const typename blas_base<T>::type *,
                                                 const T*>::type;
-template<typename T, typename BT>
+template<typename T>
 using ptr_type     =    typename conditional<   is_complex<T>::value,
-                                                BT *,
+                                                typename blas_base<T>::type *,
                                                 T*>::type;
-template<typename T, typename BT>
+template<typename T>
 using scale_type   =    typename conditional<   is_complex<T>::value,
-                                                const BT *,
+                                                const typename blas_base<T>::type *,
                                                 const T>::type;
 
-template<typename T, typename BT>
+template<typename T>
 using gemm_func_def = void (*)( const CBLAS_ORDER, const CBLAS_TRANSPOSE, const CBLAS_TRANSPOSE,
                                 const blasint, const blasint, const blasint,
-                                scale_type<T, BT>, cptr_type<T, BT>, const blasint,
-                                cptr_type<T, BT>, const blasint,
-                                scale_type<T, BT>, ptr_type<T, BT>, const blasint);
+                                scale_type<T>, cptr_type<T>, const blasint,
+                                cptr_type<T>, const blasint,
+                                scale_type<T>, ptr_type<T>, const blasint);
 
-template<typename T, typename BT>
+template<typename T>
 using gemv_func_def = void (*)( const CBLAS_ORDER, const CBLAS_TRANSPOSE,
                                 const blasint, const blasint,
-                                scale_type<T, BT>, cptr_type<T, BT>, const blasint,
-                                cptr_type<T, BT>, const blasint,
-                                scale_type<T, BT>, ptr_type<T, BT>, const blasint);
+                                scale_type<T>, cptr_type<T>, const blasint,
+                                cptr_type<T>, const blasint,
+                                scale_type<T>, ptr_type<T>, const blasint);
 
-#define BLAS_FUNC_DEF( FUNC )                                                      \
-template<typename T, typename BT> FUNC##_func_def<T, BT> FUNC##_func();
+#define BLAS_FUNC_DEF( FUNC )                           \
+template<typename T> FUNC##_func_def<T> FUNC##_func();
 
-
-#define BLAS_FUNC( FUNC, TYPE, BASE_TYPE, PREFIX )                                 \
-template<> FUNC##_func_def<TYPE, BASE_TYPE>     FUNC##_func<TYPE, BASE_TYPE>()     \
+#define BLAS_FUNC( FUNC, TYPE, PREFIX )                 \
+  template<> FUNC##_func_def<TYPE> FUNC##_func<TYPE>()  \
 { return &cblas_##PREFIX##FUNC; }
 
 BLAS_FUNC_DEF( gemm )
-#if (defined(OS_WIN) && !defined(USE_MKL)) || defined(IS_OPENBLAS)
-BLAS_FUNC(gemm , float   , float  , s)
-BLAS_FUNC(gemm , double  , double , d)
-BLAS_FUNC(gemm , cfloat  , float  , c)
-BLAS_FUNC(gemm , cdouble , double , z)
-#else
-BLAS_FUNC(gemm , float   , float , s)
-BLAS_FUNC(gemm , double  , double, d)
-BLAS_FUNC(gemm , cfloat  ,   void, c)
-BLAS_FUNC(gemm , cdouble ,   void, z)
-#endif
+BLAS_FUNC(gemm , float   , s)
+BLAS_FUNC(gemm , double  , d)
+BLAS_FUNC(gemm , cfloat  , c)
+BLAS_FUNC(gemm , cdouble , z)
 
 BLAS_FUNC_DEF(gemv)
-#if (defined(OS_WIN) && !defined(USE_MKL)) || defined(IS_OPENBLAS)
-BLAS_FUNC(gemv , float   ,  float , s)
-BLAS_FUNC(gemv , double  ,  double, d)
-BLAS_FUNC(gemv , cfloat  ,  float , c)
-BLAS_FUNC(gemv , cdouble ,  double, z)
-#else
-BLAS_FUNC(gemv , float   ,  float, s)
-BLAS_FUNC(gemv , double  , double, d)
-BLAS_FUNC(gemv , cfloat  ,   void, c)
-BLAS_FUNC(gemv , cdouble ,   void, z)
-#endif
+BLAS_FUNC(gemv , float   , s)
+BLAS_FUNC(gemv , double  , d)
+BLAS_FUNC(gemv , cfloat  , c)
+BLAS_FUNC(gemv , cdouble , z)
 
-template<typename T, typename BT, int value>
-typename enable_if<is_floating_point<T>::value, scale_type<T,BT>>::type
+template<typename T, int value>
+typename enable_if<is_floating_point<T>::value, scale_type<T>>::type
 getScale() { return T(value); }
 
-template<typename T, typename BT, int value>
-typename enable_if<is_complex<T>::value, scale_type<T,BT>>::type
+template<typename T, int value>
+typename enable_if<is_complex<T>::value, scale_type<T>>::type
 getScale()
 {
     static T val(value);
-    return (const BT *)&val;
+    return (const typename blas_base<T>::type *)&val;
 }
 
 CBLAS_TRANSPOSE
@@ -110,38 +114,6 @@ toCblasTranspose(af_mat_prop opt)
     return out;
 }
 
-using namespace std;
-
-
-#if (defined(OS_WIN) && !defined(USE_MKL)) || defined(IS_OPENBLAS)
-#define BT typename af::dtype_traits<T>::base_type
-#define REINTERPRET_CAST(PTR_TYPE, X) reinterpret_cast<PTR_TYPE>((X))
-#else
-template<typename T> struct cblas_types;
-
-template<>
-struct cblas_types<float> {
-    typedef float base_type;
-};
-
-template<>
-struct cblas_types<cfloat> {
-    typedef void base_type;
-};
-
-template<>
-struct cblas_types<double> {
-    typedef double base_type;
-};
-
-template<>
-struct cblas_types<cdouble> {
-    typedef void base_type;
-};
-#define BT typename cblas_types<T>::base_type
-#define REINTERPRET_CAST(PTR_TYPE, X) (X)
-#endif
-
 template<typename T>
 Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
                 af_mat_prop optLhs, af_mat_prop optRhs)
@@ -161,26 +133,33 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
 
     //FIXME: Leaks on errors.
     Array<T> out = createEmptyArray<T>(af::dim4(M, N, 1, 1));
-    auto alpha = getScale<T, BT, 1>();
-    auto beta  = getScale<T, BT, 0>();
+    auto alpha = getScale<T, 1>();
+    auto beta  = getScale<T, 0>();
 
     dim4 lStrides = lhs.strides();
     dim4 rStrides = rhs.strides();
+    using BT  =       typename blas_base<T>::type;
+    using CBT = const typename blas_base<T>::type;
+
     if(rDims[bColDim] == 1) {
         N = lDims[aColDim];
-        gemv_func<T, BT>()(
+        gemv_func<T>()(
             CblasColMajor, lOpts,
             lDims[0], lDims[1],
-            alpha, REINTERPRET_CAST(const BT*, lhs.get()), lStrides[1],
-            REINTERPRET_CAST(const BT*, rhs.get()), rStrides[0],
-            beta, REINTERPRET_CAST(BT*, out.get()), 1);
+            alpha,
+            reinterpret_cast<CBT*>(lhs.get()), lStrides[1],
+            reinterpret_cast<CBT*>(rhs.get()), rStrides[0],
+            beta,
+            reinterpret_cast<BT*>(out.get()), 1);
     } else {
-        gemm_func<T, BT>()(
+        gemm_func<T>()(
             CblasColMajor, lOpts, rOpts,
             M, N, K,
-            alpha, REINTERPRET_CAST(const BT*, lhs.get()), lStrides[1],
-            REINTERPRET_CAST(const BT*, rhs.get()), rStrides[1],
-            beta, REINTERPRET_CAST(BT*, out.get()), out.dims()[0]);
+            alpha,
+            reinterpret_cast<CBT*>(lhs.get()), lStrides[1],
+            reinterpret_cast<CBT*>(rhs.get()), rStrides[1],
+            beta,
+            reinterpret_cast<BT*>(out.get()), out.dims()[0]);
     }
 
     return out;

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list