OpenAtom  Version1.5a
CLA_Matrix.h
1 #ifndef CLA_Matrix_H
2 #define CLA_Matrix_H
3 
4 #include "CLA_Matrix.decl.h"
5 #include "ckmulticast.h"
6 
7 /** @addtogroup Ortho
8  @{
9 */
10 
11 /*
12  * The CLA_Matrix class should not be used directy by the user. See comments
13  * below regarding CLA_Matrix_interface
14  */
16  friend class CLA_Matrix_interface;
17 
18  public:
19  CLA_Matrix(){}
20  CLA_Matrix(CkMigrateMessage *m){}
21  ~CLA_Matrix();
22  virtual void pup(PUP::er &p);
23  virtual void ResumeFromSync();
24  inline void synchronize(void){AtSync();}
25 
26  /* For 2D algorihtm */
27  CLA_Matrix(int M, int K, int N, int m, int k, int n, int strideM,
28  int strideN, int strideK, int part,
29  CProxy_CLA_Matrix other1, CProxy_CLA_Matrix other2, CkCallback ready, int gemmSplitOrtho);
30  void receiveA(CLA_Matrix_msg *m);
31  void receiveB(CLA_Matrix_msg *m);
32 
33  /* For 3D algorithm */
34  CLA_Matrix(CProxy_CLA_MM3D_multiplier p, int M, int K, int N, int m, int k,
35  int n, int strideM, int strideK, int strideN, int part, CkCallback cb,
36  CkGroupID gid, int gemmSplitOrtho);
37  void ready(CkReductionMsg *m);
38  void readyC(CkReductionMsg *m);
39  void mult_done(CkReductionMsg *m);
40  private:
41  void multiply(double alpha, double beta, internalType *data,
42  void (*fptr) (void *), void *usr_data);
43  void multiply();
44 
45  /* shared */
46  int M, K, N, m, k, n, um, uk, un;
47  int M_chunks, K_chunks, N_chunks;
48  int M_stride, K_stride, N_stride;
49  int part;
50  int algorithm;
51  int gemmSplitOrtho;
52  void (*fcb) (void *obj);
53  void *user_data;
54  double alpha, beta;
55 
56  /* For 2D algorithm */
57  CProxySection_CLA_Matrix commGroup2D; // used by A and B
58  internalType *tmpA, *tmpB, *dest; // used by C
59  int row_count, col_count; // used by C
60  CProxy_CLA_Matrix other1; // For A, B. For B, A. For C, A.
61  CProxy_CLA_Matrix other2; // For A, C. For B, C. For C, B.
62 
63  /* For 3D algorithm */
64  CProxySection_CLA_MM3D_multiplier commGroup3D; // used by all
65  /* below used only by C */
66  CkCallback init_cb;
67  CkReductionMsg *res_msg;
68  bool got_start, got_data;
69 };
70 
71 
72 
73 
74 /* This class below and the make_multiplier function below are the only way in
75  * which a user of the library should interact with the libary. Users should
76  * never explicitly create char CLA_Matrix object or call their entry methods.
77  */
78 
80  public:
82  inline void multiply(double alpha, double beta, internalType *data,
83  void (*fptr) (void *), void *usr_data, int x, int y){
84  p(x, y).ckLocal()->multiply(alpha, beta, data, fptr, usr_data);
85  }
86  inline void sync(int x, int y){
87  p(x, y).ckLocal()->synchronize();
88  }
89  void pup(PUP::er &per){ per | p; }
90  private:
91  CProxy_CLA_Matrix p;
92  inline void setProxy(CProxy_CLA_Matrix pp){ p = pp; }
93 
94  friend int make_multiplier(CLA_Matrix_interface *A, CLA_Matrix_interface *B,
95  CLA_Matrix_interface *C, CProxy_ArrayElement bindA,
96  CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
97  int M, int K, int N, int m, int k, int n, int strideM, int strideK,
98  int strideN, CkCallback cbA, CkCallback cbB,
99  CkCallback cbC, CkGroupID gid, int algorithm, int gemmSplitOrtho);
100 };
101 
102 
103 
104 
106  public:
107  CLA_Matrix_msg(internalType *data, int d1, int d2, int fromX, int fromY);
108 /// ~CLA_Matrix_msg(){delete [] data;}
109  internalType *data;
110  int d1, d2;
111  int fromX, fromY;
112 };
113 
114 
115 
116 
119  public:
120  CLA_MM3D_mult_init_msg(CkGroupID gid, CkCallback ready,
121  CkCallback reduce){
122  this->gid = gid;
123  this->ready = ready;
124  this->reduce = reduce;
125  }
126  CkGroupID gid;
127  CkCallback ready;
128  CkCallback reduce;
129 };
130 
131 
132 
133 
134 
135 class CLA_MM3D_Map : public CkArrayMap {
136  public:
137  CLA_MM3D_Map(int mc, int kc, int nc){
138  M_chunks = mc;
139  K_chunks = kc;
140  N_chunks = nc;
141  pes = CkNumPes();
142  }
143  virtual int procNum(int arrayHdl, const CkArrayIndex &idx){
144  CkArrayIndex3D idx3d = *(CkArrayIndex3D *) &idx;
145  return (N_chunks * idx3d.index[0] + N_chunks * M_chunks * idx3d.index[2] +
146  idx3d.index[1]) % pes;
147  }
148  private:
149  int M_chunks, K_chunks, N_chunks, pes;
150 };
151 
152 
153 
154 
156  public:
158  CLA_MM3D_multiplier(CkMigrateMessage *m){};
159  CLA_MM3D_multiplier(int m, int k, int n);
160  ~CLA_MM3D_multiplier(){};
161  void initialize_reduction(CLA_MM3D_mult_init_msg *m);
162  void receiveA(CLA_Matrix_msg *msg);
163  void receiveB(CLA_Matrix_msg *msg);
164  void multiply(internalType *A, internalType *B);
165  private:
166  int m, k, n;
167  bool gotA, gotB;
168  CLA_Matrix_msg *data_msg;
169  CkSectionInfo sectionCookie;
170  CkCallback reduce_CB;
171  CkMulticastMgr *redGrp;
172 /*
173  double *A, *B, *C;
174 */
175 };
176 
177 
178 
179 
180 /* Function below creates the necessary interfaces so that
181  * C = beta * C + alpha * A * B
182  * can be computed. X will be bound to bindX for
183  * X in {A, B, C}. A is M x K, B is K x N, C is M x N. M, K, and N are
184  * decomposed into chunks of size m, k, and n, respectively. Along a given
185  * dimension, the array must be of size ceil(1.0 * Y / y) (Y in {M, K, N},
186  * y in {m, k, n}). If y does not divide Y, the last element must have only
187  * Y % y elements along the given dimension. The variables strideY indicate
188  * the stride at which elements are to be placed along that dimension.
189  * When matrix X is ready, it does a
190  * callback to cbX. A CKGroupID gid to be used by the library must be passed
191  * in (this can be a simple CProxy_CkMulticastMgr::ckNew()). The algorithm
192  * used to multiply the matrices is determined by the value of 'algorithm',
193  * which can take the values defined below.
194  *
195  * Return value: Zero is returned upon success. A negative value is returned
196  * if an error occurs.
197  * ERR_INVALID_ALG: an invalid algorithm was selected
198  * ERR_INVALID_DIM: invalid dimensions were given
199  */
200 int make_multiplier(
202  CProxy_ArrayElement bindA, CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
203  int M, int K, int N,
204  int m, int k, int n,
205  int strideM, int strideK, int strideZ,
206  CkCallback cbA, CkCallback cbB, CkCallback cbC,
207  CkGroupID gid, int algorithm, int gemmSplitOrtho
208  );
209 
210 
211 
212 
213 /* Valid values for 'algorithm' are given below. As new ones are added,
214  * they should be given incremental numbers. MM_ALG_MIN should not be changed,
215  * and MM_ALG_MAX should have the value of the greatest defined algorithm.
216  * MM_ALG_MIN MM_ALG_MAX are used to validate user input, so they should always
217  * be updated as new algorithms are added.
218  */
219 #define MM_ALG_MIN 1
220 #define MM_ALG_2D 1
221 #define MM_ALG_3D 2
222 #define MM_ALG_MAX 2
223 
224 /* Error codes */
225 #define SUCCESS 0
226 #define ERR_INVALID_ALG -1
227 #define ERR_INVALID_DIM -2
228 
229 
230 
231 
232 /* Transpose data, which has dimension m x n */
233 template <typename T>
234 void transpose(T *data, int m, int n)
235 {
236  if(m == n){
237  /* transpose square matrix in place */
238  for(int i = 0; i < m; i++)
239  for(int j = i + 1; j < n; j++){
240  T tmp = data[i * n + j];
241  data[i * n + j] = data[j * m + i];
242  data[j * m + i] = tmp;
243  }
244  }
245  else {
246  T *tmp = new T[m * n];
247  CmiMemcpy(tmp, data, m * n * sizeof(T));
248  for(int i = 0; i < m; i++)
249  for(int j = 0; j < n; j++)
250  data[j * m + i] = tmp[i * n + j];
251  delete [] tmp;
252  }
253 }
254 /*@}*/
255 #endif
internalType * data
~CLA_Matrix_msg(){delete [] data;}
Definition: CLA_Matrix.h:109