#include "TCsim_sim.h"
extern int numWth;
extern int netLength;
extern int netHeight;
extern int netWidth;

// Switch implementation
Switch::Switch() {}

// the non-empty constructor (empty constructors do nothing)
Switch::Switch (SwitchInit *m)
{
  id = m->id;
  curLabel = Label(0, id.id);
  //assign neighbors to channel
  //build the cubecoord array we use locally to determine neighbors
  CProxy_routeMap nodemap_proxy(m->groupID);
  groupID=m->groupID;
  nodemap=nodemap_proxy.ckLocalBranch();
  int ***cubecoords = (int***)malloc(m->netLength*sizeof(int**));
  for (int z=0; z<m->netLength; z++)
    cubecoords[z] = (int**)malloc(m->netHeight*sizeof(int*));
  for (int x=0; x<m->netLength; x++)
    for (int y=0; y<m->netHeight; y++)
      cubecoords[x][y] = (int*)malloc(m->netWidth*sizeof(int));
  for (int l=0;l<m->netLength;l++)
    for(int h=0;h<m->netHeight;h++)
      for(int w=0;w<m->netWidth;w++)
	cubecoords[l][h][w]=m->idstart+l*m->netHeight*m->netWidth+h*m->netWidth+w;
  int    i=0;
  Chanlock locks[6];
  
  for(i=0;i<6;i++) {
    Label lockl=GetLabel();
    locks[i]=Chanlock(lockl);
  }
  i=0;
  if(m->position.x==m->netLength-1)  // we're wrapping to x=0
    channels[i]=Channel(id,Label(cubecoords[0][m->position.y][m->position.z],0),m->bandwidth,locks[i],m->wrapcost);
  else
    channels[i]=Channel(id,Label(cubecoords[m->position.x+1][m->position.y][m->position.z],0),m->bandwidth,locks[i],1);
  i++;
  if(m->position.x==0)  // we're wrapping to x=netLength
    channels[i]=Channel(id,Label(cubecoords[m->netLength-1][m->position.y][m->position.z],0),m->bandwidth,locks[i],m->wrapcost);
  else
    channels[i]=Channel(id,Label(cubecoords[m->position.x-1][m->position.y][m->position.z],0),m->bandwidth,locks[i],1);
  i++;
  if(m->position.y==m->netHeight-1)  // we're wrapping to y=0
    channels[i]=Channel(id,Label(cubecoords[m->position.x][0][m->position.z],0),m->bandwidth,locks[i],m->wrapcost);
  else
    channels[i]=Channel(id,Label(cubecoords[m->position.x][m->position.y+1][m->position.z],0),m->bandwidth,locks[i],1);
  i++;
  if(m->position.y==0)  // we're wrapping to y=netHeight
    channels[i]=Channel(id,Label(cubecoords[m->position.x][m->netHeight-1][m->position.z],0),m->bandwidth,locks[i],m->wrapcost);
  else
    channels[i]=Channel(id,Label(cubecoords[m->position.x][m->position.y-1][m->position.z],0),m->bandwidth,locks[i],1);
  i++;
  if(m->position.z==m->netWidth-1)  // we're wrapping to z=0
    channels[i]=Channel(id,Label(cubecoords[m->position.x][m->position.y][0],0),m->bandwidth,locks[i],m->wrapcost);
  else
    channels[i]=Channel(id,Label(cubecoords[m->position.x][m->position.y][m->position.z+1],0),m->bandwidth,locks[i],1);
  i++;
  if(m->position.z==0)  // we're wrapping to z=netWidth
    channels[i]=Channel(id,Label(cubecoords[m->position.x][m->position.y][m->netWidth-1],0),m->bandwidth,locks[i],m->wrapcost);
  else
    channels[i]=Channel(id,Label(cubecoords[m->position.x][m->position.y][m->position.z-1],0),m->bandwidth,locks[i],1);
  // to initialize queues their zero state is good enough do nothing
  lastChannel = 0;
  delete (m);
/*
  for (i=0; i<6; i++)
    channels[i].dump();
*/
}

// event handling

