[clblas] 45/61: adding zgemm kernel for hawaii
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Fri Jul 24 22:49:47 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch master
in repository clblas.
commit 6623809f354ae2dd6d5fee22695320f245b89591
Author: David Tanner <guacamoleo at gmail.com>
Date: Tue Jun 16 12:57:12 2015 -0500
adding zgemm kernel for hawaii
---
src/library/CMakeLists.txt | 4 +
src/library/bingen.cmake | 1 +
src/library/blas/functor/bonaire.cc | 2 +-
src/library/blas/functor/gcn_zgemm.cc | 354 +++++++++++++++++++++++++
src/library/blas/functor/hawaii.cc | 23 +-
src/library/blas/functor/include/gcn_zgemm.h | 62 +++++
src/library/blas/functor/include/hawaii.h | 1 +
src/library/blas/gens/clTemplates/zgemm_gcn.cl | 314 ++++++++++++++++++++++
8 files changed, 759 insertions(+), 2 deletions(-)
diff --git a/src/library/CMakeLists.txt b/src/library/CMakeLists.txt
index b5f5d66..e45a2aa 100644
--- a/src/library/CMakeLists.txt
+++ b/src/library/CMakeLists.txt
@@ -76,6 +76,7 @@ set(SRC_BLAS
blas/functor/hawaii_sgemmSplit64_32.cc
blas/functor/gcn_dgemmCommon.cc
blas/functor/gcn_sgemm.cc
+ blas/functor/gcn_zgemm.cc
blas/functor/gcn_dgemmSmallMatrices.cc
blas/functor/gcn_sgemmSmallMatrices.cc
blas/functor/hawaii_sgemmBranchKernel.cc
@@ -107,6 +108,7 @@ set(SRC_BLAS_HEADERS
blas/functor/include/hawaii_sgemmSplit64_32.h
blas/functor/include/gcn_dgemmCommon.h
blas/functor/include/gcn_sgemm.h
+ blas/functor/include/gcn_zgemm.h
blas/functor/include/gcn_dgemmSmallMatrices.h
blas/functor/include/gcn_sgemmSmallMatrices.h
blas/functor/include/hawaii_sgemmBranchKernel.h
@@ -231,6 +233,7 @@ set (SRC_CL_TEMPLATES
dgemm_gcn_SmallMatrices.cl
sgemm_gcn_SmallMatrices.cl
sgemm_gcn.cl
+ zgemm_gcn.cl
)
set(SRC_CL_TEMPLATES_GEN
@@ -248,6 +251,7 @@ set(SRC_CL_TEMPLATES_GEN
sgemm_gcn_SmallMatrices.clTahiti_64.bin.cl
sgemm_gcn_SmallMatrices.clBonaire_64.bin.cl
sgemm_gcn.clHawaii_64.bin.cl
+ zgemm_gcn.clHawaii_64.bin.cl
sgemm_gcn.clBonaire_64.bin.cl
sgemm_gcn.clTahiti_64.bin.cl
sgemm_hawaiiSplit64_32.clHawaii_64.bin.cl
diff --git a/src/library/bingen.cmake b/src/library/bingen.cmake
index bb6515d..7b13557 100644
--- a/src/library/bingen.cmake
+++ b/src/library/bingen.cmake
@@ -13,6 +13,7 @@ set (BIN_CL_TEMPLATES_HAWAII_CL2
${CLTEMPLATE_PATH}/dgemm_hawaiiSplitKernel.cl
${CLTEMPLATE_PATH}/sgemm_hawaiiSplitKernel.cl
${CLTEMPLATE_PATH}/sgemm_gcn.cl
+${CLTEMPLATE_PATH}/zgemm_gcn.cl
${CLTEMPLATE_PATH}/sgemm_gcn_SmallMatrices.cl
${CLTEMPLATE_PATH}/sgemm_hawaiiSplit64_32.cl
)
diff --git a/src/library/blas/functor/bonaire.cc b/src/library/blas/functor/bonaire.cc
index a97d9b9..4e2989d 100644
--- a/src/library/blas/functor/bonaire.cc
+++ b/src/library/blas/functor/bonaire.cc
@@ -37,7 +37,7 @@ FunctorSelectorBonaire::FunctorSelectorBonaire()
-// The selector function for SGEMM on hawaii
+// The selector function for SGEMM on bonaire
clblasSgemmFunctor * FunctorSelectorBonaire::select_sgemm_specific(clblasSgemmFunctor::Args & args)
{
#ifdef CLBLAS_BONAIRE_DYNAMIC_KERNEL
diff --git a/src/library/blas/functor/gcn_zgemm.cc b/src/library/blas/functor/gcn_zgemm.cc
new file mode 100644
index 0000000..72ac9a1
--- /dev/null
+++ b/src/library/blas/functor/gcn_zgemm.cc
@@ -0,0 +1,354 @@
+#include <stdio.h>
+#include <string.h>
+#include <clBLAS.h>
+
+#include <devinfo.h>
+#include "clblas-internal.h"
+#include "solution_seq.h"
+
+#include <functor.h>
+#include <binary_lookup.h>
+#include <iostream>
+
+#include <functor_xgemm.h>
+#include <tahiti.h>
+#include <hawaii.h>
+#include <gcn_ZGEMM.h>
+
+#include "BinaryBuild.h"
+
+//for the moment only managing source code and cl binary
+
+#if BUILD_KERNEL_FROM_STRING
+#include "zgemm_gcn.clT"
+#else
+
+#include "zgemm_gcn.clHawaii_64.bin.clT"
+//#include "zgemm_gcn.clBonaire_64.bin.clT"
+
+//#include "ZGEMM_gcn.clTahiti_64.bin.clT"
+#endif
+
+
+//
+// The name of the 'const char *' providing the kernel OpenCL source
+//
+// dgemm_TATB_DIVN_DIVM_DIVK_BS0xBS1_NV0xNV1
+//
+// For instance, DGEMM_SRC_NAME(N,T,32,64,8,8,8,4,8) is dgemm_NT_32_64_8_8x8_4x8
+//
+#define ZGEMM_SRC_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) zgemm_##TA##TB##_##DIVN##_##DIVM##_##DIVK##_##BS0##x##BS1##_##NV0##x##NV1##MULT
+#define ZGEMM_SRC_NAME_TAHITI(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT) zgemm_##TA##TB##_##DIVN##_##DIVM##_##DIVK##_##BS0##x##BS1##_##NV0##x##NV1##MULT##_##BITS##_bin_Tahiti
+#define ZGEMM_SRC_NAME_HAWAII(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT) zgemm_##TA##TB##_##DIVN##_##DIVM##_##DIVK##_##BS0##x##BS1##_##NV0##x##NV1##MULT##_##BITS##_bin_Hawaii
+#define ZGEMM_SRC_NAME_BONAIRE(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT) zgemm_##TA##TB##_##DIVN##_##DIVM##_##DIVK##_##BS0##x##BS1##_##NV0##x##NV1##MULT##_##BITS##_bin_Bonaire
+
+//
+// The name of the 'const char []' global variable that contain the SPIR data.
+// That name is similar to the one produced by DGEMM_SRC_NAME but suffixed by _spir
+//
+#define ZGEMM_SPIR_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) zgemm_##TA##TB##_##DIVN##_##DIVM##_##DIVK##_##BS0##x##BS1##_##NV0##x##NV1_spir
+
+//
+// The name of the 'const char []' global variable that contain the CL binaries data.
+// That name is similar to the one produced by DGEMM_SRC_NAME but suffixed by _bin
+//
+
+
+// The name of the kernel itself.
+// This is basically the name returned by DGEMM_SRC_NAME but as string
+//
+#define ZGEMM_KERNEL_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) "zgemm_" #TA #TB "_" #DIVN "_" #DIVM "_" #DIVK "_" #BS0 "x" #BS1 "_" #NV0 "x" #NV1 #MULT
+
+//
+// Helpers to transform N and T in proper clblas values for the macros above
+//
+#define trans_N clblasNoTrans
+#define trans_T clblasTrans
+
+
+// Fill a variant descriptor using OpenCL source
+#define ZGEMM_VARIANT_SRC(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) { \
+ ZGEMM_KERNEL_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) , \
+ ZGEMM_SRC_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) , \
+ NULL, NULL, 0, \
+ trans_##TA, trans_##TB, \
+ DIVN,DIVM,DIVK, \
+{ BS0, BS1 } , \
+{ NV0, NV1 } \
+}
+
+// Fill a variant descriptor using SPIR
+#define ZGEMM_VARIANT_SPIR(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) { \
+ ZGEMM_KERNEL_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1) , \
+ NULL , "-x spir -spir-std=1.2" \
+ ZGEMM_SPIR_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1), \
+ sizeof(ZGEMM_SPIR_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1)), \
+ trans_##TA,trans_##TB, \
+ DIVN,DIVM,DIVK, \
+{ BS0, BS1 } , \
+{ NV0, NV1 } \
+}
+
+// Fill a variant descriptor using CL Binaries
+#define ZGEMM_VARIANT_BIN_CL1(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,DEVICE,MULT) { \
+ ZGEMM_KERNEL_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) , \
+ NULL , NULL, \
+ ZGEMM_SRC_NAME##_##DEVICE(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT), \
+ sizeof(ZGEMM_SRC_NAME##_##DEVICE(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT)), \
+ trans_##TA,trans_##TB, \
+ DIVN,DIVM,DIVK, \
+{ BS0, BS1 } , \
+{ NV0, NV1 } \
+}
+
+
+#define ZGEMM_VARIANT_BIN_CL2(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,DEVICE,MULT) { \
+ ZGEMM_KERNEL_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) , \
+ NULL , "-cl-std=CL2.0", \
+ ZGEMM_SRC_NAME##_##DEVICE(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT), \
+ sizeof(ZGEMM_SRC_NAME##_##DEVICE(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT)), \
+ trans_##TA,trans_##TB, \
+ DIVN,DIVM,DIVK, \
+{ BS0, BS1 } , \
+{ NV0, NV1 } \
+}
+
+// Make it 1 to enable additional debug 'print'
+#define VERB 0
+
+// Just because the full name is too long
+typedef clblasZgemmFunctorGCN::Variant Variant ;
+
+//
+// The static cache used to store all instances of clblasZgemmFunctorGCN
+//
+typedef clblasFunctorCache<clblasZgemmFunctorGCN,const Variant *> Cache ;
+static Cache cache ;
+
+
+// return true iff a kernel variant is applicable to the specified args
+static bool applicable( const Variant & var, clblasZgemmFunctor::Args & args )
+{
+#if 0
+ // Transpose values are tested in select_variant
+ if ( args.transA != var.transA ) return false ;
+ if ( args.transB != var.transB ) return false ;
+#endif
+ if ( args.N % var.divN != 0 ) return false ;
+ if ( args.M % var.divM != 0 ) return false ;
+ if ( args.K % var.divK != 0 ) return false ;
+ //if ( args.beta==0 && var.mult.compare("__ALPHA")!=0)
+ // return false ;
+
+ return true ;
+}
+
+//
+// The goal of this function is to return the Variant to be used
+// for the ZGEMM specified by 'args'.
+//
+// The variants are typically tested sequentially from the more
+// specific to the more generic. Additional conditions can be
+// placed into the surrounding 'if' (typically that would be
+// to perform additional tests on M, N and K).
+//
+//
+
+static const Variant * select_variant( clblasZgemmFunctor::Args & args, const char* DevName, cl_uint _64BitsUse )
+{
+ if(_64BitsUse!=64)
+ {
+ std::cout<<"we don't support clblas on 32 bits"<< std::endl;
+ assert(1);
+ return NULL;
+ }
+
+ if ( args.transA == clblasNoTrans )
+ {
+ if ( args.transB == clblasTrans )
+ {
+ //we only manage the binary version here
+ if(!strcmp(DevName, "Hawaii"))
+ {
+#ifndef CLBLAS_HAWAII_DYNAMIC_KERNEL
+ //zgemm_NT_32_64_8_16x16_2x4__ALPHABETA_64_bin_Hawaii
+ static const Variant variant = ZGEMM_VARIANT_BIN_CL2(N,T,64,32,8,16,16,2,4,64,HAWAII, __ALPHABETA) ;
+ if ( applicable(variant,args) ) // needs to return true
+ return &variant;
+#endif
+ }
+ }
+ }
+
+ return NULL ; // No suitable variant ... will use the fallback
+
+}
+
+/*
+ * constructor
+ */
+clblasZgemmFunctorGCN::clblasZgemmFunctorGCN(Args & args, const Variant * variant, cl_int & err) :
+ m_program(0) , m_variant(variant)
+{
+
+ cl_device_id device;
+ cl_context context;
+
+ cl_command_queue queue = args.queue;
+ err = getDeviceAndContext(queue, device, context);
+ if( err != CL_SUCCESS )
+ {
+ return;
+ }
+
+ if (VERB) printf(" ===> GET KERNEL %s\n", this->m_variant->kernel_name) ;
+
+ //Ben do I use the correct "kernel_name"?
+ BinaryLookup bl(context, device, "clblasZgemmFunctorGCN");
+ //clGetDeviceInfo(device, CL_DEVICE_NAME);
+
+ bl.variantRaw( this->m_variant->kernel_name, strlen(this->m_variant->kernel_name)+1 ) ;
+
+ if ( !bl.found() ) // may create empty file or may wait until file is ready
+ {
+ if ( this->m_variant->bin != 0 )
+ {
+ // build from a pre-compiled version of the kernel (SPIR or cl binaries)
+ err = bl.buildFromBinary(this->m_variant->bin, this->m_variant->bin_size, this->m_variant->build_options);
+ }
+ else
+ {
+ // directly build from a char*
+ err = bl.buildFromSource(this->m_variant->source);
+ }
+
+ if ( err != CL_SUCCESS )
+ {
+ if (VERB) printf(" ===> BUILD PROBLEM\n") ;
+
+ return;
+ }
+ }
+
+ this->m_program = bl.getProgram();
+}
+
+clblasStatus clblasZgemmFunctorGCN::execute(Args &args)
+{
+ cl_int err;
+ cl_command_queue queue = args.queue;
+
+ if (VERB) printf(" ===> EXECUTE KERNEL %s\n", this->m_variant->kernel_name) ;
+
+ cl_kernel kernel = clCreateKernel( this->m_program, this->m_variant->kernel_name, &err);
+ if (err != CL_SUCCESS) return clblasStatus(err) ;
+
+ if (VERB) printf(" ===> FOUND %s\n", this->m_variant->kernel_name) ;
+
+ int M = args.M, N = args.N, K = args.K;
+ int lda = args.lda, ldb = args.ldb, ldc = args.ldc;
+
+ int offsetA = args.offA;
+ int offsetB = args.offB;
+ int offsetC = args.offC;
+
+ int arg=0 ;
+
+ // All zgemm kernels shall have the same arguments: (A,B,C,M,N,K,alpha,beta,lda,ldb,ldc,offa,offb,offc)
+
+ setKernelArg<int>(kernel, arg++, M);
+ setKernelArg<int>(kernel, arg++, N);
+ setKernelArg<int>(kernel, arg++, K);
+ setKernelArg<DoubleComplex>(kernel, arg++, args.alpha);
+ setKernelArg<DoubleComplex>(kernel, arg++, args.beta);
+
+ setKernelArg<cl_mem>(kernel, arg++, args.A);
+ setKernelArg<cl_mem>(kernel, arg++, args.B);
+ setKernelArg<cl_mem>(kernel, arg++, args.C);
+
+ setKernelArg<int>(kernel, arg++, lda);
+ setKernelArg<int>(kernel, arg++, ldb);
+ setKernelArg<int>(kernel, arg++, ldc);
+
+ setKernelArg<int>(kernel, arg++, offsetA);
+ setKernelArg<int>(kernel, arg++, offsetB);
+ setKernelArg<int>(kernel, arg++, offsetC);
+
+ const size_t * ls = this->m_variant->ls ; // Each work group is made of ls[0] x ls[1] PE
+ const size_t * bwi = this->m_variant->bwi ; // Each PE updates bwi[0] x bwi[1] values
+
+ size_t globalThreads[2];
+
+ unsigned int thx, thy;
+
+ thx = M/bwi[0] + ((M%bwi[0] != 0) ? 1 : 0);
+ thx = thx/ls[0] + ((thx%ls[0] != 0) ? 1 : 0);
+ thx = ls[0] * thx;
+
+ thy = N/bwi[1] + ((N%bwi[1] != 0) ? 1 : 0);
+ thy = thy/ls[1] + ((thy%ls[1] != 0) ? 1 : 0);
+ thy = ls[1] * thy;
+
+ globalThreads[0] = thx;
+ globalThreads[1] = thy;
+
+ err = clEnqueueNDRangeKernel(queue, kernel, 2, NULL,
+ globalThreads, ls ,
+ args.numEventsInWaitList,
+ args.eventWaitList,
+ args.events);
+
+ clReleaseKernel(kernel) ;
+
+ if (VERB) printf(" ===> ERR=%d \n",(int)err) ;
+
+ return clblasStatus(err) ;
+}
+
+
+clblasZgemmFunctorGCN *
+ clblasZgemmFunctorGCN::provide(clblasZgemmFunctor::Args & args, const char* DevName)
+{
+
+ if ( args.order == clblasRowMajor )
+ return NULL ; // The RowMajor case shall never occur.
+
+ cl_device_id dev;
+ cl_context ctxt;
+
+ cl_int err = getDeviceAndContext(args.queue, dev, ctxt);
+ if (err != CL_SUCCESS)
+ {
+ return NULL;
+ }
+
+
+ cl_uint bitness = getAddressBits(dev);
+
+ const Variant * variant = select_variant( args, DevName, bitness ) ;
+ if ( variant == NULL )
+ return NULL ;
+
+
+
+ Cache::Lookup lookup(cache, ctxt, dev, variant) ;
+
+ if ( lookup.ok() )
+ {
+ clblasZgemmFunctorGCN * functor = lookup.get();
+ functor->retain(); // increment the reference counter to avoid deletion while it is still being used
+ return functor;
+ }
+
+ clblasZgemmFunctorGCN * functor = new clblasZgemmFunctorGCN(args, variant, err);
+ if (err != CL_SUCCESS)
+ {
+ return NULL;
+ }
+
+ lookup.set(functor) ;
+
+ return functor;
+
+}
+
diff --git a/src/library/blas/functor/hawaii.cc b/src/library/blas/functor/hawaii.cc
index 643ba9e..0d2b2c0 100644
--- a/src/library/blas/functor/hawaii.cc
+++ b/src/library/blas/functor/hawaii.cc
@@ -25,6 +25,7 @@
#include "gcn_sgemmSmallMatrices.h"
#include "hawaii_sgemmBranchKernel.h"
#include "hawaii_sgemmSplit64_32.h"
+#include "gcn_zgemm.h"
FunctorSelectorHawaii FunctorSelectorHawaii::instance ;
@@ -93,7 +94,7 @@ clblasDgemmFunctor * FunctorSelectorHawaii::select_dgemm_specific(clblasDgemmFun
}
-// The selector function for SGEMM on hawaii
+// The selector function for SGEMM on hawaii
clblasSgemmFunctor * FunctorSelectorHawaii::select_sgemm_specific(clblasSgemmFunctor::Args & args)
{
#ifdef CLBLAS_HAWAII_DYNAMIC_KERNEL
@@ -163,6 +164,26 @@ clblasSgemmFunctor * FunctorSelectorHawaii::select_sgemm_specific(clblasSgemmFun
}
+// The selector function for ZGEMM on hawaii
+clblasZgemmFunctor * FunctorSelectorHawaii::select_zgemm_specific(clblasZgemmFunctor::Args & args)
+{
+
+ //TODO: the logic below is complicated; Needs cleanup;
+ clblasZgemmFunctor * functor;
+
+ if ( args.M%32==0
+ && args.N%64==0
+ && args.K%8==0
+ && args.transA==clblasNoTrans
+ && args.transB==clblasTrans
+ && args.order==clblasColumnMajor) {
+ functor = clblasZgemmFunctorGCN::provide(args, "Hawaii");
+ } else {
+ return this->clblasFunctorSelector::select_zgemm_specific(args);
+ }
+
+}
+
// The selector function for DTRSM on hawaii
//
clblasDtrsmFunctor * FunctorSelectorHawaii::select_dtrsm_specific(clblasDtrsmFunctor::Args & args)
diff --git a/src/library/blas/functor/include/gcn_zgemm.h b/src/library/blas/functor/include/gcn_zgemm.h
new file mode 100644
index 0000000..8da218b
--- /dev/null
+++ b/src/library/blas/functor/include/gcn_zgemm.h
@@ -0,0 +1,62 @@
+#ifndef CLBLASZGEMMFUNCTORGCN
+#define CLBLASZGEMMFUNCTORGCN
+#include <string>
+
+class clblasZgemmFunctorGCN : public clblasZgemmFunctor
+{
+public:
+
+ //
+ // A structure that describes a kernel variant.
+ //
+ // It is important that all instances of those structures shall
+ // be const and static because their addresses are used as keys
+ // in the internal functor cache.
+ //
+ // Also, they shall all have a unique kernel name.
+ //
+ struct Variant
+ {
+ const char * kernel_name ;
+ const char * source ; // the kernel source (shall be unique)
+ const char * build_options;
+ const char * bin ;
+ size_t bin_size ;
+ clblasTranspose transA ; //
+ clblasTranspose transB ; //
+ unsigned divN ; // Required divisor of N (use 1 when N can be of any value)
+ unsigned divM ; // Required divisor of M (use 1 when M can be of any value)
+ unsigned divK ; // Required divisor of K (use 1 when K can be of any value)
+ size_t ls[2] ; // Local size (the work-group size)
+ size_t bwi[2] ; // Block size work-item: Number of elements calculated by each work items
+ // So basically each kernel is computing a block of
+ // (ls[0]*bwi[0]) x (ls[1]*bwi[1])
+ // elements of C.
+ std::string mult;
+ } ;
+
+private: // Constructor & Destructor
+
+ clblasZgemmFunctorGCN(Args & args, const Variant * variant, cl_int & err) ;
+
+public:
+
+ // Provide a suitable clblasZgemmFunctorGCN for the specified args
+ // or NULL if none
+ static clblasZgemmFunctorGCN * provide(clblasZgemmFunctor::Args & args, const char* DevName) ;
+
+public: // inherited member from clblasZgemmFunctor
+
+ virtual clblasStatus execute(Args &args) ;
+
+protected:
+ //we need a default constructor as we derive this class,
+ //but we can't use the specific constructor as the arguments won't be the same (variant!!!).
+ //Maybe it worth revisiting this class to have something cleaner
+ clblasZgemmFunctorGCN(){};
+ cl_program m_program ;
+protected:
+ const Variant * m_variant ; // Pointer to a 'const static' object describing the kernel variant.
+} ;
+
+#endif
diff --git a/src/library/blas/functor/include/hawaii.h b/src/library/blas/functor/include/hawaii.h
index b75ef63..bf5471a 100644
--- a/src/library/blas/functor/include/hawaii.h
+++ b/src/library/blas/functor/include/hawaii.h
@@ -33,6 +33,7 @@ public:
// and incx==1
virtual clblasDgemmFunctor * select_dgemm_specific(clblasDgemmFunctor::Args & args);
virtual clblasSgemmFunctor * select_sgemm_specific(clblasSgemmFunctor::Args & args);
+ virtual clblasZgemmFunctor * select_zgemm_specific(clblasZgemmFunctor::Args & args);
virtual clblasDtrsmFunctor * select_dtrsm_specific(clblasDtrsmFunctor::Args & args);
};
diff --git a/src/library/blas/gens/clTemplates/zgemm_gcn.cl b/src/library/blas/gens/clTemplates/zgemm_gcn.cl
new file mode 100644
index 0000000..fc28caf
--- /dev/null
+++ b/src/library/blas/gens/clTemplates/zgemm_gcn.cl
@@ -0,0 +1,314 @@
+/*******************************************************************************
+ * Notes:
+ * for column major, id(0) is row so C data is coalesced
+ * for row major, id(0) is col
+ ******************************************************************************/
+
+static const char * zgemm_NT_64_32_8_16x16_2x4__ALPHABETA = "
+
+
+// convert preprocs to ints for comparison
+#define _S_ 1
+#define _D_ 2
+#define _C_ 3
+#define _Z_ 4
+
+/*******************************************************************************
+ * Pre-Processor "Strings"
+ ******************************************************************************/
+#define COLUMN_MAJOR_STR ColMajor
+#define ROW_MAJOR_STR RowMajor
+
+/*******************************************************************************
+ * Kernel PreProcessor Definitions
+ ******************************************************************************/
+#define WG_NUM_ROWS 16
+#define WG_NUM_COLS 16
+#define MICRO_TILE_NUM_ROWS 2
+#define MICRO_TILE_NUM_COLS 4
+#define NUM_UNROLL_ITER 8
+#define ORDER ColMajor
+#define TRANSPOSE_A N
+#define TRANSPOSE_B T
+#define DATA_TYPE _Z_
+
+#define MACRO_TILE_NUM_ROWS 32
+#define MACRO_TILE_NUM_COLS 64
+// each row lengthened by this ammount
+#define LOCAL_ROW_PAD 1
+// each col lengthened by this ammount
+#define LOCAL_COL_PAD 1
+
+
+/*******************************************************************************
+ * Global Memory Indices
+ * Note: (a==b)==(c==d) means if both are true or neither is true
+ ******************************************************************************/
+
+/* col-major non-transposed
+ * row-major transposed */
+#define GET_GLOBAL_INDEX_N(ROW,COL,STRIDE) ((COL)*(STRIDE)+(ROW))
+
+/* col-major transposed
+ * row-major non-transposed */
+#define GET_GLOBAL_INDEX_T(ROW,COL,STRIDE) ((ROW)*(STRIDE)+(COL))
+
+// global A
+#if (ORDER==COLUMN_MAJOR_STR) == (TRANSPOSE_A==N)
+#define GET_GLOBAL_INDEX_A(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(lda))
+#else
+#define GET_GLOBAL_INDEX_A(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(lda))
+#endif
+
+// global B
+#if (ORDER==COLUMN_MAJOR_STR) == (TRANSPOSE_B==N)
+#define GET_GLOBAL_INDEX_B(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(ldb))
+#else
+#define GET_GLOBAL_INDEX_B(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(ldb))
+#endif
+
+// global C
+#if (ORDER==COLUMN_MAJOR_STR)
+#define GET_GLOBAL_INDEX_C(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(ldc))
+#else
+#define GET_GLOBAL_INDEX_C(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(ldc))
+#endif
+
+/*******************************************************************************
+ * Local Memory Indices
+ ******************************************************************************/
+
+// localA - rotated 90 degrees from B but use same accessor unless slow
+#define GET_LOCAL_INDEX_A(ROW,COL) (ROW + COL*(MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD) )
+#define GET_LOCAL_STEP_A ( ((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) \
+ * ((WG_NUM_ROWS)*(WG_NUM_COLS)/(MACRO_TILE_NUM_COLS))
+
+// localB
+#define GET_LOCAL_INDEX_B(ROW,COL) ((COL) + (ROW)*((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) )
+#define GET_LOCAL_STEP_B ( ((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) \
+ * ((WG_NUM_ROWS)*(WG_NUM_COLS)/(MACRO_TILE_NUM_COLS))
+
+/*******************************************************************************
+ * Data Types
+ ******************************************************************************/
+
+// single precision
+#if DATA_TYPE==_S_
+#define DATA_TYPE_STR float
+#define DATA_TYPE_CHAR s
+#define TYPE_MAD(MUL0,MUL1,DST) DST = mad(MUL0,MUL1,DST);
+#define TYPE_MAD2( DST, ALPHA, REG, BETA ) DST = (ALPHA)*(REG) + (BETA)*(DST);
+
+// double precision
+#elif DATA_TYPE==_D_
+#define DATA_TYPE_STR double
+#define DATA_TYPE_CHAR d
+#define TYPE_MAD(MUL0,MUL1,DST) DST = mad(MUL0,MUL1,DST);
+#define TYPE_MAD2( DST, ALPHA, REG, BETA ) DST = (ALPHA)*(REG) + (BETA)*(DST);
+
+// complex single precision
+#elif DATA_TYPE==_C_
+#define DATA_TYPE_STR float2
+#define DATA_TYPE_CHAR c
+#define TYPE_MAD(MUL0,MUL1,DST) \
+ DST.s0 = mad( MUL0.s0, MUL1.s0, DST.s0 ); \
+ DST.s0 = mad( -MUL0.s1, MUL1.s1, DST.s0 ); \
+ DST.s1 = mad( MUL0.s0, MUL1.s1, DST.s1 ); \
+ DST.s1 = mad( MUL0.s1, MUL1.s0, DST.s1 );
+#define TYPE_MAD2( DST, ALPHA, REG, BETA ) \
+ /* (1) */ \
+ type_mad2_tmp = REG.s0; \
+ REG.s0 *= ALPHA.s0; \
+ REG.s0 = mad( -ALPHA.s1, REG.s1, REG.s0 ); \
+ REG.s1 *= ALPHA.s0; \
+ REG.s1 = mad( ALPHA.s1, type_mad2_tmp, REG.s1 ); \
+ /* (2) */ \
+ REG.s0 = mad( BETA.s0, DST.s0, REG.s0 ); \
+ REG.s0 = mad( -BETA.s1, DST.s1, REG.s0 ); \
+ REG.s1 = mad( BETA.s1, DST.s0, REG.s1 ); \
+ REG.s1 = mad( BETA.s0, DST.s1, REG.s1 ); \
+ /* (3) */ \
+ DST = REG;
+
+// complex double precision
+#else
+#define DATA_TYPE_STR double2
+#define DATA_TYPE_CHAR z
+#define TYPE_MAD(MUL0,MUL1,DST) \
+ DST.s0 = mad( MUL0.s0, MUL1.s0, DST.s0 ); \
+ DST.s0 = mad( -MUL0.s1, MUL1.s1, DST.s0 ); \
+ DST.s1 = mad( MUL0.s0, MUL1.s1, DST.s1 ); \
+ DST.s1 = mad( MUL0.s1, MUL1.s0, DST.s1 );
+#define TYPE_MAD2( DST, ALPHA, REG, BETA ) \
+ /* (1) */ \
+ type_mad2_tmp = REG.s0; \
+ REG.s0 *= ALPHA.s0; \
+ REG.s0 = mad( -ALPHA.s1, REG.s1, REG.s0 ); \
+ REG.s1 *= ALPHA.s0; \
+ REG.s1 = mad( ALPHA.s1, type_mad2_tmp, REG.s1 ); \
+ /* (2) */ \
+ REG.s0 = mad( BETA.s0, DST.s0, REG.s0 ); \
+ REG.s0 = mad( -BETA.s1, DST.s1, REG.s0 ); \
+ REG.s1 = mad( BETA.s1, DST.s0, REG.s1 ); \
+ REG.s1 = mad( BETA.s0, DST.s1, REG.s1 ); \
+ /* (3) */ \
+ DST = REG;
+
+#endif
+
+/*******************************************************************************
+ * 2x4 micro tile
+ ******************************************************************************/
+#define MAD2x4 \
+ rA[0] = localA[offA + 0*WG_NUM_ROWS]; \
+ rA[1] = localA[offA + 1*WG_NUM_ROWS]; \
+ rB[0] = localB[offB + 0*WG_NUM_COLS]; \
+ rB[1] = localB[offB + 1*WG_NUM_COLS]; \
+ rB[2] = localB[offB + 2*WG_NUM_COLS]; \
+ rB[3] = localB[offB + 3*WG_NUM_COLS]; \
+ offA += (MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD); \
+ offB += (MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD); \
+ TYPE_MAD(rA[0],rB[0],rC[0][0]); \
+ TYPE_MAD(rA[1],rB[0],rC[1][0]); \
+ TYPE_MAD(rA[0],rB[1],rC[0][1]); \
+ TYPE_MAD(rA[1],rB[1],rC[1][1]); \
+ TYPE_MAD(rA[0],rB[2],rC[0][2]); \
+ TYPE_MAD(rA[1],rB[2],rC[1][2]); \
+ TYPE_MAD(rA[0],rB[3],rC[0][3]); \
+ TYPE_MAD(rA[1],rB[3],rC[1][3]); \
+ mem_fence(CLK_LOCAL_MEM_FENCE);
+
+// concatenate kernel name
+// zgemm_NT_64_32_8_16x16_2x4__ALPHABETA
+#define CONCAT_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC) \
+ DT ## gemm_ ## TA ## TB ## _ ## TILE_COLS ## _ ## TILE_ROWS ## _ ## NUI ## _ ## WGR ## x ## WGC ## _ ## MTR ## x ## MTC ## __ALPHABETA
+#define KERNEL_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC) CONCAT_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC)
+
+/*******************************************************************************
+ * Kernel
+ ******************************************************************************/
+__attribute__((reqd_work_group_size(WG_NUM_COLS,WG_NUM_ROWS,1)))
+__kernel void KERNEL_NAME(DATA_TYPE_CHAR,TRANSPOSE_A,TRANSPOSE_B,MACRO_TILE_NUM_COLS,MACRO_TILE_NUM_ROWS,NUM_UNROLL_ITER,WG_NUM_ROWS,WG_NUM_COLS,MICRO_TILE_NUM_ROWS,MICRO_TILE_NUM_COLS) (
+ uint const M,
+ uint const N,
+ uint const K,
+ DATA_TYPE_STR const alpha,
+ DATA_TYPE_STR const beta,
+ __global DATA_TYPE_STR const * restrict A,
+ __global DATA_TYPE_STR const * restrict B,
+ __global DATA_TYPE_STR * C,
+ uint const lda,
+ uint const ldb,
+ uint const ldc,
+ uint const offA,
+ uint const offB,
+ uint const offC )
+{
+ // registers
+ DATA_TYPE_STR rC[MICRO_TILE_NUM_ROWS][MICRO_TILE_NUM_COLS] = {0};
+ DATA_TYPE_STR rA[MICRO_TILE_NUM_ROWS];
+ DATA_TYPE_STR rB[MICRO_TILE_NUM_COLS];
+
+ // local memory
+ __local DATA_TYPE_STR localA[NUM_UNROLL_ITER*(MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD)];
+ __local DATA_TYPE_STR localB[NUM_UNROLL_ITER*(MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD)];
+
+/*
+ * for coalesced C writing
+ * if column major, id(0) is row
+ * if row major, id(0) is col
+ */
+ uint groupRow = get_group_id(0);
+ uint groupCol = get_group_id(1);
+ uint localRow = get_local_id(0);
+ uint localCol = get_local_id(1);
+ uint localSerial = localRow + localCol*WG_NUM_ROWS;
+
+ /*****************************************************************************
+ * global indices being loaded
+ ****************************************************************************/
+ // which gAij is this thread responsible for loading?
+#define globalARow (groupRow*MACRO_TILE_NUM_ROWS + localSerial%MACRO_TILE_NUM_ROWS)
+#define globalACol (localSerial/MACRO_TILE_NUM_ROWS)
+#define globalAIdx (GET_GLOBAL_INDEX_A( globalARow, globalACol ) )
+ A += globalAIdx;
+ // which gBij is this thread responsible for loading?
+#define globalBRow (localSerial/MACRO_TILE_NUM_COLS)
+#define globalBCol (groupCol*MACRO_TILE_NUM_COLS + localSerial%MACRO_TILE_NUM_COLS)
+#define globalBIdx (GET_GLOBAL_INDEX_B( globalBRow, globalBCol ) )
+ B += globalBIdx;
+
+ uint block_k = K / NUM_UNROLL_ITER;
+#pragma nounroll
+ do {
+
+ /***************************************************************************
+ * local indices being written
+ **************************************************************************/
+ // which lAij is this thread responsible for writing?
+#define localARow (localSerial % MACRO_TILE_NUM_ROWS)
+#define localACol (localSerial / MACRO_TILE_NUM_ROWS)
+#define localAStride ( (MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD) * (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_ROWS) )
+#define globalAStride ( GET_GLOBAL_INDEX_A(0, (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_ROWS) ) )
+#define localAIdx ( GET_LOCAL_INDEX_A(localARow, localACol) )
+ __local DATA_TYPE_STR *lA = localA + localAIdx;
+ // which lBij is this thread responsible for writing?
+#define localBRow ( localSerial / MACRO_TILE_NUM_COLS )
+#define localBCol ( localSerial % MACRO_TILE_NUM_COLS )
+#define localBIdx ( GET_LOCAL_INDEX_B(localBRow, localBCol) )
+#define localBStride ( (MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD) * (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_COLS) )
+#define globalBStride ( GET_GLOBAL_INDEX_B( (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_COLS), 0 ) )
+ __local DATA_TYPE_STR *lB = localB + localBIdx;
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ /***************************************************************************
+ * Load global -> local
+ * num loads = num threads / total loads
+ **************************************************************************/
+ // 2x4 uTile x 8unroll
+ lA[ 0*localAStride ] = A[ 0*globalAStride ];
+ lB[ 0*localBStride ] = B[ 0*globalBStride ];
+ lB[ 1*localBStride ] = B[ 1*globalBStride ];
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ uint offA = localRow;
+ uint offB = localCol;
+
+ /***************************************************************************
+ * do mads in registers
+ **************************************************************************/
+ MAD2x4
+ MAD2x4
+ MAD2x4
+ MAD2x4
+ MAD2x4
+ MAD2x4
+ MAD2x4
+ MAD2x4
+
+ // fully shift
+ A += lda*NUM_UNROLL_ITER; // b/c N
+ B += ldb*NUM_UNROLL_ITER; // b/c T
+
+ } while (--block_k > 0);
+
+ // which global Cij is this thread responsible for computing?
+ uint globalCRow = groupRow * MACRO_TILE_NUM_ROWS + localRow;
+ uint globalCCol = groupCol * MACRO_TILE_NUM_COLS + localCol;
+
+ /***************************************************************************
+ * write data
+ **************************************************************************/
+ double type_mad2_tmp; // used in TYPE_MAD2
+ TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+0*WG_NUM_COLS) ], alpha, rC[0][0], beta )
+ TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+1*WG_NUM_COLS) ], alpha, rC[0][1], beta )
+ TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+2*WG_NUM_COLS) ], alpha, rC[0][2], beta )
+ TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+3*WG_NUM_COLS) ], alpha, rC[0][3], beta )
+ TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+0*WG_NUM_COLS) ], alpha, rC[1][0], beta )
+ TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+1*WG_NUM_COLS) ], alpha, rC[1][1], beta )
+ TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+2*WG_NUM_COLS) ], alpha, rC[1][2], beta )
+ TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+3*WG_NUM_COLS) ], alpha, rC[1][3], beta )
+
+}
+
+";
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/clblas.git
More information about the debian-science-commits
mailing list