Commit 52b6a488 authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Merge branch 'test_SVD' into 'master'

Use back-up guards in iterative matrix inversion refinement

See merge request giacomo.mulas/np_tmcode!109
parents 768361a4 287f322b
Loading
Loading
Loading
Loading
+8 −21
Original line number Diff line number Diff line
@@ -48,7 +48,14 @@ void magma_zinvert(
 * R = A^-1 A - I
 *
 * and the convergence of refinement is estimated through the largest element
 * modulus left in R.
 * modulus left in R. Since Newton-Schultz iterative refinement is dangerous for
 * unstable matrices, the function first creates a host copy of the unrefined
 * inverted matrix and then attempts refinement. If the residue generated by
 * refinement does not improve over the original residue, the function returns
 * an error code that informs the calling function about corruption danger and
 * instructs it to use the back-up. In case of successful refinement, instead,
 * the calling function will be responsible of getting the refined matrix from
 * device back to the host.
 *
 * \param rs: `const RuntimeSettings &` Runtime settings instance. [IN]
 * \param a: `magmaDoubleComplex *` Pointer to the first element of the non-inverted matrix on host. [IN]
@@ -62,26 +69,6 @@ magma_int_t magma_newton(
  magmaDoubleComplex* d_a, magma_queue_t queue
);

/**
 * \brief Perform norm-based Newton-Schulz iterative refinement of matrix inversion.
 *
 * In this function the residual of the inversion of a matrix A is evaluated as:
 *
 * R = A A^-1 A - A
 *
 * and the convergence of refinement is estimated through the norm of R.
 *
 * \param rs: `const RuntimeSettings &` Runtime settings instance. [IN]
 * \param a: `magmaDoubleComplex *` Pointer to the first element of the non-inverted matrix on host. [IN]
 * \param m: `const magma_int_t` Number of rows / columns in a. [IN]
 * \param d_a: `magmaDoubleComplex *` Pointer to the matrix on the GPU. [IN/OUT]
 * \param queue: `magma_queue_t` GPU communication queue. [IN]
 */
magma_int_t magma_newton_norm(
  const RuntimeSettings& rs, magmaDoubleComplex* a, const magma_int_t m,
  magmaDoubleComplex* d_a, magma_queue_t queue
);

/* \brief Invert a complex matrix with double precision elements, applying iterative refinement of the solution
 *
 * call magma_zinvert1() to perform the first matrix inversion, then magma_refine() to do the refinement (only if maxrefiters is >0)
+1 −1
Original line number Diff line number Diff line
@@ -391,7 +391,7 @@ GeometryConfiguration* GeometryConfiguration::from_legacy(const std::string& fil
    }
    if (str_target_size > 13) {
      if (str_target.substr(0, 13).compare("INV_ACCURACY=") == 0) {
	int accuracy_goal = (double)stod(str_target.substr(13, str_target.length()));
	double accuracy_goal = (double)stod(str_target.substr(13, str_target.length()));
	conf->_accuracy_goal = accuracy_goal;
	is_parsed = true;
      }
+71 −52
Original line number Diff line number Diff line
@@ -86,6 +86,7 @@ using namespace std;
void cublas_zinvert(dcomplex **mat, np_int n, int device_id, const RuntimeSettings& rs) {
  char buffer[128];
  string message;
  int ref_err = 0;
  cudacall(cudaSetDevice(device_id));
  cublasHandle_t handle;
  cublascall(cublasCreate_v2(&handle));
@@ -126,6 +127,8 @@ void cublas_zinvert(dcomplex **mat, np_int n, int device_id, const RuntimeSettin
      cuDoubleComplex *d_id;
      // copy the original matrix again to d_a, so I do not need to destroy the old d_a and recreate a new one
      cudacall(cudaMemcpy(d_a, a, m*m*sizeof(cuDoubleComplex),cudaMemcpyHostToDevice)); // from here on, d_a contains the original matrix, for refinement use  
      // copy the unrefined inverted matrix from device to host, so we have a back-up if refinement fails
      cudacall(cudaMemcpy(a, d_c, m*m*sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost));
      cudacall(cudaMalloc<cuDoubleComplex>(&d_a_residual, m*m*sizeof(cuDoubleComplex)));
      cudacall(cudaMalloc<cuDoubleComplex>(&d_a_refine, m*m*sizeof(cuDoubleComplex)));
      // allocate memory for the temporary matrix products
@@ -136,67 +139,83 @@ void cublas_zinvert(dcomplex **mat, np_int n, int device_id, const RuntimeSettin
      cudacall(cudaMalloc<cuDoubleComplex>(&d_id, sizeof(cuDoubleComplex)));
      cudacall(cudaMemcpy(d_id, m_id, sizeof(cuDoubleComplex),cudaMemcpyHostToDevice)); // copy identity to device vector
      delete[] native_id; // free identity vector on host

      // Detect the maximum value of the inverse matrix.
      double oldmax = 2.0e+16, curmax = 1.0e+16;
      np_int maxindex;
      cuDoubleComplex cublasmax;
      cublascall(CUIZAMAX(handle, mm, d_c, 1, &maxindex));
      cudacall(cudaMemcpy(&cublasmax, d_c + maxindex - 1, sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost));
      curmax = cabs( (((double *) &(cublasmax))[0]) + I * (((double *) &(cublasmax))[1]));
      sprintf(buffer, "INFO: largest matrix value has modulus %.5le.\n", curmax);
      message = buffer;
      rs.logger->log(message);

      // Iterative refinement loop
      int max_ref_iters = rs.max_ref_iters;
      for (int ri = 0; ri < max_ref_iters; ri++) {
	oldmax = curmax;
	// multiply minus the original matrix times the inverse matrix
	// NOTE: factors in zgemm are swapped because zgemm is designed for column-major
	// Fortran-style arrays, whereas our arrays are C-style row-major.
	cublascall(CUZGEMM(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, m, m, &cu_mone, d_c, m, d_a, m, &cu_zero, d_a_residual, m));
	// add the identity to the product
	cublascall(CUZAXPY(handle, m, &cu_one, d_id, 0, d_a_residual, m+1));
      double oldmax=0;
      np_int maxindex;
	// find the maximum absolute value of the residual
	cublascall(CUIZAMAX(handle, mm, d_a_residual, 1, &maxindex));
      cuDoubleComplex cublasmax;
	// transfer the maximum value to the host
	cudacall(cudaMemcpy(&cublasmax, d_a_residual + maxindex - 1, sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost));
	// take the module
      oldmax = cabs( (((double *) &(cublasmax))[0]) + I*(((double *) &(cublasmax))[1]));
      //printf("Initial max residue = %g\n", oldmax);
      sprintf(buffer, "INFO: Initial max residue is %.5le.\n", oldmax);
	curmax = cabs( (((double *) &(cublasmax))[0]) + I * (((double *) &(cublasmax))[1]));
	sprintf(buffer, "DEBUG: iteration %d has residue %.5le; target residue is %.5le.\n", ri, curmax, rs.accuracy_goal);
	message = buffer;
      rs.logger->log(message);
      if (oldmax > rs.accuracy_goal) {
	for (int iter = 0; iter < rs.max_ref_iters; iter++) {
	rs.logger->log(message, LOG_DEBG);
	if (curmax < 0.5) { // Safe conditions for Newton-Schultz iteration.
	  if (curmax < 0.95 * oldmax) { // Newton-Schultz iteration is improving and can proceed.
	    // multiply the inverse times the residual, add to the initial inverse
	    cublascall(CUZGEMM(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, m, m, &cu_one, d_a_residual, m, d_c, m, &cu_zero, d_a_refine, m));
	    // add to the initial inverse
	    cublascall(CUZAXPY(handle, mm, &cu_one, d_a_refine, 1, d_c, 1));
	  // multiply minus the original matrix times the new inverse matrix
	  cublascall(CUZGEMM(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, m, m, &cu_mone, d_c, m, d_a, m, &cu_zero, d_a_residual, m));
	  // add the identity to the product
	  cublascall(CUZAXPY(handle, m, &cu_one, d_id, 0, d_a_residual, m+1));
	  // find the maximum absolute value of the residual
	  np_int maxindex;
	  cublascall(CUIZAMAX(handle, mm, d_a_residual, 1, &maxindex));
	  // transfer the maximum value to the host
	  cuDoubleComplex cublasmax;
	  cudacall(cudaMemcpy(&cublasmax, d_a_residual+maxindex-1, sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost));
	  // take the module
	  double newmax = cabs( (((double *) &(cublasmax))[0]) + I*(((double *) &(cublasmax))[1]));
	  sprintf(buffer, "DEBUG: Maximum residue after %d iterations = %.5le\n", iter+1, newmax);
	    if (curmax < rs.accuracy_goal) {
	      message = "DEBUG: good news - optimal convergence achieved. Stopping.\n";
	      rs.logger->log(message, LOG_DEBG);
	      ref_err = 0;
	      break; // ri for
	    }
	  } else {
	    if (curmax > 0.1) {
	      sprintf(buffer, "INFO: iteration %d achieved limiting residue %.5le. Cannot reach goal. Reverting.\n", ri, curmax);
	      message = buffer;
	      rs.logger->log(message);
	  // if the maximum in the residual decreased from the previous iteration,
	  // update oldmax and go on, otherwise no point further iterating refinements
	  // if ((refinemode==2) && ((newmax > oldmax)||(newmax < accuracygoal))) iteraterefine = false;
	  if (newmax < rs.accuracy_goal) {
	    message = "INFO: optimal convergence achieved. Stopping.\n";
	    rs.logger->log(message);
	    break; // iter for
	  }
	  if (newmax > 0.99 * oldmax) {
	    message = "WARN: cannot improve further. Stopping.\n";
	      ref_err = 1;
	    } else {
	      sprintf(buffer, "WARN: iteration %d achieved limiting residue %.5le. Stopping.\n", ri, curmax);
	      message = buffer;
	      rs.logger->log(message, LOG_WARN);
	    break; // iter for
	    }
	  oldmax = newmax;
	    break; // ri for
	  }
	} else { // curmax > 0.5. Newton-Schultz iteration is dangerous.
	  if (curmax < oldmax) {
	    sprintf(buffer, "WARN: iteration %d has residue %.5le. Iterating is dangerous. Stopping.\n", ri, curmax);
	    message = buffer;
	    rs.logger->log(message, LOG_WARN);
	  } else {
	message = "INFO: starting accuracy is already good.\n";
	    ref_err = 1;
	    sprintf(buffer, "INFO: iteration %d has residue %.5le. Reverting to unrefined and stopping.\n", ri, curmax);
	    message = buffer;
	    rs.logger->log(message);
	  }
	  break; // ri for
	}
      } // end of ri for
      cudaFree(d_a_residual);
      cudaFree(d_a_refine);
      cudaFree(d_id);
    }
    if (ref_err == 0) { // Refinement was fine, we can copy refined matrix back to host.
      cudacall(cudaMemcpy(a, d_c, m * m * sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost));
    }
    cudaFree(batch_d_a);
    cudaFree(batch_d_c);
    cudaFree(piv);
+49 −24
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@

#ifdef USE_LAPACK

#include <string>
#include <cstring>

#ifndef INCLUDE_TYPES_H_
#include "../include/types.h"
@@ -136,28 +136,29 @@ lapack_int lapack_newton(
  lcomplex lapack_mone = -1.0 + I * 0.0;
  lapack_int mm = m * m;
  lapack_int incx, incy;
  lcomplex *ax, *r;
  lcomplex *ax, *r, *unrefined;
  lcomplex *id_diag = new lcomplex[m];
  double oldmax = 2.0e+16, curmax = 1.0e+16;
  for (lapack_int hi = 0; hi < m; hi++)
    id_diag[hi] = lapack_one;
  ax = new lcomplex[mm];
  r = new lcomplex[mm];
  double max_residue, target_residue;
  // Create a back-up copy of the unrefined matrix.
  unrefined = new lcomplex[mm];
  memcpy(unrefined, a, mm * sizeof(lcomplex));
  incx = 1;
  lapack_int maxindex = izamax_(&mm, a, &incx) - 1;
  lcomplex lapackmax = a[maxindex];
  curmax = cabs(lapackmax); //cabs(magmamax.x + I * magmamax.y);
  target_residue = curmax * rs.accuracy_goal;
  sprintf(buffer, "INFO: largest matrix value has modulus %.5le; target residue is %.5le.\n", curmax, target_residue);
  curmax = cabs(lapackmax);
  sprintf(buffer, "INFO: largest matrix value has modulus %.5le.\n", curmax);
  message = buffer;
  rs.logger->log(message);
  for (int ri = 0; ri < max_ref_iters; ri++) {
    oldmax = curmax;
    // Compute -A*X
    zgemm_(
      &lapackNoTrans, &lapackNoTrans, &m, &m, &m, &lapack_mone, a, &m,
      a_orig, &m, &lapack_zero, ax, &m
      &lapackNoTrans, &lapackNoTrans, &m, &m, &m, &lapack_mone, a_orig, &m,
      a, &m, &lapack_zero, ax, &m
    );
    // Transform -A*X into (I - A*X)
    incx = 1;
@@ -166,10 +167,11 @@ lapack_int lapack_newton(
    maxindex = izamax_(&mm, ax, &incx) - 1;
    lapackmax = ax[maxindex];
    curmax = cabs(lapackmax);
    sprintf(buffer, "DEBUG: iteration %d has residue %.5le; target residue is %.5le.\n", ri, curmax, target_residue);
    sprintf(buffer, "DEBUG: iteration %d has residue %.5le; target residue is %.5le.\n", ri, curmax, rs.accuracy_goal);
    message = buffer;
    rs.logger->log(message, LOG_DEBG);
    if (curmax < 0.99 * oldmax) {
    if (curmax < 0.5) { // Safe conditions for Newton-Schultz iteration.
      if (curmax < 0.95 * oldmax) { // Newton-Schultz iteration is improving and can proceed.
	// Compute R = (I - A*X)*X
	zgemm_(
          &lapackNoTrans, &lapackNoTrans, &m, &m, &m, &lapack_one, a, &m,
@@ -177,20 +179,43 @@ lapack_int lapack_newton(
        );
	// Set X = X + R
	zaxpy_(&mm, &lapack_one, r, &incx, a, &incx);
      if (curmax < target_residue) {
	if (curmax < rs.accuracy_goal) {
	  message = "DEBUG: good news - optimal convergence achieved. Stopping.\n";
	  rs.logger->log(message, LOG_DEBG);
	  err = LAPACK_SUCCESS;
	  break; // ri for
        }
      } else {
      message = "WARN: not so good news - cannot improve further. Stopping.\n";
	if (curmax > 0.1) {
	  sprintf(buffer, "INFO: iteration %d achieved limiting residue %.5le. Cannot reach goal. Reverting.\n", ri, curmax);
	  message = buffer;
	  rs.logger->log(message);
	  memcpy(a, unrefined, mm * sizeof(lcomplex));
	} else {
	  sprintf(buffer, "WARN: iteration %d achieved limiting residue %.5le. Stopping.\n", ri, curmax);
	  message = buffer;
	  rs.logger->log(message, LOG_WARN);
	}
	break; // ri for
      }
    } else { // curmax > 0.5. Newton-Schultz iteration is dangerous.
      if (curmax < oldmax) {
	sprintf(buffer, "WARN: iteration %d has residue %.5le. Iterating is dangerous. Stopping.\n", ri, curmax);
	message = buffer;
	rs.logger->log(message, LOG_WARN);
      } else {
	sprintf(buffer, "INFO: iteration %d has residue %.5le. Reverting to unrefined and stopping.\n", ri, curmax);
	message = buffer;
	rs.logger->log(message);
	memcpy(a, unrefined, mm * sizeof(lcomplex));
      }
      break; // ri for
    }
  } // end of ri for
  delete[] id_diag;
  delete[] ax;
  delete[] r;
  delete[] unrefined;
  return err;
}

+72 −138
Original line number Diff line number Diff line
@@ -64,7 +64,7 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id, const Runt
  magmaDoubleComplex *a = (magmaDoubleComplex *)mat[0]; // pointer to first element on host
  string message;
  char buffer[128];
  magma_int_t err = MAGMA_SUCCESS;
  magma_int_t err = MAGMA_SUCCESS, ref_err = MAGMA_SUCCESS;
  magma_queue_t queue = NULL;
  magma_device_t dev = (magma_device_t)device_id;
  magma_queue_create(dev, &queue);
@@ -99,9 +99,14 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id, const Runt
    delete[] piv; // delete piv created by magma_zgetrf()
    magma_free(dwork);
    if (rs.use_refinement) {
      err = magma_newton(rs, a, m, d_a, queue);
      // magma_newton makes a back-up copy of the unrefined inverted
      // matrix on the host. If refinement fails, the err flag is set
      // to a non-zero value, to prevent corrupting the inverted matrix.
      ref_err = magma_newton(rs, a, m, d_a, queue);
    }
    if (ref_err == MAGMA_SUCCESS) { // Refinement did improve the inversion accuracy, so we retireve the refined matrix.
      magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy d_a -> a
    }
    magma_free(d_a);
    // >>> END OF LU INVERSION <<<
  } else if (rs.invert_mode == RuntimeSettings::INV_MODE_GESV) {
@@ -132,9 +137,14 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id, const Runt
    delete[] piv; // free host memory
    magma_free(d_a);
    if (rs.use_refinement) {
      err = magma_newton(rs, a, m, d_id, queue);
      // magma_newton makes a back-up copy of the unrefined inverted
      // matrix on the host. If refinement fails, the err flag is set
      // to a non-zero value, to prevent corrupting the inverted matrix.
      ref_err = magma_newton(rs, a, m, d_id, queue);
    }
    if (ref_err == MAGMA_SUCCESS) { // Refinement did improve the inversion accuracy, so we retireve the refined matrix.
      magma_zgetmatrix(m, m, d_id , m, a, m, queue); // copy d_id -> a
    }
    magma_free(d_id);
    // >>> END OF GESV INVERSION <<<
  } else if (rs.invert_mode == RuntimeSettings::INV_MODE_RBT) {
@@ -193,15 +203,8 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id, const Runt
    delete[] h_rwork;
    delete[] h_iwork;
    delete[] h_a;
    sprintf(buffer, "%.5le", h_s[0]);
    message = "DEBUG: s[0] = ";
    message += buffer;
    message += "; s[";
    message += to_string(m - 1);
    message += "] = ";
    sprintf(buffer, "%.5le", h_s[m - 1]);
    message += buffer;
    message += ".\n";
    sprintf(buffer, "DEBUG: s[0] = %.5le; s[%s] = %.5le.\n", h_s[0], to_string(m - 1).c_str(), h_s[m - 1]);
    message = buffer;
    rs.logger->log(message, LOG_DEBG);
    
    // Step 2: Upload decomposed matix to GPU
@@ -251,11 +254,16 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id, const Runt

    // Step 5: refine inversion
    if (rs.use_refinement) {
      err = magma_newton(rs, a, m, d_a, queue);
      // magma_newton makes a back-up copy of the unrefined inverted
      // matrix on the host. If refinement fails, the err flag is set
      // to a non-zero value, to prevent corrupting the inverted matrix.
      ref_err = magma_newton(rs, a, m, d_a, queue);
    }

    // Step 6: get result back to host
    if (ref_err == MAGMA_SUCCESS) { // Refinement did improve the inversion accuracy, so we retireve the refined matrix.
      magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy d_a -> a
    }
    magma_free(d_a);
    // >>> END OF SVD INVERSION <<<
  }
@@ -263,109 +271,11 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id, const Runt
  jer = (int)err;
}

magma_int_t magma_newton_norm(
  const RuntimeSettings& rs, magmaDoubleComplex* a, const magma_int_t m,
  magmaDoubleComplex* d_a, magma_queue_t queue
) {
  magma_int_t err;
  string message;
  char buffer[128];
  const int max_ref_iters = rs.max_ref_iters;
  const magmaDoubleComplex magma_zero = MAGMA_Z_MAKE(0.0, 0.0);
  const magmaDoubleComplex magma_one = MAGMA_Z_MAKE(1.0, 0.0);
  const magmaDoubleComplex magma_mone = MAGMA_Z_MAKE(-1.0, 0.0);
  const magmaDoubleComplex magma_two = MAGMA_Z_MAKE(2.0, 0.0);
  const magma_int_t mm = m * m;
  magmaDoubleComplex *d_a_orig, *d_ax, *d_axa;
  magmaDoubleComplex *h_id_diag = new magmaDoubleComplex[m];
  magmaDoubleComplex *d_id_diag;
  for (magma_int_t hi = 0; hi < m; hi++)
    h_id_diag[hi] = magma_one;
  err = magma_zmalloc(&d_id_diag, m);
  if (err != MAGMA_SUCCESS) {
    message = "ERROR: could not allocate d_id_diag!\n";
    rs.logger->err(message);
    exit(1);
  }
  magma_zsetvector(m, h_id_diag, 1, d_id_diag, 1, queue);
  delete[] h_id_diag;
  double norm = 1.0e+16, old_norm = 2.0e+16;
  err = magma_zmalloc(&d_a_orig, mm);
  if (err != MAGMA_SUCCESS) {
    message = "ERROR: could not allocate d_a_orig!\n";
    rs.logger->err(message);
    exit(1);
  }
  magma_zsetmatrix(m, m, a, m, d_a_orig, m, queue); // copy a -> d_a_orig
  err = magma_zmalloc(&d_ax, mm);
  if (err != MAGMA_SUCCESS) {
    message = "ERROR: could not allocate d_ax!\n";
    rs.logger->err(message);
    exit(1);
  }
  err = magma_zmalloc(&d_axa, mm);
  if (err != MAGMA_SUCCESS) {
    message = "ERROR: could not allocate d_axa!\n";
    rs.logger->err(message);
    exit(1);
  }
  double normA = magma_dznrm2(mm, d_a_orig, 1, queue);
  if (normA == 0.0) normA = 1.0;
  for (int ri = 0; ri < max_ref_iters; ri++) {
    // Compute A*X
    magmablas_zgemm(
      MagmaNoTrans, MagmaNoTrans, m, m, m, magma_one, d_a_orig, m,
      d_a, m, magma_zero, d_ax, m, queue
    );
    // Compute A*X*A
    magmablas_zgemm(
      MagmaNoTrans, MagmaNoTrans, m, m, m, magma_one, d_ax, m,
      d_a_orig, m, magma_zero, d_axa, m, queue
    );
    // Now we use d_axa in place of R: d_axa = A*X*A - A
    magma_zaxpy(mm, magma_mone, d_a_orig, 1, d_axa, 1, queue);
    // Compute the norm of R
    norm = magma_dznrm2(mm, d_axa, 1, queue);
    double relative_res = norm / normA;
    sprintf(buffer, "%.5le", relative_res);
    message = "INFO: refinement iteration " + to_string(ri) + " achieved relative residual of "
      + buffer + ".\n";
    rs.logger->log(message);
    if (norm < 0.99 * old_norm) {
      // Transform d_ax in 2*I - A*X
      magma_zscal(mm, magma_mone, d_ax, 1, queue);
      magma_zaxpy(m, magma_two, d_id_diag, 1, d_ax, m + 1, queue);
      // Replace d_axa with X*(2*I - A*X)
      magma_zgemm(
        MagmaNoTrans, MagmaNoTrans, m, m, m, magma_one, d_a, m,
	d_ax, m, magma_zero, d_axa, m, queue
      );
      // Set d_a = X*(2*I - A*X)
      magma_zcopy(mm, d_axa, 1, d_a, 1, queue);
      old_norm = norm;
      if (relative_res < 1.0e-16) {
	message = "DEBUG: good news - optimal convergence achieved. Stopping.\n";
	rs.logger->log(message, LOG_DEBG);
	break; // ri for
      }
    } else {
      message = "WARN: not so good news - cannot improve further. Stopping.\n";
      rs.logger->log(message, LOG_WARN);
      break; // ri for
    }
  }
  magma_free(d_a_orig);
  magma_free(d_ax);
  magma_free(d_axa);
  magma_free(d_id_diag);
  return err;
}

magma_int_t magma_newton(
  const RuntimeSettings& rs, magmaDoubleComplex* a, const magma_int_t m,
  magmaDoubleComplex* d_a, magma_queue_t queue
) {
  magma_int_t err;
  magma_int_t err = MAGMA_SUCCESS;
  string message;
  char buffer[128];
  const int max_ref_iters = rs.max_ref_iters;
@@ -395,6 +305,7 @@ magma_int_t magma_newton(
    exit(1);
  }
  magma_zsetmatrix(m, m, a, m, d_a_orig, m, queue); // copy a -> d_a_orig
  magma_zgetmatrix(m, m, d_a, m, a, m, queue); // copy pre-refinement d_a -> a
  err = magma_zmalloc(&d_ax, mm);
  if (err != MAGMA_SUCCESS) {
    message = "ERROR: could not allocate d_ax!\n";
@@ -407,49 +318,72 @@ magma_int_t magma_newton(
    rs.logger->err(message);
    exit(1);
  }
  double max_residue, target_residue;
  magma_int_t maxindex = magma_izamax(mm, d_a, 1, queue) - 1;
  magma_queue_sync(queue);
  magmaDoubleComplex magmamax = magma_mone;
  magma_zgetvector(1, d_a + maxindex, 1, &magmamax, 1, queue);
  curmax = MAGMA_Z_ABS(magmamax); //cabs(magmamax.x + I * magmamax.y);
  target_residue = curmax * rs.accuracy_goal;
  sprintf(buffer, "INFO: largest matrix value has modulus %.5le; target residue is %.5le.\n", curmax, target_residue);
  curmax = MAGMA_Z_ABS(magmamax);
  sprintf(buffer, "INFO: largest matrix value has modulus %.5le.\n", curmax);
  message = buffer;
  rs.logger->log(message);
  for (int ri = 0; ri < max_ref_iters; ri++) {
    oldmax = curmax;
    // Compute -A*X
    magmablas_zgemm(
      MagmaNoTrans, MagmaNoTrans, m, m, m, magma_mone, d_a_orig, m,
      d_a, m, magma_zero, d_ax, m, queue
      MagmaNoTrans, MagmaNoTrans, m, m, m, magma_mone, d_a, m,
      d_a_orig, m, magma_zero, d_ax, m, queue
    );
    // Transform -A*X into (I - A*X)
    magma_zaxpy(m, magma_one, d_id_diag, 1, d_ax, m + 1, queue);
    maxindex = magma_izamax(mm, d_ax, 1, queue) - 1;
    magma_queue_sync(queue);
    magma_zgetvector(1, d_ax + maxindex, 1, &magmamax, 1, queue);
    curmax = cabs(magmamax.x + I * magmamax.y);
    sprintf(buffer, "DEBUG: iteration %d has residue %.5le; target residue is %.5le.\n", ri, curmax, target_residue);
    curmax = MAGMA_Z_ABS(magmamax);
    sprintf(buffer, "DEBUG: iteration %d has residue %.5le; target residue is %.5le.\n", ri, curmax, rs.accuracy_goal);
    message = buffer;
    rs.logger->log(message, LOG_DEBG);
    if (curmax < 0.99 * oldmax) {
      // Compute R = (A*X - I)*X
    if (curmax < 0.5) { // Safe conditions for Newton-Schultz iteration.
      if (curmax < 0.95 * oldmax) { // Newton-Schultz iteration is improving and can proceed.
	// Compute R = X*(A*X - I)
	magmablas_zgemm(
          MagmaNoTrans, MagmaNoTrans, m, m, m, magma_one, d_ax, m,
	  d_a, m, magma_zero, d_r, m, queue
        );
	// Set X = X + R
	magma_zaxpy(mm, magma_one, d_r, 1, d_a, 1, queue);
      if (curmax < target_residue) {
	if (curmax < rs.accuracy_goal) {
	  message = "DEBUG: good news - optimal convergence achieved. Stopping.\n";
	  rs.logger->log(message, LOG_DEBG);
	  err = MAGMA_SUCCESS;
	  break; // ri for
        }
      } else {
      message = "WARN: not so good news - cannot improve further. Stopping.\n";
	if (curmax > 0.1) {
	  sprintf(buffer, "INFO: iteration %d achieved limiting residue %.5le. Cannot reach goal. Reverting.\n", ri, curmax);
	  message = buffer;
	  rs.logger->log(message);
	  err = 1;
	} else {
	  sprintf(buffer, "WARN: iteration %d achieved limiting residue %.5le. Stopping.\n", ri, curmax);
	  message = buffer;
	  rs.logger->log(message, LOG_WARN);
	}
	break; // ri for
      }
    } else { // curmax > 0.5. Newton-Schultz iteration is dangerous.
      if (curmax < oldmax) {
	sprintf(buffer, "WARN: iteration %d has residue %.5le. Iterating is dangerous. Stopping.\n", ri, curmax);
	message = buffer;
	rs.logger->log(message, LOG_WARN);
      } else {
	err = 1;
	sprintf(buffer, "INFO: iteration %d has residue %.5le. Reverting to unrefined and stopping.\n", ri, curmax);
	message = buffer;
	rs.logger->log(message);
      }
      break; // ri for
    }
  } // end of ri for
  magma_free(d_a_orig);
  magma_free(d_ax);
  magma_free(d_r);
Loading