Commit c091ad23 authored by Ambra Di Piano's avatar Ambra Di Piano
Browse files

one make_sh with mode for operation

parent 356a9eca
Loading
Loading
Loading
Loading
+10 −7
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ def make_simulator_conf(jobname_conf, configuration, node_number):
    with open(jobname_conf, 'w+') as f:
        dump(configuration, f, default_flow_style=False)

def make_simulator_sh(jobname, slurmconf, jobname_conf, jobname_sh, jobname_log):
def make_sh(jobname, slurmconf, jobname_conf, jobname_sh, jobname_log, mode='simulator'):
    # write sbatch
    with open(jobname_sh, 'w+') as f:
        f.write("#!/bin/bash")
@@ -34,15 +34,18 @@ def make_simulator_sh(jobname, slurmconf, jobname_conf, jobname_sh, jobname_log)
        f.write(f"\n#SBATCH --partition={slurmconf['partition']}")
        f.write(f"\n")
        f.write(f"\nsource activate {slurmconf['environment']}")
        if mode == 'simulator':
            f.write(f"\npython {join(dirname(abspath(__file__)).replace('configure', 'simulator'), 'base_simulator.py')} -f {jobname_conf}\n")
        else:
            raise ValueError(f"Invalid 'mode' {mode}")

def make_simulator_sbatch(jobname, configuration, node_number):
    output = configuration['simulator']['output']
    jobname_sh = join(output, f"{jobname}.sh")
    jobname_log = join(output, f"{jobname}.slurm")
    jobname_conf = join(output, f"{jobname}.yml")
    make_simulator_conf(jobname_conf, configuration, node_number)
    make_simulator_sh(jobname, configuration['slurm'], jobname_conf, jobname_sh, jobname_log)
    jobname_sh = join(output, f"{jobname}_simulator.sh")
    jobname_log = join(output, f"{jobname}_simulator.slurm")
    jobname_conf = join(output, f"{jobname}_simulator.yml")
    make_simulator_conf(jobname_conf, configuration, node_number, mode='simulator')
    make_sh(jobname, configuration['slurm'], jobname_conf, jobname_sh, jobname_log)
    system(f"sbatch {jobname_sh}")
    
def make_mapper_sh():
+3 −3
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ import pytest
from shutil import rmtree
from os import listdir, makedirs
from os.path import isfile, join
from astrort.configure.slurmjobs import make_simulator_conf, make_simulator_sh, make_simulator_sbatch
from astrort.configure.slurmjobs import make_simulator_conf, make_sh, make_simulator_sbatch
from astrort.utils.wrap import load_yaml_conf

@pytest.mark.test_conf_file
@@ -34,7 +34,7 @@ def test_make_simulator_conf(test_conf_file):
    assert found_configurations == expected_configurations, f"Expected {expected_configurations} simulations, found {found_configurations}"

@pytest.mark.test_conf_file
def test_make_simulator_sh(test_conf_file):
def test_make_sh(test_conf_file):

    # clean output
    conf = load_yaml_conf(test_conf_file)
@@ -49,7 +49,7 @@ def test_make_simulator_sh(test_conf_file):
        jobname_sh = join(output, f"job_{jobname}.sh")
        jobname_log = join(output, f"job_{jobname}.log")
        jobname_conf = join(output, f"job_{jobname}.yml")
        make_simulator_sh(jobname, conf['slurm'], jobname_conf, jobname_sh, jobname_log)
        make_sh(jobname, conf['slurm'], jobname_conf, jobname_sh, jobname_log, mode='simulator')

    # check output
    expected_sh = conf['slurm']['nodes']
+36 −0
Original line number Diff line number Diff line
# *****************************************************************************
# Copyright (C) 2023 INAF
# This software is distributed under the terms of the BSD-3-Clause license
#
# Authors:
# Ambra Di Piano <ambra.dipiano@inaf.it>
# *****************************************************************************

import pytest
from shutil import rmtree
from os import listdir
from os.path import isfile, join
from astrort.simulator.base_simulator import base_simulator
from astrort.simulator.base_mapper import base_mapper
from astrort.utils.wrap import load_yaml_conf

@pytest.mark.skip('#TODO')
@pytest.mark.test_conf_file
@pytest.mark.parametrize('seeds', [None, list([1,2])])
def test_base_mapper(test_conf_file, seeds):

    # clean output
    conf = load_yaml_conf(test_conf_file)
    rmtree(conf['mapper']['output'])

    # run simulator
    base_simulator(test_conf_file)
    base_mapper(test_conf_file, seeds)

    # check output
    expected_maps = conf['simulator']['samples']
    found_maps = len([f for f in listdir(conf['mapper']['output']) if isfile(join(conf['mapper']['output'], f)) and '.fits' in f and conf['mapper']['name'] in f])
    assert found_maps == expected_maps, f"Expected {expected_maps} maps, found {found_maps}"