Commit c0e4a780 authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Refactor Swap1 interface and revert to serial cfrfme implementation

parent b8184b93
Loading
Loading
Loading
Loading
+10 −10
Original line number Diff line number Diff line
@@ -27,14 +27,14 @@
class Swap1 {
protected:
  //! Index of the last element to be filled.
  int last_index;
  int _last_index;
  //! Number of vector coordinates. QUESTION: correct?
  int nkv;
  int _nkv;
  //! NLMMT = 2 * LM * (LM + 2)
  int nlmmt;
  int _nlmmt;

  //! QUESTION: definition?
  dcomplex *wk;
  dcomplex *_wk;

  /*! \brief Load a Swap1 instance from a HDF5 binary file.
   *
@@ -64,24 +64,24 @@ protected:

public:
  //! \brief Read only view on WK.
  const dcomplex *vec_wk;
  const dcomplex *wk;
  
  /*! \brief Swap1 instance constructor.
   *
   * \param lm: `int` Maximum field expansion order.
   * \param _nkv: `int` Number of vector coordinates. QUESTION: correct?
   */
  Swap1(int lm, int _nkv);
  Swap1(int lm, int nkv);

  /*! \brief Swap1 instance destroyer.
   */
  ~Swap1() { delete[] wk; }
  ~Swap1() { delete[] _wk; }

  /*! \brief Append an element at the end of the vector.
   *
   * \param value: `complex double` The value to be added to the vector.
   */
  void append(dcomplex value) { wk[last_index++] = value; }
  void append(dcomplex value) { _wk[_last_index++] = value; }
  
  /*! \brief Load a Swap1 instance from binary file.
   *
@@ -98,11 +98,11 @@ public:
   * \param _nkv: `int` Number of vector coordinates. QUESTION: correct?
   * \return size: `long` The necessary memory size in bytes.
   */
  static long get_memory_requirement(int lm, int _nkv);
  static long get_memory_requirement(int lm, int nkv);
  
  /*! \brief Bring the pointer to the next element at the start of vector.
   */
  void reset() { last_index = 0; }
  void reset() { _last_index = 0; }

