Unverified Commit fdc2bcfc authored by Akke Viitanen's avatar Akke Viitanen
Browse files

add unit tests

check the coverage with pytest --cov=src --cov-report=html tests/
parent ea492a2c
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -44,6 +44,7 @@ example_*
# build files
_readthedocs
_build
.coverage

# backup
*.bak*
+1 −1
Original line number Diff line number Diff line
@@ -68,7 +68,7 @@ write_to = "src/lsst_inaf_agile/_version.py"
[tool.pytest.ini_options]
testpaths = [
    "tests",
    "src",
    "src/lsst_inaf_agile",
    "docs",
]
addopts = "--doctest-modules --doctest-glob=*.rst"
+82 −102
Original line number Diff line number Diff line
@@ -8,8 +8,8 @@
import logging
import multiprocessing
import os
import sys
import time
from dataclasses import dataclass
from itertools import product

import astropy.units as u
@@ -18,7 +18,7 @@ import numpy as np
import pandas as pd
from mock_catalog_SED.qsogen_4_catalog import qsosed

from lsst_inaf_agile import lightcurve, lusso2010, sed, ueda2014, util, zou2024
from lsst_inaf_agile import lightcurve, lusso2010, sed, util, zou2024
from lsst_inaf_agile.catalog_galaxy import CatalogGalaxy
from lsst_inaf_agile.mbh import get_log_mbh_continuity_new2
from lsst_inaf_agile.merloni2014 import Merloni2014
@@ -29,16 +29,22 @@ logger = logging.getLogger(__name__)
# Eddington ratio in cgs
LOG_LAMBDA_EDD = np.log10(1.26e38)

# Maximum allowed E(B-V)
EBV_MAX = 9.0

# Optimization: store the pre-computed AGN fluxes
FLUXES = {}
# key is (ID, band, is_rest_frame)
FLUXES: dict[tuple[int, str, bool], float] = {}


@dataclass
class QsogenPosteriorDistribution:
    filename: str = f"{ROOT}/data/posteriors/posterior_frozen.dat"
    posterior_distribution = pd.read_csv(filename, sep=" ")
    parameter_names = posterior_distribution.columns


try:
    # Pre-load the posterior distribution for the AGN SED
    logger.info("Reading in the posterior distribution...")
    POSTERIOR_DISTRIBUTION = pd.read_csv(f"{ROOT}/data/posteriors/posterior_frozen.dat", sep=" ")
    PARAMETER_NAMES = POSTERIOR_DISTRIBUTION.columns
except FileNotFoundError:
    logger.warning("Posterior distribution not found")
QPODI = QsogenPosteriorDistribution()


class CatalogAGN:
@@ -61,6 +67,8 @@ class CatalogAGN:
        assumed Merloni2014 obscuration model
    filter_db: str, optional
        filename of the EGG filter database (db.dat)
    overwrite: bool, optional
        overwrite an existing AGN catalog

    Attributes
    ----------
@@ -92,6 +100,8 @@ class CatalogAGN:
        seed: int,
        merloni2014: Merloni2014,
        filter_db: str = "data/egg/share/filter-db/db.dat",
        qsogen_posterior_distribution: QsogenPosteriorDistribution = QPODI,
        overwrite: bool = False,
    ):
        """Initialize the AGN catalog."""
        self.dirname = dirname
@@ -101,6 +111,11 @@ class CatalogAGN:
        self.seed = seed
        self.merloni2014 = merloni2014
        self.filter_db = filter_db
        self.qsogen_posterior_distribution = qsogen_posterior_distribution

        # NOTE: mock-1000-4000 column is needed in the galaxy catalog
        if "magabs_mock-1000-4000" not in self.catalog_galaxy.catalog.dtype.names:
            raise ValueError("Galaxy catalog must have magabs_mock-1000-4000 flux column")

        # Set the seed
        if self.seed is not None:
