Commit 084e3c97 authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Implement LAPACK based LU inversion and refinement

parent abe72174
Loading
Loading
Loading
Loading
+30 −12
Original line number Diff line number Diff line
@@ -23,7 +23,18 @@
#ifndef INCLUDE_LAPACK_CALLS_H_
#define INCLUDE_LAPACK_CALLS_H_

/*! \brief Invert a complex matrix with double precision elements.
#ifndef LAPACK_SUCCESS
#define LAPACK_SUCCESS 0
#endif

#ifdef USE_MKL
typedef MKL_Complex16 lcomplex;
#else
typedef dcomplex lcomplex;
#endif // USE_MKL

/**
 * \brief Invert a complex matrix with double precision elements.
 *
 * Use LAPACKE64 to perform an in-place matrix inversion for a complex
 * matrix with double precision elements.
@@ -31,21 +42,28 @@
 * \param mat: Matrix of complex. The matrix to be inverted.
 * \param n: `np_int` The number of rows and columns of the [n x n] matrix.
 * \param jer: `int &` Reference to an integer return flag.
 * \param rs: `const RuntimeSettings &` Runtime settings instance.
 */
void zinvert(dcomplex **mat, np_int n, int &jer);
void zinvert(dcomplex **mat, np_int n, int &jer, const RuntimeSettings& rs=RuntimeSettings());

/*! \brief Invert a complex matrix with double precision elements, using iterative refinement to improve accuracy.
/**
 * \brief Perform Newton-Schulz iterative refinement of matrix inversion.
 *
 * Use LAPACKE64 to perform an in-place matrix inversion for a complex
 * matrix with double precision elements.
 * In this function the residual of the inversion of a matrix A is evaluated as:
 *
 * \param mat: Matrix of complex. The matrix to be inverted.
 * \param n: `np_int` The number of rows and columns of the [n x n] matrix.
 * \param jer: `int &` Reference to an integer return flag.
 * \param maxrefiters: `int` Maximum number of refinement iterations.
 * \param accuracygoal: `double &` Accuracy to achieve in iterative refinement, defined as the module of the maximum difference between the identity matrix and the matrix product of the (approximate) inverse times the original matrix. On return, it contains the actually achieved accuracy.
 * \param refinemode: `int` Flag to control the refinement mode.
 * R = A^-1 A - I
 *
 * and the convergence of refinement is estimated through the largest element
 * modulus left in R.
 *
 * \param rs: `const RuntimeSettings &` Runtime settings instance. [IN]
 * \param a_orig: `lcomplex *` Pointer to the first element of the non-inverted matrix. [IN]
 * \param m: `const lapack_int` Number of rows / columns in a. [IN]
 * \param a: `lcomplex *` Pointer to the inverted matrix. [IN/OUT]
 * \return err: `lapack_int` An error code (LAPACK_SUCCESS, if everything was fine).
 */
void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int &maxrefiters, double &accuracygoal, int refinemode);
lapack_int lapack_newton(
  const RuntimeSettings& rs, lcomplex* a_orig, const lapack_int m, lcomplex* a
);

#endif
+7 −11
Original line number Diff line number Diff line
@@ -26,12 +26,6 @@ using namespace std;
#include "../include/types.h"
#endif

#ifdef USE_LAPACK
#ifndef INCLUDE_LAPACK_CALLS_H_
#include "../include/lapack_calls.h"
#endif
#endif

#ifndef INCLUDE_LOGGING_H_
#include "../include/logging.h"
#endif
@@ -40,6 +34,12 @@ using namespace std;
#include "../include/Configuration.h"
#endif

#ifdef USE_LAPACK
#ifndef INCLUDE_LAPACK_CALLS_H_
#include "../include/lapack_calls.h"
#endif
#endif