  /*! \brief Write a Swap1 instance to binary file.
   *
+37 −37
Original line number Diff line number Diff line
@@ -47,13 +47,13 @@
using namespace std;

// >>> START OF Swap1 CLASS IMPLEMENTATION <<<
Swap1::Swap1(int lm, int _nkv) {
  nkv = _nkv;
  nlmmt = 2 * lm * (lm + 2);
  const int size = nkv * nkv * nlmmt;
  wk = new dcomplex[size]();
  vec_wk = wk;
  last_index = 0;
Swap1::Swap1(int lm, int nkv) {
  _nkv = nkv;
  _nlmmt = 2 * lm * (lm + 2);
  const int size = nkv * nkv * _nlmmt;
  _wk = new dcomplex[size]();
  wk = _wk;
  _last_index = 0;
}

Swap1* Swap1::from_binary(const std::string& file_name, const std::string& mode) {
@@ -76,21 +76,21 @@ Swap1* Swap1::from_hdf5(const std::string& file_name) {
  herr_t status = hdf_file->get_status();
  double *elements;
  string str_type;
  int _nlmmt, _nkv, lm, num_elements, index;
  int nlmmt, nkv, lm, num_elements, index;
  dcomplex value;
  if (status == 0) {
    status = hdf_file->read("NLMMT", "INT32", &_nlmmt);
    status = hdf_file->read("NKV", "INT32", &_nkv);
    lm = (int)(sqrt(4.0 + 2.0 * _nlmmt) / 2.0) - 1;
    num_elements = 2 * _nlmmt * _nkv * _nkv;
    instance = new Swap1(lm, _nkv);
    status = hdf_file->read("NLMMT", "INT32", &nlmmt);
    status = hdf_file->read("NKV", "INT32", &nkv);
    lm = (int)(sqrt(4.0 + 2.0 * nlmmt) / 2.0) - 1;
    num_elements = 2 * nlmmt * nkv * nkv;
    instance = new Swap1(lm, nkv);
    elements = new double[num_elements]();
    str_type = "FLOAT64_(" + to_string(num_elements) + ")";
    status = hdf_file->read("WK", str_type, elements);
    for (int wi = 0; wi < num_elements / 2; wi++) {
      index = 2 * wi;
      value = elements[index] + elements[index + 1] * I;
      instance->wk[wi] = value;
      instance->_wk[wi] = value;
    } // wi loop
    delete[] elements;
    status = hdf_file->close();
@@ -102,19 +102,19 @@ Swap1* Swap1::from_hdf5(const std::string& file_name) {
Swap1* Swap1::from_legacy(const std::string& file_name) {
  fstream input;
  Swap1 *instance = NULL;
  int _nlmmt, _nkv, lm;
  int nlmmt, nkv, lm;
  double rval, ival;
  input.open(file_name.c_str(), ios::in | ios::binary);
  if (input.is_open()) {
    input.read(reinterpret_cast<char *>(&_nlmmt), sizeof(int));
    lm = (int)(sqrt(4.0 + 2.0 * _nlmmt) / 2.0) - 1;
    input.read(reinterpret_cast<char *>(&_nkv), sizeof(int));
    instance = new Swap1(lm, _nkv);
    int num_elements = _nlmmt * _nkv * _nkv;
    input.read(reinterpret_cast<char *>(&nlmmt), sizeof(int));
    lm = (int)(sqrt(4.0 + 2.0 * nlmmt) / 2.0) - 1;
    input.read(reinterpret_cast<char *>(&nkv), sizeof(int));
    instance = new Swap1(lm, nkv);
    int num_elements = nlmmt * nkv * nkv;
    for (int j = 0; j < num_elements; j++) {
      input.read(reinterpret_cast<char *>(&rval), sizeof(double));
      input.read(reinterpret_cast<char *>(&ival), sizeof(double));
      instance->wk[j] = rval + ival * I;
      instance->_wk[j] = rval + ival * I;
    }
    input.close();
  } else {
@@ -123,9 +123,9 @@ Swap1* Swap1::from_legacy(const std::string& file_name) {
  return instance;
}

long Swap1::get_memory_requirement(int lm, int _nkv) {
long Swap1::get_memory_requirement(int lm, int nkv) {
  long size = (long)(3 * sizeof(int));
  size += (long)(sizeof(dcomplex) * 2 * lm * (lm + 2) * _nkv * _nkv);
  size += (long)(sizeof(dcomplex) * 2 * lm * (lm + 2) * nkv * nkv);
  return size;
}

@@ -146,20 +146,20 @@ void Swap1::write_hdf5(const std::string& file_name) {
  List<void *> rec_ptr_list(1);
  herr_t status;
  string str_type;
  int num_elements = 2 * nlmmt * nkv * nkv;
  int num_elements = 2 * _nlmmt * _nkv * _nkv;
  rec_name_list.set(0, "NLMMT");
  rec_type_list.set(0, "INT32_(1)");
  rec_ptr_list.set(0, &nlmmt);
  rec_ptr_list.set(0, &_nlmmt);
  rec_name_list.append("NKV");
  rec_type_list.append("INT32_(1)");
  rec_ptr_list.append(&nkv);
  rec_ptr_list.append(&_nkv);
  rec_name_list.append("WK");
  str_type = "FLOAT64_(" + to_string(num_elements) + ")";
  rec_type_list.append(str_type);
  double *ptr_elements = new double[num_elements]();
  for (int wi = 0; wi < num_elements / 2; wi++) {
    ptr_elements[2 * wi] = real(wk[wi]);
    ptr_elements[2 * wi + 1] = imag(wk[wi]);
    ptr_elements[2 * wi] = real(_wk[wi]);
    ptr_elements[2 * wi + 1] = imag(_wk[wi]);
  }
  rec_ptr_list.append(ptr_elements);

@@ -185,12 +185,12 @@ void Swap1::write_legacy(const std::string& file_name) {
  double rval, ival;
  output.open(file_name.c_str(), ios::out | ios::binary);
  if (output.is_open()) {
    int num_elements = nlmmt * nkv * nkv;
    output.write(reinterpret_cast<char *>(&nlmmt), sizeof(int));
    output.write(reinterpret_cast<char *>(&nkv), sizeof(int));
    int num_elements = _nlmmt * _nkv * _nkv;
    output.write(reinterpret_cast<char *>(&_nlmmt), sizeof(int));
    output.write(reinterpret_cast<char *>(&_nkv), sizeof(int));
    for (int j = 0; j < num_elements; j++) {
      rval = real(wk[j]);
      ival = imag(wk[j]);
      rval = real(_wk[j]);
      ival = imag(_wk[j]);
      output.write(reinterpret_cast<char *>(&rval), sizeof(double));
      output.write(reinterpret_cast<char *>(&ival), sizeof(double));
    }
@@ -201,15 +201,15 @@ void Swap1::write_legacy(const std::string& file_name) {
}

bool Swap1::operator ==(Swap1 &other) {
  if (nlmmt != other.nlmmt) {
  if (_nlmmt != other._nlmmt) {
    return false;
  }
  if (nkv != other.nkv) {
  if (_nkv != other._nkv) {
    return false;
  }
  const int num_elements = nlmmt * nkv * nkv;
  const int num_elements = _nlmmt * _nkv * _nkv;
  for (int i = 0; i < num_elements; i++) {
    if (wk[i] != other.wk[i]) {
    if (_wk[i] != other._wk[i]) {
      return false;
    }
  }
+7 −8
Original line number Diff line number Diff line
@@ -374,18 +374,17 @@ void frfme(string data_file, string output_path) {
#ifdef USE_NVTX
	  nvtxRangePop();
#endif
	  dcomplex *vec_w = new dcomplex[nkv * nkv]();
	  dcomplex **w = new dcomplex*[nkv];
	  for (int wi = 0; wi < nkv; wi++) w[wi] = vec_w + wi * nkv;
#ifdef USE_NVTX
	  nvtxRangePush("j80 loop");
#endif
#pragma omp parallel for
	  for (int j80 = jlmf; j80 <= jlml; j80++) {
	    dcomplex *vec_w = new dcomplex[nkv * nkv]();
	    dcomplex **w = new dcomplex*[nkv];
	    for (int wi = 0; wi < nkv; wi++) w[wi] = vec_w + wi * nkv;
	    int wk_index = (j80 - jlmf) * nkv * nkv;
	    int wk_index = 0;
	    for (int jy50 = 0; jy50 < nkv; jy50++) {
	      for (int jx50 = 0; jx50 < nkv; jx50++) {
		for (int wi = 0; wi < nlmmt; wi++) wk[wi] = tt1->vec_wk[wk_index++];
		for (int wi = 0; wi < nlmmt; wi++) wk[wi] = tt1->wk[wk_index++];
		w[jx50][jy50] = wk[j80 - 1];
	      } // jx50
	    } // jy50 loop
@@ -420,9 +419,9 @@ void frfme(string data_file, string output_path) {
		} // ix65 loop
	      } // iy70 loop
	    } // iz75 loop
	  } // j80 loop
	  delete[] vec_w;
	  delete[] w;
	  } // j80 loop
#ifdef USE_NVTX
	  nvtxRangePop();
#endif