Commit 67e2b24a authored by Mulas, Giacomo's avatar Mulas, Giacomo
Browse files

First implementation of iterative refinement of matrix inversion, in magma and lapack.

parent a1bfb721
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -34,4 +34,17 @@
 */
void zinvert(dcomplex **mat, np_int n, int &jer);

/*! \brief Invert a complex matrix with double precision elements, using iterative refinement to improve accuracy.
 *
 * Use LAPACKE64 to perform an in-place matrix inversion for a complex
 * matrix with double precision elements.
 *
 * \param mat: Matrix of complex. The matrix to be inverted.
 * \param n: `np_int` The number of rows and columns of the [n x n] matrix.
 * \param jer: `int &` Reference to an integer return flag.
 * \param maxrefiters: `int` Maximum number of refinement iterations.
 * \param accuracygoal: `double` Accuracy to achieve in iterative refinement, defined as the module of the maximum difference between the identity matrix and the matrix product of the (approximate) inverse times the original matrix. On return, it contains the actually achieved accuracy
 */
void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int maxrefiters, double &accuracygoal);

#endif
+3 −2
Original line number Diff line number Diff line
@@ -43,9 +43,10 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id=0);
 * \param mat: Matrix of complex. The matrix to be inverted.
 * \param n: `np_int` The number of rows and columns of the [n x n] matrix.
 * \param jer: `int &` Reference to an integer return flag.
 * \param refiters: `int` integer number of refinement iterations to apply.
 * \param maxrefiters: `int` Maximum number of refinement iterations to apply.
 * \param accuracygoal: `double` Accuracy to achieve in iterative refinement, defined as the module of the maximum difference between the identity matrix and the matrix product of the (approximate) inverse times the original matrix. On return, it contains the actually achieved accuracy
 * \param device_id: `int` ID of the device for matrix inversion offloading.
 */
void magma_zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters, int device_id);
void magma_zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int maxrefiters, double &accuracygoal, int device_id);

#endif
+4 −3
Original line number Diff line number Diff line
@@ -51,14 +51,15 @@ void invert_matrix(dcomplex **mat, np_int size, int &ier, np_int max_size, int t
#ifdef USE_MAGMA
#ifdef USE_REFINEMENT
  // try using the iterative refinement to obtain a more accurate solution
  const int refiters = 3;
  magma_zinvert_and_refine(mat, size, ier, refiters, target_device);
  const int maxrefiters = 6;
  double accuracygoal = 1e-6;
  magma_zinvert_and_refine(mat, size, ier, maxrefiters, accuracygoal, target_device);
#elif
  magma_zinvert(mat, size, ier, target_device);
#endif  
#elif defined USE_LAPACK
#ifdef USE_REFINEMENT
  zinvert_and_refine(mat, size, ier, refiters);
  zinvert_and_refine(mat, size, ier, maxrefiters, accuracygoal);
#elif
  zinvert(mat, size, ier);
#endif
+25 −15
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ void zinvert(dcomplex **mat, np_int n, int &jer) {
  delete[] IPIV;
}

void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters) {
void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int maxrefiters, double &accuracygoal) {
#ifdef USE_MKL
  extern void zcopy_(np_int *n, MKL_Complex16 *arr1, np_int *inc1, MKL_Complex16 *arr2,
		     np_int *inc2);
@@ -93,7 +93,7 @@ void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters) {
#ifdef USE_MKL
  MKL_Complex16 *arr_orig = new MKL_Complex16[nn];
  MKL_Complex16 *arr_refine = new MKL_Complex16[nn];
  MKL_Complex16 *arr_unref = new MKL_Complex16[nn];
  // MKL_Complex16 *arr_unref = new MKL_Complex16[nn];
  MKL_Complex16 *id = new MKL_Complex16[n];
  for (np_int i=0; i<n ; i++) {
    id[i].real =  1;
@@ -102,7 +102,7 @@ void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters) {
#else
  dcomplex *arr_orig = new dcomplex[nn];
  dcomplex *arr_refine = new dcomplex[nn];
  dcomplex *arr_unref = new dcomplex[nn];
  // dcomplex *arr_unref = new dcomplex[nn];
  dcomplex *id = new dcomplex[n];
  for (np_int i=0; i<n ; i++) id[i] = (dcomplex) 1;
#endif
@@ -113,11 +113,11 @@ void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters) {
  
  LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n, arr, n, IPIV);
  LAPACKE_zgetri(LAPACK_ROW_MAJOR, n, arr, n, IPIV);
  zcopy_(&nn, arr, &incx, arr_unref, &incx);
  // zcopy_(&nn, arr, &incx, arr_unref, &incx);
  delete[] IPIV;

  bool iteraterefine = true;
  double oldmax = std::numeric_limits<double>::max();
  // double oldmax = std::numeric_limits<double>::max();
  char transa = 'N';
#ifdef USE_MKL
  MKL_Complex16 dczero;
@@ -134,26 +134,36 @@ void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters) {
  dcomplex dcone = 1;
  dcomplex dcmone = -1;
#endif
  zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr_orig, &n, arr, &n, &dczero, arr_refine, &n);
  // multiply minus the original matrix times the inverse matrix
  // NOTE: factors in zgemm are swapped because zgemm is designed for column-major
  // Fortran-style arrays, whereas our arrays are C-style row-major.
  zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr, &n, arr_orig, &n, &dczero, arr_refine, &n);
  np_int incy = n+1;
  zaxpy_(&n, &dcone, id, &incx, arr_refine, &incy);
  for (int iter=0; (iter<refiters) && iteraterefine; iter++) {
    zgemm_(&transa, &transa, &n, &n, &n, &dcone, arr, &n, arr_refine, &n, &dcone, arr, &n);
    zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr_orig, &n, arr, &n, &dczero, arr_refine, &n);
    zaxpy_(&n, &dcone, id, &incx, arr_refine, &incy);
  np_int maxindex = izamax_(&n, arr_refine, &incx);
