Commit 65fb2825 authored by Andrea Giannetti's avatar Andrea Giannetti
Browse files

Created script to reproduce comparison figures for the results. Updated documentation.md.

parent b0da2f6c
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -55,7 +55,7 @@ overrides are included into the [etl/config/config.yml](../etl/config/config.yml
4. **Additional files**:
The script [prs_analytical_representations.py](../etl/prs/prs_analytical_representations.py) provides a convenient way of checking the analytical representations of the ratio vs. density curves.
The file [prs_check_biases_poc_sample.py](../etl/prs/prs_check_biases_poc_sample.py) checks for biases in the massive clump sample used in the proof-of-concept.
The scripts [prs_poc_figures.py](../etl/prs/prs_poc_figures.py), and [prs_poc_latex_table.py](../etl/prs/prs_poc_latex_table.py) can be used to reproduce the content of the paper, regarding the POC.
The scripts [prs_poc_figures.py](../etl/prs/prs_poc_figures.py), [prs_make_comparison_figures.py](../etl/prs/prs_make_comparison_figures.py), and [prs_poc_latex_table.py](../etl/prs/prs_poc_latex_table.py) can be used to reproduce some of the content of the paper, in terms of figures and tables.

### Running the pipeline

+2 −2
Original line number Diff line number Diff line
@@ -599,9 +599,9 @@ def get_inference_data(use_model_for_inference: str, limit_rows: int):
    return data, line_pairs


if __name__ == '__main__':
    external_input = load_config_file(config_file_path='config/density_inference_input.yml')
logger = setup_logger(name='PRS - DENSITY INFERENCE')
if __name__ == '__main__':
    external_input = load_config_file(config_file_path=os.path.join('config', 'density_inference_input.yml'))
    try:
        limit_rows = external_input['limit_rows']
    except KeyError:
+149 −0
Original line number Diff line number Diff line
import pandas as pd
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
from assets.commons import (load_config_file,
                            validate_parameter,
                            setup_logger,
                            get_postprocessed_data)
from assets.constants import line_ratio_mapping
from prs.prs_density_inference import get_inference_data
from typing import Tuple, Union, List


filename_root_map = {
    'isothermal_p15': 'comparison_isothermal_fiducial',
    # 'constant_abundance_p15_q05_x01': 'comparison_lowabundance_fiducial',
    # 'constant_abundance_p15_q05_x10': 'comparison_highabundance_fiducial',
    'hot_core_p15_q05': 'comparison_hotcore_fiducial',
}


def plot_kde_ratio_nh2(grid: np.array,
                       values_on_grid: np.array,
                       comparison_grid: np.array,
                       comparison_values_on_grid: np.array,
                       ratio_string: str,
                       data: pd.DataFrame,
                       comparison_data: pd.DataFrame,
                       root_outfile: str = None,
                       ratio_limits: Union[None, list] = None):
    """
        Plot the Kernel Density Estimate (KDE) of a ratio against average H2 density along the line-of-sight and save
         the plot as a PNG file.

        :param grid: The grid of x and y values used for the KDE.
        :param comparison_grid: The grid of x and y values used for the KDE of the comparison model.
        :param values_on_grid: The computed KDE values on the grid.
        :param comparison_values_on_grid: The computed KDE values on the comparison model grid.
        :param ratio_string: The ratio string indicating which ratio of the training data to plot.
        :param data: The DataFrame containing the data for the scatterplot.
        :param data: The DataFrame containing the comparison model data for the scatterplot.
        :param root_outfile: The root of the filename used to save the figures.
        :param ratio_limits: Optional. The limits for the ratio axis. Defaults to None, which auto-scales the axis.
        :return: None. Saves the plot as a PNG file in the specified folder.
    """
    plt.rcParams.update({'font.size': 20})
    plt.clf()
    plt.figure(figsize=(8, 6))
    plt.scatter(data['avg_nh2'], data[f'ratio_{ratio_string}'], marker='+', alpha=0.1,
                facecolor='grey')
    plt.scatter(comparison_data['avg_nh2'], comparison_data[f'ratio_{ratio_string}'], marker='x', alpha=0.01,
                facecolor='green')

    plt.contour(10 ** grid[0], grid[1], values_on_grid, levels=np.arange(0.05, 0.95, 0.15),
                colors='black')
    plt.contour(10 ** comparison_grid[0], comparison_grid[1],
                comparison_values_on_grid,
                levels=np.arange(0.05, 0.95, 0.15),
                colors='lightgreen')
    plt.semilogx()
    plt.xlabel(r'<$n$(H$_2$)> [cm$^{-3}$]')
    plt.ylabel(f'Ratio {line_ratio_mapping[ratio_string]}')
    plt.ylim(ratio_limits)
    plt.tight_layout()
    plt.savefig(os.path.join(
        'prs',
        'output',
        'comparison_figures',
        f'{root_outfile}_{ratio_string}.png'))


def main(ratio_list: list,
         comparison_model: str,
         points_per_axis: int = 200,
         limit_rows: Union[None, int] = None,
         best_bandwidths: Union[None, dict] = None,
         ratio_limits: Union[None, dict] = None):
    _use_model_for_inference = 'constant_abundance_p15_q05'
    _model_root_folder = os.path.join('prs', 'output', 'run_type', _use_model_for_inference)
    data, _ = get_inference_data(use_model_for_inference=_use_model_for_inference,
                                          limit_rows=limit_rows)
    comparison_data, _ = get_inference_data(use_model_for_inference=comparison_model,
                                          limit_rows=limit_rows)

    for ratio_string in ratio_list:
        with open(
            os.path.join(_model_root_folder, 'trained_model', f'ratio_density_kde_{ratio_string}.pickle'), 'rb'
        ) as infile:
            kde_dict = pickle.load(infile)
        with open(
            os.path.join(os.path.join('prs', 'output', 'run_type', comparison_model),
                         'trained_model', f'ratio_density_kde_{ratio_string}.pickle'), 'rb'
        ) as infile:
            comparison_kde_dict = pickle.load(infile)

        grids = []
        for kde in [kde_dict, comparison_kde_dict]:
            grids.append(get_grid(best_bandwidths, kde, ratio_string))

        plot_kde_ratio_nh2(grid=grids[0],
                           values_on_grid=kde_dict['values_rt_only'].reshape(points_per_axis, points_per_axis),
                           ratio_string=ratio_string,
                           root_outfile=filename_root_map[comparison_model],
                           data=data[data['source'] == 'RT'],
                           comparison_data=comparison_data[comparison_data['source'] == 'RT'],
                           comparison_grid=grids[1],
                           comparison_values_on_grid=comparison_kde_dict['values_rt_only'].reshape(points_per_axis, points_per_axis),
                           ratio_limits=ratio_limits[ratio_string])


def get_grid(best_bandwidths: dict,
             kde_dict: dict,
             ratio_string: str) -> np.array:
    """
    Computes the KDE grid for plotting.
    :param best_bandwidths: The values of the ratio bandwidth used for the KDE computation.
    :param kde_dict: The dictionary used to persist the KDE results, containing the grid point (x, y) and the values of
     the PDF at those points.
    :param ratio_string: The line ratio to be used for plotting.
    :return: The array with the scaled grid for plotting.
    """
    scaled_grid = np.meshgrid(kde_dict['x'], kde_dict['y'], indexing='ij')
    grid = scaled_grid.copy()
    grid[0] = scaled_grid[0] * 0.2
    grid[1] = scaled_grid[1] * best_bandwidths[ratio_string]
    return grid


if __name__ == '__main__':
    external_input = load_config_file(config_file_path=os.path.join('config', 'density_inference_input.yml'))
    logger = setup_logger(name='PRS - FIGURES')
    try:
        limit_rows = external_input['limit_rows']
    except KeyError:
        limit_rows = None
    try:
        points_per_axis = external_input['points_per_axis']
    except KeyError:
        points_per_axis = 200

    for comparison_model in filename_root_map.keys():
        logger.info(f'Producing figure for {comparison_model}.')
        main(ratio_list=external_input['ratios_to_include'],
             points_per_axis=points_per_axis,
             limit_rows=limit_rows,
             comparison_model=comparison_model,
             best_bandwidths=external_input['best_bandwidths'],
             ratio_limits=external_input['ratio_limits'])