/*
Load-balancing test program:
  Orion Sky Lawlor, 10/19/1999

  Added more complex comm patterns
  Robert Brunner, 11/3/1999

*/

#include <stdio.h>
#include <unistd.h>
#include <math.h>
#include "charm++.h"
#include "LBDatabase.h"
#include "Topo.h"
#include "CentralLB.h"
#include "RandCentLB.h"
#include "RecBisectBfLB.h"
#include "RefineLB.h"
#include "CommLB.h"
#ifdef USE_METIS
#include "MetisLB.h"
#endif
#include "HeapCentLB.h"
#include "NeighborLB.h"
#include "WSLB.h"
#include "GreedyRefLB.h"
#include "RandRefLB.h"
#include "manager.h"

#include "lb_test.decl.h"

CkChareID mid;//Main ID
CkGroupID topoid;
CkArrayID aid;//Array ID
int n_loadbalance;

#define N_LOADBALANCE 500 /*Times around ring until we load balance*/

int cycle_count,element_count,step_count,print_count;
int min_us,max_us;

class HiMsg : public CMessage_HiMsg {
public:
  int length;
  int chksum;
  int refnum;
  char* data;

  static void* alloc(int msgnum, int size, int* array, int priobits) {
    int totalsize = size + array[0] * sizeof(char);
    HiMsg* msg = (HiMsg*) CkAllocMsg(msgnum,totalsize,priobits);
    msg->data = (char*) msg + ALIGN8(size);
    return (void*) msg;
  }
  static HiMsg* pack(HiMsg* in) {
    in->data = (char*)(in->data - (char*)(&in->data));
    return in;
  }
  static HiMsg* unpack(void* in) {
    HiMsg* me = new(in) HiMsg;
    me->data = (char*)(&me->data) + (size_t)me->data;
    return me;
  }
};

class main : public Chare {
public:
  int nDone;

  main(CkMigrateMessage *m) {}
  main(CkArgMsg* m);

  void maindone(void) {
    CkPrintf("In main done\n");
    nDone++;
    if (nDone==element_count) {
      CkPrintf("All done\n");
      CkExit();
    }
  };

private:
  void arg_error(char* argv0);
};

static const struct {
  const char *name;//Name of strategy (on command line)
  const char *description;//Text description of strategy
  void (*create)(void);//Strategy routine
} StratTable[]={
  {"none",
   "none - The null load balancer, collect data, but do nothing",
   CreateCentralLB},
  {"neighbor",
   "neighbor - The neighborhood load balancer",
   CreateNeighborLB},
  {"workstation",
   "workstation - Like neighbor, but for workstation performance",
   CreateWSLB},
  {"random",
   "random - Assign objects to processors randomly",
   CreateRandCentLB},
  {"greedy",
   "greedy - Use the greedy algorithm to place heaviest object on the "
   "least-loaded processor until done",
   CreateHeapCentLB},
#ifdef USE_METIS
  {"metis",
   "metis - Use Metis(tm) to partition object graph",
   CreateMetisLB},
#endif
  {"refine",
   "refine - Move a very few objects away from most heavily-loaded processor",
   CreateRefineLB},
  {"greedyref",
   "greedyref - Apply greedy, then refine",
   CreateGreedyRefLB},
  {"randref",
   "randref - Apply random, then refine",
   CreateRandRefLB},
  {"comm",
   "comm - Greedy with communication",
   CreateCommLB},

  {"recbf",
   "recbf - Recursive partitioning with Breadth first enumeration, with 2 nuclei",
   CreateRecBisectBfLB},

  {NULL,NULL,NULL}
};

