diff --git a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py index 41c8380..e83a94a 100644 --- a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py +++ b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py @@ -3,11 +3,18 @@ from diffsky.experimental import lc_phot_kern from diffsky.experimental import mc_lightcone_halos as mclh from diffsky.experimental import precompute_ssp_phot as psspp +from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS +from diffsky.param_utils.spspop_param_utils import get_unbounded_spspop_params_tw_dust +from diffsky.ssp_err_model.defaults import get_unbounded_ssperr_params from diffstar.defaults import FB, T_TABLE_MIN +from diffstar.diffstarpop import get_unbounded_diffstarpop_params from dsps.cosmology import flat_wcdm from dsps.cosmology.defaults import DEFAULT_COSMOLOGY +from dsps.metallicity.umzr import DEFAULT_MZR_PARAMS from jax import random as jran +from jax.flatten_util import ravel_pytree +from .. import n_mag_opt from ..utils import zbin_volume try: @@ -40,8 +47,11 @@ def plot_n_ugriz( label1, label2, saveAs, + lh_centroids=None, + lg_n_data_err_lh=None, + lg_n_thresh=None, lgmp_min=10.0, - sky_area_degsq=0.25, + sky_area_degsq=0.1, cosmo_params=DEFAULT_COSMOLOGY, fb=FB, ): @@ -217,6 +227,73 @@ def plot_n_ugriz( plt.savefig(saveAs) plt.show() + # Output loss based on lh_centroids, not 1D histograms as above + if lh_centroids is not None: + lc_nhalos = np.ones(lc_halopop["logmp0"].shape) + ran_key, n_key = jran.split(ran_key, 2) + + # 1 + u_diffstarpop_params1 = get_unbounded_diffstarpop_params(diffstarpop_params1) + u_diffstarpop_theta1, u_diffstarpop_unravel = ravel_pytree( + u_diffstarpop_params1 + ) + + u_spspop_params1 = get_unbounded_spspop_params_tw_dust(spspop_params1) + u_spspop_theta1, u_spspop_unravel = ravel_pytree(u_spspop_params1) + + u_ssp_err_pop_params1 = get_unbounded_ssperr_params(ssp_err_pop_params1) + u_ssp_err_pop_theta1, u_ssp_err_pop_unravel = ravel_pytree( + u_ssp_err_pop_params1 + ) + + u_theta1 = (u_diffstarpop_theta1, u_spspop_theta1, u_ssp_err_pop_theta1) + + # 2 + u_diffstarpop_params2 = get_unbounded_diffstarpop_params(diffstarpop_params2) + u_diffstarpop_theta2, u_diffstarpop_unravel = ravel_pytree( + u_diffstarpop_params2 + ) + + u_spspop_params2 = get_unbounded_spspop_params_tw_dust(spspop_params2) + u_spspop_theta2, u_spspop_unravel = ravel_pytree(u_spspop_params2) + + u_ssp_err_pop_params2 = get_unbounded_ssperr_params(ssp_err_pop_params2) + u_ssp_err_pop_theta2, u_ssp_err_pop_unravel = ravel_pytree( + u_ssp_err_pop_params2 + ) + + u_theta1 = (u_diffstarpop_theta1, u_spspop_theta1, u_ssp_err_pop_theta1) + u_theta2 = (u_diffstarpop_theta2, u_spspop_theta2, u_ssp_err_pop_theta2) + + loss_args = ( + lg_n_thresh, + n_key, + lc_halopop["z_obs"], + lc_halopop["t_obs"], + lc_halopop["mah_params"], + lc_halopop["logmp0"], + lc_nhalos, + lc_vol_mpc3, + t_table, + ssp_data, + precomputed_ssp_mag_table, + z_phot_table, + wave_eff_table, + DEFAULT_MZR_PARAMS, + DEFAULT_SCATTER_PARAMS, + lh_centroids, + dmag, + mag_column, + DEFAULT_COSMOLOGY, + FB, + ) + + loss1 = n_mag_opt._loss_kern(u_theta1, lg_n_data_err_lh, *loss_args) + loss2 = n_mag_opt._loss_kern(u_theta2, lg_n_data_err_lh, *loss_args) + + print(f"default loss = {loss1:.2f}") + print(f"fit loss = {loss2:.2f}") + def get_obs_colors_mag(lc_phot, mag_column): num_halos, n_bands = lc_phot.obs_mags_q.shape diff --git a/diffhtwo/experimental/n_mag.py b/diffhtwo/experimental/n_mag.py index 6b78589..839ad89 100644 --- a/diffhtwo/experimental/n_mag.py +++ b/diffhtwo/experimental/n_mag.py @@ -378,7 +378,7 @@ def Gehrels_low_eq12(Ngal): @jjit -def get_n_data_err(N, vol, N_floor=1e-3): +def get_n_data_err(N, vol, N_floor=1e-12): N = jnp.where(N > N_floor, N, N_floor) lg_n = jnp.log10(N / vol) diff --git a/diffhtwo/experimental/n_mag_opt.py b/diffhtwo/experimental/n_mag_opt.py index 2988249..b0c28a1 100644 --- a/diffhtwo/experimental/n_mag_opt.py +++ b/diffhtwo/experimental/n_mag_opt.py @@ -39,8 +39,8 @@ @jjit -def _mse_w(lg_n_pred, lg_n_target, lg_n_target_err): - mask = lg_n_target > -8.0 +def _mse_w(lg_n_pred, lg_n_target, lg_n_target_err, lg_n_thresh): + mask = lg_n_target > lg_n_thresh nbins = jnp.maximum(jnp.sum(mask), 1) resid = lg_n_pred - lg_n_target @@ -197,6 +197,7 @@ def _mse_w(lg_n_pred, lg_n_target, lg_n_target_err): def _loss_kern( u_theta, lg_n_target, + lg_n_thresh, ran_key, lc_z_obs, lc_t_obs, @@ -273,7 +274,7 @@ def _loss_kern( fb, ) - return _mse_w(lg_n_model, lg_n_target[0], lg_n_target[1]) + return _mse_w(lg_n_model, lg_n_target[0], lg_n_target[1], lg_n_thresh) loss_and_grad = jjit(value_and_grad(_loss_kern)) @@ -283,6 +284,7 @@ def _loss_kern( def fit_n( u_theta_init, lg_n_target, + lg_n_thresh, ran_key, lc_z_obs, lc_t_obs, @@ -310,6 +312,7 @@ def fit_n( other = ( lg_n_target, + lg_n_thresh, ran_key, lc_z_obs, lc_t_obs, @@ -349,6 +352,7 @@ def _opt_update(opt_state, i): None, 0, None, + None, 0, 0, 0, @@ -389,6 +393,7 @@ def fit_n_multi_z( u_theta_init, trainable, lg_n_target, + lg_n_thresh, ran_key, lc_z_obs, lc_t_obs, @@ -416,6 +421,7 @@ def fit_n_multi_z( other = ( lg_n_target, + lg_n_thresh, ran_key, lc_z_obs, lc_t_obs, diff --git a/diffhtwo/experimental/tests/test_n_mag_kern_multi_z.py b/diffhtwo/experimental/tests/test_n_mag_multi_z.py similarity index 71% rename from diffhtwo/experimental/tests/test_n_mag_kern_multi_z.py rename to diffhtwo/experimental/tests/test_n_mag_multi_z.py index ca1bd06..596798f 100644 --- a/diffhtwo/experimental/tests/test_n_mag_kern_multi_z.py +++ b/diffhtwo/experimental/tests/test_n_mag_multi_z.py @@ -6,10 +6,16 @@ from diffsky.experimental import mc_lightcone_halos as mclh from diffsky.experimental import precompute_ssp_phot as psspp from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS -from diffsky.param_utils.spspop_param_utils import DEFAULT_SPSPOP_PARAMS -from diffsky.ssp_err_model.ssp_err_model import ZERO_SSPERR_PARAMS +from diffsky.param_utils.spspop_param_utils import ( + DEFAULT_SPSPOP_PARAMS, + DEFAULT_SPSPOP_U_PARAMS, +) +from diffsky.ssp_err_model.defaults import ZERO_SSPERR_PARAMS, ZERO_SSPERR_U_PARAMS from diffstar.defaults import FB, T_TABLE_MIN -from diffstar.diffstarpop.defaults import DEFAULT_DIFFSTARPOP_PARAMS +from diffstar.diffstarpop.defaults import ( + DEFAULT_DIFFSTARPOP_PARAMS, + DEFAULT_DIFFSTARPOP_U_PARAMS, +) from diffstar.diffstarpop.kernels.params.params_diffstarpopfits_mgash import ( DiffstarPop_Params_Diffstarpopfits_mgash, ) @@ -24,6 +30,8 @@ from diffhtwo.experimental.data_loaders import retrieve_tcurves from diffhtwo.experimental.utils import zbin_volume +from .. import n_mag_opt + TEST_DIR = os.path.dirname(os.path.abspath(__file__)) DATA_PATH = os.path.join(TEST_DIR, "..", "data_loaders") @@ -92,7 +100,7 @@ lc_halopop_mah_params_multi_z = [] lc_halopop_nhalos_multi_z = [] lc_halopop_logmp0_multi_z = [] -lc_halopop_lc_vol_mpc3_multi_z = [] +lc_halopop_vol_mpc3_multi_z = [] t_table_multi_z = [] precomputed_ssp_mag_table_multi_z = [] @@ -135,7 +143,7 @@ lc_halopop_mah_params_multi_z.append(lc_halopop["mah_params"]) lc_halopop_logmp0_multi_z.append(lc_halopop["logmp0"]) lc_halopop_nhalos_multi_z.append(lc_halopop["nhalos"]) - lc_halopop_lc_vol_mpc3_multi_z.append(lc_halopop["lc_vol_Mpc3"]) + lc_halopop_vol_mpc3_multi_z.append(lc_halopop["lc_vol_Mpc3"]) t_table_multi_z.append(t_table) precomputed_ssp_mag_table_multi_z.append(precomputed_ssp_mag_table) z_phot_table_multi_z.append(z_phot_table) @@ -147,7 +155,7 @@ lc_halopop_mah_params_multi_z = jnp.asarray(lc_halopop_mah_params_multi_z) lc_halopop_logmp0_multi_z = jnp.asarray(lc_halopop_logmp0_multi_z) lc_halopop_nhalos_multi_z = jnp.asarray(lc_halopop_nhalos_multi_z) -lc_halopop_lc_vol_mpc3_multi_z = jnp.asarray(lc_halopop_lc_vol_mpc3_multi_z) +lc_halopop_vol_mpc3_multi_z = jnp.asarray(lc_halopop_vol_mpc3_multi_z) t_table_multi_z = jnp.asarray(t_table_multi_z) precomputed_ssp_mag_table_multi_z = jnp.asarray(precomputed_ssp_mag_table_multi_z) z_phot_table_multi_z = jnp.asarray(z_phot_table_multi_z) @@ -157,7 +165,6 @@ ran_key, n_key = jran.split(ran_key, 2) n_args_multi_z = ( - DIFFSTARPOP_UM_plus_exsitu, DEFAULT_SPSPOP_PARAMS, n_key, lc_halopop_z_obs_multi_z, @@ -165,7 +172,7 @@ lc_halopop_mah_params_multi_z, lc_halopop_logmp0_multi_z, lc_halopop_nhalos_multi_z, - lc_halopop_lc_vol_mpc3_multi_z, + lc_halopop_vol_mpc3_multi_z, t_table_multi_z, ssp_data, precomputed_ssp_mag_table_multi_z, @@ -181,8 +188,47 @@ FB, ) -lg_n_multi_z, lg_n_avg_err_multi_z = n_mag.n_mag_kern_multi_z(*n_args_multi_z) +lg_n_multi_z, lg_n_avg_err_multi_z = n_mag.n_mag_kern_multi_z( + DIFFSTARPOP_UM_plus_exsitu, *n_args_multi_z +) +lg_n_data_err_lh_multi_z = jnp.stack((lg_n_multi_z, lg_n_avg_err_multi_z), axis=1) + +lg_n_multi_z2, lg_n_avg_err_multi_z2 = n_mag.n_mag_kern_multi_z( + DEFAULT_DIFFSTARPOP_PARAMS, *n_args_multi_z +) +# loss w/ DEFAULT_DIFFSTARPOP when DIFFSTARPOP_UM_plus_exsitu is the target data +u_diffstarpop_theta2, u_diffstarpop_unravel = ravel_pytree(DEFAULT_DIFFSTARPOP_U_PARAMS) +u_spspop_theta2, u_spspop_unravel = ravel_pytree(DEFAULT_SPSPOP_U_PARAMS) +u_ssp_err_pop_theta2, u_ssp_err_pop_unravel = ravel_pytree(ZERO_SSPERR_U_PARAMS) +u_theta2 = (u_diffstarpop_theta2, u_spspop_theta2, u_ssp_err_pop_theta2) + +lg_n_thresh = -10 +loss_args_multi_z = ( + lg_n_thresh, + n_key, + lc_halopop_z_obs_multi_z, + lc_halopop_t_obs_multi_z, + lc_halopop_mah_params_multi_z, + lc_halopop_logmp0_multi_z, + lc_halopop_nhalos_multi_z, + lc_halopop_vol_mpc3_multi_z, + t_table_multi_z, + ssp_data, + precomputed_ssp_mag_table_multi_z, + z_phot_table_multi_z, + wave_eff_table_multi_z, + DEFAULT_MZR_PARAMS, + DEFAULT_SCATTER_PARAMS, + lh_centroids_multi_z, + dmag, + mag_column, + DEFAULT_COSMOLOGY, + FB, +) +loss_multi_z = n_mag_opt._loss_kern_multi_z( + u_theta2, lg_n_data_err_lh_multi_z, *loss_args_multi_z +) for zbin in range(0, len(zbins)): zmin = zbins[zbin][0] @@ -240,3 +286,32 @@ ) lg_n_single_z, lg_n_avg_err_single_z = n_mag.n_mag_kern(*n_args_single_z) assert np.allclose(lg_n_multi_z[zbin], lg_n_single_z) + + loss_args_single_z = ( + lg_n_thresh, + n_key, + jnp.array(lc_halopop["z_obs"]), + lc_halopop["t_obs"], + lc_halopop["mah_params"], + lc_halopop["logmp0"], + lc_halopop["nhalos"], + lc_halopop["lc_vol_Mpc3"], + t_table, + ssp_data, + precomputed_ssp_mag_table, + z_phot_table, + wave_eff_table, + DEFAULT_MZR_PARAMS, + DEFAULT_SCATTER_PARAMS, + lh_centroids, + dmag, + mag_column, + DEFAULT_COSMOLOGY, + FB, + ) + lg_n_data_err_lh_single_z = jnp.vstack((lg_n_single_z, lg_n_avg_err_single_z)) + + loss_single_z = n_mag_opt._loss_kern( + u_theta2, lg_n_data_err_lh_single_z, *loss_args_single_z + ) + assert np.isclose(loss_multi_z[zbin], loss_single_z) diff --git a/diffhtwo/experimental/tests/test_n_mag_opt.py b/diffhtwo/experimental/tests/test_n_mag_opt.py index ec2040b..f7ee86a 100644 --- a/diffhtwo/experimental/tests/test_n_mag_opt.py +++ b/diffhtwo/experimental/tests/test_n_mag_opt.py @@ -67,6 +67,7 @@ bin_edges = np.arange(18.0 - dmag / 2, 26.0, dmag) bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2 bin_centers = bin_centers.reshape(bin_centers.size, 1) +lg_n_thresh = -8 ran_key, n_key = jran.split(ran_key, 2) lg_n_true, lg_n_avg_err_true = n_mag.n_mag_kern( @@ -100,6 +101,7 @@ loss_hist, grad_hist, u_theta_fit = n_mag_opt.fit_n( u_diffstarpop_theta_default, lg_n_true, + lg_n_thresh, fit_n_key, jnp.array(lc_halopop["z_obs"]), lc_halopop["t_obs"],