Commit 8f0e1a2c authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Bring magma_ztm() out of target offload context

parent 8ef840de
Loading
Loading
Loading
Loading
+43 −6
Original line number Diff line number Diff line
@@ -886,26 +886,63 @@ int cluster_jxi488_cycle(
  if (rs.use_offload) {
    // whenever rs.use_offload == true, USE_TARGET_OFFLOAD is defined.
#ifdef USE_MAGMA
    const int ind3j_size = (cid->c1->lm + 1) * cid->c1->lm;
    const int nv3j = cid->c1->nv3j;
    const int nsphmo = nsph - 1;
    const int ncou = nsph * nsphmo;
    const int litpo = li + li + 1;
    const int litpos = litpo * litpo;
    const int lmtpo = cid->c1->lmtpo;
    const int lmtpos = cid->c1->lmtpos;
    const int rsize = li * nsph;
    const magma_int_t ndi = nsph * li * (li + 2);
    const magma_int_t nlem = cid->c1->nlem;
    const magma_int_t nlemt = nlem + nlem;
    magmaDoubleComplex* vec_am = (magmaDoubleComplex *)(cid->am[0]);
    magmaDoubleComplex *rmi = (magmaDoubleComplex *)(cid->c1->rmi[0]);
    magmaDoubleComplex *rei = (magmaDoubleComplex *)(cid->c1->rei[0]);
    double *rxx = cid->c1->rxx;
    double *ryy = cid->c1->ryy;
    double *rzz = cid->c1->rzz;
    int *ind3j = cid->c1->ind3j[0];
    double *v3j0 = cid->c1->v3j0;
    magmaDoubleComplex *vh = (magmaDoubleComplex *)(cid->c1->vh);
    magmaDoubleComplex *vyhj = (magmaDoubleComplex *)(cid->c1->vyhj);
    magmaDoubleComplex *vj0 = (magmaDoubleComplex *)(cid->c1->vj0);
    magmaDoubleComplex *vyj0 = (magmaDoubleComplex *)(cid->c1->vyj0);
    magmaDoubleComplex *gis_v = (magmaDoubleComplex *)(cid->c1->gis[0]);
    magmaDoubleComplex *gls_v = (magmaDoubleComplex *)(cid->c1->gls[0]);
    magmaDoubleComplex *sam_v = (magmaDoubleComplex *)(cid->c1->sam[0]);
    magmaDoubleComplex *vec_am0m = (magmaDoubleComplex *)(cid->c1->am0m[0]);
    magma_queue_t queue = NULL;
    magma_device_t device_id = (magma_device_t)(cid->proc_device);
    magma_queue_create(device_id, &queue);
#pragma omp target data map(to: vec_am[0:ndit*ndit]) device(cid->proc_device)
#pragma omp target data map(to: rxx[0:nsph], ryy[0:nsph], rzz[0:nsph]) \
  map(to: ind3j[0:ind3j_size], v3j0[0:nv3j], vh[0:ncou*litpo]) \
  map(to: vyhj[0:ncou*litpos], vj0[0:nsph*lmtpo], vyj0[0:nsph*lmtpos]) \
  map(to: sam_v[0:ndit*nlem], rmi[0:rsize], rei[0:rsize]) \
  map(to: gis_v[0:ndi*nlem], gls_v[0:ndi*nlem]) \
  map(tofrom:vec_am[0:ndit*ndit], vec_am0m[0:nlemt*nlemt]) \
  device(cid->proc_device)
    {
      magma_cms(vec_am, cid->c1, cid->proc_device); // Initialize uninverted matrix on GPU
      
      // Initialize uninverted matrix on GPU
      magma_cms(
        vec_am, cid->c1, rxx, ryy, rzz, ind3j, v3j0, vh, vyhj, vj0, vyj0, rmi, rei, cid->proc_device
      );
      magma_zinvert_resident(vec_am, (magma_int_t)ndit, jer, queue, cid->proc_device, rs);
      magma_queue_sync(queue);
      
      magma_ztm(vec_am, cid->c1, cid->proc_device);
      
      if (jer != 0) {
	sprintf(virtual_line, "ERROR: matrix inversion returned code %d!\n", jer);
	message = virtual_line;
	logger->err(message);
      }
    }
    } // end of target data region
    magma_queue_destroy(queue);
    magma_ztm(
      vec_am, cid->c1, rxx, ryy, rzz, ind3j, v3j0, vh, vyhj, vj0, vyj0, sam_v,
      gis_v, gls_v, vec_am0m, cid->proc_device
    );
#else // NO_USE_MAGMA but USE_TARGET_OFFLOAD
    // TODO: implement full offload pipeline without MAGMA
    cms_flat(cid->am[0], cid->c1);
+44 −2
Original line number Diff line number Diff line
@@ -45,10 +45,26 @@ double magma_cgev(int ipamo, int mu, int l, int m);
 *
 * \param[in,out] vec_am: `magmaDoubleComplex *` Vector form of the matrix.
 * \param[in] c1: `ParticleDescriptor *` Pointer to a ParticleDescriptor instance.
 * \param[in] rxx: `double *` Vector of sphere X coordinates.
 * \param[in] ryy: `double *` Vector of sphere Y coordinates.
 * \param[in] rzz: `double *` Vector of sphere Z coordinates.
 * \param[in] ind3j: `int *` Vector form of the 3J index look-up table.
 * \param[in] v3j0: `double *`
 * \param[in] vh: `magmaDoubleComplex *`
 * \param[in] vyhj: `magmaDoubleComplex *`
 * \param[in] vj0: `magmaDoubleComplex *`
 * \param[in] vyj0: `magmaDoubleComplex *`
 * \param[in] rmi: `magmaDoubleComplex *` Vector of Mie magnetic coefficients.
 * \param[in] rei: `magmaDoubleComplex *` Vector of Mie electric coefficients.
 * \param[in] device_id: `int` ID of the device to build the matrix on.
 * \return result: `magma_int_t` A return code (MAGMA_SUCCESS, if everything is fine).
 */
magma_int_t magma_cms(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int device_id);
magma_int_t magma_cms(
  magmaDoubleComplex *vec_am, ParticleDescriptor *c1, double *rxx, double *ryy, double *rzz,
  int *ind3j, double *v3j0, magmaDoubleComplex *vh, magmaDoubleComplex *vyhj,
  magmaDoubleComplex *vj0, magmaDoubleComplex *vyj0, magmaDoubleComplex *rmi,
  magmaDoubleComplex* rei, int device_id
);

/**
 * \brief Compute the transfer vector from N2 to N1.
@@ -120,7 +136,33 @@ void magma_zinvert_resident(
  const RuntimeSettings& rs=RuntimeSettings()
);

magma_int_t magma_ztm(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int device_id);
/**
 * \brief Compute the monocentered T-matrix
 *
 * \param[in] vec_am: `magmaDoubleComplex *` Vector form of the multi-centered matrix.
 * \param[in] c1: `ParticleDescriptor *` Pointer to a ParticleDescriptor instance.
 * \param[in] rxx: `double *` Vector of sphere X coordinates.
 * \param[in] ryy: `double *` Vector of sphere Y coordinates.
 * \param[in] rzz: `double *` Vector of sphere Z coordinates.
 * \param[in] ind3j: `int *` Vector form of the 3J index look-up table.
 * \param[in] v3j0: `double *`
 * \param[in] vh: `magmaDoubleComplex *`
 * \param[in] vyhj: `magmaDoubleComplex *`
 * \param[in] vj0: `magmaDoubleComplex *`
 * \param[in] vyj0: `magmaDoubleComplex *`
 * \param[in] sam_v: `magmaDoubleComplex *` Vector form of sums from vec_am.
 * \param[in] gis_v: `magmaDoubleComplex *`
 * \param[in] gls_v: `magmaDoubleComplex *`
 * \param[in] device_id: `int` ID of the device to build the matrix on.
 * \return result: `magma_int_t` A return code (MAGMA_SUCCESS, if everything is fine).
 */
magma_int_t magma_ztm(
  magmaDoubleComplex *vec_am, ParticleDescriptor *c1, double *rxx, double *ryy,
  double *rzz, int *ind3j, double *v3j0, magmaDoubleComplex *vh, magmaDoubleComplex *vyhj,
  magmaDoubleComplex *vj0, magmaDoubleComplex *vyj0, magmaDoubleComplex *sam_v,
  magmaDoubleComplex *gis_v, magmaDoubleComplex *gls_v, magmaDoubleComplex *vec_am0m,
  int device_id
);
#endif // USE_TARGET_OFFLOAD

/**
+162 −206
Original line number Diff line number Diff line
@@ -114,7 +114,12 @@ double magma_cgev(int ipamo, int mu, int l, int m) {
}
#pragma omp end declare target

magma_int_t magma_cms(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int device_id) {
magma_int_t magma_cms(
  magmaDoubleComplex *vec_am, ParticleDescriptor *c1, double *rxx, double *ryy, double *rzz,
  int *ind3j, double *v3j0, magmaDoubleComplex *vh, magmaDoubleComplex *vyhj,
  magmaDoubleComplex *vj0, magmaDoubleComplex *vyj0, magmaDoubleComplex *rmi,
  magmaDoubleComplex* rei, int device_id
) {
  const magmaDoubleComplex cz0 = MAGMA_Z_MAKE(0.0, 0.0);
  const magmaDoubleComplex cz1 = MAGMA_Z_MAKE(1.0, 0.0); 
  const int nsph = c1->nsph;
@@ -127,8 +132,6 @@ magma_int_t magma_cms(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
  const int lmtpo = c1->lmtpo;
  const int lmtpos = c1->lmtpos;
  const int rsize = li * nsph;
  const int ind3j_size = (c1->lm + 1) * c1->lm;
  const int nv3j = c1->nv3j;
  const magma_int_t max_litpo = 2 * li + 1;
  const magma_int_t nlim = li * (li + 2);
  const magma_int_t ndi = nsph * nlim;
@@ -136,18 +139,7 @@ magma_int_t magma_cms(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
  const magma_int_t size = ndit * ndit;
  const magma_int_t num_pairs = (nsph * (nsph - 1)) / 2;
  const magma_int_t total_iters = num_pairs * li * max_litpo * li * max_litpo;
  magmaDoubleComplex *rmi = (magmaDoubleComplex *)(c1->rmi[0]);
  magmaDoubleComplex *rei = (magmaDoubleComplex *)(c1->rei[0]);
  magmaDoubleComplex vj = MAGMA_Z_MAKE(real(c1->vj), imag(c1->vj));
  double *rxx = c1->rxx;
  double *ryy = c1->ryy;
  double *rzz = c1->rzz;
  int *ind3j = c1->ind3j[0];
  double *v3j0 = c1->v3j0;
  magmaDoubleComplex *vh = (magmaDoubleComplex *)(c1->vh);
  magmaDoubleComplex *vyhj = (magmaDoubleComplex *)(c1->vyhj);
  magmaDoubleComplex *vj0 = (magmaDoubleComplex *)(c1->vj0);
  magmaDoubleComplex *vyj0 = (magmaDoubleComplex *)(c1->vyj0);

  int lut_n1[num_pairs];
  int lut_n2[num_pairs];
@@ -160,11 +152,7 @@ magma_int_t magma_cms(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
  }

#pragma omp target teams distribute parallel for \
  firstprivate(total_iters, max_litpo, num_pairs, li, le, ndi, ndit, vj) \
  map(to: lut_n1[0:num_pairs], lut_n2[0:num_pairs]) \
  map(to: rxx[0:nsph], ryy[0:nsph], rzz[0:nsph]) \
  map(to: ind3j[0:ind3j_size], v3j0[0:nv3j], vh[0:ncou*litpo]) \
  map(to: vyhj[0:ncou*litpos], vj0[0:nsph*lmtpo], vyj0[0:nsph*lmtpos]) \
  device(device_id)
  for (magma_int_t iter = 0; iter < total_iters; ++iter) {
    magma_int_t t = iter;
@@ -217,14 +205,12 @@ magma_int_t magma_cms(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
    // End of index 2 magnetic quantum numbers
    magmaDoubleComplex cgh, cgk, zvalue;
    cgh = (is_valid_iter) ?
      // cz1 : // ghit(0, 0, nbl, l1, m1, l2, m2, c1, rac3j) :
	magma_ghit<0>(
	  0, nbl, l1, m1, l2, m2, rxx, ryy, rzz, ind3j, v3j0, vh,
	  vyhj, vj0, vyj0, vj, li, le, rac3j
      ) :
	cz0;
    cgk = (is_valid_iter) ?
      // cz1 : // ghit(0, 1, nbl, l1, m1, l2, m2, c1, rac3j) :
	magma_ghit<0>(
        1, nbl, l1, m1, l2, m2, rxx, ryy, rzz, ind3j, v3j0, vh,
	  vyhj, vj0, vyj0, vj, li, le, rac3j
@@ -246,8 +232,6 @@ magma_int_t magma_cms(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
  
  magma_int_t diag_iters = nsph * li * max_litpo;
#pragma omp target teams distribute parallel for \
  firstprivate(diag_iters, max_litpo, li, nsph, ndi, ndit) \
  map(to: rmi[0:rsize], rei[0:rsize]) \
  device(device_id)
  for(magma_int_t iter = 0; iter < diag_iters; ++iter) {
    magma_int_t t = iter;
@@ -714,17 +698,19 @@ void magma_zinvert_resident(
      }
      // >>> END OF GESV INVERSION <<<
    }
  }
  } // end of target region
  jer = (int)err;
}

magma_int_t magma_ztm(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int device_id) {
magma_int_t magma_ztm(
  magmaDoubleComplex *vec_am, ParticleDescriptor *c1, double *rxx, double *ryy,
  double *rzz, int *ind3j, double *v3j0, magmaDoubleComplex *vh, magmaDoubleComplex *vyhj,
  magmaDoubleComplex *vj0, magmaDoubleComplex *vyj0, magmaDoubleComplex *sam_v,
  magmaDoubleComplex *gis_v, magmaDoubleComplex *gls_v, magmaDoubleComplex *vec_am0m,
  int device_id
) {
  magma_int_t result = MAGMA_SUCCESS;
  const magmaDoubleComplex cz0 = MAGMA_Z_MAKE(0.0, 0.0);
  magmaDoubleComplex *gis_v = (magmaDoubleComplex *)(c1->gis[0]);
  magmaDoubleComplex *gls_v = (magmaDoubleComplex *)(c1->gls[0]);
  magmaDoubleComplex *sam_v = (magmaDoubleComplex *)(c1->sam[0]);
  magmaDoubleComplex *vec_am0m = (magmaDoubleComplex *)(c1->am0m[0]);
  const magma_int_t k2max = c1->li * (c1->li + 2);
  const magma_int_t k3max = c1->le * (c1->le + 2);
  const magma_int_t nsph = c1->nsph;
@@ -743,27 +729,7 @@ magma_int_t magma_ztm(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
  const magma_int_t ind3j_size = (c1->lm + 1) * c1->lm;
  const magma_int_t nv3j = c1->nv3j;
  magmaDoubleComplex vj = MAGMA_Z_MAKE(real(c1->vj), imag(c1->vj));
  double *rxx = c1->rxx;
  double *ryy = c1->ryy;
  double *rzz = c1->rzz;
  int *ind3j = c1->ind3j[0];
  double *v3j0 = c1->v3j0;
  magmaDoubleComplex *vh = (magmaDoubleComplex *)(c1->vh);
  magmaDoubleComplex *vyhj = (magmaDoubleComplex *)(c1->vyhj);
  magmaDoubleComplex *vj0 = (magmaDoubleComplex *)(c1->vj0);
  magmaDoubleComplex *vyj0 = (magmaDoubleComplex *)(c1->vyj0);

#pragma omp target data use_device_ptr(vec_am) \
  map(to: rxx[0:nsph], ryy[0:nsph], rzz[0:nsph]) \
  map(to: ind3j[0:ind3j_size], v3j0[0:nv3j], vh[0:ncou*litpo]) \
  map(to: vyhj[0:ncou*litpos], vj0[0:nsph*lmtpo], vyj0[0:nsph*lmtpos]) \
  map(tofrom: sam_v[0:ndit*nlem], gis_v[0:ndi*nlem], gls_v[0:ndi*nlem]) \
  map(tofrom: vec_am0m[0:nlemt*nlemt]) \
  device(device_id)
  {  
#pragma omp target teams distribute parallel for collapse(3) \
  firstprivate(k2max, k3max, li, le, ndi, ndit, vj) \
  device(device_id)
  for (magma_int_t n2 = 1; n2 <= nsph; n2++) {
    for (magma_int_t k2 = 1; k2 <= k2max; k2++) {
      for (magma_int_t k3 = 1; k3 <= k3max; k3++) {
@@ -807,9 +773,6 @@ magma_int_t magma_ztm(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
    } // close k2 loop, former l2 + im2 loops
  } // close n2 loop

#pragma omp target teams distribute parallel for collapse(2)	\
  firstprivate(ndi, ndit, nlem, cz0) \
  device(device_id)
  for (magma_int_t i1 = 1; i1 <= ndi; i1++) {
    for (magma_int_t i3 = 1; i3 <= nlem; i3++) {
      magmaDoubleComplex sum1 = cz0;
@@ -860,9 +823,6 @@ magma_int_t magma_ztm(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
    } // i3 loop
  } // i1 loop
  
#pragma omp target teams distribute parallel for collapse(2)	\
  firstprivate(ndi, nlem) \
  device(device_id)
  for (magma_int_t i1 = 1; i1 <= ndi; i1++) {
    for (magma_int_t i0 = 1; i0 <= nlem; i0++) {
      magma_int_t vec_index = (i1 - 1) * nlem + i0 - 1;
@@ -871,9 +831,6 @@ magma_int_t magma_ztm(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
    } // i0 loop
  } // i1 loop
  
#pragma omp target teams distribute parallel for collapse(2)	\
  firstprivate(ndi, ndit, nlem, nlemt, cz0) \
  device(device_id)
  for (magma_int_t i0 = 1; i0 <= nlem; i0++) {
    for (magma_int_t i3 = 1; i3 <= nlemt; i3++) {
      magma_int_t i0e = i0 + nlem;
@@ -906,7 +863,6 @@ magma_int_t magma_ztm(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
      vec_am0m[vec_ind0e + i3 - 1] = MAGMA_Z_NEGATE(sum2);
    } // i3 loop
  } // i0 loop
  }
  return result;
}