#ifdef USE_MKL
    dcomplex newzmax = arr_refine[maxindex].real + I*arr_refine[maxindex].imag;
  double oldmax = cabs(arr_refine[maxindex].real + I*arr_refine[maxindex].imag);
#elif
  double oldmax = cabs(arr_refine[maxindex];);
#endif
  if (oldmax < accuracygoal) iteraterefine = false;
  for (int iter=0; (iter<maxrefiters) && iteraterefine; iter++) {
    zgemm_(&transa, &transa, &n, &n, &n, &dcone, arr_refine, &n, arr, &n, &dcone, arr, &n);
    zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr, &n, arr_orig, &n, &dczero, arr_refine, &n);
    zaxpy_(&n, &dcone, id, &incx, arr_refine, &incy);
    maxindex = izamax_(&n, arr_refine, &incx);
#ifdef USE_MKL
    double newmax = cabs(arr_refine[maxindex].real + I*arr_refine[maxindex].imag);
#elif
    dcomplex newzmax = arr_refine[maxindex];
    double newmax = cabs(arr_refine[maxindex]);
#endif
    double newmax = cabs(newzmax);
    if (newmax < oldmax) oldmax = newmax; else iteraterefine = false;
    if ((newmax > oldmax)||(newmax < accuracygoal)) iteraterefine = false;
    oldmax = newmax; 
  }
  delete[] id;
  delete[] arr_refine;
  delete[] arr_orig;
  delete[] arr_unref;
  // delete[] arr_unref;
  
}

+35 −24
Original line number Diff line number Diff line
@@ -62,7 +62,7 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
  jer = (int)err;
}