main::main(CkArgMsg *m) 
{
  char *strategy;//String name for strategy routine
  char *topology;//String name for communication topology
  int stratNo;
  nDone=0;

  manager_init();
  int cur_arg = 1;

  if (m->argc > cur_arg)
    element_count=atoi(m->argv[cur_arg++]);
  else arg_error(m->argv[0]);

  if (m->argc > cur_arg)
    step_count=atoi(m->argv[cur_arg++]);
  else arg_error(m->argv[0]);
  
  if (m->argc > cur_arg)
    print_count=atoi(m->argv[cur_arg++]);
  else arg_error(m->argv[0]);
  
  if (m->argc > cur_arg)
    n_loadbalance=atoi(m->argv[cur_arg++]);
  else arg_error(m->argv[0]);

  if (m->argc > cur_arg)
    min_us=atoi(m->argv[cur_arg++]);
  else arg_error(m->argv[0]);

  if (m->argc > cur_arg)
    max_us=atoi(m->argv[cur_arg++]);
  else arg_error(m->argv[0]);

  if (m->argc > cur_arg)
    strategy=m->argv[cur_arg++];
  else arg_error(m->argv[0]);

  if (m->argc > cur_arg)
    topology=m->argv[cur_arg++];
  else arg_error(m->argv[0]);

  //Look up the user's strategy in table
  stratNo=0;
  while (StratTable[stratNo].name!=NULL) {
    if (0==strcasecmp(strategy,StratTable[stratNo].name)) {
      //We found the user's chosen strategy!
      StratTable[stratNo].create();//Create this strategy
      break;
    }
    stratNo++;
  }

  CkPrintf("%d processors\n",CkNumPes());
  CkPrintf("%d elements\n",element_count);
  CkPrintf("Print every %d steps\n",print_count);
  CkPrintf("Sync every %d steps\n",n_loadbalance);
  CkPrintf("First node busywaits %d usec; last node busywaits %d\n",
	   min_us,max_us);

  mid = thishandle;


  if (StratTable[stratNo].name==NULL)
    //The user's strategy name wasn't in the table-- bad!
    CkAbort("ERROR! Strategy not found!  \n");
	
  topoid = Topo::Create(element_count,topology,min_us,max_us);
  if (topoid == -1)
    CkAbort("ERROR! Topology not found!  \n");

  aid = CProxy_Lb_array::ckNew(element_count);
  CProxy_Lb_array hproxy(aid);

  //Start everybody computing
  for (int i=0;i<element_count;i++)
    hproxy[i].ForwardMessages();
};

void main::arg_error(char* argv0)
{
  CkPrintf("Usage: %s \n"
    "<elements> <steps> <print-freq> <lb-freq> <min-dur us> <max-dur us>\n"
    "<strategy> <topology>\n"
    "<strategy> is the load-balancing strategy:\n",argv0);
  int stratNo=0;
  while (StratTable[stratNo].name!=NULL) {
    CkPrintf("  %s\n",StratTable[stratNo].description);
    stratNo++;
  }

  int topoNo = 0;
  CkPrintf("<topology> is the object connection topology:\n");
  while (TopoTable[topoNo].name) {
    CkPrintf("  %s\n",TopoTable[topoNo].desc);
    topoNo++;
  }

  CkPrintf("\n"
	   " The program creates a ring of element_count array elements,\n"
	   "which all compute and send to their neighbor cycle_count.\n"
	   "Computation proceeds across the entire ring simultaniously.\n"
	   "Orion Sky Lawlor, olawlor@acm.org, PPL, 10/14/1999\n");
  abort();
}

class Lb_array : public ArrayElement1D {
public:
  Lb_array(void) {
    //    CkPrintf("Element %d created\n",thisIndex);

    //Find out who to send to, and how many to receive
    TopoMap = CProxy_Topo::ckLocalBranch(topoid);
    send_count = TopoMap->SendCount(thisIndex);
    send_to = new Topo::MsgInfo[send_count];
    TopoMap->SendTo(thisIndex,send_to);
    recv_count = TopoMap->RecvCount(thisIndex)+1;
    
    // Benchmark the work function
    work_per_sec = CalibrateWork();

    //Create massive load imbalance by making load
    // linear in processor number.
    usec = TopoMap->Work(thisIndex);
//    CkPrintf("Element %d working for %d ms\n",thisIndex,usec);

    //msec=meanms+(devms-meanms)*thisIndex/(element_count-1);

    // Initialize some more variables
    nTimes=0;
    sendTime=0;
    lastTime=CmiWallTimer();
    n_received = 0;
    resumed = 1;
    busywork = (int)(usec*1e-6*work_per_sec);
    
    int i;
    for(i=0; i < future_bufsz; i++)
      future_receives[i]=0;
	
	usesAtSync=CmiTrue;
  }

