Commit d679bb90 authored by Andrea Giannetti's avatar Andrea Giannetti
Browse files

Fixed grid filename duplication and potential concurrency.

parent f8b09919
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
import glob
import os
import uuid
import stg.stg_build_db_structure
from stg.stg_build_db_structure import init_db
from itertools import product, chain
from multiprocessing import Pool
from assets.commons import (load_config_file,
@@ -43,6 +43,7 @@ def compute_grid(tdust, nh2, line, density_keyword, dust_temperature_keyword):

def build_model_grid(run_id: str,
                     cleanup_scratches: bool = True):
    init_db()
    stg_config = load_config_file(os.path.join('stg', 'config', 'config.yml'))
    pl_density_index = float(stg_config['grid']['density_powerlaw_idx'])
    pl_dust_temperature_idx = float(stg_config['grid']['dust_temperature_powerlaw_idx'])
+14 −10
Original line number Diff line number Diff line
@@ -19,8 +19,6 @@ from sqlalchemy.exc import OperationalError
logger = setup_logger(name='DB_SETUP')
Base = declarative_base()

engine = get_pg_engine(logger=logger)


class GridFiles(Base):
    __tablename__ = "grid_files"
@@ -191,6 +189,8 @@ class RatioMaps(Base):
    mom_zero_map_2 = relationship("MomentZeroMaps", foreign_keys=[mom_zero_name_2, run_id])


def init_db():
    engine = get_pg_engine(logger=logger)
    try:
        Base.metadata.create_all(bind=engine)
    except OperationalError:
@@ -199,3 +199,7 @@ except OperationalError:
        Base.metadata.create_all(bind=engine)
    logger.info('Connection successful! DB initialized as needed.')
    engine.dispose()


if __name__ == '__main__':
    init_db()
+27 −9
Original line number Diff line number Diff line
@@ -272,12 +272,14 @@ def write_molecular_number_density_profiles(profiles: dict,
                          grid_metadata=grid_metadata)


def save_fits_grid_profile(quantity,
                           grid_metadata,
                           filename,
                           path=None):
def save_fits_grid_profile(quantity: np.array,
                           filename: str,
                           path: str = None):
    _path = validate_parameter(path, default=os.path.join('prs', 'fits', 'grids'))
    if not os.path.isfile(os.path.join(_path, filename)):
        fits.writeto(os.path.join(_path, filename), quantity)
    else:
        logger.info('Skipping saving of fits grid. File already present!')


def convert_dimensional_unit(value: Union[float, None, str],
@@ -435,8 +437,8 @@ def write_stellar_input_file(stars_metadata: dict,
                             grid_metadata: dict,
                             path: str,
                             wavelengths_micron: np.array):
    star_properties = [' '.join([str(rstar), str(mstar), str(pos[0]), str(pos[1]), str(pos[2])]) for rstar, mstar, pos in
                       zip(stars_metadata['rstars'], stars_metadata['mstars'], stars_metadata['star_positions'])]
    star_properties = [' '.join([str(rstar), str(mstar), str(pos[0]), str(pos[1]), str(pos[2])]) for rstar, mstar, pos
                       in zip(stars_metadata['rstars'], stars_metadata['mstars'], stars_metadata['star_positions'])]
    override_defaults = {
        'iformat': 2,
        'nstars': stars_metadata['nstars'],
@@ -452,6 +454,21 @@ def write_stellar_input_file(stars_metadata: dict,
                      flatten_style='C')


def get_grid_name(method: Union[str, None] = None,
                  zip_filename: Union[str, None] = None,
                  quantity_name: Union[str, None] = None):
    _method = validate_parameter(method, default='uuid')
    allowed_methods = ('uuid', 'composite_grid')
    if method == 'uuid':
        return f'{str(uuid.uuid4())}.fits'
    elif method == 'composite_grid':
        assert ((zip_filename is not None) and (quantity_name is not None))
        return f'{".".join(zip_filename.split(".")[0:-1])}_{quantity_name}.fits'
    else:
        raise NotImplementedError(
            f'The chosen method is not available. Allowed options are: {" ".join(allowed_methods)}')


def main(run_id: str,
         override_config: Union[dict, None] = None,
         path_radmc_files: Union[str, None] = None,
@@ -516,9 +533,10 @@ def main(run_id: str,
                                       grid_metadata=grid_metadata,
                                       run_id=run_id)

    grid_file_name = f'{str(uuid.uuid4())}.fits'
    grid_file_name = get_grid_name(method='composite_grid',
                                   zip_filename=zip_filename,
                                   quantity_name='gas_number_density')
    save_fits_grid_profile(quantity=profiles['gas_number_density'],
                           grid_metadata=grid_metadata,
                           filename=grid_file_name)

    populate_grid_files(quantity='gas_number_density',
+20 −1
Original line number Diff line number Diff line
import os
import numpy as np
from stg.stg_radmc_input_generator import get_solid_body_rotation_y
from stg.stg_radmc_input_generator import (get_solid_body_rotation_y,
                                           get_grid_name)
from astropy import units as u
from unittest import TestCase
from assets.commons import (load_config_file,
@@ -19,6 +20,24 @@ class Test(TestCase):
    def setUp(self):
        self.config_filename = 'config.yml'

    def test_get_grid_filename_composite(self):
        grid_name = get_grid_name(method='composite_grid',
                                  zip_filename='abc.def.zip',
                                  quantity_name='h2_density')
        self.assertEqual(grid_name, 'abc.def_h2_density.fits')

    def test_get_grid_filename_missing_info(self):
        with self.assertRaises(AssertionError):
            grid_name = get_grid_name(method='composite_grid',
                                      quantity_name='h2_density')
        with self.assertRaises(AssertionError):
            grid_name = get_grid_name(method='composite_grid',
                                      zip_filename='abc.def.zip')

    def test_get_grid_filename_undefined_method(self):
        with self.assertRaises(NotImplementedError):
            get_grid_name(method='puzzidilontano')

    def test_get_solid_body_rotation_y(self):
        _config_filename = os.path.join('test_files', self.config_filename)
        grid_size = 3