Commit 01c9dcb7 authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Use refinement back-up guard in cuBLAS calls

parent da039659
Loading
Loading
Loading
Loading
+71 −52
Original line number Diff line number Diff line
@@ -86,6 +86,7 @@ using namespace std;
void cublas_zinvert(dcomplex **mat, np_int n, int device_id, const RuntimeSettings& rs) {
  char buffer[128];
  string message;
  int ref_err = 0;
  cudacall(cudaSetDevice(device_id));
  cublasHandle_t handle;
  cublascall(cublasCreate_v2(&handle));
@@ -126,6 +127,8 @@ void cublas_zinvert(dcomplex **mat, np_int n, int device_id, const RuntimeSettin
      cuDoubleComplex *d_id;
      // copy the original matrix again to d_a, so I do not need to destroy the old d_a and recreate a new one
      cudacall(cudaMemcpy(d_a, a, m*m*sizeof(cuDoubleComplex),cudaMemcpyHostToDevice)); // from here on, d_a contains the original matrix, for refinement use  
      // copy the unrefined inverted matrix from device to host, so we have a back-up if refinement fails
      cudacall(cudaMemcpy(a, d_c, m*m*sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost));
      cudacall(cudaMalloc<cuDoubleComplex>(&d_a_residual, m*m*sizeof(cuDoubleComplex)));
      cudacall(cudaMalloc<cuDoubleComplex>(&d_a_refine, m*m*sizeof(cuDoubleComplex)));
      // allocate memory for the temporary matrix products
@@ -136,67 +139,83 @@ void cublas_zinvert(dcomplex **mat, np_int n, int device_id, const RuntimeSettin
      cudacall(cudaMalloc<cuDoubleComplex>(&d_id, sizeof(cuDoubleComplex)));
      cudacall(cudaMemcpy(d_id, m_id, sizeof(cuDoubleComplex),cudaMemcpyHostToDevice)); // copy identity to device vector
      delete[] native_id; // free identity vector on host

      // Detect the maximum value of the inverse matrix.
      double oldmax = 2.0e+16, curmax = 1.0e+16;
      np_int maxindex;
      cuDoubleComplex cublasmax;
      cublascall(CUIZAMAX(handle, mm, d_c, 1, &maxindex));
      cudacall(cudaMemcpy(&cublasmax, d_c + maxindex - 1, sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost));
      curmax = cabs( (((double *) &(cublasmax))[0]) + I * (((double *) &(cublasmax))[1]));
      sprintf(buffer, "INFO: largest matrix value has modulus %.5le.\n", curmax);
      message = buffer;
      rs.logger->log(message);

      // Iterative refinement loop
      int max_ref_iters = rs.max_ref_iters;
      for (int ri = 0; ri < max_ref_iters; ri++) {
	oldmax = curmax;
	// multiply minus the original matrix times the inverse matrix
	// NOTE: factors in zgemm are swapped because zgemm is designed for column-major
	// Fortran-style arrays, whereas our arrays are C-style row-major.
	cublascall(CUZGEMM(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, m, m, &cu_mone, d_c, m, d_a, m, &cu_zero, d_a_residual, m));
	// add the identity to the product
	cublascall(CUZAXPY(handle, m, &cu_one, d_id, 0, d_a_residual, m+1));
      double oldmax=0;
      np_int maxindex;
	// find the maximum absolute value of the residual
	cublascall(CUIZAMAX(handle, mm, d_a_residual, 1, &maxindex));
      cuDoubleComplex cublasmax;
	// transfer the maximum value to the host
	cudacall(cudaMemcpy(&cublasmax, d_a_residual + maxindex - 1, sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost));
	// take the module
      oldmax = cabs( (((double *) &(cublasmax))[0]) + I*(((double *) &(cublasmax))[1]));
      //printf("Initial max residue = %g\n", oldmax);
      sprintf(buffer, "INFO: Initial max residue is %.5le.\n", oldmax);
	curmax = cabs( (((double *) &(cublasmax))[0]) + I * (((double *) &(cublasmax))[1]));
	sprintf(buffer, "DEBUG: iteration %d has residue %.5le; target residue is %.5le.\n", ri, curmax, rs.accuracy_goal);
	message = buffer;
      rs.logger->log(message);
      if (oldmax > rs.accuracy_goal) {
	for (int iter = 0; iter < rs.max_ref_iters; iter++) {
	rs.logger->log(message, LOG_DEBG);
	if (curmax < 0.5) { // Safe conditions for Newton-Schultz iteration.
	  if (curmax < 0.95 * oldmax) { // Newton-Schultz iteration is improving and can proceed.
	    // multiply the inverse times the residual, add to the initial inverse
	    cublascall(CUZGEMM(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, m, m, &cu_one, d_a_residual, m, d_c, m, &cu_zero, d_a_refine, m));
	    // add to the initial inverse
	    cublascall(CUZAXPY(handle, mm, &cu_one, d_a_refine, 1, d_c, 1));
	  // multiply minus the original matrix times the new inverse matrix
	  cublascall(CUZGEMM(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, m, m, &cu_mone, d_c, m, d_a, m, &cu_zero, d_a_residual, m));
	  // add the identity to the product
	  cublascall(CUZAXPY(handle, m, &cu_one, d_id, 0, d_a_residual, m+1));
	  // find the maximum absolute value of the residual
	  np_int maxindex;
	  cublascall(CUIZAMAX(handle, mm, d_a_residual, 1, &maxindex));
	  // transfer the maximum value to the host
	  cuDoubleComplex cublasmax;
	  cudacall(cudaMemcpy(&cublasmax, d_a_residual+maxindex-1, sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost));
	  // take the module
	  double newmax = cabs( (((double *) &(cublasmax))[0]) + I*(((double *) &(cublasmax))[1]));
	  sprintf(buffer, "DEBUG: Maximum residue after %d iterations = %.5le\n", iter+1, newmax);
	    if (curmax < rs.accuracy_goal) {
	      message = "DEBUG: good news - optimal convergence achieved. Stopping.\n";
	      rs.logger->log(message, LOG_DEBG);
	      ref_err = 0;
	      break; // ri for
	    }
	  } else {
	    if (curmax > 0.1) {
	      sprintf(buffer, "INFO: iteration %d achieved limiting residue %.5le. Cannot reach goal. Reverting.\n", ri, curmax);
	      message = buffer;
	      rs.logger->log(message);
	  // if the maximum in the residual decreased from the previous iteration,
	  // update oldmax and go on, otherwise no point further iterating refinements
	  // if ((refinemode==2) && ((newmax > oldmax)||(newmax < accuracygoal))) iteraterefine = false;
	  if (newmax < rs.accuracy_goal) {
	    message = "INFO: optimal convergence achieved. Stopping.\n";
	    rs.logger->log(message);
	    break; // iter for
	  }
	  if (newmax > 0.95 * oldmax) {
	    message = "WARN: cannot improve further. Stopping.\n";
	      ref_err = 1;
	    } else {
	      sprintf(buffer, "WARN: iteration %d achieved limiting residue %.5le. Stopping.\n", ri, curmax);
	      message = buffer;
	      rs.logger->log(message, LOG_WARN);
	    break; // iter for
	    }
	  oldmax = newmax;
	    break; // ri for
	  }
	} else { // curmax > 0.5. Newton-Schultz iteration is dangerous.
	  if (curmax < oldmax) {
	    sprintf(buffer, "WARN: iteration %d has residue %.5le. Iterating is dangerous. Stopping.\n", ri, curmax);
	    message = buffer;
	    rs.logger->log(message, LOG_WARN);
	  } else {
	message = "INFO: starting accuracy is already good.\n";
	    ref_err = 1;
	    sprintf(buffer, "INFO: iteration %d has residue %.5le. Reverting to unrefined and stopping.\n", ri, curmax);
	    message = buffer;
	    rs.logger->log(message);
	  }
	  break; // ri for
	}
      } // end of ri for
      cudaFree(d_a_residual);
      cudaFree(d_a_refine);
      cudaFree(d_id);
    }
    if (ref_err == 0) { // Refinement was fine, we can copy refined matrix back to host.
      cudacall(cudaMemcpy(a, d_c, m * m * sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost));
    }
    cudaFree(batch_d_a);
    cudaFree(batch_d_c);
    cudaFree(piv);