Unverified Commit b5a591b1 authored by Akke Viitanen's avatar Akke Viitanen
Browse files

increase egg test coverage

parent 7e4ac2f1
Loading
Loading
Loading
Loading
+57 −0
Original line number Diff line number Diff line
[egg]

# egg configuration. Each key value pair corresponds to input arguments to egg.
#
# Run
#   $ egg-gencat help
# in order to find out all the available input arguments.

# print verbose output
verbose: 0

# EGG filename
out: data/tests/test_egg_config/egg.fits

# area in deg2
area: 0.001

# right ascension of center
ra0: +150.11916667

# declination of center
dec0: +2.20583333

# apparent flux bands
bands: [lsst-u,johnson-B,lsst-g,lsst-r,lsst-i,lsst-z,lsst-y,euclid-VIS,euclid-nisp-Y,euclid-nisp-J,euclid-nisp-H,spitzer-irac1,spitzer-irac2,spitzer-irac3,spitzer-irac4,wise-w1,wise-w2,wise-w3,wise-w4,mock-1000-4000,mock-1450,mock-4400,mock-5007,mock-15um,vista-Y,vista-J,vista-H,vista-Ks]

# absolute flux bands
rfbands: [lsst-u,johnson-B,lsst-g,lsst-r,lsst-i,lsst-z,lsst-y,euclid-VIS,euclid-nisp-Y,euclid-nisp-J,euclid-nisp-H,spitzer-irac1,spitzer-irac2,spitzer-irac3,spitzer-irac4,wise-w1,wise-w2,wise-w3,wise-w4,mock-1000-4000,mock-1450,mock-4400,mock-5007,mock-15um,vista-Y,vista-J,vista-H,vista-Ks]

# input stellar mass function
mass_func: data/egg/share/mass_func_cosmos2020_agn_zou24_loglambda32.fits

# minimum redshift
zmin: 0.21

# maximum redshift
zmax: 5.49

# redshift slice resolution
dz: 0.05

# minimum stellar mass in logMsun (Salpeter IMF)
mmin: 8.74

# NOTE: 0.30 deg should cover most cases i.e. r~0-200 Mpc/h at z~0.20-5.50
clust_r0: 0.30
clust_fclust_him: 0.40
clust_fclust_lom: 0.20

# save SED?
save_sed: 1

# random number seed
seed: 20250911

# filter database
filter_db: data/egg/share/filter-db/db.dat
+100 −58
Original line number Diff line number Diff line
@@ -20,9 +20,15 @@ handler.setLevel(logging.DEBUG)
handler.setFormatter(formatter)
logger.addHandler(handler)


# See if EGG has been installed properly
def find_egg():
    if not shutil.which("egg-gencat"):
        raise FileNotFoundError("Could not find EGG binary. Check the EGG installation.")
    return shutil.which("egg-gencat")


find_egg()


