Unverified Commit 72545af3 authored by Akke Viitanen's avatar Akke Viitanen
Browse files

fix catalog combined

parent 6d1bf63b
Loading
Loading
Loading
Loading
+25 −11
Original line number Diff line number Diff line
@@ -58,7 +58,7 @@ class CatalogCombined:
    def __init__(
        self,
        dirname: str,
        catalog_galaxy=None,
        catalog_galaxy,
        catalog_agn=None,
        catalog_star=None,
        catalog_binary=None,
@@ -72,13 +72,15 @@ class CatalogCombined:
        self.catalog_binary = catalog_binary

        # NOTE: short-circuit for an existing truth catalog
        if os.path.exists(f := self.get_filename()):
            logger.info(f"Found catalog FITS file {f}")
            self.catalog_combined = fitsio.read(f)
        if os.path.exists(filename := self.get_filename()):
            logger.info(f"Found catalog FITS file {filename}")
            self.catalog_combined = util.read_fits(filename)
            return

        # NOTE: short-circuit for an existing database truth catalog
        if filenames := glob.glob(f"{self.dirname}/db/**/master.db", recursive=True):
        if sql_query is not None and (
            filenames := glob.glob(f"{self.dirname}/db/**/master.db", recursive=True)
        ):
            import sqlite3

            import pandas as pd
@@ -96,6 +98,14 @@ class CatalogCombined:
        self.catalog_combined = self.postprocess()
        self.write()

    def get_catalogs(self):
        """Return the non-None catalogs."""
        ret = []
        for c in self.catalog_galaxy, self.catalog_agn, self.catalog_star, self.catalog_binary:
            if c is not None:
                ret.append(c)
        return ret

    def get_filename(self):
        """Return the FITS filename of the combined catalog."""
        return f"{self.dirname}/catalog.fits"
@@ -110,7 +120,7 @@ class CatalogCombined:

        """
        dtype = {}
        for catalog in (self.catalog_galaxy, self.catalog_agn, self.catalog_star, self.catalog_binary):
        for catalog in self.get_catalogs():
            for name in catalog.catalog.dtype.names:
                dtype[name] = catalog.catalog.dtype[name]

@@ -155,7 +165,7 @@ class CatalogCombined:
        """
        n_galaxy, n_star, n_binary = self.get_number_galaxy_star_binary()
        n_total = n_galaxy + n_star + n_binary
        catalog_combined = np.full(n_total, np.nan, dtype=self.get_dtype())
        catalog_combined = np.zeros(n_total, dtype=self.get_dtype())

        # Set the catalog ID
        catalog_combined["ID"] = np.arange(n_total)
@@ -254,16 +264,20 @@ class CatalogCombined:

        from astropy.table import Table

        table = Table(self.catalog)
        logger.info("Creating the directory ...")
        util.create_directory(filename_database)

        logger.info("Creating the pandas dataframe ...")
        table = Table(self.catalog_combined)
        df = table.to_pandas()
        df.reset_index()

        logger.info("Ingesting ...")
        with sqlite3.connect(filename_database) as con:
            df.to_sql(name_table, con, index=False, if_exists=if_exists, chunksize=1024, method="multi")

    def get_is_galaxy(self):
        """Return boolean vector selecting galaxies."""
        return np.isfinite(self.catalog_combined["Z"])
        return self.catalog_combined["Z"] > 0

    def get_is_star(self):
        """Return boolean vector selecting stars."""
@@ -374,7 +388,7 @@ class CatalogCombined:
    def get_area(self):
        """Return the EGG catalog area in deg2."""
        try:
            cmd = fitsio.read(f"{self.dirname}/egg.fits", columns="CMD")[0]
            cmd = util.read_fits(f"{self.dirname}/egg.fits", columns="CMD")[0]
            area_string = re.findall("area=([0-9.]+)", cmd)[0]
            return float(area_string)
        except OSError:
+17 −0
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ import logging
import os

import astropy.units as u
import fitsio
import numpy as np
from astropy import constants
from astropy.coordinates import SkyCoord
@@ -380,3 +381,19 @@ def get_stellar_mass_completeness_cosmos2020(type: str, redshift: float) -> floa
    }
    f1, f2 = factors[type]
    return f1 * (1 + redshift) + f2 * (1 + redshift) ** 2


