[shark] 04/12: added support for Generic and OpenBLAS. compiles

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Thu Dec 10 19:48:13 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch master
in repository shark.

commit 57e553b495692ad26cc3c3fbc0769cafcaf94dd4
Author: Oswin Krause <oswin.krause at di.ku.dk>
Date:   Mon Nov 16 16:32:25 2015 +0100

    added support for Generic and OpenBLAS. compiles
---
 CMakeLists.txt                                       | 17 ++++++++++++++---
 include/shark/LinAlg/BLAS/kernels/atlas/potrf.hpp    |  4 +++-
 .../shark/LinAlg/BLAS/kernels/cblas/cblas_inc.hpp    | 10 +++++++---
 include/shark/LinAlg/BLAS/kernels/cblas/dot.hpp      | 12 ++++++------
 include/shark/LinAlg/BLAS/kernels/cblas/gemm.hpp     | 20 ++++++++++----------
 include/shark/LinAlg/BLAS/kernels/cblas/gemv.hpp     | 20 ++++++++++----------
 .../LinAlg/BLAS/kernels/{atlas => cblas}/tpmv.hpp    | 12 ++++++------
 include/shark/LinAlg/BLAS/kernels/cblas/trmm.hpp     |  6 +++---
 include/shark/LinAlg/BLAS/kernels/cblas/trmv.hpp     |  8 ++++----
 include/shark/LinAlg/BLAS/kernels/cblas/trsm.hpp     | 12 ++++++------
 include/shark/LinAlg/BLAS/kernels/cblas/trsv.hpp     |  8 ++++----
 include/shark/LinAlg/BLAS/kernels/dot.hpp            |  8 ++++----
 12 files changed, 77 insertions(+), 60 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index d4f3997..4262917 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -9,6 +9,7 @@
 project( shark )
 
 set_property(GLOBAL PROPERTY USE_FOLDERS ON)
+INCLUDE (CheckFunctionExists)
 
 cmake_minimum_required( VERSION 2.8 )
 cmake_policy(SET CMP0003 NEW)
@@ -239,6 +240,9 @@ if( ENABLE_CBLAS )
 		set(CBLAS_LIBRARIES "-framework Accelerate" )
 	else()
 		#todo: do a propper vendor check
+		find_library(OPENBLAS_LIBRARY openblas
+			HINTS ${CBLAS_ROOT}/lib /opt/local/lib
+		)
 		find_library(CBLAS_LIBRARY cblas
 			HINTS ${ATLAS_ROOT}/lib ${CBLAS_ROOT}/lib /opt/local/lib /usr/lib64/atlas/
 		)