class Egg:
@@ -51,8 +57,9 @@ class Egg:
        """
        self.egg_kwargs = egg_kwargs

    def get_argument_line(self):
    def get_argument_line(self, exclude=None):
        """Get the EGG argument line for calling it from the terminal."""
        if exclude is None:
            exclude = []
        egg_kwargs = {k: v for k, v in self.egg_kwargs.items()}
        return sorted([f"{k}={str(v)}" for k, v in egg_kwargs.items() if k not in exclude])
@@ -83,17 +90,41 @@ class Egg:
        util.create_directory(filename)
        return os.system(cmd)

    def get_sed(self, i):
        """Return an EGG SED."""
        dirname = os.path.dirname(self.egg_kwargs["save_sed"])
        fname = f"{dirname}/egg-seds-{i}.fits"
        try:
            if not os.path.exists(fname):
                os.system(f"egg-getsed seds={dirname}/egg-seds.dat id={i}")
    def get_sed(self, i, component=None, overwrite=False):
        """
        Return an EGG SED.

        Parameters
        ----------
        i: int
            ID of the galaxy.
        component: str or None
            One of "bulge", "disk" or None (= bulge + disk).
        overwrite: bool
            Run egg-getsed regardless of existing filename.

        Raises
        ------
        ValueError
            If component is invalid.
        """

        if component not in [None, "bulge", "disk"]:
            raise ValueError(f"Invalid {component=}")

        dirname = os.path.dirname(self.egg_kwargs["out"])
        fname = f"{dirname}/egg-seds"
        if component in ["bulge", "disk"]:
            fname += f"-{component}"
        fname += f"-{i}.fits"

        if overwrite or not os.path.exists(fname):
            cmd = f"egg-getsed seds={dirname}/egg-seds.dat id={i}"
            if component is not None:
                cmd += f" component={component}"
            os.system(cmd)
        assert os.path.exists(fname)
        return util.read_fits(fname)
        except FileNotFoundError:
            logger.warning("Could not find galaxy sed")
            return None

    @staticmethod
    def read(filename):
@@ -113,7 +144,6 @@ class Egg:
        EGG writes the FITS files.
        """
        import fitsio
        import numpy as np

        columns = [
            "ID",
@@ -145,45 +175,59 @@ class Egg:
            "RFMAG_BULGE",
            "RFMAG_DISK",
        ]

        try:
        ret = {k: fitsio.read(filename, columns=k) for k in columns}
        return ret
        # except NotImplementedError("FIXME: catch the correct error and remove 'raise'"):
        except:
            raise
            pass

        logger.info("Reading (and writing) the EGG flux files one-by-one")

        # Remove the fluxes
        columns = columns[:-6]
        ret = {k: fitsio.read(filename, columns=k) for k in columns}

        Ngal = ret["ID"][0].size
        Nbands = ret["BANDS"][0].size
        for k1 in "FLUX", "RFMAG":
            for k2 in "", "_BULGE", "_DISK":
                k = k1 + k2
                fin = filename
                fout = filename.replace(".fits", f"_{k}.dat")
                if not os.path.exists(fout):
                    logger.info(f"  Writing {k} from {fin} to {fout}")
                    os.system(f"/home/viitanen/.local/bin/write_fits_column {fin} {fout} {k} {Ngal * Nbands}")
                ret[k] = np.fromfile(fout, dtype=np.float32).reshape((1, Ngal, Nbands))

        return ret
        # NOTE: fossil code below
        # try:
        #    ret = {k: fitsio.read(filename, columns=k) for k in columns}
        #    return ret
        ## except NotImplementedError("FIXME: catch the correct error and remove 'raise'"):
        # except:
        #    raise

        # logger.info("Reading (and writing) the EGG flux files one-by-one")

        ## Remove the fluxes
        # columns = columns[:-6]
        # ret = {k: fitsio.read(filename, columns=k) for k in columns}

        # Ngal = ret["ID"][0].size
        # Nbands = ret["BANDS"][0].size
        # for k1 in "FLUX", "RFMAG":
        #    for k2 in "", "_BULGE", "_DISK":
        #        k = k1 + k2
        #        fin = filename
        #        fout = filename.replace(".fits", f"_{k}.dat")
        #        if not os.path.exists(fout):
        #            logger.info(f"  Writing {k} from {fin} to {fout}")
        #            os.system(
        #               f"/home/viitanen/.local/bin/write_fits_column {fin} {fout} {k} {Ngal * Nbands}"
        #            )
        #        ret[k] = np.fromfile(fout, dtype=np.float32).reshape((1, Ngal, Nbands))

        # return ret

    @staticmethod
    def get_smf(z, key, filename):
        """Read an EGG-like stellar mass function from a file."""
        """
        Read an EGG-like stellar mass function from a file.

        Parameters
        ----------
        z: float
            Redshift of the stellar mass function.
        key: str
            One of "ACTIVE" or "PASSIVE".
        filename: str
            Filename containing the stellar mass function in EGG format.
        """
        import numpy as np

        smf = util.read_fits(filename)
        i = None
        for _i, (zlo, zhi) in enumerate(smf["ZB"][0].T):
            if zlo <= z < zhi:
                print(zlo, z, zhi)
                i = _i
                break
        x = np.mean(smf["MB"][0], axis=0)
