Commit 92f2b350 authored by Akke Viitanen's avatar Akke Viitanen
Browse files

Further refactoring with ruff

parent ed0cf542
Loading
Loading
Loading
Loading
+33 −107
Original line number Diff line number Diff line
@@ -8,44 +8,32 @@
Create Galaxy+AGN mocks for the LSST Italian AGN in-kind contribution
"""

import logging
logger = logging.getLogger(__name__)

from copy import deepcopy
from itertools import product
import glob
import logging
import multiprocessing
import os
import sys
import time
import re

from astropy.time import Time
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
import astropy.units as u
import fitsio
import numpy as np
import pandas as pd
import sqlite3

import lightcurve
import lusso2010
from mbh import get_log_mbh_continuity_new2
from merloni2014 import Merloni2014
from mock_catalog_SED.qsogen_4_catalog import qsosed
import my_lsst
import sed
import smf
import ueda2014
import util
from util import ROOT
import zou2024

logger = logging.getLogger(__name__)

# Optimization: store the pre-computed AGN SEDs and lightcurves
#SEDS = {}
# Optimization: store the pre-computed AGN fluxes and lightcurves
FLUXES = {}
LBOL = {}
LIGHTCURVES = {}

try:
@@ -55,69 +43,13 @@ try:
    #POSTERIOR_DISTRIBUTION = pd.read_csv(f"{ROOT}/data/posteriors/posterior_2024_05_20.dat", sep=' ')
    POSTERIOR_DISTRIBUTION = pd.read_csv(f"{ROOT}/data/posteriors/posterior_frozen.dat", sep=' ')
    PARAMETER_NAMES = POSTERIOR_DISTRIBUTION.columns
except:
except FileNotFoundError:
    logger.warning("Posterior distribution not found")


# Dictionary for pre-loaded observations
DF = {}

def get_band_egg(band):

    band = band.replace("magabs_", "")
    with open(f"{ROOT}/egg/share/filter-db/db.dat") as f:
        for line in f:
            name, filename = line.strip().split('=')
            if name != band:
                continue
            return fitsio.read(f"{ROOT}/egg/share/filter-db/{filename}")
    return None


def _get_lightcurve_agn(args):

    i, filename, band, flux, z, mag_i, logMbh, type2 = args
    logger.debug(f"Estimating AGN lightcurve {i} {band}")

    kwargs = {
        "mjd0": 0,
        "mjd": np.arange(0, 3653, 1),
        "band": band,
        "flux": flux,
        "z": z,
        "mag_i": mag_i,
        "logMbh": logMbh,
        "type2": type2,
        "T": 3653,
        "deltatc": 1,
        "seed": i + 1
    }

    if not os.path.exists(filename):
        dirname = os.path.dirname(filename)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        lc = lightcurve.get_lightcurve_agn(**kwargs)
        np.save(filename, lc)

    lc = np.load(filename)
    return i, band, lc


class CatalogAGN:

    def __init__(
        self,
        dirname,
        egg,
        type_plambda,
        save_sed,
        seed,
        merloni2014_interpolate,
        merloni2014_extrapolate,
        merloni2014_f_obs_minimum,
        merloni2014_f_obs_maximum,
    ):
    def __init__(self, dirname, egg, type_plambda, save_sed, seed, merloni2014):

        self.dirname = dirname
        self.egg = egg
@@ -125,19 +57,14 @@ class CatalogAGN:
        self.rfbands = egg["RFBANDS"][0]
        self.type_plambda = type_plambda
        self.save_sed = int(save_sed)
        self.seed = seed
        self.merloni2014 = merloni2014

        # Set the seed
        self.seed = seed
        if self.seed is not None:
            self.seed = int(self.seed)
        np.random.seed(self.seed)

        self.merloni2014 = Merloni2014(
            int(merloni2014_interpolate),
            int(merloni2014_extrapolate),
            float(merloni2014_f_obs_minimum),
            float(merloni2014_f_obs_maximum),
        )

        # Create the catalog
        self.catalog = self.get_catalog()
@@ -147,30 +74,32 @@ class CatalogAGN:
            self.get_lightcurve_agn(band=f"lsst-{b}")

    def __getitem__(self, key):
        if not key in self.catalog.dtype.names:
        if key not in self.catalog.dtype.names:
            return self.egg[key][0]
        return self.catalog[key]

    def _get_select(self, z=None, m=None, l=None, t="all"):
    def _get_select(self, redshift=None, mstar=None, luminosity=None, t="all"):

        select = np.ones_like(self["Z"], dtype=bool)

        # Perform the selection in redshift and Mstar
        if z is not None:
            if z not in self.SZ:
                self.SZ[z] =  (z <= self["Z"]) * (self["Z"] < z + self.dz)
            select *= self.SZ[z]
        if redshift is not None:
            if redshift not in self.SZ:
                self.SZ[redshift] =  (redshift <= self["Z"]) * (self["Z"] < redshift + self.dz)
            select *= self.SZ[redshift]

        if m is not None:
            if m not in self.SM:
                self.SM[m] = (m <= self["M"]) * (self["M"] < m + self.dm)
            select *= self.SM[m]
        if mstar is not None:
            if mstar not in self.SM:
                self.SM[mstar] = (mstar <= self["M"]) * (self["M"] < mstar + self.dm)
            select *= self.SM[mstar]

        # Perform the selection in lx if requested
        if l is not None:
            if l not in self.SL:
                self.SL[l] = (l <= self.catalog["log_LX_2_10"]) * (self.catalog["log_LX_2_10"] < l + self.dl)
            select *= self.SL[l]
        if luminosity is not None:
            if luminosity not in self.SL:
                self.SL[luminosity] = \
                    (luminosity <= self.catalog["log_LX_2_10"]) * \
                    (self.catalog["log_LX_2_10"] < luminosity + self.dl)
            select *= self.SL[luminosity]

        # Perform the selection in type
        if t != "all":
@@ -341,7 +270,6 @@ class CatalogAGN:
        Lusso+10 eq. 5 (inverted) Lx = alpha L_opt - beta
        """

        import lusso2010
        return lusso2010.get_log_L_2500(self["log_L_2_keV"], alpha, beta, scatter)

        # 20241001