  //Packing/migration utilities
  Lb_array(CkMigrateMessage *m) {
	   CkPrintf("Migrated element %d to processor %d\n",thisIndex,CkMyPe());
    TopoMap = CProxy_Topo::ckLocalBranch(topoid);
    //Find out who to send to, and how many to receive
    send_count = TopoMap->SendCount(thisIndex);
    send_to = new Topo::MsgInfo[send_count];
    TopoMap->SendTo(thisIndex,send_to);
    recv_count = TopoMap->RecvCount(thisIndex)+1;
    resumed = 0;
  }

  virtual void pup(PUP::er &p)
  {
	ArrayElement1D::pup(p);//<- pack our superclass
	p(nTimes);p(sendTime);
	p(usec);p(lastTime);
	p(work_per_sec);
	p(busywork);
	p(n_received);
	p(future_receives,future_bufsz);
  }

  void Compute(HiMsg *m) { 
    //Perform computation

    if (m->refnum > nTimes) {
      //      CkPrintf("[%d] Future message received %d %d\n",
      //      	       thisIndex,nTimes,m->refnum);
      int future_indx = m->refnum - nTimes - 1;
      if (future_indx >= future_bufsz) {
	CkPrintf("[%d] future_indx is too far in the future %d, expecting %d, got %d\n",
		 thisIndex,future_indx,nTimes,m->refnum);
	(CProxy_Lb_array(thisArrayID))[thisIndex].Compute(m);
      } else {
	future_receives[future_indx]++;
	delete m;
      }
    } else if (m->refnum < nTimes) {
      CkPrintf("[%d] Stale message received %d %d\n",
	       thisIndex,nTimes,m->refnum);
      delete m;
    } else {
      n_received++;

      //      CkPrintf("[%d] %d n_received=%d of %d\n",
      //      	       CkMyPe(),thisIndex,n_received,recv_count);
      if (n_received == recv_count) {
	//	CkPrintf("[%d] %d computing %d\n",CkMyPe(),thisIndex,nTimes);

	if (nTimes % print_count == 0) {
	  //Write out the current time
	  if (thisIndex==0) {
	    double now = CmiWallTimer();
	    CkPrintf("GREP00\t%d\t%lf\t%lf\n",
		     nTimes,now,now-lastTime);
	    lastTime=now;
	  }
	}

	n_received = future_receives[0];

	// Move all the future_receives down one slot
	int i;
	for(i=1;i<future_bufsz;i++)
	  future_receives[i-1] = future_receives[i];
	future_receives[future_bufsz-1] = 0;

	nTimes++;//Increment our "times around"	

	double startTime=CmiWallTimer();
	// First check contents of message
	//     int chksum = 0;
	//     for(int i=0; i < m->length; i++)
	//       chksum += m->data[i];
      
	//     if (chksum != m->chksum)
	//       CkPrintf("Checksum mismatch! %d %d\n",chksum,m->chksum);

	//Do Computation:
	work(busywork,&result);
	
	int loadbalancing = 0;
	if (nTimes == step_count) {
	  //We're done-- send a message to main telling it to die
	  CProxy_main mproxy(mid);
	  mproxy.maindone();
	} else if (nTimes % n_loadbalance == 0) {
	  //We're not done yet...
	  //Either load balance, or send a message to the next guy
	  CkPrintf("Element %d AtSync on PE %d\n",thisIndex,CkMyPe());
	  AtSync();
	  loadbalancing = 1;
	} else ForwardMessages();
      }
      delete m;
    }
  }

