[clblas] 45/61: adding zgemm kernel for hawaii

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Fri Jul 24 22:49:47 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch master
in repository clblas.

commit 6623809f354ae2dd6d5fee22695320f245b89591
Author: David Tanner <guacamoleo at gmail.com>
Date:   Tue Jun 16 12:57:12 2015 -0500

    adding zgemm kernel for hawaii
---
 src/library/CMakeLists.txt                     |   4 +
 src/library/bingen.cmake                       |   1 +
 src/library/blas/functor/bonaire.cc            |   2 +-
 src/library/blas/functor/gcn_zgemm.cc          | 354 +++++++++++++++++++++++++
 src/library/blas/functor/hawaii.cc             |  23 +-
 src/library/blas/functor/include/gcn_zgemm.h   |  62 +++++
 src/library/blas/functor/include/hawaii.h      |   1 +
 src/library/blas/gens/clTemplates/zgemm_gcn.cl | 314 ++++++++++++++++++++++
 8 files changed, 759 insertions(+), 2 deletions(-)

diff --git a/src/library/CMakeLists.txt b/src/library/CMakeLists.txt
index b5f5d66..e45a2aa 100644
--- a/src/library/CMakeLists.txt
+++ b/src/library/CMakeLists.txt
@@ -76,6 +76,7 @@ set(SRC_BLAS
 	blas/functor/hawaii_sgemmSplit64_32.cc
 	blas/functor/gcn_dgemmCommon.cc
 	blas/functor/gcn_sgemm.cc
+	blas/functor/gcn_zgemm.cc
 	blas/functor/gcn_dgemmSmallMatrices.cc
 	blas/functor/gcn_sgemmSmallMatrices.cc
 	blas/functor/hawaii_sgemmBranchKernel.cc
@@ -107,6 +108,7 @@ set(SRC_BLAS_HEADERS
 	blas/functor/include/hawaii_sgemmSplit64_32.h
 	blas/functor/include/gcn_dgemmCommon.h
 	blas/functor/include/gcn_sgemm.h
+	blas/functor/include/gcn_zgemm.h
 	blas/functor/include/gcn_dgemmSmallMatrices.h
 	blas/functor/include/gcn_sgemmSmallMatrices.h
 	blas/functor/include/hawaii_sgemmBranchKernel.h
@@ -231,6 +233,7 @@ set (SRC_CL_TEMPLATES
 	dgemm_gcn_SmallMatrices.cl
 	sgemm_gcn_SmallMatrices.cl
 	sgemm_gcn.cl
+	zgemm_gcn.cl
 )
 
 set(SRC_CL_TEMPLATES_GEN
@@ -248,6 +251,7 @@ set(SRC_CL_TEMPLATES_GEN
 	sgemm_gcn_SmallMatrices.clTahiti_64.bin.cl
     sgemm_gcn_SmallMatrices.clBonaire_64.bin.cl
 	sgemm_gcn.clHawaii_64.bin.cl
+	zgemm_gcn.clHawaii_64.bin.cl
     sgemm_gcn.clBonaire_64.bin.cl
 	sgemm_gcn.clTahiti_64.bin.cl
 	sgemm_hawaiiSplit64_32.clHawaii_64.bin.cl
diff --git a/src/library/bingen.cmake b/src/library/bingen.cmake
index bb6515d..7b13557 100644
--- a/src/library/bingen.cmake
+++ b/src/library/bingen.cmake
@@ -13,6 +13,7 @@ set (BIN_CL_TEMPLATES_HAWAII_CL2
 ${CLTEMPLATE_PATH}/dgemm_hawaiiSplitKernel.cl
 ${CLTEMPLATE_PATH}/sgemm_hawaiiSplitKernel.cl
 ${CLTEMPLATE_PATH}/sgemm_gcn.cl
+${CLTEMPLATE_PATH}/zgemm_gcn.cl
 ${CLTEMPLATE_PATH}/sgemm_gcn_SmallMatrices.cl
 ${CLTEMPLATE_PATH}/sgemm_hawaiiSplit64_32.cl
 )
diff --git a/src/library/blas/functor/bonaire.cc b/src/library/blas/functor/bonaire.cc
index a97d9b9..4e2989d 100644
--- a/src/library/blas/functor/bonaire.cc
+++ b/src/library/blas/functor/bonaire.cc
@@ -37,7 +37,7 @@ FunctorSelectorBonaire::FunctorSelectorBonaire()
 
 
 
-// The selector function for SGEMM on hawaii 
+// The selector function for SGEMM on bonaire 
 clblasSgemmFunctor * FunctorSelectorBonaire::select_sgemm_specific(clblasSgemmFunctor::Args & args)
 {
 #ifdef CLBLAS_BONAIRE_DYNAMIC_KERNEL
diff --git a/src/library/blas/functor/gcn_zgemm.cc b/src/library/blas/functor/gcn_zgemm.cc
new file mode 100644
index 0000000..72ac9a1
--- /dev/null
+++ b/src/library/blas/functor/gcn_zgemm.cc
@@ -0,0 +1,354 @@
+#include <stdio.h>
+#include <string.h>
+#include <clBLAS.h>
+
+#include <devinfo.h>
+#include "clblas-internal.h"
+#include "solution_seq.h"
+
+#include <functor.h>
+#include <binary_lookup.h>
+#include <iostream>
+
+#include <functor_xgemm.h>
+#include <tahiti.h>
+#include <hawaii.h>
+#include <gcn_ZGEMM.h>
+
+#include "BinaryBuild.h"
+
+//for the moment only managing source code and cl binary
+
+#if BUILD_KERNEL_FROM_STRING
+#include "zgemm_gcn.clT"
+#else 
+
+#include "zgemm_gcn.clHawaii_64.bin.clT"
+//#include "zgemm_gcn.clBonaire_64.bin.clT"
+
+//#include "ZGEMM_gcn.clTahiti_64.bin.clT"
+#endif
+
+
+//
+// The name of the 'const char *' providing the kernel OpenCL source
+//
+//  dgemm_TATB_DIVN_DIVM_DIVK_BS0xBS1_NV0xNV1
+//
+// For instance, DGEMM_SRC_NAME(N,T,32,64,8,8,8,4,8) is dgemm_NT_32_64_8_8x8_4x8
+//
+#define ZGEMM_SRC_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT)               zgemm_##TA##TB##_##DIVN##_##DIVM##_##DIVK##_##BS0##x##BS1##_##NV0##x##NV1##MULT
+#define ZGEMM_SRC_NAME_TAHITI(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT)   zgemm_##TA##TB##_##DIVN##_##DIVM##_##DIVK##_##BS0##x##BS1##_##NV0##x##NV1##MULT##_##BITS##_bin_Tahiti
+#define ZGEMM_SRC_NAME_HAWAII(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT)   zgemm_##TA##TB##_##DIVN##_##DIVM##_##DIVK##_##BS0##x##BS1##_##NV0##x##NV1##MULT##_##BITS##_bin_Hawaii
+#define ZGEMM_SRC_NAME_BONAIRE(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT)  zgemm_##TA##TB##_##DIVN##_##DIVM##_##DIVK##_##BS0##x##BS1##_##NV0##x##NV1##MULT##_##BITS##_bin_Bonaire
+
+//
+// The name of the 'const char []' global variable that contain the SPIR data.
+// That name is similar to the one produced by DGEMM_SRC_NAME but suffixed by _spir
+//
+#define ZGEMM_SPIR_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT)  zgemm_##TA##TB##_##DIVN##_##DIVM##_##DIVK##_##BS0##x##BS1##_##NV0##x##NV1_spir
+
+//
+// The name of the 'const char []' global variable that contain the CL binaries data.
+// That name is similar to the one produced by DGEMM_SRC_NAME but suffixed by _bin
+//
+
+
+// The name of the kernel itself.
+// This is basically the name returned by DGEMM_SRC_NAME but as string
+//
+#define ZGEMM_KERNEL_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) "zgemm_"  #TA #TB "_" #DIVN "_" #DIVM "_" #DIVK "_" #BS0 "x" #BS1 "_" #NV0 "x" #NV1 #MULT
+
+//
+// Helpers to transform N and T in proper clblas values for the macros above
+//
+#define trans_N clblasNoTrans
+#define trans_T clblasTrans
+
+
+// Fill a variant descriptor using OpenCL source 
+#define ZGEMM_VARIANT_SRC(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) { \
+  ZGEMM_KERNEL_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) ,     \
+  ZGEMM_SRC_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) ,        \
+  NULL, NULL, 0,                                                \
+  trans_##TA, trans_##TB,                                       \
+  DIVN,DIVM,DIVK,                                               \
+{ BS0, BS1 } ,                                                \
+{ NV0, NV1 }                                                  \
+} 
+
+// Fill a variant descriptor using SPIR  
+#define ZGEMM_VARIANT_SPIR(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) {  \
+  ZGEMM_KERNEL_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1) ,       \
+  NULL , "-x spir -spir-std=1.2"                                  \
+  ZGEMM_SPIR_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1),          \
+  sizeof(ZGEMM_SPIR_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1)),  \
+  trans_##TA,trans_##TB,                                          \
+  DIVN,DIVM,DIVK,                                                 \
+{ BS0, BS1 } ,                                                  \
+{ NV0, NV1 }                                                    \
+} 
+
+// Fill a variant descriptor using CL Binaries  
+#define ZGEMM_VARIANT_BIN_CL1(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,DEVICE,MULT) {  \
+  ZGEMM_KERNEL_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) ,           \
+  NULL , NULL,                                                        \
+  ZGEMM_SRC_NAME##_##DEVICE(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT),          \
+  sizeof(ZGEMM_SRC_NAME##_##DEVICE(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT)),  \
+  trans_##TA,trans_##TB,                                              \
+  DIVN,DIVM,DIVK,                                                     \
+{ BS0, BS1 } ,                                                      \
+{ NV0, NV1 }                                                        \
+} 
+
+
+#define ZGEMM_VARIANT_BIN_CL2(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,DEVICE,MULT) {  \
+  ZGEMM_KERNEL_NAME(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,MULT) ,           \
+  NULL , "-cl-std=CL2.0",                                                        \
+  ZGEMM_SRC_NAME##_##DEVICE(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT),          \
+  sizeof(ZGEMM_SRC_NAME##_##DEVICE(TA,TB,DIVN,DIVM,DIVK,BS0,BS1,NV0,NV1,BITS,MULT)),  \
+  trans_##TA,trans_##TB,                                              \
+  DIVN,DIVM,DIVK,                                                     \
+{ BS0, BS1 } ,                                                      \
+{ NV0, NV1 }                                                        \
+} 
+
+// Make it 1 to enable additional debug 'print' 
+#define VERB 0
+
+// Just because the full name is too long
+typedef clblasZgemmFunctorGCN::Variant Variant ; 
+
+//
+// The static cache used to store all instances of clblasZgemmFunctorGCN 
+//
+typedef clblasFunctorCache<clblasZgemmFunctorGCN,const Variant *> Cache ;
+static Cache cache  ;
+
+
+// return true iff a kernel variant is applicable to the specified args
+static bool applicable( const Variant & var, clblasZgemmFunctor::Args & args ) 
+{
+#if 0
+  // Transpose values are tested in select_variant
+  if ( args.transA != var.transA ) return false ;
+  if ( args.transB != var.transB ) return false ;
+#endif
+  if ( args.N % var.divN != 0 ) return false ; 
+  if ( args.M % var.divM != 0 ) return false ; 
+  if ( args.K % var.divK != 0 ) return false ; 
+  //if ( args.beta==0 && var.mult.compare("__ALPHA")!=0)
+  //  return false ;
+
+  return true ;
+}
+
+//
+// The goal of this function is to return the Variant to be used 
+// for the ZGEMM specified by 'args'. 
+//
+// The variants are typically tested sequentially from the more 
+// specific to the more generic. Additional conditions can be 
+// placed into the surrounding 'if' (typically that would be 
+// to perform additional tests on M, N and K).
+// 
+//
+
+static const Variant * select_variant( clblasZgemmFunctor::Args & args, const char* DevName, cl_uint _64BitsUse )
+{
+  if(_64BitsUse!=64)
+  {
+    std::cout<<"we don't support clblas on 32 bits"<< std::endl;
+    assert(1);
+    return NULL;
+  }
+
+  if ( args.transA == clblasNoTrans ) 
+  {
+    if ( args.transB == clblasTrans ) 
+    {
+        //we only manage the binary version here
+        if(!strcmp(DevName, "Hawaii"))
+        {
+#ifndef CLBLAS_HAWAII_DYNAMIC_KERNEL
+          //zgemm_NT_32_64_8_16x16_2x4__ALPHABETA_64_bin_Hawaii
+          static const Variant variant = ZGEMM_VARIANT_BIN_CL2(N,T,64,32,8,16,16,2,4,64,HAWAII, __ALPHABETA) ;
+          if ( applicable(variant,args) ) // needs to return true
+            return &variant;
+#endif
+        }
+    }
+  }
+
+  return NULL ; // No suitable variant ... will use the fallback
+
+}  
+
+/*
+ * constructor
+ */
+clblasZgemmFunctorGCN::clblasZgemmFunctorGCN(Args & args, const Variant * variant, cl_int & err)  :
+  m_program(0) , m_variant(variant)
+{
+
+  cl_device_id device;
+  cl_context context;
+
+  cl_command_queue queue = args.queue;
+  err = getDeviceAndContext(queue, device, context);
+  if( err != CL_SUCCESS )
+  {
+    return;
+  }
+
+  if (VERB) printf(" ===> GET KERNEL %s\n", this->m_variant->kernel_name) ;
+
+  //Ben do I use the correct "kernel_name"?
+  BinaryLookup bl(context, device, "clblasZgemmFunctorGCN");
+  //clGetDeviceInfo(device, CL_DEVICE_NAME);
+
+  bl.variantRaw( this->m_variant->kernel_name, strlen(this->m_variant->kernel_name)+1 ) ;
+
+  if ( !bl.found() ) // may create empty file or may wait until file is ready  
+  {
+    if ( this->m_variant->bin != 0 ) 
+    {
+      // build from a pre-compiled version of the kernel (SPIR or cl binaries)
+      err = bl.buildFromBinary(this->m_variant->bin, this->m_variant->bin_size, this->m_variant->build_options);
+    }
+    else
+    {
+      // directly build from a char* 
+      err = bl.buildFromSource(this->m_variant->source);
+    } 
+
+    if ( err != CL_SUCCESS )
+    {  
+      if (VERB) printf(" ===> BUILD PROBLEM\n") ;
+
+      return;
+    }
+  }
+
+  this->m_program = bl.getProgram();
+}
+
+clblasStatus clblasZgemmFunctorGCN::execute(Args &args) 
+{
+  cl_int err;
+  cl_command_queue queue = args.queue;
+
+  if (VERB) printf(" ===> EXECUTE KERNEL %s\n", this->m_variant->kernel_name) ;
+
+  cl_kernel kernel = clCreateKernel( this->m_program, this->m_variant->kernel_name,  &err);
+  if (err != CL_SUCCESS) return clblasStatus(err) ;
+
+  if (VERB) printf(" ===> FOUND %s\n", this->m_variant->kernel_name) ;
+
+  int M   = args.M, N = args.N, K = args.K;
+  int lda = args.lda, ldb = args.ldb, ldc = args.ldc;
+
+  int offsetA = args.offA;
+  int offsetB = args.offB;
+  int offsetC = args.offC;
+
+  int arg=0 ; 
+
+  // All zgemm kernels shall have the same arguments: (A,B,C,M,N,K,alpha,beta,lda,ldb,ldc,offa,offb,offc) 
+  
+  setKernelArg<int>(kernel, arg++, M);
+  setKernelArg<int>(kernel, arg++, N);
+  setKernelArg<int>(kernel, arg++, K);
+  setKernelArg<DoubleComplex>(kernel, arg++, args.alpha);
+  setKernelArg<DoubleComplex>(kernel, arg++, args.beta);
+
+  setKernelArg<cl_mem>(kernel, arg++, args.A);
+  setKernelArg<cl_mem>(kernel, arg++, args.B);
+  setKernelArg<cl_mem>(kernel, arg++, args.C);
+
+  setKernelArg<int>(kernel, arg++, lda);
+  setKernelArg<int>(kernel, arg++, ldb);
+  setKernelArg<int>(kernel, arg++, ldc);
+
+  setKernelArg<int>(kernel, arg++, offsetA);
+  setKernelArg<int>(kernel, arg++, offsetB);
+  setKernelArg<int>(kernel, arg++, offsetC);
+
+  const size_t * ls  = this->m_variant->ls  ; // Each work group is made of ls[0] x ls[1]  PE
+  const size_t * bwi = this->m_variant->bwi ; // Each PE updates bwi[0] x bwi[1] values
+
+  size_t globalThreads[2];
+
+  unsigned int thx, thy;
+
+  thx   = M/bwi[0] + ((M%bwi[0] != 0) ? 1 : 0);   
+  thx   = thx/ls[0] + ((thx%ls[0] != 0) ? 1 : 0); 
+  thx   = ls[0] * thx;
+
+  thy   = N/bwi[1] + ((N%bwi[1] != 0) ? 1 : 0);  
+  thy   = thy/ls[1] + ((thy%ls[1] != 0) ? 1 : 0); 
+  thy   = ls[1] * thy;
+
+  globalThreads[0] = thx;
+  globalThreads[1] = thy;
+
+  err = clEnqueueNDRangeKernel(queue, kernel, 2, NULL,
+    globalThreads, ls , 
+    args.numEventsInWaitList, 
+    args.eventWaitList, 
+    args.events);
+
+  clReleaseKernel(kernel) ;
+
+  if (VERB) printf(" ===> ERR=%d \n",(int)err) ;
+
+  return clblasStatus(err) ;
+}
+
+
+clblasZgemmFunctorGCN * 
+  clblasZgemmFunctorGCN::provide(clblasZgemmFunctor::Args & args, const char* DevName) 
+{
+
+  if ( args.order == clblasRowMajor ) 
+    return NULL ;   // The RowMajor case shall never occur. 
+
+  cl_device_id dev;
+  cl_context   ctxt;
+
+  cl_int err = getDeviceAndContext(args.queue, dev, ctxt);
+  if (err != CL_SUCCESS)
+  {
+    return NULL;
+  }
+
+
+  cl_uint bitness = getAddressBits(dev);
+
+  const Variant * variant = select_variant( args, DevName, bitness ) ;
+  if ( variant == NULL )  
+    return NULL ; 
+
+
+
+  Cache::Lookup lookup(cache, ctxt, dev, variant) ;
+
+  if ( lookup.ok() )
+  {
+    clblasZgemmFunctorGCN * functor = lookup.get();
+    functor->retain(); // increment the reference counter to avoid deletion while it is still being used
+    return functor;
+  }
+
+  clblasZgemmFunctorGCN * functor = new clblasZgemmFunctorGCN(args, variant, err);
+  if (err != CL_SUCCESS)
+  {
+    return NULL;
+  }
+
+  lookup.set(functor) ;
+
+  return functor;
+
+}
+
diff --git a/src/library/blas/functor/hawaii.cc b/src/library/blas/functor/hawaii.cc
index 643ba9e..0d2b2c0 100644
--- a/src/library/blas/functor/hawaii.cc
+++ b/src/library/blas/functor/hawaii.cc
@@ -25,6 +25,7 @@
 #include "gcn_sgemmSmallMatrices.h"
 #include "hawaii_sgemmBranchKernel.h"
 #include "hawaii_sgemmSplit64_32.h"