@@ -249,6 +253,7 @@ if( ENABLE_CBLAS )
 			HINTS ${ATLAS_ROOT}/lib ${CBLAS_ROOT}/lib /opt/local/lib /usr/lib64/atlas/
 		)
 		mark_as_advanced(
+			OPENBLAS_LIBRARY
 			CBLAS_LIBRARY
 			CLAPACK_LIBRARY
 			ATLAS_LIBRARY
@@ -256,14 +261,20 @@ if( ENABLE_CBLAS )
 		#find the cblas.h include path
 		if(CBLAS_LIBRARY )
 			get_filename_component(CBLAS_LIB_PATH ${CBLAS_LIBRARY} PATH )
-			
+		elseif( OPENBLAS_LIBRARY)	
+			get_filename_component(CBLAS_LIB_PATH ${OPENBLAS_LIBRARY} PATH )
+		endif()
+		if(CBLAS_LIB_PATH)
 			find_file(CBLAS_INCLUDES cblas.h 
-				PATHS ${CBLAS_LIB_PATH} ${CBLAS_LIB_PATH}../include ${CBLAS_LIB_PATH}../include/atlas
+				PATHS ${CBLAS_LIB_PATH} ${CBLAS_LIB_PATH}/../include ${CBLAS_LIB_PATH}/../include/atlas
 			)
 			get_filename_component(CBLAS_INCLUDES ${CBLAS_INCLUDES} PATH )
 		endif()
 
-		if( CBLAS_LIBRARY AND CLAPACK_LIBRARY AND ATLAS_LIBRARY AND CBLAS_INCLUDES)
+		if( OPENBLAS_LIBRARY AND CBLAS_INCLUDES)
+			set(CBLAS_VENDOR "OpenBLAS")
+			set(CBLAS_LIBRARIES ${OPENBLAS_LIBRARY})
+		elseif( CBLAS_LIBRARY AND CLAPACK_LIBRARY AND ATLAS_LIBRARY AND CBLAS_INCLUDES)
 			set(CBLAS_VENDOR "ATLAS")
 			set(CBLAS_LIBRARIES ${CLAPACK_LIBRARY} ${CBLAS_LIBRARY} ${ATLAS_LIBRARY})
 		elseif( CBLAS_LIBRARY AND CBLAS_INCLUDES)
diff --git a/include/shark/LinAlg/BLAS/kernels/atlas/potrf.hpp b/include/shark/LinAlg/BLAS/kernels/atlas/potrf.hpp
index 524e48b..dbafea9 100644
--- a/include/shark/LinAlg/BLAS/kernels/atlas/potrf.hpp
+++ b/include/shark/LinAlg/BLAS/kernels/atlas/potrf.hpp
@@ -34,7 +34,9 @@
 #define SHARK_LINALG_BLAS_KERNELS_ATLAS_POTRF_H
 
 #include "../cblas/cblas_inc.hpp"
-#include <clapack.h>
+extern "C"{
+	#include <clapack.h>
+}
 
 namespace shark {
 namespace blas {
diff --git a/include/shark/LinAlg/BLAS/kernels/cblas/cblas_inc.hpp b/include/shark/LinAlg/BLAS/kernels/cblas/cblas_inc.hpp
index 3b945e6..5100498 100644
--- a/include/shark/LinAlg/BLAS/kernels/cblas/cblas_inc.hpp
+++ b/include/shark/LinAlg/BLAS/kernels/cblas/cblas_inc.hpp
@@ -28,8 +28,8 @@
  *
  */
 
-#ifndef SHARK_LINALG_BLAS_KERNELS_ATLAS_CBLAS_INC_HPP
-#define SHARK_LINALG_BLAS_KERNELS_ATLAS_CBLAS_INC_HPP
+#ifndef SHARK_LINALG_BLAS_KERNELS_CBLAS_CBLAS_INC_HPP
+#define SHARK_LINALG_BLAS_KERNELS_CBLAS_CBLAS_INC_HPP
 
 #ifdef __APPLE__
 
@@ -95,9 +95,13 @@ template<> struct storage_order<column_major> {
 
 }}}
 
-
+#ifndef OPENBLAS_CONST
 typedef void cblas_float_complex_type;
 typedef void cblas_double_complex_type;
+#else
+typedef float cblas_float_complex_type;
+typedef double cblas_double_complex_type;
+#endif
 
 
 #endif
diff --git a/include/shark/LinAlg/BLAS/kernels/cblas/dot.hpp b/include/shark/LinAlg/BLAS/kernels/cblas/dot.hpp
index f0bfa12..db560ca 100644
--- a/include/shark/LinAlg/BLAS/kernels/cblas/dot.hpp
+++ b/include/shark/LinAlg/BLAS/kernels/cblas/dot.hpp
@@ -58,9 +58,9 @@ inline void dot(int N,
 	std::complex<float>& result
 ) {
 	cblas_cdotu_sub(N, 
-		static_cast<cblas_float_complex_type const* >(x), strideX, 
-		static_cast<cblas_float_complex_type const* >(y), strideY,
-		static_cast<cblas_float_complex_type*>(&result)
+		reinterpret_cast<cblas_float_complex_type const* >(x), strideX, 
+		reinterpret_cast<cblas_float_complex_type const* >(y), strideY,
+		reinterpret_cast<cblas_float_complex_type*>(&result)
 	);
 }
 
@@ -70,9 +70,9 @@ inline void dot(int N,
 	std::complex<double>& result
 ) {
 	cblas_zdotu_sub(N, 
-		static_cast<cblas_double_complex_type const* >(x), strideX,
-		static_cast<cblas_double_complex_type const* >(y), strideY,
-		static_cast<cblas_double_complex_type*>(&result)
+		reinterpret_cast<cblas_double_complex_type const* >(x), strideX,
+		reinterpret_cast<cblas_double_complex_type const* >(y), strideY,
+		reinterpret_cast<cblas_double_complex_type*>(&result)
 	);
 }
 
diff --git a/include/shark/LinAlg/BLAS/kernels/cblas/gemm.hpp b/include/shark/LinAlg/BLAS/kernels/cblas/gemm.hpp
index a9cf430..6adb836 100644
--- a/include/shark/LinAlg/BLAS/kernels/cblas/gemm.hpp
+++ b/include/shark/LinAlg/BLAS/kernels/cblas/gemm.hpp
@@ -84,11 +84,11 @@ inline void gemm(
 	cblas_cgemm(
 		Order, TransA, TransB,
 		M, N, K,
-		static_cast<cblas_double_complex_type const *>(&alphaArg),
-		static_cast<cblas_double_complex_type const *>(A), lda,
-		static_cast<cblas_double_complex_type const *>(B), ldb,
-		static_cast<cblas_double_complex_type const *>(&betaArg),
-		static_cast<cblas_double_complex_type *>(C), ldc
+		reinterpret_cast<cblas_float_complex_type const *>(&alphaArg),
+		reinterpret_cast<cblas_float_complex_type const *>(A), lda,
+		reinterpret_cast<cblas_float_complex_type const *>(B), ldb,
+		reinterpret_cast<cblas_float_complex_type const *>(&betaArg),
+		reinterpret_cast<cblas_float_complex_type *>(C), ldc
 	);
 }
 
@@ -106,11 +106,11 @@ inline void gemm(
 	cblas_zgemm(
 		Order, TransA, TransB,
 		M, N, K,
-		static_cast<void const *>(&alphaArg),
-		static_cast<void const *>(A), lda,
-		static_cast<void const *>(B), ldb,
-		static_cast<void const *>(&betaArg),
-		static_cast<void *>(C), ldc
+		reinterpret_cast<cblas_double_complex_type const *>(&alphaArg),
+		reinterpret_cast<cblas_double_complex_type const *>(A), lda,
+		reinterpret_cast<cblas_double_complex_type const *>(B), ldb,
+		reinterpret_cast<cblas_double_complex_type const *>(&betaArg),
+		reinterpret_cast<cblas_double_complex_type *>(C), ldc
 	);
 }
 
diff --git a/include/shark/LinAlg/BLAS/kernels/cblas/gemv.hpp b/include/shark/LinAlg/BLAS/kernels/cblas/gemv.hpp
index 566a1aa..76d4ec7 100644
--- a/include/shark/LinAlg/BLAS/kernels/cblas/gemv.hpp
+++ b/include/shark/LinAlg/BLAS/kernels/cblas/gemv.hpp
@@ -69,11 +69,11 @@ inline void gemv(CBLAS_ORDER const Order,
 	std::complex<float> alphaArg(alpha,0);
 	std::complex<float> betaArg(beta,0);
 	cblas_cgemv(Order, TransA, M, N,
-	        static_cast<cblas_float_complex_type const *>(&alphaArg),
-	        static_cast<cblas_float_complex_type const *>(A), lda,
-	        static_cast<cblas_float_complex_type const *>(X), incX,
-	        static_cast<cblas_float_complex_type const *>(&betaArg),
-	        static_cast<cblas_float_complex_type *>(Y), incY);
+	        reinterpret_cast<cblas_float_complex_type const *>(&alphaArg),
+	        reinterpret_cast<cblas_float_complex_type const *>(A), lda,
+	        reinterpret_cast<cblas_float_complex_type const *>(X), incX,
+	        reinterpret_cast<cblas_float_complex_type const *>(&betaArg),
+	        reinterpret_cast<cblas_float_complex_type *>(Y), incY);
 }
 
 inline void gemv(CBLAS_ORDER const Order,
@@ -87,11 +87,11 @@ inline void gemv(CBLAS_ORDER const Order,
 	std::complex<double> alphaArg(alpha,0);
 	std::complex<double> betaArg(beta,0);
 	cblas_zgemv(Order, TransA, M, N,
-	        static_cast<cblas_double_complex_type const *>(&alphaArg),
-	        static_cast<cblas_double_complex_type const *>(A), lda,
-	        static_cast<cblas_double_complex_type const *>(X), incX,
-	        static_cast<cblas_double_complex_type const *>(&betaArg),
-	        static_cast<cblas_double_complex_type *>(Y), incY);
+	        reinterpret_cast<cblas_double_complex_type const *>(&alphaArg),
+	        reinterpret_cast<cblas_double_complex_type const *>(A), lda,
+	        reinterpret_cast<cblas_double_complex_type const *>(X), incX,
+	        reinterpret_cast<cblas_double_complex_type const *>(&betaArg),
+	        reinterpret_cast<cblas_double_complex_type *>(Y), incY);
 }
 
 
