00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #include "CLA_Matrix.h"
00016
00017 #if 0
00018 #include <sstream>
00019 using std::ostringstream;
00020 using std::endl;
00021 #endif
00022
00023 #ifdef FORTRANUNDERSCORE
00024 #define DGEMM dgemm_
00025 #else
00026 #define DGEMM dgemm
00027 #endif
00028
00029
00030 extern "C" {void DGEMM(char *, char *, int *, int *, int *, double *,
00031 double *, int *, double *, int *, double *,
00032 double *, int *);}
00033
00034 #define MULTARG_A 0
00035 #define MULTARG_B 1
00036 #define MULTARG_C 2
00037 extern CkReduction::reducerType sumFastDoubleType;
00038
00039
00040
00041
00042 int make_multiplier(CLA_Matrix_interface *A, CLA_Matrix_interface *B,
00043 CLA_Matrix_interface *C, CProxy_ArrayElement bindA,
00044 CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
00045 int M,
00046 int K,
00047 int N,
00048 int m,
00049 int k,
00050 int n,
00051 int strideM,
00052 int strideK,
00053 int strideN,
00054 CkCallback cbA, CkCallback cbB,
00055 CkCallback cbC, CkGroupID gid, int algorithm, int gemmSplitOrtho){
00056
00057 if(algorithm < MM_ALG_MIN || MM_ALG_MAX < algorithm)
00058 return ERR_INVALID_ALG;
00059
00060 if(m > M || k > K || n > N)
00061 return ERR_INVALID_DIM;
00062
00063 CkArrayOptions optsA(0);
00064 CkArrayOptions optsB(0);
00065 CkArrayOptions optsC(0);
00066
00067 optsA.bindTo(bindA);
00068 optsB.bindTo(bindB);
00069 optsC.bindTo(bindC);
00070 CProxy_CLA_Matrix pa = CProxy_CLA_Matrix::ckNew(optsA);
00071 CProxy_CLA_Matrix pb = CProxy_CLA_Matrix::ckNew(optsB);
00072 CProxy_CLA_Matrix pc = CProxy_CLA_Matrix::ckNew(optsC);
00073 A->setProxy(pa);
00074 B->setProxy(pb);
00075 C->setProxy(pc);
00076
00077
00078 int M_chunks = (M + m - 1) / m;
00079 int K_chunks = (K + k - 1) / k;
00080 int N_chunks = (N + n - 1) / n;
00081 if(M%m!=0)
00082 M_chunks--;
00083 if(K%k!=0)
00084 K_chunks--;
00085 if(N%n!=0)
00086 N_chunks--;
00087
00088
00089
00090 if(algorithm == MM_ALG_2D){
00091 for(int i = 0; i < M_chunks; i++)
00092 for(int j = 0; j < K_chunks; j++)
00093 (A->p(i * strideM, j * strideK)).insert(M, K, N, m, k, n, strideM,
00094 strideK, strideN, MULTARG_A, B->p, C->p, cbA, gemmSplitOrtho);
00095 A->p.doneInserting();
00096
00097 for(int i = 0; i < K_chunks; i++)
00098 for(int j = 0; j < N_chunks; j++)
00099 (B->p(i * strideK, j * strideN)).insert(M, K, N, m, k, n, strideM,
00100 strideK, strideN, MULTARG_B, A->p, C->p, cbB, gemmSplitOrtho);
00101 B->p.doneInserting();
00102
00103 for(int i = 0; i < M_chunks; i++)
00104 for(int j = 0; j < N_chunks; j++)
00105 (C->p(i * strideM, j * strideN)).insert(M, K, N, m, k, n, strideM,
00106 strideK, strideN, MULTARG_C, A->p, B->p, cbC, gemmSplitOrtho);
00107 C->p.doneInserting();
00108 }
00109 else if(algorithm == MM_ALG_3D){
00110 CProxy_CLA_MM3D_multiplier mult = CProxy_CLA_MM3D_multiplier::ckNew();
00111 int curpe = 0;
00112 int totpe = CkNumPes();
00113 for(int i = 0; i < M_chunks; i++){
00114 int mm = m;
00115 if(i == M_chunks - 1) {
00116 mm = M % m;
00117 if(mm == 0)
00118 mm = m;
00119 }
00120 for(int j = 0; j < N_chunks; j++){
00121 int nn = n;
00122 if(j == N_chunks - 1) {
00123 nn = N % n;
00124 if(nn == 0)
00125 nn = n;
00126 }
00127 for(int l = 0; l < K_chunks; l++){
00128 int kk = k;
00129 if(l == K_chunks - 1) {
00130 kk = K % k;
00131 if(kk == 0)
00132 kk = k;
00133 }
00134 mult(i, j, l).insert(mm, kk, nn, curpe);
00135 curpe = (curpe + 1) % totpe;
00136 }
00137 }
00138 }
00139 mult.doneInserting();
00140
00141 for(int i = 0; i < M_chunks; i++)
00142 for(int j = 0; j < K_chunks; j++)
00143 (A->p(i * strideM, j * strideK)).insert(mult, M, K, N, m, k, n,
00144 strideM, strideK, strideN, MULTARG_A, cbA, gid, gemmSplitOrtho);
00145 A->p.doneInserting();
00146
00147 for(int i = 0; i < K_chunks; i++)
00148 for(int j = 0; j < N_chunks; j++)
00149 (B->p(i * strideK, j * strideN)).insert(mult, M, K, N, m, k, n,
00150 strideM, strideK, strideN, MULTARG_B, cbB, gid, gemmSplitOrtho);
00151 B->p.doneInserting();
00152
00153 for(int i = 0; i < M_chunks; i++)
00154 for(int j = 0; j < N_chunks; j++)
00155 (C->p(i * strideM, j * strideN)).insert(mult, M, K, N, m, k, n,
00156 strideM, strideK, strideN, MULTARG_C, cbC, gid, gemmSplitOrtho);
00157 C->p.doneInserting();
00158 }
00159
00160 return SUCCESS;
00161 }
00162
00163
00164 void transpose(double *data, int m, int n){
00165 if(m == n){
00166
00167 for(int i = 0; i < m; i++)
00168 for(int j = i + 1; j < n; j++){
00169 double tmp = data[i * n + j];
00170 data[i * n + j] = data[j * m + i];
00171 data[j * m + i] = tmp;
00172 }
00173 }
00174 else {
00175 double *tmp = new double[m * n];
00176 CmiMemcpy(tmp, data, m * n * sizeof(double));
00177 for(int i = 0; i < m; i++)
00178 for(int j = 0; j < n; j++)
00179 data[j * m + i] = tmp[i * n + j];
00180 delete [] tmp;
00181 }
00182 }
00183
00184
00185
00186
00187
00188 CLA_Matrix::CLA_Matrix(int M, int K, int N, int m, int k, int n,
00189 int strideM, int strideK, int strideN, int part,
00190 CProxy_CLA_Matrix other1, CProxy_CLA_Matrix other2, CkCallback ready, int _gemmSplitOrtho){
00191
00192 this->M = M; this->K = K; this->N = N;
00193 this->um = m; this->uk = k; this->un = n;
00194 this->part = part;
00195 this->algorithm = MM_ALG_2D;
00196 this->other1 = other1; this->other2 = other2;
00197 this->M_stride = strideM;
00198 this->K_stride = strideK;
00199 this->N_stride = strideN;
00200 gemmSplitOrtho=_gemmSplitOrtho;
00201 M_chunks = (M + m - 1) / m;
00202 K_chunks = (K + k - 1) / k;
00203 N_chunks = (N + n - 1) / n;
00204 if(M%m!=0)
00205 M_chunks--;
00206 if(K%k!=0)
00207 K_chunks--;
00208 if(N%n!=0)
00209 N_chunks--;
00210
00211
00212 algorithm = MM_ALG_2D;
00213 usesAtSync = CmiFalse;
00214 setMigratable(false);
00215
00216 if(part == MULTARG_A){
00217 if(thisIndex.x == (M_chunks - 1) * strideM){
00218 this->m = m + M % m;
00219 if(this->m == 0) this->m = m;
00220 }
00221 else this->m = m;
00222 if(thisIndex.y == (K_chunks - 1) * strideK){
00223 this->k = k + K % k;
00224 if(this->k == 0) this->k = k;
00225 }
00226 else this->k = k;
00227 this->n = n;
00228 } else if(part == MULTARG_B) {
00229 if(thisIndex.x == (K_chunks - 1) * strideK){
00230 this->k = k + K % k;
00231 if(this->k == 0) this->k = k;
00232 }
00233 else this->k = k;
00234 if(thisIndex.y == (N_chunks - 1) * strideN){
00235 this->n = n + N % n;
00236 if(this->n == 0) this->n = n;
00237 }
00238 else this->n = n;
00239 this->m = m;
00240 } else if(part == MULTARG_C) {
00241 if(thisIndex.x == (M_chunks - 1) * strideM){
00242 this->m = m + M % m;
00243 if(this->m == 0) this->m = m;
00244 }
00245 else this->m = m;
00246 if(thisIndex.y == (N_chunks - 1) * strideN){
00247 this->n = n + N % n;
00248 if(this->n == 0) this->n = n;
00249 }
00250 else this->n = n;
00251 this->k = k;
00252 got_start = false;
00253 row_count = col_count = 0;
00254 }
00255
00256
00257 if(part == MULTARG_A){
00258 commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, thisIndex.x,
00259 thisIndex.x, 1, 0, (N_chunks - 1) * strideN, strideN);
00260 tmpA = tmpB = NULL;
00261 } else if(part == MULTARG_B) {
00262 commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, 0,
00263 (M_chunks - 1) * strideM, strideM, thisIndex.y, thisIndex.y, 1);
00264 tmpA = tmpB = NULL;
00265 } else if(part == MULTARG_C) {
00266 tmpA = new double[this->m * K];
00267 tmpB = new double[K * this->n];
00268 }
00269
00270
00271 contribute(0, NULL, CkReduction::sum_int, ready);
00272 }
00273
00274
00275 CLA_Matrix::CLA_Matrix(CProxy_CLA_MM3D_multiplier p, int M, int K, int N,
00276 int m, int k, int n, int strideM, int strideK, int strideN, int part,
00277 CkCallback cb, CkGroupID gid, int _gemmSplitOrtho){
00278
00279 this->M = M; this->K = K; this->N = N;
00280 this->um = m; this->uk = k; this->un = n;
00281 this->part = part;
00282 this->algorithm = MM_ALG_2D;
00283 this->other1 = other1; this->other2 = other2;
00284 this->M_stride = strideM;
00285 this->K_stride = strideK;
00286 this->N_stride = strideN;
00287 gemmSplitOrtho=_gemmSplitOrtho;
00288 M_chunks = (M + m - 1) / m;
00289 K_chunks = (K + k - 1) / k;
00290 N_chunks = (N + n - 1) / n;
00291 got_data = got_start = false;
00292 res_msg = NULL;
00293 algorithm = MM_ALG_3D;
00294 usesAtSync = CmiFalse;
00295 setMigratable(false);
00296
00297 if(part == MULTARG_A){
00298 if(thisIndex.x == (M_chunks - 1) * strideM){
00299 this->m = M % m;
00300 if(this->m == 0) this->m = m;
00301 }
00302 else this->m = m;
00303 if(thisIndex.y == (K_chunks - 1) * strideK){
00304 this->k = K % k;
00305 if(this->k == 0) this->k = k;
00306 }
00307 else this->k = k;
00308 this->n = n;
00309 } else if(part == MULTARG_B) {
00310 if(thisIndex.x == (K_chunks - 1) * strideK){
00311 this->k = K % k;
00312 if(this->k == 0) this->k = k;
00313 }
00314 else this->k = k;
00315 if(thisIndex.y == (N_chunks - 1) * strideN){
00316 this->n = N % n;
00317 if(this->n == 0) this->n = n;
00318 }
00319 else this->n = n;
00320 this->m = m;
00321 } else if(part == MULTARG_C) {
00322 if(thisIndex.x == (M_chunks - 1) * strideM){
00323 this->m = M % m;
00324 if(this->m == 0) this->m = m;
00325 }
00326 else this->m = m;
00327 if(thisIndex.y == (N_chunks - 1) * strideN){
00328 this->n = N % n;
00329 if(this->n == 0) this->n = n;
00330 }
00331 else this->n = n;
00332 this->k = k;
00333 }
00334
00335
00336 if(part == MULTARG_A){
00337 int x = thisIndex.x / strideM;
00338 int y = thisIndex.y / strideK;
00339 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, 0,
00340 N_chunks - 1, 1, y, y, 1);
00341 contribute(0, NULL, CkReduction::sum_int, cb);
00342 } else if(part == MULTARG_B) {
00343 int x = thisIndex.x / strideK;
00344 int y = thisIndex.y / strideN;
00345 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, 0,
00346 M_chunks - 1, 1, y, y, 1, x, x, 1);
00347 contribute(0, NULL, CkReduction::sum_int, cb);
00348 } else if(part == MULTARG_C) {
00349 init_cb = cb;
00350 int x = thisIndex.x / strideM;
00351 int y = thisIndex.y / strideN;
00352 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, y, y, 1,
00353 0, K_chunks - 1, 1);
00354 commGroup3D.ckSectionDelegate(CProxy_CkMulticastMgr(gid).ckLocalBranch());
00355 CLA_MM3D_mult_init_msg *m = new CLA_MM3D_mult_init_msg(gid,
00356 CkCallback(CkIndex_CLA_Matrix::readyC(NULL),
00357 thisProxy(thisIndex.x, thisIndex.y)), CkCallback(
00358 CkIndex_CLA_Matrix::mult_done(NULL), thisProxy(thisIndex.x,
00359 thisIndex.y)));
00360 commGroup3D.initialize_reduction(m);
00361 }
00362 }
00363
00364 CLA_Matrix::~CLA_Matrix(){
00365 if(algorithm == MM_ALG_2D){
00366 delete [] tmpA;
00367 delete [] tmpB;
00368 }
00369 else if(algorithm == MM_ALG_3D){
00370 if(res_msg != NULL)
00371 delete res_msg;
00372 }
00373 }
00374
00375 void CLA_Matrix::pup(PUP::er &p){
00376
00377 if(algorithm == MM_ALG_3D){
00378 CmiAbort("3D algorithm does not currently support migration.\n");
00379 }
00380
00381
00382 CBase_CLA_Matrix::pup(p);
00383
00384
00385 p | M; p | K; p | N; p | m; p | k; p | n; p | um; p | uk; p | un;
00386 p | M_chunks; p | K_chunks; p | N_chunks;
00387 p | M_stride; p | K_stride; p | N_stride;
00388 p | part; p | algorithm;
00389 p | alpha; p | beta;
00390 p | gemmSplitOrtho;
00391
00392 if(algorithm == MM_ALG_2D){
00393 p | row_count; p | col_count;
00394 p | other1; p | other2;
00395 if(part == MULTARG_C){
00396 if(p.isUnpacking()){
00397 tmpA = new double[m * K];
00398 tmpB = new double[K * n];
00399 }
00400 PUParray(p, tmpA, m * K);
00401 PUParray(p, tmpB, K * n);
00402 }
00403 }
00404 else if(algorithm == MM_ALG_3D){
00405 p | init_cb;
00406 p | got_start; p | got_data;
00407 p | commGroup3D;
00408 }
00409 }
00410
00411 void CLA_Matrix::ResumeFromSync(void){
00412
00413 if(algorithm == MM_ALG_2D){
00414 if(part == MULTARG_A){
00415 commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, thisIndex.x,
00416 thisIndex.x, 1, 0, (N_chunks - 1) * N_stride, N_stride);
00417 tmpA = tmpB = NULL;
00418 } else if(part == MULTARG_B) {
00419 commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, 0,
00420 (M_chunks - 1) * M_stride, M_stride, thisIndex.y, thisIndex.y, 1);
00421 tmpA = tmpB = NULL;
00422 }
00423 } else if(algorithm == MM_ALG_3D){
00424 #if 0
00425 if(part == MULTARG_A){
00426 int x = thisIndex.x / M_stride;
00427 int y = thisIndex.y / K_stride;
00428 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, 0,
00429 N_chunks - 1, 1, y, y, 1);
00430 contribute(0, NULL, CkReduction::sum_int, cb);
00431 } else if(part == MULTARG_B) {
00432 int x = thisIndex.x / K_stride;
00433 int y = thisIndex.y / N_stride;
00434 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, 0,
00435 M_chunks - 1, 1, y, y, 1, x, x, 1);
00436 contribute(0, NULL, CkReduction::sum_int, cb);
00437 } else if(part == MULTARG_C) {
00438 init_cb = cb;
00439 int x = thisIndex.x / M_stride;
00440 int y = thisIndex.y / N_stride;
00441 commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, y, y,
00442 1, 0, K_chunks - 1, 1);
00443
00444
00445
00446
00447
00448
00449
00450
00451
00452 }
00453 #endif
00454 }
00455 }
00456
00457 void CLA_Matrix::multiply(double alpha, double beta, double *data,
00458 void (*fptr) (void*), void *usr_data){
00459 if(algorithm == MM_ALG_2D){
00460
00461 if(part == MULTARG_A){
00462 CLA_Matrix_msg *msg = new (m * k) CLA_Matrix_msg(data, m, k, thisIndex.x,
00463 thisIndex.y);
00464 commGroup2D.receiveA(msg);
00465 } else if(part == MULTARG_B){
00466 CLA_Matrix_msg *msg = new (k * n) CLA_Matrix_msg(data, k, n, thisIndex.x,
00467 thisIndex.y);
00468 commGroup2D.receiveB(msg);
00469 }
00470
00471 else if(part == MULTARG_C){
00472 fcb = fptr;
00473 user_data = usr_data;
00474 dest = data;
00475 this->alpha = alpha;
00476 this->beta = beta;
00477 got_start = true;
00478
00479 if(row_count == K_chunks && col_count == K_chunks)
00480 multiply();
00481 }
00482 else
00483 CmiAbort("CLA_Matrix internal error");
00484 } else if(algorithm == MM_ALG_3D){
00485 if(part == MULTARG_A){
00486 CLA_Matrix_msg *msg = new (m * k) CLA_Matrix_msg(data, m, k, thisIndex.x,
00487 thisIndex.y);
00488 commGroup3D.receiveA(msg);
00489 } else if(part == MULTARG_B){
00490 CLA_Matrix_msg *msg = new (k * n) CLA_Matrix_msg(data, k, n, thisIndex.x,
00491 thisIndex.y);
00492 commGroup3D.receiveB(msg);
00493 } else if(part == MULTARG_C){
00494 fcb = fptr;
00495 user_data = usr_data;
00496 dest = data;
00497 this->alpha = alpha;
00498 this->beta = beta;
00499 got_start = true;
00500 if(got_data){
00501 got_start = got_data = false;
00502
00503 double *data = (double*) res_msg->getData();
00504 transpose(data, n, m);
00505 for(int i = 0; i < m; i++)
00506 for(int j = 0; j < n; j++)
00507 dest[i * n + j] = beta * dest[i * n + j] + alpha * data[i * n + j];
00508 delete res_msg;
00509 res_msg = NULL;
00510 (*fcb)(user_data);
00511 }
00512 }
00513 }
00514 }
00515
00516 void CLA_Matrix::receiveA(CLA_Matrix_msg *msg){
00517
00518 row_count++;
00519 for(int i = 0; i < m; i++)
00520 CmiMemcpy(&tmpA[K * i + uk * (msg->fromY / K_stride)], &msg->data[i * msg->d2],
00521 msg->d2 * sizeof(double));
00522 delete msg;
00523
00524
00525 if(row_count == K_chunks && col_count == K_chunks && got_start)
00526 multiply();
00527 }
00528
00529 void CLA_Matrix::receiveB(CLA_Matrix_msg *msg){
00530
00531 col_count++;
00532 CmiMemcpy(&tmpB[n * uk * (msg->fromX / K_stride)], msg->data,
00533 msg->d1 * msg->d2 * sizeof(double));
00534 delete msg;
00535
00536
00537 if(row_count == K_chunks && col_count == K_chunks && got_start)
00538 multiply();
00539 }
00540
00541 void CLA_Matrix::multiply(){
00542
00543 row_count = col_count = 0;
00544 got_start = false;
00545
00546
00547 if(beta != 0)
00548 transpose(dest, m, n);
00549
00550 char trans = 'T';
00551 #define ORTHO_DGEMM_SPLIT
00552
00553 #define BUNDLE_USER_EVENTS
00554 #ifdef ORTHO_DGEMM_SPLIT
00555 double betap = 1.0;
00556 int Ksplit_m=gemmSplitOrtho;
00557 int Ksplit = (K > Ksplit_m) ? Ksplit_m : K;
00558 int Krem = (K % Ksplit);
00559 int Kloop = K/Ksplit-1;
00560
00561 #ifndef CMK_OPTIMIZE
00562 double StartTime=CmiWallTimer();
00563 #endif
00564
00565 #ifdef TEST_ALIGN
00566 CkAssert((unsigned int) tmpA %16==0);
00567 CkAssert((unsigned int) tmpB %16==0);
00568 CkAssert((unsigned int) dest %16==0);
00569 #endif
00570
00571 #ifdef PRINT_DGEMM_PARAMS
00572 CkPrintf("HEY-DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, Ksplit, alpha, beta, K, n, m);
00573 #endif
00574
00575 #ifdef _NAN_CHECK_
00576 for(int in=0; in<Ksplit; in++)
00577 for(int jn=0; jn<m; jn++)
00578 CkAssert(finite(tmpA[in*m+jn]));
00579
00580 for(int in=0; in<n; in++)
00581 for(int jn=0; jn<Ksplit; jn++)
00582 CkAssert(finite(tmpB[in*K+jn]));
00583 #endif
00584
00585 DGEMM(&trans, &trans, &m, &n, &Ksplit, &alpha, tmpA, &K, tmpB, &n, &beta,
00586 dest,&m);
00587
00588 #ifdef _NAN_CHECK_
00589 for(int in=0; in<m; in++)
00590 for(int jn=0; jn<n; jn++)
00591 CkAssert(finite(dest[in*n+jn]));
00592 #endif
00593
00594 #ifndef BUNDLE_USER_EVENTS
00595 #ifndef CMK_OPTIMIZE
00596 traceUserBracketEvent(401, StartTime, CmiWallTimer());
00597 #endif
00598 #endif
00599
00600 CmiNetworkProgress();
00601
00602 for(int i=1;i<=Kloop;i++){
00603 int aoff = Ksplit*i;
00604 int boff = n*i*Ksplit;
00605 if(i==Kloop){Ksplit+=Krem;}
00606 #ifndef BUNDLE_USER_EVENTS
00607 #ifndef CMK_OPTIMIZE
00608 StartTime=CmiWallTimer();
00609 #endif
00610 #endif
00611
00612 #ifdef TEST_ALIGN
00613 CkAssert((unsigned int) &(tmpA[aoff]) %16==0);
00614 CkAssert((unsigned int) &(tmpB[boff]) %16==0);
00615 CkAssert((unsigned int) dest %16==0);
00616 #endif
00617
00618 #ifdef PRINT_DGEMM_PARAMS
00619 CkPrintf("HEY-DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, Ksplit, alpha, betap, K, n, m);
00620 #endif
00621
00622 #ifdef _NAN_CHECK_
00623 for(int in=0; in<Ksplit; in++)
00624 for(int jn=0; jn<m; jn++)
00625 CkAssert(finite(tmpA[aoff+in*m+jn]));
00626
00627 for(int in=0; in<n; in++)
00628 for(int jn=0; jn<Ksplit; jn++)
00629 CkAssert(finite(tmpB[boff+in*Ksplit+jn]));
00630 #endif
00631
00632 DGEMM(&trans, &trans, &m, &n, &Ksplit, &alpha, &tmpA[aoff], &K,
00633 &tmpB[boff], &n, &betap, dest, &m);
00634
00635 #ifdef _NAN_CHECK_
00636 for(int in=0; in<m; in++)
00637 for(int jn=0; jn<n; jn++)
00638 CkAssert(finite(dest[in*n+jn]));
00639 #endif
00640
00641 #ifndef BUNDLE_USER_EVENTS
00642 #ifndef CMK_OPTIMIZE
00643 traceUserBracketEvent(401, StartTime, CmiWallTimer());
00644 #endif
00645 #endif
00646 CmiNetworkProgress();
00647 }
00648
00649 #ifdef BUNDLE_USER_EVENTS
00650 #ifndef CMK_OPTIMIZE
00651 traceUserBracketEvent(401, StartTime, CmiWallTimer());
00652 #endif
00653 #endif
00654
00655 #else
00656
00657 #ifndef CMK_OPTIMIZE
00658 double StartTime=CmiWallTimer();
00659 #endif
00660
00661 #ifdef PRINT_DGEMM_PARAMS
00662 CkPrintf("CLA_MATRIX DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, K, alpha, beta, K, n, m);
00663 #endif
00664
00665 #ifdef _NAN_CHECK_
00666 for(int in=0; in<K; in++)
00667 for(int jn=0; jn<m; jn++)
00668 CkAssert(finite(tmpA[in*m+jn]));
00669
00670 for(int in=0; in<n; in++)
00671 for(int jn=0; jn<K; jn++)
00672 CkAssert(finite(tmpB[in*K+jn]));
00673 #endif
00674
00675 DGEMM(&trans, &trans, &m, &n, &K, &alpha, tmpA, &K, tmpB, &n, &beta,
00676 dest, &m);
00677
00678 #ifdef _NAN_CHECK_
00679 for(int in=0; in<m; in++)
00680 for(int jn=0; jn<n; jn++)
00681 CkAssert(finite(dest[in*n+jn]));
00682 #endif
00683
00684 #ifndef CMK_OPTIMIZE
00685 traceUserBracketEvent(401, StartTime, CmiWallTimer());
00686 #endif
00687
00688 #endif
00689
00690 transpose(dest, n, m);
00691
00692
00693 fcb(user_data);
00694 }
00695
00696 void CLA_Matrix::readyC(CkReductionMsg *msg){
00697 CkCallback cb(CkIndex_CLA_Matrix::ready(NULL), thisProxy(0, 0));
00698 contribute(0, NULL, CkReduction::sum_int, cb);
00699 delete msg;
00700 }
00701
00702 void CLA_Matrix::ready(CkReductionMsg *msg){
00703 init_cb.send();
00704 delete msg;
00705 }
00706
00707 void CLA_Matrix::mult_done(CkReductionMsg *msg){
00708 if(got_start){
00709 got_start = got_data = false;
00710
00711 double *data = (double*) msg->getData();
00712 transpose(data, n, m);
00713 for(int i = 0; i < m; i++)
00714 for(int j = 0; j < n; j++)
00715 dest[i * n + j] = beta * dest[i * n + j] + alpha * data[i * n + j];
00716 delete msg;
00717 msg = NULL;
00718 (*fcb)(user_data);
00719 }
00720 else{
00721 got_data = true;
00722 res_msg = msg;
00723 }
00724 }
00725
00726
00727
00728 CLA_Matrix_msg::CLA_Matrix_msg(double *data, int d1, int d2, int fromX,
00729 int fromY){
00730 CmiMemcpy(this->data, data, d1 * d2 * sizeof(double));
00731 this->d1 = d1; this->d2 = d2;
00732 this->fromX = fromX; this->fromY = fromY;
00733 }
00734
00735
00736
00737 CLA_MM3D_multiplier::CLA_MM3D_multiplier(int m, int k, int n){
00738 this->m = m; this->k = k; this->n = n;
00739 data_msg = NULL;
00740 gotA = gotB = false;
00741 }
00742
00743 void CLA_MM3D_multiplier::initialize_reduction(CLA_MM3D_mult_init_msg *m){
00744 reduce_CB = m->reduce;
00745 CkGetSectionInfo(sectionCookie, m);
00746 redGrp = CProxy_CkMulticastMgr(m->gid).ckLocalBranch();
00747 redGrp->contribute(0, NULL, CkReduction::sum_int, sectionCookie, m->ready);
00748 delete m;
00749 }
00750
00751 void CLA_MM3D_multiplier::receiveA(CLA_Matrix_msg *msg){
00752 gotA = true;
00753 if(gotB){
00754 multiply(msg->data, data_msg->data);
00755 delete msg;
00756 delete data_msg;
00757 }
00758 else
00759 data_msg = msg;
00760 }
00761
00762 void CLA_MM3D_multiplier::receiveB(CLA_Matrix_msg *msg){
00763 gotB = true;
00764 if(gotA){
00765 multiply(data_msg->data, msg->data);
00766 delete msg;
00767 delete data_msg;
00768 }
00769 else
00770 data_msg = msg;
00771 }
00772
00773 void CLA_MM3D_multiplier::multiply(double *A, double *B){
00774 double alpha = 1, beta = 0;
00775 gotA = gotB = false;
00776 char trans = 'T';
00777 double *C = new double[m * n];
00778 #ifndef CMK_OPTIMIZE
00779 double StartTime=CmiWallTimer();
00780 #endif
00781 #ifdef TEST_ALIGN
00782 CkAssert((unsigned int) A %16==0);
00783 CkAssert((unsigned int) B %16==0);
00784 CkAssert((unsigned int) C %16==0);
00785 #endif
00786
00787 #ifdef PRINT_DGEMM_PARAMS
00788 CkPrintf("HEY-DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, k, alpha, beta, k, n, m);
00789 #endif
00790 DGEMM(&trans, &trans, &m, &n, &k, &alpha, A, &k, B, &n, &beta, C, &m);
00791 #ifndef CMK_OPTIMIZE
00792 traceUserBracketEvent(402, StartTime, CmiWallTimer());
00793 #endif
00794 CmiNetworkProgress();
00795
00796 redGrp->contribute(m * n * sizeof(double), C, sumFastDoubleType,
00797 sectionCookie, reduce_CB);
00798 delete [] C;
00799 }
00800
00801
00802 #include "CLA_Matrix.def.h"
00803