Commit 48eb1490 authored by Ambra Di Piano's avatar Ambra Di Piano
Browse files

fix seed to format for name and files

parent cfe8e177
Loading
Loading
Loading
Loading
+15 −2
Original line number Diff line number Diff line
@@ -11,8 +11,8 @@ from astrort.utils.utils import *

@pytest.mark.test_tmp_folder
@pytest.mark.parametrize('samples', [3, 5, 8, 10])
def test_seeds_to_string_formatter(samples, test_tmp_folder):
    name = seeds_to_string_formatter(samples, test_tmp_folder, name='test', seed=1, ext='fits')
def test_seeds_to_string_formatter_files(samples, test_tmp_folder):
    name = seeds_to_string_formatter_files(samples, test_tmp_folder, name='test', seed=1, ext='fits')

    if samples <= 1e3:
        assert name == f"{test_tmp_folder}/test_001.fits"
@@ -23,6 +23,19 @@ def test_seeds_to_string_formatter(samples, test_tmp_folder):
    else:
        assert name == f"{test_tmp_folder}/test_1.fits"

@pytest.mark.parametrize('samples', [3, 5, 8, 10])
def test_seeds_to_string_formatter(samples):
    name = seeds_to_string_formatter(samples, name='test', seed=1)

    if samples <= 1e3:
        assert name == f"test_001"
    elif samples <= 1e5:
        assert name == f"test_00001"
    elif samples <= 1e8:
        assert name == f"test_00000001"
    else:
        assert name == f"test_1"

@pytest.mark.parametrize('array', ['lst', 'mst', 'sst', 'cta', 'north', 'south'])
def test_get_instrument_fov(array):
    fov = get_instrument_fov(array)
+12 −1
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@

from os.path import join

def seeds_to_string_formatter(samples, output, name, seed, ext):
def seeds_to_string_formatter_files(samples, output, name, seed, ext):
    if samples <= 1e3:
        name = join(output, f"{name}_{seed:03d}.{ext}")
    elif samples <= 1e5:
@@ -19,6 +19,17 @@ def seeds_to_string_formatter(samples, output, name, seed, ext):
        name = join(output, f"{name}_{seed}.{ext}")
    return name

def seeds_to_string_formatter(samples, name, seed):
    if samples <= 1e3:
        name = join(f"{name}_{seed:03d}")
    elif samples <= 1e5:
        name = join(f"{name}_{seed:05d}")
    elif samples <= 1e8:
        name = join(f"{name}_{seed:08d}")
    else:
        name = join(f"{name}_{seed}")
    return name

def get_instrument_fov(instrument):
    if instrument == 'lst':
        fov = 2.5
+3 −3
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ import yaml
import numpy as np
import astropy.units as u
from os.path import dirname, abspath, join, basename, isfile
from astrort.utils.utils import seeds_to_string_formatter, get_instrument_fov
from astrort.utils.utils import seeds_to_string_formatter_files, get_instrument_fov, seeds_to_string_formatter
from astrort.configure.check_configuration import CheckConfiguration
from rtasci.lib.RTAManageXml import ManageXml
from astropy.coordinates import SkyCoord 
@@ -25,7 +25,7 @@ def configure_simulator_no_visibility(simulator, configuration):
    if '$TEMPLATES$' in configuration['model']:
        configuration['model'] = join(dirname(abspath(__file__)).replace('utils', 'templates'), basename(configuration['model']))
    simulator.model = configuration['model']
    simulator.output = seeds_to_string_formatter(configuration['samples'], configuration['output'], configuration['name'], configuration['seed'], 'fits')
    simulator.output = seeds_to_string_formatter_files(configuration['samples'], configuration['output'], configuration['name'], configuration['seed'], 'fits')
    simulator.caldb = configuration['prod']
    simulator.irf = configuration['irf']
    simulator.fov = get_instrument_fov(configuration['array'])
@@ -72,7 +72,7 @@ def get_point_source_info(simulator):
    return {'point_ra': pointing.ra.deg, 'point_dec': pointing.dec.deg, 'offset': separation.value, 'source_ra': source.ra.deg, 'source_dec': source.dec.deg}

def write_simulation_info(simulator, configuration, pointing, datfile):
    name = seeds_to_string_formatter(configuration['samples'], configuration['output'], configuration['name'], configuration['seed'], '')
    name = seeds_to_string_formatter(configuration['samples'], configuration['name'], configuration['seed'])
    seed = simulator.seed
    tstart, tstop = simulator.t
    duration = configuration['duration']