@@ -108,28 +123,16 @@ class CatalogAGN:
        np.random.seed(self.seed)

        # Create the catalog
        self.catalog = self.get_catalog()
        self.catalog = self.get_catalog(overwrite)

    def __getitem__(self, key):
        """
        Return AGN catalog column.

        Parameters
        ----------
        key: str
            Name of column.

        Returns
        -------
        Value of the column.

        """
        """Return AGN catalog column corresponding to key."""
        if key not in self.catalog.dtype.names:
            return self.catalog_galaxy[key]
        return self.catalog[key]

    @staticmethod
    def get_columns(bands, rfbands, is_public=False):
    def get_columns(bands: list[str], rfbands: list[str]):
        """
        Return the names, types, descriptions, and units of each column.

@@ -139,8 +142,6 @@ class CatalogAGN:
            egg-style list of apparent magnitude bands
        rfbands: list
            egg-style list of absolute magnitude bands
        is_public: bool
            return a public column listing instead of an internal one

        Returns
        -------
@@ -177,30 +178,31 @@ class CatalogAGN:
            ]
        )

    def get_dtype(self):
    def get_dtype(self) -> np.dtype:
        """Return the numpy dtype corresponding to the column names and types."""
        i = []
        for n, t, _, _ in self.get_columns(self.catalog_galaxy.bands, self.catalog_galaxy.rfbands):
            i += [(n, t)]
        return np.dtype(i)

    def get_catalog(self):
    def get_catalog(self, overwrite):
        """Build the catalog column-by-column and write to FITS."""
        # Short-circuit for existing catalog
        filename = f"{self.dirname}/agn.fits"
        if os.path.exists(filename):
        if os.path.exists(filename) and not overwrite:
            logger.info(f"Returning an existing AGN catalog {filename}")
            return fitsio.read(filename)

        logger.info(f"Creating the AGN catalog {filename}")
        self.catalog = np.empty_like(self.catalog_galaxy["ID"], dtype=self.get_dtype())
        for col, _, _, _ in self.get_columns(self.catalog_galaxy.bands, self.catalog_galaxy.rfbands):
            if col in self.catalog_galaxy.get_dtype().names:
                self.catalog[col] = self.catalog_galaxy[col]
            else:
                continue
            self.catalog[col] = self.get_agn(col)

        # Write the catalog
        fitsio.write(filename, self.catalog, clobber=True)

        return self.catalog

    @staticmethod
@@ -210,6 +212,16 @@ class CatalogAGN:

        The random number seed is generated programmatically from the name of
        the column.

        Examples
        --------
        >>> CatalogAGN._get_seed("foo")
        33048
        >>> CatalogAGN._get_seed("bar")
        30282
        >>> CatalogAGN._get_seed("baz")
        31066

        """
        # NOTE: numpy random seed must be between 0 and 2 ** 32 - 1
        # NOTE: hash is not persistent between python runs. Use a custom hash
@@ -239,15 +251,11 @@ class CatalogAGN:

        return ret

    def get_log_lambda_SAR(self, add_ctk=True, idxs=None):
    def get_log_lambda_SAR(self, add_ctk=True) -> np.ndarray:
        """Return log10 of the specific black hole accretion rate in erg/s/Msun."""
        if self.type_plambda != "zou+2024":
        if self.type_plambda.lower() != "zou+2024":
            raise ValueError

        # Default: process the full catalog
        if idxs is None:
            idxs = np.arange(self["ID"].size)

        # Get the relevant galaxy properties
        m = self["M"]
        z = self["Z"]
@@ -258,7 +266,7 @@ class CatalogAGN:
        if not add_ctk:
            logger.info("Assigning log_lambda_SAR for CTN...")
            for t in ["star-forming", "quiescent"]:
                select = self._get_select(t=t)
                select = p == (1 if t == "quiescent" else 0)
                U = np.random.rand(select.sum())
                log_lambda_SAR[select] = zou2024.get_inv_log_Plambda(np.log10(U), m[select], z[select], t)
            return log_lambda_SAR
@@ -278,7 +286,7 @@ class CatalogAGN:
        t1 = time.time()
        logger.info(f"Assigned {len(log_lambda_SAR)} objects in {t1 - t0} seconds")

        return log_lambda_SAR
        return np.array(log_lambda_SAR)

    def get_log_LX_2_10(self):
        """Return log10 of the intrinsic 2-10 keV X-ray luminosity in erg/s."""
