00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef CLA_Matrix_H
00013 #define CLA_Matrix_H
00014
00015 #include "CLA_Matrix.decl.h"
00016 #include "ckmulticast.h"
00017
00018
00019
00020
00021
00022 class CLA_Matrix : public CBase_CLA_Matrix{
00023 friend class CLA_Matrix_interface;
00024
00025 public:
00026 CLA_Matrix(){}
00027 CLA_Matrix(CkMigrateMessage *m){}
00028 ~CLA_Matrix();
00029 virtual void pup(PUP::er &p);
00030 virtual void ResumeFromSync();
00031 inline void synchronize(void){AtSync();}
00032
00033
00034 CLA_Matrix(int M, int K, int N, int m, int k, int n, int strideM,
00035 int strideN, int strideK, int part,
00036 CProxy_CLA_Matrix other1, CProxy_CLA_Matrix other2, CkCallback ready, int gemmSplitOrtho);
00037 void receiveA(CLA_Matrix_msg *m);
00038 void receiveB(CLA_Matrix_msg *m);
00039
00040
00041 CLA_Matrix(CProxy_CLA_MM3D_multiplier p, int M, int K, int N, int m, int k,
00042 int n, int strideM, int strideK, int strideN, int part, CkCallback cb,
00043 CkGroupID gid, int gemmSplitOrtho);
00044 void ready(CkReductionMsg *m);
00045 void readyC(CkReductionMsg *m);
00046 void mult_done(CkReductionMsg *m);
00047 private:
00048 void multiply(double alpha, double beta, double *data,
00049 void (*fptr) (void *), void *usr_data);
00050 void multiply();
00051
00052
00053 int M, K, N, m, k, n, um, uk, un;
00054 int M_chunks, K_chunks, N_chunks;
00055 int M_stride, K_stride, N_stride;
00056 int part;
00057 int algorithm;
00058 int gemmSplitOrtho;
00059 void (*fcb) (void *obj);
00060 void *user_data;
00061 double alpha, beta;
00062
00063
00064 CProxySection_CLA_Matrix commGroup2D;
00065 double *tmpA, *tmpB, *dest;
00066 int row_count, col_count;
00067 CProxy_CLA_Matrix other1;
00068 CProxy_CLA_Matrix other2;
00069
00070
00071 CProxySection_CLA_MM3D_multiplier commGroup3D;
00072
00073 CkCallback init_cb;
00074 CkReductionMsg *res_msg;
00075 bool got_start, got_data;
00076 };
00077
00078
00079
00080
00081
00082
00083 class CLA_Matrix_interface {
00084 public:
00085 CLA_Matrix_interface(){}
00086 inline void multiply(double alpha, double beta, double *data,
00087 void (*fptr) (void *), void *usr_data, int x, int y){
00088 p(x, y).ckLocal()->multiply(alpha, beta, data, fptr, usr_data);
00089 }
00090 inline void sync(int x, int y){
00091 p(x, y).ckLocal()->synchronize();
00092 }
00093 void pup(PUP::er &per){ per | p; }
00094 private:
00095 CProxy_CLA_Matrix p;
00096 inline void setProxy(CProxy_CLA_Matrix pp){ p = pp; }
00097
00098 friend int make_multiplier(CLA_Matrix_interface *A, CLA_Matrix_interface *B,
00099 CLA_Matrix_interface *C, CProxy_ArrayElement bindA,
00100 CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
00101 int M, int K, int N, int m, int k, int n, int strideM, int strideK,
00102 int strideN, CkCallback cbA, CkCallback cbB,
00103 CkCallback cbC, CkGroupID gid, int algorithm, int gemmSplitOrtho);
00104 };
00105
00106 class CLA_Matrix_msg : public CkMcastBaseMsg, public CMessage_CLA_Matrix_msg {
00107 public:
00108 CLA_Matrix_msg(double *data, int d1, int d2, int fromX, int fromY);
00109
00110 double *data;
00111 int d1, d2;
00112 int fromX, fromY;
00113 };
00114
00115 class CLA_MM3D_mult_init_msg : public CkMcastBaseMsg,
00116 public CMessage_CLA_MM3D_mult_init_msg {
00117 public:
00118 CLA_MM3D_mult_init_msg(CkGroupID gid, CkCallback ready,
00119 CkCallback reduce){
00120 this->gid = gid;
00121 this->ready = ready;
00122 this->reduce = reduce;
00123 }
00124 CkGroupID gid;
00125 CkCallback ready;
00126 CkCallback reduce;
00127 };
00128
00129 class CLA_MM3D_Map : public CkArrayMap {
00130 public:
00131 CLA_MM3D_Map(int mc, int kc, int nc){
00132 M_chunks = mc;
00133 K_chunks = kc;
00134 N_chunks = nc;
00135 pes = CkNumPes();
00136 }
00137 virtual int procNum(int arrayHdl, const CkArrayIndex &idx){
00138 CkArrayIndex3D idx3d = *(CkArrayIndex3D *) &idx;
00139 return (N_chunks * idx3d.index[0] + N_chunks * M_chunks * idx3d.index[2] +
00140 idx3d.index[1]) % pes;
00141 }
00142 private:
00143 int M_chunks, K_chunks, N_chunks, pes;
00144 };
00145
00146 class CLA_MM3D_multiplier : public CBase_CLA_MM3D_multiplier{
00147 public:
00148 CLA_MM3D_multiplier(){};
00149 CLA_MM3D_multiplier(CkMigrateMessage *m){};
00150 CLA_MM3D_multiplier(int m, int k, int n);
00151 ~CLA_MM3D_multiplier(){};
00152 void initialize_reduction(CLA_MM3D_mult_init_msg *m);
00153 void receiveA(CLA_Matrix_msg *msg);
00154 void receiveB(CLA_Matrix_msg *msg);
00155 void multiply(double *A, double *B);
00156 private:
00157 int m, k, n;
00158 bool gotA, gotB;
00159 CLA_Matrix_msg *data_msg;
00160 CkSectionInfo sectionCookie;
00161 CkCallback reduce_CB;
00162 CkMulticastMgr *redGrp;
00163
00164
00165
00166 };
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195 #define MM_ALG_MIN 1
00196 #define MM_ALG_2D 1
00197 #define MM_ALG_3D 2
00198 #define MM_ALG_MAX 2
00199
00200 int make_multiplier(CLA_Matrix_interface *A, CLA_Matrix_interface *B,
00201 CLA_Matrix_interface *C, CProxy_ArrayElement bindA,
00202 CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
00203 int M, int K, int N, int m, int k, int n, int strideM, int strideK,
00204 int strideZ, CkCallback cbA, CkCallback cbB,
00205 CkCallback cbC, CkGroupID gid, int algorithm, int gemmSplitOrtho);
00206
00207
00208 #define SUCCESS 0
00209 #define ERR_INVALID_ALG -1
00210 #define ERR_INVALID_DIM -2
00211 void transpose(double *data, int m, int n);
00212 #endif
00213