Commit 5e9ef3f7 authored by Ambra Di Piano's avatar Ambra Di Piano
Browse files

seed to string generalised

parent 99bdd18c
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -11,10 +11,10 @@ from os import system
from os.path import join, dirname, abspath

def make_configuration(jobname_conf, configuration, node_number):
    configuration['simulator']['seed'] = node_number*configuration['simulator']['samples']
    configuration['simulator']['samples'] = node_number*configuration['simulator']['samples'] + 1
    # write new configuration
    with open(jobname_conf, 'w+') as f:
        new_configuration = dump(configuration, f, default_flow_style=False)
        dump(configuration, f, default_flow_style=False)

def make_sh(jobname, slurmconf, jobname_conf, jobname_sh, jobname_log):
    # write sbatch
@@ -34,9 +34,9 @@ def make_sh(jobname, slurmconf, jobname_conf, jobname_sh, jobname_log):

def make_sbatch(jobname, configuration, node_number):
    output = configuration['simulator']['output']
    jobname_sh = join(output, f"job_{jobname}.sh")
    jobname_log = join(output, f"job_{jobname}.log")
    jobname_conf = join(output, f"job_{jobname}.yml")
    jobname_sh = join(output, f"{jobname}.sh")
    jobname_log = join(output, f"{jobname}.slurm")
    jobname_conf = join(output, f"{jobname}.yml")
    make_configuration(jobname_conf, configuration, node_number)
    make_sh(jobname, configuration['slurm'], jobname_conf, jobname_sh, jobname_log)
    system(f"sbatch {jobname_sh}")
+1 −1
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ from astrort.utils.utils import *
@pytest.mark.test_tmp_folder
@pytest.mark.parametrize('samples', [3, 5, 8, 10])
def test_seeds_to_string_formatter(samples, test_tmp_folder):
    name = seeds_to_string_formatter(samples, test_tmp_folder, name='test', seed=1)
    name = seeds_to_string_formatter(samples, test_tmp_folder, name='test', seed=1, ext='fits')

    if samples <= 1e3:
        assert name == f"{test_tmp_folder}/test_001.fits"
+5 −5
Original line number Diff line number Diff line
@@ -8,15 +8,15 @@

from os.path import join

def seeds_to_string_formatter(samples, output, name, seed):
def seeds_to_string_formatter(samples, output, name, seed, ext):
    if samples <= 1e3:
        name = join(output, f"{name}_{seed:03d}.fits")
        name = join(output, f"{name}_{seed:03d}.{ext}")
    elif samples <= 1e5:
        name = join(output, f"{name}_{seed:05d}.fits")
        name = join(output, f"{name}_{seed:05d}.{ext}")
    elif samples <= 1e8:
        name = join(output, f"{name}_{seed:08d}.fits")
        name = join(output, f"{name}_{seed:08d}.{ext}")
    else:
        name = join(output, f"{name}_{seed}.fits")
        name = join(output, f"{name}_{seed}.{ext}")
    return name

def get_instrument_fov(instrument):
+1 −1
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ def configure_simulator_no_visibility(simulator, configuration):
    if '$TEMPLATES$' in configuration['model']:
        configuration['model'] = join(dirname(abspath(__file__)).replace('utils', 'templates'), basename(configuration['model']))
    simulator.model = configuration['model']
    simulator.output = seeds_to_string_formatter(configuration['samples'], configuration['output'], configuration['name'], configuration['seed'])
    simulator.output = seeds_to_string_formatter(configuration['samples'], configuration['output'], configuration['name'], configuration['seed'], 'fits')
    simulator.caldb = configuration['prod']
    simulator.irf = configuration['irf']
    simulator.fov = get_instrument_fov(configuration['array'])