Unverified Commit 7e4ac2f1 authored by Akke Viitanen's avatar Akke Viitanen
Browse files

increase image_simulator test coverage

parent cdfeb2f9
Loading
Loading
Loading
Loading
Loading
+29 −14
Original line number Diff line number Diff line
@@ -52,12 +52,14 @@ class ImageSimulator:
            ]
        logger.info(f"First MJD is {self.mjd0=}")

    def get_visit(self, observation_id=None):
    def get_visit(self, observation_id=None, limit=None):
        """Get the visit corresponding to the observation_id."""
        # Build the query
        query = "SELECT * FROM observations"
        if observation_id is not None:
            query += f" WHERE observationId = {observation_id}"
        if limit is not None:
            query += f" LIMIT {limit}"

        # Query the database
        with sqlite3.connect(self.filename_baseline) as conn:
@@ -106,17 +108,29 @@ class ImageSimulator:
        observation_id: int,
        exptime: str = "38 29.2 29.2 29.2 29.2 29.2",
        ra_dec=None,  # override ra_dec?
    ):
        """Write an instance catalog to a file."""
    ) -> list[str]:
        """
        Write an instance catalog to a file.

        Returns
        -------
        filenames: list[str]
            Filenames that were written as output.
        """

        # Get the header and the catalog
        header = self.get_header(observation_id, ra_dec)
        b = "ugrizy"[header["filter"]]
        band = f"lsst-{b}"
        catalog = self.get_catalog(band, header["mjd"])

        # record the output filenames
        filenames = []

        # Write the header and the catalog
        filename_instance_catalog = f"{self.dirname}/{observation_id}/instance_catalog.txt"
        util.create_directory(filename_instance_catalog)
        filenames.append(filename_instance_catalog)

        logger.info(f"Writing {filename_instance_catalog}")
        with open(filename_instance_catalog, "w") as f:
@@ -154,6 +168,10 @@ class ImageSimulator:
            with open(filename_yaml, "w") as f:
                print(imsim_yaml, file=f)

            filenames.append(filename_yaml)

        return filenames

    def get_catalog(self, band, mjd):
        """
        Return the instance catalog corresponding to band and mjd.
@@ -167,8 +185,11 @@ class ImageSimulator:

        """
        is_agn = self.catalog["is_agn"]
        is_galaxy = self.catalog["Z"] > 0
        is_star = ~is_galaxy
        is_galaxy = self.catalog.get_is_galaxy()
        is_star = self.catalog.get_is_star()
        n_galaxy, n_star, n_binary = self.catalog.get_number_galaxy_star_binary()
        assert is_galaxy.sum() == n_galaxy
        assert is_star.sum() == n_star + n_binary

        uid = self.catalog["ID"]
        ra = self.catalog["RA"]
