CLA_Matrix.h

Go to the documentation of this file.
00001 /*****************************************************************************
00002  * $Source: /cvsroot/leanCP/src_charm_driver/main/CLA_Matrix.h,v $
00003  * $Author: bhatele $
00004  * $Date: 2007/12/05 08:32:47 $
00005  * $Revision: 1.9 $
00006  *****************************************************************************/
00007 
00008 /** \file CLA_Matrix.h
00009  *  
00010  */
00011 
00012 #ifndef CLA_Matrix_H
00013 #define CLA_Matrix_H
00014 
00015 #include "CLA_Matrix.decl.h"
00016 #include "ckmulticast.h"
00017 
00018 /* Class definitions */
00019 
00020 /* The CLA_Matrix class should not be used directy by the user. See comments
00021  * below regarding CLA_Matrix_interface. */
00022 class CLA_Matrix : public CBase_CLA_Matrix{
00023   friend class CLA_Matrix_interface;
00024 
00025   public:
00026     CLA_Matrix(){}
00027     CLA_Matrix(CkMigrateMessage *m){}
00028     ~CLA_Matrix();
00029     virtual void pup(PUP::er &p);
00030     virtual void ResumeFromSync();
00031     inline void synchronize(void){AtSync();}
00032 
00033     /* For 2D algorihtm */
00034     CLA_Matrix(int M, int K, int N, int m, int k, int n, int strideM,
00035      int strideN, int strideK, int part,
00036      CProxy_CLA_Matrix other1, CProxy_CLA_Matrix other2, CkCallback ready, int gemmSplitOrtho);
00037     void receiveA(CLA_Matrix_msg *m);
00038     void receiveB(CLA_Matrix_msg *m);
00039 
00040     /* For 3D algorithm */
00041     CLA_Matrix(CProxy_CLA_MM3D_multiplier p, int M, int K, int N, int m, int k,
00042      int n, int strideM, int strideK, int strideN, int part, CkCallback cb,
00043      CkGroupID gid, int gemmSplitOrtho);
00044     void ready(CkReductionMsg *m);
00045     void readyC(CkReductionMsg *m);
00046     void mult_done(CkReductionMsg *m);
00047   private:
00048     void multiply(double alpha, double beta, double *data,
00049      void (*fptr) (void *), void *usr_data);
00050     void multiply();
00051 
00052     /* shared */
00053     int M, K, N, m, k, n, um, uk, un;
00054     int M_chunks, K_chunks, N_chunks;
00055     int M_stride, K_stride, N_stride;
00056     int part;
00057     int algorithm;
00058     int gemmSplitOrtho;
00059     void (*fcb) (void *obj);
00060     void *user_data;
00061     double alpha, beta;
00062 
00063     /* For 2D algorithm */
00064     CProxySection_CLA_Matrix commGroup2D; // used by A and B
00065     double *tmpA, *tmpB, *dest; // used by C
00066     int row_count, col_count; // used by C
00067     CProxy_CLA_Matrix other1; // For A, B. For B, A. For C, A.
00068     CProxy_CLA_Matrix other2; // For A, C. For B, C. For C, B.
00069 
00070     /* For 3D algorithm */
00071     CProxySection_CLA_MM3D_multiplier commGroup3D; // used by all
00072     /* below used only by C */
00073     CkCallback init_cb;
00074     CkReductionMsg *res_msg;
00075     bool got_start, got_data;
00076 };
00077 
00078 /* This class below and the make_multiplier function below are the only way in
00079  * which a user of the library should interact with the libary. Users should
00080  * never explicitly create char CLA_Matrix object or call their entry methods.
00081  */
00082 
00083 class CLA_Matrix_interface {
00084   public:
00085     CLA_Matrix_interface(){}
00086     inline void multiply(double alpha, double beta, double *data,
00087      void (*fptr) (void *), void *usr_data, int x, int y){
00088       p(x, y).ckLocal()->multiply(alpha, beta, data, fptr, usr_data);
00089     }
00090     inline void sync(int x, int y){
00091       p(x, y).ckLocal()->synchronize();
00092     }
00093     void pup(PUP::er &per){ per | p; }
00094   private:
00095     CProxy_CLA_Matrix p;
00096     inline void setProxy(CProxy_CLA_Matrix pp){ p = pp; }
00097 
00098   friend int make_multiplier(CLA_Matrix_interface *A, CLA_Matrix_interface *B,
00099    CLA_Matrix_interface *C, CProxy_ArrayElement bindA,
00100    CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
00101    int M, int K, int N, int m, int k, int n, int strideM, int strideK,
00102    int strideN, CkCallback cbA, CkCallback cbB,
00103    CkCallback cbC, CkGroupID gid, int algorithm, int gemmSplitOrtho);
00104 };
00105 
00106 class CLA_Matrix_msg : public CkMcastBaseMsg, public CMessage_CLA_Matrix_msg {
00107   public:
00108     CLA_Matrix_msg(double *data, int d1, int d2, int fromX, int fromY);
00109 ///    ~CLA_Matrix_msg(){delete [] data;}
00110     double *data;
00111     int d1, d2;
00112     int fromX, fromY;
00113 };
00114 
00115 class CLA_MM3D_mult_init_msg : public CkMcastBaseMsg,
00116  public CMessage_CLA_MM3D_mult_init_msg {
00117   public:
00118     CLA_MM3D_mult_init_msg(CkGroupID gid, CkCallback ready,
00119        CkCallback reduce){
00120       this->gid = gid;
00121       this->ready = ready;
00122       this->reduce = reduce;
00123     }
00124     CkGroupID gid;
00125     CkCallback ready;
00126     CkCallback reduce;
00127 };
00128 
00129 class CLA_MM3D_Map : public CkArrayMap {
00130   public:
00131     CLA_MM3D_Map(int mc, int kc, int nc){
00132       M_chunks = mc;
00133       K_chunks = kc;
00134       N_chunks = nc;
00135       pes = CkNumPes();
00136     }
00137     virtual int procNum(int arrayHdl, const CkArrayIndex &idx){
00138       CkArrayIndex3D idx3d = *(CkArrayIndex3D *) &idx;
00139       return (N_chunks * idx3d.index[0] + N_chunks * M_chunks * idx3d.index[2] +
00140        idx3d.index[1]) % pes;
00141     }
00142   private:
00143     int M_chunks, K_chunks, N_chunks, pes;
00144 };
00145 
00146 class CLA_MM3D_multiplier : public CBase_CLA_MM3D_multiplier{
00147   public:
00148     CLA_MM3D_multiplier(){};
00149     CLA_MM3D_multiplier(CkMigrateMessage *m){};
00150     CLA_MM3D_multiplier(int m, int k, int n);
00151     ~CLA_MM3D_multiplier(){};
00152     void initialize_reduction(CLA_MM3D_mult_init_msg *m);
00153     void receiveA(CLA_Matrix_msg *msg);
00154     void receiveB(CLA_Matrix_msg *msg);
00155     void multiply(double *A, double *B);
00156   private:
00157     int m, k, n;
00158     bool gotA, gotB;
00159     CLA_Matrix_msg *data_msg;
00160     CkSectionInfo sectionCookie;
00161     CkCallback reduce_CB;
00162     CkMulticastMgr *redGrp;
00163 /*
00164     double *A, *B, *C;
00165 */
00166 };
00167 
00168 /* Function below creates the necessary interfaces so that
00169  * C = beta * C + alpha * A * B
00170  * can be computed. X will be bound to bindX for
00171  * X in {A, B, C}. A is M x K, B is K x N, C is M x N. M, K, and N are
00172  * decomposed into chunks of size m, k, and n, respectively. Along a given
00173  * dimension, the array must be of size ceil(1.0 * Y / y) (Y in {M, K, N},
00174  * y in {m, k, n}). If y does not divide Y, the last element must have only
00175  * Y % y elements along the given dimension. The variables strideY indicate
00176  * the stride at which elements are to be placed along that dimension.
00177  * When matrix X is ready, it does a
00178  * callback to cbX. A CKGroupID gid to be used by the library must be passed
00179  * in (this can be a simple CProxy_CkMulticastMgr::ckNew()). The algorithm
00180  * used to multiply the matrices is determined by the value of 'algorithm',
00181  * which can take the values defined below.
00182  *
00183  * Return value: Zero is returned upon success. A negative value is returned
00184  * if an error occurs.
00185  *  ERR_INVALID_ALG: an invalid algorithm was selected
00186  *  ERR_INVALID_DIM: invalid dimensions were given
00187  */
00188 
00189 /* Valid values for 'algorithm' are given below. As new ones are added,
00190  * they should be given incremental numbers. MM_ALG_MIN should not be changed,
00191  * and MM_ALG_MAX should have the value of the greatest defined algorithm.
00192  * MM_ALG_MIN MM_ALG_MAX are used to validate user input, so they should always
00193  * be updated as new algorithms are added.
00194  */
00195 #define MM_ALG_MIN      1
00196 #define MM_ALG_2D       1
00197 #define MM_ALG_3D       2
00198 #define MM_ALG_MAX      2
00199 
00200 int make_multiplier(CLA_Matrix_interface *A, CLA_Matrix_interface *B,
00201  CLA_Matrix_interface *C, CProxy_ArrayElement bindA,
00202  CProxy_ArrayElement bindB, CProxy_ArrayElement bindC,
00203  int M, int K, int N, int m, int k, int n, int strideM, int strideK,
00204  int strideZ, CkCallback cbA, CkCallback cbB,
00205  CkCallback cbC, CkGroupID gid, int algorithm, int gemmSplitOrtho);
00206 
00207 /* Error codes */
00208 #define SUCCESS                 0
00209 #define ERR_INVALID_ALG         -1
00210 #define ERR_INVALID_DIM         -2
00211 void transpose(double *data, int m, int n);
00212 #endif
00213 

Generated on Thu Dec 6 18:25:28 2007 for leanCP by  doxygen 1.5.3