Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@
from diffsky.experimental import lc_phot_kern
from diffsky.experimental import mc_lightcone_halos as mclh
from diffsky.experimental import precompute_ssp_phot as psspp
from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS
from diffsky.param_utils.spspop_param_utils import get_unbounded_spspop_params_tw_dust
from diffsky.ssp_err_model.defaults import get_unbounded_ssperr_params
from diffstar.defaults import FB, T_TABLE_MIN
from diffstar.diffstarpop import get_unbounded_diffstarpop_params
from dsps.cosmology import flat_wcdm
from dsps.cosmology.defaults import DEFAULT_COSMOLOGY
from dsps.metallicity.umzr import DEFAULT_MZR_PARAMS
from jax import random as jran
from jax.flatten_util import ravel_pytree

from .. import n_mag_opt
from ..utils import zbin_volume

try:
Expand Down Expand Up @@ -40,8 +47,11 @@ def plot_n_ugriz(
label1,
label2,
saveAs,
lh_centroids=None,
lg_n_data_err_lh=None,
lg_n_thresh=None,
lgmp_min=10.0,
sky_area_degsq=0.25,
sky_area_degsq=0.1,
cosmo_params=DEFAULT_COSMOLOGY,
fb=FB,
):
Expand Down Expand Up @@ -217,6 +227,73 @@ def plot_n_ugriz(
plt.savefig(saveAs)
plt.show()

# Output loss based on lh_centroids, not 1D histograms as above
if lh_centroids is not None:
lc_nhalos = np.ones(lc_halopop["logmp0"].shape)
ran_key, n_key = jran.split(ran_key, 2)

# 1
u_diffstarpop_params1 = get_unbounded_diffstarpop_params(diffstarpop_params1)
u_diffstarpop_theta1, u_diffstarpop_unravel = ravel_pytree(
u_diffstarpop_params1
)

u_spspop_params1 = get_unbounded_spspop_params_tw_dust(spspop_params1)
u_spspop_theta1, u_spspop_unravel = ravel_pytree(u_spspop_params1)

u_ssp_err_pop_params1 = get_unbounded_ssperr_params(ssp_err_pop_params1)
u_ssp_err_pop_theta1, u_ssp_err_pop_unravel = ravel_pytree(
u_ssp_err_pop_params1
)

u_theta1 = (u_diffstarpop_theta1, u_spspop_theta1, u_ssp_err_pop_theta1)

# 2
u_diffstarpop_params2 = get_unbounded_diffstarpop_params(diffstarpop_params2)
u_diffstarpop_theta2, u_diffstarpop_unravel = ravel_pytree(
u_diffstarpop_params2
)

u_spspop_params2 = get_unbounded_spspop_params_tw_dust(spspop_params2)
u_spspop_theta2, u_spspop_unravel = ravel_pytree(u_spspop_params2)

u_ssp_err_pop_params2 = get_unbounded_ssperr_params(ssp_err_pop_params2)
u_ssp_err_pop_theta2, u_ssp_err_pop_unravel = ravel_pytree(
u_ssp_err_pop_params2
)

u_theta1 = (u_diffstarpop_theta1, u_spspop_theta1, u_ssp_err_pop_theta1)
u_theta2 = (u_diffstarpop_theta2, u_spspop_theta2, u_ssp_err_pop_theta2)

loss_args = (
lg_n_thresh,
n_key,
lc_halopop["z_obs"],
lc_halopop["t_obs"],
lc_halopop["mah_params"],
lc_halopop["logmp0"],
lc_nhalos,
lc_vol_mpc3,
t_table,
ssp_data,
precomputed_ssp_mag_table,
z_phot_table,
wave_eff_table,
DEFAULT_MZR_PARAMS,
DEFAULT_SCATTER_PARAMS,
lh_centroids,
dmag,
mag_column,
DEFAULT_COSMOLOGY,
FB,
)

loss1 = n_mag_opt._loss_kern(u_theta1, lg_n_data_err_lh, *loss_args)
loss2 = n_mag_opt._loss_kern(u_theta2, lg_n_data_err_lh, *loss_args)

print(f"default loss = {loss1:.2f}")
print(f"fit loss = {loss2:.2f}")


def get_obs_colors_mag(lc_phot, mag_column):
num_halos, n_bands = lc_phot.obs_mags_q.shape
Expand Down
2 changes: 1 addition & 1 deletion diffhtwo/experimental/n_mag.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def Gehrels_low_eq12(Ngal):


@jjit
def get_n_data_err(N, vol, N_floor=1e-3):
def get_n_data_err(N, vol, N_floor=1e-12):
N = jnp.where(N > N_floor, N, N_floor)
lg_n = jnp.log10(N / vol)

Expand Down
12 changes: 9 additions & 3 deletions diffhtwo/experimental/n_mag_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@


@jjit
def _mse_w(lg_n_pred, lg_n_target, lg_n_target_err):
mask = lg_n_target > -8.0
def _mse_w(lg_n_pred, lg_n_target, lg_n_target_err, lg_n_thresh):
mask = lg_n_target > lg_n_thresh
nbins = jnp.maximum(jnp.sum(mask), 1)

resid = lg_n_pred - lg_n_target
Expand Down Expand Up @@ -197,6 +197,7 @@ def _mse_w(lg_n_pred, lg_n_target, lg_n_target_err):
def _loss_kern(
u_theta,
lg_n_target,
lg_n_thresh,
ran_key,
lc_z_obs,
lc_t_obs,
Expand Down Expand Up @@ -273,7 +274,7 @@ def _loss_kern(
fb,
)

return _mse_w(lg_n_model, lg_n_target[0], lg_n_target[1])
return _mse_w(lg_n_model, lg_n_target[0], lg_n_target[1], lg_n_thresh)


loss_and_grad = jjit(value_and_grad(_loss_kern))
Expand All @@ -283,6 +284,7 @@ def _loss_kern(
def fit_n(
u_theta_init,
lg_n_target,
lg_n_thresh,
ran_key,
lc_z_obs,
lc_t_obs,
Expand Down Expand Up @@ -310,6 +312,7 @@ def fit_n(

other = (
lg_n_target,
lg_n_thresh,
ran_key,
lc_z_obs,
lc_t_obs,
Expand Down Expand Up @@ -349,6 +352,7 @@ def _opt_update(opt_state, i):
None,
0,
None,
None,
0,
0,
0,
Expand Down Expand Up @@ -389,6 +393,7 @@ def fit_n_multi_z(
u_theta_init,
trainable,
lg_n_target,
lg_n_thresh,
ran_key,
lc_z_obs,
lc_t_obs,
Expand Down Expand Up @@ -416,6 +421,7 @@ def fit_n_multi_z(

other = (
lg_n_target,
lg_n_thresh,
ran_key,
lc_z_obs,
lc_t_obs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
from diffsky.experimental import mc_lightcone_halos as mclh
from diffsky.experimental import precompute_ssp_phot as psspp
from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS
from diffsky.param_utils.spspop_param_utils import DEFAULT_SPSPOP_PARAMS
from diffsky.ssp_err_model.ssp_err_model import ZERO_SSPERR_PARAMS
from diffsky.param_utils.spspop_param_utils import (
DEFAULT_SPSPOP_PARAMS,
DEFAULT_SPSPOP_U_PARAMS,
)
from diffsky.ssp_err_model.defaults import ZERO_SSPERR_PARAMS, ZERO_SSPERR_U_PARAMS
from diffstar.defaults import FB, T_TABLE_MIN
from diffstar.diffstarpop.defaults import DEFAULT_DIFFSTARPOP_PARAMS
from diffstar.diffstarpop.defaults import (
DEFAULT_DIFFSTARPOP_PARAMS,
DEFAULT_DIFFSTARPOP_U_PARAMS,
)
from diffstar.diffstarpop.kernels.params.params_diffstarpopfits_mgash import (
DiffstarPop_Params_Diffstarpopfits_mgash,
)
Expand All @@ -24,6 +30,8 @@
from diffhtwo.experimental.data_loaders import retrieve_tcurves
from diffhtwo.experimental.utils import zbin_volume

from .. import n_mag_opt

TEST_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_PATH = os.path.join(TEST_DIR, "..", "data_loaders")

Expand Down Expand Up @@ -92,7 +100,7 @@
lc_halopop_mah_params_multi_z = []
lc_halopop_nhalos_multi_z = []
lc_halopop_logmp0_multi_z = []
lc_halopop_lc_vol_mpc3_multi_z = []
lc_halopop_vol_mpc3_multi_z = []

t_table_multi_z = []
precomputed_ssp_mag_table_multi_z = []
Expand Down Expand Up @@ -135,7 +143,7 @@
lc_halopop_mah_params_multi_z.append(lc_halopop["mah_params"])
lc_halopop_logmp0_multi_z.append(lc_halopop["logmp0"])
lc_halopop_nhalos_multi_z.append(lc_halopop["nhalos"])
lc_halopop_lc_vol_mpc3_multi_z.append(lc_halopop["lc_vol_Mpc3"])
lc_halopop_vol_mpc3_multi_z.append(lc_halopop["lc_vol_Mpc3"])
t_table_multi_z.append(t_table)
precomputed_ssp_mag_table_multi_z.append(precomputed_ssp_mag_table)
z_phot_table_multi_z.append(z_phot_table)
Expand All @@ -147,7 +155,7 @@
lc_halopop_mah_params_multi_z = jnp.asarray(lc_halopop_mah_params_multi_z)
lc_halopop_logmp0_multi_z = jnp.asarray(lc_halopop_logmp0_multi_z)
lc_halopop_nhalos_multi_z = jnp.asarray(lc_halopop_nhalos_multi_z)
lc_halopop_lc_vol_mpc3_multi_z = jnp.asarray(lc_halopop_lc_vol_mpc3_multi_z)
lc_halopop_vol_mpc3_multi_z = jnp.asarray(lc_halopop_vol_mpc3_multi_z)
t_table_multi_z = jnp.asarray(t_table_multi_z)
precomputed_ssp_mag_table_multi_z = jnp.asarray(precomputed_ssp_mag_table_multi_z)
z_phot_table_multi_z = jnp.asarray(z_phot_table_multi_z)
Expand All @@ -157,15 +165,14 @@

ran_key, n_key = jran.split(ran_key, 2)
n_args_multi_z = (
DIFFSTARPOP_UM_plus_exsitu,
DEFAULT_SPSPOP_PARAMS,
n_key,
lc_halopop_z_obs_multi_z,
lc_halopop_t_obs_multi_z,
lc_halopop_mah_params_multi_z,
lc_halopop_logmp0_multi_z,
lc_halopop_nhalos_multi_z,
lc_halopop_lc_vol_mpc3_multi_z,
lc_halopop_vol_mpc3_multi_z,
t_table_multi_z,
ssp_data,
precomputed_ssp_mag_table_multi_z,
Expand All @@ -181,8 +188,47 @@
FB,
)

lg_n_multi_z, lg_n_avg_err_multi_z = n_mag.n_mag_kern_multi_z(*n_args_multi_z)
lg_n_multi_z, lg_n_avg_err_multi_z = n_mag.n_mag_kern_multi_z(
DIFFSTARPOP_UM_plus_exsitu, *n_args_multi_z
)
lg_n_data_err_lh_multi_z = jnp.stack((lg_n_multi_z, lg_n_avg_err_multi_z), axis=1)

lg_n_multi_z2, lg_n_avg_err_multi_z2 = n_mag.n_mag_kern_multi_z(
DEFAULT_DIFFSTARPOP_PARAMS, *n_args_multi_z
)

# loss w/ DEFAULT_DIFFSTARPOP when DIFFSTARPOP_UM_plus_exsitu is the target data
u_diffstarpop_theta2, u_diffstarpop_unravel = ravel_pytree(DEFAULT_DIFFSTARPOP_U_PARAMS)
u_spspop_theta2, u_spspop_unravel = ravel_pytree(DEFAULT_SPSPOP_U_PARAMS)
u_ssp_err_pop_theta2, u_ssp_err_pop_unravel = ravel_pytree(ZERO_SSPERR_U_PARAMS)
u_theta2 = (u_diffstarpop_theta2, u_spspop_theta2, u_ssp_err_pop_theta2)

lg_n_thresh = -10
loss_args_multi_z = (
lg_n_thresh,
n_key,
lc_halopop_z_obs_multi_z,
lc_halopop_t_obs_multi_z,
lc_halopop_mah_params_multi_z,
lc_halopop_logmp0_multi_z,
lc_halopop_nhalos_multi_z,
lc_halopop_vol_mpc3_multi_z,
t_table_multi_z,
ssp_data,
precomputed_ssp_mag_table_multi_z,
z_phot_table_multi_z,
wave_eff_table_multi_z,
DEFAULT_MZR_PARAMS,
DEFAULT_SCATTER_PARAMS,
lh_centroids_multi_z,
dmag,
mag_column,
DEFAULT_COSMOLOGY,
FB,
)
loss_multi_z = n_mag_opt._loss_kern_multi_z(
u_theta2, lg_n_data_err_lh_multi_z, *loss_args_multi_z
)

for zbin in range(0, len(zbins)):
zmin = zbins[zbin][0]
Expand Down Expand Up @@ -240,3 +286,32 @@
)
lg_n_single_z, lg_n_avg_err_single_z = n_mag.n_mag_kern(*n_args_single_z)
assert np.allclose(lg_n_multi_z[zbin], lg_n_single_z)

loss_args_single_z = (
lg_n_thresh,
n_key,
jnp.array(lc_halopop["z_obs"]),
lc_halopop["t_obs"],
lc_halopop["mah_params"],
lc_halopop["logmp0"],
lc_halopop["nhalos"],
lc_halopop["lc_vol_Mpc3"],
t_table,
ssp_data,
precomputed_ssp_mag_table,
z_phot_table,
wave_eff_table,
DEFAULT_MZR_PARAMS,
DEFAULT_SCATTER_PARAMS,
lh_centroids,
dmag,
mag_column,
DEFAULT_COSMOLOGY,
FB,
)
lg_n_data_err_lh_single_z = jnp.vstack((lg_n_single_z, lg_n_avg_err_single_z))

loss_single_z = n_mag_opt._loss_kern(
u_theta2, lg_n_data_err_lh_single_z, *loss_args_single_z
)
assert np.isclose(loss_multi_z[zbin], loss_single_z)
2 changes: 2 additions & 0 deletions diffhtwo/experimental/tests/test_n_mag_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
bin_edges = np.arange(18.0 - dmag / 2, 26.0, dmag)
bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
bin_centers = bin_centers.reshape(bin_centers.size, 1)
lg_n_thresh = -8

ran_key, n_key = jran.split(ran_key, 2)
lg_n_true, lg_n_avg_err_true = n_mag.n_mag_kern(
Expand Down Expand Up @@ -100,6 +101,7 @@
loss_hist, grad_hist, u_theta_fit = n_mag_opt.fit_n(
u_diffstarpop_theta_default,
lg_n_true,
lg_n_thresh,
fit_n_key,
jnp.array(lc_halopop["z_obs"]),
lc_halopop["t_obs"],
Expand Down
Loading