@@ -190,21 +211,15 @@ class ImageSimulator:
            # Handle star
            if is_star[i]:
                # NOTE: dirty
                n_galaxy, n_star, n_binary = self.catalog.get_number_galaxy_star_binary()
                idx = self.catalog.get_index_star(i)

                # Handle single star
                if n_galaxy <= i < n_galaxy + n_star:
                    # Handle single star
                    lc = self.catalog.catalog_star.get_lightcurve_mjd(idx, band, mjd, self.mjd0)

                else:
                    # Handle binary star
                elif n_galaxy + n_star <= i:
                    lc = self.catalog.catalog_binary.get_lightcurve_mjd(idx, band, mjd, self.mjd0)

                # Handle something else
                else:
                    raise ValueError

                mag_point[i] = util.flux_to_mag(lc)
                catalog += [
                    f"object {uid[i]} {ra[i]} {dec[i]} {mag_point[i]} "
+2 −2
Original line number Diff line number Diff line
@@ -341,8 +341,8 @@ def get_galaxy_ab(reff, ratio):
    #   ratio == b / a
    #   b = ratio * a
    #   ellip = 1 - ratio
    a = reff / np.sqrt(ratio)
    b = reff * np.sqrt(ratio)
    a = np.ma.true_divide(reff, np.ma.sqrt(ratio))
    b = reff * np.ma.sqrt(ratio)
    return a, b


+114 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2025-07-04 20:39:49

from unittest import TestCase

import numpy as np
from lsst_inaf_agile.image_simulator import ImageSimulator

from tests.lsst_inaf_agile.test_catalog_combined import create_combined_catalog


class TestImageSimulator(TestCase):
    def setUp(self):
        self.image_simulator = ImageSimulator(
            "data/tests/test_image_simulator",
            create_combined_catalog(),
            "data/baseline/baseline_v4.0_10yrs.db",
        )

    def test__init__(self):
        self.assertTrue("image_simulator" in self.image_simulator.dirname)
        self.assertGreaterEqual(self.image_simulator.mjd0, 60000.0)

    def test_get_visit(self):
        visit0 = self.image_simulator.get_visit(None, limit=100)
        self.assertEqual(len(visit0), 100)

        visit1 = self.image_simulator.get_visit(0)
        visit2 = self.image_simulator.get_visit(10)
        visit3 = self.image_simulator.get_visit(100)

        for visit in visit1, visit2, visit3:
            self.assertTrue("observationId" in visit)
            self.assertTrue("observationStartMJD" in visit)

        self.assertGreaterEqual(visit2["observationStartMJD"], visit1["observationStartMJD"])
        self.assertGreaterEqual(visit3["observationStartMJD"], visit2["observationStartMJD"])

    def test_get_header(self):
        header1 = self.image_simulator.get_header(0)
        header2 = self.image_simulator.get_header(0, ra_dec=(0, 1))
        for k in header1:
            if k in ("rightascension", "declination"):
                self.assertNotEqual(header1[k], header2[k])
            else:
                self.assertEqual(header1[k], header2[k])

    def test_write_instance_catalog(self):
        finst, *fyaml = self.image_simulator.write_instance_catalog(0)

        # Test some instance catalog lines
        has_ra = False
        has_dec = False
        has_object = False
        with open(finst, "r", encoding="utf-8") as f:
            for line in f:
                has_ra += "rightascension" in line
                has_dec += "declination" in line
                has_object += "object" in line
        self.assertTrue(has_ra)
        self.assertTrue(has_dec)
        self.assertTrue(has_object)

        # Test some yaml lines
        for fy in fyaml:
            has_fexptime = False
            has_band = False
            with open(fy, "r", encoding="utf-8") as f:
                for line in f:
                    has_fexptime += "fexptime" in line
                    has_band += "sband" in line
            self.assertTrue(has_fexptime)
            self.assertTrue(has_band)

        # Test some instcat files
        finst2, *_ = self.image_simulator.write_instance_catalog(1)
        finst3, *_ = self.image_simulator.write_instance_catalog(1, exptime="0 0 0 0 0 0")
        finst4, *_ = self.image_simulator.write_instance_catalog(1, ra_dec=(0, 1))

        lines = []
        for ff in finst, finst2, finst3, finst4:
            with open(ff, "r", encoding="utf-8") as f:
                lines.append("".join(f))
        lines1, lines2, lines3, lines4 = lines

        self.assertNotEqual(lines1, lines2)
        # NOTE: exptime enters the yaml file
        self.assertEqual(lines2, lines3)
        # NOTE: ra, dec enters the instcat
        self.assertNotEqual(lines3, lines4)

    def test_get_catalog(self):
        with self.assertRaises(ValueError):
            for b in "ugrizy":
                _ = self.image_simulator.get_catalog(f"lsst-{b}", 0)

        bands = "ugrizy"
        c1 = [self.image_simulator.get_catalog(f"lsst-{b}", self.image_simulator.mjd0) for b in bands]
        c2 = [self.image_simulator.get_catalog(f"lsst-{b}", self.image_simulator.mjd0 + 100) for b in bands]
        for i in range(len(bands)):
            # test that bands differ
            self.assertTrue(np.any(c1[i] != c2[(i + 1) % len(bands)]))
            # test that mjds differ
            self.assertTrue(np.any(c1[i] != c2[i]))

    def test_simulate_image(self):
        # TODO: fix the environment to be able to actually run galsim... maybe
        # through pyproject.toml
        self.image_simulator.simulate_image(0, 94)
        self.image_simulator.simulate_image(1, 94)
        self.image_simulator.simulate_image(1, 93)
        self.image_simulator.simulate_image(1, 95)