Commit 3e0025ff authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Revise FRFME offloaded loop implementation

parent 4dfc33af
Loading
Loading
Loading
Loading
+65 −58
Original line number Diff line number Diff line
@@ -630,13 +630,14 @@ void offload_loop(
) {
  int nvtot = nxv * nyv * nzv;
  int nkvs = nkv * nkv;
  int nkvmo = nkvmo;
  int nkvmo = nkv - 1; 
  int nkvvmo = nkvmo * nkv;
  int nvxy = nxv * nyv;
  const dcomplex uim = 0.0 + I * 1.0;
  dcomplex cc0 = 0.0 + I * 0.0;
  dcomplex uim = 0.0 + I * 1.0;

#pragma omp parallel for simd collapse(2)
 // Inizializza global_vec_w e vec_wsum sulla CPU in parallelo
#pragma omp parallel for collapse(2)
  for (int j80 = jlmf - 1; j80 < jlml; j80++) {
    for (int jxy50 = 0; jxy50 < nkvs; jxy50++) {
      int j80_index = j80 - jlmf + 1;
@@ -649,68 +650,74 @@ void offload_loop(
    } // jxy50 loop
  }
  
#pragma omp parallel for
  for (long i = 0; i < size_vec_wsum; i++) {
    vec_wsum[i] = cc0;
  }

#pragma omp target data map(tofrom: vec_wsum[0:size_vec_wsum]) \
  map(to: global_vec_w[0:size_global_vec_w]) \
  map(to: vec_tt1_wk[0:size_vec_tt1_wk]) \
  map(to: _xv[0:nxv], _yv[0:nyv], _zv[0:nzv]) \
  map(to: vkv[0:nkv], vec_vkzm[0:nkvs])
  {
    // Kernel 1: run the calculation
#pragma omp target teams distribute parallel for collapse(3)
    // Esegue il calcolo principale in un unico kernel GPU
    // con parallelizzazione sui cicli più grandi (nvtot e nkvs)
#pragma omp target teams distribute parallel for collapse(2)
    for (int ixyz = 0; ixyz < nvtot; ixyz++) {
      for (int jy60x55 = 0; jy60x55 < nkvs ; jy60x55++) {
	for (int j80 = jlmf - 1; j80 < jlml; j80++) {
	  int j80_index = j80 - jlmf + 1;
	  dcomplex *vec_w = global_vec_w + nkvs * j80_index;

        int iz75 = ixyz / nvxy;
        int iy70 = (ixyz % nvxy) / nxv;
        int ix65 = ixyz % nxv;
        double z = _zv[iz75] + frsh;
        double y = _yv[iy70];
        double x = _xv[ix65];

        int jy60 = jy60x55 / nkv;
        int jx55 = jy60x55 % nkv;
        int w_index = (jx55 * nkv) + jy60;
        double vky = vkv[jy60];
        double factor = (jy60 == 0 || jy60 == nkvmo) ? 0.5 : 1.0;
	  long wsum_index = (j80_index * nvtot) + ixyz;
        double vkx, vkz;
        dcomplex phas, term;

        if (jx55 == 0) {
	    // jx55 = 0: phasf
	    double vkx = vkv[nkvmo];
	    double vkz = vec_vkzm[jy60];
          vkx = vkv[nkvmo];
          vkz = vec_vkzm[jy60];
          double angle = -vkx * x + vky * y + vkz * z;
          double s, c;
          sincos(angle, &s, &c);
	    dcomplex phasf = c + uim * s;
	    dcomplex term = vec_w[jy60] * phasf * 0.5;
          phas = c + uim * s;
          term = phas * 0.5;
          term *= factor;
	    vec_wsum[wsum_index] += term * delks;
        } else if (jx55 == nkvmo) {
	    // jx55 = nkv - 1: phasl
	    double vkx = vkv[nkvmo];
	    double vkz = vec_vkzm[nkvvmo + jy60];
          vkx = vkv[nkvmo];
          vkz = vec_vkzm[nkvvmo + jy60];
          double angle = vkx * x + vky * y + vkz * z;
          double s, c;
          sincos(angle, &s, &c);
	    dcomplex phasl = c + uim * s;
	    dcomplex term = vec_w[nkvvmo + jy60] * phasl * 0.5;
          phas = c + uim * s;
          term = phas * 0.5;
          term *= factor;
	    vec_wsum[wsum_index] += term * delks;
        } else {
	    // 1 <= jx55 < nkv - 1
	    double vkx = vkv[jx55];
	    double vkz = vec_vkzm[w_index];
          vkx = vkv[jx55];
          vkz = vec_vkzm[w_index];
          double angle = vkx * x + vky * y + vkz * z;
          double s, c;
          sincos(angle, &s, &c);
	    dcomplex phas = c + uim * s;
	    dcomplex term = vec_w[w_index] * phas;
	    term *= factor;
	    vec_wsum[wsum_index] += term * delks;
          phas = c + uim * s;
          term = phas * factor;
        }
        
        // L'ultima loop è ora seriale, garantendo la correttezza dei risultati
        for (int j80 = jlmf - 1; j80 < jlml; j80++) {
          int j80_index = j80 - jlmf + 1;
          dcomplex *vec_w = global_vec_w + nkvs * j80_index;
          long wsum_index = (j80_index * nvtot) + ixyz;
          vec_wsum[wsum_index] += delks * vec_w[w_index] * term;
        }
      } // jy60x55 loop
    } // ixyz loop
    } // j80 loop
  } // target region
}
#endif // USE TARGET_OFFLOAD