// Entry point for messages into the network
void Switch::initiate(DataMesgEnv *m)
{ // Push_back the message into our outbound queue then try to send the
  // message (make the netconnection, once that succeeds transmission
  // will begin automatically)
  if (m->bag.dest == id) {
    CkPrintf("ERROR: Don't place local messages on the network!\n");
    CkAbort("Kerplooey\n");
  }
  // reset the label
  m->bag.mid = GetLabel();
  m->bag.mesg.id = m->bag.mid;
  outMesgQ.push_back(m->bag);
  //m->dump();
  netconnect(m->bag.mid);
}


// The method by which other nodes communicate with us for channel chanlocks
// and otherwise open data transfers
// does not consume bandwidth. 
// sender lock the channel
void Switch::recv(Command *m)
{
  cmdOperation thisop = m->bag.command.cmdop;
  bool success = false;
  if ((m->bag.dest == id)&& (thisop == chanlock)) {
    DataMesg expect(m->bag.command.mid, m->bag.packsize, m->bag.numDataPackets, m->bag.td);
    inMesgQ.push_back (expect);
    success=true;
  }
  else if (thisop == chanlock) {
    int dchannel;
    map < Label, OutnBack >::iterator miter=packetroute.find(m->bag.command.mid);
    // already has a route starting from this switch, have a loop back
    if((miter!=packetroute.end()) && (miter->first ==m->bag.command.mid)) {
      success=false;
    }
    else {
      dchannel = findTheChannel(m->bag.dest,m->bag.sender);
      if(dchannel<6) {
	success = channels[dchannel].chanlock.testAndSet (m->bag.command.mid);
      }
    }
    if(success) {
      // add route to packetroute
      //	  pair < Label, OutnBack > melem;
      //	  melem.first=m->bag.command.mid;
      //	  OutnBack back(dchannel,m->bag.sender.idAsInt());
      // melem.second=back;
      packetroute.insert(map < Label, OutnBack>::value_type(Label(m->bag.command.mid),OutnBack(dchannel,m->bag.sender)));
      //CkPrintf("Switch %s, added route to %d for mid %s at time %d\n", id.sdump(),channels[dchannel].dest.idAsInt(),m->bag.command.mid.sdump(),m->timestamp);
      Command *passcommand =new Command (m->bag);
      passcommand->bag.sender=id;
      POSE_invoke (recv (passcommand), Switch, channels[dchannel].dest.idAsInt(), channels[dchannel].cost);
    }
  }
  else {
    parent->CommitError("ERROR Switch %s unchanlock or unknown command id %s %d\n",id.sdump(),m->bag.id.sdump(),m->timestamp);
    CmiAbort("");
  }
  if(!(success) || (m->bag.dest == id)) {
    // we need to generate our own reply here and now rather than pass along 
    //a reply later
    // it may be possible that this reply is sent to itself
    //               reply,src, dest,        intermed, from,   mid,     cid,done
    ReplyBag newbag(success,id,m->bag.src,m->bag.sender,id,m->bag.command.mid,m->bag.id,false);
    Reply *myreply = new Reply (newbag);
    if((success) && (m->bag.dest == id))
      replyQ.push_back(newbag); // we'll want this later to send the done reply
    POSE_invoke (recv (myreply), Switch, m->bag.sender.idAsInt(), getcost(m->bag.sender));
  }
}


