From 643da56c606d20d5e9832f046664570c75b82d1e Mon Sep 17 00:00:00 2001 From: Jaxen Godfrey Date: Fri, 18 Apr 2025 08:09:50 -0700 Subject: [PATCH 1/4] gwinferno/pipeline/utils.py --- gwinferno/pipeline/utils.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/gwinferno/pipeline/utils.py b/gwinferno/pipeline/utils.py index d0bc56f2..bf033234 100644 --- a/gwinferno/pipeline/utils.py +++ b/gwinferno/pipeline/utils.py @@ -8,7 +8,7 @@ import xarray as xr from gwinferno.interpolation import LogXLogYBSpline -from gwinferno.interpolation import LogYBSpline +from gwinferno.interpolation import LogYBSpline, LogXBSpline from gwinferno.models.bsplines.separable import BSplineIIDSpinMagnitudes from gwinferno.models.bsplines.separable import BSplineIIDSpinTilts from gwinferno.models.bsplines.separable import BSplineIndependentSpinMagnitudes @@ -26,11 +26,11 @@ def load_base_parser(): ) parser.add_argument("--run-label", type=str) parser.add_argument("--result-dir", type=str) - parser.add_argument("--m-nsplines", type=str, default=50) - parser.add_argument("--q-nsplines", type=str, default=30) - parser.add_argument("--a-nsplines", type=str, default=16) - parser.add_argument("--tilt-nsplines", type=str, default=16) - parser.add_argument("--z-nsplines", type=str, default=20) + parser.add_argument("--m-nsplines", type=int, default=50) + parser.add_argument("--q-nsplines", type=int, default=30) + parser.add_argument("--a-nsplines", type=int, default=16) + parser.add_argument("--tilt-nsplines", type=int, default=16) + parser.add_argument("--z-nsplines", type=int, default=20) parser.add_argument("--mmin", type=float, default=3.0) parser.add_argument("--mmax", type=float, default=100.0) parser.add_argument("--chains", type=int, default=1) @@ -48,7 +48,7 @@ def load_base_parser(): """ -def load_pe_and_injections_as_dict(file, ignore=None): +def load_pe_and_injections_as_dict(file, ignore=[]): """Load PE and injection file created by `gwinferno.preprocess.data_collection.save_posterior_samples_and_injection_datasets_as_idata()`. Parameters @@ -74,7 +74,7 @@ def load_pe_and_injections_as_dict(file, ignore=None): data = az.from_netcdf(file) print(f"data file {file} loaded") - if ignore is not None: + if ignore: sel = np.zeros(data.pe_data["event"].values.shape, dtype=bool) for gw in ignore: sel += data.pe_data["event"] == gw @@ -90,7 +90,6 @@ def load_pe_and_injections_as_dict(file, ignore=None): total_inj = data.inj_data.attrs["total_generated"] obs_time = data.inj_data.attrs["analysis_time"] nObs = data.pe_data.posteriors.shape[0] - constants = {"total_inj": total_inj, "obs_time": obs_time, "nObs": nObs} return pedict, injdict, constants, param_names @@ -145,13 +144,13 @@ def setup_bspline_spin_models(pedict, injdict, a1_nsplines, ct1_nsplines, IID=Fa return mag_model, tilt_model - -def setup_powerlaw_spline_redshift_model(pedict, injdict, z_nsplines): +def setup_powerlaw_spline_redshift_model(pedict, injdict, z_nsplines, basis = LogXBSpline): print("initializing redshift model") return PowerlawSplineRedshiftModel( z_nsplines, pedict["redshift"], injdict["redshift"], + basis = basis ) @@ -210,8 +209,7 @@ def bspline_spin_prior(a_nsplines=None, ct_nsplines=None, a_tau=None, ct_tau=Non def bspline_redshift_prior(z_nsplines=None, z_tau=None, name=None, z_cs_sig=1, z_deg=2): name = "_" + name if name is not None else "" - z_cs = numpyro.sample("z_cs" + name, dist.Normal(0, z_cs_sig), sample_shape=(z_nsplines - 1,)) - z_cs = jnp.concatenate([jnp.zeros(1), z_cs]) + z_cs = numpyro.sample("z_cs" + name, dist.Normal(0, z_cs_sig), sample_shape=(z_nsplines,)) numpyro.factor("z_smoothing_prior" + name, apply_difference_prior(z_cs, z_tau, degree=z_deg)) return z_cs @@ -236,7 +234,7 @@ def pdf_dict_to_xarray(pdf_dict, param_dict, n_samples, subpop_names=None): xr_dict = xr_dict | pdfs else: z = {"redshift_pdfs": (["draw", "redshift"], pdf_dict["redshift"])} - xr_dict | z + xr_dict = xr_dict | z del pdf_dict["redshift"] for i, nm in enumerate(subpop_names): single = {f"{nm}_{key}_pdfs": (["draw", key], item[i]) for key, item in pdf_dict.items()} From c9aa6d35cf8ff819aeab0810d1ba3db0cb1faf86 Mon Sep 17 00:00:00 2001 From: Jaxen Godfrey Date: Fri, 18 Apr 2025 08:12:36 -0700 Subject: [PATCH 2/4] remove line --- gwinferno/postprocess/calculations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gwinferno/postprocess/calculations.py b/gwinferno/postprocess/calculations.py index 939982cc..7847e753 100644 --- a/gwinferno/postprocess/calculations.py +++ b/gwinferno/postprocess/calculations.py @@ -267,7 +267,6 @@ def calculate_powerlaw_spline_rate_of_z_ppds(lamb, z_cs, rate, z_model, pop_frac rs = np.zeros((len(lamb), len(zs))) def calc_rz(cs, la, r, f): - cs = jnp.concatenate([jnp.array([0]), cs]) return r * f * jnp.power(1.0 + zs, la) * jnp.exp(z_model.interpolator.project(z_model.norm_design_matrix, cs)) calc_rz = jit(calc_rz) From 6e904f27f064d52a71ac247f2f95afb17fa52c54 Mon Sep 17 00:00:00 2001 From: Jaxen Godfrey Date: Fri, 18 Apr 2025 08:13:46 -0700 Subject: [PATCH 3/4] support for subpopulations --- gwinferno/postprocess/plot.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/gwinferno/postprocess/plot.py b/gwinferno/postprocess/plot.py index d897fa24..edea4109 100644 --- a/gwinferno/postprocess/plot.py +++ b/gwinferno/postprocess/plot.py @@ -15,7 +15,7 @@ def plot_pdf(x, pdf, label, color="blue", loglog=True, alpha=1.0): plt.fill_between(x, low, high, color=color, alpha=0.1) -def plot_mass_pdfs(mpdfs, qpdfs, m1, q, names, label, result_dir, save=True, colors=["red", "blue", "green"]): +def plot_mass_pdfs(mpdfs, qpdfs, m1, q, names, label, result_dir, save=True, colors=["red", "blue", "green"], alt_label = ''): plt.figure(figsize=(15, 5)) for i in range(len(mpdfs)): @@ -39,11 +39,11 @@ def plot_mass_pdfs(mpdfs, qpdfs, m1, q, names, label, result_dir, save=True, col plt.xlim(0, 1) plt.show() if save: - plt.savefig(result_dir + f"/mass_ratio_pdf_{label}.png", dpi=100) + plt.savefig(result_dir + f"/mass_ratio_pdf_{label}_{alt_label}.png", dpi=100) plt.close() -def plot_spin_pdfs(a_pdfs, tilt_pdfs, aa, cc, names, label, result_dir, save=True, colors=["red", "blue", "green"], secondary=False): +def plot_spin_pdfs(a_pdfs, tilt_pdfs, aa, cc, names, label, result_dir, save=True, colors=["red", "blue", "green"], secondary=False, alt_label = ''): if secondary: comp = "2" @@ -59,7 +59,7 @@ def plot_spin_pdfs(a_pdfs, tilt_pdfs, aa, cc, names, label, result_dir, save=Tru plt.xlim(0, 1) plt.show() if save: - plt.savefig(result_dir + f"/spin_mag{comp}_pdf_{label}.png", dpi=100) + plt.savefig(result_dir + f"/spin_mag{comp}_pdf_{label}_{alt_label}.png", dpi=100) plt.close() plt.figure(figsize=(10, 7)) @@ -71,18 +71,20 @@ def plot_spin_pdfs(a_pdfs, tilt_pdfs, aa, cc, names, label, result_dir, save=Tru plt.xlim(-1, 1) plt.show() if save: - plt.savefig(result_dir + f"/cos_tilt{comp}_pdf_{label}.png", dpi=100) + plt.savefig(result_dir + f"/cos_tilt{comp}_pdf_{label}_{alt_label}.png", dpi=100) plt.close() -def plot_rate_of_z_pdfs(z_pdfs, z, label, result_dir, save=True): +def plot_rate_of_z_pdfs(z_pdfs, z, names, label, result_dir, colors = ['red', 'blue', 'green'], save=True): plt.figure(figsize=(10, 7)) - plot_pdf(z, z_pdfs, "redshift") + for i in range(len(z_pdfs)): + plot_pdf(z, z_pdfs[i], names[i], color = colors[i], loglog=False) plt.xlabel("z") plt.ylabel("R(z)") plt.legend() - plt.xlim(z[0], 1.5) + plt.yscale('log') + plt.xlim(0, 1.5) plt.ylim(5, 1e3) plt.show() if save: From 7f18044f391fe4238358acee5c8fce03d5477d02 Mon Sep 17 00:00:00 2001 From: Jaxen Godfrey Date: Fri, 18 Apr 2025 08:46:53 -0700 Subject: [PATCH 4/4] pre-commit --- gwinferno/pipeline/utils.py | 13 +++++-------- gwinferno/postprocess/plot.py | 12 ++++++------ 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/gwinferno/pipeline/utils.py b/gwinferno/pipeline/utils.py index bf033234..77e4207f 100644 --- a/gwinferno/pipeline/utils.py +++ b/gwinferno/pipeline/utils.py @@ -7,8 +7,9 @@ import numpyro.distributions as dist import xarray as xr +from gwinferno.interpolation import LogXBSpline from gwinferno.interpolation import LogXLogYBSpline -from gwinferno.interpolation import LogYBSpline, LogXBSpline +from gwinferno.interpolation import LogYBSpline from gwinferno.models.bsplines.separable import BSplineIIDSpinMagnitudes from gwinferno.models.bsplines.separable import BSplineIIDSpinTilts from gwinferno.models.bsplines.separable import BSplineIndependentSpinMagnitudes @@ -144,14 +145,10 @@ def setup_bspline_spin_models(pedict, injdict, a1_nsplines, ct1_nsplines, IID=Fa return mag_model, tilt_model -def setup_powerlaw_spline_redshift_model(pedict, injdict, z_nsplines, basis = LogXBSpline): + +def setup_powerlaw_spline_redshift_model(pedict, injdict, z_nsplines, basis=LogXBSpline): print("initializing redshift model") - return PowerlawSplineRedshiftModel( - z_nsplines, - pedict["redshift"], - injdict["redshift"], - basis = basis - ) + return PowerlawSplineRedshiftModel(z_nsplines, pedict["redshift"], injdict["redshift"], basis=basis) """ diff --git a/gwinferno/postprocess/plot.py b/gwinferno/postprocess/plot.py index edea4109..6a4d0f12 100644 --- a/gwinferno/postprocess/plot.py +++ b/gwinferno/postprocess/plot.py @@ -15,7 +15,7 @@ def plot_pdf(x, pdf, label, color="blue", loglog=True, alpha=1.0): plt.fill_between(x, low, high, color=color, alpha=0.1) -def plot_mass_pdfs(mpdfs, qpdfs, m1, q, names, label, result_dir, save=True, colors=["red", "blue", "green"], alt_label = ''): +def plot_mass_pdfs(mpdfs, qpdfs, m1, q, names, label, result_dir, save=True, colors=["red", "blue", "green"], alt_label=""): plt.figure(figsize=(15, 5)) for i in range(len(mpdfs)): @@ -43,7 +43,7 @@ def plot_mass_pdfs(mpdfs, qpdfs, m1, q, names, label, result_dir, save=True, col plt.close() -def plot_spin_pdfs(a_pdfs, tilt_pdfs, aa, cc, names, label, result_dir, save=True, colors=["red", "blue", "green"], secondary=False, alt_label = ''): +def plot_spin_pdfs(a_pdfs, tilt_pdfs, aa, cc, names, label, result_dir, save=True, colors=["red", "blue", "green"], secondary=False, alt_label=""): if secondary: comp = "2" @@ -75,15 +75,15 @@ def plot_spin_pdfs(a_pdfs, tilt_pdfs, aa, cc, names, label, result_dir, save=Tru plt.close() -def plot_rate_of_z_pdfs(z_pdfs, z, names, label, result_dir, colors = ['red', 'blue', 'green'], save=True): +def plot_rate_of_z_pdfs(z_pdfs, z, names, label, result_dir, colors=["red", "blue", "green"], save=True): plt.figure(figsize=(10, 7)) - for i in range(len(z_pdfs)): - plot_pdf(z, z_pdfs[i], names[i], color = colors[i], loglog=False) + for i in range(len(z_pdfs)): + plot_pdf(z, z_pdfs[i], names[i], color=colors[i], loglog=False) plt.xlabel("z") plt.ylabel("R(z)") plt.legend() - plt.yscale('log') + plt.yscale("log") plt.xlim(0, 1.5) plt.ylim(5, 1e3) plt.show()