diff --git a/include/shark/LinAlg/BLAS/kernels/atlas/tpmv.hpp b/include/shark/LinAlg/BLAS/kernels/cblas/tpmv.hpp
similarity index 92%
rename from include/shark/LinAlg/BLAS/kernels/atlas/tpmv.hpp
rename to include/shark/LinAlg/BLAS/kernels/cblas/tpmv.hpp
index 413bee9..68a6544 100644
--- a/include/shark/LinAlg/BLAS/kernels/atlas/tpmv.hpp
+++ b/include/shark/LinAlg/BLAS/kernels/cblas/tpmv.hpp
@@ -29,8 +29,8 @@
  *
  */
 //===========================================================================
-#ifndef SHARK_LINALG_BLAS_KERNELS_ATLAS_TPMV_HPP
-#define SHARK_LINALG_BLAS_KERNELS_ATLAS_TPMV_HPP
+#ifndef SHARK_LINALG_BLAS_KERNELS_CBLAS_TPMV_HPP
+#define SHARK_LINALG_BLAS_KERNELS_CBLAS_TPMV_HPP
 
 #include "cblas_inc.hpp"
 #include "../../matrix_proxy.hpp"
@@ -80,8 +80,8 @@ inline void tpmv(
         std::complex<float>* X, int const incX
 ) {
 	cblas_ctpmv(Order, uplo, transA, unit, N, 
-		static_cast<void const *>(A),
-	        static_cast<void *>(X), incX
+		reinterpret_cast<cblas_float_complex_type const *>(A),
+	        reinterpret_cast<cblas_float_complex_type *>(X), incX
 	);
 }
 