+#include "gcn_zgemm.h"
 
 FunctorSelectorHawaii FunctorSelectorHawaii::instance ;
 
@@ -93,7 +94,7 @@ clblasDgemmFunctor * FunctorSelectorHawaii::select_dgemm_specific(clblasDgemmFun
 }
 
 
-// The selector function for SGEMM on hawaii 
+// The selector function for SGEMM on hawaii
 clblasSgemmFunctor * FunctorSelectorHawaii::select_sgemm_specific(clblasSgemmFunctor::Args & args)
 {
 #ifdef CLBLAS_HAWAII_DYNAMIC_KERNEL
@@ -163,6 +164,26 @@ clblasSgemmFunctor * FunctorSelectorHawaii::select_sgemm_specific(clblasSgemmFun
 }
 
 
+// The selector function for ZGEMM on hawaii
+clblasZgemmFunctor * FunctorSelectorHawaii::select_zgemm_specific(clblasZgemmFunctor::Args & args)
+{
+
+  //TODO: the logic below is complicated; Needs cleanup;
+  clblasZgemmFunctor * functor;
+
+  if ( args.M%32==0
+    && args.N%64==0
+    && args.K%8==0
+    && args.transA==clblasNoTrans
+    && args.transB==clblasTrans
+    && args.order==clblasColumnMajor) {
+    functor = clblasZgemmFunctorGCN::provide(args, "Hawaii");
+  } else {
+    return this->clblasFunctorSelector::select_zgemm_specific(args);
+  }
+
+}
+
 // The selector function for DTRSM on hawaii
 //
 clblasDtrsmFunctor * FunctorSelectorHawaii::select_dtrsm_specific(clblasDtrsmFunctor::Args & args)
diff --git a/src/library/blas/functor/include/gcn_zgemm.h b/src/library/blas/functor/include/gcn_zgemm.h
new file mode 100644
index 0000000..8da218b
--- /dev/null
+++ b/src/library/blas/functor/include/gcn_zgemm.h
@@ -0,0 +1,62 @@
+#ifndef CLBLASZGEMMFUNCTORGCN
+#define CLBLASZGEMMFUNCTORGCN
+#include <string>
+
+class clblasZgemmFunctorGCN : public clblasZgemmFunctor 
+{
+public:
+  
+  //
+  // A structure that describes a kernel variant. 
+  //
+  // It is important that all instances of those structures shall 
+  // be const and static because their addresses are used as keys 
+  // in the internal functor cache. 
+  // 
+  // Also, they shall all have a unique kernel name.
+  // 
+  struct Variant 
+  {    
+    const char *    kernel_name ;  
+    const char *    source ;   // the kernel source (shall be unique)
+    const char *    build_options;
+    const char *    bin ; 
+    size_t          bin_size ; 
+    clblasTranspose transA ;   //
+    clblasTranspose transB ;   //
+    unsigned        divN ;     // Required divisor of N  (use 1 when N can be of any value) 
+    unsigned        divM ;     // Required divisor of M  (use 1 when M can be of any value) 
+    unsigned        divK ;     // Required divisor of K  (use 1 when K can be of any value) 
+    size_t          ls[2]  ;   // Local size (the work-group size)
+    size_t          bwi[2] ;   // Block size work-item:  Number of elements calculated by each work items 
+                               // So basically each kernel is computing a block of
+                               //   (ls[0]*bwi[0]) x (ls[1]*bwi[1])  
+                               // elements of C. 
+    std::string mult;
+  } ;
+  
+private:  // Constructor & Destructor
+
+  clblasZgemmFunctorGCN(Args & args, const Variant * variant, cl_int & err) ;
+  
+public:
+
+  // Provide a suitable clblasZgemmFunctorGCN for the specified args
+  // or NULL if none 
+  static clblasZgemmFunctorGCN * provide(clblasZgemmFunctor::Args & args, const char* DevName) ;
+
+public: // inherited member from clblasZgemmFunctor
+
+  virtual clblasStatus execute(Args &args) ;
+
+protected:
+  //we need a default constructor as we derive this class, 
+  //but we can't use the specific constructor as the arguments won't be the same (variant!!!).
+  //Maybe it worth revisiting this class to have something cleaner
+  clblasZgemmFunctorGCN(){};
+  cl_program      m_program ;
+protected:
+  const Variant * m_variant ; // Pointer to a 'const static' object describing the kernel variant. 
+} ; 
+
+#endif
diff --git a/src/library/blas/functor/include/hawaii.h b/src/library/blas/functor/include/hawaii.h
index b75ef63..bf5471a 100644
--- a/src/library/blas/functor/include/hawaii.h
+++ b/src/library/blas/functor/include/hawaii.h
@@ -33,6 +33,7 @@ public:
     // and incx==1
     virtual clblasDgemmFunctor * select_dgemm_specific(clblasDgemmFunctor::Args & args);
 	virtual clblasSgemmFunctor * select_sgemm_specific(clblasSgemmFunctor::Args & args);
+	virtual clblasZgemmFunctor * select_zgemm_specific(clblasZgemmFunctor::Args & args);
     virtual clblasDtrsmFunctor * select_dtrsm_specific(clblasDtrsmFunctor::Args & args);
 
 };
