[clblas] 37/67: enable big dgemm with split calls

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Oct 27 08:02:13 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch master
in repository clblas.

commit 4d67a9e72f448223b5dff33f2c290644318f780d
Author: Timmy <timmy.liu at amd.com>
Date:   Wed Sep 16 15:59:26 2015 -0500

    enable big dgemm with split calls
---
 src/library/blas/specialCases/GemmSpecialCases.cpp | 309 ++++++++++++++++-----
 1 file changed, 237 insertions(+), 72 deletions(-)

diff --git a/src/library/blas/specialCases/GemmSpecialCases.cpp b/src/library/blas/specialCases/GemmSpecialCases.cpp
index 4a5fd82..be11f41 100644
--- a/src/library/blas/specialCases/GemmSpecialCases.cpp
+++ b/src/library/blas/specialCases/GemmSpecialCases.cpp
@@ -30,7 +30,8 @@
     printf("OpenCL error %i on line %u\n", RET, __LINE__); \
     assert(false); \
     }
-
+/*
+template<typename precision>
 clblasStatus SGEMM_SPLIT_CALLS(
 	cl_kernel *ClKernel, clblasOrder order,
 	unsigned int tile_size, unsigned int WG_size,
@@ -40,10 +41,31 @@ clblasStatus SGEMM_SPLIT_CALLS(
 	clblasTranspose transA,
 	clblasTranspose transB,
 	cl_uint M, cl_uint N, cl_uint K,
-	float alpha,
+	precision alpha,
 	cl_mem A, cl_uint offA, cl_uint lda,
 	cl_mem B, cl_uint offB, cl_uint ldb,
-	float beta,
+	precision beta,
+	cl_mem C, cl_uint offC, cl_uint ldc,
+	cl_uint numCommandQueues,
+	cl_command_queue *commandQueues,
+	cl_uint numEventsInWaitList,
+	const cl_event *eventWaitList,
+	cl_event *events);
+*/
+template<typename precision>
+clblasStatus GEMM_SPLIT_CALLS(
+	cl_kernel *ClKernel, clblasOrder order,
+	unsigned int tile_size, unsigned int WG_size,
+	unsigned int M_split_factor,
+	unsigned int N_split_factor,
+	unsigned int K_split_factor,
+	clblasTranspose transA,
+	clblasTranspose transB,
+	cl_uint M, cl_uint N, cl_uint K,
+	precision alpha,
+	cl_mem A, cl_uint offA, cl_uint lda,
+	cl_mem B, cl_uint offB, cl_uint ldb,
+	precision beta,
 	cl_mem C, cl_uint offC, cl_uint ldc,
 	cl_uint numCommandQueues,
 	cl_command_queue *commandQueues,
@@ -51,7 +73,7 @@ clblasStatus SGEMM_SPLIT_CALLS(
 	const cl_event *eventWaitList,
 	cl_event *events)
 {
-	//for example, when M=N=K=8192
+	//for example, when M=N=K=8192 in GEMM col NT
 	//we are gonna call 16 GEMMs 
 	//each GEMM has M=N=K=4096
 	//note are direct GEMM call has a 0.7 TFLOPS performance
@@ -80,98 +102,103 @@ clblasStatus SGEMM_SPLIT_CALLS(
 	// #15: C22 = a*A23*B23 + 1*C22
 	// #16: C22 = a*A24*B24 + 1*C22 now we are done with C22
 
-	unsigned int small_M = M / M_split_factor;
-	unsigned int small_N = N / N_split_factor;
-	unsigned int small_K = K / K_split_factor;
+	if (transA == clblasNoTrans && transB == clblasTrans)
+	{
 
-	size_t GlobalX = ((small_M - 1) / tile_size + 1) * WG_size;
-	size_t GlobalY = ((small_N - 1) / tile_size + 1) * WG_size;
-	size_t gs[2] = { GlobalX, GlobalY };
-	size_t wgsize[2] = { WG_size, WG_size };
-	cl_int error = 0;
+		unsigned int small_M = M / M_split_factor;
+		unsigned int small_N = N / N_split_factor;
+		unsigned int small_K = K / K_split_factor;
 
-	cl_float betaone = 1;
+		size_t GlobalX = ((small_M - 1) / tile_size + 1) * WG_size;
+		size_t GlobalY = ((small_N - 1) / tile_size + 1) * WG_size;
+		size_t gs[2] = { GlobalX, GlobalY };
+		size_t wgsize[2] = { WG_size, WG_size };
+		cl_int error = 0;
 
-	error = clSetKernelArg(*ClKernel, 5, sizeof(cl_uint), &small_M);
-	assert(error == CL_SUCCESS);
-	error = clSetKernelArg(*ClKernel, 6, sizeof(cl_uint), &small_N);
-	assert(error == CL_SUCCESS);
-	error = clSetKernelArg(*ClKernel, 7, sizeof(cl_uint), &small_K);
-	assert(error == CL_SUCCESS);
+		precision betaone = 1;
 
-	for (int M_split_index = 0; M_split_index < M_split_factor; M_split_index++)
-	{
-		for (int N_split_index = 0; N_split_index < N_split_factor; N_split_index++)
-		{
-			unsigned int offc_C = ldc*N / N_split_factor * N_split_index + M / M_split_factor * M_split_index + offC;
-			error = clSetKernelArg(*ClKernel, 13, sizeof(cl_uint), &offc_C);
-			assert(error == CL_SUCCESS);
+		error = clSetKernelArg(*ClKernel, 5, sizeof(cl_uint), &small_M);
+		assert(error == CL_SUCCESS);
+		error = clSetKernelArg(*ClKernel, 6, sizeof(cl_uint), &small_N);
+		assert(error == CL_SUCCESS);
+		error = clSetKernelArg(*ClKernel, 7, sizeof(cl_uint), &small_K);
+		assert(error == CL_SUCCESS);
 
-			for (int K_split_index = 0; K_split_index < K_split_factor; K_split_index++)
+		for (int M_split_index = 0; M_split_index < M_split_factor; M_split_index++)
+		{
+			for (int N_split_index = 0; N_split_index < N_split_factor; N_split_index++)
 			{
-				unsigned int offa_A = (M / M_split_factor * M_split_index) + (lda * K / K_split_factor * K_split_index) + offA;
-				unsigned int offb_B = (N / N_split_factor * N_split_index) + (ldb * K / K_split_factor * K_split_index) + offB;
-				error = clSetKernelArg(*ClKernel, 11, sizeof(cl_uint), &offa_A);
-				assert(error == CL_SUCCESS);
-				error = clSetKernelArg(*ClKernel, 12, sizeof(cl_uint), &offb_B);
+				unsigned int offc_C = ldc*N / N_split_factor * N_split_index + M / M_split_factor * M_split_index + offC;
+				error = clSetKernelArg(*ClKernel, 13, sizeof(cl_uint), &offc_C);
 				assert(error == CL_SUCCESS);
 
-				if (K_split_index == 0)
+				for (int K_split_index = 0; K_split_index < K_split_factor; K_split_index++)
 				{
-					error = clSetKernelArg(*ClKernel, 4, sizeof(cl_float), &(beta));
+					unsigned int offa_A = (M / M_split_factor * M_split_index) + (lda * K / K_split_factor * K_split_index) + offA;
+					unsigned int offb_B = (N / N_split_factor * N_split_index) + (ldb * K / K_split_factor * K_split_index) + offB;
+					error = clSetKernelArg(*ClKernel, 11, sizeof(cl_uint), &offa_A);
+					assert(error == CL_SUCCESS);
+					error = clSetKernelArg(*ClKernel, 12, sizeof(cl_uint), &offb_B);
 					assert(error == CL_SUCCESS);
 
-					if (M_split_index == 0 && N_split_index == 0)
+					if (K_split_index == 0)
 					{
-						//very first GEMM call
-						if ((M_split_factor == 1) && (N_split_factor == 1) && (K_split_factor == 1))
+						error = clSetKernelArg(*ClKernel, 4, sizeof(precision), &(beta));
+						assert(error == CL_SUCCESS);
+
+						if (M_split_index == 0 && N_split_index == 0)
 						{
-							//also very last GEMM call
-							error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
-								gs, wgsize, numEventsInWaitList, eventWaitList, &events[0]);
-							assert(error == CL_SUCCESS);
+							//very first GEMM call
+							if ((M_split_factor == 1) && (N_split_factor == 1) && (K_split_factor == 1))
+							{
+								//also very last GEMM call
+								error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
+									gs, wgsize, numEventsInWaitList, eventWaitList, &events[0]);
+								assert(error == CL_SUCCESS);
+							}
+							else
+							{
+								error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
+									gs, wgsize, numEventsInWaitList, eventWaitList, NULL);
+								assert(error == CL_SUCCESS);
+							}
 						}
 						else
 						{
 							error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
-								gs, wgsize, numEventsInWaitList, eventWaitList, NULL);
+								gs, wgsize, 0, NULL, NULL);
 							assert(error == CL_SUCCESS);
 						}
 					}
 					else
 					{
-						error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
-							gs, wgsize, 0, NULL, NULL);
+						error = clSetKernelArg(*ClKernel, 4, sizeof(precision), &betaone);
 						assert(error == CL_SUCCESS);
-					}
-				}
-				else
-				{
-					error = clSetKernelArg(*ClKernel, 4, sizeof(cl_float), &betaone);
-					assert(error == CL_SUCCESS);
 
-					if ((M_split_index == (M_split_factor - 1)) && (N_split_index == (N_split_factor - 1)) && (K_split_index == (K_split_factor - 1)))
-					{
-						//very last GEMM call
-						error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
-							gs, wgsize, 0, NULL, events);
-						assert(error == CL_SUCCESS);
-					}
-					else
-					{
-						error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
-							gs, wgsize, 0, NULL, NULL);
-						assert(error == CL_SUCCESS);
+						if ((M_split_index == (M_split_factor - 1)) && (N_split_index == (N_split_factor - 1)) && (K_split_index == (K_split_factor - 1)))
+						{
+							//very last GEMM call
+							error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
+								gs, wgsize, 0, NULL, events);
+							assert(error == CL_SUCCESS);
+						}
+						else
+						{
+							error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
+								gs, wgsize, 0, NULL, NULL);
+							assert(error == CL_SUCCESS);
+						}
 					}
 				}
 			}
 		}
