4 #include "CLA_Matrix.decl.h"
5 #include "ckmulticast.h"
22 virtual void pup(PUP::er &p);
23 virtual void ResumeFromSync();
24 inline void synchronize(
void){AtSync();}
27 CLA_Matrix(
int M,
int K,
int N,
int m,
int k,
int n,
int strideM,
28 int strideN,
int strideK,
int part,
29 CProxy_CLA_Matrix other1, CProxy_CLA_Matrix other2, CkCallback ready,
int gemmSplitOrtho);
34 CLA_Matrix(CProxy_CLA_MM3D_multiplier p,
int M,
int K,
int N,
int m,
int k,
35 int n,
int strideM,
int strideK,
int strideN,
int part, CkCallback cb,
36 CkGroupID gid,
int gemmSplitOrtho);
37 void ready(CkReductionMsg *m);
38 void readyC(CkReductionMsg *m);
39 void mult_done(CkReductionMsg *m);
41 void multiply(
double alpha,
double beta, internalType *data,
42 void (*fptr) (
void *),
void *usr_data);
46 int M, K, N, m, k, n, um, uk, un;
47 int M_chunks, K_chunks, N_chunks;
48 int M_stride, K_stride, N_stride;
52 void (*fcb) (
void *obj);
57 CProxySection_CLA_Matrix commGroup2D;
58 internalType *tmpA, *tmpB, *dest;
59 int row_count, col_count;
60 CProxy_CLA_Matrix other1;
61 CProxy_CLA_Matrix other2;
64 CProxySection_CLA_MM3D_multiplier commGroup3D;
67 CkReductionMsg *res_msg;
68 bool got_start, got_data;
82 inline void multiply(
double alpha,
double beta, internalType *data,
83 void (*fptr) (
void *),
void *usr_data,
int x,
int y){
84 p(x, y).ckLocal()->multiply(alpha, beta, data, fptr, usr_data);
86 inline void sync(
int x,
int y){
87 p(x, y).ckLocal()->synchronize();
89 void pup(PUP::er &per){ per | p; }
92 inline void setProxy(CProxy_CLA_Matrix pp){ p = pp; }
96 CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
97 int M,
int K,
int N,
int m,
int k,
int n,
int strideM,
int strideK,
98 int strideN, CkCallback cbA, CkCallback cbB,
99 CkCallback cbC, CkGroupID gid,
int algorithm,
int gemmSplitOrtho);
124 this->reduce = reduce;
143 virtual int procNum(
int arrayHdl,
const CkArrayIndex &idx){
144 CkArrayIndex3D idx3d = *(CkArrayIndex3D *) &idx;
145 return (N_chunks * idx3d.index[0] + N_chunks * M_chunks * idx3d.index[2] +
146 idx3d.index[1]) % pes;
149 int M_chunks, K_chunks, N_chunks, pes;
164 void multiply(internalType *A, internalType *B);
169 CkSectionInfo sectionCookie;
170 CkCallback reduce_CB;
171 CkMulticastMgr *redGrp;
202 CProxy_ArrayElement bindA, CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
205 int strideM,
int strideK,
int strideZ,
206 CkCallback cbA, CkCallback cbB, CkCallback cbC,
207 CkGroupID gid,
int algorithm,
int gemmSplitOrtho
226 #define ERR_INVALID_ALG -1
227 #define ERR_INVALID_DIM -2
233 template <
typename T>
234 void transpose(T *data,
int m,
int n)
238 for(
int i = 0; i < m; i++)
239 for(
int j = i + 1; j < n; j++){
240 T tmp = data[i * n + j];
241 data[i * n + j] = data[j * m + i];
242 data[j * m + i] = tmp;
246 T *tmp =
new T[m * n];
247 CmiMemcpy(tmp, data, m * n *
sizeof(T));
248 for(
int i = 0; i < m; i++)
249 for(
int j = 0; j < n; j++)
250 data[j * m + i] = tmp[i * n + j];
internalType * data
~CLA_Matrix_msg(){delete [] data;}