#ifdef USE_MAGMA
#ifndef INCLUDE_MAGMA_CALLS_H_
#include "../include/magma_calls.h"
@@ -82,11 +82,7 @@ void invert_matrix(
  cublas_zinvert(mat, size, target_device);
#endif
#elif defined USE_LAPACK
#ifdef USE_REFINEMENT
  zinvert_and_refine(mat, size, ier, maxrefiters, accuracygoal, refinemode);
#else
  zinvert(mat, size, ier);
#endif
  zinvert(mat, size, ier, rs);
#else
  lucin(mat, size, size, ier);
#endif
+109 −144
Original line number Diff line number Diff line
@@ -18,6 +18,9 @@
 *
 * \brief Implementation of the interface with LAPACK libraries.
 */

#include <string>

#ifndef INCLUDE_TYPES_H_
#include "../include/types.h"
#endif
@@ -33,164 +36,126 @@

#ifdef USE_LAPACK

#ifndef INCLUDE_LAPACK_CALLS_H_
#include "../include/lapack_calls.h"
#endif

#include <limits>

#ifdef USE_MKL
  extern "C" void zcopy_(np_int *n, MKL_Complex16 *arr1, np_int *inc1, MKL_Complex16 *arr2,
		     np_int *inc2);
  extern "C" void zgemm_(char *transa, char *transb, np_int *l, np_int *m, np_int *n,
		     MKL_Complex16 *alpha, MKL_Complex16 *a, np_int *lda,
		     MKL_Complex16 *b, np_int *ldb, MKL_Complex16 *beta,
		     MKL_Complex16 *c, np_int *ldc);
  extern "C" void zaxpy_(np_int *n, MKL_Complex16 *alpha, MKL_Complex16 *arr1, np_int *inc1,
		     MKL_Complex16 *arr2, np_int *inc2);
  extern "C" np_int izamax_(np_int *n, MKL_Complex16 *arr1, np_int *inc1);
#else
  extern "C" void zcopy_(np_int *n, dcomplex *arr1, np_int *inc1, dcomplex *arr2, np_int *inc2);
  extern "C" void zgemm_(char *transa, char *transb, np_int *l, np_int *m, np_int *n,
		     dcomplex *alpha, dcomplex *a, np_int *lda,
		     dcomplex *b, np_int *ldb, dcomplex *beta,
		     dcomplex *c, np_int *ldc);
  extern "C" void zaxpy_(np_int *n, dcomplex *alpha, dcomplex *arr1, np_int *inc1,
		     dcomplex *arr2, np_int *inc2);
  extern "C" np_int izamax_(np_int *n, dcomplex *arr1, np_int *inc1);
#ifndef INCLUDE_LOGGING_H_
#include "../include/logging.h"
#endif

void zinvert(dcomplex **mat, np_int n, int &jer) {
  jer = 0;
  dcomplex *arr = &(mat[0][0]);
  const dcomplex uim = 0.0 + 1.0 * I;

#ifdef USE_MKL
  MKL_Complex16 *arr2 = (MKL_Complex16 *) arr;
#ifndef INCLUDE_CONFIGURATION_H_
#include "../include/Configuration.h"
#endif

  np_int* IPIV = new np_int[n]();
  
#ifdef USE_MKL
  LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n, arr2, n, IPIV);
  LAPACKE_zgetri(LAPACK_ROW_MAJOR, n, arr2, n, IPIV);
#else
  LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n, arr, n, IPIV);
  LAPACKE_zgetri(LAPACK_ROW_MAJOR, n, arr, n, IPIV);
#ifndef INCLUDE_LAPACK_CALLS_H_
#include "../include/lapack_calls.h"
#endif

  delete[] IPIV;
}
extern "C" void zcopy_(np_int *n, lcomplex *arr1, np_int *inc1, lcomplex *arr2, np_int *inc2);
extern "C" void zgemm_(
  char *transa, char *transb, np_int *l, np_int *m, np_int *n, lcomplex *alpha, lcomplex *a,
  np_int *lda, lcomplex *b, np_int *ldb, lcomplex *beta, lcomplex *c, np_int *ldc
);
extern "C" void zaxpy_(
  np_int *n, lcomplex *alpha, lcomplex *arr1, np_int *inc1, lcomplex *arr2, np_int *inc2
);
extern "C" np_int izamax_(np_int *n, lcomplex *arr1, np_int *inc1);

