From 38bab8f021d622ccefba2e0294ccc3fb90b659ba Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 14:53:35 -0500 Subject: [PATCH 01/15] Initial commit From f2c861abc3d361bf305eee918b2d91ca71005915 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 15:02:39 -0500 Subject: [PATCH 02/15] introduce lg_n_thresh in n_mag_opt To allow flexibility in setting the lowest lg_n_data used for fitting --- diffhtwo/experimental/n_mag_opt.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/diffhtwo/experimental/n_mag_opt.py b/diffhtwo/experimental/n_mag_opt.py index 2988249..b0c28a1 100644 --- a/diffhtwo/experimental/n_mag_opt.py +++ b/diffhtwo/experimental/n_mag_opt.py @@ -39,8 +39,8 @@ @jjit -def _mse_w(lg_n_pred, lg_n_target, lg_n_target_err): - mask = lg_n_target > -8.0 +def _mse_w(lg_n_pred, lg_n_target, lg_n_target_err, lg_n_thresh): + mask = lg_n_target > lg_n_thresh nbins = jnp.maximum(jnp.sum(mask), 1) resid = lg_n_pred - lg_n_target @@ -197,6 +197,7 @@ def _mse_w(lg_n_pred, lg_n_target, lg_n_target_err): def _loss_kern( u_theta, lg_n_target, + lg_n_thresh, ran_key, lc_z_obs, lc_t_obs, @@ -273,7 +274,7 @@ def _loss_kern( fb, ) - return _mse_w(lg_n_model, lg_n_target[0], lg_n_target[1]) + return _mse_w(lg_n_model, lg_n_target[0], lg_n_target[1], lg_n_thresh) loss_and_grad = jjit(value_and_grad(_loss_kern)) @@ -283,6 +284,7 @@ def _loss_kern( def fit_n( u_theta_init, lg_n_target, + lg_n_thresh, ran_key, lc_z_obs, lc_t_obs, @@ -310,6 +312,7 @@ def fit_n( other = ( lg_n_target, + lg_n_thresh, ran_key, lc_z_obs, lc_t_obs, @@ -349,6 +352,7 @@ def _opt_update(opt_state, i): None, 0, None, + None, 0, 0, 0, @@ -389,6 +393,7 @@ def fit_n_multi_z( u_theta_init, trainable, lg_n_target, + lg_n_thresh, ran_key, lc_z_obs, lc_t_obs, @@ -416,6 +421,7 @@ def fit_n_multi_z( other = ( lg_n_target, + lg_n_thresh, ran_key, lc_z_obs, lc_t_obs, From 5aa2cc3710bf8570738a8f5e263f034bf41256f3 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 15:11:25 -0500 Subject: [PATCH 03/15] Update test_n_mag_opt.py --- diffhtwo/experimental/tests/test_n_mag_opt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/diffhtwo/experimental/tests/test_n_mag_opt.py b/diffhtwo/experimental/tests/test_n_mag_opt.py index ec2040b..d884d01 100644 --- a/diffhtwo/experimental/tests/test_n_mag_opt.py +++ b/diffhtwo/experimental/tests/test_n_mag_opt.py @@ -67,6 +67,7 @@ bin_edges = np.arange(18.0 - dmag / 2, 26.0, dmag) bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2 bin_centers = bin_centers.reshape(bin_centers.size, 1) +lg_n_thresh = 1e-8 ran_key, n_key = jran.split(ran_key, 2) lg_n_true, lg_n_avg_err_true = n_mag.n_mag_kern( @@ -100,6 +101,7 @@ loss_hist, grad_hist, u_theta_fit = n_mag_opt.fit_n( u_diffstarpop_theta_default, lg_n_true, + lg_n_thresh, fit_n_key, jnp.array(lc_halopop["z_obs"]), lc_halopop["t_obs"], From 7cd9a5593f57e43d63e433b3f0eaba2f1a80bcf6 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 15:52:37 -0500 Subject: [PATCH 04/15] Update test_n_mag_opt.py --- diffhtwo/experimental/tests/test_n_mag_opt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffhtwo/experimental/tests/test_n_mag_opt.py b/diffhtwo/experimental/tests/test_n_mag_opt.py index d884d01..f7ee86a 100644 --- a/diffhtwo/experimental/tests/test_n_mag_opt.py +++ b/diffhtwo/experimental/tests/test_n_mag_opt.py @@ -67,7 +67,7 @@ bin_edges = np.arange(18.0 - dmag / 2, 26.0, dmag) bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2 bin_centers = bin_centers.reshape(bin_centers.size, 1) -lg_n_thresh = 1e-8 +lg_n_thresh = -8 ran_key, n_key = jran.split(ran_key, 2) lg_n_true, lg_n_avg_err_true = n_mag.n_mag_kern( From f6bf362b13b5d73aa5e6de7c59b4cbf07a8288b5 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 16:19:27 -0500 Subject: [PATCH 05/15] Update plot_mag_color_1d_hist.py --- .../diagnostics/plot_mag_color_1d_hist.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py index 41c8380..a6a94c1 100644 --- a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py +++ b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py @@ -3,11 +3,18 @@ from diffsky.experimental import lc_phot_kern from diffsky.experimental import mc_lightcone_halos as mclh from diffsky.experimental import precompute_ssp_phot as psspp +from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS +from diffsky.param_utils.spspop_param_utils import get_unbounded_spspop_params_tw_dust +from diffsky.ssp_err_model.defaults import get_unbounded_ssperr_params from diffstar.defaults import FB, T_TABLE_MIN +from diffstar.diffstarpop import get_unbounded_diffstarpop_params from dsps.cosmology import flat_wcdm from dsps.cosmology.defaults import DEFAULT_COSMOLOGY +from dsps.metallicity.umzr import DEFAULT_MZR_PARAMS from jax import random as jran +from jax.flatten_util import ravel_pytree +from .. import n_mag_opt from ..utils import zbin_volume try: @@ -40,6 +47,8 @@ def plot_n_ugriz( label1, label2, saveAs, + lh_centroids, + lg_n_data_err_lh, lgmp_min=10.0, sky_area_degsq=0.25, cosmo_params=DEFAULT_COSMOLOGY, @@ -217,6 +226,64 @@ def plot_n_ugriz( plt.savefig(saveAs) plt.show() + # Output loss based on lh_centroids, not 1D histograms as above + lg_n_thresh = -8 + lc_nhalos = np.ones_like(lh_centroids[0, :, 0]) + ran_key, n_key = jran.split(ran_key, 2) + + # 1 + u_diffstarpop_params1 = get_unbounded_diffstarpop_params(diffstarpop_params1) + u_diffstarpop_theta1, u_diffstarpop_unravel = ravel_pytree(u_diffstarpop_params1) + + u_spspop_params1 = get_unbounded_spspop_params_tw_dust(spspop_params1) + u_spspop_theta1, u_spspop_unravel = ravel_pytree(u_spspop_params1) + + u_ssp_err_pop_params1 = get_unbounded_ssperr_params(ssp_err_pop_params1) + u_ssp_err_pop_theta1, u_ssp_err_pop_unravel = ravel_pytree(u_ssp_err_pop_params1) + + u_theta1 = (u_diffstarpop_theta1, u_spspop_theta1, u_ssp_err_pop_theta1) + + # 2 + u_diffstarpop_params2 = get_unbounded_diffstarpop_params(diffstarpop_params2) + u_diffstarpop_theta2, u_diffstarpop_unravel = ravel_pytree(u_diffstarpop_params2) + + u_spspop_params2 = get_unbounded_spspop_params_tw_dust(spspop_params2) + u_spspop_theta2, u_spspop_unravel = ravel_pytree(u_spspop_params2) + + u_ssp_err_pop_params2 = get_unbounded_ssperr_params(ssp_err_pop_params2) + u_ssp_err_pop_theta2, u_ssp_err_pop_unravel = ravel_pytree(u_ssp_err_pop_params2) + + u_theta1 = (u_diffstarpop_theta1, u_spspop_theta1, u_ssp_err_pop_theta1) + u_theta2 = (u_diffstarpop_theta2, u_spspop_theta2, u_ssp_err_pop_theta2) + + args = ( + lg_n_thresh, + n_key, + lc_halopop["z_obs"], + lc_halopop["t_obs"], + lc_halopop["mah_params"], + lc_halopop["logmp0"], + lc_nhalos, + lc_vol_mpc3, + t_table, + ssp_data, + precomputed_ssp_mag_table, + z_phot_table, + wave_eff_table, + DEFAULT_MZR_PARAMS, + DEFAULT_SCATTER_PARAMS, + lh_centroids, + dmag, + mag_column, + DEFAULT_COSMOLOGY, + FB, + ) + + loss1 = n_mag_opt._loss_kern(u_theta1, lg_n_data_err_lh, *args) + loss2 = n_mag_opt._loss_kern(u_theta2, lg_n_data_err_lh, *args) + + print(loss1, loss2) + def get_obs_colors_mag(lc_phot, mag_column): num_halos, n_bands = lc_phot.obs_mags_q.shape From 9ef12de76fec98be5ae7d72226aaee4e7e0f182d Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 16:23:39 -0500 Subject: [PATCH 06/15] Update plot_mag_color_1d_hist.py --- diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py index a6a94c1..a60e5e5 100644 --- a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py +++ b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py @@ -50,7 +50,7 @@ def plot_n_ugriz( lh_centroids, lg_n_data_err_lh, lgmp_min=10.0, - sky_area_degsq=0.25, + sky_area_degsq=0.1, cosmo_params=DEFAULT_COSMOLOGY, fb=FB, ): From 9e36829e04f17c29196bbe07760dbd1981464b0e Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 16:26:51 -0500 Subject: [PATCH 07/15] Update plot_mag_color_1d_hist.py --- diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py index a60e5e5..c7571b7 100644 --- a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py +++ b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py @@ -60,6 +60,7 @@ def plot_n_ugriz( lc_halopop = mclh.mc_lightcone_host_halo_diffmah(*lc_args) lc_vol_mpc3 = zbin_volume(sky_area_degsq, zlow=zmin, zhigh=zmax).value data_vol_mpc3 = zbin_volume(data_sky_area_degsq, zlow=zmin, zhigh=zmax).value + print(lc_halopop) n_z_phot_table = 15 From a1e3336ef9472843c9bf890f1cb071bb74a27cc3 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 16:30:22 -0500 Subject: [PATCH 08/15] Update plot_mag_color_1d_hist.py --- diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py index c7571b7..9275373 100644 --- a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py +++ b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py @@ -60,7 +60,6 @@ def plot_n_ugriz( lc_halopop = mclh.mc_lightcone_host_halo_diffmah(*lc_args) lc_vol_mpc3 = zbin_volume(sky_area_degsq, zlow=zmin, zhigh=zmax).value data_vol_mpc3 = zbin_volume(data_sky_area_degsq, zlow=zmin, zhigh=zmax).value - print(lc_halopop) n_z_phot_table = 15 @@ -229,7 +228,7 @@ def plot_n_ugriz( # Output loss based on lh_centroids, not 1D histograms as above lg_n_thresh = -8 - lc_nhalos = np.ones_like(lh_centroids[0, :, 0]) + lc_nhalos = np.ones(lc_halopop["logmp0"].shape) ran_key, n_key = jran.split(ran_key, 2) # 1 From 80efc03aeb16f26b1d4a9db4194a95066566b351 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 16:35:04 -0500 Subject: [PATCH 09/15] Update plot_mag_color_1d_hist.py --- diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py index 9275373..a225c29 100644 --- a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py +++ b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py @@ -282,7 +282,8 @@ def plot_n_ugriz( loss1 = n_mag_opt._loss_kern(u_theta1, lg_n_data_err_lh, *args) loss2 = n_mag_opt._loss_kern(u_theta2, lg_n_data_err_lh, *args) - print(loss1, loss2) + print(f"default loss = {loss1:.2f}") + print(f"fit loss = {loss2:.2f}") def get_obs_colors_mag(lc_phot, mag_column): From 6129a7f452e485a2d6a486257414cd7ed7b937e0 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 17:47:58 -0500 Subject: [PATCH 10/15] test if _loss_kern_multi_z == _loss_kern --- ..._kern_multi_z.py => test_n_mag_multi_z.py} | 93 +++++++++++++++++-- 1 file changed, 84 insertions(+), 9 deletions(-) rename diffhtwo/experimental/tests/{test_n_mag_kern_multi_z.py => test_n_mag_multi_z.py} (72%) diff --git a/diffhtwo/experimental/tests/test_n_mag_kern_multi_z.py b/diffhtwo/experimental/tests/test_n_mag_multi_z.py similarity index 72% rename from diffhtwo/experimental/tests/test_n_mag_kern_multi_z.py rename to diffhtwo/experimental/tests/test_n_mag_multi_z.py index ca1bd06..366cf73 100644 --- a/diffhtwo/experimental/tests/test_n_mag_kern_multi_z.py +++ b/diffhtwo/experimental/tests/test_n_mag_multi_z.py @@ -6,10 +6,16 @@ from diffsky.experimental import mc_lightcone_halos as mclh from diffsky.experimental import precompute_ssp_phot as psspp from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS -from diffsky.param_utils.spspop_param_utils import DEFAULT_SPSPOP_PARAMS -from diffsky.ssp_err_model.ssp_err_model import ZERO_SSPERR_PARAMS +from diffsky.param_utils.spspop_param_utils import ( + DEFAULT_SPSPOP_PARAMS, + DEFAULT_SPSPOP_U_PARAMS, +) +from diffsky.ssp_err_model.ssp_err_model import ZERO_SSPERR_PARAMS, ZERO_SSPERR_U_PARAMS from diffstar.defaults import FB, T_TABLE_MIN -from diffstar.diffstarpop.defaults import DEFAULT_DIFFSTARPOP_PARAMS +from diffstar.diffstarpop.defaults import ( + DEFAULT_DIFFSTARPOP_PARAMS, + DEFAULT_DIFFSTARPOP_U_PARAMS, +) from diffstar.diffstarpop.kernels.params.params_diffstarpopfits_mgash import ( DiffstarPop_Params_Diffstarpopfits_mgash, ) @@ -24,6 +30,8 @@ from diffhtwo.experimental.data_loaders import retrieve_tcurves from diffhtwo.experimental.utils import zbin_volume +from .. import n_mag_opt + TEST_DIR = os.path.dirname(os.path.abspath(__file__)) DATA_PATH = os.path.join(TEST_DIR, "..", "data_loaders") @@ -92,7 +100,7 @@ lc_halopop_mah_params_multi_z = [] lc_halopop_nhalos_multi_z = [] lc_halopop_logmp0_multi_z = [] -lc_halopop_lc_vol_mpc3_multi_z = [] +lc_halopop_vol_mpc3_multi_z = [] t_table_multi_z = [] precomputed_ssp_mag_table_multi_z = [] @@ -135,7 +143,7 @@ lc_halopop_mah_params_multi_z.append(lc_halopop["mah_params"]) lc_halopop_logmp0_multi_z.append(lc_halopop["logmp0"]) lc_halopop_nhalos_multi_z.append(lc_halopop["nhalos"]) - lc_halopop_lc_vol_mpc3_multi_z.append(lc_halopop["lc_vol_Mpc3"]) + lc_halopop_vol_mpc3_multi_z.append(lc_halopop["lc_vol_Mpc3"]) t_table_multi_z.append(t_table) precomputed_ssp_mag_table_multi_z.append(precomputed_ssp_mag_table) z_phot_table_multi_z.append(z_phot_table) @@ -147,7 +155,7 @@ lc_halopop_mah_params_multi_z = jnp.asarray(lc_halopop_mah_params_multi_z) lc_halopop_logmp0_multi_z = jnp.asarray(lc_halopop_logmp0_multi_z) lc_halopop_nhalos_multi_z = jnp.asarray(lc_halopop_nhalos_multi_z) -lc_halopop_lc_vol_mpc3_multi_z = jnp.asarray(lc_halopop_lc_vol_mpc3_multi_z) +lc_halopop_vol_mpc3_multi_z = jnp.asarray(lc_halopop_vol_mpc3_multi_z) t_table_multi_z = jnp.asarray(t_table_multi_z) precomputed_ssp_mag_table_multi_z = jnp.asarray(precomputed_ssp_mag_table_multi_z) z_phot_table_multi_z = jnp.asarray(z_phot_table_multi_z) @@ -157,7 +165,6 @@ ran_key, n_key = jran.split(ran_key, 2) n_args_multi_z = ( - DIFFSTARPOP_UM_plus_exsitu, DEFAULT_SPSPOP_PARAMS, n_key, lc_halopop_z_obs_multi_z, @@ -165,7 +172,7 @@ lc_halopop_mah_params_multi_z, lc_halopop_logmp0_multi_z, lc_halopop_nhalos_multi_z, - lc_halopop_lc_vol_mpc3_multi_z, + lc_halopop_vol_mpc3_multi_z, t_table_multi_z, ssp_data, precomputed_ssp_mag_table_multi_z, @@ -181,8 +188,47 @@ FB, ) -lg_n_multi_z, lg_n_avg_err_multi_z = n_mag.n_mag_kern_multi_z(*n_args_multi_z) +lg_n_multi_z, lg_n_avg_err_multi_z = n_mag.n_mag_kern_multi_z( + DIFFSTARPOP_UM_plus_exsitu, *n_args_multi_z +) +lg_n_data_err_lh_multi_z = jnp.stack((lg_n_multi_z, lg_n_avg_err_multi_z), axis=1) + +lg_n_multi_z2, lg_n_avg_err_multi_z2 = n_mag.n_mag_kern_multi_z( + DEFAULT_DIFFSTARPOP_PARAMS, *n_args_multi_z +) +# loss w/ DEFAULT_DIFFSTARPOP when DIFFSTARPOP_UM_plus_exsitu is the target data +u_diffstarpop_theta2, u_diffstarpop_unravel = ravel_pytree(DEFAULT_DIFFSTARPOP_U_PARAMS) +u_spspop_theta2, u_spspop_unravel = ravel_pytree(DEFAULT_SPSPOP_U_PARAMS) +u_ssp_err_pop_theta2, u_ssp_err_pop_unravel = ravel_pytree(ZERO_SSPERR_U_PARAMS) +u_theta2 = (u_diffstarpop_theta2, u_spspop_theta2, u_ssp_err_pop_theta2) + +lg_n_thresh = -10 +loss_args_multi_z = ( + lg_n_thresh, + n_key, + lc_halopop_z_obs_multi_z, + lc_halopop_t_obs_multi_z, + lc_halopop_mah_params_multi_z, + lc_halopop_logmp0_multi_z, + lc_halopop_nhalos_multi_z, + lc_halopop_vol_mpc3_multi_z, + t_table_multi_z, + ssp_data, + precomputed_ssp_mag_table_multi_z, + z_phot_table_multi_z, + wave_eff_table_multi_z, + DEFAULT_MZR_PARAMS, + DEFAULT_SCATTER_PARAMS, + lh_centroids_multi_z, + dmag, + mag_column, + DEFAULT_COSMOLOGY, + FB, +) +loss_multi_z = n_mag_opt._loss_kern_multi_z( + u_theta2, lg_n_data_err_lh_multi_z, *loss_args_multi_z +) for zbin in range(0, len(zbins)): zmin = zbins[zbin][0] @@ -240,3 +286,32 @@ ) lg_n_single_z, lg_n_avg_err_single_z = n_mag.n_mag_kern(*n_args_single_z) assert np.allclose(lg_n_multi_z[zbin], lg_n_single_z) + + loss_args_single_z = ( + lg_n_thresh, + n_key, + jnp.array(lc_halopop["z_obs"]), + lc_halopop["t_obs"], + lc_halopop["mah_params"], + lc_halopop["logmp0"], + lc_halopop["nhalos"], + lc_halopop["lc_vol_Mpc3"], + t_table, + ssp_data, + precomputed_ssp_mag_table, + z_phot_table, + wave_eff_table, + DEFAULT_MZR_PARAMS, + DEFAULT_SCATTER_PARAMS, + lh_centroids, + dmag, + mag_column, + DEFAULT_COSMOLOGY, + FB, + ) + lg_n_data_err_lh_single_z = jnp.vstack((lg_n_single_z, lg_n_avg_err_single_z)) + + loss_single_z = n_mag_opt._loss_kern( + u_theta2, lg_n_data_err_lh_single_z, *loss_args_single_z + ) + assert np.isclose(loss_multi_z[zbin], loss_single_z) From 44d169e1949b2624ad41acd354a95389a6bc6d72 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 17:55:49 -0500 Subject: [PATCH 11/15] Update test_n_mag_multi_z.py --- diffhtwo/experimental/tests/test_n_mag_multi_z.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffhtwo/experimental/tests/test_n_mag_multi_z.py b/diffhtwo/experimental/tests/test_n_mag_multi_z.py index 366cf73..596798f 100644 --- a/diffhtwo/experimental/tests/test_n_mag_multi_z.py +++ b/diffhtwo/experimental/tests/test_n_mag_multi_z.py @@ -10,7 +10,7 @@ DEFAULT_SPSPOP_PARAMS, DEFAULT_SPSPOP_U_PARAMS, ) -from diffsky.ssp_err_model.ssp_err_model import ZERO_SSPERR_PARAMS, ZERO_SSPERR_U_PARAMS +from diffsky.ssp_err_model.defaults import ZERO_SSPERR_PARAMS, ZERO_SSPERR_U_PARAMS from diffstar.defaults import FB, T_TABLE_MIN from diffstar.diffstarpop.defaults import ( DEFAULT_DIFFSTARPOP_PARAMS, From 17be3ea59bc759a1825205587d51a8519fe3584f Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 18:05:10 -0500 Subject: [PATCH 12/15] Update plot_mag_color_1d_hist.py --- diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py index a225c29..b81cf7a 100644 --- a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py +++ b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py @@ -49,6 +49,7 @@ def plot_n_ugriz( saveAs, lh_centroids, lg_n_data_err_lh, + lg_n_thresh, lgmp_min=10.0, sky_area_degsq=0.1, cosmo_params=DEFAULT_COSMOLOGY, @@ -227,7 +228,6 @@ def plot_n_ugriz( plt.show() # Output loss based on lh_centroids, not 1D histograms as above - lg_n_thresh = -8 lc_nhalos = np.ones(lc_halopop["logmp0"].shape) ran_key, n_key = jran.split(ran_key, 2) From 8d836d232a33d611111c6a1ddc5dd7a4da51e422 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 19:55:34 -0500 Subject: [PATCH 13/15] Update plot_mag_color_1d_hist.py --- .../diagnostics/plot_mag_color_1d_hist.py | 109 ++++++++++-------- 1 file changed, 59 insertions(+), 50 deletions(-) diff --git a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py index b81cf7a..e83a94a 100644 --- a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py +++ b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py @@ -47,9 +47,9 @@ def plot_n_ugriz( label1, label2, saveAs, - lh_centroids, - lg_n_data_err_lh, - lg_n_thresh, + lh_centroids=None, + lg_n_data_err_lh=None, + lg_n_thresh=None, lgmp_min=10.0, sky_area_degsq=0.1, cosmo_params=DEFAULT_COSMOLOGY, @@ -228,62 +228,71 @@ def plot_n_ugriz( plt.show() # Output loss based on lh_centroids, not 1D histograms as above - lc_nhalos = np.ones(lc_halopop["logmp0"].shape) - ran_key, n_key = jran.split(ran_key, 2) - - # 1 - u_diffstarpop_params1 = get_unbounded_diffstarpop_params(diffstarpop_params1) - u_diffstarpop_theta1, u_diffstarpop_unravel = ravel_pytree(u_diffstarpop_params1) - - u_spspop_params1 = get_unbounded_spspop_params_tw_dust(spspop_params1) - u_spspop_theta1, u_spspop_unravel = ravel_pytree(u_spspop_params1) + if lh_centroids is not None: + lc_nhalos = np.ones(lc_halopop["logmp0"].shape) + ran_key, n_key = jran.split(ran_key, 2) + + # 1 + u_diffstarpop_params1 = get_unbounded_diffstarpop_params(diffstarpop_params1) + u_diffstarpop_theta1, u_diffstarpop_unravel = ravel_pytree( + u_diffstarpop_params1 + ) - u_ssp_err_pop_params1 = get_unbounded_ssperr_params(ssp_err_pop_params1) - u_ssp_err_pop_theta1, u_ssp_err_pop_unravel = ravel_pytree(u_ssp_err_pop_params1) + u_spspop_params1 = get_unbounded_spspop_params_tw_dust(spspop_params1) + u_spspop_theta1, u_spspop_unravel = ravel_pytree(u_spspop_params1) - u_theta1 = (u_diffstarpop_theta1, u_spspop_theta1, u_ssp_err_pop_theta1) + u_ssp_err_pop_params1 = get_unbounded_ssperr_params(ssp_err_pop_params1) + u_ssp_err_pop_theta1, u_ssp_err_pop_unravel = ravel_pytree( + u_ssp_err_pop_params1 + ) - # 2 - u_diffstarpop_params2 = get_unbounded_diffstarpop_params(diffstarpop_params2) - u_diffstarpop_theta2, u_diffstarpop_unravel = ravel_pytree(u_diffstarpop_params2) + u_theta1 = (u_diffstarpop_theta1, u_spspop_theta1, u_ssp_err_pop_theta1) - u_spspop_params2 = get_unbounded_spspop_params_tw_dust(spspop_params2) - u_spspop_theta2, u_spspop_unravel = ravel_pytree(u_spspop_params2) + # 2 + u_diffstarpop_params2 = get_unbounded_diffstarpop_params(diffstarpop_params2) + u_diffstarpop_theta2, u_diffstarpop_unravel = ravel_pytree( + u_diffstarpop_params2 + ) - u_ssp_err_pop_params2 = get_unbounded_ssperr_params(ssp_err_pop_params2) - u_ssp_err_pop_theta2, u_ssp_err_pop_unravel = ravel_pytree(u_ssp_err_pop_params2) + u_spspop_params2 = get_unbounded_spspop_params_tw_dust(spspop_params2) + u_spspop_theta2, u_spspop_unravel = ravel_pytree(u_spspop_params2) - u_theta1 = (u_diffstarpop_theta1, u_spspop_theta1, u_ssp_err_pop_theta1) - u_theta2 = (u_diffstarpop_theta2, u_spspop_theta2, u_ssp_err_pop_theta2) + u_ssp_err_pop_params2 = get_unbounded_ssperr_params(ssp_err_pop_params2) + u_ssp_err_pop_theta2, u_ssp_err_pop_unravel = ravel_pytree( + u_ssp_err_pop_params2 + ) - args = ( - lg_n_thresh, - n_key, - lc_halopop["z_obs"], - lc_halopop["t_obs"], - lc_halopop["mah_params"], - lc_halopop["logmp0"], - lc_nhalos, - lc_vol_mpc3, - t_table, - ssp_data, - precomputed_ssp_mag_table, - z_phot_table, - wave_eff_table, - DEFAULT_MZR_PARAMS, - DEFAULT_SCATTER_PARAMS, - lh_centroids, - dmag, - mag_column, - DEFAULT_COSMOLOGY, - FB, - ) + u_theta1 = (u_diffstarpop_theta1, u_spspop_theta1, u_ssp_err_pop_theta1) + u_theta2 = (u_diffstarpop_theta2, u_spspop_theta2, u_ssp_err_pop_theta2) + + loss_args = ( + lg_n_thresh, + n_key, + lc_halopop["z_obs"], + lc_halopop["t_obs"], + lc_halopop["mah_params"], + lc_halopop["logmp0"], + lc_nhalos, + lc_vol_mpc3, + t_table, + ssp_data, + precomputed_ssp_mag_table, + z_phot_table, + wave_eff_table, + DEFAULT_MZR_PARAMS, + DEFAULT_SCATTER_PARAMS, + lh_centroids, + dmag, + mag_column, + DEFAULT_COSMOLOGY, + FB, + ) - loss1 = n_mag_opt._loss_kern(u_theta1, lg_n_data_err_lh, *args) - loss2 = n_mag_opt._loss_kern(u_theta2, lg_n_data_err_lh, *args) + loss1 = n_mag_opt._loss_kern(u_theta1, lg_n_data_err_lh, *loss_args) + loss2 = n_mag_opt._loss_kern(u_theta2, lg_n_data_err_lh, *loss_args) - print(f"default loss = {loss1:.2f}") - print(f"fit loss = {loss2:.2f}") + print(f"default loss = {loss1:.2f}") + print(f"fit loss = {loss2:.2f}") def get_obs_colors_mag(lc_phot, mag_column): From 43b8fadca73b7ec2a88ef75d51096f09d5b45a0f Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 21:58:45 -0500 Subject: [PATCH 14/15] N_floor=1e-3 --> N_floor=1e-6 --- diffhtwo/experimental/n_mag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffhtwo/experimental/n_mag.py b/diffhtwo/experimental/n_mag.py index 6b78589..39ae86c 100644 --- a/diffhtwo/experimental/n_mag.py +++ b/diffhtwo/experimental/n_mag.py @@ -378,7 +378,7 @@ def Gehrels_low_eq12(Ngal): @jjit -def get_n_data_err(N, vol, N_floor=1e-3): +def get_n_data_err(N, vol, N_floor=1e-6): N = jnp.where(N > N_floor, N, N_floor) lg_n = jnp.log10(N / vol) From d10585b41b9928b15eab8dea381f68553d55c611 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 11 Jan 2026 22:23:47 -0500 Subject: [PATCH 15/15] N_floor=1e-6 --> N_floor=1e-12 --- diffhtwo/experimental/n_mag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffhtwo/experimental/n_mag.py b/diffhtwo/experimental/n_mag.py index 39ae86c..839ad89 100644 --- a/diffhtwo/experimental/n_mag.py +++ b/diffhtwo/experimental/n_mag.py @@ -378,7 +378,7 @@ def Gehrels_low_eq12(Ngal): @jjit -def get_n_data_err(N, vol, N_floor=1e-6): +def get_n_data_err(N, vol, N_floor=1e-12): N = jnp.where(N > N_floor, N, N_floor) lg_n = jnp.log10(N / vol)