Commit 8ac63378 authored by Ambra Di Piano's avatar Ambra Di Piano
Browse files

replicate dataset option

parent d5733abb
Loading
Loading
Loading
Loading
+13 −1
Original line number Diff line number Diff line
@@ -7,9 +7,10 @@
# *****************************************************************************

import argparse
import pandas as pd
from time import time
from rtasci.lib.RTACtoolsSimulation import RTACtoolsSimulation
from astrort.utils.wrap import load_yaml_conf, configure_simulator_no_visibility, write_simulation_info, set_pointing
from astrort.utils.wrap import load_yaml_conf, configure_simulator_no_visibility, write_simulation_info, set_pointing, set_irf
from astrort.configure.logging import set_logger, get_log_level, get_logfile
from astrort.configure.slurmjobs import slurm_submission

@@ -25,11 +26,22 @@ def base_simulator(configuration_file):
    #makedirs(configuration['simulator']['output'], exist_ok=True)
    # start simulations
    log.info(f"\n {'-'*17} \n| START SIMULATOR | \n {'-'*17} \n")
    if configuration['simulator']['replicate'] is not None:
        replica = pd.read_csv(configuration['simulator']['replicate'], sep=' ', header=0)
        log.info(f"Replicate pointing and IRF from {configuration['simulator']['replicate']}")
    else:
        replica = None
    # loop seeds
    for i in range(configuration['simulator']['samples']):
        clock_sim = time()
        simulator = RTACtoolsSimulation()
        # check pointing option
        if replica is not None:
            configuration['simulator']['pointing'] = {'ra': replica[replica['seed']==configuration['simulator']['seed']]['point_ra'].values[0],  
                                                      'dec': replica[replica['seed']==configuration['simulator']['seed']]['point_dec'].values[0]}
            configuration['simulator']['irf'] = replica[replica['seed']==configuration['simulator']['seed']]['irf'].values[0]   
        simulator, point = set_pointing(simulator, configuration['simulator'], log)
        simulator.irf = set_irf(configuration['simulator'], log)
        # complete configuration
        simulator = configure_simulator_no_visibility(simulator, configuration['simulator'], log)
        simulator.run_simulation()
+25 −3
Original line number Diff line number Diff line
@@ -7,26 +7,48 @@
# *****************************************************************************

import pytest
import yaml
import pandas as pd
import numpy as np
from shutil import rmtree
from os import listdir
from os import listdir, makedirs
from os.path import isfile, join
from astrort.simulator.base_simulator import base_simulator
from astrort.utils.wrap import load_yaml_conf

@pytest.mark.test_conf_file
def test_base_simulator(test_conf_file):
@pytest.mark.test_data_folder
@pytest.mark.test_tmp_folder
@pytest.mark.parametrize('replicate', [None, 'test_simulator.dat'])
def test_base_simulator(test_conf_file, replicate, test_data_folder, test_tmp_folder):

    # clean output
    conf = load_yaml_conf(test_conf_file)
    conf['simulator']['replicate'] = join(test_data_folder, replicate) if replicate is not None else replicate
    rmtree(conf['simulator']['output'], ignore_errors=True)

    # create tmp
    makedirs(test_tmp_folder, exist_ok=True)
    tmp_conf_file = join(test_tmp_folder, 'test.yml')
    with open(tmp_conf_file, 'w+') as f:
        yaml.dump(conf, f)

    # run simulator
    base_simulator(test_conf_file)
    base_simulator(tmp_conf_file)

    # check output
    expected_simulations = conf['simulator']['samples']
    found_simulations = len([f for f in listdir(conf['simulator']['output']) if isfile(join(conf['simulator']['output'], f)) and '.fits' in f and conf['simulator']['name'] in f])
    assert found_simulations == expected_simulations, f"Expected {expected_simulations} simulations, found {found_simulations}"

    if replicate is not None:
        data_new = pd.read_csv(join(test_tmp_folder, replicate), sep=' ', header=0)
        data_old = pd.read_csv(join(test_data_folder, replicate), sep=' ', header=0)
        for i in range(conf['simulator']['samples']):
            assert data_new['irf'][i] == data_old['irf'][i]
            assert np.round(data_new['point_ra'][i], decimals=3) == np.round(data_old['point_ra'][i], decimals=3)
            assert np.round(data_new['point_dec'][i], decimals=3) == np.round(data_old['point_dec'][i], decimals=3)