Commit f3264886 authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Use GPU matrix multiplication in SVD recombination

parent d7d7dc8c
Loading
Loading
Loading
Loading
+19 −24
Original line number Diff line number Diff line
@@ -97,35 +97,30 @@ void magma_svd_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
    );

    magma_free_cpu(work);
    for (magma_int_t si = 0; si < n; si++)
      s[si] = (s[si] == 0.0) ? 0.0 : 1.0 / s[si];

    for (magma_int_t ri = 0; ri < n; ri++) {
      for (magma_int_t rj = 0; rj < n; rj++) {
    	u[n * ri + rj] = MAGMA_Z_MAKE(s[ri] * real(u[n * ri + rj]), -s[ri] * imag(u[n * ri + rj]));
    	vt[n * ri + rj] = MAGMA_Z_MAKE(real(vt[n * ri + rj]), -imag(vt[n * ri + rj]));
      }
    }

    magmaDoubleComplex value;
    for (magma_int_t mi = 0; mi < n; mi++) {
      for (magma_int_t mj = 0; mj < n; mj++) {
	value = MAGMA_Z_MAKE(0.0, 0.0);
	for (magma_int_t mk = 0; mk < n; mk++) {
	  magmaDoubleComplex elem1 = vt[n * mi + mk];
	  magmaDoubleComplex elem2 = u[n * mk + mj];
	  value = MAGMA_Z_ADD(value, MAGMA_Z_MUL(elem1, elem2));
	}
	a[n * mi + mj] = value;
    double rpart, ipart;
    for (magma_int_t si = 0; si < n; si++) {
      s[si] = (s[si] == 0.0) ? 0.0 : 1.0 / s[si];
      for (magma_int_t sj = 0; sj < n; sj++) {
	rpart = s[si] * real(u[n * si + sj] );
	ipart = s[si] * imag(u[n * si + sj] );
	u[n * si + sj] = MAGMA_Z_MAKE(rpart, -ipart);
	rpart = real(vt[n * si + sj] );
	ipart = imag(vt[n * si + sj] );
	vt[n * si + sj] = MAGMA_Z_MAKE(rpart, -ipart);
      }
    }

    magmaDoubleComplex *d_a;
    magmaDoubleComplex *d_a, *d_u, *d_vt;
    magma_zmalloc(&d_a, n * n);
    magma_zsetmatrix(n, n, a, n, d_a , n, queue);
    magmablas_ztranspose_inplace(n, d_a, n, queue);
    magma_zmalloc(&d_u, n * n);
    magma_zmalloc(&d_vt, n * n);
    magma_zsetmatrix(n, n, u, n, d_u, n, queue);
    magma_zsetmatrix(n, n, vt, n, d_vt, n, queue);
    magmablas_zgemm(MagmaTrans, MagmaTrans, n, n, n, cc1, d_vt, n, d_u, n, cc0, d_a, n, queue);
    magma_zgetmatrix(n, n, d_a, n, a, n, queue);
    magma_free(d_a);
    magma_free(d_u);
    magma_free(d_vt);
  } else {
    jer = (int)err;
  }