00001
00002
00003
00004
00005
00006
00007
00008
00009 #include "ampiimpl.h"
00010 #include "tcharm.h"
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00058
00060
00061 #ifndef MAX
00062 int MAX(int a, int b){
00063 if(a>b)
00064 return a;
00065 else
00066 return b;
00067 }
00068 #endif
00069
00070 #if 0
00071 int MPI_Pack_size(int incount, MPI_Datatype type, MPI_Comm comm, int *size)
00072 {
00073 CkDDT_DataType *ddt = getAmpiInstance(comm)->getDDT()->getType(type);
00074 int typesize = ddt->getSize();
00075 *size = incount * typesize;
00076 return MPI_SUCCESS;
00077 }
00078 #endif
00079
00080
00081
00082
00083 void MPICH_Localcopy(void *sendbuf, int sendcount, MPI_Datatype sendtype,
00084 void *recvbuf, int recvcount, MPI_Datatype recvtype)
00085 {
00086 int rank;
00087
00088 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
00089 getAmpiInstance(MPI_COMM_WORLD)->sendrecv ( sendbuf, sendcount, sendtype,
00090 rank, MPI_ATA_TAG,
00091 recvbuf, recvcount, recvtype,
00092 rank, MPI_ATA_TAG,
00093 MPI_COMM_WORLD, MPI_STATUS_IGNORE);
00094 }
00095
00096
00097 inline void MPID_Datatype_get_extent_macro(MPI_Datatype &type, MPI_Aint &extent){
00098 CkDDT_DataType *ddt = getAmpiInstance(MPI_COMM_WORLD)->getDDT()->getType(type);
00099 extent = ddt->getExtent();
00100 }
00101
00102 inline void MPID_Datatype_get_size_macro(MPI_Datatype &type, int &size){
00103 CkDDT_DataType *ddt = getAmpiInstance(MPI_COMM_WORLD)->getDDT()->getType(type);
00104 size = ddt->getSize();
00105 }
00106
00107
00109
00111
00112
00113
00114
00115
00116
00117 int AMPI_Alltoall_long(
00118 void *sendbuf,
00119 int sendcount,
00120 MPI_Datatype sendtype,
00121 void *recvbuf,
00122 int recvcount,
00123 MPI_Datatype recvtype,
00124 MPI_Comm comm )
00125 {
00126
00127 int comm_size, i, pof2;
00128 MPI_Aint sendtype_extent, recvtype_extent;
00129
00130 int src, dst, rank, nbytes;
00131 MPI_Status status;
00132 int sendtype_size;
00133
00134 if (sendcount == 0) return MPI_SUCCESS;
00135
00136 MPI_Comm_rank (MPI_COMM_WORLD, &rank);
00137 MPI_Comm_size (MPI_COMM_WORLD, &comm_size);
00138
00139
00140
00141 MPID_Datatype_get_extent_macro(recvtype, recvtype_extent);
00142 MPID_Datatype_get_extent_macro(sendtype, sendtype_extent);
00143
00144 MPID_Datatype_get_size_macro(sendtype, sendtype_size);
00145 nbytes = sendtype_size * sendcount;
00146
00147
00148
00149 MPICH_Localcopy(((char *)sendbuf +
00150 rank*sendcount*sendtype_extent),
00151 sendcount, sendtype,
00152 ((char *)recvbuf +
00153 rank*recvcount*recvtype_extent),
00154 recvcount, recvtype);
00155
00156
00157
00158 i = 1;
00159 while (i < comm_size)
00160 i *= 2;
00161 if (i == comm_size)
00162 pof2 = 1;
00163 else
00164 pof2 = 0;
00165
00166
00167 for (i=1; i<comm_size; i++) {
00168 if (pof2 == 1) {
00169
00170 src = dst = rank ^ i;
00171 }
00172 else {
00173 src = (rank - i + comm_size) % comm_size;
00174 dst = (rank + i) % comm_size;
00175 }
00176
00177 getAmpiInstance(comm)->sendrecv(((char *)sendbuf +
00178 dst*sendcount*sendtype_extent),
00179 sendcount, sendtype, dst,
00180 MPI_ATA_TAG,
00181 ((char *)recvbuf +
00182 src*recvcount*recvtype_extent),
00183 recvcount, recvtype, src,
00184 MPI_ATA_TAG, comm, &status);
00185 }
00186
00187 return MPI_SUCCESS;
00188 }
00189
00190
00192
00194
00195 #if 0
00196 int AMPI_Alltoall_short(
00197 void *sendbuf,
00198 int sendcount,
00199 MPI_Datatype sendtype,
00200 void *recvbuf,
00201 int recvcount,
00202 MPI_Datatype recvtype,
00203 MPI_Comm comm )
00204 {
00205
00206 int comm_size, i, pof2;
00207 MPI_Aint sendtype_extent, recvtype_extent;
00208
00209 int mpi_errno=MPI_SUCCESS, src, dst, rank, nbytes;
00210 MPI_Status status;
00211 void *tmp_buf;
00212 int sendtype_size, pack_size, block, position, *displs, count;
00213
00214 MPI_Datatype newtype;
00215 MPI_Aint recvtype_true_extent, recvbuf_extent, recvtype_true_lb;
00216
00217
00218 if (sendcount == 0) return MPI_SUCCESS;
00219
00220 MPI_Comm_rank (MPI_COMM_WORLD, &rank);
00221 MPI_Comm_size (MPI_COMM_WORLD, &comm_size);
00222
00223
00224 MPID_Datatype_get_extent_macro(recvtype, recvtype_extent);
00225 MPID_Datatype_get_extent_macro(sendtype, sendtype_extent);
00226
00227 MPID_Datatype_get_size_macro(sendtype, sendtype_size);
00228 nbytes = sendtype_size * sendcount;
00229
00230
00231
00232
00233
00234 MPI_Pack_size(recvcount*comm_size, recvtype, comm, &pack_size);
00235 tmp_buf = malloc(pack_size);
00236 CkAssert(tmp_buf);
00237
00238
00239
00240 MPICH_Localcopy((char *) sendbuf + rank*sendcount*sendtype_extent,
00241 (comm_size - rank)*sendcount, sendtype, recvbuf,
00242 (comm_size - rank)*recvcount, recvtype);
00243
00244 MPICH_Localcopy(sendbuf, rank*sendcount, sendtype,
00245 (char *) recvbuf + (comm_size-rank)*recvcount*recvtype_extent,
00246 rank*recvcount, recvtype);
00247
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257
00258 displs = (int*)malloc(comm_size * sizeof(int));
00259 CkAssert(displs);
00260
00261
00262 pof2 = 1;
00263 while (pof2 < comm_size) {
00264 dst = (rank + pof2) % comm_size;
00265 src = (rank - pof2 + comm_size) % comm_size;
00266
00267
00268
00269
00270 count = 0;
00271 for (block=1; block<comm_size; block++) {
00272 if (block & pof2) {
00273 displs[count] = block * recvcount;
00274 count++;
00275 }
00276 }
00277
00278 mpi_errno = MPI_Type_create_indexed_block(count, recvcount, displs, recvtype, &newtype);
00279
00280 if (mpi_errno)
00281 return mpi_errno;
00282
00283 mpi_errno = MPI_Type_commit(&newtype);
00284
00285 if (mpi_errno)
00286 return mpi_errno;
00287
00288 position = 0;
00289 mpi_errno = MPI_Pack(recvbuf, 1, newtype, tmp_buf, pack_size,
00290 &position, comm);
00291
00292 getAmpiInstance(comm)->sendrecv(tmp_buf, position, MPI_PACKED, dst,
00293 MPI_ATA_TAG, recvbuf, 1, newtype,
00294 src, MPI_ATA_TAG, comm,
00295 MPI_STATUS_IGNORE);
00296
00297 if (mpi_errno)
00298 return mpi_errno;
00299
00300
00301 mpi_errno = MPI_Type_free(&newtype);
00302
00303 if (mpi_errno)
00304 return mpi_errno;
00305
00306 pof2 *= 2;
00307 }
00308
00309 free(displs);
00310 free(tmp_buf);
00311
00312
00313
00314
00315
00316 mpi_errno = MPI_Type_get_true_extent(recvtype, &recvtype_true_lb,
00317 &recvtype_true_extent);
00318
00319 if (mpi_errno)
00320 return mpi_errno;
00321
00322 recvbuf_extent = recvcount * comm_size *
00323 (MAX(recvtype_true_extent, recvtype_extent));
00324 tmp_buf = malloc(recvbuf_extent);
00325 CkAssert(tmp_buf);
00326
00327
00328 tmp_buf = (void *)((char*)tmp_buf - recvtype_true_lb);
00329
00330 MPICH_Localcopy((char *) recvbuf + (rank+1)*recvcount*recvtype_extent,
00331 (comm_size - rank - 1)*recvcount, recvtype, tmp_buf,
00332 (comm_size - rank - 1)*recvcount, recvtype);
00333
00334 MPICH_Localcopy(recvbuf, (rank+1)*recvcount, recvtype,
00335 (char *) tmp_buf + (comm_size-rank-1)*recvcount*recvtype_extent,
00336 (rank+1)*recvcount, recvtype);
00337
00338
00339
00340
00341
00342 for (i=0; i<comm_size; i++)
00343 MPICH_Localcopy((char *) tmp_buf + i*recvcount*recvtype_extent,
00344 recvcount, recvtype,
00345 (char *) recvbuf + (comm_size-i-1)*recvcount*recvtype_extent,
00346 recvcount, recvtype);
00347
00348 free((char*)tmp_buf + recvtype_true_lb);
00349
00350 }
00351 #endif
00352
00354
00356
00357 int AMPI_Alltoall_medium(
00358 void *sendbuf,
00359 int sendcount,
00360 MPI_Datatype sendtype,
00361 void *recvbuf,
00362 int recvcount,
00363 MPI_Datatype recvtype,
00364 MPI_Comm comm )
00365 {
00366
00367 int comm_size, i;
00368 MPI_Aint sendtype_extent, recvtype_extent;
00369
00370 int mpi_errno=MPI_SUCCESS, dst, rank, nbytes;
00371 int sendtype_size;
00372
00373 MPI_Request *reqarray;
00374 MPI_Status *starray;
00375
00376 if (sendcount == 0) return MPI_SUCCESS;
00377
00378 MPI_Comm_rank (MPI_COMM_WORLD, &rank);
00379 MPI_Comm_size (MPI_COMM_WORLD, &comm_size);
00380
00381
00382 MPID_Datatype_get_extent_macro(recvtype, recvtype_extent);
00383 MPID_Datatype_get_extent_macro(sendtype, sendtype_extent);
00384
00385 MPID_Datatype_get_size_macro(sendtype, sendtype_size);
00386 nbytes = sendtype_size * sendcount;
00387
00388
00389
00390 reqarray = (MPI_Request *) malloc(2*comm_size*sizeof(MPI_Request));
00391
00392 if (!reqarray)
00393 return MPI_ERR_OTHER;
00394
00395 starray = (MPI_Status *) malloc(2*comm_size*sizeof(MPI_Status));
00396 if (!starray) {
00397 free(reqarray);
00398 return MPI_ERR_OTHER;
00399 }
00400
00401
00402 ampi *ptr = getAmpiInstance(comm);
00403 for ( i=0; i<comm_size; i++ ) {
00404 dst = (rank+i) % comm_size;
00405 ptr->irecv((char *)recvbuf + dst*recvcount*recvtype_extent, recvcount, recvtype, dst,
00406 MPI_ATA_TAG, comm, &reqarray[i]);
00407 }
00408
00409 for ( i=0; i<comm_size; i++ ) {
00410 dst = (rank+i) % comm_size;
00411
00412
00413 ptr->send(MPI_ATA_TAG, getAmpiInstance(comm)->getRank(),
00414 (char *)sendbuf + dst*sendcount*sendtype_extent,
00415 sendcount, sendtype, dst, comm);
00416 reqarray[i+comm_size] = MPI_REQUEST_NULL;
00417 }
00418
00419
00420 mpi_errno = MPI_Waitall(2*comm_size,reqarray,starray);
00421
00422
00423
00424
00425
00426
00427
00428
00429
00430
00431 free(starray);
00432 free(reqarray);
00433
00434 return mpi_errno;
00435 }
00436
00437
00438
00440
00442
00443