-	}
 
-	return clblasSuccess;
+		return clblasSuccess;
+	}
+	return clblasNotImplemented;
 }
 
-clblasStatus GEMM_mod1024(
+clblasStatus SGEMM_mod1024(
 	clblasTranspose transA,
 	clblasTranspose transB,
 	cl_uint M, cl_uint N, cl_uint K,
@@ -267,7 +294,7 @@ clblasStatus GEMM_mod1024(
 					err = clSetKernelArg(*tileClKernel, 13, sizeof(cl_uint), &offC);
 					CL_CHECK(err);
 
-					status = SGEMM_SPLIT_CALLS(
+					status = GEMM_SPLIT_CALLS(
 						tileClKernel, clblasColumnMajor,
 						128, 16,
 						M_split_factor,
@@ -343,7 +370,7 @@ clblasStatus GEMM_mod1024(
 					CL_CHECK(err);
 
 
-					status = SGEMM_SPLIT_CALLS(
+					status = GEMM_SPLIT_CALLS(
 						tileClKernel, clblasColumnMajor,
 						96, 16,
 						M_split_factor,
@@ -373,7 +400,7 @@ clblasStatus GEMM_mod1024(
 }
 
 
-clblasStatus GEMM_SPLIT64_32(
+clblasStatus SGEMM_SPLIT64_32(
 	clblasTranspose transA,
 	clblasTranspose transB,
 	cl_uint M, cl_uint N, cl_uint K,
@@ -508,7 +535,7 @@ clblasStatus GEMM_SPLIT64_32(
 	return clblasNotImplemented;
 }
 
-clblasStatus GEMM_BRANCH_32(
+clblasStatus SGEMM_BRANCH_32(
 	clblasTranspose transA,
 	clblasTranspose transB,
 	cl_uint M, cl_uint N, cl_uint K,
@@ -682,6 +709,121 @@ clblasStatus GEMM_BRANCH_32(
 	return clblasNotImplemented;
 }
 
+clblasStatus DGEMM_BIG_MOD48(
+	clblasTranspose transA,
+	clblasTranspose transB,
+	cl_uint M, cl_uint N, cl_uint K,
+	double alpha,
+	cl_mem A, cl_uint offA, cl_uint lda,
+	cl_mem B, cl_uint offB, cl_uint ldb,
+	double beta,
+	cl_mem C, cl_uint offC, cl_uint ldc,
+	cl_uint numCommandQueues,
+	cl_command_queue *commandQueues,
+	cl_uint numEventsInWaitList,
+	const cl_event *eventWaitList,
+	cl_event *events,
+	bool &specialCaseHandled)
+{
+	const char *tileKernelSource = NULL;
+	cl_kernel  *tileClKernel = NULL;
+	size_t tileKernelBinarySize = 0;
+	cl_int err;
+
+
+	const unsigned char *tileKernelBinary = NULL;
+
+	clblasStatus status;
+	//split the kernel calls to handle dgemm NT perf drop when matrix sizes are big
+	if ((lda == ldb) && (lda >= 18000) && (lda <= 36000)) // between 18000 and 36000 for now
+	{
+		if (!((transA == clblasNoTrans) && (transB == clblasTrans)))
+			return clblasNotImplemented;
+
+		unsigned int M_split_factor;
+		unsigned int N_split_factor;
+		unsigned int K_split_factor;
+		if ((M % 192 == 0) && (N % 192 == 0) && (K % 192 == 0) && (K > lda / 4))
+		{
+			M_split_factor = 4;
+			N_split_factor = 4;
+			K_split_factor = 4;
+		}
+		else if ((M % 96 == 0) && (N % 96 == 0) && (K % 96 == 0) && (K > lda / 4))
+		{
+			M_split_factor = 2;
+			N_split_factor = 2;
+			K_split_factor = 2;
+		}
+		else
+		{
+			return clblasNotImplemented;
+		}
+
+		tileKernelSource = dgemm_Col_NT_B1_MX048_NX048_KX08_src;
+		tileClKernel = &dgemm_Col_NT_B1_MX048_NX048_KX08_clKernel;
+		tileKernelBinary = dgemm_Col_NT_B1_MX048_NX048_KX08_bin;
+		tileKernelBinarySize = dgemm_Col_NT_B1_MX048_NX048_KX08_binSize;
+
+		makeGemmKernel(tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions);
+
+		err = clSetKernelArg(*tileClKernel, 0, sizeof(cl_mem), &A);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 1, sizeof(cl_mem), &B);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 2, sizeof(cl_mem), &C);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 3, sizeof(cl_double), &alpha);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 4, sizeof(cl_double), &beta);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 5, sizeof(cl_uint), &M);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 6, sizeof(cl_uint), &N);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 7, sizeof(cl_uint), &K);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 8, sizeof(cl_uint), &lda);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 9, sizeof(cl_uint), &ldb);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 10, sizeof(cl_uint), &ldc);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 11, sizeof(cl_uint), &offA);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 12, sizeof(cl_uint), &offB);
+		CL_CHECK(err);
+		err = clSetKernelArg(*tileClKernel, 13, sizeof(cl_uint), &offC);
+		CL_CHECK(err);
+
+		status = GEMM_SPLIT_CALLS(
+			tileClKernel, clblasColumnMajor,
+			48, 8,
+			M_split_factor,
+			N_split_factor, K_split_factor,
+			transA,
+			transB,
+			M, N, K,
+			alpha,
+			A, offA, lda,
+			B, offB, ldb,
+			beta,
+			C, offC, ldc,
+			numCommandQueues,
+			commandQueues,
+			numEventsInWaitList,
+			eventWaitList,
+			events);
+		if (status == clblasSuccess)
+			specialCaseHandled = true;
+
+		return status;
+	}
+
+
+	return clblasNotImplemented;
+}
+
 template<>
 clblasStatus
 GemmSpecialCases<float>(clblasOrder order,
@@ -707,7 +849,7 @@ bool &specialCaseHandled)
 	clblasStatus status;
 
 	//handles big multiples of 1024
