[clblas] 37/67: enable big dgemm with split calls
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Tue Oct 27 08:02:13 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch master
in repository clblas.
commit 4d67a9e72f448223b5dff33f2c290644318f780d
Author: Timmy <timmy.liu at amd.com>
Date: Wed Sep 16 15:59:26 2015 -0500
enable big dgemm with split calls
---
src/library/blas/specialCases/GemmSpecialCases.cpp | 309 ++++++++++++++++-----
1 file changed, 237 insertions(+), 72 deletions(-)
diff --git a/src/library/blas/specialCases/GemmSpecialCases.cpp b/src/library/blas/specialCases/GemmSpecialCases.cpp
index 4a5fd82..be11f41 100644
--- a/src/library/blas/specialCases/GemmSpecialCases.cpp
+++ b/src/library/blas/specialCases/GemmSpecialCases.cpp
@@ -30,7 +30,8 @@
printf("OpenCL error %i on line %u\n", RET, __LINE__); \
assert(false); \
}
-
+/*
+template<typename precision>
clblasStatus SGEMM_SPLIT_CALLS(
cl_kernel *ClKernel, clblasOrder order,
unsigned int tile_size, unsigned int WG_size,
@@ -40,10 +41,31 @@ clblasStatus SGEMM_SPLIT_CALLS(
clblasTranspose transA,
clblasTranspose transB,
cl_uint M, cl_uint N, cl_uint K,
- float alpha,
+ precision alpha,
cl_mem A, cl_uint offA, cl_uint lda,
cl_mem B, cl_uint offB, cl_uint ldb,
- float beta,
+ precision beta,
+ cl_mem C, cl_uint offC, cl_uint ldc,
+ cl_uint numCommandQueues,
+ cl_command_queue *commandQueues,
+ cl_uint numEventsInWaitList,
+ const cl_event *eventWaitList,
+ cl_event *events);
+*/
+template<typename precision>
+clblasStatus GEMM_SPLIT_CALLS(
+ cl_kernel *ClKernel, clblasOrder order,
+ unsigned int tile_size, unsigned int WG_size,
+ unsigned int M_split_factor,
+ unsigned int N_split_factor,
+ unsigned int K_split_factor,
+ clblasTranspose transA,
+ clblasTranspose transB,
+ cl_uint M, cl_uint N, cl_uint K,
+ precision alpha,
+ cl_mem A, cl_uint offA, cl_uint lda,
+ cl_mem B, cl_uint offB, cl_uint ldb,
+ precision beta,
cl_mem C, cl_uint offC, cl_uint ldc,
cl_uint numCommandQueues,
cl_command_queue *commandQueues,
@@ -51,7 +73,7 @@ clblasStatus SGEMM_SPLIT_CALLS(
const cl_event *eventWaitList,
cl_event *events)
{
- //for example, when M=N=K=8192
+ //for example, when M=N=K=8192 in GEMM col NT
//we are gonna call 16 GEMMs
//each GEMM has M=N=K=4096
//note are direct GEMM call has a 0.7 TFLOPS performance
@@ -80,98 +102,103 @@ clblasStatus SGEMM_SPLIT_CALLS(
// #15: C22 = a*A23*B23 + 1*C22
// #16: C22 = a*A24*B24 + 1*C22 now we are done with C22
- unsigned int small_M = M / M_split_factor;
- unsigned int small_N = N / N_split_factor;
- unsigned int small_K = K / K_split_factor;
+ if (transA == clblasNoTrans && transB == clblasTrans)
+ {
- size_t GlobalX = ((small_M - 1) / tile_size + 1) * WG_size;
- size_t GlobalY = ((small_N - 1) / tile_size + 1) * WG_size;
- size_t gs[2] = { GlobalX, GlobalY };
- size_t wgsize[2] = { WG_size, WG_size };
- cl_int error = 0;
+ unsigned int small_M = M / M_split_factor;
+ unsigned int small_N = N / N_split_factor;
+ unsigned int small_K = K / K_split_factor;
- cl_float betaone = 1;
+ size_t GlobalX = ((small_M - 1) / tile_size + 1) * WG_size;
+ size_t GlobalY = ((small_N - 1) / tile_size + 1) * WG_size;
+ size_t gs[2] = { GlobalX, GlobalY };
+ size_t wgsize[2] = { WG_size, WG_size };
+ cl_int error = 0;
- error = clSetKernelArg(*ClKernel, 5, sizeof(cl_uint), &small_M);
- assert(error == CL_SUCCESS);
- error = clSetKernelArg(*ClKernel, 6, sizeof(cl_uint), &small_N);
- assert(error == CL_SUCCESS);
- error = clSetKernelArg(*ClKernel, 7, sizeof(cl_uint), &small_K);
- assert(error == CL_SUCCESS);
+ precision betaone = 1;
- for (int M_split_index = 0; M_split_index < M_split_factor; M_split_index++)
- {
- for (int N_split_index = 0; N_split_index < N_split_factor; N_split_index++)
- {
- unsigned int offc_C = ldc*N / N_split_factor * N_split_index + M / M_split_factor * M_split_index + offC;
- error = clSetKernelArg(*ClKernel, 13, sizeof(cl_uint), &offc_C);
- assert(error == CL_SUCCESS);
+ error = clSetKernelArg(*ClKernel, 5, sizeof(cl_uint), &small_M);
+ assert(error == CL_SUCCESS);
+ error = clSetKernelArg(*ClKernel, 6, sizeof(cl_uint), &small_N);
+ assert(error == CL_SUCCESS);
+ error = clSetKernelArg(*ClKernel, 7, sizeof(cl_uint), &small_K);
+ assert(error == CL_SUCCESS);
- for (int K_split_index = 0; K_split_index < K_split_factor; K_split_index++)
+ for (int M_split_index = 0; M_split_index < M_split_factor; M_split_index++)
+ {
+ for (int N_split_index = 0; N_split_index < N_split_factor; N_split_index++)
{
- unsigned int offa_A = (M / M_split_factor * M_split_index) + (lda * K / K_split_factor * K_split_index) + offA;
- unsigned int offb_B = (N / N_split_factor * N_split_index) + (ldb * K / K_split_factor * K_split_index) + offB;
- error = clSetKernelArg(*ClKernel, 11, sizeof(cl_uint), &offa_A);
- assert(error == CL_SUCCESS);
- error = clSetKernelArg(*ClKernel, 12, sizeof(cl_uint), &offb_B);
+ unsigned int offc_C = ldc*N / N_split_factor * N_split_index + M / M_split_factor * M_split_index + offC;
+ error = clSetKernelArg(*ClKernel, 13, sizeof(cl_uint), &offc_C);
assert(error == CL_SUCCESS);
- if (K_split_index == 0)
+ for (int K_split_index = 0; K_split_index < K_split_factor; K_split_index++)
{
- error = clSetKernelArg(*ClKernel, 4, sizeof(cl_float), &(beta));
+ unsigned int offa_A = (M / M_split_factor * M_split_index) + (lda * K / K_split_factor * K_split_index) + offA;
+ unsigned int offb_B = (N / N_split_factor * N_split_index) + (ldb * K / K_split_factor * K_split_index) + offB;
+ error = clSetKernelArg(*ClKernel, 11, sizeof(cl_uint), &offa_A);
+ assert(error == CL_SUCCESS);
+ error = clSetKernelArg(*ClKernel, 12, sizeof(cl_uint), &offb_B);
assert(error == CL_SUCCESS);
- if (M_split_index == 0 && N_split_index == 0)
+ if (K_split_index == 0)
{
- //very first GEMM call
- if ((M_split_factor == 1) && (N_split_factor == 1) && (K_split_factor == 1))
+ error = clSetKernelArg(*ClKernel, 4, sizeof(precision), &(beta));
+ assert(error == CL_SUCCESS);
+
+ if (M_split_index == 0 && N_split_index == 0)
{
- //also very last GEMM call
- error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
- gs, wgsize, numEventsInWaitList, eventWaitList, &events[0]);
- assert(error == CL_SUCCESS);
+ //very first GEMM call
+ if ((M_split_factor == 1) && (N_split_factor == 1) && (K_split_factor == 1))
+ {
+ //also very last GEMM call
+ error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
+ gs, wgsize, numEventsInWaitList, eventWaitList, &events[0]);
+ assert(error == CL_SUCCESS);
+ }
+ else
+ {
+ error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
+ gs, wgsize, numEventsInWaitList, eventWaitList, NULL);
+ assert(error == CL_SUCCESS);
+ }
}
else
{
error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
- gs, wgsize, numEventsInWaitList, eventWaitList, NULL);
+ gs, wgsize, 0, NULL, NULL);
assert(error == CL_SUCCESS);
}
}
else
{
- error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
- gs, wgsize, 0, NULL, NULL);
+ error = clSetKernelArg(*ClKernel, 4, sizeof(precision), &betaone);
assert(error == CL_SUCCESS);
- }
- }
- else
- {
- error = clSetKernelArg(*ClKernel, 4, sizeof(cl_float), &betaone);
- assert(error == CL_SUCCESS);
- if ((M_split_index == (M_split_factor - 1)) && (N_split_index == (N_split_factor - 1)) && (K_split_index == (K_split_factor - 1)))
- {
- //very last GEMM call
- error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
- gs, wgsize, 0, NULL, events);
- assert(error == CL_SUCCESS);
- }
- else
- {
- error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
- gs, wgsize, 0, NULL, NULL);
- assert(error == CL_SUCCESS);
+ if ((M_split_index == (M_split_factor - 1)) && (N_split_index == (N_split_factor - 1)) && (K_split_index == (K_split_factor - 1)))
+ {
+ //very last GEMM call
+ error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
+ gs, wgsize, 0, NULL, events);
+ assert(error == CL_SUCCESS);
+ }
+ else
+ {
+ error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL,
+ gs, wgsize, 0, NULL, NULL);
+ assert(error == CL_SUCCESS);
+ }
}
}
}
}
- }
- return clblasSuccess;
+ return clblasSuccess;
+ }
+ return clblasNotImplemented;
}
-clblasStatus GEMM_mod1024(
+clblasStatus SGEMM_mod1024(
clblasTranspose transA,
clblasTranspose transB,
cl_uint M, cl_uint N, cl_uint K,
@@ -267,7 +294,7 @@ clblasStatus GEMM_mod1024(
err = clSetKernelArg(*tileClKernel, 13, sizeof(cl_uint), &offC);
CL_CHECK(err);
- status = SGEMM_SPLIT_CALLS(
+ status = GEMM_SPLIT_CALLS(
tileClKernel, clblasColumnMajor,
128, 16,
M_split_factor,
@@ -343,7 +370,7 @@ clblasStatus GEMM_mod1024(
CL_CHECK(err);
- status = SGEMM_SPLIT_CALLS(
+ status = GEMM_SPLIT_CALLS(
tileClKernel, clblasColumnMajor,
96, 16,
M_split_factor,
@@ -373,7 +400,7 @@ clblasStatus GEMM_mod1024(
}
-clblasStatus GEMM_SPLIT64_32(
+clblasStatus SGEMM_SPLIT64_32(
clblasTranspose transA,
clblasTranspose transB,
cl_uint M, cl_uint N, cl_uint K,
@@ -508,7 +535,7 @@ clblasStatus GEMM_SPLIT64_32(
return clblasNotImplemented;
}
-clblasStatus GEMM_BRANCH_32(
+clblasStatus SGEMM_BRANCH_32(
clblasTranspose transA,
clblasTranspose transB,
cl_uint M, cl_uint N, cl_uint K,
@@ -682,6 +709,121 @@ clblasStatus GEMM_BRANCH_32(
return clblasNotImplemented;
}
+clblasStatus DGEMM_BIG_MOD48(
+ clblasTranspose transA,
+ clblasTranspose transB,
+ cl_uint M, cl_uint N, cl_uint K,
+ double alpha,
+ cl_mem A, cl_uint offA, cl_uint lda,
+ cl_mem B, cl_uint offB, cl_uint ldb,
+ double beta,
+ cl_mem C, cl_uint offC, cl_uint ldc,
+ cl_uint numCommandQueues,
+ cl_command_queue *commandQueues,
+ cl_uint numEventsInWaitList,
+ const cl_event *eventWaitList,
+ cl_event *events,
+ bool &specialCaseHandled)
+{
+ const char *tileKernelSource = NULL;
+ cl_kernel *tileClKernel = NULL;
+ size_t tileKernelBinarySize = 0;
+ cl_int err;
+
+
+ const unsigned char *tileKernelBinary = NULL;
+
+ clblasStatus status;
+ //split the kernel calls to handle dgemm NT perf drop when matrix sizes are big
+ if ((lda == ldb) && (lda >= 18000) && (lda <= 36000)) // between 18000 and 36000 for now
+ {
+ if (!((transA == clblasNoTrans) && (transB == clblasTrans)))
+ return clblasNotImplemented;
+
+ unsigned int M_split_factor;
+ unsigned int N_split_factor;
+ unsigned int K_split_factor;
+ if ((M % 192 == 0) && (N % 192 == 0) && (K % 192 == 0) && (K > lda / 4))
+ {
+ M_split_factor = 4;
+ N_split_factor = 4;
+ K_split_factor = 4;
+ }
+ else if ((M % 96 == 0) && (N % 96 == 0) && (K % 96 == 0) && (K > lda / 4))
+ {
+ M_split_factor = 2;
+ N_split_factor = 2;
+ K_split_factor = 2;
+ }
+ else
+ {
+ return clblasNotImplemented;
+ }
+
+ tileKernelSource = dgemm_Col_NT_B1_MX048_NX048_KX08_src;
+ tileClKernel = &dgemm_Col_NT_B1_MX048_NX048_KX08_clKernel;
+ tileKernelBinary = dgemm_Col_NT_B1_MX048_NX048_KX08_bin;
+ tileKernelBinarySize = dgemm_Col_NT_B1_MX048_NX048_KX08_binSize;
+
+ makeGemmKernel(tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions);
+
+ err = clSetKernelArg(*tileClKernel, 0, sizeof(cl_mem), &A);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 1, sizeof(cl_mem), &B);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 2, sizeof(cl_mem), &C);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 3, sizeof(cl_double), &alpha);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 4, sizeof(cl_double), &beta);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 5, sizeof(cl_uint), &M);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 6, sizeof(cl_uint), &N);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 7, sizeof(cl_uint), &K);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 8, sizeof(cl_uint), &lda);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 9, sizeof(cl_uint), &ldb);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 10, sizeof(cl_uint), &ldc);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 11, sizeof(cl_uint), &offA);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 12, sizeof(cl_uint), &offB);
+ CL_CHECK(err);
+ err = clSetKernelArg(*tileClKernel, 13, sizeof(cl_uint), &offC);
+ CL_CHECK(err);
+
+ status = GEMM_SPLIT_CALLS(
+ tileClKernel, clblasColumnMajor,
+ 48, 8,
+ M_split_factor,
+ N_split_factor, K_split_factor,
+ transA,
+ transB,
+ M, N, K,
+ alpha,
+ A, offA, lda,
+ B, offB, ldb,
+ beta,
+ C, offC, ldc,
+ numCommandQueues,
+ commandQueues,
+ numEventsInWaitList,
+ eventWaitList,
+ events);
+ if (status == clblasSuccess)
+ specialCaseHandled = true;
+
+ return status;
+ }
+
+
+ return clblasNotImplemented;
+}
+
template<>
clblasStatus
GemmSpecialCases<float>(clblasOrder order,
@@ -707,7 +849,7 @@ bool &specialCaseHandled)
clblasStatus status;
//handles big multiples of 1024
- status = GEMM_mod1024(transA,
+ status = SGEMM_mod1024(transA,
transB,
M, N, K,
alpha,
@@ -725,7 +867,7 @@ bool &specialCaseHandled)
return status;
//handles mod32 but not mod64
- status = GEMM_SPLIT64_32(transA,
+ status = SGEMM_SPLIT64_32(transA,
transB,
M, N, K,
alpha,
@@ -744,7 +886,7 @@ bool &specialCaseHandled)
//handles middle range sgemm (M*N<1080*1080) that are not mod32 (M%32!=0 || N%32!=0)
//use 32x32 micro tile kernels with branch statement within kernels
- status = GEMM_BRANCH_32(transA,
+ status = SGEMM_BRANCH_32(transA,
transB,
M, N, K,
alpha,
@@ -783,6 +925,29 @@ const cl_event *eventWaitList,
cl_event *events,
bool &specialCaseHandled)
{
+ if (order == clblasRowMajor)
+ return clblasNotImplemented;
+
+ clblasStatus status;
+
+ status = DGEMM_BIG_MOD48(transA,
+ transB,
+ M, N, K,
+ alpha,
+ A, offA, lda,
+ B, offB, ldb,
+ beta,
+ C, offC, ldc,
+ numCommandQueues,
+ commandQueues,
+ numEventsInWaitList,
+ eventWaitList,
+ events,
+ specialCaseHandled);
+
+ if (specialCaseHandled)
+ return status;
+
return clblasNotImplemented;
}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/clblas.git
More information about the debian-science-commits
mailing list