@@ -359,7 +367,7 @@ class CatalogAGN:
        alpha_2=11.6133635,
        n_2=1.42972,
        mu_type_2=0.3,
        ebv_ctk=9.0,
        ebv_ctk=EBV_MAX,
    ):
        """
        Return AGN reddening E(B-V) in ABmag.
@@ -400,54 +408,6 @@ class CatalogAGN:

        return ebv

    @staticmethod
    def get_plambda_ctk(loglambda, mstar, z, t):
        r"""
        Return the accretion rate distribution of the CTK AGN population.

        Combine the accretion rate distribution of Zou+2024 (CTN AGN) and the
        CTK AGN fraction of Ueda+2014. The p_CTK is defined as:

            p_CTK = f_CTK_AGN / (f_CTK_AGN - 1) * p_CTN

        where the CTK AGN fraction is defined as

            f_CTK_AGN \equiv N_CTK / (N_CTN + N_CTK)

        and p_CTN is the accretion rate distribution of CTN AGN.
        """
        p_ctn = 10 ** zou2024.get_log_plambda(loglambda, mstar, z, t)

        lx = loglambda + mstar
        frac = [ueda2014.get_f(lx, z, n) for n in [20, 21, 22, 23, 24, 25]]

        frac_ctk_agn = (frac[-2] + frac[-1]) / np.sum(frac, axis=0)
        p_ctk = p_ctn * frac_ctk_agn / (1 - frac_ctk_agn)

        # NOTE: safety assertion that p_agn is still a probability measure
        # within l \in loglambda
        test = np.trapezoid(p_ctn + p_ctk, loglambda)
        assert np.all(test <= 1.00), (test, mstar, z, t)
        return p_ctk

    @staticmethod
    def get_plambda_ctk_prime(loglambda, mstar, z, t, dloglambda):
        """
        Return the reduced CTK AGN accretion rate distribution.

        The reduced galaxy population refers to a population of objects reduced
        already by the number of CTN AGN i.e. Ngal' = Ngal - Nctn. So that the
        total AGN number remains to be given by CTN + CTK.

        The resulting p(lambda) is given by

            p_ctk_prime = p_ctk / (1 - p_ctn * dloglambda)
        """
        p_ctk = CatalogAGN.get_plambda_ctk(loglambda, mstar, z, t)
        p_ctn = 10 ** zou2024.get_log_plambda(loglambda, mstar, z, t)
        p_ctk_prime = p_ctk / (1 - p_ctn * dloglambda)
        return p_ctk_prime

    def get_is_agn_ctn(self):
        """Render a random sample of AGN "active" according to the duty cycle."""
        # NOTE: avoid non-AGN and CTK AGN when assigning this flag
@@ -489,7 +449,7 @@ class CatalogAGN:
        fits = fitsio.read(filename)
        return fits["LAMBDA"][0], fits["FLUX"][0]

    def _get_sed(self, i, ratio_max=0.90, dlog_wav=7.65e-4):
    def _get_sed(self, i, ratio_max: float = 0.90, dlog_wav: float = 7.65e-4):
        """
        Generate an AGN SED.

@@ -497,9 +457,9 @@ class CatalogAGN:
        ----------
        i: int
            ID of AGN for which SED is returned.
        ratio_max: float
        ratio_max: float, optional
            Maximum allowed ratio flux_agn / flux_total for type2 AGN.
        dlog_wav: float
        dlog_wav: float, optional
            SED wavelength resolution in dex.

        """
@@ -526,12 +486,18 @@ class CatalogAGN:
                add_NL=self["is_optical_type2"][i],
                NL_normalization="lamastra",
                Av_lines=self.catalog_galaxy["AVLINES_BULGE"][i] + self.catalog_galaxy["AVLINES_DISK"][i],
                **dict(zip(PARAMETER_NAMES, *POSTERIOR_DISTRIBUTION.sample().values, strict=False)),
                **dict(
                    zip(
                        self.qsogen_posterior_distribution.parameter_names,
                        *self.qsogen_posterior_distribution.posterior_distribution.sample().values,
                        strict=False,
                    )
                ),
            )

            # Check for type2 AGN flux NOT exceeding the host galaxy flux by
            # some limit
            if self["is_optical_type2"][i] and self["E_BV"][i] <= 9.0:
            if self["is_optical_type2"][i] and self["E_BV"][i] <= EBV_MAX:
                lam, flux_agn = util.luminosity_to_flux(
                    agn_sed.wavlen.value,
                    agn_sed.lum.value,
@@ -546,13 +512,12 @@ class CatalogAGN:
                if ratio > ratio_max:
                    logger.info(
                        "AGN 1000-4000 angstrom >90% of the total... Incrementing E(B-V) by 0.10...",
                        f"{self['ID']:6d}",
                        f"{self['ID'][i]:6d}",
                        f"{np.log10(flux_agn):.2f}",
                        f"{np.log10(flux_gal):.2f}",
                        f"{ratio:.2f}",
                        f"{self['E_BV'][i]:.2f}",
                        f"{(self['E_BV'][i] + 0.10):.2f}",
                        file=sys.stderr,
                    )
                    self["E_BV"][i] += 0.10
                    continue
@@ -734,7 +699,7 @@ class CatalogAGN:

        return get_occupation_fraction(self["M"])

    def get_lightcurve(self, i, band, mjd=None, *args, **kwargs):
    def get_lightcurve(self, i, band, *args, **kwargs):
        """
        Return an AGN lightcurve.

