#include <stdio.h>
#include "allvars.h"
#include "proto.h"
#include <hip/hip_runtime.h>
#include <rccl/rccl.h>

int verbose_level = 0;

static uint64_t getHostHash(const char* string) {
  // Based on DJB2a, result = result * 33 ^ char                                                                                                 
  uint64_t result = 5381;
  for (int c = 0; string[c] != '\0'; c++){
    result = ((result << 5) + result) ^ string[c];
  }
  return result;
}


static void getHostName(char* hostname, int maxlen) {
  gethostname(hostname, maxlen);
  for (int i=0; i< maxlen; i++) {
    if (hostname[i] == '.') {
        hostname[i] = '\0';
        return;
    }
  }  
}


void gridding(){

    if(rank == 0)printf("GRIDDING DATA\n");

    // Create histograms and linked lists
    
    clock_gettime(CLOCK_MONOTONIC, &begin);
    start = clock();

    // Initialize linked list
    initialize_array();

    //Sector and Gridding data
    gridding_data();

    #ifdef USE_MPI
        MPI_Barrier(MPI_COMM_WORLD);
    #endif

    end = clock();
    clock_gettime(CLOCK_MONOTONIC, &finish);
    timing.process_time = ((double) (end - start)) / CLOCKS_PER_SEC;
    timing.process_time1 = (finish.tv_sec - begin.tv_sec);
    timing.process_time1 += (finish.tv_nsec - begin.tv_nsec) / 1000000000.0;
    clock_gettime(CLOCK_MONOTONIC, &begin);

}

void initialize_array(){

    histo_send = (long*) calloc(nsectors+1,sizeof(long));
    int * boundary = (int*) calloc(metaData.Nmeasures,sizeof(int));
    double uuh,vvh;
    for (long iphi = 0; iphi < metaData.Nmeasures; iphi++)
    {
     	   boundary[iphi] = -1;
           vvh = data.vv[iphi];  //less or equal to 0.6
           int binphi = (int)(vvh*nsectors); //has values expect 0 and nsectors-1. So we use updist and downdist condition
           // check if the point influence also neighboring slabs
           double updist = (double)((binphi+1)*yaxis)*dx - vvh;
           double downdist = vvh - (double)(binphi*yaxis)*dx;
           //
           histo_send[binphi]++;
           if(updist < w_supporth && updist >= 0.0) {histo_send[binphi+1]++; boundary[iphi] = binphi+1;};
           if(downdist < w_supporth && binphi > 0 && downdist >= 0.0) {histo_send[binphi-1]++; boundary[iphi] = binphi-1;};
    }

    sectorarray = (long**)malloc ((nsectors+1) * sizeof(long*));
    for(int sec=0; sec<(nsectors+1); sec++)
    {
      	   sectorarray[sec] = (long*)malloc(histo_send[sec]*sizeof(long));
    }

    long *counter = (long*) calloc(nsectors+1,sizeof(long));
    for (long iphi = 0; iphi < metaData.Nmeasures; iphi++)
    {
           vvh = data.vv[iphi];
           int binphi = (int)(vvh*nsectors);
           double updist = (double)((binphi+1)*yaxis)*dx - vvh;
           double downdist = vvh - (double)(binphi*yaxis)*dx;
           sectorarray[binphi][counter[binphi]] = iphi;
           counter[binphi]++;
           if(updist < w_supporth && updist >= 0.0) { sectorarray[binphi+1][counter[binphi+1]] = iphi; counter[binphi+1]++;};
           if(downdist < w_supporth && binphi > 0 && downdist >= 0.0) { sectorarray[binphi-1][counter[binphi-1]] = iphi; counter[binphi-1]++;};
    }
     
    
   #ifdef PIPPO
        long iiii = 0;
        for (int j=0; j<nsectors; j++)
        {
                iiii = 0;
                for(long iphi = histo_send[j]-1; iphi>=0; iphi--)
                {
                      printf("%d %d %ld %ld %ld\n",rank,j,iiii,histo_send[j],sectorarray[j][iphi]);
                      iiii++;
                }
        }
   #endif

    #ifdef VERBOSE
        for (int iii=0; iii<nsectors+1; iii++)printf("HISTO %d %d %ld\n",rank, iii, histo_send[iii]);
    #endif
}

