Commit 347bc13d authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Use back-up guard for LAPACK refinement

parent 0338fbd0
Loading
Loading
Loading
Loading
+49 −24
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@

#ifdef USE_LAPACK

#include <string>
#include <cstring>

#ifndef INCLUDE_TYPES_H_
#include "../include/types.h"
@@ -136,28 +136,29 @@ lapack_int lapack_newton(
  lcomplex lapack_mone = -1.0 + I * 0.0;
  lapack_int mm = m * m;
  lapack_int incx, incy;
  lcomplex *ax, *r;
  lcomplex *ax, *r, *unrefined;
  lcomplex *id_diag = new lcomplex[m];
  double oldmax = 2.0e+16, curmax = 1.0e+16;
  for (lapack_int hi = 0; hi < m; hi++)
    id_diag[hi] = lapack_one;
  ax = new lcomplex[mm];
  r = new lcomplex[mm];
  double max_residue, target_residue;
  // Create a back-up copy of the unrefined matrix.
  unrefined = new lcomplex[mm];
  memcpy(unrefined, a, mm * sizeof(lcomplex));
  incx = 1;
  lapack_int maxindex = izamax_(&mm, a, &incx) - 1;
  lcomplex lapackmax = a[maxindex];
  curmax = cabs(lapackmax); //cabs(magmamax.x + I * magmamax.y);
  target_residue = curmax * rs.accuracy_goal;
  sprintf(buffer, "INFO: largest matrix value has modulus %.5le; target residue is %.5le.\n", curmax, target_residue);
  curmax = cabs(lapackmax);
  sprintf(buffer, "INFO: largest matrix value has modulus %.5le.\n", curmax);
  message = buffer;
  rs.logger->log(message);
  for (int ri = 0; ri < max_ref_iters; ri++) {
    oldmax = curmax;
    // Compute -A*X
    zgemm_(
      &lapackNoTrans, &lapackNoTrans, &m, &m, &m, &lapack_mone, a, &m,
      a_orig, &m, &lapack_zero, ax, &m
      &lapackNoTrans, &lapackNoTrans, &m, &m, &m, &lapack_mone, a_orig, &m,
      a, &m, &lapack_zero, ax, &m
    );
    // Transform -A*X into (I - A*X)
    incx = 1;
@@ -166,10 +167,11 @@ lapack_int lapack_newton(
    maxindex = izamax_(&mm, ax, &incx) - 1;
    lapackmax = ax[maxindex];
    curmax = cabs(lapackmax);
    sprintf(buffer, "DEBUG: iteration %d has residue %.5le; target residue is %.5le.\n", ri, curmax, target_residue);
    sprintf(buffer, "DEBUG: iteration %d has residue %.5le; target residue is %.5le.\n", ri, curmax, rs.accuracy_goal);
    message = buffer;
    rs.logger->log(message, LOG_DEBG);
    if (curmax < 0.95 * oldmax) {
    if (curmax < 0.5) { // Safe conditions for Newton-Schultz iteration.
      if (curmax < 0.95 * oldmax) { // Newton-Schultz iteration is improving and can proceed.
	// Compute R = (I - A*X)*X
	zgemm_(
          &lapackNoTrans, &lapackNoTrans, &m, &m, &m, &lapack_one, a, &m,
@@ -177,20 +179,43 @@ lapack_int lapack_newton(
        );
	// Set X = X + R
	zaxpy_(&mm, &lapack_one, r, &incx, a, &incx);
      if (curmax < target_residue) {
	if (curmax < rs.accuracy_goal) {
	  message = "DEBUG: good news - optimal convergence achieved. Stopping.\n";
	  rs.logger->log(message, LOG_DEBG);
	  err = LAPACK_SUCCESS;
	  break; // ri for
        }
      } else {
      message = "WARN: not so good news - cannot improve further. Stopping.\n";
	if (curmax > 0.1) {
	  sprintf(buffer, "INFO: iteration %d achieved limiting residue %.5le. Cannot reach goal. Reverting.\n", ri, curmax);
	  message = buffer;
	  rs.logger->log(message);
	  memcpy(a, unrefined, mm * sizeof(lcomplex));
	} else {
	  sprintf(buffer, "WARN: iteration %d achieved limiting residue %.5le. Stopping.\n", ri, curmax);
	  message = buffer;
	  rs.logger->log(message, LOG_WARN);
	}
	break; // ri for
      }
    } else { // curmax > 0.5. Newton-Schultz iteration is dangerous.
      if (curmax < oldmax) {
	sprintf(buffer, "WARN: iteration %d has residue %.5le. Iterating is dangerous. Stopping.\n", ri, curmax);
	message = buffer;
	rs.logger->log(message, LOG_WARN);
      } else {
	sprintf(buffer, "INFO: iteration %d has residue %.5le. Reverting to unrefined and stopping.\n", ri, curmax);
	message = buffer;
	rs.logger->log(message);
	memcpy(a, unrefined, mm * sizeof(lcomplex));
      }
      break; // ri for
    }
  } // end of ri for
  delete[] id_diag;
  delete[] ax;
  delete[] r;
  delete[] unrefined;
  return err;
}