void magma_zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters, int device_id) {
void magma_zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int maxrefiters, double &accuracygoal, int device_id) {
  // magma_int_t result = magma_init();
  magma_int_t err = MAGMA_SUCCESS;
  magma_queue_t queue = NULL;
@@ -79,7 +79,7 @@ void magma_zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters,
  magmaDoubleComplex *d_a_refine; // pointer to residual array on device
  ldwork = m * magma_get_zgetri_nb(m); // optimal block size
  // allocate matrices
  magmaDoubleComplex *a_unref = new magmaDoubleComplex[mm]; 
  // magmaDoubleComplex *a_unref = new magmaDoubleComplex[mm]; 
  err = magma_zmalloc(&d_a, mm); // device memory for a, will contain the inverse after call to zgetri
  err = magma_zmalloc(&d_a_orig, mm); // device memory for copy of a
  err = magma_zmalloc(&dwork, ldwork); // dev. mem. for ldwork
@@ -90,21 +90,23 @@ void magma_zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters,
  magma_zgetrf_gpu(m, m, d_a, m, piv, &info);
  // do the in-place inversion, after which d_a contains the (first approx) inverse
  magma_zgetri_gpu(m, d_a, m, piv, dwork, ldwork, &info);
  magma_zgetmatrix(m, m, d_a , m, a_unref, m, queue); // copy unrefined d_a -> a_unref
  // magma_zgetmatrix(m, m, d_a , m, a_unref, m, queue); // copy unrefined d_a -> a_unref
  magma_free(dwork); // free dwork, it was only needed by zgetri
  // allocate memory for the temporary matrix product
  err = magma_zmalloc(&d_a_refine, mm); // device memory for iterative correction of inverse of a
  // allocate memory for the identity vector on the host
  magmaDoubleComplex *d_id;
  {
    dcomplex *native_id = new dcomplex[m];
    for (magma_int_t i=0; i<m; i++) native_id[i] = 1;
    magmaDoubleComplex *id = (magmaDoubleComplex *) &(native_id[0]);
    // fill it with 1
  magmaDoubleComplex *d_id;
    err = magma_zmalloc(&d_id, m);
    magma_zsetvector(m, id, 1, d_id, 1, queue); // copy identity to device vector
    delete[] native_id; // free identity vector on host
  }
  bool iteraterefine = true;
  double oldmax = std::numeric_limits<double>::max();
  // double oldmax = std::numeric_limits<double>::max();
  magmaDoubleComplex magma_mone;
  magma_mone.x = -1;
  magma_mone.y = 0;
@@ -115,37 +117,46 @@ void magma_zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int refiters,
  magma_zero.x = 0;
  magma_zero.y = 0;
  // multiply minus the original matrix times the inverse matrix
  magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m,  magma_mone, d_a_orig, m, d_a, m, magma_zero, d_a_refine, m, queue);
  // NOTE: factors in zgemm are swapped because zgemm is designed for column-major
  // Fortran-style arrays, whereas our arrays are C-style row-major.
  magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m,  magma_mone, d_a, m, d_a_orig, m, magma_zero, d_a_refine, m, queue);
  // add the identity to the product
  magma_zaxpy (m, magma_one, d_id, 1, d_a_refine, m+1, queue);
  // begin correction loop (should iterate refiters times)
  for (int iter=0; (iter<refiters) && iteraterefine; iter++) {
  // find the maximum absolute value of the residual
  magma_int_t maxindex = magma_izamax(mm, d_a_refine, 1, queue);
  magmaDoubleComplex magmamax;
  // transfer the maximum value to the host
  magma_zgetvector(1, d_a_refine+maxindex, 1, &magmamax, 1, queue);
  // take the module
  double oldmax = cabs(magmamax.x + I*magmamax.y);
  if (oldmax < accuracygoal) iteraterefine = false;
  // begin correction loop (should iterate maxrefiters times)
  for (int iter=0; (iter<maxrefiters) && iteraterefine; iter++) {
    // multiply the inverse times the residual, add to the initial inverse
    magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m, magma_one, d_a, m, d_a_refine, m, magma_one, d_a, m, queue);
    magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m, magma_one, d_a_refine, m, d_a, m, magma_one, d_a, m, queue);
    // multiply minus the original matrix times the new inverse matrix
    magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m, magma_mone, d_a_orig, m, d_a, m, magma_zero, d_a_refine, m, queue);
    magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m, magma_mone, d_a, m, d_a_orig, m, magma_zero, d_a_refine, m, queue);
    // add the identity to the product
    magma_zaxpy (m, magma_one, d_id, 1, d_a_refine, m+1, queue);
    // find the maximum absolute value of the residual
    magma_int_t maxindex = magma_izamax(mm, d_a_refine, 1, queue);
    magmaDoubleComplex magmamax;
    maxindex = magma_izamax(mm, d_a_refine, 1, queue);
    // transfer the maximum value to the host
    magma_zgetvector(1, d_a_refine+maxindex, 1, &magmamax, 1, queue);
    dcomplex newzmax = magmamax.x + I*magmamax.y;
    // take the module
    double newmax = cabs(newzmax);
    double newmax = cabs(magmamax.x + I*magmamax.y);
    // if the maximum in the residual decreased from the previous iteration,
    // update oldmax and go on, otherwise no point further iterating refinements
    if (newmax < oldmax) oldmax = newmax; else iteraterefine = false;
    if ((newmax > oldmax)||(newmax < accuracygoal)) iteraterefine = false;
    oldmax = newmax;
  }
  accuracygoal = oldmax;
  // end correction loop
  // free temporary device arrays 
  magma_free(d_id);
  magma_free(d_a_refine);
  magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy final refined d_a -> a
  // I should probably do some meaningful check / comparison between a and a_unref
  delete[] piv; // free host memory
  delete[] a_unref;
  // delete[] a_unref;
  magma_free(d_id);
  magma_free(d_a); // free device memory
  magma_free(d_a_orig); // free device memory
  magma_free(d_a_refine); // free device memory