[clblas] 20/61: improve sgemm column major TN small matrix perf. some type/bug fixes

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Fri Jul 24 22:49:45 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch master
in repository clblas.

commit a280c968ede336a75705ff368dba073b40a37f3e
Author: Timmy <timmy.liu at amd.com>
Date:   Wed Apr 29 11:27:35 2015 -0500

    improve sgemm column major TN small matrix perf. some type/bug fixes
---
 src/library/blas/functor/hawaii.cc                 |   6 +-
 .../blas/functor/hawaii_sgemmBranchKernel.cc       |  46 ++++++--
 .../gens/clTemplates/sgemm_gcn_SmallMatrices.cl    | 124 +++++++++++++++++++++
 3 files changed, 164 insertions(+), 12 deletions(-)

diff --git a/src/library/blas/functor/hawaii.cc b/src/library/blas/functor/hawaii.cc
index 7f0e58c..3824b91 100644
--- a/src/library/blas/functor/hawaii.cc
+++ b/src/library/blas/functor/hawaii.cc
@@ -101,7 +101,7 @@ clblasSgemmFunctor * FunctorSelectorHawaii::select_sgemm_specific(clblasSgemmFun
   //TODO: the logic below is complicated; Needs cleanup;
   clblasSgemmFunctor * functor;
   bool Not_TT = ((args.transA==clblasNoTrans && args.transB==clblasTrans ) || ( args.transA==clblasNoTrans && args.transB==clblasNoTrans ) || ( args.transA==clblasTrans && args.transB==clblasNoTrans ));
-  bool SmallMatrices = args.M/6*args.N/6<150*150 || ((args.M%64!=0 && args.N%64!=0 && args.M<1900 &&args.N<1900 ) && (args.M%96!=0 && args.N%96!=0 && args.M<1900 &&args.N<1900 ));
+  bool SmallMatrices = args.M/6*args.N/6<180*180 || ((args.M%64!=0 && args.N%64!=0 && args.M<1900 &&args.N<1900 ) && (args.M%96!=0 && args.N%96!=0 && args.M<1900 &&args.N<1900 ));
   bool SmallMatricesMod32= (SmallMatrices && (args.M%32==0&&args.N%32==0)) ;
   SmallMatricesMod32 = SmallMatricesMod32&&Not_TT&&args.K % 16 == 0;
   //SmallMatrices= false;
@@ -109,8 +109,8 @@ clblasSgemmFunctor * FunctorSelectorHawaii::select_sgemm_specific(clblasSgemmFun
   bool useSpliKernel=((args.M%96==0 && args.N%96==0) || !(args.M%64==0 && args.N%64==0&& args.M<4000 &&args.N<4000)) /*&&args.K%16==0*/;
   useSpliKernel=useSpliKernel&&Not_TT;
   
-  //the English translation of below is: if small matrix that is (not mod32) and (NT or NN) and K has to be mod 16
-  if (SmallMatrices && (!SmallMatricesMod32) && (args.transA == clblasNoTrans) && (args.K%16 == 0))
+  //the English translation of below is: if small matrix that is (not mod32) and (not_TT) and K has to be mod 16
+  if (SmallMatrices && (!SmallMatricesMod32) && (Not_TT) && (args.K%16 == 0))
   {
 	  functor = clBlashawaiiSgemmBranchKernelFunctor::provide(args, "Hawaii");
 	  if (functor)
diff --git a/src/library/blas/functor/hawaii_sgemmBranchKernel.cc b/src/library/blas/functor/hawaii_sgemmBranchKernel.cc
index 34c5f33..6192cd3 100644
--- a/src/library/blas/functor/hawaii_sgemmBranchKernel.cc
+++ b/src/library/blas/functor/hawaii_sgemmBranchKernel.cc
@@ -123,7 +123,7 @@ static const Variant * select_variant_BranchKernel(clblasSgemmFunctor::Args & ar
 
 			// ===== sgemm NN ======
 			// sgemm_NN_32_32_16_16x16_2x2__ALPHABETA_BRANCH
-			const char* KName_NT = SGEMM_KERNEL_NAME(N, N, 32, 32, 16, 16, 16, 2, 2, __ALPHABETA, BRANCH);
+			const char* KName_NN = SGEMM_KERNEL_NAME(N, N, 32, 32, 16, 16, 16, 2, 2, __ALPHABETA, BRANCH);
 			const char* KBin_NN64;
 			size_t KBin_NNSize64 = 0;
 #if BUILD_KERNEL_FROM_STRING
@@ -132,14 +132,14 @@ static const Variant * select_variant_BranchKernel(clblasSgemmFunctor::Args & ar
 #else
 			if (!strcmp(DevName, "Hawaii"))
 			{
-				//KBin_NT64             = SGEMM_SRC_NAME_BIN(N, T, 16, __ALPHABETA,  64, HAWAII) ;
+
 				KBin_NN64 = sgemm_NN_32_32_16_16x16_2x2__ALPHABETA_BRANCH_64_bin_Hawaii;
 				KBin_NNSize64 = sizeof(sgemm_NN_32_32_16_16x16_2x2__ALPHABETA_BRANCH_64_bin_Hawaii);
 
 			}
 #endif
 			static const Variant variant = SGEMM_VARIANT_OBJ(N, N, 16, 16, 16, 2, 2, 64, __ALPHABETA,
-				KName_NT,
+				KName_NN,
 				NULL,
 				NULL,
 				KBin_NN64,
@@ -188,21 +188,49 @@ static const Variant * select_variant_BranchKernel(clblasSgemmFunctor::Args & ar
 
 			return &variant;
 		}
-		else
+	}
+	else
+	{
+		if (args.transB == clblasNoTrans)
 		{
-			if (args.transB == clblasNoTrans)
-			{
 
 				// ===== sgemm TN ======
-				// currently not supported
+				//sgemm_TN_32_32_16_16x16_2x2__ALPHABETA_BRANCH
+				const char* KName_TN = SGEMM_KERNEL_NAME(T, N, 32, 32, 16, 16, 16, 2, 2, __ALPHABETA, BRANCH);
+
+
+				const char* KBin_TN64;
+				size_t KBin_TNSize64 = 0;
+
+
+#if BUILD_KERNEL_FROM_STRING
+				//currently not supported
 				return NULL;
-			}
-		}
+#else
+				if (!strcmp(DevName, "Hawaii"))
+				{
+					KBin_TN64 = sgemm_TN_32_32_16_16x16_2x2__ALPHABETA_BRANCH_64_bin_Hawaii;
+					KBin_TNSize64 = sizeof(sgemm_TN_32_32_16_16x16_2x2__ALPHABETA_BRANCH_64_bin_Hawaii);
 
+				}
+#endif
+				// ===== SGEMM NT ======
+				static const Variant variant = SGEMM_VARIANT_OBJ(T, N, 16, 16, 16, 2, 2, 64, __ALPHABETA,
+					KName_TN,
+					NULL,
+					NULL,
+					KBin_TN64,
+					KBin_TNSize64);
+
+				return &variant;
+		}
 		return NULL;
 	}
+
+		return NULL;
 }
 
+
 clBlashawaiiSgemmBranchKernelFunctor::clBlashawaiiSgemmBranchKernelFunctor(Args & args, const Variant * variant, cl_int & err)
 {
 
diff --git a/src/library/blas/gens/clTemplates/sgemm_gcn_SmallMatrices.cl b/src/library/blas/gens/clTemplates/sgemm_gcn_SmallMatrices.cl
index 17aa291..137cb3a 100644
--- a/src/library/blas/gens/clTemplates/sgemm_gcn_SmallMatrices.cl
+++ b/src/library/blas/gens/clTemplates/sgemm_gcn_SmallMatrices.cl
@@ -910,3 +910,127 @@ __kernel void sgemm_TN_32_32_16_16x16_2x2__ALPHA( __global float const * restric
 
 ";
 
+static const char * sgemm_TN_32_32_16_16x16_2x2__ALPHABETA_BRANCH = "
+
+#define  M2x2 \
+            rA[0][0] = lA[offA + 0];				  \
+            rA[0][1] = lA[offA + 16];				  \
+            rB[0][0] = lB[offB + 0];				  \
+            rB[0][1] = lB[offB + 16];				  \
+            offA += 33;								  \
+            offB += 33;								  \
+            rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \
+            rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \
+            rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \
+            rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \
+			mem_fence(CLK_LOCAL_MEM_FENCE);
+
+__attribute__((reqd_work_group_size(16,16,1)))
+
+__kernel void sgemm_TN_32_32_16_16x16_2x2__ALPHABETA_BRANCH( __global float const * restrict A,
+  __global float const * restrict B,
+  __global float * C,
+  uint const M,
+  uint const N,
+  uint const K,
+  float const alpha,
+  float const beta,
+  uint lda,
+  uint ldb,
+  uint ldc,
+  uint offsetA,
+  uint offsetB,
+  uint offsetC)
+{
+    float rC[2][2]  = {(float)0};
+    float rA[1][2];
+    float rB[1][2];
+    
+    
+    A += offsetA;
+    B += offsetB;
+    C+=offsetC;
+    
+    __local float lA[528];//16*32+16
+    __local float lB[528];
+    
+    uint gidx = get_group_id(0);
+    uint gidy = get_group_id(1);
+    uint idx = get_local_id(0);
+    uint idy = get_local_id(1);
+    
+    int CurrentOffSetA = gidx*32+ idy;
+    int CurrentOffSetB = gidy*32+ idy;
+
+    A +=  (gidx*32+idy)*lda + idx;
+    B +=  (gidy*32+idy)*ldb + idx;
+    
+   
+    uint block_k = K >> 4;
+    do 
+    {
+      __local float* plA = lA + idx*33+idy;
+      __local float* plB = lB + idx*33+idy;
+      barrier(CLK_LOCAL_MEM_FENCE);
+  
+      plB[0]  = CurrentOffSetB>=N?0.0:B[0];
+      plB[16] = CurrentOffSetB+16>=N?0.0:B[16*ldb];
+
+      plA[0]  = CurrentOffSetA>=M?0.0:A[0];
+      plA[16] = CurrentOffSetA+16>=M?0.0:A[16*lda];
+
+
+      barrier(CLK_LOCAL_MEM_FENCE);
+      uint offA = idx;
+      uint offB = idy;
+
+
+        M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+		M2x2
+
+      A += 16;
+      B += 16;
+    } while (--block_k > 0);
+
+
+    int offset_x = gidx*32+idx;
+    int offset_y = gidy*32+ idy;
+
+    if(offset_x>=M || offset_y>=N )
+      return;
+
+    C+=offset_x+offset_y*ldc;
+
+    int i = 0;
+    do 
+    {
+      C[0     ] = mad(alpha, rC[i][0], beta*C[0]);
+      if(offset_y+16<N)
+        C[16*ldc] = mad(alpha, rC[i][1], beta*C[16*ldc]);
+
+      C+=16;
+      offset_x+=16;
+      if(offset_x>=M )
+        return;
+
+
+    }
+    while (++i < 2);
+}
+
+";
+

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/clblas.git



More information about the debian-science-commits mailing list