// The method by which other nodes reply to commands or error conditions
//* does not consume bandwidth. 
void Switch::recv(Reply *m)
{
/*
  ???? gzheng
  map < Label, OutnBack >::const_iterator miter=packetroute.find(m->bag.mid);
  if ((miter!=packetroute.end()) && (miter->first==m->bag.mid)) {
    // see if reply come from the packroute channel expected
    if (m->bag.from.id != channels[miter->second.first].dest.id)  {
      parent->CommitError("ERROR: [%d] Reply not valid from:%s channel dest:%s\n", id.id, m->bag.from.sdump(), channels[miter->second.first].dest.sdump());
      return;
    }
  }
  else {	// not found a packet route
    if (m->bag.intermed != m->bag.from)  {
      parent->CommitError("ERROR: [%d]  no packetroute for msg [%s].\n", id.id, m->bag.mid.sdump());
      return;
    }
  }
*/

  if(m->bag.dest == id) { // this reply is for us
    //CkPrintf("Switch %s got reply %d for command %s mid %s at time %d\n", id.sdump(), m->bag.reply,m->bag.cid.sdump(), m->bag.mid.sdump(),m->timestamp);

    if(m->bag.done) { // we're done with the message
      map < Label, OutnBack >::const_iterator miter=packetroute.find(m->bag.mid);
      if ((miter!=packetroute.end()) && (miter->first==m->bag.mid)) {
        if (m->bag.from.id != channels[miter->second.first].dest.id)  {
          parent->CommitError("ERROR: [%d] Reply not valid from:%s channel dest:%s\n", id.id, m->bag.from.sdump(), channels[miter->second.first].dest.sdump());
          return;
        }
        channels[miter->second.first].chanlock.unchanlock(m->bag.mid);
        packetroute.erase(m->bag.mid);
        channelAvailable();
      }

      DataMesgEnvBag mymesg(m->bag.mid);
      deque <DataMesgEnvBag>::iterator msgiter=find(outMesgQ.begin(),
						    outMesgQ.end(), mymesg);
      if ((msgiter!=outMesgQ.end())&&(mymesg.mid==msgiter->mid))
	outMesgQ.erase(msgiter); 
    }
    else if (!m->bag.cid.isnull()) { // is a reply to a command
      // BEGIN IF
      CommandBag mycommand(m->bag.cid);
      deque <CommandBag>::iterator cmditer=find(cmdMesgQ.begin(),
						cmdMesgQ.end(), mycommand);
      if (cmditer==cmdMesgQ.end()) {
        parent->CommitError("ERROR: [%d] No such command [%s] for msg:%d in cmdMesgQ %d recovery:%d.\n", id.id, m->bag.cid.sdump(), m->bag.mid.id, cmditer->command.cmdop, CpvAccess(stateRecovery));
	// probably ripped out during loop avoidance
/*   ???????? gzheng
        map < Label, OutnBack >::const_iterator miter=packetroute.find(m->bag.mid);
        if ((miter!=packetroute.end()) && (miter->first==m->bag.mid)) {
          channels[miter->second.first].chanlock.unchanlock(m->bag.mid);
          packetroute.erase(m->bag.mid);
        }
*/
      }
      else if (m->bag.reply) {
	if (cmditer->command.cmdop == chanlock) { // send the message
          map < Label, OutnBack >::const_iterator miter=packetroute.find(m->bag.mid);
	  if (m->bag.from.id != channels[miter->second.first].dest.id)  {
            parent->CommitError("ERROR: [%d] Reply not valid from:%s channel dest:%s\n", id.id, m->bag.from.sdump(), channels[miter->second.first].dest.sdump());
            return;
          }
	  if (!sendMesg(m->bag.mid))
	    parent->CommitError("ERROR sendMesg failed cmd%s mesg%s time%d\n",
				 m->bag.cid.sdump(), m->bag.mid.sdump(),
				 m->timestamp);
	}
	else {
	  parent->CommitError("ERROR: Our %s command [%s] was %s not chanlock\n", id.sdump(), m->bag.cid.sdump(), cmditer->command.cmdop);
	}
	cmdMesgQ.erase(cmditer);  // we're done with this command
      }
      else { // our command failed try again
	if (cmditer->command.cmdop==chanlock) {
          map < Label, OutnBack >::const_iterator miter=packetroute.find(m->bag.mid);
          if ((miter!=packetroute.end()) && (miter->first==m->bag.mid)) {
	    channels[miter->second.first].chanlock.unchanlock(m->bag.mid);
	    packetroute.erase(m->bag.mid);
          }
	  cmdMesgQ.erase(cmditer);  //we're done with this command
	  netconnect(m->bag.mid);
          //channelAvailable();		// ??
	}
	else {
	  // if chanrecvdata fails it means the netconnection broke and we 
	  // have to rebuild the netconnection and resend failed packets
	  // first cut implementation assumes netconnections can't break
	  parent->CommitError("ERROR: [%s] unknown command type command [%s] was %d not chanlock\n", id.sdump(), m->bag.cid.sdump(), (int)cmditer->command.cmdop);
	  cmdMesgQ.erase(cmditer);  // we're done with this command
	}
      }
      // END IF
    }
    else { // a chanrecvdata they can't fail so do nothing
      parent->CommitError("ERROR: [%s] Reply wasn't to a command at all\n",
			   id.sdump());
    }
  }
  else {  	// not for us, forwarding
    map < Label, OutnBack >::const_iterator miter=packetroute.find(m->bag.mid);
    if ((miter!=packetroute.end()) && (miter->first==m->bag.mid)) {
      Label back=miter->second.second;
      if ((! m->bag.reply)||(m->bag.done)) { //unchanlock before forwarding
	channels[packetroute[m->bag.mid].first].chanlock.unchanlock(m->bag.mid);
	packetroute.erase(m->bag.mid);
      }
      //              reply,   src, dest,       intermed,     mid,    cid,done
      ReplyBag newbag(m->bag.reply, m->bag.src, m->bag.dest, back, id, m->bag.mid,
		      m->bag.cid,m->bag.done);
      Reply *myreply = new Reply (newbag);
      POSE_invoke(recv(myreply), Switch, back.idAsInt(), getcost(Label(back)));
      // CmiPrintf("[forward reply: status:%d src:%s dest:%s intermed:%s cid:%s]\n", m->bag.reply, m->bag.src.sdump(), m->bag.dest.sdump(), back.sdump(), m->bag.cid.sdump());
      // forward
    }
    else {
      parent->CommitError("ERROR: [%d]  no packetroute for msg [%s].\n", id.id, m->bag.mid.sdump());
    }
  }
}

