#include "ampi.h"
#include "charm++.h"

#define NX 2
#define NY 2
#define NZ 2
#define TCHUNKS (NX*NY*NZ)

#define DIMX 20
#define DIMY 20
#define DIMZ 20

class chunk {
  public:
    double t[DIMX+2][DIMY+2][DIMZ+2];
    int xidx, yidx, zidx;
    int xm, xp, ym, yp, zm, zp;
    double sbxm[DIMY*DIMZ];
    double sbxp[DIMY*DIMZ];
    double sbym[DIMX*DIMZ];
    double sbyp[DIMX*DIMZ];
    double sbzm[DIMX*DIMY];
    double sbzp[DIMX*DIMY];
    double rbxm[DIMY*DIMZ];
    double rbxp[DIMY*DIMZ];
    double rbym[DIMX*DIMZ];
    double rbyp[DIMX*DIMZ];
    double rbzm[DIMX*DIMY];
    double rbzp[DIMX*DIMY];
};

void chunk_pup(pup_er p, void *d)
{
  chunk **cpp = (chunk **) d;
  if(pup_isUnpacking(p))
    *cpp = new chunk;
  chunk *cp = *cpp;
  pup_int(p, &cp->xidx);
  pup_int(p, &cp->yidx);
  pup_int(p, &cp->zidx);
  pup_int(p, &cp->xp);
  pup_int(p, &cp->xm);
  pup_int(p, &cp->yp);
  pup_int(p, &cp->ym);
  pup_int(p, &cp->zp);
  pup_int(p, &cp->zm);
  pup_doubles(p, &cp->t[0][0][0], (DIMX+2)*(DIMY+2)*(DIMZ+2));
  pup_doubles(p, cp->sbxm, (DIMY*DIMZ));
  pup_doubles(p, cp->sbxp, (DIMY*DIMZ));
  pup_doubles(p, cp->rbxm, (DIMY*DIMZ));
  pup_doubles(p, cp->rbxp, (DIMY*DIMZ));
  pup_doubles(p, cp->sbym, (DIMX*DIMZ));
  pup_doubles(p, cp->sbyp, (DIMX*DIMZ));
  pup_doubles(p, cp->rbym, (DIMX*DIMZ));
  pup_doubles(p, cp->rbyp, (DIMX*DIMZ));
  pup_doubles(p, cp->sbzm, (DIMX*DIMY));
  pup_doubles(p, cp->sbzp, (DIMX*DIMY));
  pup_doubles(p, cp->rbzm, (DIMX*DIMY));
  pup_doubles(p, cp->rbzp, (DIMX*DIMY));
  if(pup_isDeleting(p))
    delete cp;
}

#define abs(x) ((x)<0.0 ? -(x) : (x))

int index1d(int ix, int iy, int iz)
{
  return NY*NZ*ix + NZ*iy + iz;
}

void index3d(int index, int& ix, int& iy, int& iz)
{
  ix = index/(NY*NZ);
  iy = (index%(NY*NZ))/NZ;
  iz = index%NZ;
}

static void copyout(double *d, double t[DIMX+2][DIMY+2][DIMZ+2],
                    int sx, int ex, int sy, int ey, int sz, int ez)
{
  int i, j, k;
  int l = 0;
  for(i=sx; i<=ex; i++)
    for(j=sy; j<=ey; j++)
      for(k=sz; k<=ez; k++, l++)
        d[l] = t[i][j][k];
}

static void copyin(double *d, double t[DIMX+2][DIMY+2][DIMZ+2],
                    int sx, int ex, int sy, int ey, int sz, int ez)
{
  int i, j, k;
  int l = 0;
  for(i=sx; i<=ex; i++)
    for(j=sy; j<=ey; j++)
      for(k=sz; k<=ez; k++, l++)
        t[i][j][k] = d[l];
}