void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int &maxiters, double &accuracygoal, int refinemode) {
using namespace std;

void zinvert(dcomplex **mat, np_int n, int &jer, const RuntimeSettings& rs) {
  jer = 0;
#ifdef USE_MKL
  MKL_Complex16 *arr = (MKL_Complex16 *) &(mat[0][0]);
#else
  dcomplex *arr = &(mat[0][0]);
#endif
  char buffer[128];
  string message;
  lapack_int info;
  lcomplex *arr = &(mat[0][0]);
  lcomplex *arr_orig;
  np_int nn = n * n;
  np_int incx = 1;
  np_int incx0 = 0;
#ifdef USE_MKL
  MKL_Complex16 *arr_orig = NULL;
  MKL_Complex16 *arr_residual = NULL;
  MKL_Complex16 *arr_refine = NULL;
  MKL_Complex16 *id = NULL;
#else
  dcomplex *arr_orig = NULL;
  dcomplex *arr_residual = NULL;
  dcomplex *arr_refine = NULL;
  dcomplex *id = NULL;
#endif
  if (maxiters>0) {
#ifdef USE_MKL
    arr_orig = new MKL_Complex16[nn];
    arr_residual = new MKL_Complex16[nn];
    arr_refine = new MKL_Complex16[nn];
    id = new MKL_Complex16[1];
    id[0].real =  1;
    id[0].imag =  0;
#else
    arr_orig = new dcomplex[nn];
    arr_residual = new dcomplex[nn];
    arr_refine = new dcomplex[nn];
    id = new dcomplex[1];
    id[0] = (dcomplex) 1;
#endif
    zcopy_(&nn, arr, &incx, arr_orig, &incx);
  if (rs.use_refinement) {
    lapack_int inc1 = 1;
    arr_orig = new lcomplex[nn];
    zcopy_(&nn, arr, &inc1, arr_orig, &inc1);
  }
  // const dcomplex uim = 0.0 + 1.0 * I;
  
  np_int* IPIV = new np_int[n]();
  
  if (rs.invert_mode == RuntimeSettings::INV_MODE_LU) {
    np_int* IPIV = new np_int[n];
    LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n, arr, n, IPIV);
    LAPACKE_zgetri(LAPACK_ROW_MAJOR, n, arr, n, IPIV);
    delete[] IPIV;

  if (maxiters>0) {
    bool iteraterefine = true;
    char transa = 'N';
#ifdef USE_MKL
    MKL_Complex16 dczero;
    dczero.real = 0;
    dczero.imag = 0;
    MKL_Complex16 dcone;
    dcone.real = 1;
    dcone.imag = 0;
    MKL_Complex16 dcmone;
    dcmone.real = -1;
    dcmone.imag = 0;
#else
    dcomplex dczero = 0;
    dcomplex dcone = 1;
    dcomplex dcmone = -1;
#endif
    // multiply minus the original matrix times the inverse matrix
    // NOTE: factors in zgemm are swapped because zgemm is designed for column-major
    // Fortran-style arrays, whereas our arrays are C-style row-major.
    zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr, &n, arr_orig, &n, &dczero, arr_residual, &n);
    np_int incy = n+1;
    zaxpy_(&n, &dcone, id, &incx0, arr_residual, &incy);
    double oldmax = 0;
    if (refinemode >0) {
      np_int maxindex = izamax_(&nn, arr_residual, &incx);
#ifdef USE_MKL
      oldmax = cabs(arr_residual[maxindex].real + I*arr_residual[maxindex].imag);
#else
      oldmax = cabs(arr_residual[maxindex]);
#endif
      printf("Initial max residue = %g\n", oldmax);
      if (oldmax < accuracygoal) iteraterefine = false;
    }
    int iter;
    for (iter=0; (iter<maxiters) && iteraterefine; iter++) {
      zgemm_(&transa, &transa, &n, &n, &n, &dcone, arr_residual, &n, arr, &n, &dczero, arr_refine, &n);
      zaxpy_(&nn, &dcone, arr_refine, &incx, arr, &incx);
	// zcopy_(&nn, arr_refine, &incx, arr, &incx);
      zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr, &n, arr_orig, &n, &dczero, arr_residual, &n);
      zaxpy_(&n, &dcone, id, &incx0, arr_residual, &incy);
      if ((refinemode==2) || ((refinemode==1) && (iter == (maxiters-1)))) {
	np_int maxindex = izamax_(&nn, arr_residual, &incx);
#ifdef USE_MKL
	double newmax = cabs(arr_residual[maxindex].real + I*arr_residual[maxindex].imag);
#else
	double newmax = cabs(arr_residual[maxindex]);
#endif
	printf("Max residue after %d iterations = %g\n", iter+1, newmax);
	if ((refinemode==2) && ((newmax > oldmax)||(newmax < accuracygoal))) iteraterefine = false;
	oldmax = newmax; 
      }
    if (rs.use_refinement) {
      info = lapack_newton(rs, arr_orig, n, arr);
    }
    if (refinemode==2) maxiters = iter;
    accuracygoal = oldmax;
    delete[] id;
    delete[] arr_refine;
  } // inversion mode switch
  if (rs.use_refinement) {
    delete[] arr_orig;
    delete[] arr_residual;
  }
}

lapack_int lapack_newton(
  const RuntimeSettings& rs, lcomplex* a_orig, lapack_int m, lcomplex* a
) {
  lapack_int err = LAPACK_SUCCESS;
  string message;
  char buffer[128];
  char lapackNoTrans = 'N';
  const int max_ref_iters = rs.max_ref_iters;
  lcomplex lapack_zero = 0.0 + I * 0.0;
  lcomplex lapack_one = 1.0 + I * 0.0;
  lcomplex lapack_mone = -1.0 + I * 0.0;
  lapack_int mm = m * m;
  lapack_int incx, incy;
  lcomplex *ax, *r;
  lcomplex *id_diag = new lcomplex[m];
  double oldmax = 2.0e+16, curmax = 1.0e+16;
  for (lapack_int hi = 0; hi < m; hi++)
    id_diag[hi] = lapack_one;
  ax = new lcomplex[mm];
  r = new lcomplex[mm];
  double max_residue, target_residue;
  incx = 1;
  lapack_int maxindex = izamax_(&mm, a, &incx) - 1;
  lcomplex lapackmax = a[maxindex];
  curmax = cabs(lapackmax); //cabs(magmamax.x + I * magmamax.y);
  target_residue = curmax * rs.accuracy_goal;
  sprintf(buffer, "INFO: largest matrix value has modulus %.5le; target residue is %.5le.\n", curmax, target_residue);
  message = buffer;
  rs.logger->log(message);
  for (int ri = 0; ri < max_ref_iters; ri++) {
    oldmax = curmax;
    // Compute -A*X
    zgemm_(
      &lapackNoTrans, &lapackNoTrans, &m, &m, &m, &lapack_mone, a, &m,
      a_orig, &m, &lapack_zero, ax, &m
    );
    // Transform -A*X into (I - A*X)
    incx = 1;
    incy = m + 1;
    zaxpy_(&m, &lapack_one, id_diag, &incx, ax, &incy);
    maxindex = izamax_(&mm, ax, &incx) - 1;
    lapackmax = ax[maxindex];
    curmax = cabs(lapackmax);
    sprintf(buffer, "DEBUG: iteration %d has residue %.5le; target residue is %.5le.\n", ri, curmax, target_residue);
    message = buffer;
    rs.logger->log(message, LOG_DEBG);
    if (curmax < 0.99 * oldmax) {
      // Compute R = (I - A*X)*X
      zgemm_(
        &lapackNoTrans, &lapackNoTrans, &m, &m, &m, &lapack_one, a, &m,
	ax, &m, &lapack_zero, r, &m
      );
      // Set X = X + R
      zaxpy_(&mm, &lapack_one, r, &incx, a, &incx);
      if (curmax < target_residue) {
	message = "DEBUG: good news - optimal convergence achieved. Stopping.\n";
	rs.logger->log(message, LOG_DEBG);
	break; // ri for
      }
    } else {
      message = "WARN: not so good news - cannot improve further. Stopping.\n";
      rs.logger->log(message, LOG_WARN);
      break; // ri for
    }
  }
  delete[] id_diag;
  delete[] ax;
  delete[] r;
  return err;
}

#endif