Commit 113b526c authored by Andrea Giannetti's avatar Andrea Giannetti
Browse files

Updated DB constraints; inserted inspection into pipeline.

parent 693dda57
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ from assets.commons import (load_config_file,
from stg.stg_radmc_input_generator import main as stg_main
from mdl.mdl_execute_radmc_command import main as execute_radmc_script
from prs.prs_compute_integrated_fluxes_and_ratios import main as prs_main
from prs.prs_inspect_results import main as prs_inspection_main


def build_model_grid():
@@ -32,7 +33,7 @@ def build_model_grid():
        ratio_fits = []
        for line in lines:
            mdl_overrides = {
                'lines': {},
                'grid_lines': overrides,
                'model': {
                    'radmc_observation': {
                        'iline': line
@@ -42,6 +43,7 @@ def build_model_grid():
            cube_fits.append(execute_radmc_script(grid_tarfile=grid_tarfiles[-1],
                                                  override_config=mdl_overrides))
        ratio_fits.append(prs_main(cube_fits_list=cube_fits))
    prs_inspection_main()


if __name__ == '__main__':
+3 −2
Original line number Diff line number Diff line
@@ -20,6 +20,8 @@ from assets.commons import (load_config_file,
from assets.constants import (radmc_options_mapping,
                              radmc_lines_mode_mapping)

logger = setup_logger(name='MDL')


def write_radmc_main_input_file(config_mdl: dict,
                                config_lines: dict,
@@ -88,10 +90,9 @@ def main(grid_tarfile: str,
    # This is necessary, because the lines_mode is needed both in the lines.inp and radmc3d.inp files
    # The reason for splitting the main input file from the rest is that some parameters can be changed
    # independently of the grid for the modeling. The mdl hash should depend on all the mdl parameters, not a subset
    logger = setup_logger(name='MDL')
    executed_on = datetime.now()
    config_stg = load_config_file(os.path.join('stg', 'config', 'config.yml'),
                                  override_config=override_config['lines'])
                                  override_config=override_config['grid_lines'])
    config_lines = config_stg['lines']
    config_mdl = load_config_file(os.path.join('mdl', 'config', 'config.yml'),
                                  override_config=override_config['model'])
+2 −1
Original line number Diff line number Diff line
@@ -14,6 +14,8 @@ from assets.constants import aggregation_function_mapping
from stg.stg_build_db_structure import (MomentZeroMaps,
                                        RatioMaps)

logger = setup_logger(name='PRS')


def populate_mom_zero_table(config_prs: dict,
                            fits_cube_name: str,
@@ -177,7 +179,6 @@ def main(cube_fits_list: List[str],
    _mom0_out_cube2 = validate_parameter(mom0_out_cube2, default=cube_fits_list[1].replace('.fits', '_mom0.fits'))
    config_prs = load_config_file(os.path.join('prs', 'config', 'config.yml'))['flux_computation']
    executed_on = datetime.now()
    logger = setup_logger(name='PRS')
    engine = get_pg_engine(logger=logger)
    config_prs.update({
        'cube_fits_list': cube_fits_list,
+68 −34
Original line number Diff line number Diff line
@@ -5,24 +5,60 @@ import xarray as xr
from typing import Union
from itertools import product
from astropy.io import fits
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from stg.stg_build_db_structure import (GridPars,
                                        ModelPars,
                                        RatioMaps,
                                        MomentZeroMaps)
from assets.commons import (load_config_file,
                            parse_grid_overrides,
                            setup_logger)
                            setup_logger,
                            get_pg_engine)

logger = setup_logger(name='PRS - INSPECT')


# TODO: connect to DB to retrieve filename, given parameters or recompute it if easier
def get_fitsfile_name(dust_temperature: float,
                      gas_density: float,
                      lines: Union[list, tuple]) -> str:
    return f'ratio_td_{str(int(dust_temperature))}_nh2_{str(round(gas_density, 1))}_lines_{"-".join(lines)}.fits'


def aggregate_image(data:np.array):
    return np.nanmean(data)
def get_fitsfile_name_from_db(
        dust_temperature: float,
        gas_density: float,
        lines: Union[list, tuple],
        session: Session) -> str:
    results = session.query(GridPars, RatioMaps).join(ModelPars).join(MomentZeroMaps).join(RatioMaps,
                                                                                           or_(RatioMaps.mom_zero_map_1,
                                                                                               RatioMaps.mom_zero_map_2)).filter(
        and_(GridPars.dust_temperature == dust_temperature,
             GridPars.central_density == gas_density, or_(ModelPars.iline.in_(lines)))).all()
    assert len(results) == 1
    return results[0][1].ratio_map_name


logger = setup_logger(name='PRS - INSPECT')
def get_aggregated_ratio_from_db(
        dust_temperature: float,
        gas_density: float,
        lines: Union[list, tuple],
        session: Session) -> float:
    results = session.query(GridPars, RatioMaps).join(ModelPars).join(MomentZeroMaps).join(RatioMaps,
                                                                                           or_(RatioMaps.mom_zero_map_1,
                                                                                               RatioMaps.mom_zero_map_2)).filter(
        and_(GridPars.dust_temperature == dust_temperature,
             GridPars.central_density == gas_density, or_(ModelPars.iline.in_(lines)))).all()
    assert len(results) == 1
    return results[0][1].aggregated_ratio


def main():
    engine = get_pg_engine(logger=logger)
    config = load_config_file(os.path.join('config', 'config.yml'))

    with Session(engine) as session:
        # grid definition
        dust_temperatures = parse_grid_overrides(par_name='dust_temperature',
                                                 config=config)
@@ -36,18 +72,16 @@ results = xr.DataArray(np.empty(shape=[len(dust_temperatures), len(central_densi
                                   'dust_temperature': dust_temperatures,
                                   'gas_density': central_densities
                               })
for (tdust, nH2) in product(dust_temperatures, central_densities):
    overrides = {
        'grid': {
            'dust_temperature': tdust,
            'central_density': nH2,
        }
    }
    filename = get_fitsfile_name(dust_temperature=tdust, gas_density=nH2, lines=lines)
    logger.debug(get_fitsfile_name(dust_temperature=tdust, gas_density=nH2, lines=lines))
    hdu = fits.open(os.path.join('prs', 'fits', 'ratios', filename))
    aggregated_ratio = aggregate_image(hdu[0].data)
    results.loc[tdust, nH2] = aggregated_ratio
    logger.debug(aggregated_ratio)
        for (tdust, nh2) in product(dust_temperatures, central_densities):
            aggregated_ratio = get_aggregated_ratio_from_db(dust_temperature=tdust,
                                                 gas_density=nh2,
                                                 lines=lines,
                                                 session=session)
            results.loc[tdust, nh2] = aggregated_ratio
            logger.debug(f'The aggregated ratio for lines {lines}, using {nh2}, {tdust} is: {aggregated_ratio}')
    results.plot(x='dust_temperature', y='gas_density', yscale='log')
plt.savefig(f'ratio_grid_lines_{"-".join(lines)}.png')
    plt.savefig(os.path.join('prs', 'output', f'ratio_grid_lines_{"-".join(lines)}.png'))


if __name__ == '__main__':
    main()
+2 −11
Original line number Diff line number Diff line
@@ -87,7 +87,6 @@ class ModelPars(Base):
class MomentZeroMaps(Base):
    __tablename__ = "moment_zero_maps"
    mom_zero_name = Column(String(150), primary_key=True)
    ratio_map_name = relationship("MomentZeroMaps", cascade="all, delete-orphan")
    fits_cube_name = Column(String(150), ForeignKey('model_parameters.fits_cube_name'), nullable=False)
    integration_limit_low = Column(Float)
    integration_limit_high = Column(Float)
@@ -101,19 +100,11 @@ class RatioMaps(Base):
    ratio_map_name = Column(String(150), primary_key=True)
    mom_zero_name_1 = Column(String(150), ForeignKey('moment_zero_maps.mom_zero_name'), nullable=False)
    mom_zero_name_2 = Column(String(150), ForeignKey('moment_zero_maps.mom_zero_name'), nullable=False)
    mom_zero_map_1 = relationship("MomentZeroMaps", foreign_keys=mom_zero_name_1)
    mom_zero_map_2 = relationship("MomentZeroMaps", foreign_keys=mom_zero_name_2)
    aggregated_ratio = Column(Float)
    aggregation_function = Column(String(20))
    created_on = Column(DateTime)


# class Address(Base):
#     __tablename__ = "address"
#     id = Column(Integer, primary_key=True)
#     email_address = Column(String, nullable=False)
#     user_id = Column(Integer, ForeignKey("user_account.id"), nullable=False)
#     user = relationship("User", back_populates="addresses")
#     def __repr__(self):
#         return f"Address(id={self.id!r}, email_address={self.email_address!r})"


Base.metadata.create_all(engine)
Loading