Commit a1bfb721 authored by Mulas, Giacomo's avatar Mulas, Giacomo
Browse files

First attempt at implementing iterative refinement of inverse matrix

both in magma and lapack. Still to test, so far it just compiles
without errors.
parent 46b48d81
Loading
Loading
Loading
Loading
+14 −1
Original line number Diff line number Diff line
@@ -25,7 +25,7 @@

/*! \brief Invert a complex matrix with double precision elements.
 *
 * Use LAPACKE64 to perform an in-place matrix inversion for a complex
 * Use MAGMA to perform an in-place matrix inversion for a complex
 * matrix with double precision elements.
 *
 * \param mat: Matrix of complex. The matrix to be inverted.
@@ -35,4 +35,17 @@
 */
void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id=0);

/*! \brief Invert a complex matrix with double precision elements, applying iterative refinement of the solution
 *
 * Use MAGMA to perform an in-place matrix inversion for a complex
 * matrix with double precision elements.
 *
 * \param mat: Matrix of complex. The matrix to be inverted.
 * \param n: `np_int` The number of rows and columns of the [n x n] matrix.
 * \param jer: `int &` Reference to an integer return flag.
 * \param refiters: `int` integer number of refinement iterations to apply.
 * \param device_id: `int` ID of the device for matrix inversion offloading.
 */
void magma_zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters, int device_id);

#endif
+12 −0
Original line number Diff line number Diff line
@@ -29,6 +29,8 @@
#endif

#ifdef USE_MAGMA
// define by hand for a first test
#define USE_REFINEMENT 1
#ifndef INCLUDE_MAGMA_CALLS_H_
#include "../include/magma_calls.h"
#endif
@@ -47,9 +49,19 @@ using namespace std;
void invert_matrix(dcomplex **mat, np_int size, int &ier, np_int max_size, int target_device) {
  ier = 0;
#ifdef USE_MAGMA
#ifdef USE_REFINEMENT
  // try using the iterative refinement to obtain a more accurate solution
  const int refiters = 3;
  magma_zinvert_and_refine(mat, size, ier, refiters, target_device);
#elif
  magma_zinvert(mat, size, ier, target_device);
#endif  
#elif defined USE_LAPACK
#ifdef USE_REFINEMENT
  zinvert_and_refine(mat, size, ier, refiters);
#elif
  zinvert(mat, size, ier);
#endif
#else
  lucin(mat, max_size, size, ier);
#endif
+100 −0
Original line number Diff line number Diff line
@@ -32,10 +32,13 @@
*/

#ifdef USE_LAPACK

#ifndef INCLUDE_LAPACK_CALLS_H_
#include "../include/lapack_calls.h"
#endif

#include <limits>

void zinvert(dcomplex **mat, np_int n, int &jer) {
  jer = 0;
  dcomplex *arr = &(mat[0][0]);
@@ -57,4 +60,101 @@ void zinvert(dcomplex **mat, np_int n, int &jer) {

  delete[] IPIV;
}

void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters) {
#ifdef USE_MKL
  extern void zcopy_(np_int *n, MKL_Complex16 *arr1, np_int *inc1, MKL_Complex16 *arr2,
		     np_int *inc2);
  extern void zgemm_(char *transa, char *transb, np_int *l, np_int *m, np_int *n,
		     MKL_Complex16 *alpha, MKL_Complex16 *a, np_int *lda,
		     MKL_Complex16 *b, np_int *ldb, MKL_Complex16 *beta,
		     MKL_Complex16 *c, np_int *ldc);
  extern void zaxpy_(np_int *n, MKL_Complex16 *alpha, MKL_Complex16 *arr1, np_int *inc1,
		     MKL_Complex16 *arr2, np_int *inc2);
  extern np_int izamax_(np_int *n, MKL_Complex16 *arr1, np_int *inc1);
#else
  extern void zcopy_(np_int *n, dcomplex *arr1, np_int *inc1, dcomplex *arr2, np_int *inc2);
  extern void zgemm_(char *transa, char *transb, np_int *l, np_int *m, np_int *n,
		     dcomplex *alpha, dcomplex *a, np_int *lda,
		     dcomplex *b, np_int *ldb, dcomplex *beta,
		     dcomplex *c, np_int *ldc);
  extern void zaxpy_(np_int *n, dcomplex *alpha, dcomplex *arr1, np_int *inc1,
		     dcomplex *arr2, np_int *inc2);
  extern np_int izamax_(np_int *n, dcomplex *arr1, np_int *inc1);
#endif
  jer = 0;
#ifdef USE_MKL
  MKL_Complex16 *arr = (MKL_Complex16 *) &(mat[0][0]);
#else
  dcomplex *arr = &(mat[0][0]);
#endif
  np_int nn = n*n;
  np_int incx = 1;
#ifdef USE_MKL
  MKL_Complex16 *arr_orig = new MKL_Complex16[nn];
  MKL_Complex16 *arr_refine = new MKL_Complex16[nn];
  MKL_Complex16 *arr_unref = new MKL_Complex16[nn];
  MKL_Complex16 *id = new MKL_Complex16[n];
  for (np_int i=0; i<n ; i++) {
    id[i].real =  1;
    id[i].imag =  0;
  }
#else
  dcomplex *arr_orig = new dcomplex[nn];
  dcomplex *arr_refine = new dcomplex[nn];
  dcomplex *arr_unref = new dcomplex[nn];
  dcomplex *id = new dcomplex[n];
  for (np_int i=0; i<n ; i++) id[i] = (dcomplex) 1;
#endif
  zcopy_(&nn, arr, &incx, arr_orig, &incx);
  const dcomplex uim = 0.0 + 1.0 * I;
  
  np_int* IPIV = new np_int[n]();
  
  LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n, arr, n, IPIV);
  LAPACKE_zgetri(LAPACK_ROW_MAJOR, n, arr, n, IPIV);
  zcopy_(&nn, arr, &incx, arr_unref, &incx);
  delete[] IPIV;

  bool iteraterefine = true;
  double oldmax = std::numeric_limits<double>::max();
  char transa = 'N';
#ifdef USE_MKL
  MKL_Complex16 dczero;
  dczero.real = 0;
  dczero.imag = 0;
  MKL_Complex16 dcone;
  dcone.real = 1;
  dcone.imag = 0;
  MKL_Complex16 dcmone;
  dcmone.real = -1;
  dcmone.imag = 0;
#else
  dcomplex dczero = 0;
  dcomplex dcone = 1;
  dcomplex dcmone = -1;
#endif
  zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr_orig, &n, arr, &n, &dczero, arr_refine, &n);
  np_int incy = n+1;
  zaxpy_(&n, &dcone, id, &incx, arr_refine, &incy);
  for (int iter=0; (iter<refiters) && iteraterefine; iter++) {
    zgemm_(&transa, &transa, &n, &n, &n, &dcone, arr, &n, arr_refine, &n, &dcone, arr, &n);
    zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr_orig, &n, arr, &n, &dczero, arr_refine, &n);
    zaxpy_(&n, &dcone, id, &incx, arr_refine, &incy);
    np_int maxindex = izamax_(&n, arr_refine, &incx);