void gridding_data(){

  double shift = (double)(dx*yaxis);
    
 #ifndef USE_MPI
  file.pFile1 = fopen (out.outfile1,"w");
 #endif

  timing.kernel_time = 0.0;
  timing.kernel_time1 = 0.0;
  timing.reduce_time = 0.0;
  timing.reduce_time1 = 0.0;
  timing.compose_time = 0.0;
  timing.compose_time1 = 0.0; 

  // calculate the resolution in radians
  resolution = 1.0/MAX(fabs(metaData.uvmin),fabs(metaData.uvmax));
    
  // calculate the resolution in arcsec 
  double resolution_asec = (3600.0*180.0)/MAX(fabs(metaData.uvmin),fabs(metaData.uvmax))/PI;
  if ( rank == 0 )
    printf("RESOLUTION = %f rad, %f arcsec\n", resolution, resolution_asec);

  //Initialize nccl
 #ifdef NCCL_REDUCE

  double * grid_gpu, *gridss_gpu;
  int local_rank = 0;

  uint64_t hostHashs[size];
  char hostname[1024];
  getHostName(hostname, 1024);
  hostHashs[rank] = getHostHash(hostname);
  MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD);
  for (int p=0; p<size; p++) {
     if (p == rank) break;
     if (hostHashs[p] == hostHashs[rank]) local_rank++;
  }
  
  ncclUniqueId id;
  ncclComm_t comm;
  hipStream_t stream_reduce;

  if (rank == 0) ncclGetUniqueId(&id);
  MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);

  hipSetDevice(local_rank);

  hipMalloc(&grid_gpu, 2*param.num_w_planes*xaxis*yaxis * sizeof(double));
  hipMalloc(&gridss_gpu, 2*param.num_w_planes*xaxis*yaxis * sizeof(double));
  hipStreamCreate(&stream_reduce);
  
  ncclCommInitRank(&comm, size, id, rank);
 #endif
      
  for (long isector = 0; isector < nsectors; isector++)
    {
      clock_gettime(CLOCK_MONOTONIC, &begink);
      startk = clock();
      // define local destination sector
      //isector = (isector_count+rank)%size;  // this line must be wrong! [LT]

      // allocate sector arrays 
      long    Nsec       = histo_send[isector];
      double *uus        = (double*) malloc(Nsec*sizeof(double));
      double *vvs        = (double*) malloc(Nsec*sizeof(double));
      double *wws        = (double*) malloc(Nsec*sizeof(double));
      long    Nweightss  = Nsec*metaData.polarisations;
      long    Nvissec    = Nweightss*metaData.freq_per_chan;
      float *weightss    = (float*) malloc(Nweightss*sizeof(float));
      float *visreals    = (float*) malloc(Nvissec*sizeof(float));
      float *visimgs     = (float*) malloc(Nvissec*sizeof(float));
       
      // select data for this sector
      long icount = 0;
      long ip = 0;
      long inu = 0;

      for(long iphi = histo_send[isector]-1; iphi>=0; iphi--)
        {
	  long ilocal = sectorarray[isector][iphi];
	  //double vvh = data.vv[ilocal];
	  //int binphi = (int)(vvh*nsectors);
	  //if (binphi == isector || boundary[ilocal] == isector) {
	  uus[icount] = data.uu[ilocal];
	  vvs[icount] = data.vv[ilocal]-isector*shift;
	  wws[icount] = data.ww[ilocal];
	  for (long ipol=0; ipol<metaData.polarisations; ipol++)
	    {
	      weightss[ip] = data.weights[ilocal*metaData.polarisations+ipol];
	      ip++;
	    }
	  for (long ifreq=0; ifreq<metaData.polarisations*metaData.freq_per_chan; ifreq++)
	    {
	      visreals[inu] = data.visreal[ilocal*metaData.polarisations*metaData.freq_per_chan+ifreq];
	      visimgs[inu] = data.visimg[ilocal*metaData.polarisations*metaData.freq_per_chan+ifreq];
	      //if(visimgs[inu]>1e10 || visimgs[inu]<-1e10)printf("%f %f %ld %ld %d %ld %ld\n",visreals[inu],visimgs[inu],inu,Nvissec,rank,ilocal*metaData.polarisations*metaData.freq_per_chan+ifreq,metaData.Nvis);
	      inu++;
	    }
	  icount++;
	}
      
      clock_gettime(CLOCK_MONOTONIC, &finishk);
      endk = clock();
      timing.compose_time += ((double) (endk - startk)) / CLOCKS_PER_SEC;
      timing.compose_time1 += (finishk.tv_sec - begink.tv_sec);
      timing.compose_time1 += (finishk.tv_sec - begink.tv_sec);
      timing.compose_time1 += (finishk.tv_nsec - begink.tv_nsec) / 1000000000.0;
      
     #ifndef USE_MPI
      double vvmin = 1e20;
      double uumax = -1e20;
      double vvmax = -1e20;
	 
      for (long ipart=0; ipart<Nsec; ipart++)
	{
	  uumin = MIN(uumin,uus[ipart]);
	  uumax = MAX(uumax,uus[ipart]);
	  vvmin = MIN(vvmin,vvs[ipart]);
	  vvmax = MAX(vvmax,vvs[ipart]);
	     
	  if(ipart%10 == 0)fprintf (file.pFile, "%ld %f %f %f\n",isector,uus[ipart],vvs[ipart]+isector*shift,wws[ipart]);
	}
	 
      printf("UU, VV, min, max = %f %f %f %f\n", uumin, uumax, vvmin, vvmax);
     #endif

      // Make convolution on the grid

     #ifdef VERBOSE
      printf("Processing sector %ld\n",isector);
     #endif
      clock_gettime(CLOCK_MONOTONIC, &begink);
      startk = clock();

     //We have to call different GPUs per MPI task!!! [GL]
      wstack(param.num_w_planes,
	     Nsec,
	     metaData.freq_per_chan,
	     metaData.polarisations,
	     uus,
	     vvs,
	     wws,
	     visreals,
	     visimgs,
	     weightss,
	     dx,
	     dw,
	     param.w_support,
	     xaxis,
	     yaxis,
	     gridss,
	     param.num_threads,
	     rank);
      //Allocate memory on devices non-blocking for the host                                                                                   
      ///////////////////////////////////////////////////////

     #ifdef NCCL_REDUCE
      hipMemcpyAsync(gridss_gpu, gridss, 2*param.num_w_planes*xaxis*yaxis*sizeof(double), hipMemcpyHostToDevice, stream_reduce);
     #endif
      /* int z =0 ;
       * #pragma omp target map(to:test_i_gpu) map(from:z)
       * {
       *   int x; // only accessible from accelerator
       *     x = 2;
       *       z = x + test_i_gpu;
       *       }*/

      clock_gettime(CLOCK_MONOTONIC, &finishk);
      endk = clock();
      timing.kernel_time += ((double) (endk - startk)) / CLOCKS_PER_SEC;
      timing.kernel_time1 += (finishk.tv_sec - begink.tv_sec);
      timing.kernel_time1 += (finishk.tv_nsec - begink.tv_nsec) / 1000000000.0;
     #ifdef VERBOSE
      printf("Processed sector %ld\n",isector);
     #endif
      
      
      clock_gettime(CLOCK_MONOTONIC, &begink);
      startk = clock();

      //for (long iii=0; iii<2*xaxis*yaxis*num_w_planes; iii++)printf("--> %f\n",gridss[iii]);
    
     #ifndef USE_MPI
      long stride = isector*2*xaxis*yaxis*num_w_planes;
      for (long iii=0; iii<2*xaxis*yaxis*num_w_planes; iii++)
	gridtot[stride+iii] = gridss[iii];
     #endif

      // Write grid in the corresponding remote slab
     #ifdef USE_MPI
      // int target_rank = (int)isector;    it implied that size >= nsectors
      int target_rank = (int)(isector % size);
     #ifdef NCCL_REDUCE
      hipStreamSynchronize(stream_reduce);
    
      ncclReduce(gridss_gpu, grid_gpu, size_of_grid, ncclDouble,
		 ncclSum, target_rank, comm, stream_reduce);
      hipStreamSynchronize(stream_reduce);
      //hipMemcpyAsync(grid, grid_gpu, 2*param.num_w_planes*xaxis*yaxis*sizeof(double), hipMemcpyDeviceToHost, stream_reduce);
      //hipStreamSynchronize(stream_reduce);
     #endif
      
     #ifdef REDUCE
      MPI_Reduce(gridss,grid,size_of_grid,MPI_DOUBLE,MPI_SUM,target_rank,MPI_COMM_WORLD);
     #endif //REDUCE
      
       //Let's use now the new implementation (ring in shmem and Ired inter-nodes)
     #endif //USE_MPI
	       
       clock_gettime(CLOCK_MONOTONIC, &finishk);
       endk = clock();
       timing.reduce_time += ((double) (endk - startk)) / CLOCKS_PER_SEC;
       timing.reduce_time1 += (finishk.tv_sec - begink.tv_sec);
       timing.reduce_time1 += (finishk.tv_nsec - begink.tv_nsec) / 1000000000.0;
       // Go to next sector
       for (long inull=0; inull<2*param.num_w_planes*xaxis*yaxis; inull++)gridss[inull] = 0.0;

       // Deallocate all sector arrays
       free(uus);
       free(vvs);
       free(wws);
       free(weightss);
       free(visreals);
       free(visimgs);
      // End of loop over sector    
    }

  //Copy data back from device to host (to be deleted in next steps)

 #ifdef NCCL_REDUCE
  hipMemcpyAsync(grid, grid_gpu, 2*param.num_w_planes*xaxis*yaxis*sizeof(double), hipMemcpyDeviceToHost, stream_reduce);
 #endif
    #ifndef USE_MPI
        fclose(file.pFile1);
    #endif

    	
    #ifdef USE_MPI
        MPI_Barrier(MPI_COMM_WORLD);
    #endif

    end = clock();
    clock_gettime(CLOCK_MONOTONIC, &finish);
    timing.process_time = ((double) (end - start)) / CLOCKS_PER_SEC;
    timing.process_time1 = (finish.tv_sec - begin.tv_sec);
    timing.process_time1 += (finish.tv_nsec - begin.tv_nsec) / 1000000000.0;
    clock_gettime(CLOCK_MONOTONIC, &begin);

   #ifdef NCCL_REDUCE
    hipStreamSynchronize(stream_reduce);
    hipFree(gridss_gpu);
    hipFree(grid_gpu);

    hipStreamDestroy(stream_reduce);
    
    ncclCommDestroy(comm);
   #endif
}

