Commit f24f42dd authored by Mulas, Giacomo's avatar Mulas, Giacomo
Browse files

first working implementation of iterative refinement to assess and improve...

first working implementation of iterative refinement to assess and improve numerical stability of numerical inversion of am
parent b80dc77f
Loading
Loading
Loading
Loading
+13 −1
Original line number Diff line number Diff line
@@ -38,6 +38,8 @@
#ifdef USE_MAGMA
#include <cuda_runtime.h>
#endif
// define by hand for a first test
#define USE_REFINEMENT 1

#ifndef INCLUDE_TYPES_H_
#include "../include/types.h"
@@ -348,7 +350,7 @@ void cluster(const string& config_file, const string& data_file, const string& o
      /* for the next iterations, just always do maxiter iterations, assuming the accuracy is good enough */
      cid->refinemode = 0;
      /* add an extra iteration for margin, if this does not exceed initialmaxrefiters */
      if (cid->maxrefiters < initialmaxrefiters) cid->maxrefiters++;
      // if (cid->maxrefiters < initialmaxrefiters) cid->maxrefiters++;
      if (jer != 0) {
	// First loop failed. Halt the calculation.
	fclose(timing_file);
@@ -821,6 +823,16 @@ int cluster_jxi488_cycle(int jxi488, ScattererConfiguration *sconf, GeometryConf
  double actualaccuracy = cid->accuracygoal;
  invert_matrix(cid->am, ndit, jer, cid->maxrefiters, actualaccuracy, cid->refinemode, mxndm, cid->proc_device);
  // in principle, we should check whether the returned actualaccuracy is indeed lower than the accuracygoal, and do something about it if not
#ifdef USE_REFINEMENT
  if (cid->refinemode==2) {
    message = "INFO: calibration obtained accuracy " + to_string(actualaccuracy) + " (" + to_string(cid->accuracygoal) + " requested) in " + to_string(cid->maxrefiters) + " refinement iterations\n";
    logger->log(message);
    if (actualaccuracy > 1e-2) {
      printf("Accuracy worse than 0.01, stopping");
      exit(1);
    }
  }
#endif
#ifdef DEBUG_AM
  VirtualAsciiFile *outam2 = new VirtualAsciiFile();
  string outam2_name = output_path + "/c_AM2_JXI" + to_string(jxi488) + ".txt";
+1 −1
Original line number Diff line number Diff line
@@ -45,6 +45,6 @@ void zinvert(dcomplex **mat, np_int n, int &jer);
 * \param maxrefiters: `int` Maximum number of refinement iterations.
 * \param accuracygoal: `double` Accuracy to achieve in iterative refinement, defined as the module of the maximum difference between the identity matrix and the matrix product of the (approximate) inverse times the original matrix. On return, it contains the actually achieved accuracy
 */
void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int maxrefiters, double &accuracygoal);
void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int &maxrefiters, double &accuracygoal, int refinemode);

#endif
+1 −1
Original line number Diff line number Diff line
@@ -457,7 +457,7 @@ void ScattererConfiguration::mpibcast(const mixMPI *mpidata) {
  MPI_Bcast(&itemp, 1, MPI_INT, 0, MPI_COMM_WORLD);
  char *ctemp = strdup(_reference_variable_name.c_str());
  MPI_Bcast(ctemp, itemp, MPI_CHAR, 0, MPI_COMM_WORLD);
  delete[] ctemp;
  free(ctemp);
  MPI_Bcast(&_idfc, 1, MPI_INT, 0, MPI_COMM_WORLD);
  itemp = sizeof(bool);
  char *ptemp = (char *) &_use_external_sphere;
+2 −0
Original line number Diff line number Diff line
@@ -23,6 +23,8 @@
#endif

#ifdef USE_LAPACK
// define by hand for a first test
#define USE_REFINEMENT 1
#ifndef INCLUDE_LAPACK_CALLS_H_
#include "../include/lapack_calls.h"
#endif
+77 −57
Original line number Diff line number Diff line
@@ -82,7 +82,7 @@ void zinvert(dcomplex **mat, np_int n, int &jer) {
  delete[] IPIV;
}

void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int &maxrefiters, double &accuracygoal, int refinemode) {
void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int &maxiters, double &accuracygoal, int refinemode) {

  jer = 0;
#ifdef USE_MKL
@@ -93,20 +93,35 @@ void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int &maxrefiters, do
  np_int nn = n*n;
  np_int incx = 1;
#ifdef USE_MKL
  MKL_Complex16 *arr_orig = new MKL_Complex16[nn];
  MKL_Complex16 *arr_refine = new MKL_Complex16[nn];
  MKL_Complex16 *id = new MKL_Complex16[n];
  MKL_Complex16 *arr_orig = NULL;
    MKL_Complex16 *arr_residual = NULL;
    MKL_Complex16 *arr_refine = NULL;
    MKL_Complex16 *id = NULL;
#else
    dcomplex *arr_orig = NULL;
    dcomplex *arr_residual = NULL;
    dcomplex *arr_refine = NULL;
    dcomplex *id = NULL;
#endif
  if (maxiters>0) {
#ifdef USE_MKL
    arr_orig = new MKL_Complex16[nn];
    arr_residual = new MKL_Complex16[nn];
    arr_refine = new MKL_Complex16[nn];
    id = new MKL_Complex16[n];
    for (np_int i=0; i<n ; i++) {
      id[i].real =  1;
      id[i].imag =  0;
    }
#else
  dcomplex *arr_orig = new dcomplex[nn];
  dcomplex *arr_refine = new dcomplex[nn];
  dcomplex *id = new dcomplex[n];
    arr_orig = new dcomplex[nn];
    arr_residual = new dcomplex[nn];
    arr_refine = new dcomplex[nn];
    id = new dcomplex[n];
    for (np_int i=0; i<n ; i++) id[i] = (dcomplex) 1;
#endif
    zcopy_(&nn, arr, &incx, arr_orig, &incx);
  }
  const dcomplex uim = 0.0 + 1.0 * I;
  
  np_int* IPIV = new np_int[n]();
@@ -115,6 +130,7 @@ void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int &maxrefiters, do
  LAPACKE_zgetri(LAPACK_ROW_MAJOR, n, arr, n, IPIV);
  delete[] IPIV;

  if (maxiters>0) {
    bool iteraterefine = true;
    char transa = 'N';
#ifdef USE_MKL
@@ -135,40 +151,44 @@ void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int &maxrefiters, do
    // multiply minus the original matrix times the inverse matrix
    // NOTE: factors in zgemm are swapped because zgemm is designed for column-major
    // Fortran-style arrays, whereas our arrays are C-style row-major.
  zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr, &n, arr_orig, &n, &dczero, arr_refine, &n);
    zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr, &n, arr_orig, &n, &dczero, arr_residual, &n);
    np_int incy = n+1;
  zaxpy_(&n, &dcone, id, &incx, arr_refine, &incy);
    zaxpy_(&n, &dcone, id, &incx, arr_residual, &incy);
    double oldmax = 0;
    if (refinemode >0) {
    np_int maxindex = izamax_(&n, arr_refine, &incx);
      np_int maxindex = izamax_(&n, arr_residual, &incx);
#ifdef USE_MKL
    oldmax = cabs(arr_refine[maxindex].real + I*arr_refine[maxindex].imag);
      oldmax = cabs(arr_residual[maxindex].real + I*arr_residual[maxindex].imag);
#else
    oldmax = cabs(arr_refine[maxindex]);
      oldmax = cabs(arr_residual[maxindex]);
#endif
      if (oldmax < accuracygoal) iteraterefine = false;
    }
    int iter;
  for (iter=0; (iter<maxrefiters) && iteraterefine; iter++) {
    zgemm_(&transa, &transa, &n, &n, &n, &dcone, arr_refine, &n, arr, &n, &dcone, arr, &n);
    zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr, &n, arr_orig, &n, &dczero, arr_refine, &n);
    zaxpy_(&n, &dcone, id, &incx, arr_refine, &incy);
    if ((refinemode==2) || ((refinemode==1) && (iter == (maxrefiters-1)))) {
      np_int maxindex = izamax_(&n, arr_refine, &incx);
    for (iter=0; (iter<maxiters) && iteraterefine; iter++) {
      zgemm_(&transa, &transa, &n, &n, &n, &dcone, arr_residual, &n, arr, &n, &dczero, arr_refine, &n);
      zaxpy_(&nn, &dcone, arr_refine, &incx, arr, &incx);
	// zcopy_(&nn, arr_refine, &incx, arr, &incx);
      zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr, &n, arr_orig, &n, &dczero, arr_residual, &n);
      zaxpy_(&n, &dcone, id, &incx, arr_residual, &incy);
      if ((refinemode==2) || ((refinemode==1) && (iter == (maxiters-1)))) {
	np_int maxindex = izamax_(&n, arr_residual, &incx);
#ifdef USE_MKL
      double newmax = cabs(arr_refine[maxindex].real + I*arr_refine[maxindex].imag);
	double newmax = cabs(arr_residual[maxindex].real + I*arr_residual[maxindex].imag);
#else
      double newmax = cabs(arr_refine[maxindex]);
	double newmax = cabs(arr_residual[maxindex]);
#endif
	if ((refinemode==2) && ((newmax > oldmax)||(newmax < accuracygoal))) iteraterefine = false;
	oldmax = newmax; 
      }
    }
  if (refinemode==2) maxrefiters = iter;
    if (refinemode==2) maxiters = iter;
    accuracygoal = oldmax;
    delete[] id;
    delete[] arr_refine;
    delete[] arr_orig;
    delete[] arr_residual;
  }

}

Loading