#ifdef USE_MKL
    dcomplex newzmax = arr_refine[maxindex].real + I*arr_refine[maxindex].imag;
#elif
    dcomplex newzmax = arr_refine[maxindex];
#endif
    double newmax = cabs(newzmax);
    if (newmax < oldmax) oldmax = newmax; else iteraterefine = false;
  }
  delete[] id;
  delete[] arr_refine;
  delete[] arr_orig;
  delete[] arr_unref;
  
}

#endif
+98 −0
Original line number Diff line number Diff line
@@ -23,10 +23,13 @@
#endif

#ifdef USE_MAGMA

#ifndef INCLUDE_MAGMA_CALLS_H_
#include "../include/magma_calls.h"
#endif

#include <limits>

void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
  // magma_int_t result = magma_init();
  magma_int_t err = MAGMA_SUCCESS;
@@ -53,8 +56,103 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
  magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy d_a -> a
  delete[] piv; // free host memory
  magma_free(d_a); // free device memory
  magma_free(dwork); // free device memory
  magma_queue_destroy(queue); // destroy queue
  // result = magma_finalize();
  jer = (int)err;
}

void magma_zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters, int device_id) {
  // magma_int_t result = magma_init();
  magma_int_t err = MAGMA_SUCCESS;
  magma_queue_t queue = NULL;
  magma_device_t dev = (magma_device_t)device_id;
  magma_queue_create(dev, &queue);
  magmaDoubleComplex *dwork; // workspace
  magma_int_t ldwork; // size of dwork
  magma_int_t *piv , info; // array of pivot indices
  magma_int_t m = (magma_int_t)n; // changed rows; a - mxm matrix
  magma_int_t mm = m * m; // size of a
  magmaDoubleComplex *a = (magmaDoubleComplex *)&(mat[0][0]); // pointer to first element on host
  magmaDoubleComplex *d_a; // pointer to first element on device
  magmaDoubleComplex *d_a_orig; // pointer to original array on device
  magmaDoubleComplex *d_a_refine; // pointer to residual array on device
  ldwork = m * magma_get_zgetri_nb(m); // optimal block size
  // allocate matrices
  magmaDoubleComplex *a_unref = new magmaDoubleComplex[mm]; 
  err = magma_zmalloc(&d_a, mm); // device memory for a, will contain the inverse after call to zgetri
  err = magma_zmalloc(&d_a_orig, mm); // device memory for copy of a
  err = magma_zmalloc(&dwork, ldwork); // dev. mem. for ldwork
  piv = new magma_int_t[m]; // host mem.
  magma_zsetmatrix(m, m, a, m, d_a , m, queue); // copy a -> d_a
  magma_zcopy(mm, d_a, 1, d_a_orig, 1, queue); // copy d_a -> d_a_orig on gpu
  // do the LU factorisation
  magma_zgetrf_gpu(m, m, d_a, m, piv, &info);
  // do the in-place inversion, after which d_a contains the (first approx) inverse
  magma_zgetri_gpu(m, d_a, m, piv, dwork, ldwork, &info);
  magma_zgetmatrix(m, m, d_a , m, a_unref, m, queue); // copy unrefined d_a -> a_unref
  magma_free(dwork); // free dwork, it was only needed by zgetri
  // allocate memory for the temporary matrix product
  err = magma_zmalloc(&d_a_refine, mm); // device memory for iterative correction of inverse of a
  // allocate memory for the identity vector on the host
  dcomplex *native_id = new dcomplex[m];
  for (magma_int_t i=0; i<m; i++) native_id[i] = 1;
  magmaDoubleComplex *id = (magmaDoubleComplex *) &(native_id[0]);
  // fill it with 1
  magmaDoubleComplex *d_id;
  err = magma_zmalloc(&d_id, m);
  magma_zsetvector(m, id, 1, d_id, 1, queue); // copy identity to device vector
  delete[] native_id; // free identity vector on host
  bool iteraterefine = true;
  double oldmax = std::numeric_limits<double>::max();
  magmaDoubleComplex magma_mone;
  magma_mone.x = -1;
  magma_mone.y = 0;
  magmaDoubleComplex magma_one;
  magma_one.x = 1;
  magma_one.y = 0;
  magmaDoubleComplex magma_zero;
  magma_zero.x = 0;
  magma_zero.y = 0;
  // multiply minus the original matrix times the inverse matrix
  magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m,  magma_mone, d_a_orig, m, d_a, m, magma_zero, d_a_refine, m, queue);
  // add the identity to the product
  magma_zaxpy (m, magma_one, d_id, 1, d_a_refine, m+1, queue);
  // begin correction loop (should iterate refiters times)
  for (int iter=0; (iter<refiters) && iteraterefine; iter++) {
    // multiply the inverse times the residual, add to the initial inverse
    magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m, magma_one, d_a, m, d_a_refine, m, magma_one, d_a, m, queue);
    // multiply minus the original matrix times the new inverse matrix
    magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m, magma_mone, d_a_orig, m, d_a, m, magma_zero, d_a_refine, m, queue);
    // add the identity to the product
    magma_zaxpy (m, magma_one, d_id, 1, d_a_refine, m+1, queue);
    // find the maximum absolute value of the residual
    magma_int_t maxindex = magma_izamax(mm, d_a_refine, 1, queue);
    magmaDoubleComplex magmamax;
    // transfer the maximum value to the host
    magma_zgetvector(1, d_a_refine+maxindex, 1, &magmamax, 1, queue);
    dcomplex newzmax = magmamax.x + I*magmamax.y;
    // take the module
    double newmax = cabs(newzmax);
    // if the maximum in the residual decreased from the previous iteration,
    // update oldmax and go on, otherwise no point further iterating refinements
    if (newmax < oldmax) oldmax = newmax; else iteraterefine = false;
  }
  // end correction loop
  // free temporary device arrays 
  magma_free(d_id);
  magma_free(d_a_refine);
  magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy final refined d_a -> a
  // I should probably do some meaningful check / comparison between a and a_unref
  delete[] piv; // free host memory
  delete[] a_unref;
  magma_free(d_a); // free device memory
  magma_free(d_a_orig); // free device memory
  magma_free(d_a_refine); // free device memory
  magma_queue_destroy(queue); // destroy queue
  // result = magma_finalize();
  jer = (int)err;
}


#endif