Commit b80dc77f authored by Mulas, Giacomo's avatar Mulas, Giacomo
Browse files

fix compilation of blas-based iterative refinement

parent c684f060
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -54,13 +54,13 @@ void invert_matrix(dcomplex **mat, np_int size, int &ier, int &maxrefiters, doub
  // we pass to magma_zinvert_and_refine() the accuracygoal in, get the actual
  // accuracy back out
  magma_zinvert_and_refine(mat, size, ier, maxrefiters, accuracygoal, refinemode, target_device);
#elif
#else
  magma_zinvert(mat, size, ier, target_device);
#endif  
#elif defined USE_LAPACK
#ifdef USE_REFINEMENT
  zinvert_and_refine(mat, size, ier, maxrefiters, accuracygoal, refinemode);
#elif
#else
  zinvert(mat, size, ier);
#endif
#else
+43 −38
Original line number Diff line number Diff line
@@ -39,6 +39,27 @@

#include <limits>

#ifdef USE_MKL
  extern "C" void zcopy_(np_int *n, MKL_Complex16 *arr1, np_int *inc1, MKL_Complex16 *arr2,
		     np_int *inc2);
  extern "C" void zgemm_(char *transa, char *transb, np_int *l, np_int *m, np_int *n,
		     MKL_Complex16 *alpha, MKL_Complex16 *a, np_int *lda,
		     MKL_Complex16 *b, np_int *ldb, MKL_Complex16 *beta,
		     MKL_Complex16 *c, np_int *ldc);
  extern "C" void zaxpy_(np_int *n, MKL_Complex16 *alpha, MKL_Complex16 *arr1, np_int *inc1,
		     MKL_Complex16 *arr2, np_int *inc2);
  extern "C" np_int izamax_(np_int *n, MKL_Complex16 *arr1, np_int *inc1);
#else
  extern "C" void zcopy_(np_int *n, dcomplex *arr1, np_int *inc1, dcomplex *arr2, np_int *inc2);
  extern "C" void zgemm_(char *transa, char *transb, np_int *l, np_int *m, np_int *n,
		     dcomplex *alpha, dcomplex *a, np_int *lda,
		     dcomplex *b, np_int *ldb, dcomplex *beta,
		     dcomplex *c, np_int *ldc);
  extern "C" void zaxpy_(np_int *n, dcomplex *alpha, dcomplex *arr1, np_int *inc1,
		     dcomplex *arr2, np_int *inc2);
  extern "C" np_int izamax_(np_int *n, dcomplex *arr1, np_int *inc1);
#endif

void zinvert(dcomplex **mat, np_int n, int &jer) {
  jer = 0;
  dcomplex *arr = &(mat[0][0]);
@@ -61,27 +82,8 @@ void zinvert(dcomplex **mat, np_int n, int &jer) {
  delete[] IPIV;
}

void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int maxrefiters, double &accuracygoal) {
#ifdef USE_MKL
  extern void zcopy_(np_int *n, MKL_Complex16 *arr1, np_int *inc1, MKL_Complex16 *arr2,
		     np_int *inc2);
  extern void zgemm_(char *transa, char *transb, np_int *l, np_int *m, np_int *n,
		     MKL_Complex16 *alpha, MKL_Complex16 *a, np_int *lda,
		     MKL_Complex16 *b, np_int *ldb, MKL_Complex16 *beta,
		     MKL_Complex16 *c, np_int *ldc);
  extern void zaxpy_(np_int *n, MKL_Complex16 *alpha, MKL_Complex16 *arr1, np_int *inc1,
		     MKL_Complex16 *arr2, np_int *inc2);
  extern np_int izamax_(np_int *n, MKL_Complex16 *arr1, np_int *inc1);
#else
  extern void zcopy_(np_int *n, dcomplex *arr1, np_int *inc1, dcomplex *arr2, np_int *inc2);
  extern void zgemm_(char *transa, char *transb, np_int *l, np_int *m, np_int *n,
		     dcomplex *alpha, dcomplex *a, np_int *lda,
		     dcomplex *b, np_int *ldb, dcomplex *beta,
		     dcomplex *c, np_int *ldc);
  extern void zaxpy_(np_int *n, dcomplex *alpha, dcomplex *arr1, np_int *inc1,
		     dcomplex *arr2, np_int *inc2);
  extern np_int izamax_(np_int *n, dcomplex *arr1, np_int *inc1);
#endif
void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int &maxrefiters, double &accuracygoal, int refinemode) {

  jer = 0;
#ifdef USE_MKL
  MKL_Complex16 *arr = (MKL_Complex16 *) &(mat[0][0]);
@@ -93,7 +95,6 @@ void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int maxrefiters, dou
#ifdef USE_MKL
  MKL_Complex16 *arr_orig = new MKL_Complex16[nn];
  MKL_Complex16 *arr_refine = new MKL_Complex16[nn];
  // MKL_Complex16 *arr_unref = new MKL_Complex16[nn];
  MKL_Complex16 *id = new MKL_Complex16[n];
  for (np_int i=0; i<n ; i++) {
    id[i].real =  1;
@@ -102,7 +103,6 @@ void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int maxrefiters, dou
#else
  dcomplex *arr_orig = new dcomplex[nn];
  dcomplex *arr_refine = new dcomplex[nn];
  // dcomplex *arr_unref = new dcomplex[nn];
  dcomplex *id = new dcomplex[n];
  for (np_int i=0; i<n ; i++) id[i] = (dcomplex) 1;
#endif
@@ -113,11 +113,9 @@ void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int maxrefiters, dou
  
  LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n, arr, n, IPIV);
  LAPACKE_zgetri(LAPACK_ROW_MAJOR, n, arr, n, IPIV);
  // zcopy_(&nn, arr, &incx, arr_unref, &incx);
  delete[] IPIV;

  bool iteraterefine = true;
  // double oldmax = std::numeric_limits<double>::max();
  char transa = 'N';
#ifdef USE_MKL
  MKL_Complex16 dczero;
@@ -140,30 +138,37 @@ void zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int maxrefiters, dou
  zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr, &n, arr_orig, &n, &dczero, arr_refine, &n);
  np_int incy = n+1;
  zaxpy_(&n, &dcone, id, &incx, arr_refine, &incy);
  double oldmax = 0;
  if (refinemode >0) {
    np_int maxindex = izamax_(&n, arr_refine, &incx);
#ifdef USE_MKL
  double oldmax = cabs(arr_refine[maxindex].real + I*arr_refine[maxindex].imag);
#elif
  double oldmax = cabs(arr_refine[maxindex];);
    oldmax = cabs(arr_refine[maxindex].real + I*arr_refine[maxindex].imag);
#else
    oldmax = cabs(arr_refine[maxindex]);
#endif
    if (oldmax < accuracygoal) iteraterefine = false;
  for (int iter=0; (iter<maxrefiters) && iteraterefine; iter++) {
  }
  int iter;
  for (iter=0; (iter<maxrefiters) && iteraterefine; iter++) {
    zgemm_(&transa, &transa, &n, &n, &n, &dcone, arr_refine, &n, arr, &n, &dcone, arr, &n);
    zgemm_(&transa, &transa, &n, &n, &n, &dcmone, arr, &n, arr_orig, &n, &dczero, arr_refine, &n);
    zaxpy_(&n, &dcone, id, &incx, arr_refine, &incy);
    maxindex = izamax_(&n, arr_refine, &incx);
    if ((refinemode==2) || ((refinemode==1) && (iter == (maxrefiters-1)))) {
      np_int maxindex = izamax_(&n, arr_refine, &incx);
#ifdef USE_MKL
      double newmax = cabs(arr_refine[maxindex].real + I*arr_refine[maxindex].imag);
#elif
#else
      double newmax = cabs(arr_refine[maxindex]);
#endif
    if ((newmax > oldmax)||(newmax < accuracygoal)) iteraterefine = false;
      if ((refinemode==2) && ((newmax > oldmax)||(newmax < accuracygoal))) iteraterefine = false;
      oldmax = newmax; 
    }
  }
  if (refinemode==2) maxrefiters = iter;
  accuracygoal = oldmax;
  delete[] id;
  delete[] arr_refine;
  delete[] arr_orig;
  // delete[] arr_unref;
  
}

+0 −1
Original line number Diff line number Diff line
@@ -106,7 +106,6 @@ void magma_zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int &maxrefite
    delete[] native_id; // free identity vector on host
  }
  bool iteraterefine = true;
  // double oldmax = std::numeric_limits<double>::max();
  magmaDoubleComplex magma_mone;
  magma_mone.x = -1;
  magma_mone.y = 0;