// The method by which other nodes communicate with us for data packets
//*  consumes bandwidth
void Switch::recv (Packet *m)
{ // try to pass it along; if the chanlock succeeded then the route was 
  // established so just get the route
  Label key=m->bag.mid;
  map < Label, OutnBack >::const_iterator miter=packetroute.find(key);
  if ((miter!=packetroute.end()) &&(miter->first==key)) {
    Packet *passpack = new Packet(m->bag);
    POSE_invoke(recv(passpack), Switch, channels[miter->second.first].dest.idAsInt(), channels[miter->second.first].cost);
  }
  else { //   add to message
    //   issue print of message receipt
    //   if message complete issue the done
    DataMesg mymesg(m->bag.mid);
    deque <DataMesg>::iterator mesgiter=find(inMesgQ.begin(),inMesgQ.end(),mymesg);
    if (mesgiter==inMesgQ.end()) {
      parent->CommitError("ERROR: Switch %s CANNOT find mesg id %s to append packet %d\n",id.sdump(),m->bag.mid.sdump(),m->bag.sequence);
      // the sky it is a falling upon us
      //   printf error
      //   next implementation we'll reply with failure 
    }
    else { //insert the packet
      //CkPrintf("Switch %s got mid %s packet %d.\n",id.sdump(),m->bag.mid.sdump(),m->bag.sequence);
      mesgiter->rcvdpacks++;
      if (mesgiter->rcvdpacks>=mesgiter->numpackets) {
	// pull up the stored reply to the lock
	// it will have the correct enveloping info
	ReplyBag myreply(true,m->bag.mid);
	deque <ReplyBag>::iterator repliter=find(replyQ.begin(),replyQ.end(),myreply);
	if (repliter==replyQ.end()) {
	  parent->CommitError("ERROR: cannot find reply for %s in replyQ\n",m->bag.mid.sdump());
	  //		  while(repliter !=replyQ.end())
	  //{repliter->dump();repliter++;}
	}
	else {
	  ReplyBag newbag(true,repliter->src,repliter->dest,repliter->intermed,id,m->bag.mid,repliter->cid,true);
	  Reply *evreply=new Reply(newbag);
	  POSE_invoke (recv (evreply), Switch, repliter->intermed.idAsInt(), getcost(repliter->intermed));
	  replyQ.erase(repliter);
	  
	  // SEND TO BGPROC HERE
	  // mesgIter contains that TaskData we need to talk to the
	  // appropriate BGprocs
	  // Build TaskMsg
	  int myNodePID = switchPIDToNodePID(parent->thisIndex);
	  int myNode = switchPIDToNode(parent->thisIndex);
	  TaskMsg *tm = new TaskMsg(mesgiter->td.srcNode, mesgiter->td.msgID,
		     mesgiter->td.index, mesgiter->td.recvTime,
		     mesgiter->td.msgSize, myNode,
 		     mesgiter->td.destNode, mesgiter->td.destTID);
	  // Now we need to figure out where this thing goes
//	  int myNode = parent->thisIndex - (netLength*netHeight*netWidth)*numWth;
          POSE_invoke(recvIncomingMsg(tm), BGnode, myNodePID, 0);
	  // CkPrintf("[NetSim: message received at final destination switch.\n...myNodePID=%d destNode=%d destTID=%d numWth=%d]\n", myNodePID, mesgiter->td.destNode, mesgiter->td.destTID, numWth);
        }
	inMesgQ.erase(mesgiter);
      }
    }
  }
}