@@ -746,18 +711,21 @@ class CatalogAGN:
            ID of AGN for which SED is returned.
        band: str
            EGG-style passband name.
        mjd: float or array_like or None
            MJD of observation.

        """
        if mjd is None:
            mjd = util.get_mjd_vec()

        # Short-circuit for non-AGN
        is_agn = self["is_agn"][i]
        if not is_agn:
            return None

        filename = f"{self.dirname}/lightcurves/agn/{i}/{band}.npy"
        if os.path.exists(filename):
            logger.info(f"Reading AGN lightcurve {filename}")
            return np.load(filename)

        mjd = util.get_mjd_vec()

        kwargs.update(
            {
                "mjd0": 0.0,
@@ -804,6 +772,18 @@ class CatalogAGN:
        -------
        The value of the lightcurve at given MJD in microjanskies.

        Raises
        ------
        ValueError
            If mjd + mjd0 is not within [0, ten years].

        """

        is_in_range = np.atleast_1d((mjd - mjd0 >= 0) & (mjd - mjd0 < 10 * 365.25))
        if not np.all(is_in_range):
            raise ValueError("mjd + mjd0 must be within [0, 10 year]")

        lc = self.get_lightcurve(i, band)
        if lc is None:
            return None
        return np.interp(mjd - mjd0, util.get_mjd_vec(), lc)
+10 −11
Original line number Diff line number Diff line
@@ -11,6 +11,8 @@ import os
import fitsio
import numpy as np

from lsst_inaf_agile import util

logger = logging.getLogger(__name__)


@@ -25,27 +27,24 @@ class CatalogGalaxy:
        self.rfbands = [s.strip() for s in self.egg["RFBANDS"][0]]
        self.catalog = self.get_catalog()

    def get_catalog(self):
    def get_catalog(self, overwrite=False):
        """Generate the galaxy catalog and writes it to disk."""
        # Create the Galaxy catalog
        self.catalog = np.empty_like(self.egg["RA"][0], dtype=self.get_dtype())
        filename = f"{self.dirname}/galaxy.fits"

        if not os.path.exists(filename):
        if not os.path.exists(filename) or overwrite:
            # column, type, description, unit
            for c, t, _, _ in self.get_columns(self.bands, self.rfbands):
                self.catalog[c] = self.get_galaxy(c).astype(t)

            if not os.path.exists(os.path.dirname(filename)):
                os.makedirs(os.path.dirname(filename))

            util.create_directory(filename)
            fitsio.write(filename, self.catalog, clobber=True)

        self.catalog = fitsio.read(filename)
        return self.catalog

    @staticmethod
    def get_columns(bands: list[str], rfbands: list[str], is_public: bool = False) -> list[tuple]:
    def get_columns(bands: list[str], rfbands: list[str]) -> list[tuple]:
        """
        Return the galaxy catalog columns, types, and descriptions.

@@ -54,7 +53,7 @@ class CatalogGalaxy:
        """
        ret = (
            [
                ("ID", np.int64, "unique ID", ""),
                ("ID", np.int64, "Unique ID", ""),
                ("RA", np.float64, "Right ascenscion", "deg"),
                ("DEC", np.float64, "Declination", "deg"),
                ("Z", np.float64, "Cosmological redshift", ""),
@@ -85,7 +84,7 @@ class CatalogGalaxy:

        return ret

    def get_dtype(self):
    def get_dtype(self) -> np.dtype:
        """Return the numpy dtype corresponding to the columns."""
        dtype = []
        for n, t, _, _ in self.get_columns(self.bands, self.rfbands):
@@ -100,7 +99,7 @@ class CatalogGalaxy:
        key_without_suffix = key.replace("_disk", "").replace("_bulge", "")

        try:
            if key_without_suffix in _bands:
            if key_without_suffix in self.bands:
                # 20250514
                # idx = _bands.index(key_without_suffix)
                idx = self.bands.index(key_without_suffix)
@@ -115,7 +114,7 @@ class CatalogGalaxy:
                idx = self.rfbands.index(key.replace("magabs_", ""))
                return self.egg["RFMAG"][0, :, idx]

        except IndexError:
        except ValueError:
            # Fluxes not generated... return zerovalues
            return np.full_like(self.egg["Z"][0], np.inf if "magabs_" in key else 0.0)

+5 −4
Original line number Diff line number Diff line
@@ -100,10 +100,11 @@ class CatalogStar:
            table = "lsst_sim.simdr2_binary"
            selection = "c3_" + selection

        ra_min = self.catalog_galaxy["RA"].min()
        ra_max = self.catalog_galaxy["RA"].max()
        dec_min = self.catalog_galaxy["DEC"].min()
        dec_max = self.catalog_galaxy["DEC"].max()
        ra, dec = (self.catalog_galaxy[k] for k in ("RA", "DEC"))
        ra_min = ra.min()
        ra_max = ra.max()
        dec_min = dec.min()
        dec_max = dec.max()

        query = f"""
                SELECT * FROM {table}
Loading