Commit d9eb3dea authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Implement identity based refinement iteration

parent 4279501c
Loading
Loading
Loading
Loading
+27 −0
Original line number Diff line number Diff line
@@ -43,6 +43,13 @@ void magma_zinvert(
/**
 * \brief Perform Newton-Schulz iterative refinement of matrix inversion.
 *
 * In this function the residual of the inversion of a matrix A is evaluated as:
 *
 * R = A^-1 A - I
 *
 * and the convergence of refinement is estimated through the largest element
 * modulus left in R.
 *
 * \param rs: `const RuntimeSettings &` Runtime settings instance. [IN]
 * \param a: `magmaDoubleComplex *` Pointer to the first element of the non-inverted matrix on host. [IN]
 * \param m: `const magma_int_t` Number of rows / columns in a. [IN]
@@ -54,6 +61,26 @@ magma_int_t magma_newton(
  magmaDoubleComplex* d_a, magma_queue_t queue
);

/**
 * \brief Perform norm-based Newton-Schulz iterative refinement of matrix inversion.
 *
 * In this function the residual of the inversion of a matrix A is evaluated as:
 *
 * R = A A^-1 A - A
 *
 * and the convergence of refinement is estimated through the norm of R.
 *
 * \param rs: `const RuntimeSettings &` Runtime settings instance. [IN]
 * \param a: `magmaDoubleComplex *` Pointer to the first element of the non-inverted matrix on host. [IN]
 * \param m: `const magma_int_t` Number of rows / columns in a. [IN]
 * \param d_a: `magmaDoubleComplex *` Pointer to the matrix on the GPU. [IN/OUT]
 * \param queue: `magma_queue_t` GPU communication queue. [IN]
 */
magma_int_t magma_newton_norm(
  const RuntimeSettings& rs, magmaDoubleComplex* a, const magma_int_t m,
  magmaDoubleComplex* d_a, magma_queue_t queue
);

/* \brief Invert a complex matrix with double precision elements, applying iterative refinement of the solution
 *
 * call magma_zinvert1() to perform the first matrix inversion, then magma_refine() to do the refinement (only if maxrefiters is >0)
+98 −2
Original line number Diff line number Diff line
@@ -263,7 +263,7 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id, const Runt
  jer = (int)err;
}

magma_int_t magma_newton(
magma_int_t magma_newton_norm(
  const RuntimeSettings& rs, magmaDoubleComplex* a, const magma_int_t m,
  magmaDoubleComplex* d_a, magma_queue_t queue
) {
@@ -361,4 +361,100 @@ magma_int_t magma_newton(
  return err;
}

#endif // USE_MAGMA
magma_int_t magma_newton(
  const RuntimeSettings& rs, magmaDoubleComplex* a, const magma_int_t m,
  magmaDoubleComplex* d_a, magma_queue_t queue
) {
  magma_int_t err;
  string message;
  char buffer[128];
  const int max_ref_iters = rs.max_ref_iters;
  const magmaDoubleComplex magma_zero = MAGMA_Z_MAKE(0.0, 0.0);
  const magmaDoubleComplex magma_one = MAGMA_Z_MAKE(1.0, 0.0);
  const magmaDoubleComplex magma_mone = MAGMA_Z_MAKE(-1.0, 0.0);
  const magmaDoubleComplex magma_two = MAGMA_Z_MAKE(2.0, 0.0);
  const magma_int_t mm = m * m;
  magmaDoubleComplex *d_a_orig, *d_ax, *d_r;
  magmaDoubleComplex *h_id_diag = new magmaDoubleComplex[m];
  magmaDoubleComplex *d_id_diag;
  double oldmax = 2.0e+16, curmax = 1.0e+16;
  for (magma_int_t hi = 0; hi < m; hi++)
    h_id_diag[hi] = magma_one;
  err = magma_zmalloc(&d_id_diag, m);
  if (err != MAGMA_SUCCESS) {
    message = "ERROR: could not allocate d_id_diag!\n";
    rs.logger->err(message);
    exit(1);
  }
  magma_zsetvector(m, h_id_diag, 1, d_id_diag, 1, queue);
  delete[] h_id_diag;
  err = magma_zmalloc(&d_a_orig, mm);
  if (err != MAGMA_SUCCESS) {
    message = "ERROR: could not allocate d_a_orig!\n";
    rs.logger->err(message);
    exit(1);
  }
  magma_zsetmatrix(m, m, a, m, d_a_orig, m, queue); // copy a -> d_a_orig
  err = magma_zmalloc(&d_ax, mm);
  if (err != MAGMA_SUCCESS) {
    message = "ERROR: could not allocate d_ax!\n";
    rs.logger->err(message);
    exit(1);
  }
  err = magma_zmalloc(&d_r, mm);
  if (err != MAGMA_SUCCESS) {
    message = "ERROR: could not allocate d_r!\n";
    rs.logger->err(message);
    exit(1);
  }
  double max_residue, target_residue;
  magma_int_t maxindex = magma_izamax(mm, d_a, 1, queue);
  magmaDoubleComplex magmamax = magma_mone;
  magma_zgetvector(1, d_a + maxindex - 1, 1, &magmamax, 1, queue);
  curmax = cabs(magmamax.x + I * magmamax.y);
  target_residue = curmax * 1.0e-07;
  sprintf(buffer, "INFO: largest matrix value has modulus %.5le; target residue is %.5le.\n", curmax, target_residue);
  message = buffer;
  rs.logger->log(message);
  for (int ri = 0; ri < max_ref_iters; ri++) {
    oldmax = curmax;
    // Compute -A*X
    magmablas_zgemm(
      MagmaNoTrans, MagmaNoTrans, m, m, m, magma_mone, d_a_orig, m,
      d_a, m, magma_zero, d_ax, m, queue
    );
    // Transform -A*X into (I - A*X)
    magma_zaxpy(m, magma_one, d_id_diag, 1, d_ax, m + 1, queue);
    maxindex = magma_izamax(mm, d_ax, 1, queue);
    magma_zgetvector(1, d_ax + maxindex - 1, 1, &magmamax, 1, queue);
    curmax = cabs(magmamax.x + I * magmamax.y);
    sprintf(buffer, "DEBUG: iteration %d has residue %.5le; target residue is %.5le.\n", ri, curmax, target_residue);
    message = buffer;
    rs.logger->log(message, LOG_DEBG);
    if (curmax < 0.99 * oldmax) {
      // Compute R = (A*X - I)*X
      magmablas_zgemm(
        MagmaNoTrans, MagmaNoTrans, m, m, m, magma_one, d_ax, m,
	d_a, m, magma_zero, d_r, m, queue
      );
      // Set X = X + R
      magma_zaxpy(mm, magma_one, d_r, 1, d_a, 1, queue);
      if (curmax < target_residue) {
	message = "DEBUG: good news - optimal convergence achieved. Stopping.\n";
	rs.logger->log(message, LOG_DEBG);
	break; // ri for
      }
    } else {
      message = "WARN: not so good news - cannot improve further. Stopping.\n";
      rs.logger->log(message, LOG_WARN);
      break; // ri for
    }
  }
  magma_free(d_a_orig);
  magma_free(d_ax);
  magma_free(d_r);
  magma_free(d_id_diag);
  return err;
}

#endif