@@ -95,8 +95,8 @@ inline void tpmv(
         std::complex<double>* X, int const incX
 ) {
 	cblas_ztpmv(Order, uplo, transA, unit, N, 
-		static_cast<void const *>(A),
-	        static_cast<void *>(X), incX
+		reinterpret_cast<cblas_double_complex_type const *>(A),
+	        reinterpret_cast<cblas_double_complex_type *>(X), incX
 	);
 }
 
diff --git a/include/shark/LinAlg/BLAS/kernels/cblas/trmm.hpp b/include/shark/LinAlg/BLAS/kernels/cblas/trmm.hpp
index a9e8d6f..9f4da55 100644
--- a/include/shark/LinAlg/BLAS/kernels/cblas/trmm.hpp
+++ b/include/shark/LinAlg/BLAS/kernels/cblas/trmm.hpp
@@ -108,9 +108,9 @@ inline void trmm(
 ) {
 	std::complex<double> alpha = 1.0;
 	cblas_ztrmm(order, side, uplo, transA, unit, M, N, 
-		static_cast<cblas_double_complex_type const *>(&alpha),
-		static_cast<cblas_double_complex_type const *>(A), lda,
-	        static_cast<cblas_double_complex_type *>(B), incB
+		reinterpret_cast<cblas_double_complex_type const *>(&alpha),
+		reinterpret_cast<cblas_double_complex_type const *>(A), lda,
+	        reinterpret_cast<cblas_double_complex_type *>(B), incB
 	);
 }
 
diff --git a/include/shark/LinAlg/BLAS/kernels/cblas/trmv.hpp b/include/shark/LinAlg/BLAS/kernels/cblas/trmv.hpp
index f385700..6a0115c 100644
--- a/include/shark/LinAlg/BLAS/kernels/cblas/trmv.hpp
+++ b/include/shark/LinAlg/BLAS/kernels/cblas/trmv.hpp
@@ -80,8 +80,8 @@ inline void trmv(
         std::complex<float>* X, int const incX
 ) {
 	cblas_ctrmv(Order, uplo, transA, unit, N, 
-		static_cast<cblas_float_complex_type const *>(A), lda,
-	        static_cast<cblas_float_complex_type *>(X), incX
+		reinterpret_cast<cblas_float_complex_type const *>(A), lda,
+	        reinterpret_cast<cblas_float_complex_type *>(X), incX
 	);
 }
 
@@ -95,8 +95,8 @@ inline void trmv(
         std::complex<double>* X, int const incX
 ) {
 	cblas_ztrmv(Order, uplo, transA, unit, N, 
-		static_cast<cblas_double_complex_type const *>(A), lda,
-	        static_cast<cblas_double_complex_type *>(X), incX
+		reinterpret_cast<cblas_double_complex_type const *>(A), lda,
+	        reinterpret_cast<cblas_double_complex_type *>(X), incX
 	);
 }
 
