Commit 730486ef authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Refactor the offloaded nested loop structure to take advantage from teams

parent a0a3f870
Loading
Loading
Loading
Loading
+44 −20
Original line number Diff line number Diff line
@@ -191,6 +191,7 @@ void frfme(string data_file, string output_path) {
  nvtxRangePush("Running frfme()");
#endif
  chrono::time_point<chrono::high_resolution_clock> t_start = chrono::high_resolution_clock::now();
  chrono::time_point<chrono::high_resolution_clock> t_end;
  chrono::duration<double> elapsed;
  char buffer[256];
  string message = "INIT";
@@ -437,6 +438,13 @@ void frfme(string data_file, string output_path) {
	sprintf(buffer, "TOTAL: %.2lg MB\n", size_mb);
	message = string(buffer);
	logger.log(message);
	sprintf(buffer, "INFO: nxv = %d, nyv = %d, nzv = %d, nrvc = %d\n", nxv, nyv, nzv, (nxv * nyv * nzv));
	message = string(buffer);
	logger.log(message);
	sprintf(buffer, "INFO: nlmmt = %d, nkv = %d\n", nlmmt, nkv);
	message = string(buffer);
	logger.log(message);
	
	tfrfme = new TFRFME(lmode, lm, nkv, nxv, nyv, nzv);
	double *_xv = tfrfme->get_x();
	double *_yv = tfrfme->get_y();
@@ -518,6 +526,7 @@ void frfme(string data_file, string output_path) {
	  double *vec_vkzm = vkzm[0];
	  dcomplex *global_vec_w = new dcomplex[size_global_vec_w];
#ifdef USE_TARGET_OFFLOAD
	  t_start = chrono::high_resolution_clock::now();
	  message = "INFO: Mapping data to device.\n";
	  logger.log(message);
	  map_data(
@@ -525,6 +534,12 @@ void frfme(string data_file, string output_path) {
	    size_vec_tt1_wk, vkv, _xv, nxv, _yv, nyv, _zv, nzv, vec_vkzm, jlmf, jlml,
	    nkv, nlmmt
          );
	  t_end = chrono::high_resolution_clock::now();
	  elapsed = t_end - t_start;
	  sprintf(buffer, "INFO: copying data to device took %lfs.\n", elapsed.count());
	  message = string(buffer);
	  logger.log(message);
	  t_start = chrono::high_resolution_clock::now();
	  message = "INFO: computing loop.\n";
	  logger.log(message);
	  offload_loop(
@@ -532,6 +547,12 @@ void frfme(string data_file, string output_path) {
	    size_vec_tt1_wk, vkv, _xv, nxv, _yv, nyv, _zv, nzv, vec_vkzm, jlmf, jlml,
	    nkv, nlmmt, delks, frsh
          );
	  t_end = chrono::high_resolution_clock::now();
	  elapsed = t_end - t_start;
	  sprintf(buffer, "INFO: loop calculation took %lfs.\n", elapsed.count());
	  message = string(buffer);
	  logger.log(message);
	  t_start = chrono::high_resolution_clock::now();
	  message = "INFO: cleaning device memory.\n";
	  logger.log(message);
	  unmap_data(
@@ -539,6 +560,11 @@ void frfme(string data_file, string output_path) {
	    size_vec_tt1_wk, vkv, _xv, nxv, _yv, nyv, _zv, nzv, vec_vkzm, jlmf, jlml,
	    nkv, nlmmt
          );
	  t_end = chrono::high_resolution_clock::now();
	  elapsed = t_end - t_start;
	  sprintf(buffer, "INFO: result recovery and device memory clean-up took %lfs.\n", elapsed.count());
	  message = string(buffer);
	  logger.log(message);
#else
#pragma omp parallel for
	  for (int j80 = jlmf - 1; j80 < jlml; j80++) {
@@ -687,24 +713,24 @@ void offload_loop(
  const int nkv, const int nlmmt, double delks, double frsh
) {
  int nrvc = nxv * nyv * nzv;
  const int nkvs = nkv * nkv;
  const dcomplex cc0 = 0.0 + I * 0.0;
  const dcomplex uim = 0.0 + I * 1.0;
#pragma omp target teams distribute parallel for \
  map(to: jlmf, jlml, nkv, nkvs, nrvc, nlmmt, nxv, nyv, nzv)	\
  map(to: delks, frsh) map(to: cc0, uim)
  int nkvs = nkv * nkv;
  int nkvmo = nkv - 1;
  int nkvvmo = (nkv - 1) * nkv;
  int nvxy = nxv * nyv;
  dcomplex cc0 = 0.0 + I * 0.0;
  dcomplex uim = 0.0 + I * 1.0;
#pragma omp target
  for (int j80 = jlmf - 1; j80 < jlml; j80++) {
    dcomplex *vec_w = global_vec_w + nkvs * (j80 - jlmf + 1);
#pragma omp parallel for simd
    int j80_index = j80 - jlmf + 1;
    dcomplex *vec_w = global_vec_w + nkvs * j80_index;
    for (int jxy50 = 0; jxy50 < nkvs; jxy50++) {
      int wk_index = nlmmt * jxy50;
      dcomplex wk_value = vec_tt1_wk[wk_index + j80];
      dcomplex wk_value = vec_tt1_wk[wk_index + j80_index];
      int jy50 = jxy50 / nkv;
      int jx50 = jxy50 % nkv;
      vec_w[(nkv * jx50) + jy50] = wk_value;
    } // jxy50 loop
    int nvxy = nxv * nyv;
#pragma omp parallel for
#pragma omp teams distribute parallel for
    for (int ixyz = 0; ixyz < nrvc; ixyz++) {
      int iz75 = ixyz / nvxy;
      int iy70 = (ixyz % nvxy) / nxv;
@@ -720,43 +746,41 @@ void offload_loop(
	int jx55 = jy60x55 % nkv;
	int w_index = (jx55 * nkv) + jy60;
	double vky = vkv[jy60];
	double factor = (jy60 == 0 || jy60 == nkv - 1) ? 0.5 : 1.0;
	if (jx55 == 0) {
	  // jx55 = 0: phasf
	  double vkx = vkv[nkv - 1];
	  double vkx = vkv[nkvmo];
	  double vkz = vec_vkzm[jy60];
	  double angle = -vkx * x + vky * y + vkz * z;
	  dcomplex phasf = cos(angle) + uim * sin(angle);
	  dcomplex term = vec_w[jy60] * phasf * 0.5;
	  double factor = (jy60 == 0 || jy60 == nkv - 1) ? 0.5 : 1.0;
	  term *= factor;
	  rsumy += (real(term));
	  isumy += (imag(term));
	} else if (jx55 == nkv - 1) {
	  // jx55 = nkv - 1: phasl
	  double vkx = vkv[nkv - 1];
	  double vkz = vec_vkzm[(nkv - 1) * nkv + jy60];
	  double vkz = vec_vkzm[nkvvmo + jy60];
	  double angle = vkx * x + vky * y + vkz * z;
	  dcomplex phasl = cos(angle) + uim * sin(angle);
	  dcomplex term = vec_w[(nkv - 1) * nkv + jy60] * phasl * 0.5;
	  double factor = (jy60 == 0 || jy60 == nkv - 1) ? 0.5 : 1.0;
	  dcomplex term = vec_w[nkvvmo + jy60] * phasl * 0.5;
	  term *= factor;
	  rsumy += (real(term));
	  isumy += (imag(term));
	} else {
	  // 1 <= jx55 < nkv - 1
	  double vkx = vkv[jx55];
	  double vkz = vec_vkzm[(jx55) * nkv + jy60];
	  double vkz = vec_vkzm[w_index];
	  double angle = vkx * x + vky * y + vkz * z;
	  dcomplex phas = cos(angle) + uim * sin(angle);
	  dcomplex term = vec_w[(jx55) * nkv + jy60] * phas;
	  double factor = (jy60 == 0 || jy60 == nkv - 1) ? 0.5 : 1.0;
	  dcomplex term = vec_w[w_index] * phas;
	  term *= factor;
	  rsumy += (real(term));
	  isumy += (imag(term));
	}
      } // jy60x55 loop
      dcomplex sumy = rsumy + uim * isumy;
      vec_wsum[(j80 * nrvc) + ixyz] = sumy * delks;
      vec_wsum[(j80_index * nrvc) + ixyz] = sumy * delks;
    } // ixyz loop
  } // j80 loop
}