-	status = GEMM_mod1024(transA,
+	status = SGEMM_mod1024(transA,
 		transB,
 		M, N, K,
 		alpha,
@@ -725,7 +867,7 @@ bool &specialCaseHandled)
 		return status;
 
 	//handles mod32 but not mod64
-	status = GEMM_SPLIT64_32(transA,
+	status = SGEMM_SPLIT64_32(transA,
 		transB,
 		M, N, K,
 		alpha,
@@ -744,7 +886,7 @@ bool &specialCaseHandled)
 
 	//handles middle range sgemm (M*N<1080*1080) that are not mod32 (M%32!=0 || N%32!=0)
 	//use 32x32 micro tile kernels with branch statement within kernels
-	status = GEMM_BRANCH_32(transA,
+	status = SGEMM_BRANCH_32(transA,
 		transB,
 		M, N, K,
 		alpha,
@@ -783,6 +925,29 @@ const cl_event *eventWaitList,
 cl_event *events,
 bool &specialCaseHandled)
 {
+	if (order == clblasRowMajor)
+		return clblasNotImplemented;
+
+	clblasStatus status;
+	
+	status = DGEMM_BIG_MOD48(transA,
+		transB,
+		M, N, K,
+		alpha,
+		A, offA, lda,
+		B, offB, ldb,
+		beta,
+		C, offC, ldc,
+		numCommandQueues,
+		commandQueues,
+		numEventsInWaitList,
+		eventWaitList,
+		events,
+		specialCaseHandled);
+		
+	if (specialCaseHandled)
+		return status;
+
 	return clblasNotImplemented;
 }
 

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/clblas.git



More information about the debian-science-commits mailing list