diff --git a/include/shark/LinAlg/BLAS/kernels/cblas/trsm.hpp b/include/shark/LinAlg/BLAS/kernels/cblas/trsm.hpp
index d45f25a..7fa3c09 100644
--- a/include/shark/LinAlg/BLAS/kernels/cblas/trsm.hpp
+++ b/include/shark/LinAlg/BLAS/kernels/cblas/trsm.hpp
@@ -62,9 +62,9 @@ inline void trsm(
 ) {
 	std::complex<float> alpha(1.0,0);
 	cblas_ctrsm(order, side, uplo, transA, unit,n, nRHS,
-		static_cast<cblas_float_complex_type const *>(&alpha),
-	        static_cast<cblas_float_complex_type const *>(A), lda,
-	        static_cast<cblas_float_complex_type *>(B), ldb);
+		reinterpret_cast<cblas_float_complex_type const *>(&alpha),
+	        reinterpret_cast<cblas_float_complex_type const *>(A), lda,
+	        reinterpret_cast<cblas_float_complex_type *>(B), ldb);
 }
 inline void trsm(
 	CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA, 
@@ -74,9 +74,9 @@ inline void trsm(
 ) {
 	std::complex<double> alpha(1.0,0);
 	cblas_ztrsm(order, side, uplo, transA, unit,n, nRHS,
-		static_cast<cblas_double_complex_type const *>(&alpha),
-	        static_cast<cblas_double_complex_type const *>(A), lda,
-	        static_cast<cblas_double_complex_type *>(B), ldb);
+		reinterpret_cast<cblas_double_complex_type const *>(&alpha),
+	        reinterpret_cast<cblas_double_complex_type const *>(A), lda,
+	        reinterpret_cast<cblas_double_complex_type *>(B), ldb);
 }
 
 // trsm(): solves A system of linear equations A * X = B
diff --git a/include/shark/LinAlg/BLAS/kernels/cblas/trsv.hpp b/include/shark/LinAlg/BLAS/kernels/cblas/trsv.hpp
index 1af8b4e..41fe34a 100644
--- a/include/shark/LinAlg/BLAS/kernels/cblas/trsv.hpp
+++ b/include/shark/LinAlg/BLAS/kernels/cblas/trsv.hpp
@@ -61,8 +61,8 @@ inline void trsv(
 	std::complex<float> const *A, int lda, std::complex<float> *b, int strideX
 ){
 	cblas_ctrsv(order, uplo, transA, unit,n,
-	        static_cast<cblas_float_complex_type const *>(A), lda,
-	        static_cast<cblas_float_complex_type *>(b), strideX);
+	        reinterpret_cast<cblas_float_complex_type const *>(A), lda,
+	        reinterpret_cast<cblas_float_complex_type *>(b), strideX);
 }
 inline void trsv(
 	CBLAS_ORDER order, CBLAS_UPLO uplo,
@@ -71,8 +71,8 @@ inline void trsv(
 	std::complex<double> const *A, int lda, std::complex<double> *b, int strideX
 ){
 	cblas_ztrsv(order, uplo, transA, unit,n,
-	        static_cast<cblas_double_complex_type const *>(A), lda,
-	        static_cast<cblas_double_complex_type *>(b), strideX);
+	        reinterpret_cast<cblas_double_complex_type const *>(A), lda,
+	        reinterpret_cast<cblas_double_complex_type *>(b), strideX);
 }
 
 // trsv(): solves A system of linear equations A * x = b
diff --git a/include/shark/LinAlg/BLAS/kernels/dot.hpp b/include/shark/LinAlg/BLAS/kernels/dot.hpp
index b2f05ed..773df7a 100644
--- a/include/shark/LinAlg/BLAS/kernels/dot.hpp
+++ b/include/shark/LinAlg/BLAS/kernels/dot.hpp
@@ -32,9 +32,9 @@
 
 #include "default/dot.hpp"
 
-#ifdef SHARK_USE_CBLAS
-#include "cblas/dot.hpp"
-#else
+//~ #ifdef SHARK_USE_CBLAS 
+//~ #include "cblas/dot.hpp" // not faster in many cases, the compiler optimizes quite well.
+//~ #else
 // if no bindings are included, we have to provide the default has_optimized_dot
 // otherwise the binding will take care of this
 namespace shark { namespace blas { namespace bindings{
@@ -42,7 +42,7 @@ template<class V1, class V2,class result_type>
 struct  has_optimized_dot
 : public boost::mpl::false_{};
 }}}
-#endif
+//~ #endif
 	
 namespace shark { namespace blas {namespace kernels{
 	

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/shark.git



More information about the debian-science-commits mailing list