void write_grided_data()
{

   #ifdef WRITE_DATA
     // Write results
     if (rank == 0)
     {
        printf("WRITING GRIDDED DATA\n");
        file.pFilereal = fopen (out.outfile2,"wb");
        file.pFileimg = fopen (out.outfile3,"wb");
        #ifdef USE_MPI
           for (int isector=0; isector<nsectors; isector++)
           {
	    #ifdef RING //Let the MPI_Get copy from the right location (Results must be checked!) [GL]
	     MPI_Get(gridss,size_of_grid,MPI_DOUBLE,isector,0,size_of_grid,MPI_DOUBLE,Me.win.win);
       	    #else
              MPI_Win_lock(MPI_LOCK_SHARED,isector,0,slabwin);
              MPI_Get(gridss,size_of_grid,MPI_DOUBLE,isector,0,size_of_grid,MPI_DOUBLE,slabwin);
              MPI_Win_unlock(isector,slabwin);
	    #endif
              for (long i=0; i<size_of_grid/2; i++)
              {
                      gridss_real[i] = gridss[2*i];
                      gridss_img[i] = gridss[2*i+1];
              }
              if (param.num_w_planes > 1)
              {
                      for (int iw=0; iw<param.num_w_planes; iw++)
                        for (int iv=0; iv<yaxis; iv++)
                          for (int iu=0; iu<xaxis; iu++)
                          {
                               long global_index = (iu + (iv+isector*yaxis)*xaxis + iw*param.grid_size_x*param.grid_size_y)*sizeof(double);
                               long index = iu + iv*xaxis + iw*xaxis*yaxis;
                               fseek(file.pFilereal, global_index, SEEK_SET);
                               fwrite(&gridss_real[index], 1, sizeof(double), file.pFilereal);
                          }
                      for (int iw=0; iw<param.num_w_planes; iw++)
                        for (int iv=0; iv<yaxis; iv++)
                          for (int iu=0; iu<xaxis; iu++)
                          {
                               long global_index = (iu + (iv+isector*yaxis)*xaxis + iw*param.grid_size_x*param.grid_size_y)*sizeof(double);
                               long index = iu + iv*xaxis + iw*xaxis*yaxis;
                               fseek(file.pFileimg, global_index, SEEK_SET);
                               fwrite(&gridss_img[index], 1, sizeof(double), file.pFileimg);
                               //double v_norm = sqrt(gridss[index]*gridss[index]+gridss[index+1]*gridss[index+1]);
                               //fprintf (file.pFile, "%d %d %d %f %f %f\n", iu,isector*yaxis+iv,iw,gridss[index],gridss[index+1],v_norm);
                          }

              }
              else
              {
                      for (int iw=0; iw<param.num_w_planes; iw++)
                      {
                          long global_index = (xaxis*isector*yaxis + iw*param.grid_size_x*param.grid_size_y)*sizeof(double);
                          long index = iw*xaxis*yaxis;
                          fseek(file.pFilereal, global_index, SEEK_SET);
                          fwrite(&gridss_real[index], xaxis*yaxis, sizeof(double), file.pFilereal);
                          fseek(file.pFileimg, global_index, SEEK_SET);
                          fwrite(&gridss_img[index], xaxis*yaxis, sizeof(double), file.pFileimg);
                     }
              }
          }
       #else
          for (int iw=0; iw<param.num_w_planes; iw++)
             for (int iv=0; iv<param.grid_size_y; iv++)
               for (int iu=0; iu<param.grid_size_x; iu++)
                {
                      long index = 2*(iu + iv*param.grid_size_x + iw*param.grid_size_x*param.grid_size_y);
                      fwrite(&gridtot[index], 1, sizeof(double), file.pFilereal);
                      fwrite(&gridtot[index+1], 1, sizeof(double), file.pFileimg);
                      //double v_norm = sqrt(gridtot[index]*gridtot[index]+gridtot[index+1]*gridtot[index+1]);
                      //fprintf (file.pFile, "%d %d %d %f %f %f\n", iu,iv,iw,gridtot[index],gridtot[index+1],v_norm);
                 }
        #endif
        fclose(file.pFilereal);
        fclose(file.pFileimg);
     }

     #ifdef USE_MPI
        MPI_Win_fence(0,slabwin);
     #endif

   #endif //WRITE_DATA 

}
