#include <stdlib.h>
#include <math.h>
#include "fem.h"

extern "C" void
init(void)
{
  CkPrintf("init called\n");
}

struct Node {
  double temp;
  double prev;
};

struct Element {
  double temp;
  double prev;
};

extern "C" void
driver(int nnodes, int *nnums, int nelems, int *enums, int npere, int *conn)
{
  int N = nnodes;
  int E = nelems;
  int NE = npere;
  Node *nodes = new Node[N];
  Element *elems = new Element[E];
  int i,j;
  for(i=0;i<N;i++) {
    nodes[i].temp = drand48()/8.0;
    nodes[i].prev = 0.0;
  }
  for(i=0;i<E;i++) {
    elems[i].temp = 0.0;
    elems[i].prev = 0.0;
  }
  int fid = FEM_Create_Field(FEM_DOUBLE, 1, 0, sizeof(Node));
  FEM_Update_Field(fid, nodes);
  int converged = 0;
  while (!converged) {
    for(i=0; i<E; i++) {
      elems[i].prev = elems[i].temp;
      // temp of the element is average of temps of its nodes
      elems[i].temp = 0.0;
      for(j=0;j<NE;j++) {
        elems[i].temp += (nodes[conn[i*NE+j]].temp);
      }
      elems[i].temp /= NE;
    }
    for(i = 0; i<N; i++) {
      nodes[i].prev = nodes[i].temp;
      nodes[i].temp = 0.0;
    }
    for(i=0; i<E; i++) {
      for(j=0;j<NE;j++) {
        nodes[conn[i*NE+j]].temp += (elems[i].temp/8.0);
      }
    }
    FEM_Update_Field(fid, nodes);
    double diff = 0.0;
    for(i=0; i<E; i++) {
      diff += fabs(elems[i].prev - elems[i].temp);
    }
    FEM_Reduce(FEM_DOUBLE, &diff, &diff, FEM_SUM);
    if(FEM_My_Partition() == 0) {
      CkPrintf("Error = %lf\n", diff);
    }
    converged = (diff < 1e-8);
  }
}

extern "C" void
finalize(void)
{
  CkPrintf("finalize called\n");
}