// normal methods
//PRE: we have a chanlocked channel 
//POSE: message sent
bool Switch::sendMesg(Label  mid)
{
  DataMesgEnvBag mymesg(mid);
  deque <DataMesgEnvBag>::iterator mesgiter=find(outMesgQ.begin(),outMesgQ.end(),mymesg);
  //  DataMesg *copy;
  //  *copy = mesgiter->mesg; 
  map < Label, OutnBack >::const_iterator miter=packetroute.find(mid);
  if (mesgiter == outMesgQ.end()) {
    parent->CommitError("ERROR: Switch[%s]:sendMesg has no mesg matching %s in outMesgQ \n", id.sdump(), mid.sdump());
    return false;
  }
  else if ((miter != packetroute.end()) && (miter->first == mid)) {
    //    CkPrintf("Switch[%s]:sendMesg sending %s\n", id.sdump(), mid.sdump());
    int dchannel=packetroute[mid].first;
    int dest=channels[dchannel].dest.idAsInt();  
    // invent the packets
    bool last=false;
    unsigned long timeperpacket=1;
    unsigned long packperunittime=channels[dchannel].bandwidth/mesgiter->mesg.maxpacksize;
    
    if (mesgiter->mesg.maxpacksize>channels[dchannel].bandwidth)      
      // messages larger than frame  ???? gzheng
      timeperpacket = mesgiter->mesg.maxpacksize/channels[dchannel].bandwidth;
    
    for (unsigned long seq=0; seq<mesgiter->mesg.numpackets; seq++) {
      Label npid=GetLabel();
      if (seq+1 == mesgiter->mesg.numpackets)
	last=true;
      PacketBag newbag(npid,mid,seq,last);
      Packet *npacket= new Packet(newbag);
      if (timeperpacket>1)
	POSE_invoke (recv (npacket), Switch, dest, channels[dchannel].cost+ (int) (timeperpacket*seq) );
      else
	POSE_invoke (recv (npacket), Switch, dest, channels[dchannel].cost+ (int) (seq/packperunittime) );
    }
  }
  else {
    if ((miter == packetroute.end()) || (!(miter->first == mid)))
      parent->CommitError("ERROR: Switch[%s]:sendMesg has no route to send %s\n", id.sdump(), mid.sdump());
    else
      parent->CommitError("ERROR: Switch[%s]:sendMesg has no mesg matching %s in outMesgQ \n", id.sdump(), mid.sdump());
    return false;
  }
  return true;
}

void Switch::channelAvailable()
{
  if (connectReqQ2.empty()) return;
  const ConnectRequest &req = connectReqQ2.top();
  Label mid = req.mid;
  connectReqQ2.pop();
  netconnect(mid);
}