diff --git a/src/library/blas/gens/clTemplates/zgemm_gcn.cl b/src/library/blas/gens/clTemplates/zgemm_gcn.cl
new file mode 100644
index 0000000..fc28caf
--- /dev/null
+++ b/src/library/blas/gens/clTemplates/zgemm_gcn.cl
@@ -0,0 +1,314 @@
+/*******************************************************************************
+ * Notes:
+ * for column major, id(0) is row so C data is coalesced
+ * for row major, id(0) is col
+ ******************************************************************************/
+
+static const char * zgemm_NT_64_32_8_16x16_2x4__ALPHABETA = "
+
+
+// convert preprocs to ints for comparison
+#define _S_ 1
+#define _D_ 2
+#define _C_ 3
+#define _Z_ 4
+
+/*******************************************************************************
+ * Pre-Processor "Strings"
+ ******************************************************************************/
+#define COLUMN_MAJOR_STR      ColMajor
+#define ROW_MAJOR_STR         RowMajor
+
+/*******************************************************************************
+ * Kernel PreProcessor Definitions
+ ******************************************************************************/
+#define WG_NUM_ROWS           16
+#define WG_NUM_COLS           16
+#define MICRO_TILE_NUM_ROWS   2
+#define MICRO_TILE_NUM_COLS   4
+#define NUM_UNROLL_ITER       8
+#define ORDER                 ColMajor
+#define TRANSPOSE_A           N
+#define TRANSPOSE_B           T
+#define DATA_TYPE             _Z_
+
+#define MACRO_TILE_NUM_ROWS   32
+#define MACRO_TILE_NUM_COLS   64
+// each row lengthened by this ammount
+#define LOCAL_ROW_PAD         1
+// each col lengthened by this ammount
+#define LOCAL_COL_PAD         1
+
+
+/*******************************************************************************
+ * Global Memory Indices
+ * Note: (a==b)==(c==d) means if both are true or neither is true
+ ******************************************************************************/
+
+/* col-major non-transposed
+ * row-major transposed */
+#define GET_GLOBAL_INDEX_N(ROW,COL,STRIDE) ((COL)*(STRIDE)+(ROW))
+
+/* col-major transposed
+ * row-major non-transposed */
+#define GET_GLOBAL_INDEX_T(ROW,COL,STRIDE) ((ROW)*(STRIDE)+(COL))
+
+// global A
+#if (ORDER==COLUMN_MAJOR_STR) == (TRANSPOSE_A==N)
+#define GET_GLOBAL_INDEX_A(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(lda))
+#else
+#define GET_GLOBAL_INDEX_A(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(lda))
+#endif
+
+// global B
+#if (ORDER==COLUMN_MAJOR_STR) == (TRANSPOSE_B==N)
+#define GET_GLOBAL_INDEX_B(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(ldb))
+#else
+#define GET_GLOBAL_INDEX_B(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(ldb))
+#endif
+
+// global C
+#if (ORDER==COLUMN_MAJOR_STR)
+#define GET_GLOBAL_INDEX_C(ROW,COL) GET_GLOBAL_INDEX_N((ROW),(COL),(ldc))
+#else
+#define GET_GLOBAL_INDEX_C(ROW,COL) GET_GLOBAL_INDEX_T((ROW),(COL),(ldc))
+#endif
+
+/*******************************************************************************
+ * Local Memory Indices
+ ******************************************************************************/
+
+// localA - rotated 90 degrees from B but use same accessor unless slow
+#define GET_LOCAL_INDEX_A(ROW,COL) (ROW + COL*(MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD) )
+#define GET_LOCAL_STEP_A ( ((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) \
+    * ((WG_NUM_ROWS)*(WG_NUM_COLS)/(MACRO_TILE_NUM_COLS))
+
+// localB
+#define GET_LOCAL_INDEX_B(ROW,COL) ((COL) + (ROW)*((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) )
+#define GET_LOCAL_STEP_B ( ((MACRO_TILE_NUM_COLS)+(LOCAL_ROW_PAD)) \
+    * ((WG_NUM_ROWS)*(WG_NUM_COLS)/(MACRO_TILE_NUM_COLS))
+
+/*******************************************************************************
+ * Data Types
+ ******************************************************************************/
+
+// single precision
+#if DATA_TYPE==_S_
+#define DATA_TYPE_STR         float
+#define DATA_TYPE_CHAR        s
+#define TYPE_MAD(MUL0,MUL1,DST) DST = mad(MUL0,MUL1,DST);
+#define TYPE_MAD2( DST, ALPHA, REG, BETA ) DST = (ALPHA)*(REG) + (BETA)*(DST);
+
+// double precision
+#elif DATA_TYPE==_D_
+#define DATA_TYPE_STR         double
+#define DATA_TYPE_CHAR        d
+#define TYPE_MAD(MUL0,MUL1,DST) DST = mad(MUL0,MUL1,DST);
+#define TYPE_MAD2( DST, ALPHA, REG, BETA ) DST = (ALPHA)*(REG) + (BETA)*(DST);
+
+// complex single precision
+#elif DATA_TYPE==_C_
+#define DATA_TYPE_STR         float2
+#define DATA_TYPE_CHAR        c
+#define TYPE_MAD(MUL0,MUL1,DST) \
+  DST.s0 = mad(  MUL0.s0, MUL1.s0, DST.s0 ); \
+  DST.s0 = mad( -MUL0.s1, MUL1.s1, DST.s0 ); \
+  DST.s1 = mad(  MUL0.s0, MUL1.s1, DST.s1 ); \
+  DST.s1 = mad(  MUL0.s1, MUL1.s0, DST.s1 );
+#define TYPE_MAD2( DST, ALPHA, REG, BETA ) \
+  /* (1) */ \
+  type_mad2_tmp = REG.s0; \
+  REG.s0 *= ALPHA.s0; \
+  REG.s0 = mad( -ALPHA.s1, REG.s1, REG.s0 ); \
+  REG.s1 *= ALPHA.s0; \
+  REG.s1 = mad(  ALPHA.s1, type_mad2_tmp, REG.s1 ); \
+  /* (2) */ \
+  REG.s0 = mad(  BETA.s0, DST.s0, REG.s0 ); \
+  REG.s0 = mad( -BETA.s1, DST.s1, REG.s0 ); \
+  REG.s1 = mad(  BETA.s1, DST.s0, REG.s1 ); \
+  REG.s1 = mad(  BETA.s0, DST.s1, REG.s1 ); \
+  /* (3) */ \
+  DST = REG;
+
+// complex double precision
+#else
+#define DATA_TYPE_STR         double2
+#define DATA_TYPE_CHAR        z
+#define TYPE_MAD(MUL0,MUL1,DST) \
+  DST.s0 = mad(  MUL0.s0, MUL1.s0, DST.s0 ); \
+  DST.s0 = mad( -MUL0.s1, MUL1.s1, DST.s0 ); \
+  DST.s1 = mad(  MUL0.s0, MUL1.s1, DST.s1 ); \
+  DST.s1 = mad(  MUL0.s1, MUL1.s0, DST.s1 );
+#define TYPE_MAD2( DST, ALPHA, REG, BETA ) \
+  /* (1) */ \
+  type_mad2_tmp = REG.s0; \
+  REG.s0 *= ALPHA.s0; \
+  REG.s0 = mad( -ALPHA.s1, REG.s1, REG.s0 ); \
+  REG.s1 *= ALPHA.s0; \
+  REG.s1 = mad(  ALPHA.s1, type_mad2_tmp, REG.s1 ); \
+  /* (2) */ \
+  REG.s0 = mad(  BETA.s0, DST.s0, REG.s0 ); \
+  REG.s0 = mad( -BETA.s1, DST.s1, REG.s0 ); \
+  REG.s1 = mad(  BETA.s1, DST.s0, REG.s1 ); \
+  REG.s1 = mad(  BETA.s0, DST.s1, REG.s1 ); \
+  /* (3) */ \
+  DST = REG;
+
+#endif
+
+/*******************************************************************************
+ * 2x4 micro tile
+ ******************************************************************************/
+#define MAD2x4 \
+  rA[0] = localA[offA + 0*WG_NUM_ROWS]; \
+  rA[1] = localA[offA + 1*WG_NUM_ROWS]; \
+  rB[0] = localB[offB + 0*WG_NUM_COLS]; \
+  rB[1] = localB[offB + 1*WG_NUM_COLS]; \
+  rB[2] = localB[offB + 2*WG_NUM_COLS]; \
+  rB[3] = localB[offB + 3*WG_NUM_COLS]; \
+  offA += (MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD); \
+  offB += (MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD); \
+  TYPE_MAD(rA[0],rB[0],rC[0][0]); \
+  TYPE_MAD(rA[1],rB[0],rC[1][0]); \
+  TYPE_MAD(rA[0],rB[1],rC[0][1]); \
+  TYPE_MAD(rA[1],rB[1],rC[1][1]); \
+  TYPE_MAD(rA[0],rB[2],rC[0][2]); \
+  TYPE_MAD(rA[1],rB[2],rC[1][2]); \
+  TYPE_MAD(rA[0],rB[3],rC[0][3]); \
+  TYPE_MAD(rA[1],rB[3],rC[1][3]); \
+  mem_fence(CLK_LOCAL_MEM_FENCE);
+
+// concatenate kernel name
+// zgemm_NT_64_32_8_16x16_2x4__ALPHABETA
+#define CONCAT_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC) \
+  DT ## gemm_ ## TA ## TB ## _ ## TILE_COLS ## _ ## TILE_ROWS ## _ ## NUI ## _ ## WGR ## x ## WGC ## _ ## MTR ## x ## MTC ## __ALPHABETA
+#define KERNEL_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC) CONCAT_NAME(DT,TA,TB,TILE_COLS,TILE_ROWS,NUI,WGR,WGC,MTR,MTC)
+
+/*******************************************************************************
+ * Kernel
+ ******************************************************************************/
+__attribute__((reqd_work_group_size(WG_NUM_COLS,WG_NUM_ROWS,1)))
+__kernel void KERNEL_NAME(DATA_TYPE_CHAR,TRANSPOSE_A,TRANSPOSE_B,MACRO_TILE_NUM_COLS,MACRO_TILE_NUM_ROWS,NUM_UNROLL_ITER,WG_NUM_ROWS,WG_NUM_COLS,MICRO_TILE_NUM_ROWS,MICRO_TILE_NUM_COLS) ( 
+  uint const M,
+  uint const N,
+  uint const K,
+  DATA_TYPE_STR const alpha,
+  DATA_TYPE_STR const beta,
+  __global DATA_TYPE_STR const * restrict A,
+  __global DATA_TYPE_STR const * restrict B,
+  __global DATA_TYPE_STR       *          C,
+  uint const lda,
+  uint const ldb,
+  uint const ldc,
+  uint const offA,
+  uint const offB,
+  uint const offC )
+{
+  // registers
+  DATA_TYPE_STR rC[MICRO_TILE_NUM_ROWS][MICRO_TILE_NUM_COLS]  = {0};
+  DATA_TYPE_STR rA[MICRO_TILE_NUM_ROWS];
+  DATA_TYPE_STR rB[MICRO_TILE_NUM_COLS];
+    
+  // local memory
+  __local DATA_TYPE_STR localA[NUM_UNROLL_ITER*(MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD)];
+  __local DATA_TYPE_STR localB[NUM_UNROLL_ITER*(MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD)];
+
+/*  
+ * for coalesced C writing
+ * if column major, id(0) is row
+ * if row major, id(0) is col
+ */
+  uint groupRow = get_group_id(0);
+  uint groupCol = get_group_id(1);
+  uint localRow = get_local_id(0);
+  uint localCol = get_local_id(1);
+  uint localSerial = localRow + localCol*WG_NUM_ROWS;
+
+  /*****************************************************************************
+   * global indices being loaded
+   ****************************************************************************/
+  // which gAij is this thread responsible for loading?
+#define globalARow (groupRow*MACRO_TILE_NUM_ROWS + localSerial%MACRO_TILE_NUM_ROWS)
+#define globalACol (localSerial/MACRO_TILE_NUM_ROWS)
+#define globalAIdx (GET_GLOBAL_INDEX_A( globalARow, globalACol ) )
+  A += globalAIdx;
+  // which gBij is this thread responsible for loading?
+#define globalBRow (localSerial/MACRO_TILE_NUM_COLS)
+#define globalBCol (groupCol*MACRO_TILE_NUM_COLS + localSerial%MACRO_TILE_NUM_COLS)
+#define globalBIdx (GET_GLOBAL_INDEX_B( globalBRow, globalBCol ) )
+  B += globalBIdx;
+
+  uint block_k = K / NUM_UNROLL_ITER;
+#pragma nounroll
+  do {
+
+    /***************************************************************************
+     * local indices being written
+     **************************************************************************/
+    // which lAij is this thread responsible for writing?
+#define localARow (localSerial % MACRO_TILE_NUM_ROWS)
+#define localACol (localSerial / MACRO_TILE_NUM_ROWS)
+#define localAStride ( (MACRO_TILE_NUM_ROWS+LOCAL_COL_PAD) * (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_ROWS) )
+#define globalAStride ( GET_GLOBAL_INDEX_A(0, (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_ROWS) ) )
+#define localAIdx ( GET_LOCAL_INDEX_A(localARow, localACol) )
+    __local DATA_TYPE_STR *lA = localA + localAIdx;
+    // which lBij is this thread responsible for writing?
+#define localBRow ( localSerial / MACRO_TILE_NUM_COLS )
+#define localBCol ( localSerial % MACRO_TILE_NUM_COLS )
+#define localBIdx ( GET_LOCAL_INDEX_B(localBRow, localBCol) )
+#define localBStride  ( (MACRO_TILE_NUM_COLS+LOCAL_ROW_PAD) * (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_COLS) )
+#define globalBStride ( GET_GLOBAL_INDEX_B( (WG_NUM_ROWS*WG_NUM_COLS/MACRO_TILE_NUM_COLS), 0 ) )
+    __local DATA_TYPE_STR *lB = localB + localBIdx;
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    /***************************************************************************
+     * Load global -> local
+     * num loads = num threads / total loads
+     **************************************************************************/
+    // 2x4 uTile x 8unroll
+    lA[ 0*localAStride ] = A[ 0*globalAStride ];
+    lB[ 0*localBStride ] = B[ 0*globalBStride ];
+    lB[ 1*localBStride ] = B[ 1*globalBStride ];
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    uint offA = localRow;
+    uint offB = localCol;
+
+    /***************************************************************************
+     * do mads in registers
+     **************************************************************************/
+    MAD2x4
+    MAD2x4
+    MAD2x4
+    MAD2x4
+    MAD2x4
+    MAD2x4
+    MAD2x4
+    MAD2x4
+
+    // fully shift
+    A += lda*NUM_UNROLL_ITER; // b/c N
+    B += ldb*NUM_UNROLL_ITER; // b/c T
+
+  } while (--block_k > 0);
+
+  // which global Cij is this thread responsible for computing?
+  uint globalCRow = groupRow * MACRO_TILE_NUM_ROWS + localRow;
+  uint globalCCol = groupCol * MACRO_TILE_NUM_COLS + localCol;
+
+  /***************************************************************************
+   * write data
+   **************************************************************************/
+  double type_mad2_tmp; // used in TYPE_MAD2
+  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+0*WG_NUM_COLS) ], alpha, rC[0][0], beta )
+  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+1*WG_NUM_COLS) ], alpha, rC[0][1], beta )
+  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+2*WG_NUM_COLS) ], alpha, rC[0][2], beta )
+  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+0*WG_NUM_ROWS, globalCCol+3*WG_NUM_COLS) ], alpha, rC[0][3], beta )
+  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+0*WG_NUM_COLS) ], alpha, rC[1][0], beta )
+  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+1*WG_NUM_COLS) ], alpha, rC[1][1], beta )
+  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+2*WG_NUM_COLS) ], alpha, rC[1][2], beta )
+  TYPE_MAD2( C[ GET_GLOBAL_INDEX_C( globalCRow+1*WG_NUM_ROWS, globalCCol+3*WG_NUM_COLS) ], alpha, rC[1][3], beta )
+   
+}
+
+";

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/clblas.git



More information about the debian-science-commits mailing list