Commit 3fc30cbe authored by Andrea Giannetti's avatar Andrea Giannetti
Browse files

Refactored input validation and default filling; inserted default overriding to main.

parent 09aec9bd
Loading
Loading
Loading
Loading
+12 −2
Original line number Diff line number Diff line
@@ -11,14 +11,24 @@ from assets.constants import (radmc_grid_map,
                              leiden_url_mapping)


def load_config_file(config_file_path: str) -> dict:
def validate_parameter(param_to_validate,
                       default):
    return param_to_validate if param_to_validate is not None else default


def load_config_file(config_file_path: str,
                     override_config: Union[dict, None] = None) -> dict:
    """
    Load the information in the YAML configuration file into a python dictionary
    :param config_file_path: path to the configuration file
    :param override_config: parameters of the input file to override (e.g. for grid creation)
    :return: a dictionary with the parsed information
    """
    _override_config = validate_parameter(override_config, default={})
    with open(config_file_path) as config_file:
        config = yaml.load(config_file, Loader=yaml.FullLoader)
    for key in _override_config:
        config[key] = _override_config[key]
    return config


@@ -40,7 +50,7 @@ def compute_power_law_radial_profile(
        the grid
    :return: the distance matrix
    """
    _value_at_reference = value_at_reference if value_at_reference is not None else central_value
    _value_at_reference = validate_parameter(value_at_reference, default=central_value)
    _distance_matrix = np.where(distance_matrix == 0, 1,
                                distance_matrix) if fill_reference_pixel is True else distance_matrix
    profile = _value_at_reference * (_distance_matrix / distance_reference) ** power_law_index
+4 −2
Original line number Diff line number Diff line
import os
from typing import Union
from astropy import units as u
from assets.commons import (load_config_file,
                            convert_frequency_to_wavelength)


def main():
    config = load_config_file(os.path.join('mdl', 'config', 'config.yml'))
def main(override_config: Union[dict, None] = None):
    config = load_config_file(os.path.join('mdl', 'config', 'config.yml'),
                              override_config=override_config)
    central_frequency = convert_frequency_to_wavelength(
        frequency=float(config['radmc']['central_frequency']) * u.Unit(config['radmc']['frequency_units']),
        output_units=u.Unit("micron"))
+8 −7
Original line number Diff line number Diff line
import os
from astropy.io import fits
from typing import Union
from assets.commons import validate_parameter


def compute_moment_zero(cube: Union[str, fits.PrimaryHDU],
@@ -8,9 +9,9 @@ def compute_moment_zero(cube: Union[str, fits.PrimaryHDU],
                        moment_zero_path: Union[str, None] = None,
                        moment_zero_fits_name: Union[str, None] = None,
                        hdu_idx: int = 0):
    _cube_path = cube_path if cube_path is not None else os.path.join('prs', 'fits', 'cubes')
    _moment_zero_path = moment_zero_path if moment_zero_path is not None else os.path.join('prs', 'fits', 'moments')
    _moment_zero_fits_name = moment_zero_fits_name if moment_zero_fits_name is not None else 'test_mom0.fits'
    _cube_path = validate_parameter(cube_path, default=os.path.join('prs', 'fits', 'cubes'))
    _moment_zero_path = validate_parameter(moment_zero_path, default=os.path.join('prs', 'fits', 'moments'))
    _moment_zero_fits_name = validate_parameter(moment_zero_fits_name, default='test_mom0.fits')
    fitsfile = open_fits_file_duck_typing(fitsfile=cube, fits_path=_cube_path)
    header = fitsfile[hdu_idx].header.copy()
    data = fitsfile[hdu_idx].data
@@ -26,7 +27,7 @@ def compute_moment_zero(cube: Union[str, fits.PrimaryHDU],

def open_fits_file_duck_typing(fitsfile: Union[str, fits.PrimaryHDU],
                               fits_path: str = None) -> fits.PrimaryHDU:
    _fits_path = fits_path if fits_path is not None else '.'
    _fits_path = validate_parameter(fits_path, default='.')
    try:
        hdu = fits.open(os.path.join(_fits_path, fitsfile))
    except TypeError:
@@ -41,9 +42,9 @@ def compute_image_ratios(fits1: str,
                         ratio_fits_name: Union[str, None] = None,
                         hdu1_idx: int = 0,
                         hdu2_idx: int = 0):
    _fits_path = fits_path if fits_path is not None else os.path.join('prs', 'fits', 'moments')
    _ratio_fits_path = ratio_fits_path if ratio_fits_path is not None else os.path.join('prs', 'fits', 'ratios')
    _ratio_fits_name = ratio_fits_name if ratio_fits_name is not None else 'ratio.fits'
    _fits_path = validate_parameter(fits_path, default=os.path.join('prs', 'fits', 'moments'))
    _ratio_fits_path = validate_parameter(ratio_fits_path, default=os.path.join('prs', 'fits', 'ratios'))
    _ratio_fits_name = validate_parameter(ratio_fits_name, default='ratio.fits')
    hdu1 = open_fits_file_duck_typing(fitsfile=fits1,
                                      fits_path=_fits_path)
    hdu2 = open_fits_file_duck_typing(fitsfile=fits2,
+8 −6
Original line number Diff line number Diff line
@@ -4,7 +4,8 @@ from typing import Union
from assets.commons import (compute_power_law_radial_profile,
                            extract_grid_metadata,
                            load_config_file,
                            get_moldata)
                            get_moldata,
                            validate_parameter)
from assets.constants import (mean_molecular_mass,
                              radmc_input_headers)
from astropy import units as u
@@ -28,7 +29,7 @@ def write_radmc_input(filename,
        'active_axes': ' '.join(['1' for i in range(len(grid_metadata['grid_shape']))]),
        'continuum_lambdas': 100,  ##
    }
    _path = '.' if path is None else path
    _path = validate_parameter(path, default='.')
    _filename = 'numberdens_mol.inp' if filename.startswith('numberdens_') else filename
    with open(os.path.join(_path, filename), "w") as outfile:
        for key in radmc_input_headers[_filename]:
@@ -42,7 +43,7 @@ def write_radmc_input(filename,

def write_radmc_lines_input(line_config: dict,
                            path: Union[None, str] = None):
    _path = '.' if path is None else path
    _path = validate_parameter(path, default='.')

    if line_config['lines_mode'] == 1:
        assert len(line_config['collision_partners']) == 0
@@ -147,7 +148,7 @@ def write_molecular_number_density_profiles(profiles: dict,

def write_radmc_main_input_file(config: dict,
                                path: Union[None, str] = None):
    _path = '.' if path is None else path
    _path = validate_parameter(path, default='.')
    with open(os.path.join(_path, 'radmc3d.inp'), "w") as outfile:
        outfile.write(f'nphot = {config["radmc"]["nphotons"]}\n')
        outfile.write(f'scattering_mode_max = {config["radmc"]["scattering_mode_max"]}\n')
@@ -156,8 +157,9 @@ def write_radmc_main_input_file(config: dict,
        outfile.write(f'lines_mode = {config["lines"]["lines_mode"]}\n')


def main():
    config = load_config_file(os.path.join('stg', 'config', 'config.yml'))
def main(override_config: Union[dict, None] = None):
    config = load_config_file(os.path.join('stg', 'config', 'config.yml'),
                              override_config=override_config)
    grid_metadata = extract_grid_metadata(config=config)
    profiles = get_profiles(grid_metadata=grid_metadata)
    write_grid_input_files(grid_metadata=grid_metadata,