Commit a0a3f870 authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Fix nested loop linearization for GPU and CPU codes

parent 17d8bd16
Loading
Loading
Loading
Loading
+116 −72
Original line number Diff line number Diff line
@@ -96,10 +96,10 @@
 * expansion order.
 */
void map_data(
  dcomplex *vec_wsum, int size_vec_wsum, dcomplex *global_vec_w, int size_global_vec_w,
  const dcomplex *vec_tt1_wk, int size_vec_tt1_wk, double *vkv, double *_xv, int nxv,
  double *_yv, int nyv, double *_zv, int nzv, double *vec_vkzm, int jlmf, int jlml,
  int nkv, int nlmmt
  dcomplex *vec_wsum, const int size_vec_wsum, dcomplex *global_vec_w, const int size_global_vec_w,
  const dcomplex *vec_tt1_wk, const int size_vec_tt1_wk, double *vkv, double *_xv, const int nxv,
  double *_yv, int nyv, double *_zv, const int nzv, double *vec_vkzm, const int jlmf, const int jlml,
  const int nkv, const int nlmmt
);

/*! \brief Specialized function to perform GPU-offloaded trapping loop.
@@ -137,10 +137,10 @@ void map_data(
 * \param frsh: `double` Instrumental offset along beam z-axis.
 */
void offload_loop(
  dcomplex *vec_wsum, int size_vec_wsum, dcomplex *global_vec_w, int size_global_vec_w,
  const dcomplex *vec_tt1_wk, int size_vec_tt1_wk, double *vkv, double *_xv, int nxv,
  double *_yv, int nyv, double *_zv, int nzv, double *vec_vkzm, int jlmf, int jlml,
  int nkv, int nlmmt, double delks, double frsh
  dcomplex *vec_wsum, int size_vec_wsum, dcomplex *global_vec_w, const int size_global_vec_w,
  const dcomplex *vec_tt1_wk, const int size_vec_tt1_wk, double *vkv, double *_xv, const int nxv,
  double *_yv, const int nyv, double *_zv, const int nzv, double *vec_vkzm, const int jlmf, const int jlml,
  const int nkv, const int nlmmt, double delks, double frsh
);

/*! \brief Specialized function to clean memory on the device.
@@ -172,10 +172,10 @@ void offload_loop(
 * expansion order.
 */
void unmap_data(
  dcomplex *vec_wsum, int size_vec_wsum, dcomplex *global_vec_w, int size_global_vec_w,
  const dcomplex *vec_tt1_wk, int size_vec_tt1_wk, double *vkv, double *_xv, int nxv,
  double *_yv, int nyv, double *_zv, int nzv, double *vec_vkzm, int jlmf, int jlml,
  int nkv, int nlmmt
  dcomplex *vec_wsum, const int size_vec_wsum, dcomplex *global_vec_w, const int size_global_vec_w,
  const dcomplex *vec_tt1_wk, const int size_vec_tt1_wk, double *vkv, double *_xv, const int nxv,
  double *_yv, const int nyv, double *_zv, const int nzv, double *vec_vkzm, const int jlmf, const int jlml,
  const int nkv, const int nlmmt
);
#endif