@@ -191,19 +235,26 @@ class Egg:
        return x, y

    @staticmethod
    def run_config(config):
    def run_config(filename):
        """
        Run EGG for the given configuration file.

        Returns the EGG catalog using Egg.read.
        Parameters
        ----------
        filename: str
            Path to configuration file.

        Returns
        -------
        The EGG catalog using Egg.read.

        """
        logger.info(f"Reading {config}")
        logger.info(f"Reading {filename}")
        config = ConfigParser()
        config.read("etc/config_egg.ini")
        config.read(filename)
        assert "egg" in config
        egg = Egg(config["egg"])
        egg.run()

        return Egg.read(config["egg"]["out"])

    @staticmethod
@@ -221,12 +272,3 @@ class Egg:
            save_sed=1,
            verbose=1,
        )


def main():
    """Run EGG using an example config file."""
    Egg.run_config("etc/config_egg.ini")


if __name__ == "__main__":
    main()
+100 −6
Original line number Diff line number Diff line
@@ -7,9 +7,11 @@

import os
from unittest import TestCase
from unittest.mock import patch

import numpy as np
from lsst_inaf_agile.egg import Egg
from lsst_inaf_agile import util
from lsst_inaf_agile.egg import Egg, find_egg


class TestEgg(TestCase):
@@ -20,17 +22,109 @@ class TestEgg(TestCase):
        self.egg.run()
        self.egg_catalog = self.egg.read(self.filename)

    def test_run(self):
        self.assertTrue(os.path.exists(self.filename))
    def test__init__(self):
        self.assertIn("out", self.egg.egg_kwargs)

    @patch("shutil.which")
    def test_egg_binary(self, mock_which):
        mock_which.return_value = "/some/path/to/egg-gencat"
        self.assertTrue("egg-gencat" in find_egg())

        mock_which.return_value = None
        with self.assertRaises(FileNotFoundError):
            find_egg()

    def test_get_filename(self):
        filename = self.egg.get_filename()
        self.assertIn("egg.fits", filename)

        self.egg.egg_kwargs.pop("out")
        filename = self.egg.get_filename()
        self.assertIn("area", filename)
        self.assertIn("save_sed", filename)
        self.assertIn("bands", filename)
        self.assertIn("rfbands", filename)

    def test_get_argument_line(self):
        argument_line = self.egg.get_argument_line()
        self.assertIn("area=0.001", argument_line)
        self.assertIn("mmin=9.5", argument_line)

    def test_has_column(self):
        for column in ["ID", "Z", "M", "SFR", "PASSIVE"]:
            self.assertIn(column, self.egg_catalog)
    def test_get_area(self):
        self.assertEqual(self.egg.get_area(), 0.001)

    def test_run(self):
        # run EGG
        os.remove(self.filename)
        self.egg_kwargs["verbose"] = 0
        cmd = self.egg.run()
        self.assertEqual(cmd, 0)
        self.assertTrue(os.path.exists(self.filename))

        # try to read the fits file
        fits = util.read_fits(self.filename)
        self.assertIn("Z", fits.dtype.names)
        self.assertIn("M", fits.dtype.names)
        self.assertIn("SFR", fits.dtype.names)
        self.assertIn("PASSIVE", fits.dtype.names)

    def test_get_sed(self):
        seds = []
        seds += [self.egg.get_sed(0)]
        seds += [self.egg.get_sed(0, overwrite=True)]
        self.assertTrue(np.all(seds[-2]["FLUX"] == seds[-1]["FLUX"]))
        seds.pop()

        seds += [self.egg.get_sed(1)]
        seds += [self.egg.get_sed(1, "disk")]
        seds += [self.egg.get_sed(1, "disk", overwrite=True)]
        self.assertTrue(np.all(seds[-2]["FLUX"] == seds[-1]["FLUX"]))
        seds.pop()

        seds += [self.egg.get_sed(1, "bulge")]
        with self.assertRaises(ValueError):
            self.egg.get_sed(1, "agn")

        for sed in seds:
            self.assertIn("LAMBDA", sed.dtype.names)
            self.assertIn("FLUX", sed.dtype.names)
            self.assertTrue(np.all(sed["LAMBDA"] >= 0.0))
            self.assertTrue(np.all(sed["FLUX"] >= 0.0))

        lam = np.linspace(1, 2)
        f1 = np.interp(lam, seds[1]["LAMBDA"][0], seds[1]["FLUX"][0])
        f2 = np.interp(lam, seds[2]["LAMBDA"][0], seds[2]["FLUX"][0])
        f3 = np.interp(lam, seds[3]["LAMBDA"][0], seds[3]["FLUX"][0])
        self.assertTrue(np.allclose(f1, f2 + f3))

    def test_read(self):
        ret = self.egg.read(self.egg.get_filename())
        self.assertIn("Z", ret)
        self.assertIn("M", ret)
        self.assertIn("SFR", ret)

    def test_get_smf(self):
        x1, y1 = self.egg.get_smf(0.0, "ACTIVE", "data/egg/share/mass_func_candels.fits")
        x2, y2 = self.egg.get_smf(0.0, "PASSIVE", "data/egg/share/mass_func_candels.fits")
        x3, y3 = self.egg.get_smf(1.0, "ACTIVE", "data/egg/share/mass_func_candels.fits")
        self.assertLessEqual(np.mean(y2), np.mean(y1))
        self.assertLessEqual(np.mean(y3), np.mean(y1))

    def test_run_config(self):
        filename = "data/tests/test_egg_config/egg.fits"
        if os.path.exists(filename):
            os.remove(filename)
        util.create_directory(filename)
        self.egg.run_config("etc/config_test_egg.ini")
        self.assertTrue(os.path.exists(filename))

    def test_get_example_egg_kwargs(self):
        filename = "data/tests/test_egg_kwargs/egg.fits"
        kwargs = self.egg.get_example_egg_kwargs(filename)
        kwargs["verbose"] = 0
        self.egg.egg_kwargs = kwargs
        self.egg.run()
        self.assertTrue(os.path.exists(filename))

    def test_column_is_positive(self):
        self.assertTrue(np.all(self.egg_catalog["Z"] > 0))
