CLA_Matrix.C

Go to the documentation of this file.
00001 /*****************************************************************************
00002  * $Source: /cvsroot/leanCP/src_charm_driver/main/CLA_Matrix.C,v $
00003  * $Author: bhatele $
00004  * $Date: 2007/12/05 08:32:46 $
00005  * $Revision: 1.22 $
00006  *****************************************************************************/
00007 
00008 /** \file CLA_Matrix.C
00009  *  
00010  */
00011 /** \ingroup main
00012  *
00013  */
00014 //@{
00015 #include "CLA_Matrix.h"
00016 
00017 #if 0
00018 #include <sstream>
00019 using std::ostringstream;
00020 using std::endl;
00021 #endif
00022 
00023 #ifdef FORTRANUNDERSCORE
00024 #define DGEMM dgemm_
00025 #else
00026 #define DGEMM dgemm
00027 #endif
00028 //#define PRINT_DGEMM_PARAMS
00029 
00030 extern "C" {void  DGEMM(char *, char *, int *, int *, int *, double *,
00031                            double *, int *, double *, int *, double *,
00032                            double *, int *);}
00033 
00034 #define MULTARG_A       0
00035 #define MULTARG_B       1
00036 #define MULTARG_C       2
00037 extern CkReduction::reducerType sumFastDoubleType;
00038 /******************************************************************************/
00039 /* helper functions */
00040 
00041 /* Should be called by user to create matrices. Documented in header file. */
00042 int make_multiplier(CLA_Matrix_interface *A, CLA_Matrix_interface *B,
00043  CLA_Matrix_interface *C, CProxy_ArrayElement bindA,
00044  CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
00045                     int M, //nstates
00046                     int K, //nstates
00047                     int N, //nstates
00048                     int m, //orthograinsize
00049                     int k, //orthograinsize
00050                     int n, //orthograinsize
00051                     int strideM, // 1
00052                     int strideK, // 1
00053                     int strideN, // 1 
00054                     CkCallback cbA, CkCallback cbB,
00055  CkCallback cbC, CkGroupID gid, int algorithm, int gemmSplitOrtho){
00056   /* validate arguments */
00057   if(algorithm < MM_ALG_MIN || MM_ALG_MAX < algorithm)
00058     return ERR_INVALID_ALG;
00059 
00060   if(m > M || k > K || n > N)
00061     return ERR_INVALID_DIM;
00062   /* create arrays */
00063   CkArrayOptions optsA(0);
00064   CkArrayOptions optsB(0);
00065   CkArrayOptions optsC(0);
00066   
00067   optsA.bindTo(bindA);
00068   optsB.bindTo(bindB);
00069   optsC.bindTo(bindC);
00070   CProxy_CLA_Matrix pa = CProxy_CLA_Matrix::ckNew(optsA);
00071   CProxy_CLA_Matrix pb = CProxy_CLA_Matrix::ckNew(optsB);
00072   CProxy_CLA_Matrix pc = CProxy_CLA_Matrix::ckNew(optsC);
00073   A->setProxy(pa);
00074   B->setProxy(pb);
00075   C->setProxy(pc);
00076   
00077   /* populate arrays */
00078   int M_chunks = (M + m - 1) / m; // same as ceil(1.0 * M / m)
00079   int K_chunks = (K + k - 1) / k; // same as ceil(1.0 * K / k)
00080   int N_chunks = (N + n - 1) / n; // same as ceil(1.0 * N / n)
00081   if(M%m!=0)
00082     M_chunks--;
00083   if(K%k!=0)
00084     K_chunks--;
00085   if(N%n!=0)
00086     N_chunks--;
00087   //  correct for number of chunks
00088   
00089   // just the size of the border elements
00090   if(algorithm == MM_ALG_2D){
00091     for(int i = 0; i < M_chunks; i++)
00092       for(int j = 0; j < K_chunks; j++)
00093         (A->p(i * strideM, j * strideK)).insert(M, K, N, m, k, n, strideM,
00094          strideK, strideN, MULTARG_A, B->p, C->p, cbA, gemmSplitOrtho);
00095     A->p.doneInserting();
00096 
00097     for(int i = 0; i < K_chunks; i++)
00098       for(int j = 0; j < N_chunks; j++)
00099         (B->p(i * strideK, j * strideN)).insert(M, K, N, m, k, n, strideM,
00100          strideK, strideN, MULTARG_B, A->p, C->p, cbB, gemmSplitOrtho);
00101     B->p.doneInserting();
00102 
00103     for(int i = 0; i < M_chunks; i++)
00104       for(int j = 0; j < N_chunks; j++)
00105         (C->p(i * strideM, j * strideN)).insert(M, K, N, m, k, n, strideM,
00106          strideK, strideN, MULTARG_C, A->p, B->p, cbC, gemmSplitOrtho);
00107     C->p.doneInserting();
00108   }
00109   else if(algorithm == MM_ALG_3D){
00110     CProxy_CLA_MM3D_multiplier mult = CProxy_CLA_MM3D_multiplier::ckNew();
00111     int curpe = 0;
00112     int totpe = CkNumPes();
00113     for(int i = 0; i < M_chunks; i++){
00114       int mm = m;
00115       if(i == M_chunks - 1) {
00116         mm = M % m;
00117         if(mm == 0)
00118           mm = m;
00119       }
00120       for(int j = 0; j < N_chunks; j++){
00121         int nn = n;
00122         if(j == N_chunks - 1) {
00123           nn = N % n;
00124           if(nn == 0)
00125             nn = n;
00126         }
00127         for(int l = 0; l < K_chunks; l++){
00128           int kk = k;
00129           if(l == K_chunks - 1) {
00130             kk = K % k;
00131             if(kk == 0)
00132               kk = k;
00133           }
00134           mult(i, j, l).insert(mm, kk, nn, curpe);
00135           curpe = (curpe + 1) % totpe;
00136         }
00137       }
00138     }
00139     mult.doneInserting();
00140 
00141     for(int i = 0; i < M_chunks; i++)
00142       for(int j = 0; j < K_chunks; j++)
00143         (A->p(i * strideM, j * strideK)).insert(mult, M, K, N, m, k, n,
00144          strideM, strideK, strideN, MULTARG_A, cbA, gid, gemmSplitOrtho);
00145     A->p.doneInserting();
00146 
00147     for(int i = 0; i < K_chunks; i++)
00148       for(int j = 0; j < N_chunks; j++)
00149         (B->p(i * strideK, j * strideN)).insert(mult, M, K, N, m, k, n,
00150          strideM, strideK, strideN, MULTARG_B, cbB, gid, gemmSplitOrtho);
00151     B->p.doneInserting();
00152 
00153     for(int i = 0; i < M_chunks; i++)
00154       for(int j = 0; j < N_chunks; j++)
00155         (C->p(i * strideM, j * strideN)).insert(mult, M, K, N, m, k, n,
00156          strideM, strideK, strideN, MULTARG_C, cbC, gid, gemmSplitOrtho);
00157     C->p.doneInserting();
00158   }
00159 
00160   return SUCCESS;
00161 }
00162 
00163 /* Transpose data, which has dimension m x n */
00164 void transpose(double *data, int m, int n){
00165   if(m == n){
00166     /* transpose square matrix in place */
00167     for(int i = 0; i < m; i++)
00168       for(int j = i + 1; j < n; j++){
00169         double tmp = data[i * n + j];
00170         data[i * n + j] = data[j * m + i];
00171         data[j * m + i] = tmp;
00172       }
00173   }
00174   else {
00175     double *tmp = new double[m * n];
00176     CmiMemcpy(tmp, data, m * n * sizeof(double));
00177     for(int i = 0; i < m; i++)
00178       for(int j = 0; j < n; j++)
00179         data[j * m + i] = tmp[i * n + j];
00180     delete [] tmp;
00181   }
00182 }
00183 
00184 /******************************************************************************/
00185 /* CLA_Matrix */
00186 
00187 /* constructor for 2D algorithm */
00188 CLA_Matrix::CLA_Matrix(int M, int K, int N, int m, int k, int n,
00189  int strideM, int strideK, int strideN, int part,
00190  CProxy_CLA_Matrix other1, CProxy_CLA_Matrix other2, CkCallback ready, int _gemmSplitOrtho){
00191   /* initialize simple members */
00192   this->M = M; this->K = K; this->N = N;
00193   this->um = m; this->uk = k; this->un = n;
00194   this->part = part;
00195   this->algorithm = MM_ALG_2D;
00196   this->other1 = other1; this->other2 = other2;
00197   this->M_stride = strideM;
00198   this->K_stride = strideK;
00199   this->N_stride = strideN;
00200   gemmSplitOrtho=_gemmSplitOrtho;
00201   M_chunks = (M + m - 1) / m;
00202   K_chunks = (K + k - 1) / k;
00203   N_chunks = (N + n - 1) / n;
00204   if(M%m!=0)
00205     M_chunks--;
00206   if(K%k!=0)
00207     K_chunks--;
00208   if(N%n!=0)
00209     N_chunks--;
00210   //  correct for number of chunks
00211 
00212   algorithm = MM_ALG_2D;
00213   usesAtSync = CmiFalse;
00214   setMigratable(false);
00215   /* figure out size of our sections */
00216   if(part == MULTARG_A){
00217     if(thisIndex.x == (M_chunks - 1) * strideM){
00218       this->m = m + M % m;
00219       if(this->m == 0) this->m = m;
00220     }
00221     else this->m = m;
00222     if(thisIndex.y == (K_chunks - 1) * strideK){
00223       this->k = k + K % k;
00224       if(this->k == 0) this->k = k;
00225     }
00226     else this->k = k;
00227     this->n = n;
00228   } else if(part == MULTARG_B) {
00229     if(thisIndex.x == (K_chunks - 1) * strideK){
00230       this->k = k + K % k;
00231       if(this->k == 0) this->k = k;
00232     }
00233     else this->k = k;
00234     if(thisIndex.y == (N_chunks - 1) * strideN){
00235       this->n = n + N % n;
00236       if(this->n == 0) this->n = n;
00237     }
00238     else this->n = n;
00239     this->m = m;
00240   } else if(part == MULTARG_C) {
00241     if(thisIndex.x == (M_chunks - 1) * strideM){
00242       this->m = m + M % m;
00243       if(this->m == 0) this->m = m;
00244     }
00245     else this->m = m;
00246     if(thisIndex.y == (N_chunks - 1) * strideN){
00247       this->n = n + N % n;
00248       if(this->n == 0) this->n = n;
00249     }
00250     else this->n = n;
00251     this->k = k;
00252     got_start = false;
00253     row_count = col_count = 0;
00254   }
00255 
00256   /* make communication group for A, B, destination arrays for C */
00257   if(part == MULTARG_A){
00258     commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, thisIndex.x,
00259      thisIndex.x, 1, 0, (N_chunks - 1) * strideN, strideN);
00260     tmpA = tmpB = NULL;
00261   } else if(part == MULTARG_B) {
00262     commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, 0,
00263      (M_chunks - 1) * strideM, strideM, thisIndex.y, thisIndex.y, 1);
00264     tmpA = tmpB = NULL;
00265   } else if(part == MULTARG_C) {
00266     tmpA = new double[this->m * K];
00267     tmpB = new double[K * this->n];
00268   }
00269 
00270   /* let the caller know we are ready */
00271   contribute(0, NULL, CkReduction::sum_int, ready);
00272 }
00273 
00274 /* constructor for 3D algorithm */
00275 CLA_Matrix::CLA_Matrix(CProxy_CLA_MM3D_multiplier p, int M, int K, int N,
00276  int m, int k, int n, int strideM, int strideK, int strideN, int part,
00277  CkCallback cb, CkGroupID gid, int _gemmSplitOrtho){
00278   /* set up easy variable */
00279   this->M = M; this->K = K; this->N = N;
00280   this->um = m; this->uk = k; this->un = n;
00281   this->part = part;
00282   this->algorithm = MM_ALG_2D;
00283   this->other1 = other1; this->other2 = other2;
00284   this->M_stride = strideM;
00285   this->K_stride = strideK;
00286   this->N_stride = strideN;
00287   gemmSplitOrtho=_gemmSplitOrtho;
00288   M_chunks = (M + m - 1) / m;
00289   K_chunks = (K + k - 1) / k;
00290   N_chunks = (N + n - 1) / n;
00291   got_data = got_start = false;
00292   res_msg = NULL;
00293   algorithm = MM_ALG_3D;
00294   usesAtSync = CmiFalse;
00295   setMigratable(false);
00296   /* figure out size of our sections */
00297   if(part == MULTARG_A){
00298     if(thisIndex.x == (M_chunks - 1) * strideM){
00299       this->m = M % m;
00300       if(this->m == 0) this->m = m;
00301     }
00302     else this->m = m;
00303     if(thisIndex.y == (K_chunks - 1) * strideK){
00304       this->k = K % k;
00305       if(this->k == 0) this->k = k;
00306     }
00307     else this->k = k;
00308     this->n = n;
00309   } else if(part == MULTARG_B) {
00310     if(thisIndex.x == (K_chunks - 1) * strideK){
00311       this->k = K % k;
00312       if(this->k == 0) this->k = k;
00313     }
00314     else this->k = k;
00315     if(thisIndex.y == (N_chunks - 1) * strideN){
00316       this->n = N % n;
00317       if(this->n == 0) this->n = n;
00318     }
00319     else this->n = n;
00320     this->m = m;
00321   } else if(part == MULTARG_C) {
00322     if(thisIndex.x == (M_chunks - 1) * strideM){
00323       this->m = M % m;
00324       if(this->m == 0) this->m = m;
00325     }
00326     else this->m = m;
00327     if(thisIndex.y == (N_chunks - 1) * strideN){
00328       this->n = N % n;
00329       if(this->n == 0) this->n = n;
00330     }
00331     else this->n = n;
00332     this->k = k;
00333   }
00334 
00335   /* make communication groups, C also has to initialize reduction sections */
00336   if(part == MULTARG_A){
00337     int x = thisIndex.x / strideM;
00338     int y = thisIndex.y / strideK;
00339     commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, 0,
00340      N_chunks - 1, 1, y, y, 1);
00341     contribute(0, NULL, CkReduction::sum_int, cb);
00342   } else if(part == MULTARG_B) {
00343     int x = thisIndex.x / strideK;
00344     int y = thisIndex.y / strideN;
00345     commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, 0,
00346      M_chunks - 1, 1, y, y, 1, x, x, 1);
00347     contribute(0, NULL, CkReduction::sum_int, cb);
00348   } else if(part == MULTARG_C) {
00349     init_cb = cb;
00350     int x = thisIndex.x / strideM;
00351     int y = thisIndex.y / strideN;
00352     commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, y, y, 1,
00353      0, K_chunks - 1, 1);
00354     commGroup3D.ckSectionDelegate(CProxy_CkMulticastMgr(gid).ckLocalBranch());
00355     CLA_MM3D_mult_init_msg *m = new CLA_MM3D_mult_init_msg(gid,
00356      CkCallback(CkIndex_CLA_Matrix::readyC(NULL),
00357      thisProxy(thisIndex.x, thisIndex.y)), CkCallback(
00358      CkIndex_CLA_Matrix::mult_done(NULL), thisProxy(thisIndex.x,
00359      thisIndex.y)));
00360     commGroup3D.initialize_reduction(m);
00361   }
00362 }
00363 
00364 CLA_Matrix::~CLA_Matrix(){
00365   if(algorithm == MM_ALG_2D){
00366     delete [] tmpA;
00367     delete [] tmpB;
00368   }
00369   else if(algorithm == MM_ALG_3D){
00370     if(res_msg != NULL)
00371       delete res_msg;
00372   }
00373 }
00374 
00375 void CLA_Matrix::pup(PUP::er &p){
00376   /* make sure we are not using the 3D algorithm */
00377   if(algorithm == MM_ALG_3D){
00378     CmiAbort("3D algorithm does not currently support migration.\n");
00379   }
00380 
00381   /* pup super class */
00382   CBase_CLA_Matrix::pup(p);
00383 
00384   /* pup shared vars */
00385   p | M; p | K; p | N; p | m; p | k; p | n;  p | um; p | uk; p | un;
00386   p | M_chunks; p | K_chunks; p | N_chunks;
00387   p | M_stride; p | K_stride; p | N_stride;
00388   p | part; p | algorithm;
00389   p | alpha; p | beta;
00390   p | gemmSplitOrtho;
00391   /* pup vars used by each algorithm */
00392   if(algorithm == MM_ALG_2D){
00393     p | row_count; p | col_count;
00394     p | other1; p | other2;
00395     if(part == MULTARG_C){
00396       if(p.isUnpacking()){
00397         tmpA = new double[m * K];
00398         tmpB = new double[K * n];
00399       }
00400       PUParray(p, tmpA, m * K);
00401       PUParray(p, tmpB, K * n);
00402     }
00403   }
00404   else if(algorithm == MM_ALG_3D){
00405     p | init_cb;
00406     p | got_start; p | got_data;
00407     p | commGroup3D;
00408   }
00409 }
00410 
00411 void CLA_Matrix::ResumeFromSync(void){
00412   /* recreate the section proxies */
00413   if(algorithm == MM_ALG_2D){
00414     if(part == MULTARG_A){
00415       commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, thisIndex.x,
00416        thisIndex.x, 1, 0, (N_chunks - 1) * N_stride, N_stride);
00417       tmpA = tmpB = NULL;
00418     } else if(part == MULTARG_B) {
00419       commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, 0,
00420        (M_chunks - 1) * M_stride, M_stride, thisIndex.y, thisIndex.y, 1);
00421       tmpA = tmpB = NULL;
00422     }
00423   } else if(algorithm == MM_ALG_3D){
00424 #if 0
00425     if(part == MULTARG_A){
00426       int x = thisIndex.x / M_stride;
00427       int y = thisIndex.y / K_stride;
00428       commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, 0,
00429        N_chunks - 1, 1, y, y, 1);
00430       contribute(0, NULL, CkReduction::sum_int, cb);
00431     } else if(part == MULTARG_B) {
00432       int x = thisIndex.x / K_stride;
00433       int y = thisIndex.y / N_stride;
00434       commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, 0,
00435        M_chunks - 1, 1, y, y, 1, x, x, 1);
00436       contribute(0, NULL, CkReduction::sum_int, cb);
00437     } else if(part == MULTARG_C) {
00438       init_cb = cb;
00439       int x = thisIndex.x / M_stride;
00440       int y = thisIndex.y / N_stride;
00441       commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, y, y,
00442        1, 0, K_chunks - 1, 1);
00443 /*
00444       commGroup3D.ckSectionDelegate(CProxy_CkMulticastMgr(gid).ckLocalBranch());
00445       CLA_MM3D_mult_init_msg *m = new CLA_MM3D_mult_init_msg(gid,
00446        CkCallback(CkIndex_CLA_Matrix::readyC(NULL),
00447        thisProxy(thisIndex.x, thisIndex.y)), CkCallback(
00448        CkIndex_CLA_Matrix::mult_done(NULL), thisProxy(thisIndex.x,
00449        thisIndex.y)));
00450       commGroup3D.initialize_reduction(m);
00451 */
00452     }
00453 #endif
00454   }
00455 }
00456 
00457 void CLA_Matrix::multiply(double alpha, double beta, double *data,
00458  void (*fptr) (void*), void *usr_data){
00459   if(algorithm == MM_ALG_2D){
00460     // A and B send out their chunks, ignoring alpha, beta, ftpr, and usr_data
00461     if(part == MULTARG_A){
00462       CLA_Matrix_msg *msg = new (m * k) CLA_Matrix_msg(data, m, k, thisIndex.x,
00463        thisIndex.y);
00464       commGroup2D.receiveA(msg);
00465     } else if(part == MULTARG_B){
00466       CLA_Matrix_msg *msg = new (k * n) CLA_Matrix_msg(data, k, n, thisIndex.x,
00467        thisIndex.y);
00468       commGroup2D.receiveB(msg);
00469     }
00470     /* C stores the paramters for the multiplication */
00471     else if(part == MULTARG_C){
00472       fcb = fptr;
00473       user_data = usr_data;
00474       dest = data;
00475       this->alpha = alpha;
00476       this->beta = beta;
00477       got_start = true;
00478       /* Check if we were slow to arrive */
00479       if(row_count == K_chunks && col_count == K_chunks)
00480         multiply();
00481     }
00482     else
00483       CmiAbort("CLA_Matrix internal error");
00484   } else if(algorithm == MM_ALG_3D){
00485     if(part == MULTARG_A){
00486       CLA_Matrix_msg *msg = new (m * k) CLA_Matrix_msg(data, m, k, thisIndex.x,
00487        thisIndex.y);
00488       commGroup3D.receiveA(msg);
00489     } else if(part == MULTARG_B){
00490       CLA_Matrix_msg *msg = new (k * n) CLA_Matrix_msg(data, k, n, thisIndex.x,
00491        thisIndex.y);
00492       commGroup3D.receiveB(msg);
00493     } else if(part == MULTARG_C){
00494       fcb = fptr;
00495       user_data = usr_data;
00496       dest = data;
00497       this->alpha = alpha;
00498       this->beta = beta;
00499       got_start = true;
00500       if(got_data){
00501         got_start = got_data = false;
00502         /* transpose reduction msg and do the alpha and beta multiplications */
00503         double *data = (double*) res_msg->getData();
00504         transpose(data, n, m);
00505         for(int i = 0; i < m; i++)
00506           for(int j = 0; j < n; j++)
00507             dest[i * n + j] = beta * dest[i * n + j] + alpha * data[i * n + j];
00508         delete res_msg;
00509         res_msg = NULL;
00510         (*fcb)(user_data);
00511       }
00512     }
00513   }
00514 }
00515 
00516 void CLA_Matrix::receiveA(CLA_Matrix_msg *msg){
00517   /* store current part */
00518   row_count++;
00519   for(int i = 0; i < m; i++)
00520     CmiMemcpy(&tmpA[K * i + uk * (msg->fromY / K_stride)], &msg->data[i * msg->d2],
00521      msg->d2 * sizeof(double));
00522   delete msg;
00523 
00524   /* If we have all the parts, multiply */
00525   if(row_count == K_chunks && col_count == K_chunks && got_start)
00526     multiply();
00527 }
00528 
00529 void CLA_Matrix::receiveB(CLA_Matrix_msg *msg){
00530   /* store current part */
00531   col_count++;
00532   CmiMemcpy(&tmpB[n * uk * (msg->fromX / K_stride)], msg->data,
00533    msg->d1 * msg->d2 * sizeof(double));
00534   delete msg;
00535 
00536   /* If we have all the parts, multiply */
00537   if(row_count == K_chunks && col_count == K_chunks && got_start)
00538     multiply();
00539 }
00540 
00541 void CLA_Matrix::multiply(){
00542   /* reset counters */
00543   row_count = col_count = 0;
00544   got_start = false;
00545 
00546   /* transpose result matrix (if beta != 0) */
00547   if(beta != 0)
00548     transpose(dest, m, n);
00549   /* multiply */
00550   char trans = 'T';
00551 #define ORTHO_DGEMM_SPLIT 
00552 
00553 #define BUNDLE_USER_EVENTS
00554 #ifdef ORTHO_DGEMM_SPLIT 
00555   double betap = 1.0;
00556   int Ksplit_m=gemmSplitOrtho;  
00557   int Ksplit   = (K > Ksplit_m) ? Ksplit_m : K;
00558   int Krem     = (K % Ksplit);
00559   int Kloop    = K/Ksplit-1;
00560 
00561 #ifndef CMK_OPTIMIZE
00562   double StartTime=CmiWallTimer();
00563 #endif
00564 
00565 #ifdef TEST_ALIGN
00566   CkAssert((unsigned int) tmpA %16==0);
00567   CkAssert((unsigned int) tmpB %16==0);
00568   CkAssert((unsigned int) dest %16==0);
00569 #endif
00570 
00571 #ifdef PRINT_DGEMM_PARAMS
00572   CkPrintf("HEY-DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, Ksplit, alpha, beta, K, n, m);
00573 #endif
00574 
00575 #ifdef _NAN_CHECK_
00576   for(int in=0; in<Ksplit; in++)
00577     for(int jn=0; jn<m; jn++)
00578       CkAssert(finite(tmpA[in*m+jn]));
00579 
00580   for(int in=0; in<n; in++)
00581     for(int jn=0; jn<Ksplit; jn++)
00582       CkAssert(finite(tmpB[in*K+jn]));
00583 #endif
00584 
00585   DGEMM(&trans, &trans, &m, &n, &Ksplit, &alpha, tmpA, &K, tmpB, &n, &beta,
00586         dest,&m);
00587 
00588 #ifdef _NAN_CHECK_
00589   for(int in=0; in<m; in++)
00590     for(int jn=0; jn<n; jn++)
00591       CkAssert(finite(dest[in*n+jn]));
00592 #endif
00593 
00594 #ifndef BUNDLE_USER_EVENTS
00595 #ifndef CMK_OPTIMIZE
00596   traceUserBracketEvent(401, StartTime, CmiWallTimer());
00597 #endif
00598 #endif
00599 
00600   CmiNetworkProgress();
00601   
00602   for(int i=1;i<=Kloop;i++){
00603     int aoff = Ksplit*i;
00604     int boff = n*i*Ksplit;
00605     if(i==Kloop){Ksplit+=Krem;}
00606 #ifndef BUNDLE_USER_EVENTS
00607 #ifndef CMK_OPTIMIZE
00608     StartTime=CmiWallTimer();
00609 #endif
00610 #endif
00611 
00612 #ifdef TEST_ALIGN
00613     CkAssert((unsigned int) &(tmpA[aoff]) %16==0);
00614     CkAssert((unsigned int) &(tmpB[boff]) %16==0);
00615     CkAssert((unsigned int) dest %16==0);
00616 #endif
00617 
00618 #ifdef PRINT_DGEMM_PARAMS
00619     CkPrintf("HEY-DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, Ksplit, alpha, betap, K, n, m);
00620 #endif
00621 
00622 #ifdef _NAN_CHECK_
00623     for(int in=0; in<Ksplit; in++)
00624       for(int jn=0; jn<m; jn++)
00625         CkAssert(finite(tmpA[aoff+in*m+jn]));
00626 
00627     for(int in=0; in<n; in++)
00628       for(int jn=0; jn<Ksplit; jn++)
00629         CkAssert(finite(tmpB[boff+in*Ksplit+jn]));
00630 #endif
00631 
00632     DGEMM(&trans, &trans, &m, &n, &Ksplit, &alpha, &tmpA[aoff], &K, 
00633           &tmpB[boff], &n, &betap, dest, &m);
00634 
00635 #ifdef _NAN_CHECK_
00636     for(int in=0; in<m; in++)
00637       for(int jn=0; jn<n; jn++)
00638         CkAssert(finite(dest[in*n+jn]));
00639 #endif
00640 
00641 #ifndef BUNDLE_USER_EVENTS
00642 #ifndef CMK_OPTIMIZE
00643     traceUserBracketEvent(401, StartTime, CmiWallTimer());
00644 #endif
00645 #endif
00646     CmiNetworkProgress();
00647   }//endfor
00648 
00649 #ifdef BUNDLE_USER_EVENTS
00650 #ifndef CMK_OPTIMIZE
00651     traceUserBracketEvent(401, StartTime, CmiWallTimer());
00652 #endif
00653 #endif
00654 
00655 #else
00656   /* old unsplit version */
00657 #ifndef CMK_OPTIMIZE
00658     double StartTime=CmiWallTimer();
00659 #endif
00660 
00661 #ifdef PRINT_DGEMM_PARAMS
00662   CkPrintf("CLA_MATRIX DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, K, alpha, beta, K, n, m);
00663 #endif
00664 
00665 #ifdef _NAN_CHECK_
00666   for(int in=0; in<K; in++)
00667     for(int jn=0; jn<m; jn++)
00668       CkAssert(finite(tmpA[in*m+jn]));
00669 
00670   for(int in=0; in<n; in++)
00671     for(int jn=0; jn<K; jn++)
00672       CkAssert(finite(tmpB[in*K+jn]));
00673 #endif
00674 
00675      DGEMM(&trans, &trans, &m, &n, &K, &alpha, tmpA, &K, tmpB, &n, &beta,
00676    dest, &m);
00677 
00678 #ifdef _NAN_CHECK_
00679   for(int in=0; in<m; in++)
00680     for(int jn=0; jn<n; jn++)
00681       CkAssert(finite(dest[in*n+jn]));
00682 #endif
00683 
00684 #ifndef CMK_OPTIMIZE
00685     traceUserBracketEvent(401, StartTime, CmiWallTimer());
00686 #endif
00687 
00688 #endif
00689   /* transpose result */
00690   transpose(dest, n, m);
00691 
00692   /* tell caller we are done */
00693   fcb(user_data);
00694 }
00695 
00696 void CLA_Matrix::readyC(CkReductionMsg *msg){
00697   CkCallback cb(CkIndex_CLA_Matrix::ready(NULL), thisProxy(0, 0));
00698   contribute(0, NULL, CkReduction::sum_int, cb);
00699   delete msg;
00700 }
00701 
00702 void CLA_Matrix::ready(CkReductionMsg *msg){
00703   init_cb.send();
00704   delete msg;
00705 }
00706 
00707 void CLA_Matrix::mult_done(CkReductionMsg *msg){
00708   if(got_start){
00709     got_start = got_data = false;
00710     /* transpose reduction msg and do the alpha and beta multiplications */
00711     double *data = (double*) msg->getData();
00712     transpose(data, n, m);
00713     for(int i = 0; i < m; i++)
00714       for(int j = 0; j < n; j++)
00715         dest[i * n + j] = beta * dest[i * n + j] + alpha * data[i * n + j];
00716     delete msg;
00717     msg = NULL;
00718     (*fcb)(user_data);
00719   }
00720   else{
00721     got_data = true;
00722     res_msg = msg;
00723   }
00724 }
00725 
00726 /******************************************************************************/
00727 /* CLA_Matrix_msg */
00728 CLA_Matrix_msg::CLA_Matrix_msg(double *data, int d1, int d2, int fromX,
00729  int fromY){
00730   CmiMemcpy(this->data, data, d1 * d2 * sizeof(double));
00731   this->d1 = d1; this->d2 = d2;
00732   this->fromX = fromX; this->fromY = fromY;
00733 }
00734 
00735 /******************************************************************************/
00736 /* CLA_MM3D_multiplier */
00737 CLA_MM3D_multiplier::CLA_MM3D_multiplier(int m, int k, int n){
00738   this->m = m; this->k = k; this->n = n;
00739   data_msg = NULL;
00740   gotA = gotB = false;
00741 }
00742 
00743 void CLA_MM3D_multiplier::initialize_reduction(CLA_MM3D_mult_init_msg *m){
00744   reduce_CB = m->reduce;
00745   CkGetSectionInfo(sectionCookie, m);
00746   redGrp = CProxy_CkMulticastMgr(m->gid).ckLocalBranch();
00747   redGrp->contribute(0, NULL, CkReduction::sum_int, sectionCookie, m->ready);
00748   delete m;
00749 }
00750 
00751 void CLA_MM3D_multiplier::receiveA(CLA_Matrix_msg *msg){
00752   gotA = true;
00753   if(gotB){
00754     multiply(msg->data, data_msg->data);
00755     delete msg;
00756     delete data_msg;
00757   }
00758   else
00759     data_msg = msg;
00760 }
00761 
00762 void CLA_MM3D_multiplier::receiveB(CLA_Matrix_msg *msg){
00763   gotB = true;
00764   if(gotA){
00765     multiply(data_msg->data, msg->data);
00766     delete msg;
00767     delete data_msg;
00768   }
00769   else
00770     data_msg = msg;
00771 }
00772 
00773 void CLA_MM3D_multiplier::multiply(double *A, double *B){
00774   double alpha = 1, beta = 0;
00775   gotA = gotB = false;
00776   char trans = 'T';
00777   double *C = new double[m * n];
00778 #ifndef CMK_OPTIMIZE
00779   double  StartTime=CmiWallTimer();
00780 #endif
00781 #ifdef TEST_ALIGN
00782   CkAssert((unsigned int) A %16==0);
00783   CkAssert((unsigned int) B %16==0);
00784   CkAssert((unsigned int) C %16==0);
00785 #endif
00786 
00787 #ifdef PRINT_DGEMM_PARAMS
00788   CkPrintf("HEY-DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, k, alpha, beta, k, n, m);
00789 #endif
00790   DGEMM(&trans, &trans, &m, &n, &k, &alpha, A, &k, B, &n, &beta, C, &m);
00791 #ifndef CMK_OPTIMIZE
00792     traceUserBracketEvent(402, StartTime, CmiWallTimer());
00793 #endif
00794   CmiNetworkProgress();
00795   //    redGrp->contribute(m * n * sizeof(double), C, CkReduction::sum_double,
00796   redGrp->contribute(m * n * sizeof(double), C, sumFastDoubleType,
00797    sectionCookie, reduce_CB);
00798   delete [] C;
00799 }
00800 
00801 //@}
00802 #include "CLA_Matrix.def.h"
00803 

Generated on Thu Dec 6 18:25:28 2007 for leanCP by  doxygen 1.5.3