Commit 82c9cddf authored by Ambra Di Piano's avatar Ambra Di Piano
Browse files

add merge_data_info and test

parent c00bbf62
Loading
Loading
Loading
Loading
+25 −2
Original line number Original line Diff line number Diff line
@@ -11,7 +11,6 @@ import logging
import numpy as np
import numpy as np
from shutil import rmtree
from shutil import rmtree
from astrort.utils.wrap import *
from astrort.utils.wrap import *
from astrort.utils.utils import seeds_to_string_formatter_files
from astrort.configure.logging import set_logger
from astrort.configure.logging import set_logger
from astrort.simulator.base_simulator import base_simulator
from astrort.simulator.base_simulator import base_simulator
from rtasci.lib.RTACtoolsSimulation import RTACtoolsSimulation
from rtasci.lib.RTACtoolsSimulation import RTACtoolsSimulation
@@ -92,11 +91,35 @@ def test_merge_simulation_info(test_conf_file, test_tmp_folder):
    assert isfile(join(conf['simulator']['output'], 'merged_sim_data.dat'))
    assert isfile(join(conf['simulator']['output'], 'merged_sim_data.dat'))
    del sim
    del sim


@pytest.mark.test_tmp_folder
@pytest.mark.test_conf_file
@pytest.mark.parametrize('mode', ['simulator', 'mapper'])
def test_merge_data_info(test_conf_file, test_tmp_folder, mode):
    conf = load_yaml_conf(test_conf_file)
    conf['simulator']['pointing'] = {'ra': 1, 'dec': 1}
    pointing = get_point_source_info(conf['simulator'])
    sim = RTACtoolsSimulation()
    clock = 1
    for i in range(5):
        sim.seed = i
        datfile = join(conf['simulator']['output'], f'job_{i}_simulator.dat')
        write_simulation_info(sim, conf['simulator'], pointing, datfile, clock)
        assert isfile(datfile)

        datfile = join(conf['mapper']['output'], f'job_{i}_mapper.dat')
        write_mapping_info(conf, datfile, clock)
        assert isfile(datfile)
    
    log = set_logger(logging.CRITICAL, join(test_tmp_folder, 'test_set_logger.log'))
    merge_data_info(conf[mode], mode, log)
    assert isfile(join(conf[mode]['output'], f'merged_{mode}_data.dat'))
    del sim

@pytest.mark.test_conf_file
@pytest.mark.test_conf_file
def test_write_mapping_info(test_conf_file):
def test_write_mapping_info(test_conf_file):
    conf = load_yaml_conf(test_conf_file)
    conf = load_yaml_conf(test_conf_file)
    conf['simulator']['pointing'] = {'ra': 1, 'dec': 1}
    conf['simulator']['pointing'] = {'ra': 1, 'dec': 1}
    datfile = join(conf['simulator']['output'], 'simulator.dat')
    datfile = join(conf['mapper']['output'], 'mapper.dat')
    clock = 1
    clock = 1
    write_mapping_info(conf, datfile, clock)
    write_mapping_info(conf, datfile, clock)
    assert isfile(datfile)
    assert isfile(datfile)
+21 −0
Original line number Original line Diff line number Diff line
@@ -141,6 +141,27 @@ def write_mapping_info(configuration, datfile, clock):
    with open(datfile, 'a') as f:
    with open(datfile, 'a') as f:
        f.write(f'{name} {seed} {exposure} {center_type} {pixelsize} {smooth} {clock}\n')
        f.write(f'{name} {seed} {exposure} {center_type} {pixelsize} {smooth} {clock}\n')


def merge_data_info(configuration, mode, log):
    folder = configuration['output']
    datfiles = [join(folder, f) for f in listdir(folder) if '.dat' in f and 'job' in f and mode in f]
    merger = join(folder, f'merged_{mode}_data.dat')
    # check merger file
    if isfile(merger):
        log.warning(f"Merger output already exists, overwrite {merger}")
        f = open(merger, 'w+')
        f.close()
    # collect data
    for i, datfile in enumerate(datfiles):
        log.info(f"Collect data from {datfile}")
        data = pd.read_csv(join(datfile), sep=' ')
        if i == 0:
            table = data
        else:
            table = pd.concat([table, data], ignore_index=True)
        log.info(f"Lines in data: {len(table)}")
    # write merger file
    table.to_csv(merger, index=False, header=True, sep=' ', na_rep=np.nan)

def plot_map(fitsmap, log):
def plot_map(fitsmap, log):
    plotmap = fitsmap.replace('.fits', '.png')
    plotmap = fitsmap.replace('.fits', '.png')
    plot = Plotter(log)
    plot = Plotter(log)