+11 −9
Original line number Diff line number Diff line
@@ -75,15 +75,16 @@ class TestImageSimulator(TestCase):
            self.assertTrue(has_band)

        # Test some instcat files
        finst2, *_ = self.image_simulator.write_instance_catalog(1)
        finst3, *_ = self.image_simulator.write_instance_catalog(1, exptime="0 0 0 0 0 0")
        finst4, *_ = self.image_simulator.write_instance_catalog(1, ra_dec=(0, 1))
        def get_lines(*args, **kwargs):
            finst, *_ = self.image_simulator.write_instance_catalog(*args, **kwargs)
            with open(finst, "r", encoding="utf-8") as f:
                lines = "".join(f)
            return lines

        lines = []
        for ff in finst, finst2, finst3, finst4:
            with open(ff, "r", encoding="utf-8") as f:
                lines.append("".join(f))
        lines1, lines2, lines3, lines4 = lines
        lines1 = get_lines(0)
        lines2 = get_lines(1)
        lines3 = get_lines(1, exptime="0 0 0 0 0 0")
        lines4 = get_lines(1, ra_dec=(0, 1))

        self.assertNotEqual(lines1, lines2)
        # NOTE: exptime enters the yaml file
@@ -101,7 +102,8 @@ class TestImageSimulator(TestCase):
        c2 = [self.image_simulator.get_catalog(f"lsst-{b}", self.image_simulator.mjd0 + 100) for b in bands]
        for i in range(len(bands)):
            # test that bands differ
            self.assertTrue(np.any(c1[i] != c2[(i + 1) % len(bands)]))
            if i < len(bands) - 1:
                self.assertTrue(np.any(c1[i] != c2[i + 1]))
            # test that mjds differ
            self.assertTrue(np.any(c1[i] != c2[i]))