@@ -360,7 +288,7 @@ class CatalogAGN:
    def get_is_optical_type2(self, func_get_f_obs=None, use_f_obs_for_ctk_agn=False):

        # Get obscured AGN fraction from Merloni+2014 as a function of z, LX
        select_ctn = self["is_agn_ctn"]
        #select_ctn = self["is_agn_ctn"]
        select_ctk = self["is_agn_ctk"]

        if func_get_f_obs is None:
@@ -405,8 +333,8 @@ class CatalogAGN:
        type_1_optical = ~self["is_optical_type2"]
        type_2_optical =  self["is_optical_type2"]

        N_type_1 = np.sum(type_1_optical)
        N_type_2 = np.sum(type_2_optical)
        #N_type_1 = np.sum(type_1_optical)
        #N_type_2 = np.sum(type_2_optical)

        ebv[type_1_optical] = sample_ebv(type_1_optical.sum(), hopkins04, type_1_ebv, alpha_1, n_1)
        ebv[type_2_optical] = sample_ebv(type_2_optical.sum(), hopkins04, type_2_ebv, alpha_2, n_2) + mu_type_2
@@ -462,7 +390,7 @@ class CatalogAGN:
            p_ctk_prime = p_ctk / (1 - p_ctn * dloglambda)
        """

        p_ctk = CatalogGalaxyAGN.get_plambda_ctk(loglambda, mstar, z, t)
        p_ctk = CatalogAGN.get_plambda_ctk(loglambda, mstar, z, t)
        p_ctn = 10 ** zou2024.get_log_plambda(loglambda, mstar, z, t)
        p_ctk_prime = p_ctk / (1 - p_ctn * dloglambda)
        return p_ctk_prime
@@ -487,21 +415,19 @@ class CatalogAGN:

    def get_sed(self, i):

        filename = f"{dirname}/seds/agn-seds-{i}.fits"
        filename = f"{self.dirname}/seds/agn-seds-{i}.fits"
        if not os.path.exists(filename):
            return None

        fits = fitsio.read(filename)
        return fits["LAMBDA"][0], fits["FLUX"][0]

    def _get_sed(self, i, ratio_max=0.90):
    def _get_sed(self, i, ratio_max=0.90, dlog_wav=7.65e-4):

        logger.debug(f"Getting AGN SED {i} {self.catalog.size}")

        # Get the wavlen in angstrom
        # NOTE: Add some interesting wavelengths for greater accuracy
        #dlog_wav = 7.65e-5
        dlog_wav = 7.65e-4
        wavlen = 10 ** np.arange(np.log10(500), np.log10(250000) + dlog_wav, dlog_wav)
        wavlen = np.append(wavlen, [1450, 4400, 5007, 150000])
        wavlen = np.sort(wavlen)
@@ -616,7 +542,7 @@ class CatalogAGN:

            # NOTE: fluxes should have been initialized by now
            if (idx, band, rest_frame) not in FLUXES:
                my_sed = self._get_sed(idx)
                _ = self._get_sed(idx)
            flux[idx] = FLUXES[idx, band, rest_frame]

        return flux
@@ -760,11 +686,11 @@ class CatalogAGN:
        if scatter:
            A = np.random.normal(-0.62, 0.16, N)
            B = np.random.normal(0.70, 0.085, N)
            C = np.random.normal(0.74, 0.06, N)
            #C = np.random.normal(0.74, 0.06, N)
        else:
            A = np.random.normal(-0.62, 0.0, N)
            B = np.random.normal(0.70, 0.0, N)
            C = np.random.normal(0.74, 0.0, N)
            #C = np.random.normal(0.74, 0.0, N)

        R = A + B * X + mu
        return R + 38
+10 −13
Original line number Diff line number Diff line
@@ -9,22 +9,19 @@ Test AGN catalog
"""

import logging
logger = logging.getLogger()


from egg import Egg
from catalog_agn import CatalogAGN

from astropy.io import fits
from merloni2014 import Merloni2014

logger = logging.getLogger()
egg = fits.open("data/catalog/test/egg.fits")[1].data
merloni2014 = Merloni2014(1, 0, 0.05, 0.95)
c = CatalogAGN(
    "data/test/test_agn",
    egg,
    "zou+2024",
    1,
    20250621,
    1,
    0,
    0.05,
    0.95
    dirname="data/test/test_agn",
    egg=egg,
    type_plambda="zou+2024",
    save_sed=1,
    seed=20250621,
    merloni2014=merloni2014,
)