def read_fits(filename, *args, **kwargs):
    """
    Read a FITS filename with supressed error messages
    """
    import warnings

    from astropy.units import UnitsWarning

    logger.info(f"Reading {filename}")
    ret = None
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UnitsWarning)
        ret = fitsio.read(filename, *args, **kwargs)
    return ret
+206 −36
Original line number Diff line number Diff line
@@ -5,8 +5,10 @@

"""Test the combined catalog of AGNs, galaxies, and stars."""

import os
from unittest import TestCase

import numpy as np
from lsst_inaf_agile.catalog_agn import CatalogAGN
from lsst_inaf_agile.catalog_combined import CatalogCombined
from lsst_inaf_agile.catalog_galaxy import CatalogGalaxy
@@ -15,9 +17,7 @@ from lsst_inaf_agile.egg import Egg
from lsst_inaf_agile.merloni2014 import Merloni2014


def create_combined_catalog():
    dirname = "data/tests/test_catalog_combined"

def create_combined_catalog(dirname="data/tests/test_catalog_combined"):
    # Galaxy
    filename = f"{dirname}/egg.fits"
    egg_kwargs = Egg.get_example_egg_kwargs(filename)
@@ -47,87 +47,257 @@ def create_combined_catalog():


class TestCatalogCombined(TestCase):
    def setUp(self):
        self.catalog_combined = create_combined_catalog()

    def test_init(self):
        # non-existing init
        # existing truth catalog
        # existing database
        ...

        # cleanup
        dirname = "data/tests/test_catalog_combined"
        os.system(f"rm -rfv {dirname}")

        # create the catalog
        _ = create_combined_catalog(dirname)

        # test existing FITS file
        c2 = create_combined_catalog(dirname)

        # test existing master.db
        os.remove(c2.get_filename())
        os.makedirs(f"{dirname}/db", exist_ok=True)
        os.system(f"touch {dirname}/db/master.db")

        # test non-existing table
        import pandas

        with self.assertRaises(pandas.errors.DatabaseError):
            _ = CatalogCombined(
                dirname, catalog_galaxy=c2.catalog_galaxy, sql_query="SELECT * FROM NonExisting"
            )

        # test existing table after ingesting
        self.catalog_combined.ingest(f"{dirname}/db/master.db", "Truth")
        _ = CatalogCombined(dirname, catalog_galaxy=c2.catalog_galaxy, sql_query="SELECT * FROM Truth")

        # cleanup
        os.system(f"rm -rfv {dirname}")

    def test_get_filename(self):
        # make sure makes sense
        ...
        filename = self.catalog_combined.get_filename()
        self.assertTrue(filename.endswith(".fits"))
        self.assertTrue(os.path.exists(filename))

    def test_get_dtype(self):
        # check for some obvious columns
        ...
        dtype = self.catalog_combined.get_dtype()
        keys = [k for k, v in dtype]
        self.assertIn("Z", keys)
        self.assertIn("M", keys)
        self.assertIn("SFR", keys)
        self.assertIn("log_lambda_SAR", keys)
        self.assertIn("log_LX_2_10", keys)
        self.assertTrue(any("_total" in k for k in keys))

    def test_get_number_galaxy_star_binary(self):
        # not much to check here honestly
        ...
        n1, n2, n3 = self.catalog_combined.get_number_galaxy_star_binary()
        select1 = self.catalog_combined["Z"] > 0
        select2 = ~select1
        self.assertGreaterEqual(n1, select1.sum())
        self.assertGreaterEqual(n2 + n3, select2.sum())

    def test_get_catalog_combined(self):
        c = self.catalog_combined.get_catalog_combined()
        # check id
        self.assertTrue(np.all(c["ID"] >= 0))
        self.assertTrue(0 in c["ID"])
        self.assertTrue(c["ID"].size == np.unique(c["ID"]).size)
        # check galaxy columns
        self.assertTrue(np.all(c["Z"] >= 0))
        self.assertTrue(np.any(c["Z"] > 0))
        self.assertTrue(np.all(c["M"] >= 0))
        self.assertTrue(np.any(c["M"] > 0))
        self.assertTrue(np.all(c["SFR"] >= 0))
        # check agn columns
        self.assertTrue(np.all(c["log_lambda_SAR"] >= -100))
        self.assertTrue(np.any(c["is_agn"] == 1))
        # check stellar flags are false
        ...
        is_star = ~(c["Z"] > 0)
        self.assertTrue(np.all(c["is_agn"][is_star] == 0))
        self.assertTrue(np.all(c["is_optical_type2"][is_star] == 0))
        # NOTE: at this stage total fluxes are not calculated. Assert they exist
        self.assertTrue(np.all(c["lsst-u_total"] == 0))
        self.assertTrue(np.all(c["lsst-g_total"] == 0))
        self.assertTrue(np.all(c["lsst-r_total"] == 0))
        self.assertTrue(np.all(c["lsst-i_total"] == 0))
        self.assertTrue(np.all(c["lsst-z_total"] == 0))
        self.assertTrue(np.all(c["lsst-y_total"] == 0))

    def test_postprocess(self):
        self.catalog_combined.postprocess()
        c = self.catalog_combined.catalog_combined

        # check has _total fluxes
        self.assertTrue(np.any(c["lsst-u_total"] > 0))
        self.assertTrue(np.any(c["lsst-g_total"] > 0))
        self.assertTrue(np.any(c["lsst-r_total"] > 0))
        self.assertTrue(np.any(c["lsst-i_total"] > 0))
        self.assertTrue(np.any(c["lsst-z_total"] > 0))
        self.assertTrue(np.any(c["lsst-y_total"] > 0))

        # check has _total magabs
        # range check fluxed and magabs
        ...
        self.assertTrue(np.any(c["magabs_lsst-u_total"] < 0.0))
        self.assertTrue(np.any(c["magabs_lsst-g_total"] < 0.0))
        self.assertTrue(np.any(c["magabs_lsst-r_total"] < 0.0))
        self.assertTrue(np.any(c["magabs_lsst-i_total"] < 0.0))
        self.assertTrue(np.any(c["magabs_lsst-z_total"] < 0.0))
        self.assertTrue(np.any(c["magabs_lsst-y_total"] < 0.0))

        # range check fluxes and magabs
        self.assertTrue(np.all(c["lsst-u_total"] < 10000.0))
        self.assertTrue(np.all(c["lsst-g_total"] < 10000.0))
        self.assertTrue(np.all(c["lsst-r_total"] < 10000.0))
        self.assertTrue(np.all(c["lsst-i_total"] < 10000.0))
        self.assertTrue(np.all(c["lsst-z_total"] < 10000.0))
        self.assertTrue(np.all(c["lsst-y_total"] < 10000.0))

    def test_get_flux_total(self):
        # bands
        # observed flux
        # restframe absolute magnitude
        ...
        for b in "ugrizy":
            f1 = self.catalog_combined[f"lsst-{b}_bulge"]
            f2 = self.catalog_combined[f"lsst-{b}_disk"]
            f3 = self.catalog_combined[f"lsst-{b}_point"]
            ftot = self.catalog_combined.get_flux_total(f"lsst-{b}")
            self.assertTrue(np.allclose(f1 + f2 + f3, ftot))

        # rest-frame magnitudes
        m1 = self.catalog_combined[f"magabs_lsst-{b}"]
        m2 = self.catalog_combined[f"magabs_lsst-{b}_point"]
        mtot = self.catalog_combined.get_flux_total(f"lsst-{b}", rest_frame=True)
        self.assertTrue(np.allclose(-2.5 * np.log10(10 ** (-0.4 * m1) + 10 ** (-0.4 * m2)), mtot))

    def test_write(self):
        # check exists
        ...
        filename = self.catalog_combined.get_filename()
        if os.path.exists(filename):
            os.remove(filename)
        self.catalog_combined.write()
        self.assertTrue(os.path.exists(filename))

        # read the file
        from lsst_inaf_agile import util

        fits = util.read_fits(filename)
        self.assertGreater(len(fits), 0)

    def test_ingest(self):
        filename = "data/tests/test_catalog_combined/db/master.db"

        # ingest to a database and a table
        self.catalog_combined.ingest(filename, "test1")

        # ingest to another table
        self.catalog_combined.ingest(filename, "test2")

        # check few if_exists
        ...
        import sqlite3

        import pandas as pd

        with sqlite3.connect(filename) as con:
            df1 = pd.read_sql_query("SELECT * FROM test1", con)
            df2 = pd.read_sql_query("SELECT * FROM test2", con)
            self.assertTrue(np.all(df1["ID"] == self.catalog_combined["ID"]))
            self.assertTrue(np.all(df1["ID"] == df2["ID"]))

    def test_get_is_galaxy(self):
        # check ID
        # check Ngalaxy
        ...
        is_galaxy = self.catalog_combined.get_is_galaxy()
        self.assertEqual(self.catalog_combined["ID"][is_galaxy][0], 0)
        self.assertEqual(is_galaxy.sum(), len(self.catalog_combined.catalog_galaxy.catalog))
        self.assertTrue(np.all(self.catalog_combined["Z"][is_galaxy] > 0))

    def test_get_is_star(self):
        # check ID
        # check Nstar
        ...
        is_star = self.catalog_combined.get_is_star()
        self.assertEqual(
            is_star.sum(),
            len(self.catalog_combined.catalog_star.catalog)
            + len(self.catalog_combined.catalog_binary.catalog),
        )
        self.assertTrue(~np.any(self.catalog_combined["Z"][is_star] > 0))

    def test_get_index_star(self):
        # check against star table
        ...
        ngal = len(self.catalog_combined.catalog_galaxy.catalog)
        nstar = len(self.catalog_combined.catalog_star.catalog)

        idx0 = self.catalog_combined.get_index_star(0)
        idxs = self.catalog_combined.get_index_star(ngal)
        idxb = self.catalog_combined.get_index_star(ngal + nstar)

        self.assertTrue(idx0 is None)
        self.assertEqual(idxs, 0)
        self.assertEqual(idxb, 0)

    def test_write_reference_catalog(self):
        filename = "data/tests/test_catalog_combined/reference_catalog.csv"

        # write reference catalog
        self.catalog_combined.write_reference_catalog(filename, 24, "lsst-r")

        # write it again -> NOP
        self.catalog_combined.write_reference_catalog(filename, 24, "lsst-r")

        # redo the above with an existing reference_catalog dirname
        # make sure rm -rfv is called
        ...
        os.makedirs("data/tests/test_catalog_combined/reference_catalog", exist_ok=True)
        os.system(f"rm -fv {filename}")
        self.catalog_combined.write_reference_catalog(filename, 24, "lsst-r")

        # test the created reference catalog
        from astropy.table import Table

        self.assertTrue(os.path.exists(filename))
        reference_catalog = Table.read(filename, format="csv")
        select = np.isin(self.catalog_combined["ID"], reference_catalog["source_id"])
        self.assertTrue(np.all(self.catalog_combined["ID"][select] == reference_catalog["source_id"]))

    def test_get_area(self):
        # try with egg.fits
        # try without egg.fts
        ...
        area = self.catalog_combined.get_area()
        self.assertAlmostEqual(area, 0.001)

        # try without egg.fits -> defaults to DR1 area
        os.system("rm -fv data/tests/test_catalog_combined/egg.fits")
        area = self.catalog_combined.get_area()
        self.assertAlmostEqual(area, 24.0)

    def test_getitem(self, key):
        # try items compared against their catalog values
        ...
    def test_getitem(self):
        self.assertTrue(np.all(self.catalog_combined.catalog_combined["ID"] == self.catalog_combined["ID"]))

    def test_get_luminosity_function(self):
        # try select None or some vector
        # try values None or some vector
        # try deredden None or some vector
        # try occupation fraction None or some vector
        # try overflowing z ranges
        ...
        key = "M"
        bins = np.arange(9.5, 12.1, 0.1)
        zmin = 0.5
        zmax = 1.0
        kwargs = dict(key=key, bins=bins, zmin=zmin, zmax=zmax)
        select = self.catalog_combined["Z"] > 0.55
        values = np.random.rand(select.size)

        # create some luminosity functions
        lf0 = self.catalog_combined.get_luminosity_function(**kwargs)[0]
        lf1 = self.catalog_combined.get_luminosity_function(**kwargs, select=select)[0]
        lf2 = self.catalog_combined.get_luminosity_function(**kwargs, values=values)[0]
        lf3 = self.catalog_combined.get_luminosity_function(**kwargs, deredden=True)[0]
        lf4 = self.catalog_combined.get_luminosity_function(
            key, bins, zmin, zmax, use_occupation_fraction=True
        )[0]
        lf5 = self.catalog_combined.get_luminosity_function(key, bins, 0.0, 99.99)[0]

        # do some very basic testing
        self.assertTrue(np.all(lf0 >= lf1))
        self.assertTrue(np.all(lf0 >= lf2))
        self.assertTrue(np.all(lf0 <= lf3))
        self.assertTrue(np.all(lf0 >= lf4))
        self.assertTrue(np.all(lf0 >= lf5))