/*
A very simple 1D FEM mesh prep. tool
*/
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include "charm++.h"
#include "fem.h"

extern "C" void mesh_updated(int param);

const int np=4; /*Nodes per element*/
const int nd=3; /*Coordinates per node */
const int tetPer=1; /* Average number of tets per node */

class myMesh {
public:
	int nelems, nnodes;
	int *conn; //Connectivity: maps each element to its np nodes
	double *nodes; //Node data: np coordinates per node
	
private:
	void readWriteData(int mesh) {
		FEM_Mesh_data(mesh,FEM_NODE,FEM_DATA+0,
			nodes, 0,nnodes, FEM_DOUBLE,nd);
		FEM_Mesh_data(mesh,FEM_ELEM+0,FEM_CONN,
			conn, 0,nelems, FEM_INDEX_0,np);
	}
	
	void allocate(int ne,int nn) {
		nelems=ne; nnodes=nn;
		conn=new int[np*nelems];
		nodes=new double[nd*nnodes];
	}
	void deallocate(void) {
		delete[] conn;
		delete[] nodes;
	}
	
public: 
	/// Build this mesh with ne elements and nn nodes:
	myMesh(int ne,int nn) {
		allocate(ne,nn);
	}
	/// Build this mesh from this FEM mesh structure:
	myMesh(int mesh) {
		allocate(FEM_Mesh_get_length(mesh,FEM_ELEM+0),
		         FEM_Mesh_get_length(mesh,FEM_NODE));
		readWriteData(mesh);
	}
	
	~myMesh() {deallocate();}
	
	/// Write this mesh's data to this FEM mesh data structure:
	void write(int mesh) {
		readWriteData(mesh);
	}
	
	void check(void);
};

int dim=1000000;

void pushGhost(void) {
     /* Add a layer of node-adjacent ghosts */
     static const int edge2node[]={0,1};
     FEM_Add_ghost_layer(1,1);
     FEM_Add_ghost_elem(0,2,edge2node);
}

extern "C" void
init(void)
{
  FEM_Print("--------- init called -----------");
  
  //Prepare a new mesh:
  myMesh m(/*nelems=*/ dim,  /*nnodes=*/ dim/tetPer+1);
  for (int n=0;n<m.nnodes;n++) {
    for (int c=0;c<nd;c++) 
      m.nodes[n*nd+c]=n/(float)m.nnodes;
  }
  for (int e=0;e<m.nelems;e++) {
    int ne=e/tetPer;
    m.conn[e*np+0]=ne;
    m.conn[e*np+1]=ne+1;
    m.conn[e*np+2]=ne;
    m.conn[e*np+3]=ne+1;
  }
  
  //Push the new mesh into the framework:
  m.write(FEM_Mesh_default_write());
  pushGhost();
  
  int nChunks=FEM_Num_partitions();
  int c;
#if 0
  //Test out FEM_Serial_split (immediately writes out output files)
  FEM_Print("Calling serial split");
  FEM_Serial_split(nChunks);
  for (c=0;c<nChunks;c++) {
    FEM_Serial_begin(c);
    myMesh m(FEM_Mesh_default_read());
    CkPrintf("  serial split chunk %d> %d nodes, %d elements\n",c,m.nnodes,m.nelems);
  }
 
  //Test out FEM_Serial_assemble (reads files written by FEM_Serial_split)
  FEM_Print("Calling serial join");
  for (c=0;c<nChunks;c++) {
    FEM_Serial_read(c,nChunks);
    myMesh m(FEM_Mesh_default_read());
    CkPrintf("  serial join chunk %d> %d nodes, %d elements\n",c,m.nnodes,m.nelems);
    m.write(FEM_Mesh_default_write());
  }
  FEM_Serial_assemble();
  mesh_updated(123);
#endif
  FEM_Print("---------- end of init -------------");
}

void testEqual(double is,double shouldBe,const char *what) {
	if (fabs(is-shouldBe)<0.000001) {
		//CkPrintf("[chunk %d] %s test passed.\n",FEM_My_partition(),what);
	} 
	else {/*test failed*/
		CkPrintf("[chunk %d] %s test FAILED-- expected %f, got %f (pe %d)\n",
                        FEM_My_partition(),what,shouldBe,is,CkMyPe());
		CkAbort("FEM Test failed\n");
	}
}


void myMesh::check(void) {
}

extern "C" void
driver(void)
{
  int myID=FEM_My_partition();
  if (myID==0) FEM_Print("----------- begin driver ------------");
  for (int loop=0;loop<2;loop++) {
    {
      //Read this mesh out of the framework:
      myMesh m(FEM_Mesh_default_read());

      m.check();
      if (myID==0)
      CkPrintf("    loop %d: chunk %d> %d nodes, %d elements\n",
    	loop,FEM_My_partition(),m.nnodes,m.nelems);
    
      //Prepare mesh to be updated:
      m.write(FEM_Mesh_default_write());
    }
  
    double before=CkWallTimer();
    FEM_Update_mesh(mesh_updated,123,FEM_MESH_UPDATE);
    double elapsed=CkWallTimer()-before;
    if (myID%32==0)
      CkPrintf("Elapsed assembly/partitioning time: %.3f seconds\n",elapsed);
    
    if (0) { // loop%3==0) { 
      if (myID==0) FEM_Print("----- migrating -----");
      FEM_Migrate();
    }
  }
  if (myID==0) FEM_Print("----------- end driver ------------");
}

extern "C" void
mesh_updated(int param)
{
  CkPrintf("mesh_updated(%d) called.\n",param);
  testEqual(param,123,"mesh_updated param");
  myMesh m(FEM_Mesh_default_read());
  
  CkPrintf("mesh_updated> %d nodes, %d elements\n",m.nnodes,m.nelems);

  for (int n=0;n<m.nnodes;n++) {
    testEqual(m.nodes[n*nd+0],n/(float)m.nnodes,"node data");
  }
  for (int e=0;e<m.nelems;e++) {
    testEqual(m.conn[e*np+0],e,"element connectivity (col 0)");
    testEqual( m.conn[e*np+1],e+1,"element connectivity (col 1)");
  }
  
  m.write(FEM_Mesh_default_write());
  // pushGhost(); //< FIXME: causes duplicate layers!
}
