1 #include "CLA_Matrix.h"
5 using std::ostringstream;
9 #ifdef FORTRANUNDERSCORE
15 extern "C" {
void DGEMM(
char *,
char *,
int *,
int *,
int *,
double *,
16 double *,
int *,
double *,
int *,
double *,
19 void myGEMM(
char *opA,
char *opB,
int *m,
int *n,
int *k,
double *alpha,
complex *A,
int *lda,
complex *B,
int *ldb,
double *beta,
complex *C,
int *ldc);
20 void myGEMM(
char *opA,
char *opB,
int *m,
int *n,
int *k,
double *alpha,
double *A,
int *lda,
double *B,
int *ldb,
double *beta,
double *C,
int *ldc);
24 extern CkReduction::reducerType sumFastDoubleType;
37 CProxy_ArrayElement bindA, CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
47 CkCallback cbA, CkCallback cbB, CkCallback cbC,
48 CkGroupID gid,
int algorithm,
int gemmSplitOrtho
52 if(algorithm < MM_ALG_MIN || MM_ALG_MAX < algorithm)
53 return ERR_INVALID_ALG;
55 if(m > M || k > K || n > N)
56 return ERR_INVALID_DIM;
59 CkArrayOptions optsA, optsB, optsC;
63 optsA.setAnytimeMigration(
false);
64 optsB.setAnytimeMigration(
false);
65 optsC.setAnytimeMigration(
false);
66 CProxy_CLA_Matrix pa = CProxy_CLA_Matrix::ckNew(optsA);
67 CProxy_CLA_Matrix pb = CProxy_CLA_Matrix::ckNew(optsB);
68 CProxy_CLA_Matrix pc = CProxy_CLA_Matrix::ckNew(optsC);
74 int M_chunks = (M + m - 1) / m;
75 int K_chunks = (K + k - 1) / k;
76 int N_chunks = (N + n - 1) / n;
86 if(algorithm == MM_ALG_2D)
88 for(
int i = 0; i < M_chunks; i++)
89 for(
int j = 0; j < K_chunks; j++)
90 (A->p(i * strideM, j * strideK)).insert(M, K, N, m, k, n, strideM, strideK, strideN, MULTARG_A, B->p, C->p, cbA, gemmSplitOrtho);
93 for(
int i = 0; i < K_chunks; i++)
94 for(
int j = 0; j < N_chunks; j++)
95 (B->p(i * strideK, j * strideN)).insert(M, K, N, m, k, n, strideM, strideK, strideN, MULTARG_B, A->p, C->p, cbB, gemmSplitOrtho);
98 for(
int i = 0; i < M_chunks; i++)
99 for(
int j = 0; j < N_chunks; j++)
100 (C->p(i * strideM, j * strideN)).insert(M, K, N, m, k, n, strideM, strideK, strideN, MULTARG_C, A->p, B->p, cbC, gemmSplitOrtho);
101 C->p.doneInserting();
103 else if(algorithm == MM_ALG_3D)
105 CProxy_CLA_MM3D_multiplier mult = CProxy_CLA_MM3D_multiplier::ckNew();
107 int totpe = CkNumPes();
108 for(
int i = 0; i < M_chunks; i++)
111 if(i == M_chunks - 1)
118 for(
int j = 0; j < N_chunks; j++)
121 if(j == N_chunks - 1)
128 for(
int l = 0; l < K_chunks; l++)
131 if(l == K_chunks - 1)
137 mult(i, j, l).insert(mm, kk, nn, curpe);
138 curpe = (curpe + 1) % totpe;
142 mult.doneInserting();
144 for(
int i = 0; i < M_chunks; i++)
145 for(
int j = 0; j < K_chunks; j++)
146 (A->p(i * strideM, j * strideK)).insert(mult, M, K, N, m, k, n,strideM, strideK, strideN, MULTARG_A, cbA, gid, gemmSplitOrtho);
147 A->p.doneInserting();
149 for(
int i = 0; i < K_chunks; i++)
150 for(
int j = 0; j < N_chunks; j++)
151 (B->p(i * strideK, j * strideN)).insert(mult, M, K, N, m, k, n, strideM, strideK, strideN, MULTARG_B, cbB, gid, gemmSplitOrtho);
152 B->p.doneInserting();
154 for(
int i = 0; i < M_chunks; i++)
155 for(
int j = 0; j < N_chunks; j++)
156 (C->p(i * strideM, j * strideN)).insert(mult, M, K, N, m, k, n, strideM, strideK, strideN, MULTARG_C, cbC, gid, gemmSplitOrtho);
157 C->p.doneInserting();
170 CLA_Matrix::CLA_Matrix(
int _M,
int _K,
int _N,
int _m,
int _k,
int _n,
171 int strideM,
int strideK,
int strideN,
int _part,
172 CProxy_CLA_Matrix _other1, CProxy_CLA_Matrix _other2, CkCallback ready,
int _gemmSplitOrtho){
174 this->M = _M; this->K = _K; this->N = _N;
175 this->um = _m; this->uk = _k; this->un = _n;
176 this->m = _m; this->k = _k; this->n = _n;
178 this->algorithm = MM_ALG_2D;
179 this->other1 = _other1; this->other2 = _other2;
180 this->M_stride = strideM;
181 this->K_stride = strideK;
182 this->N_stride = strideN;
183 gemmSplitOrtho=_gemmSplitOrtho;
184 M_chunks = (_M + _m - 1) / _m;
185 K_chunks = (_K + _k - 1) / _k;
186 N_chunks = (_N + _n - 1) / _n;
195 algorithm = MM_ALG_2D;
197 setMigratable(
false);
199 if(part == MULTARG_A){
200 if(thisIndex.x == (M_chunks - 1) * strideM){
201 this->m = _m + _M % _m;
202 if(this->m == 0) this->m = _m;
205 if(thisIndex.y == (K_chunks - 1) * strideK){
206 this->k = _k + _K % _k;
207 if(this->k == 0) this->k = _k;
211 }
else if(part == MULTARG_B) {
212 if(thisIndex.x == (K_chunks - 1) * strideK){
213 this->k = _k + _K % _k;
214 if(this->k == 0) this->k = _k;
217 if(thisIndex.y == (N_chunks - 1) * strideN){
218 this->n = _n + _N % _n;
219 if(this->n == 0) this->n = _n;
223 }
else if(part == MULTARG_C) {
224 if(thisIndex.x == (M_chunks - 1) * strideM){
225 this->m = _m + _M % _m;
226 if(this->m == 0) this->m = _m;
229 if(thisIndex.y == (N_chunks - 1) * strideN){
230 this->n = _n + _N % _n;
231 if(this->n == 0) this->n = _n;
236 row_count = col_count = 0;
240 if(part == MULTARG_A){
241 commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, thisIndex.x,
242 thisIndex.x, 1, 0, (N_chunks - 1) * strideN, strideN);
244 }
else if(part == MULTARG_B) {
245 commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, 0,
246 (M_chunks - 1) * strideM, strideM, thisIndex.y, thisIndex.y, 1);
248 }
else if(part == MULTARG_C) {
249 tmpA =
new internalType[this->m * K];
250 tmpB =
new internalType[K * this->n];
254 contribute(0, NULL, CkReduction::sum_int, ready);
258 CLA_Matrix::CLA_Matrix(CProxy_CLA_MM3D_multiplier p,
int M,
int K,
int N,
259 int m,
int k,
int n,
int strideM,
int strideK,
int strideN,
int part,
260 CkCallback cb, CkGroupID gid,
int _gemmSplitOrtho){
262 this->M = M; this->K = K; this->N = N;
263 this->um = m; this->uk = k; this->un = n;
265 this->algorithm = MM_ALG_2D;
266 this->other1 = other1; this->other2 = other2;
267 this->M_stride = strideM;
268 this->K_stride = strideK;
269 this->N_stride = strideN;
270 gemmSplitOrtho=_gemmSplitOrtho;
271 M_chunks = (M + m - 1) / m;
272 K_chunks = (K + k - 1) / k;
273 N_chunks = (N + n - 1) / n;
274 got_data = got_start =
false;
276 algorithm = MM_ALG_3D;
278 setMigratable(
false);
280 if(part == MULTARG_A){
281 if(thisIndex.x == (M_chunks - 1) * strideM){
283 if(this->m == 0) this->m = m;
286 if(thisIndex.y == (K_chunks - 1) * strideK){
288 if(this->k == 0) this->k = k;
292 }
else if(part == MULTARG_B) {
293 if(thisIndex.x == (K_chunks - 1) * strideK){
295 if(this->k == 0) this->k = k;
298 if(thisIndex.y == (N_chunks - 1) * strideN){
300 if(this->n == 0) this->n = n;
304 }
else if(part == MULTARG_C) {
305 if(thisIndex.x == (M_chunks - 1) * strideM){
307 if(this->m == 0) this->m = m;
310 if(thisIndex.y == (N_chunks - 1) * strideN){
312 if(this->n == 0) this->n = n;
319 if(part == MULTARG_A){
320 int x = thisIndex.x / strideM;
321 int y = thisIndex.y / strideK;
322 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, 0,
323 N_chunks - 1, 1, y, y, 1);
324 contribute(0, NULL, CkReduction::sum_int, cb);
325 }
else if(part == MULTARG_B) {
326 int x = thisIndex.x / strideK;
327 int y = thisIndex.y / strideN;
328 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, 0,
329 M_chunks - 1, 1, y, y, 1, x, x, 1);
330 contribute(0, NULL, CkReduction::sum_int, cb);
331 }
else if(part == MULTARG_C) {
333 int x = thisIndex.x / strideM;
334 int y = thisIndex.y / strideN;
335 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, y, y, 1,
337 commGroup3D.ckSectionDelegate(CProxy_CkMulticastMgr(gid).ckLocalBranch());
339 CkCallback(CkIndex_CLA_Matrix::readyC(NULL),
340 thisProxy(thisIndex.x, thisIndex.y)), CkCallback(
341 CkIndex_CLA_Matrix::mult_done(NULL), thisProxy(thisIndex.x,
343 commGroup3D.initialize_reduction(m);
347 CLA_Matrix::~CLA_Matrix(){
348 if(algorithm == MM_ALG_2D){
352 else if(algorithm == MM_ALG_3D){
358 void CLA_Matrix::pup(PUP::er &p){
360 if(algorithm == MM_ALG_3D){
361 CmiAbort(
"3D algorithm does not currently support migration.\n");
365 CBase_CLA_Matrix::pup(p);
368 p | M; p | K; p | N; p | m; p | k; p | n; p | um; p | uk; p | un;
369 p | M_chunks; p | K_chunks; p | N_chunks;
370 p | M_stride; p | K_stride; p | N_stride;
371 p | part; p | algorithm;
375 if(algorithm == MM_ALG_2D){
376 p | row_count; p | col_count;
377 p | other1; p | other2;
378 if(part == MULTARG_C){
380 tmpA =
new internalType[m * K];
381 tmpB =
new internalType[K * n];
383 PUParray(p, tmpA, m * K);
384 PUParray(p, tmpB, K * n);
387 else if(algorithm == MM_ALG_3D){
389 p | got_start; p | got_data;
394 void CLA_Matrix::ResumeFromSync(
void){
396 if(algorithm == MM_ALG_2D){
397 if(part == MULTARG_A){
398 commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, thisIndex.x,
399 thisIndex.x, 1, 0, (N_chunks - 1) * N_stride, N_stride);
401 }
else if(part == MULTARG_B) {
402 commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, 0,
403 (M_chunks - 1) * M_stride, M_stride, thisIndex.y, thisIndex.y, 1);
406 }
else if(algorithm == MM_ALG_3D){
408 if(part == MULTARG_A){
409 int x = thisIndex.x / M_stride;
410 int y = thisIndex.y / K_stride;
411 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, 0,
412 N_chunks - 1, 1, y, y, 1);
413 contribute(0, NULL, CkReduction::sum_int, cb);
414 }
else if(part == MULTARG_B) {
415 int x = thisIndex.x / K_stride;
416 int y = thisIndex.y / N_stride;
417 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, 0,
418 M_chunks - 1, 1, y, y, 1, x, x, 1);
419 contribute(0, NULL, CkReduction::sum_int, cb);
420 }
else if(part == MULTARG_C) {
422 int x = thisIndex.x / M_stride;
423 int y = thisIndex.y / N_stride;
424 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, y, y,
425 1, 0, K_chunks - 1, 1);
440 void CLA_Matrix::multiply(
double alpha,
double beta, internalType *data,
441 void (*fptr) (
void*),
void *usr_data){
442 if(algorithm == MM_ALG_2D){
444 if(part == MULTARG_A){
447 commGroup2D.receiveA(msg);
448 }
else if(part == MULTARG_B){
451 commGroup2D.receiveB(msg);
454 else if(part == MULTARG_C){
456 user_data = usr_data;
462 if(row_count == K_chunks && col_count == K_chunks)
466 CmiAbort(
"CLA_Matrix internal error");
467 }
else if(algorithm == MM_ALG_3D){
468 if(part == MULTARG_A){
471 commGroup3D.receiveA(msg);
472 }
else if(part == MULTARG_B){
475 commGroup3D.receiveB(msg);
476 }
else if(part == MULTARG_C){
478 user_data = usr_data;
484 got_start = got_data =
false;
486 internalType *data = (internalType*) res_msg->getData();
487 transpose(data, n, m);
488 for(
int i = 0; i < m; i++)
489 for(
int j = 0; j < n; j++)
490 dest[i * n + j] = beta * dest[i * n + j] + alpha * data[i * n + j];
502 for(
int i = 0; i < m; i++)
503 CmiMemcpy(&tmpA[K * i + uk * (msg->fromY / K_stride)], &msg->
data[i * msg->d2],
504 msg->d2 *
sizeof(internalType));
508 if(row_count == K_chunks && col_count == K_chunks && got_start)
515 CmiMemcpy(&tmpB[n * uk * (msg->fromX / K_stride)], msg->
data,
516 msg->d1 * msg->d2 *
sizeof(internalType));
520 if(row_count == K_chunks && col_count == K_chunks && got_start)
527 void CLA_Matrix::multiply()
530 row_count = col_count = 0;
535 transpose(dest, m, n);
539 #define BUNDLE_USER_EVENTS
541 #ifdef CMK_TRACE_ENABLED
542 double StartTime=CmiWallTimer();
544 #ifdef PRINT_DGEMM_PARAMS
545 CkPrintf(
"CLA_MATRIX DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, K, alpha, beta, K, n, m);
548 for(
int in=0; in<K; in++)
549 for(
int jn=0; jn<m; jn++)
550 CkAssert(isfinite(tmpA[in*m+jn]));
551 for(
int in=0; in<n; in++)
552 for(
int jn=0; jn<K; jn++)
553 CkAssert(isfinite(tmpB[in*K+jn]));
555 myGEMM(&trans, &trans, &m, &n, &K, &alpha, tmpA, &K, tmpB, &n, &beta, dest, &m);
557 for(
int in=0; in<m; in++)
558 for(
int jn=0; jn<n; jn++)
559 CkAssert(isfinite(dest[in*n+jn]));
561 #ifdef CMK_TRACE_ENABLED
562 traceUserBracketEvent(401, StartTime, CmiWallTimer());
565 transpose(dest, n, m);
573 void CLA_Matrix::readyC(CkReductionMsg *msg){
574 CkCallback cb(CkIndex_CLA_Matrix::ready(NULL), thisProxy(0, 0));
575 contribute(0, NULL, CkReduction::sum_int, cb);
579 void CLA_Matrix::ready(CkReductionMsg *msg){
584 void CLA_Matrix::mult_done(CkReductionMsg *msg){
586 got_start = got_data =
false;
588 internalType *data = (internalType*) msg->getData();
589 transpose(data, n, m);
590 for(
int i = 0; i < m; i++)
591 for(
int j = 0; j < n; j++)
592 dest[i * n + j] = beta * dest[i * n + j] + alpha * data[i * n + j];
605 CLA_Matrix_msg::CLA_Matrix_msg(internalType *data,
int d1,
int d2,
int fromX,
607 CmiMemcpy(this->data, data, d1 * d2 *
sizeof(internalType));
608 this->d1 = d1; this->d2 = d2;
609 this->fromX = fromX; this->fromY = fromY;
614 CLA_MM3D_multiplier::CLA_MM3D_multiplier(
int m,
int k,
int n){
615 this->m = m; this->k = k; this->n = n;
621 reduce_CB = m->reduce;
622 CkGetSectionInfo(sectionCookie, m);
623 redGrp = CProxy_CkMulticastMgr(m->gid).ckLocalBranch();
624 redGrp->contribute(0, NULL, CkReduction::sum_int, sectionCookie, m->ready);
631 multiply(msg->
data, data_msg->
data);
642 multiply(data_msg->
data, msg->
data);
650 void CLA_MM3D_multiplier::multiply(internalType *A, internalType *B){
651 double alpha = 1, beta = 0;
654 internalType *C =
new internalType[m * n];
655 #ifdef CMK_TRACE_ENABLED
656 double StartTime=CmiWallTimer();
659 CkAssert((
unsigned int) A %16==0);
660 CkAssert((
unsigned int) B %16==0);
661 CkAssert((
unsigned int) C %16==0);
664 #ifdef PRINT_DGEMM_PARAMS
665 CkPrintf(
"HEY-DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, k, alpha, beta, k, n, m);
667 myGEMM(&trans, &trans, &m, &n, &k, &alpha, A, &k, B, &n, &beta, C, &m);
668 #ifdef CMK_TRACE_ENABLED
669 traceUserBracketEvent(402, StartTime, CmiWallTimer());
671 CmiNetworkProgress();
673 redGrp->contribute(m * n *
sizeof(internalType), C, sumFastDoubleType,
674 sectionCookie, reduce_CB);
678 #include "CLA_Matrix.def.h"
internalType * data
~CLA_Matrix_msg(){delete [] data;}
Author: Eric J Bohm Date Created: June 4th, 2006.
Ortho is decomposed by orthoGrainSize.