OpenAtom  Version1.5a
CLA_Matrix.C
1 #include "CLA_Matrix.h"
2 
3 #if 0
4 #include <sstream>
5 using std::ostringstream;
6 using std::endl;
7 #endif
8 
9 #ifdef FORTRANUNDERSCORE
10 #define DGEMM dgemm_
11 #else
12 #define DGEMM dgemm
13 #endif
14 
15 extern "C" {void DGEMM(char *, char *, int *, int *, int *, double *,
16  double *, int *, double *, int *, double *,
17  double *, int *);}
18 
19 void myGEMM(char *opA, char *opB, int *m, int *n, int *k, double *alpha, complex *A, int *lda, complex *B, int *ldb, double *beta, complex *C, int *ldc);
20 void myGEMM(char *opA, char *opB, int *m, int *n, int *k, double *alpha, double *A, int *lda, double *B, int *ldb, double *beta, double *C, int *ldc);
21 #define MULTARG_A 0
22 #define MULTARG_B 1
23 #define MULTARG_C 2
24 extern CkReduction::reducerType sumFastDoubleType;
25 #include "load_balance/MapTable.h"
26 #include "orthog_ctrl/ortho.h"
27 /** @addtogroup Ortho
28  @{
29 */
30 
31 /******************************************************************************/
32 /* helper functions */
33 
34 /* Should be called by user to create matrices. Documented in header file. */
35 int make_multiplier(
37  CProxy_ArrayElement bindA, CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
38  int M, //nstates
39  int K, //nstates
40  int N, //nstates
41  int m, //orthograinsize
42  int k, //orthograinsize
43  int n, //orthograinsize
44  int strideM, // 1
45  int strideK, // 1
46  int strideN, // 1
47  CkCallback cbA, CkCallback cbB, CkCallback cbC,
48  CkGroupID gid, int algorithm, int gemmSplitOrtho
49  )
50 {
51  /* validate arguments */
52  if(algorithm < MM_ALG_MIN || MM_ALG_MAX < algorithm)
53  return ERR_INVALID_ALG;
54 
55  if(m > M || k > K || n > N)
56  return ERR_INVALID_DIM;
57 
58  /* create arrays */
59  CkArrayOptions optsA, optsB, optsC;
60  optsA.bindTo(bindA);
61  optsB.bindTo(bindB);
62  optsC.bindTo(bindC);
63  optsA.setAnytimeMigration(false);
64  optsB.setAnytimeMigration(false);
65  optsC.setAnytimeMigration(false);
66  CProxy_CLA_Matrix pa = CProxy_CLA_Matrix::ckNew(optsA);
67  CProxy_CLA_Matrix pb = CProxy_CLA_Matrix::ckNew(optsB);
68  CProxy_CLA_Matrix pc = CProxy_CLA_Matrix::ckNew(optsC);
69  A->setProxy(pa);
70  B->setProxy(pb);
71  C->setProxy(pc);
72 
73  /* populate arrays */
74  int M_chunks = (M + m - 1) / m; // same as ceil(1.0 * M / m)
75  int K_chunks = (K + k - 1) / k; // same as ceil(1.0 * K / k)
76  int N_chunks = (N + n - 1) / n; // same as ceil(1.0 * N / n)
77  if(M%m!=0)
78  M_chunks--;
79  if(K%k!=0)
80  K_chunks--;
81  if(N%n!=0)
82  N_chunks--;
83 
84  // correct for number of chunks
85  // just the size of the border elements
86  if(algorithm == MM_ALG_2D)
87  {
88  for(int i = 0; i < M_chunks; i++)
89  for(int j = 0; j < K_chunks; j++)
90  (A->p(i * strideM, j * strideK)).insert(M, K, N, m, k, n, strideM, strideK, strideN, MULTARG_A, B->p, C->p, cbA, gemmSplitOrtho);
91  A->p.doneInserting();
92 
93  for(int i = 0; i < K_chunks; i++)
94  for(int j = 0; j < N_chunks; j++)
95  (B->p(i * strideK, j * strideN)).insert(M, K, N, m, k, n, strideM, strideK, strideN, MULTARG_B, A->p, C->p, cbB, gemmSplitOrtho);
96  B->p.doneInserting();
97 
98  for(int i = 0; i < M_chunks; i++)
99  for(int j = 0; j < N_chunks; j++)
100  (C->p(i * strideM, j * strideN)).insert(M, K, N, m, k, n, strideM, strideK, strideN, MULTARG_C, A->p, B->p, cbC, gemmSplitOrtho);
101  C->p.doneInserting();
102  }
103  else if(algorithm == MM_ALG_3D)
104  {
105  CProxy_CLA_MM3D_multiplier mult = CProxy_CLA_MM3D_multiplier::ckNew();
106  int curpe = 0;
107  int totpe = CkNumPes();
108  for(int i = 0; i < M_chunks; i++)
109  {
110  int mm = m;
111  if(i == M_chunks - 1)
112  {
113  mm = M % m;
114  if(mm == 0)
115  mm = m;
116  }
117 
118  for(int j = 0; j < N_chunks; j++)
119  {
120  int nn = n;
121  if(j == N_chunks - 1)
122  {
123  nn = N % n;
124  if(nn == 0)
125  nn = n;
126  }
127 
128  for(int l = 0; l < K_chunks; l++)
129  {
130  int kk = k;
131  if(l == K_chunks - 1)
132  {
133  kk = K % k;
134  if(kk == 0)
135  kk = k;
136  }
137  mult(i, j, l).insert(mm, kk, nn, curpe);
138  curpe = (curpe + 1) % totpe;
139  }
140  }
141  }
142  mult.doneInserting();
143 
144  for(int i = 0; i < M_chunks; i++)
145  for(int j = 0; j < K_chunks; j++)
146  (A->p(i * strideM, j * strideK)).insert(mult, M, K, N, m, k, n,strideM, strideK, strideN, MULTARG_A, cbA, gid, gemmSplitOrtho);
147  A->p.doneInserting();
148 
149  for(int i = 0; i < K_chunks; i++)
150  for(int j = 0; j < N_chunks; j++)
151  (B->p(i * strideK, j * strideN)).insert(mult, M, K, N, m, k, n, strideM, strideK, strideN, MULTARG_B, cbB, gid, gemmSplitOrtho);
152  B->p.doneInserting();
153 
154  for(int i = 0; i < M_chunks; i++)
155  for(int j = 0; j < N_chunks; j++)
156  (C->p(i * strideM, j * strideN)).insert(mult, M, K, N, m, k, n, strideM, strideK, strideN, MULTARG_C, cbC, gid, gemmSplitOrtho);
157  C->p.doneInserting();
158  }
159 
160  return SUCCESS;
161 }
162 
163 
164 
165 
166 /******************************************************************************/
167 /* CLA_Matrix */
168 
169 /* constructor for 2D algorithm */
170 CLA_Matrix::CLA_Matrix(int _M, int _K, int _N, int _m, int _k, int _n,
171  int strideM, int strideK, int strideN, int _part,
172  CProxy_CLA_Matrix _other1, CProxy_CLA_Matrix _other2, CkCallback ready, int _gemmSplitOrtho){
173  /* initialize simple members */
174  this->M = _M; this->K = _K; this->N = _N;
175  this->um = _m; this->uk = _k; this->un = _n;
176  this->m = _m; this->k = _k; this->n = _n;
177  this->part = _part;
178  this->algorithm = MM_ALG_2D;
179  this->other1 = _other1; this->other2 = _other2;
180  this->M_stride = strideM;
181  this->K_stride = strideK;
182  this->N_stride = strideN;
183  gemmSplitOrtho=_gemmSplitOrtho;
184  M_chunks = (_M + _m - 1) / _m;
185  K_chunks = (_K + _k - 1) / _k;
186  N_chunks = (_N + _n - 1) / _n;
187  if(M % m != 0)
188  M_chunks--;
189  if(K % k != 0)
190  K_chunks--;
191  if(N % n != 0)
192  N_chunks--;
193  // correct for number of chunks
194 
195  algorithm = MM_ALG_2D;
196  usesAtSync = false;
197  setMigratable(false);
198  /* figure out size of our sections */
199  if(part == MULTARG_A){
200  if(thisIndex.x == (M_chunks - 1) * strideM){
201  this->m = _m + _M % _m;
202  if(this->m == 0) this->m = _m;
203  }
204  else this->m = _m;
205  if(thisIndex.y == (K_chunks - 1) * strideK){
206  this->k = _k + _K % _k;
207  if(this->k == 0) this->k = _k;
208  }
209  else this->k = _k;
210  this->n = _n;
211  } else if(part == MULTARG_B) {
212  if(thisIndex.x == (K_chunks - 1) * strideK){
213  this->k = _k + _K % _k;
214  if(this->k == 0) this->k = _k;
215  }
216  else this->k = _k;
217  if(thisIndex.y == (N_chunks - 1) * strideN){
218  this->n = _n + _N % _n;
219  if(this->n == 0) this->n = _n;
220  }
221  else this->n = _n;
222  this->m = _m;
223  } else if(part == MULTARG_C) {
224  if(thisIndex.x == (M_chunks - 1) * strideM){
225  this->m = _m + _M % _m;
226  if(this->m == 0) this->m = _m;
227  }
228  else this->m = _m;
229  if(thisIndex.y == (N_chunks - 1) * strideN){
230  this->n = _n + _N % _n;
231  if(this->n == 0) this->n = _n;
232  }
233  else this->n = _n;
234  this->k = _k;
235  got_start = false;
236  row_count = col_count = 0;
237  }
238 
239  /* make communication group for A, B, destination arrays for C */
240  if(part == MULTARG_A){
241  commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, thisIndex.x,
242  thisIndex.x, 1, 0, (N_chunks - 1) * strideN, strideN);
243  tmpA = tmpB = NULL;
244  } else if(part == MULTARG_B) {
245  commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, 0,
246  (M_chunks - 1) * strideM, strideM, thisIndex.y, thisIndex.y, 1);
247  tmpA = tmpB = NULL;
248  } else if(part == MULTARG_C) {
249  tmpA = new internalType[this->m * K];
250  tmpB = new internalType[K * this->n];
251  }
252 
253  /* let the caller know we are ready */
254  contribute(0, NULL, CkReduction::sum_int, ready);
255 }
256 
257 /* constructor for 3D algorithm */
258 CLA_Matrix::CLA_Matrix(CProxy_CLA_MM3D_multiplier p, int M, int K, int N,
259  int m, int k, int n, int strideM, int strideK, int strideN, int part,
260  CkCallback cb, CkGroupID gid, int _gemmSplitOrtho){
261  /* set up easy variable */
262  this->M = M; this->K = K; this->N = N;
263  this->um = m; this->uk = k; this->un = n;
264  this->part = part;
265  this->algorithm = MM_ALG_2D;
266  this->other1 = other1; this->other2 = other2;
267  this->M_stride = strideM;
268  this->K_stride = strideK;
269  this->N_stride = strideN;
270  gemmSplitOrtho=_gemmSplitOrtho;
271  M_chunks = (M + m - 1) / m;
272  K_chunks = (K + k - 1) / k;
273  N_chunks = (N + n - 1) / n;
274  got_data = got_start = false;
275  res_msg = NULL;
276  algorithm = MM_ALG_3D;
277  usesAtSync = false;
278  setMigratable(false);
279  /* figure out size of our sections */
280  if(part == MULTARG_A){
281  if(thisIndex.x == (M_chunks - 1) * strideM){
282  this->m = M % m;
283  if(this->m == 0) this->m = m;
284  }
285  else this->m = m;
286  if(thisIndex.y == (K_chunks - 1) * strideK){
287  this->k = K % k;
288  if(this->k == 0) this->k = k;
289  }
290  else this->k = k;
291  this->n = n;
292  } else if(part == MULTARG_B) {
293  if(thisIndex.x == (K_chunks - 1) * strideK){
294  this->k = K % k;
295  if(this->k == 0) this->k = k;
296  }
297  else this->k = k;
298  if(thisIndex.y == (N_chunks - 1) * strideN){
299  this->n = N % n;
300  if(this->n == 0) this->n = n;
301  }
302  else this->n = n;
303  this->m = m;
304  } else if(part == MULTARG_C) {
305  if(thisIndex.x == (M_chunks - 1) * strideM){
306  this->m = M % m;
307  if(this->m == 0) this->m = m;
308  }
309  else this->m = m;
310  if(thisIndex.y == (N_chunks - 1) * strideN){
311  this->n = N % n;
312  if(this->n == 0) this->n = n;
313  }
314  else this->n = n;
315  this->k = k;
316  }
317 
318  /* make communication groups, C also has to initialize reduction sections */
319  if(part == MULTARG_A){
320  int x = thisIndex.x / strideM;
321  int y = thisIndex.y / strideK;
322  commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, 0,
323  N_chunks - 1, 1, y, y, 1);
324  contribute(0, NULL, CkReduction::sum_int, cb);
325  } else if(part == MULTARG_B) {
326  int x = thisIndex.x / strideK;
327  int y = thisIndex.y / strideN;
328  commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, 0,
329  M_chunks - 1, 1, y, y, 1, x, x, 1);
330  contribute(0, NULL, CkReduction::sum_int, cb);
331  } else if(part == MULTARG_C) {
332  init_cb = cb;
333  int x = thisIndex.x / strideM;
334  int y = thisIndex.y / strideN;
335  commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, y, y, 1,
336  0, K_chunks - 1, 1);
337  commGroup3D.ckSectionDelegate(CProxy_CkMulticastMgr(gid).ckLocalBranch());
339  CkCallback(CkIndex_CLA_Matrix::readyC(NULL),
340  thisProxy(thisIndex.x, thisIndex.y)), CkCallback(
341  CkIndex_CLA_Matrix::mult_done(NULL), thisProxy(thisIndex.x,
342  thisIndex.y)));
343  commGroup3D.initialize_reduction(m);
344  }
345 }
346 
347 CLA_Matrix::~CLA_Matrix(){
348  if(algorithm == MM_ALG_2D){
349  delete [] tmpA;
350  delete [] tmpB;
351  }
352  else if(algorithm == MM_ALG_3D){
353  if(res_msg != NULL)
354  delete res_msg;
355  }
356 }
357 
358 void CLA_Matrix::pup(PUP::er &p){
359  /* make sure we are not using the 3D algorithm */
360  if(algorithm == MM_ALG_3D){
361  CmiAbort("3D algorithm does not currently support migration.\n");
362  }
363 
364  /* pup super class */
365  CBase_CLA_Matrix::pup(p);
366 
367  /* pup shared vars */
368  p | M; p | K; p | N; p | m; p | k; p | n; p | um; p | uk; p | un;
369  p | M_chunks; p | K_chunks; p | N_chunks;
370  p | M_stride; p | K_stride; p | N_stride;
371  p | part; p | algorithm;
372  p | alpha; p | beta;
373  p | gemmSplitOrtho;
374  /* pup vars used by each algorithm */
375  if(algorithm == MM_ALG_2D){
376  p | row_count; p | col_count;
377  p | other1; p | other2;
378  if(part == MULTARG_C){
379  if(p.isUnpacking()){
380  tmpA = new internalType[m * K];
381  tmpB = new internalType[K * n];
382  }
383  PUParray(p, tmpA, m * K);
384  PUParray(p, tmpB, K * n);
385  }
386  }
387  else if(algorithm == MM_ALG_3D){
388  p | init_cb;
389  p | got_start; p | got_data;
390  p | commGroup3D;
391  }
392 }
393 
394 void CLA_Matrix::ResumeFromSync(void){
395  /* recreate the section proxies */
396  if(algorithm == MM_ALG_2D){
397  if(part == MULTARG_A){
398  commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, thisIndex.x,
399  thisIndex.x, 1, 0, (N_chunks - 1) * N_stride, N_stride);
400  tmpA = tmpB = NULL;
401  } else if(part == MULTARG_B) {
402  commGroup2D = CProxySection_CLA_Matrix::ckNew(other2, 0,
403  (M_chunks - 1) * M_stride, M_stride, thisIndex.y, thisIndex.y, 1);
404  tmpA = tmpB = NULL;
405  }
406  } else if(algorithm == MM_ALG_3D){
407 #if 0
408  if(part == MULTARG_A){
409  int x = thisIndex.x / M_stride;
410  int y = thisIndex.y / K_stride;
411  commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, 0,
412  N_chunks - 1, 1, y, y, 1);
413  contribute(0, NULL, CkReduction::sum_int, cb);
414  } else if(part == MULTARG_B) {
415  int x = thisIndex.x / K_stride;
416  int y = thisIndex.y / N_stride;
417  commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, 0,
418  M_chunks - 1, 1, y, y, 1, x, x, 1);
419  contribute(0, NULL, CkReduction::sum_int, cb);
420  } else if(part == MULTARG_C) {
421  init_cb = cb;
422  int x = thisIndex.x / M_stride;
423  int y = thisIndex.y / N_stride;
424  commGroup3D = CProxySection_CLA_MM3D_multiplier::ckNew(p, x, x, 1, y, y,
425  1, 0, K_chunks - 1, 1);
426 /*
427  commGroup3D.ckSectionDelegate(CProxy_CkMulticastMgr(gid).ckLocalBranch());
428  CLA_MM3D_mult_init_msg *m = new CLA_MM3D_mult_init_msg(gid,
429  CkCallback(CkIndex_CLA_Matrix::readyC(NULL),
430  thisProxy(thisIndex.x, thisIndex.y)), CkCallback(
431  CkIndex_CLA_Matrix::mult_done(NULL), thisProxy(thisIndex.x,
432  thisIndex.y)));
433  commGroup3D.initialize_reduction(m);
434 */
435  }
436 #endif
437  }
438 }
439 
440 void CLA_Matrix::multiply(double alpha, double beta, internalType *data,
441  void (*fptr) (void*), void *usr_data){
442  if(algorithm == MM_ALG_2D){
443  // A and B send out their chunks, ignoring alpha, beta, ftpr, and usr_data
444  if(part == MULTARG_A){
445  CLA_Matrix_msg *msg = new (m * k) CLA_Matrix_msg(data, m, k, thisIndex.x,
446  thisIndex.y);
447  commGroup2D.receiveA(msg);
448  } else if(part == MULTARG_B){
449  CLA_Matrix_msg *msg = new (k * n) CLA_Matrix_msg(data, k, n, thisIndex.x,
450  thisIndex.y);
451  commGroup2D.receiveB(msg);
452  }
453  /* C stores the paramters for the multiplication */
454  else if(part == MULTARG_C){
455  fcb = fptr;
456  user_data = usr_data;
457  dest = data;
458  this->alpha = alpha;
459  this->beta = beta;
460  got_start = true;
461  /* Check if we were slow to arrive */
462  if(row_count == K_chunks && col_count == K_chunks)
463  multiply();
464  }
465  else
466  CmiAbort("CLA_Matrix internal error");
467  } else if(algorithm == MM_ALG_3D){
468  if(part == MULTARG_A){
469  CLA_Matrix_msg *msg = new (m * k) CLA_Matrix_msg(data, m, k, thisIndex.x,
470  thisIndex.y);
471  commGroup3D.receiveA(msg);
472  } else if(part == MULTARG_B){
473  CLA_Matrix_msg *msg = new (k * n) CLA_Matrix_msg(data, k, n, thisIndex.x,
474  thisIndex.y);
475  commGroup3D.receiveB(msg);
476  } else if(part == MULTARG_C){
477  fcb = fptr;
478  user_data = usr_data;
479  dest = data;
480  this->alpha = alpha;
481  this->beta = beta;
482  got_start = true;
483  if(got_data){
484  got_start = got_data = false;
485  /* transpose reduction msg and do the alpha and beta multiplications */
486  internalType *data = (internalType*) res_msg->getData();
487  transpose(data, n, m);
488  for(int i = 0; i < m; i++)
489  for(int j = 0; j < n; j++)
490  dest[i * n + j] = beta * dest[i * n + j] + alpha * data[i * n + j];
491  delete res_msg;
492  res_msg = NULL;
493  (*fcb)(user_data);
494  }
495  }
496  }
497 }
498 
499 void CLA_Matrix::receiveA(CLA_Matrix_msg *msg){
500  /* store current part */
501  row_count++;
502  for(int i = 0; i < m; i++)
503  CmiMemcpy(&tmpA[K * i + uk * (msg->fromY / K_stride)], &msg->data[i * msg->d2],
504  msg->d2 * sizeof(internalType));
505  delete msg;
506 
507  /* If we have all the parts, multiply */
508  if(row_count == K_chunks && col_count == K_chunks && got_start)
509  multiply();
510 }
511 
512 void CLA_Matrix::receiveB(CLA_Matrix_msg *msg){
513  /* store current part */
514  col_count++;
515  CmiMemcpy(&tmpB[n * uk * (msg->fromX / K_stride)], msg->data,
516  msg->d1 * msg->d2 * sizeof(internalType));
517  delete msg;
518 
519  /* If we have all the parts, multiply */
520  if(row_count == K_chunks && col_count == K_chunks && got_start)
521  multiply();
522 }
523 
524 
525 
526 
527 void CLA_Matrix::multiply()
528 {
529  // Reset counters
530  row_count = col_count = 0;
531  got_start = false;
532 
533  // Transpose result matrix (if beta != 0)
534  if(beta != 0)
535  transpose(dest, m, n);
536  // Multiply
537  char trans = 'T';
538 
539  #define BUNDLE_USER_EVENTS
540 
541  #ifdef CMK_TRACE_ENABLED
542  double StartTime=CmiWallTimer();
543  #endif
544  #ifdef PRINT_DGEMM_PARAMS
545  CkPrintf("CLA_MATRIX DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, K, alpha, beta, K, n, m);
546  #endif
547  #ifdef _NAN_CHECK_
548  for(int in=0; in<K; in++)
549  for(int jn=0; jn<m; jn++)
550  CkAssert(isfinite(tmpA[in*m+jn]));
551  for(int in=0; in<n; in++)
552  for(int jn=0; jn<K; jn++)
553  CkAssert(isfinite(tmpB[in*K+jn]));
554  #endif
555  myGEMM(&trans, &trans, &m, &n, &K, &alpha, tmpA, &K, tmpB, &n, &beta, dest, &m);
556  #ifdef _NAN_CHECK_
557  for(int in=0; in<m; in++)
558  for(int jn=0; jn<n; jn++)
559  CkAssert(isfinite(dest[in*n+jn]));
560  #endif
561  #ifdef CMK_TRACE_ENABLED
562  traceUserBracketEvent(401, StartTime, CmiWallTimer());
563  #endif
564  // Transpose the result
565  transpose(dest, n, m);
566  // Tell caller we are done
567  fcb(user_data);
568 }
569 
570 
571 
572 
573 void CLA_Matrix::readyC(CkReductionMsg *msg){
574  CkCallback cb(CkIndex_CLA_Matrix::ready(NULL), thisProxy(0, 0));
575  contribute(0, NULL, CkReduction::sum_int, cb);
576  delete msg;
577 }
578 
579 void CLA_Matrix::ready(CkReductionMsg *msg){
580  init_cb.send();
581  delete msg;
582 }
583 
584 void CLA_Matrix::mult_done(CkReductionMsg *msg){
585  if(got_start){
586  got_start = got_data = false;
587  /* transpose reduction msg and do the alpha and beta multiplications */
588  internalType *data = (internalType*) msg->getData();
589  transpose(data, n, m);
590  for(int i = 0; i < m; i++)
591  for(int j = 0; j < n; j++)
592  dest[i * n + j] = beta * dest[i * n + j] + alpha * data[i * n + j];
593  delete msg;
594  msg = NULL;
595  (*fcb)(user_data);
596  }
597  else{
598  got_data = true;
599  res_msg = msg;
600  }
601 }
602 
603 /******************************************************************************/
604 /* CLA_Matrix_msg */
605 CLA_Matrix_msg::CLA_Matrix_msg(internalType *data, int d1, int d2, int fromX,
606  int fromY){
607  CmiMemcpy(this->data, data, d1 * d2 * sizeof(internalType));
608  this->d1 = d1; this->d2 = d2;
609  this->fromX = fromX; this->fromY = fromY;
610 }
611 
612 /******************************************************************************/
613 /* CLA_MM3D_multiplier */
614 CLA_MM3D_multiplier::CLA_MM3D_multiplier(int m, int k, int n){
615  this->m = m; this->k = k; this->n = n;
616  data_msg = NULL;
617  gotA = gotB = false;
618 }
619 
620 void CLA_MM3D_multiplier::initialize_reduction(CLA_MM3D_mult_init_msg *m){
621  reduce_CB = m->reduce;
622  CkGetSectionInfo(sectionCookie, m);
623  redGrp = CProxy_CkMulticastMgr(m->gid).ckLocalBranch();
624  redGrp->contribute(0, NULL, CkReduction::sum_int, sectionCookie, m->ready);
625  delete m;
626 }
627 
628 void CLA_MM3D_multiplier::receiveA(CLA_Matrix_msg *msg){
629  gotA = true;
630  if(gotB){
631  multiply(msg->data, data_msg->data);
632  delete msg;
633  delete data_msg;
634  }
635  else
636  data_msg = msg;
637 }
638 
639 void CLA_MM3D_multiplier::receiveB(CLA_Matrix_msg *msg){
640  gotB = true;
641  if(gotA){
642  multiply(data_msg->data, msg->data);
643  delete msg;
644  delete data_msg;
645  }
646  else
647  data_msg = msg;
648 }
649 
650 void CLA_MM3D_multiplier::multiply(internalType *A, internalType *B){
651  double alpha = 1, beta = 0;
652  gotA = gotB = false;
653  char trans = 'T';
654  internalType *C = new internalType[m * n];
655 #ifdef CMK_TRACE_ENABLED
656  double StartTime=CmiWallTimer();
657 #endif
658 #ifdef TEST_ALIGN
659  CkAssert((unsigned int) A %16==0);
660  CkAssert((unsigned int) B %16==0);
661  CkAssert((unsigned int) C %16==0);
662 #endif
663 
664 #ifdef PRINT_DGEMM_PARAMS
665  CkPrintf("HEY-DGEMM %c %c %d %d %d %f %f %d %d %d\n", trans, trans, m, n, k, alpha, beta, k, n, m);
666 #endif
667  myGEMM(&trans, &trans, &m, &n, &k, &alpha, A, &k, B, &n, &beta, C, &m);
668 #ifdef CMK_TRACE_ENABLED
669  traceUserBracketEvent(402, StartTime, CmiWallTimer());
670 #endif
671  CmiNetworkProgress();
672  // redGrp->contribute(m * n * sizeof(double), C, CkReduction::sum_double,
673  redGrp->contribute(m * n * sizeof(internalType), C, sumFastDoubleType,
674  sectionCookie, reduce_CB);
675  delete [] C;
676 }
677 /*@}*/
678 #include "CLA_Matrix.def.h"
679 
internalType * data
~CLA_Matrix_msg(){delete [] data;}
Definition: CLA_Matrix.h:109
Author: Eric J Bohm Date Created: June 4th, 2006.
Ortho is decomposed by orthoGrainSize.