@@ -214,7 +214,7 @@ void frfme(string data_file, string output_path) {
  str_target = m.suffix().str();
  regex_search(str_target, m, re);
  int jlml = stoi(m.str());
  int lmode = 0, lm = 0, nks = 0, nkv = 0;
  int lmode = 0, lm = 0, nks = 0;
  double vk = 0.0, exri = 0.0, an = 0.0, ff = 0.0, tra = 0.0;
  double exdc = 0.0, wp = 0.0, xip = 0.0, xi = 0.0;
  int idfc = 0, nxi = 0;
@@ -236,7 +236,8 @@ void frfme(string data_file, string output_path) {
    if (tfrfme != NULL) {
      lmode = tfrfme->lmode;
      lm = tfrfme->lm;
      nkv = tfrfme->nkv;
      const int nkv = tfrfme->nkv;
      nks = nkv - 1;
      nxv = tfrfme->nxv;
      nyv = tfrfme->nyv;
      nzv = tfrfme->nzv;
@@ -274,7 +275,6 @@ void frfme(string data_file, string output_path) {
      message = "ERROR: could not open TFRFME file.\n";
      logger.err(message);
    }
    nks = nkv - 1;
#ifdef USE_NVTX
    nvtxRangePop();
#endif
@@ -394,7 +394,7 @@ void frfme(string data_file, string output_path) {
#endif
	nlmmt = lm * (lm + 2) * 2;
	nks = nksh * 2;
	nkv = nks + 1;
	const int nkv = nks + 1;
	// Array initialization
	long swap1_size, swap2_size, tfrfme_size;
	double size_mb;
@@ -509,7 +509,7 @@ void frfme(string data_file, string output_path) {
#ifdef USE_NVTX
	  nvtxRangePush("j80 loop");
#endif
	  int nkvs = nkv * nkv;
	  const int nkvs = nkv * nkv;
	  int size_vec_wsum = nlmmt * nrvc;
	  int size_global_vec_w = nkvs * (jlml - jlmf + 1);
	  int size_vec_tt1_wk = nkvs * nlmmt;
@@ -551,36 +551,59 @@ void frfme(string data_file, string output_path) {
	      int jx50 = jxy50 % nkv;
	      vec_w[(nkv * jx50) + jy50] = wk_value;
	    } // jxy50 loop
#pragma omp parallel for simd
	    for (int wj = 0; wj < nrvc; wj++) vec_wsum[(j80 * nrvc) + wj] = cc0;
	    int nvtot = nxv * nyv * nzv;
	    int nvxy = nxv * nyv;
#pragma omp parallel for
	    for (int ixyz = 0; ixyz < nvtot; ixyz++) {
	    for (int ixyz = 0; ixyz < nrvc; ixyz++) {
	      int iz75 = ixyz / nvxy;
	      int iy70 = (ixyz % nvxy) / nxv;
	      int ix65 = ixyz % nxv;
	      double z = _zv[iz75] + frsh;
	      double y = _yv[iy70];
	      double x = _xv[ix65];
	      dcomplex sumy = cc0;
#pragma omp parallel for simd reduction(+:sumy)
	      double rsumy = 0.0;
	      double isumy = 0.0;
#pragma omp parallel for simd reduction(+:rsumy, isumy)
	      for (int jy60x55 = 0; jy60x55 < nkvs ; jy60x55++) {
		int jy60 = jy60x55 / nkv;
		int jx55 = jy60x55 % nkv;
		int w_index = (jx55 * nkv) + jy60;
		double vky = vkv[jy60];
		double vkx = (jx55 == 0) ? vkv[nkv - 1] : vkv[jx55];
		double vkz = vec_vkzm[(jx55 * nkv) + jy60];
		dcomplex phas = (jx55 == 0) ?
		  cexp(uim * (-vkx * x + vky * y + vkz * z)):
		  cexp(uim * (vkx * x + vky * y + vkz * z));
		dcomplex sumx = vec_w[w_index] * phas;
		double factor1 = ((jx55 == 0) || (jx55 == (nkv - 1))) ? 0.5 : 1.0;
		double factor2 = ((jy60 == 0) || (jy60 == (nkv - 1))) ? 0.5 : 1.0;
		sumx *= factor1*factor2;
		sumy += sumx;
		if (jx55 == 0) {
		  // jx55 = 0: phasf
		  double vkx = vkv[nkv - 1];
		  double vkz = vec_vkzm[jy60];
		  double angle = -vkx * x + vky * y + vkz * z;
		  dcomplex phasf = cos(angle) + uim * sin(angle);
		  dcomplex term = vec_w[jy60] * phasf * 0.5;
		  double factor = (jy60 == 0 || jy60 == nkv - 1) ? 0.5 : 1.0;
		  term *= factor;
		  rsumy += (real(term));
		  isumy += (imag(term));
		} else if (jx55 == nkv - 1) {
		  // jx55 = nkv - 1: phasl
		  double vkx = vkv[nkv - 1];
		  double vkz = vec_vkzm[(nkv - 1) * nkv + jy60];
		  double angle = vkx * x + vky * y + vkz * z;
		  dcomplex phasl = cos(angle) + uim * sin(angle);
		  dcomplex term = vec_w[(nkv - 1) * nkv + jy60] * phasl * 0.5;
		  double factor = (jy60 == 0 || jy60 == nkv - 1) ? 0.5 : 1.0;
		  term *= factor;
		  rsumy += (real(term));
		  isumy += (imag(term));
		} else {
		  // 1 <= jx55 < nkv - 1
		  double vkx = vkv[jx55];
		  double vkz = vec_vkzm[(jx55) * nkv + jy60];
		  double angle = vkx * x + vky * y + vkz * z;
		  dcomplex phas = cos(angle) + uim * sin(angle);
		  dcomplex term = vec_w[(jx55) * nkv + jy60] * phas;
		  double factor = (jy60 == 0 || jy60 == nkv - 1) ? 0.5 : 1.0;
		  term *= factor;
		  rsumy += (real(term));
		  isumy += (imag(term));
		}
	      } // jy60x55 loop
	      dcomplex sumy = rsumy + uim * isumy;
	      vec_wsum[(j80 * nrvc) + ixyz] = sumy * delks;
	    } // ixyz loop
	  } // j80 loop
@@ -644,12 +667,12 @@ void frfme(string data_file, string output_path) {

#ifdef USE_TARGET_OFFLOAD
void map_data(
  dcomplex *vec_wsum, int size_vec_wsum, dcomplex *global_vec_w, int size_global_vec_w,
  const dcomplex *vec_tt1_wk, int size_vec_tt1_wk, double *vkv, double *_xv, int nxv,
  double *_yv, int nyv, double *_zv, int nzv, double *vec_vkzm, int jlmf, int jlml,
  int nkv, int nlmmt
  dcomplex *vec_wsum, const int size_vec_wsum, dcomplex *global_vec_w, const int size_global_vec_w,
  const dcomplex *vec_tt1_wk, const int size_vec_tt1_wk, double *vkv, double *_xv, const int nxv,
  double *_yv, const int nyv, double *_zv, const int nzv, double *vec_vkzm, const int jlmf, const int jlml,
  const int nkv, const int nlmmt
) {
  int nkvs = nkv * nkv;
  const int nkvs = nkv * nkv;
#pragma omp target enter data map(to: vec_wsum[0:size_vec_wsum]) \
  map(alloc: global_vec_w[0:size_global_vec_w]) \
  map(to: vec_tt1_wk[0:size_vec_tt1_wk]) \
@@ -658,13 +681,13 @@ void map_data(
}

void offload_loop(
  dcomplex *vec_wsum, int size_vec_wsum, dcomplex *global_vec_w, int size_global_vec_w,
  const dcomplex *vec_tt1_wk, int size_vec_tt1_wk, double *vkv, double *_xv, int nxv,
  double *_yv, int nyv, double *_zv, int nzv, double *vec_vkzm, int jlmf, int jlml,
  int nkv, int nlmmt, double delks, double frsh
  dcomplex *vec_wsum, const int size_vec_wsum, dcomplex *global_vec_w, const int size_global_vec_w,
  const dcomplex *vec_tt1_wk, const int size_vec_tt1_wk, double *vkv, double *_xv, const int nxv,
  double *_yv, const int nyv, double *_zv, const int nzv, double *vec_vkzm, const int jlmf, const int jlml,
  const int nkv, const int nlmmt, double delks, double frsh
) {
  int nrvc = nxv * nyv * nzv;
  int nkvs = nkv * nkv;
  const int nkvs = nkv * nkv;
  const dcomplex cc0 = 0.0 + I * 0.0;
  const dcomplex uim = 0.0 + I * 1.0;
#pragma omp target teams distribute parallel for \
@@ -680,50 +703,71 @@ void offload_loop(
      int jx50 = jxy50 % nkv;
      vec_w[(nkv * jx50) + jy50] = wk_value;
    } // jxy50 loop
#pragma omp parallel for simd
    for (int wj = 0; wj < nrvc; wj++) vec_wsum[(j80 * nrvc) + wj] = cc0;
    int nvtot = nxv * nyv * nzv;
    int nvxy = nxv * nyv;
#pragma omp parallel for
    for (int ixyz = 0; ixyz < nvtot; ixyz++) {
    for (int ixyz = 0; ixyz < nrvc; ixyz++) {
      int iz75 = ixyz / nvxy;
      int iy70 = (ixyz % nvxy) / nxv;
      int ix65 = ixyz % nxv;
      double z = _zv[iz75] + frsh;
      double y = _yv[iy70];
      double x = _xv[ix65];
      dcomplex sumy = cc0;
#pragma omp parallel for simd reduction(+:sumy)
      double rsumy = 0.0;
      double isumy = 0.0;
#pragma omp parallel for reduction(+:rsumy, isumy)
      for (int jy60x55 = 0; jy60x55 < nkvs ; jy60x55++) {
	int jy60 = jy60x55 / nkv;
	int jx55 = jy60x55 % nkv;
	int w_index = (jx55 * nkv) + jy60;
	double vky = vkv[jy60];
	double vkx = (jx55 == 0) ? vkv[nkv - 1] : vkv[jx55];
	double vkz = vec_vkzm[(jx55 * nkv) + jy60];
	double sign = (jx55 == 0) ? -1.0 : 1.0;
	double rpart = cos(vkx * x + vky * y + vkz * z);
	double ipart = sin(sign * vkx * x + vky * y + vkz * z);
	dcomplex phas = rpart + uim * ipart;
	// dcomplex phas = cexp(uim * (sign * vkx * x + vky * y + vkz * z));
	dcomplex sumx = vec_w[w_index] * phas;
	double factor1 = ((jx55 == 0) || (jx55 == (nkv - 1))) ? 0.5 : 1.0;
	double factor2 = ((jy60 == 0) || (jy60 == (nkv - 1))) ? 0.5 : 1.0;
	sumx *= factor1*factor2;
	sumy += sumx;
	if (jx55 == 0) {
	  // jx55 = 0: phasf
	  double vkx = vkv[nkv - 1];
	  double vkz = vec_vkzm[jy60];
	  double angle = -vkx * x + vky * y + vkz * z;
	  dcomplex phasf = cos(angle) + uim * sin(angle);
	  dcomplex term = vec_w[jy60] * phasf * 0.5;
	  double factor = (jy60 == 0 || jy60 == nkv - 1) ? 0.5 : 1.0;
	  term *= factor;
	  rsumy += (real(term));
	  isumy += (imag(term));
	} else if (jx55 == nkv - 1) {
	  // jx55 = nkv - 1: phasl
	  double vkx = vkv[nkv - 1];
	  double vkz = vec_vkzm[(nkv - 1) * nkv + jy60];
	  double angle = vkx * x + vky * y + vkz * z;
	  dcomplex phasl = cos(angle) + uim * sin(angle);
	  dcomplex term = vec_w[(nkv - 1) * nkv + jy60] * phasl * 0.5;
	  double factor = (jy60 == 0 || jy60 == nkv - 1) ? 0.5 : 1.0;
	  term *= factor;
	  rsumy += (real(term));
	  isumy += (imag(term));
	} else {
	  // 1 <= jx55 < nkv - 1
	  double vkx = vkv[jx55];
	  double vkz = vec_vkzm[(jx55) * nkv + jy60];
	  double angle = vkx * x + vky * y + vkz * z;
	  dcomplex phas = cos(angle) + uim * sin(angle);
	  dcomplex term = vec_w[(jx55) * nkv + jy60] * phas;
	  double factor = (jy60 == 0 || jy60 == nkv - 1) ? 0.5 : 1.0;
	  term *= factor;
	  rsumy += (real(term));
	  isumy += (imag(term));
	}
      } // jy60x55 loop
      dcomplex sumy = rsumy + uim * isumy;
      vec_wsum[(j80 * nrvc) + ixyz] = sumy * delks;
    } // ixyz loop
  } // j80 loop
}

void unmap_data(
  dcomplex *vec_wsum, int size_vec_wsum, dcomplex *global_vec_w, int size_global_vec_w,
  const dcomplex *vec_tt1_wk, int size_vec_tt1_wk, double *vkv, double *_xv, int nxv,
  double *_yv, int nyv, double *_zv, int nzv, double *vec_vkzm, int jlmf, int jlml,
  int nkv, int nlmmt
  dcomplex *vec_wsum, const int size_vec_wsum, dcomplex *global_vec_w, const int size_global_vec_w,
  const dcomplex *vec_tt1_wk, const int size_vec_tt1_wk, double *vkv, double *_xv, const int nxv,
  double *_yv, const int nyv, double *_zv, const int nzv, double *vec_vkzm, const int jlmf, const int jlml,
  const int nkv, const int nlmmt
) {
  int nkvs = nkv * nkv;
  const int nkvs = nkv * nkv;
#pragma omp target exit data map(from: vec_wsum[0:size_vec_wsum]) \
  map(delete: global_vec_w[0:size_global_vec_w]) \
  map(delete: vec_tt1_wk[0:size_vec_tt1_wk]) \