extern "C" void AMPI_Main(int ac, char** av)
{
  int i,j,k,m,cidx;
  int iter, niter;
  AMPI_Status status;
  double error, tval, maxerr, starttime, endtime, itertime;
  chunk *cp;
  int thisIndex, ierr, nblocks;

  AMPI_Init(&ac, &av);
  AMPI_Comm_rank(AMPI_COMM_WORLD, &thisIndex);
  AMPI_Comm_size(AMPI_COMM_WORLD, &nblocks);

  if(thisIndex == 0)
    niter = 140;

  AMPI_Bcast(&niter, 1, AMPI_INT, 0, AMPI_COMM_WORLD);

  cp = new chunk;
  AMPI_Register((void*)&cp, (AMPI_PupFn) chunk_pup);

  index3d(thisIndex, cp->xidx, cp->yidx, cp->zidx);
  cp->xp = index1d((cp->xidx+1)%NX,cp->yidx,cp->zidx);
  cp->xm = index1d((cp->xidx+NX-1)%NX,cp->yidx,cp->zidx);
  cp->yp = index1d(cp->xidx,(cp->yidx+1)%NY,cp->zidx);
  cp->ym = index1d(cp->xidx,(cp->yidx+NY-1)%NY,cp->zidx);
  cp->zp = index1d(cp->xidx,cp->yidx,(cp->zidx+1)%NZ);
  cp->zm = index1d(cp->xidx,cp->yidx,(cp->zidx+NZ-1)%NZ);
  for(i=1; i<=DIMZ; i++)
    for(j=1; j<=DIMY; j++)
      for(k=1; k<=DIMX; k++)
        cp->t[k][j][i] = DIMY*DIMX*(i-1) + DIMX*(j-2) + (k-1);

  AMPI_Barrier(AMPI_COMM_WORLD);
  starttime = AMPI_Wtime();

  maxerr = 0.0;
  for(iter=1; iter<=niter; iter++) {
    maxerr = 0.0;
    copyout(cp->sbxm, cp->t, 1, 1, 1, DIMY, 1, DIMZ);
    copyout(cp->sbxp, cp->t, DIMX, DIMX, 1, DIMY, 1, DIMZ);
    copyout(cp->sbym, cp->t, 1, DIMX, 1, 1, 1, DIMZ);
    copyout(cp->sbyp, cp->t, 1, DIMX, DIMY, DIMY, 1, DIMZ);
    copyout(cp->sbzm, cp->t, 1, DIMX, 1, DIMY, 1, 1);
    copyout(cp->sbzp, cp->t, 1, DIMX, 1, DIMY, DIMZ, DIMZ);

    AMPI_Send(cp->sbxm, DIMY*DIMZ, AMPI_DOUBLE, cp->xm, 0, AMPI_COMM_WORLD);
    AMPI_Send(cp->sbxp, DIMY*DIMZ, AMPI_DOUBLE, cp->xp, 1, AMPI_COMM_WORLD);
    AMPI_Send(cp->sbym, DIMX*DIMZ, AMPI_DOUBLE, cp->ym, 2, AMPI_COMM_WORLD);
    AMPI_Send(cp->sbyp, DIMX*DIMZ, AMPI_DOUBLE, cp->yp, 3, AMPI_COMM_WORLD);
    AMPI_Send(cp->sbzm, DIMX*DIMY, AMPI_DOUBLE, cp->zm, 4, AMPI_COMM_WORLD);
    AMPI_Send(cp->sbzp, DIMX*DIMY, AMPI_DOUBLE, cp->zp, 5, AMPI_COMM_WORLD);
    AMPI_Recv(cp->rbxm, DIMY*DIMZ, AMPI_DOUBLE, cp->xm, 1, AMPI_COMM_WORLD, &status);
    AMPI_Recv(cp->rbxp, DIMY*DIMZ, AMPI_DOUBLE, cp->xp, 0, AMPI_COMM_WORLD, &status);
    AMPI_Recv(cp->rbym, DIMX*DIMZ, AMPI_DOUBLE, cp->ym, 3, AMPI_COMM_WORLD, &status);
    AMPI_Recv(cp->rbyp, DIMX*DIMZ, AMPI_DOUBLE, cp->yp, 2, AMPI_COMM_WORLD, &status);
    AMPI_Recv(cp->rbzm, DIMX*DIMY, AMPI_DOUBLE, cp->zm, 5, AMPI_COMM_WORLD, &status);
    AMPI_Recv(cp->rbzp, DIMX*DIMY, AMPI_DOUBLE, cp->zp, 4, AMPI_COMM_WORLD, &status);

    copyin(cp->sbxm, cp->t, 0, 0, 1, DIMY, 1, DIMZ);
    copyin(cp->sbxp, cp->t, DIMX+1, DIMX+1, 1, DIMY, 1, DIMZ);
    copyin(cp->sbym, cp->t, 1, DIMX, 0, 0, 1, DIMZ);
    copyin(cp->sbyp, cp->t, 1, DIMX, DIMY+1, DIMY+1, 1, DIMZ);
    copyin(cp->sbzm, cp->t, 1, DIMX, 1, DIMY, 0, 0);
    copyin(cp->sbzp, cp->t, 1, DIMX, 1, DIMY, DIMZ+1, DIMZ+1);
    if(iter > 25 &&  iter < 85 && thisIndex == 35)
      m = 9;
    else
      m = 1;
    for(; m>0; m--)
      for(i=1; i<=DIMZ; i++)
        for(j=1; j<=DIMY; j++)
          for(k=1; k<=DIMX; k++) {
            tval = (cp->t[k][j][i] + cp->t[k][j][i+1] +
                 cp->t[k][j][i-1] + cp->t[k][j+1][i]+ 
                 cp->t[k][j-1][i] + cp->t[k+1][j][i] + cp->t[k-1][j][i])/7.0;
            error = abs(tval-cp->t[k][j][i]);
            cp->t[k][j][i] = tval;
            if (error > maxerr) maxerr = error;
          }
    AMPI_Allreduce(&maxerr, &maxerr, 1, AMPI_DOUBLE, AMPI_MAX, 
                   AMPI_COMM_WORLD);
    endtime = AMPI_Wtime();
    itertime = endtime - starttime;
    AMPI_Allreduce(&itertime, &itertime, 1, AMPI_DOUBLE, AMPI_SUM,
                   AMPI_COMM_WORLD);
    itertime = itertime/nblocks;
    if (thisIndex == 0)
      CkPrintf("iter %d time: %lf maxerr: %lf\n", iter, itertime, maxerr);
    starttime = AMPI_Wtime();
    if(iter%40 == 0) {
      char dname[10];
      sprintf(dname, "%d", iter);
      CkPrintf("[%d] checkpointing at iter %d\n", thisIndex, iter);
      AMPI_Checkpoint(dname);
      CkPrintf("[%d] resuming...\n", thisIndex);
    }
    if(iter%20 == 10) {
      AMPI_Migrate();
    }
  }
  AMPI_Finalize();
}
