Commit 7995f9b5 authored by Andrea Giannetti's avatar Andrea Giannetti
Browse files

Added parameter to choose between distributed computing and multiprocessing.

parent f695cdf9
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -408,6 +408,7 @@ def parse_input_main():
    densities = parse_grid_overrides(par_name='gas_density',
                                     config=config)
    line_pairs = config['overrides']['lines_to_process']
    return _tdust_model_type, _model_type, dust_temperatures, densities, line_pairs
    n_processes = validate_parameter(config['computation']['threads'], default=10)
    return _tdust_model_type, _model_type, dust_temperatures, densities, line_pairs, n_processes

+67 −20
Original line number Diff line number Diff line
@@ -4,9 +4,10 @@ import uuid
import sqlalchemy
import argparse
import sys
from multiprocessing import Pool
from stg.stg_build_db_structure import init_db, TmpExecutionQueue
from itertools import product, chain
from typing import Union
from typing import Union, Tuple, Iterator
from assets.commons import (cleanup_directory,
                            setup_logger,
                            get_pg_engine,
@@ -19,7 +20,6 @@ from prs.prs_compute_integrated_fluxes_and_ratios import main as prs_main
from prs.prs_inspect_results import main as prs_inspection_main



def compute_grid(tdust, nh2, line, density_keyword, dust_temperature_keyword):
    scratch_dir = os.path.join('mdl', 'scratches', str(uuid.uuid4()))
    stg_overrides = {
@@ -46,6 +46,32 @@ def compute_grid(tdust, nh2, line, density_keyword, dust_temperature_keyword):
    return cube_fits_name


def compute_full_grid(tdust, nh2, line, density_keyword, dust_temperature_keyword):
    scratch_dir = os.path.join('mdl', 'scratches', str(uuid.uuid4()))
    stg_overrides = {
        'grid': {
            dust_temperature_keyword: tdust,
            density_keyword: nh2,
        }
    }
    overrides = {
        'grid_lines': stg_overrides,
        'model': {
            'radmc_observation': {
                'iline': line
            }
        }
    }
    tarname = stg_main(override_config=overrides,
                       path_radmc_files=scratch_dir,
                       run_id=run_id)
    cube_fits_name = execute_radmc_script(grid_zipfile=tarname,
                                          override_config=overrides,
                                          radmc_input_path=scratch_dir,
                                          run_id=run_id)
    return tdust, nh2, line, cube_fits_name


def initialize_queue(engine,
                     run_id,
                     run_arguments):
@@ -97,13 +123,7 @@ def insert_fits_name(engine: sqlalchemy.engine,

def compute_grid_elements(run_id: str):
    init_db()
    _tdust_model_type, _model_type, dust_temperatures, densities, line_pairs = parse_input_main()
    line_set = set(chain.from_iterable(line_pairs))

    density_keyword = 'central_density' if _model_type == 'homogeneous' else 'density_at_reference'
    dust_temperature_keyword = 'dust_temperature' if _model_type == 'isothermal' else 'dust_temperature_at_reference'

    parallel_args = product(dust_temperatures, densities, line_set, [density_keyword], [dust_temperature_keyword])
    parallel_args, _ = get_parallel_args_and_nprocesses()
    engine = get_pg_engine(logger=logger)
    initialize_queue(engine=engine,
                     run_id=run_id,
@@ -111,6 +131,15 @@ def compute_grid_elements(run_id: str):
    engine.dispose()


def get_parallel_args_and_nprocesses() -> Tuple[Iterator, int]:
    _tdust_model_type, _model_type, dust_temperatures, densities, line_pairs, n_processes = parse_input_main()
    line_set = set(chain.from_iterable(line_pairs))
    density_keyword = 'central_density' if _model_type == 'homogeneous' else 'density_at_reference'
    dust_temperature_keyword = 'dust_temperature' if _model_type == 'isothermal' else 'dust_temperature_at_reference'
    parallel_args = product(dust_temperatures, densities, line_set, [density_keyword], [dust_temperature_keyword])
    return parallel_args, n_processes


def compute_model(run_id: str):
    engine = get_pg_engine(logger=logger)
    parameters_set = get_run_pars(engine=engine,
@@ -154,7 +183,7 @@ def compute_remaining_models(run_id: Union[None, str] = None) -> int:
    return n_models


def get_results_mapping(engine: sqlalchemy.engine,
def get_results(engine: sqlalchemy.engine,
                run_id: str):
    sql_query = sqlalchemy.text(f"""SELECT dust_temperature
                                           , density
@@ -174,15 +203,17 @@ def cleanup_tmp_table(run_id: str,


def main_presentation_step(run_id: str,
         cleanup_scratches: bool = True):
    _tdust_model_type, _model_type, dust_temperatures, densities, line_pairs = parse_input_main()
                           cleanup_scratches: bool = True,
                           results_dict: Union[dict, None] = None):
    _tdust_model_type, _model_type, dust_temperatures, densities, line_pairs, n_processes = parse_input_main()

    engine = get_pg_engine(logger=logger)
    results = get_results_mapping(engine=engine,
                                  run_id=run_id)
    _results_dict = validate_parameter(results_dict,
                                       default=get_results(engine=engine,
                                                           run_id=run_id))

    results_map = {}
    for (tdust, nh2, line, cube_fits_name) in results:
    for (tdust, nh2, line, cube_fits_name) in _results_dict:
        results_map[f'{str(nh2)}_{str(tdust)}_{line}'] = cube_fits_name

    for line_pair in line_pairs:
@@ -205,20 +236,36 @@ def main_presentation_step(run_id: str,
    engine.dispose()


def process_models(distributed: bool = False) -> Tuple[Union[None, dict], int]:
    if distributed is True:
        compute_model(run_id=run_id)
        results = None
        remaining_models = compute_remaining_models(run_id)
    else:
        parallel_args, n_processes = get_parallel_args_and_nprocesses()
        with Pool(n_processes) as pool:
            results = pool.starmap(compute_full_grid, parallel_args)
        remaining_models = 0
    return results, remaining_models


logger = setup_logger(name='MAIN')
parser = argparse.ArgumentParser()
parser.add_argument('--run_id')
parser.add_argument('--cleanup_scratches')
parser.add_argument('--distributed')
args = parser.parse_args()


if __name__ == '__main__':
    run_id = initialize_run()
    assert run_id is not None
    compute_model(run_id=run_id)
    logger.debug(compute_remaining_models(run_id))
    if compute_remaining_models(run_id) == 0:
    _distributed = validate_parameter(args.distributed, default='false').lower() == 'true'
    results, remaining_models = process_models(distributed=_distributed)
    if remaining_models == 0:
        logger.info('All grid points processed. Summarizing results.')
        _cleanup = validate_parameter(args.cleanup_scratches,
                                      default=True)
        main_presentation_step(run_id=run_id,
                               cleanup_scratches=_cleanup)
                               cleanup_scratches=_cleanup,
                               results_dict=results)