void Switch::netconnect(Label  mid)
{
  DataMesgEnvBag mymesg(mid);
  deque <DataMesgEnvBag>::iterator mesgiter=find(outMesgQ.begin(),
						 outMesgQ.end(), mymesg);
  // send the chanlock command to ourself
  if ((mesgiter!=outMesgQ.end())&&(mymesg.mid==mesgiter->mid)) {
    // see if there is one channel available
    // if not available, buffer it
    int  dchannel = findTheChannel(mesgiter->dest,id);
#if 0		// 1 disabled
dchannel = 0;
#endif
    if (dchannel >= 6) {
      connectReqQ2.push(ConnectRequest(mid, mesgiter->ltimestamp));
    }
    else {
      cmdOp ncmd(chanlock,mesgiter->mid);
      Label ncid=GetLabel();
      //                 id, src,dest,sender,command, numDataPackets,packsize,td
      CommandBag newbag(ncid,id,mesgiter->dest,id,ncmd,mesgiter->mesg.numpackets,mesgiter->mesg.maxpacksize,mesgiter->mesg.td);
      Command *selfcmd= new Command(newbag);
      cmdMesgQ.push_back (newbag);
      POSE_local_invoke(recv(selfcmd), 1);
    }
    //CkPrintf("Switch %s in netconnect -- channel lock attempt src:%s dest:%s sender:%s.\n", id.sdump(), id.sdump(), mesgiter->dest.sdump(), id.sdump());
  }
}

int Switch::findTheChannel(Label dest,Label back)
{
  Chandoublepair weights[6];
  // find netdistance to dest for each neighbor to weight them
  // random selection from weighted distribution
  // Not coming up with how to do this with generic algorithm
  int avail=0;
  for (int i=0;i<6;i++) {
    if(channels[i].chanlock.test()) { //this sucker is locked route around
      weights[i].first=i;
      weights[i].second=INT_MAX;
    }
    else if (channels[i].dest == back) { //you may not reverse course
      weights[i].first=i;
      weights[i].second=INT_MAX;
    }
    else if (channels[i].dest==dest) { // if the channel leads to dest cut the search and return it
      return i;
    }
    else if((!(channels[i].dest.isnull()))&& (!(dest.isnull()))) {
      weights[i].first=i;
      weights[i].second=nodemap->netdistance(id,channels[i].dest,dest);
      avail ++;
    }
    else {
      weights[i].first=i;
      weights[i].second=INT_MAX;
    }
  }
  if (avail == 0) return 6;
  // right now weights has the netdistances in it.
  //nth_element(&weights[0],&weights[0]+1,&weights[6],PairSecondComp());
  sort(&weights[0],&weights[6],PairSecondComp());
  double minnetdistance=weights[0].second;
//  CmiPrintf("netdistance: %s %s %f %f %f %f %f %f\n", dest.sdump(), back.sdump(), weights[0].second, weights[1].second, weights[2].second, weights[3].second, weights[4].second,weights[5].second);
  if(minnetdistance==INT_MAX)
    return 6;
  else { //randomly select from the minimums 
    typedef pair<Chandoublepair*,Chandoublepair*> dppair;
    dppair matches=equal_range(&weights[0],&weights[6],weights[0],PairSecondComp());
    int pick=rand()%(matches.second-matches.first);
    int range = matches.second-matches.first;
    pick = lastChannel%range;
    lastChannel++;  if (lastChannel==6) lastChannel=0;
    CmiAssert(pick>=0 && pick<6);
    return(weights[pick].first);
  }
}

DataMesgEnv *ConvertAndInitiate(int sN, int mI, int id, int dN, int dT, int rT,
				int mS, int mxPktSz, int ts, int inSrc, int realDest)
{
  TaskData td(sN, mI, id, dN, dT, rT, mS);
  int pktnum = mS / mxPktSz;
  if(mS % mxPktSz)
    pktnum += 1;

  Label src(inSrc, 0);
  Label dest(realDest, 0);
//  Label nmid=GetLabel();
  Label nmid= Label();
  DataMesg dm(nmid, mxPktSz, pktnum, td);
  DataMesgEnvBag dmeb(ts, src, dest, dm);
  DataMesgEnv *dme = new DataMesgEnv(dmeb);
  return dme;
}

void Switch::recv_anti(Packet *m)
{
  restore(this);
}
                                                                                
void Switch::recv_anti(Command *m)
{
  restore(this);
}
                                                                                
void Switch::recv_anti(Reply *m)
{
  restore(this);
}

// several small utility functions

int Switch::switchPIDToNodePID(int switchpid)
{
  return switchpid - (netLength*netHeight*netWidth);
}

int nodeToSwitchPID(int destnode)
{
  return (netLength*netHeight*netWidth*(numWth+1) + destnode);
}

int switchPIDToNode(int switchpid)
{
  return switchpid - (netLength*netHeight*netWidth)*(numWth+1);
}