  void ResumeFromSync(void) { //Called by Load-balancing framework
    resumed = 1;
	CkPrintf("Element %d resumeFromSync on PE %d\n",thisIndex,CkMyPe());
    CProxy_Lb_array hproxy(aid);
    hproxy[thisIndex].ForwardMessages();
  }

  void ForwardMessages(void) { //Pass it on
    if (resumed != 1)
      CkPrintf("[%d] %d forwarding %d %d %d\n",CkMyPe(),thisIndex,
	       sendTime,nTimes,resumed);
    for(int s=0; s < send_count; s++) {
      int msgbytes = send_to[s].bytes;
      if (msgbytes != 1000)
	CkPrintf("[%d] %d forwarding %d bytes (%d,%d,%p) obj %p to %d\n",
		 CkMyPe(),thisIndex,msgbytes,s,send_count,send_to,
		 this,send_to[s].obj);
      HiMsg* msg = new(&msgbytes,1) HiMsg;
      msg->length = msgbytes;
      //      msg->chksum = 0;
      //      for(int i=0; i < msgbytes; i++) {
      //	msg->data[i] = i;
      //	msg->chksum += msg->data[i];
      //      }
      msg->refnum = sendTime;

      //      CkPrintf("[%d] %d sending to %d at %d:%d\n",
      //	       CkMyPe(),thisIndex,send_to[s].obj,nTimes,nCycles);
      CProxy_Lb_array hproxy(aid);
      hproxy[send_to[s].obj].Compute(msg);
    }
    int mybytes=1;
    HiMsg* msg = new(&mybytes,1) HiMsg;
    msg->length = mybytes;
    msg->refnum = sendTime;
    (CProxy_Lb_array(aid))[thisIndex].Compute(msg);

    sendTime++;
  }

private:
  double CalibrateWork() {
    static double calibrated=-1.;

    if (calibrated != -1) return calibrated;

    double wps = 0;
    // First, count how many iterations for 1 second.
    // Since we are doing lots of function calls, this will be rough
    const double end_time = CmiWallTimer()+1;
    wps = 0;
    while(CmiWallTimer() < end_time) {
      work(100,&result);
      wps+=100;
    }

    // Now we have a rough idea of how many iterations there are per
    // second, so just perform a few cycles of correction by
    // running for what we think is 1 second.  Then correct
    // the number of iterations per second to make it closer
    // to the correct value

    for(int i=0; i < 2; i++) {
      const double start_time = CmiWallTimer();
      work(wps,&result);
      const double end_time = CmiWallTimer();
      const double correction = 1. / (end_time-start_time);
      wps *= correction;
    }

    // If necessary, do a check now
    //    const double start_time3 = CmiWallTimer();
    //    work(msec * 1e-3 * wps);
    //    const double end_time3 = CmiWallTimer();
    //    CkPrintf("[%d] Work block size is %d %d %f\n",
    //	     thisIndex,wps,msec,1.e3*(end_time3-start_time3));
    calibrated = wps;
    return wps;
  };

  void work(int iter_block,int* _result) {
    *_result=0;
    for(int i=0; i < iter_block; i++) {
      *_result=(int)(sqrt(1+cos(*_result*1.57)));
    }
  };

public:
  enum { future_bufsz = 50 };

private:
  int nTimes;//Number of times I've been called
  int sendTime;//Step number for sending (in case I finish receiving
               //before sending
  int usec;//Milliseconds to "compute"
  double lastTime;//Last time recorded
  int work_per_sec;
  int busywork;
  int result;

  Topo* TopoMap;
  int send_count;
  int recv_count;
  Topo::MsgInfo* send_to;
  int n_received;
  int future_receives[future_bufsz];
  int resumed;
};

#include "lb_test.def.h"

