Commit 2041e88b authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Define magma_ztm()

parent a3aebe12
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -119,6 +119,8 @@ void magma_zinvert_resident(
  magmaDoubleComplex *mat, magma_int_t n, int &jer, magma_queue_t queue, int device_id=0,
  const RuntimeSettings& rs=RuntimeSettings()
);

magma_int_t magma_ztm(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int device_id);
#endif // USE_TARGET_OFFLOAD

/**
+253 −61
Original line number Diff line number Diff line
@@ -148,7 +148,6 @@ magma_int_t magma_cms(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
  magmaDoubleComplex *vyhj = (magmaDoubleComplex *)(c1->vyhj);
  magmaDoubleComplex *vj0 = (magmaDoubleComplex *)(c1->vj0);
  magmaDoubleComplex *vyj0 = (magmaDoubleComplex *)(c1->vyj0);
  int dbg_i1, dbg_i2;

  int lut_n1[num_pairs];
  int lut_n2[num_pairs];
@@ -161,7 +160,7 @@ magma_int_t magma_cms(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
  }
  
#pragma omp target teams distribute parallel for \
  firstprivate(total_iters, max_litpo, num_pairs, li, le, ndi, ndit, vj, dbg_i1, dbg_i2) \
  firstprivate(total_iters, max_litpo, num_pairs, li, le, ndi, ndit, vj) \
  map(to: lut_n1[0:num_pairs], lut_n2[0:num_pairs]) \
  map(to: rxx[0:nsph], ryy[0:nsph], rzz[0:nsph]) \
  map(to: ind3j[0:ind3j_size], v3j0[0:nv3j], vh[0:ncou*litpo]) \
@@ -216,8 +215,6 @@ magma_int_t magma_cms(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int de
    int j2 = (is_valid_iter) ? in1 + ilm2 : 0;
    int j2e = (is_valid_iter) ? in1 + ilm2e : 0;
    // End of index 2 magnetic quantum numbers
    dbg_i1 = i1;
    dbg_i2 = i2;
    magmaDoubleComplex cgh, cgk, zvalue;
    cgh = (is_valid_iter) ?
      // cz1 : // ghit(0, 0, nbl, l1, m1, l2, m2, c1, rac3j) :
@@ -442,7 +439,7 @@ template<int IHI> magmaDoubleComplex magma_ghit(
      MAGMA_Z_MAKE(sqrt(FOUR_PI * (l1 + l1po) * (l2 + l2 + 1)), 0.0);
    return magma_zprod(result, cr);
  } else { // The project grants that IHI = 2 whenever it is neither 0 nor 1.
    int lmtpo = lm + lm + 1;
    int lmtpo = li + le + 1;
    int lmtpos = lmtpo * lmtpo;
    int nbhj = nblmo * lmtpo;
    int nby = nblmo * lmtpos;
@@ -655,7 +652,9 @@ void magma_zinvert_resident(
  magma_int_t mm = m * m; // size of a
  const magmaDoubleComplex magma_zero = MAGMA_Z_MAKE(0.0, 0.0);
  const magmaDoubleComplex magma_one = MAGMA_Z_MAKE(1.0, 0.0);
  magmaDoubleComplex *h_a = (magmaDoubleComplex *)omp_get_mapped_ptr(a, device_id);
  magmaDoubleComplex *h_a = a;
#pragma omp target data use_device_ptr(a) device(device_id)
  {
    if (rs.invert_mode == RuntimeSettings::INV_MODE_LU) { 
      // >>> LU INVERSION <<<
      magmaDoubleComplex *dwork; // work space pointer on device
@@ -715,9 +714,202 @@ void magma_zinvert_resident(
      }
      // >>> END OF GESV INVERSION <<<
    }
  }
  jer = (int)err;
}

magma_int_t magma_ztm(magmaDoubleComplex *vec_am, ParticleDescriptor *c1, int device_id) {
  magma_int_t result = MAGMA_SUCCESS;
  const magmaDoubleComplex cz0 = MAGMA_Z_MAKE(0.0, 0.0);
  magmaDoubleComplex *gis_v = (magmaDoubleComplex *)(c1->gis[0]);
  magmaDoubleComplex *gls_v = (magmaDoubleComplex *)(c1->gls[0]);
  magmaDoubleComplex *sam_v = (magmaDoubleComplex *)(c1->sam[0]);
  magmaDoubleComplex *vec_am0m = (magmaDoubleComplex *)(c1->am0m[0]);
  const magma_int_t k2max = c1->li * (c1->li + 2);
  const magma_int_t k3max = c1->le * (c1->le + 2);
  const magma_int_t nsph = c1->nsph;
  const magma_int_t ncou = nsph * (nsph - 1);
  const magma_int_t li = c1->li;
  const magma_int_t litpo = li + li + 1;
  const magma_int_t litpos = litpo * litpo;
  const magma_int_t nlim = li * (li + 2);
  const magma_int_t ndi = nsph * li * (li + 2);
  const magma_int_t ndit = ndi + ndi;
  const magma_int_t le = c1->le;
  const magma_int_t nlem = c1->nlem;
  const magma_int_t nlemt = nlem + nlem;
  const magma_int_t lmtpo = c1->lmtpo;
  const magma_int_t lmtpos = c1->lmtpos;
  const magma_int_t ind3j_size = (c1->lm + 1) * c1->lm;
  const magma_int_t nv3j = c1->nv3j;
  magmaDoubleComplex vj = MAGMA_Z_MAKE(real(c1->vj), imag(c1->vj));
  double *rxx = c1->rxx;
  double *ryy = c1->ryy;
  double *rzz = c1->rzz;
  int *ind3j = c1->ind3j[0];
  double *v3j0 = c1->v3j0;
  magmaDoubleComplex *vh = (magmaDoubleComplex *)(c1->vh);
  magmaDoubleComplex *vyhj = (magmaDoubleComplex *)(c1->vyhj);
  magmaDoubleComplex *vj0 = (magmaDoubleComplex *)(c1->vj0);
  magmaDoubleComplex *vyj0 = (magmaDoubleComplex *)(c1->vyj0);

#pragma omp target data use_device_ptr(vec_am) \
  map(to: rxx[0:nsph], ryy[0:nsph], rzz[0:nsph]) \
  map(to: ind3j[0:ind3j_size], v3j0[0:nv3j], vh[0:ncou*litpo]) \
  map(to: vyhj[0:ncou*litpos], vj0[0:nsph*lmtpo], vyj0[0:nsph*lmtpos]) \
  map(tofrom: sam_v[0:ndit*nlem], gis_v[0:ndi*nlem], gls_v[0:ndi*nlem]) \
  map(tofrom: vec_am0m[0:nlemt*nlemt]) \
  device(device_id)
  {  
#pragma omp target teams distribute parallel for collapse(3) \
  firstprivate(k2max, k3max, li, le, ndi, ndit, vj) \
  device(device_id)
    for (magma_int_t n2 = 1; n2 <= nsph; n2++) {
      for (magma_int_t k2 = 1; k2 <= k2max; k2++) {
	for (magma_int_t k3 = 1; k3 <= k3max; k3++) {
	  double rac3j[128];
	  magma_int_t l2 = (magma_int_t)sqrt(k2 + 1);
	  magma_int_t im2 = k2 - (l2 * l2) + 1;
	  if (im2 == 0) {
	    l2--;
	    im2 = 2 * l2+1;
	  }
	  else if (im2 > 2 * l2 + 1) {
	    im2 -= 2 * l2 + 1;
	    l2++;
	  }
	  magma_int_t l3 = (magma_int_t)sqrt(k3 + 1);
	  magma_int_t im3 = k3 - (l3 * l3) + 1;
	  if (im3 == 0) {
	    l3--;
	    im3 = 2 * l3 + 1;
	  }
	  else if (im3 > 2 * l3 + 1) {
	    im3 -= 2 * l3 + 1;
	    l3++;
	  }
	  magma_int_t i2 = (n2 - 1) * li * (li + 2) + l2 * l2 + im2 - 1;
	  magma_int_t m2 = -l2 - 1 + im2;
	  magma_int_t i3 = l3 * l3 + im3 - 1;
	  magma_int_t m3 = -l3 - 1 + im3;
	  magma_int_t vec_index = (i2 - 1) * nlem + i3 - 1;
	  // gis_v[vec_index] = ghit(2, 0, n2, l2, m2, l3, m3, c1, rac3j);
	  gis_v[vec_index] = magma_ghit<2>(
					   0, n2, l2, m2, l3, m3, rxx, ryy, rzz, ind3j,
					   v3j0, vh, vyhj, vj0, vyj0, vj, li, le, rac3j
					   );
	  //gls_v[vec_index] = ghit(2, 1, n2, l2, m2, l3, m3, c1, rac3j);
	  gls_v[vec_index] = magma_ghit<2>(
					   1, n2, l2, m2, l3, m3, rxx, ryy, rzz, ind3j,
					   v3j0, vh, vyhj, vj0, vyj0, vj, li, le, rac3j
					   );
	} // close k3 loop, former l3 + im3 loops
      } // close k2 loop, former l2 + im2 loops
    } // close n2 loop
  
#pragma omp target teams distribute parallel for collapse(2)	\
  firstprivate(ndi, ndit, nlem, cz0) \
  device(device_id)
    for (magma_int_t i1 = 1; i1 <= ndi; i1++) {
      for (magma_int_t i3 = 1; i3 <= nlem; i3++) {
	magmaDoubleComplex sum1 = cz0;
	magmaDoubleComplex sum2 = cz0;
	magmaDoubleComplex sum3 = cz0;
	magmaDoubleComplex sum4 = cz0;
	magma_int_t i1e = i1 + ndi;
	magma_int_t i3e = i3 + nlem;
	for (magma_int_t i2 = 1; i2 <= ndi; i2++) {
	  magmaDoubleComplex pr;
	  magma_int_t i2e = i2 + ndi;
	  magma_int_t vec_ind_g23 = (i2 - 1) * nlem + i3 - 1;
	  magmaDoubleComplex gie = gis_v[vec_ind_g23];
	  magmaDoubleComplex gle = gls_v[vec_ind_g23];
	  magma_int_t vec_ind_a1 = (i1 - 1) * ndit;
	  magma_int_t vec_ind_a1e = (i1 - 1 + ndi) * ndit;
	  magmaDoubleComplex a1 = vec_am[vec_ind_a1 + i2 - 1];
	  magmaDoubleComplex a2 = vec_am[vec_ind_a1 + i2e - 1];
	  magmaDoubleComplex a3 = vec_am[vec_ind_a1e + i2 - 1];
	  magmaDoubleComplex a4 = vec_am[vec_ind_a1e + i2e - 1];
	  // sum1 += (a1 * gie + a2 * gle);
	  pr = magma_zprod(a1, gie);
	  sum1 = magma_zadd(sum1, pr);
	  pr = magma_zprod(a2, gle);
	  sum1 = magma_zadd(sum1, pr);
	  // sum2 += (a1 * gle + a2 * gie);
	  pr = magma_zprod(a1, gle);
	  sum2 = magma_zadd(sum2, pr);
	  pr = magma_zprod(a2, gie);
	  sum2 = magma_zadd(sum2, pr);
	  // sum3 += (a3 * gie + a4 * gle);
	  pr = magma_zprod(a3, gie);
	  sum3 = magma_zadd(sum3, pr);
	  pr = magma_zprod(a4, gle);
	  sum3 = magma_zadd(sum3, pr);
	  // sum4 += (a3 * gle + a4 * gie);
	  pr = magma_zprod(a3, gle);
	  sum4 = magma_zadd(sum4, pr);
	  pr = magma_zprod(a4, gie);
	  sum4 = magma_zadd(sum4, pr);
	} // i2 loop
	magma_int_t vec_ind1 = (i1 - 1) * nlemt;
	magma_int_t vec_ind1e = (i1e - 1) * nlemt;
	sam_v[vec_ind1 + i3 - 1] = sum1;
	sam_v[vec_ind1 + i3e - 1] = sum2;
	sam_v[vec_ind1e + i3 - 1] = sum3;
	sam_v[vec_ind1e + i3e - 1] = sum4;
      } // i3 loop
    } // i1 loop
  
#pragma omp target teams distribute parallel for collapse(2)	\
  firstprivate(ndi, nlem) \
  device(device_id)
    for (magma_int_t i1 = 1; i1 <= ndi; i1++) {
      for (magma_int_t i0 = 1; i0 <= nlem; i0++) {
	magma_int_t vec_index = (i1 - 1) * nlem + i0 - 1;
	gis_v[vec_index] = MAGMA_Z_CONJ(gis_v[vec_index]);
	gls_v[vec_index] = MAGMA_Z_CONJ(gls_v[vec_index]);
      } // i0 loop
    } // i1 loop
  
#pragma omp target teams distribute parallel for collapse(2)	\
  firstprivate(ndi, ndit, nlem, nlemt, cz0) \
  device(device_id)
    for (magma_int_t i0 = 1; i0 <= nlem; i0++) {
      for (magma_int_t i3 = 1; i3 <= nlemt; i3++) {
	magma_int_t i0e = i0 + nlem;
	magmaDoubleComplex sum1 = cz0;
	magmaDoubleComplex sum2 = cz0;
	for (magma_int_t i1 = 1; i1 <= ndi; i1 ++) {
	  magmaDoubleComplex pr;
	  magma_int_t i1e = i1 + ndi;
	  magma_int_t vec_ind1 = (i1 - 1) * nlemt;
	  magma_int_t vec_ind1e = (i1e - 1) * nlemt;
	  magmaDoubleComplex a1 = sam_v[vec_ind1 + i3 - 1];
	  magmaDoubleComplex a2 = sam_v[vec_ind1e + i3 - 1];
	  magma_int_t vec_index = (i1 - 1) * nlem + i0 - 1;
	  magmaDoubleComplex gie = gis_v[vec_index];
	  magmaDoubleComplex gle = gls_v[vec_index];
	  // sum1 += (a1 * gie + a2 * gle);
	  pr = magma_zprod(a1, gie);
	  sum1 = magma_zadd(sum1, pr);
	  pr = magma_zprod(a2, gle);
	  sum1 = magma_zadd(sum1, pr);
	  // sum2 += (a1 * gle + a2 * gie);
	  pr = magma_zprod(a1, gle);
	  sum2 = magma_zadd(sum2, pr);
	  pr = magma_zprod(a2, gie);
	  sum2 = magma_zadd(sum2, pr);
	} // i1 loop
	magma_int_t vec_ind0 = (i0 - 1) * nlemt;
	magma_int_t vec_ind0e = (i0e - 1) * nlemt;
	vec_am0m[vec_ind0 + i3 - 1] = MAGMA_Z_NEGATE(sum1);
	vec_am0m[vec_ind0e + i3 - 1] = MAGMA_Z_NEGATE(sum2);
      } // i3 loop
    } // i0 loop
  }
  return result;
}

#endif // USE_TARGET_OFFLOAD

void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id, const RuntimeSettings& rs) {