Commit f695cdf9 authored by Andrea Giannetti's avatar Andrea Giannetti
Browse files

Added checkpointing to avoid costly recomputation; parallelized main; added DB queue.

parent d6a96b24
Loading
Loading
Loading
Loading
+18 −0
Original line number Diff line number Diff line
@@ -393,3 +393,21 @@ def get_value_if_specified(parameters_dict: dict,
        return parameters_dict[key]
    except KeyError:
        return None


def parse_input_main():
    stg_config = load_config_file(os.path.join('stg', 'config', 'config.yml'))
    pl_density_idx = float(stg_config['grid']['density_powerlaw_idx'])
    pl_dust_temperature_idx = float(stg_config['grid']['dust_temperature_powerlaw_idx'])
    _model_type = 'spherical' if pl_density_idx != 0 else 'homogeneous'
    _tdust_model_type = 'heated' if pl_dust_temperature_idx != 0 else 'isothermal'
    config = load_config_file(os.path.join('config', 'config.yml'))
    # grid definition
    dust_temperatures = parse_grid_overrides(par_name='dust_temperature',
                                             config=config)
    densities = parse_grid_overrides(par_name='gas_density',
                                     config=config)
    line_pairs = config['overrides']['lines_to_process']
    return _tdust_model_type, _model_type, dust_temperatures, densities, line_pairs

+159 −30
Original line number Diff line number Diff line
import glob
import os
import uuid
from stg.stg_build_db_structure import init_db
import sqlalchemy
import argparse
import sys
from stg.stg_build_db_structure import init_db, TmpExecutionQueue
from itertools import product, chain
from multiprocessing import Pool
from assets.commons import (load_config_file,
                            parse_grid_overrides,
                            cleanup_directory,
from typing import Union
from assets.commons import (cleanup_directory,
                            setup_logger,
                            get_pg_engine)
                            get_pg_engine,
                            upsert,
                            parse_input_main,
                            validate_parameter)
from stg.stg_radmc_input_generator import main as stg_main
from mdl.mdl_execute_radmc_command import main as execute_radmc_script
from prs.prs_compute_integrated_fluxes_and_ratios import main as prs_main
from prs.prs_inspect_results import main as prs_inspection_main



def compute_grid(tdust, nh2, line, density_keyword, dust_temperature_keyword):
    scratch_dir = os.path.join('mdl', 'scratches', str(uuid.uuid4()))
    stg_overrides = {
@@ -38,40 +43,148 @@ def compute_grid(tdust, nh2, line, density_keyword, dust_temperature_keyword):
                                          override_config=overrides,
                                          radmc_input_path=scratch_dir,
                                          run_id=run_id)
    return nh2, tdust, line, cube_fits_name
    return cube_fits_name


def build_model_grid(run_id: str,
                     cleanup_scratches: bool = True):
def initialize_queue(engine,
                     run_id,
                     run_arguments):
    for arguments in run_arguments:
        raw_insert_entry = {'run_id': run_id,
                            'dust_temperature': arguments[0],
                            'density': arguments[1],
                            'line': arguments[2],
                            'density_keyword': arguments[3],
                            'dust_temperature_keyword': arguments[4],
                            'done': False}
        upsert(
            table_object=TmpExecutionQueue,
            row_dict=raw_insert_entry,
            conflict_keys=[
                TmpExecutionQueue.run_id,
                TmpExecutionQueue.dust_temperature,
                TmpExecutionQueue.density,
                TmpExecutionQueue.line,
                TmpExecutionQueue.density_keyword,
                TmpExecutionQueue.dust_temperature_keyword
            ],
            engine=engine
        )


def get_run_pars(engine: sqlalchemy.engine,
                 run_id: str):
    sql_query = sqlalchemy.text(f"""UPDATE tmp_execution_queue 
                SET done = true 
                WHERE row_id = (SELECT row_id 
                                   FROM tmp_execution_queue
                                   WHERE done=false
                                   AND run_id='{run_id}'
                                   AND pg_try_advisory_xact_lock(row_id) 
                                   LIMIT 1 FOR UPDATE) 
                RETURNING *""")
    return engine.execution_options(autocommit=True).execute(sql_query).first()


def insert_fits_name(engine: sqlalchemy.engine,
                     row_id: int,
                     fits_cube_name: str):
    sql_query = sqlalchemy.text(f"""UPDATE tmp_execution_queue 
                SET fits_cube_name =  '{fits_cube_name}'
                WHERE row_id = {row_id}""")
    engine.execution_options(autocommit=True).execute(sql_query)


def compute_grid_elements(run_id: str):
    init_db()
    stg_config = load_config_file(os.path.join('stg', 'config', 'config.yml'))
    pl_density_index = float(stg_config['grid']['density_powerlaw_idx'])
    pl_dust_temperature_idx = float(stg_config['grid']['dust_temperature_powerlaw_idx'])
    _model_type = 'spherical' if pl_density_index != 0 else 'homogeneous'
    _tdust_model_type = 'heated' if pl_dust_temperature_idx != 0 else 'isothermal'

    config = load_config_file(os.path.join('config', 'config.yml'))

    # grid definition
    dust_temperatures = parse_grid_overrides(par_name='dust_temperature',
                                             config=config)
    densities = parse_grid_overrides(par_name='gas_density',
                                     config=config)
    line_pairs = config['overrides']['lines_to_process']
    _tdust_model_type, _model_type, dust_temperatures, densities, line_pairs = parse_input_main()
    line_set = set(chain.from_iterable(line_pairs))

    density_keyword = 'central_density' if _model_type == 'homogeneous' else 'density_at_reference'
    dust_temperature_keyword = 'dust_temperature' if _model_type == 'isothermal' else 'dust_temperature_at_reference'

    parallel_args = product(dust_temperatures, densities, line_set, [density_keyword], [dust_temperature_keyword])
    with Pool(config['computation']['threads']) as pool:
        results = pool.starmap(compute_grid, parallel_args)
    engine = get_pg_engine(logger=logger)
    initialize_queue(engine=engine,
                     run_id=run_id,
                     run_arguments=parallel_args)
    engine.dispose()


def compute_model(run_id: str):
    engine = get_pg_engine(logger=logger)
    parameters_set = get_run_pars(engine=engine,
                                  run_id=run_id)
    if parameters_set is not None:
        fits_cube_name = compute_grid(tdust=parameters_set[2],
                                      nh2=parameters_set[3],
                                      line=parameters_set[4],
                                      density_keyword=parameters_set[5],
                                      dust_temperature_keyword=parameters_set[6])
        insert_fits_name(engine=engine,
                         row_id=parameters_set[0],
                         fits_cube_name=fits_cube_name)
    else:
        logger.info('All models were completed.')
    engine.dispose()


def initialize_run():
    if args.run_id is not None:
        run_id = args.run_id
    else:
        logger.info('Generating new run_id')
        run_id = str(uuid.uuid4())
        compute_grid_elements(run_id=run_id)
    sys.stdout.write(run_id)
    return run_id


def compute_remaining_models(run_id: Union[None, str] = None) -> int:
    _run_id = validate_parameter(run_id, default=os.getenv('run_id'))
    logger.info(_run_id)
    engine = get_pg_engine(logger=logger)
    sql_query = sqlalchemy.text(f"""SELECT count(*)
                                    FROM tmp_execution_queue
                                    WHERE run_id = '{run_id}'
                                        AND done = false""")
    n_models = engine.execution_options(autocommit=True).execute(sql_query).first()[0]
    engine.dispose()
    sys.stdout.write(str(n_models))
    return n_models


def get_results_mapping(engine: sqlalchemy.engine,
                        run_id: str):
    sql_query = sqlalchemy.text(f"""SELECT dust_temperature
                                           , density
                                           , line
                                           , fits_cube_name
                                    FROM tmp_execution_queue
                                    WHERE run_id = '{run_id}'""")
    return engine.execution_options(autocommit=True).execute(sql_query).all()


def cleanup_tmp_table(run_id: str,
                      engine: sqlalchemy.engine):
    sql_query = sqlalchemy.text(f"""DELETE
                                    FROM tmp_execution_queue
                                    WHERE run_id = '{run_id}'""")
    return engine.execution_options(autocommit=True).execute(sql_query)


def main_presentation_step(run_id: str,
         cleanup_scratches: bool = True):
    _tdust_model_type, _model_type, dust_temperatures, densities, line_pairs = parse_input_main()

    engine = get_pg_engine(logger=logger)
    results = get_results_mapping(engine=engine,
                                  run_id=run_id)

    results_map = {}
    for (nh2, tdust, line, cube_fits_name) in results:
    for (tdust, nh2, line, cube_fits_name) in results:
        results_map[f'{str(nh2)}_{str(tdust)}_{line}'] = cube_fits_name

    engine = get_pg_engine(logger=logger)
    for line_pair in line_pairs:
        for tdust, nh2 in product(dust_temperatures, densities):
            prs_main(cube_fits_list=[results_map[f'{str(nh2)}_{str(tdust)}_{line_pair[0]}'],
@@ -87,9 +200,25 @@ def build_model_grid(run_id: str,
    prs_inspection_main(run_id=run_id,
                        is_isothermal=_tdust_model_type == 'isothermal',
                        engine=engine)
    cleanup_tmp_table(run_id=run_id,
                      engine=engine)
    engine.dispose()


if __name__ == '__main__':
logger = setup_logger(name='MAIN')
    run_id = str(uuid.uuid4())
    build_model_grid(run_id=run_id)
parser = argparse.ArgumentParser()
parser.add_argument('--run_id')
parser.add_argument('--cleanup_scratches')
args = parser.parse_args()

if __name__ == '__main__':
    run_id = initialize_run()
    assert run_id is not None
    compute_model(run_id=run_id)
    logger.debug(compute_remaining_models(run_id))
    if compute_remaining_models(run_id) == 0:
        logger.info('All grid points processed. Summarizing results.')
        _cleanup = validate_parameter(args.cleanup_scratches,
                                      default=True)
        main_presentation_step(run_id=run_id,
                               cleanup_scratches=_cleanup)
+15 −12
Original line number Diff line number Diff line
@@ -199,7 +199,12 @@ def main(grid_zipfile: str,
    if engine is None:
        engine = get_pg_engine(logger=logger, engine_kwargs={'pool_size': 3})

    # Execute radmc
    config_full = config_mdl.copy()
    config_full.update(config_stg)
    cube_filename = f'{compute_unique_hash_filename(config=config_full)}.fits'

    # Execute radmc if not done already
    if not os.path.isfile(os.path.join('prs', 'fits', 'cubes', cube_filename)):
        logger.debug(f'Executing command: {radmc_command}')
        execution_dir = os.getcwd()
        os.chdir(_radmc_input_path)
@@ -208,13 +213,11 @@ def main(grid_zipfile: str,
        logger.debug(f'Checking presence of file: {os.path.join(_radmc_input_path, "image.out")}')
        assert os.path.isfile(os.path.join(_radmc_input_path, 'image.out'))

    config_full = config_mdl.copy()
    config_full.update(config_stg)
    cube_filename = f'{compute_unique_hash_filename(config=config_full)}.fits'

        save_cube_as_fits(cube_out_name=cube_filename,
                          cube_out_path=os.path.join('prs', 'fits', 'cubes'),
                          path_radmc_files=radmc_input_path)
    else:
        logger.info('Computation performed already! Skipping...')

    populate_line_table(config_lines=config_lines,
                        engine=engine,

etl/slurm_template.sh

0 → 100644
+28 −0
Original line number Diff line number Diff line
#!/bin/bash
#SBATCH --job-name=rdp05
#SBATCH --output=pleiadi_05.txt
##SBATCH --time=240:00:00
##SBATCH --partition arc
#
##SBATCH --ntasks=1
#SBATCH --cpus-per-task=2
##SBATCH --mem-per-cpu=800

pwd
cd /homes/agianne/agianne/sak/swiss_army_knife_lr_ple05/etl
pwd
apptainer-setup 1.0.1-centos7

run_id_output=$(python -c'import main_parallel; main_parallel.initialize_run()')
run_id=$(echo $run_id_output | rev | cut -d" " -f1 | rev)

remaining_tasks_output=$(python -c'import main_parallel; main_parallel.compute_remaining_models("'$run_id'")')
remaining_tasks=$((echo $remaining_tasks_output | rev | cut -d" " -f1 | rev))

# batch=10
# jobs_to_submit=$(( remaining_tasks < batch ? remaining_tasks : batch ))

for i in {1..$remaining_tasks}
do
  srun singularity run -B .:$HOME swiss_army_knife_latest.sif > run.log
done
+15 −0
Original line number Diff line number Diff line
@@ -5,9 +5,11 @@ import time
from sqlalchemy import (Column,
                        ForeignKey,
                        Integer,
                        Sequence,
                        String,
                        Float,
                        DateTime,
                        Boolean,
                        ARRAY,
                        ForeignKeyConstraint)
from assets.commons import (get_pg_engine,
@@ -189,6 +191,19 @@ class RatioMaps(Base):
    mom_zero_map_2 = relationship("MomentZeroMaps", foreign_keys=[mom_zero_name_2, run_id])


class TmpExecutionQueue(Base):
    __tablename__ = "tmp_execution_queue"
    row_id = Column(Integer, Sequence('row_id_seq'))
    run_id = Column(String, primary_key=True)
    dust_temperature = Column(Float, primary_key=True)
    density = Column(Float, primary_key=True)
    line = Column(Integer, primary_key=True)
    density_keyword = Column(String, primary_key=True)
    dust_temperature_keyword = Column(String, primary_key=True)
    fits_cube_name = Column(String)
    done = Column(Boolean)


def init_db():
    engine = get_pg_engine(logger=logger)
    try: