diff --git a/diffhalos/ccshmf/tests/test_ccshmf.py b/diffhalos/ccshmf/tests/test_ccshmf.py deleted file mode 100644 index 59ceda2..0000000 --- a/diffhalos/ccshmf/tests/test_ccshmf.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -""" -import os -from glob import glob - -import numpy as np - -from ..ccshmf_model import DEFAULT_CCSHMF_PARAMS, predict_ccshmf - -_THIS_DRNAME = os.path.dirname(os.path.abspath(__file__)) -TESTING_DATA_DRN = os.path.join(_THIS_DRNAME, "testing_data") -BNPAT = "smdpl_cshmf_cuml_redshift_{0:.2f}_lgmhost_{1:.2f}.txt" - - -def _mse(pred, target): - diff = pred - target - return np.mean(diff**2) - - -def _mae(pred, target): - diff = pred - target - return np.mean(np.abs(diff)) - - -def infer_redshift_from_bname(bn): - return float(bn.split("_")[4]) - - -def infer_logmhost_from_bname(bn): - return float(bn.split("_")[-1][:-4]) - - -def test_predict_ccshmf_returns_finite_valued_expected_shape(): - lgmhost = 13.0 - nsubs = 100 - lgmuarr = np.linspace(-5, 0, nsubs) - pred = predict_ccshmf(DEFAULT_CCSHMF_PARAMS, lgmhost, lgmuarr) - assert pred.shape == lgmuarr.shape - assert np.all(np.isfinite(pred)) - - lgmu = -2.0 - pred = predict_ccshmf(DEFAULT_CCSHMF_PARAMS, lgmhost, lgmu) - assert pred.shape == () - assert np.all(np.isfinite(pred)) - - nhosts = 5 - lgmhostarr = np.linspace(12, 15, nhosts) - pred = predict_ccshmf(DEFAULT_CCSHMF_PARAMS, lgmhostarr, lgmu) - assert pred.shape == (nhosts,) - assert np.all(np.isfinite(pred)) - - nhosts = 5 - nsubs = nhosts - lgmhostarr = np.linspace(12, 15, nhosts) - lgmuarr = np.linspace(-5, 0, nsubs) - pred = predict_ccshmf(DEFAULT_CCSHMF_PARAMS, lgmhostarr, lgmuarr) - assert pred.shape == (nhosts,) - assert np.all(np.isfinite(pred)) - - -def test_predict_ccshmf_accurately_approximates_simulation_data(): - """This test loads some pretabulated CCSHMF data computed from SMPDL - and compares the simulation results to the predict_ccshmf function""" - fname_list = glob(os.path.join(TESTING_DATA_DRN, "smdpl_cshmf_*.txt")) - bname_list = [os.path.basename(fn) for fn in fname_list] - - zlist = np.unique([infer_redshift_from_bname(bn) for bn in bname_list]) - - for redshift in zlist: - zpat = "redshift_{:.2f}".format(redshift) - bname_list_z = [bn for bn in bname_list if zpat in bn] - lgmh_list_z = np.array([infer_logmhost_from_bname(bn) for bn in bname_list_z]) - lgmh_list_z = lgmh_list_z[lgmh_list_z > 12] - lgmhost_targets = np.sort(lgmh_list_z) - - for itarget in range(lgmhost_targets.size): - target_lgmhost = lgmhost_targets[itarget] - bn_sample = BNPAT.format(redshift, target_lgmhost) - cshmf_data_sample = np.loadtxt(os.path.join(TESTING_DATA_DRN, bn_sample)) - target_lgmu_bins, target_lg_ccshmf = ( - cshmf_data_sample[:, 0], - cshmf_data_sample[:, 1], - ) - pred_lg_ccshmf = predict_ccshmf( - DEFAULT_CCSHMF_PARAMS, target_lgmhost, target_lgmu_bins - ) - - loss_sq = _mse(pred_lg_ccshmf, target_lg_ccshmf) - assert np.sqrt(loss_sq) < 0.15 - - loss_mae = _mae(pred_lg_ccshmf, target_lg_ccshmf) - assert loss_mae < 0.06 diff --git a/diffhalos/diffmahpop/__init__.py b/diffhalos/diffmahpop/__init__.py deleted file mode 100644 index a8306b5..0000000 --- a/diffhalos/diffmahpop/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""""" - -# flake8: noqa - -from .diffmahnet_utils import DEFAULT_MAH_PARAMS, DEFAULT_MAH_U_PARAMS diff --git a/diffhalos/diffmahpop/diffmahnet_utils.py b/diffhalos/diffmahpop/diffmahnet_utils.py deleted file mode 100644 index e300ee6..0000000 --- a/diffhalos/diffmahpop/diffmahnet_utils.py +++ /dev/null @@ -1,344 +0,0 @@ -""" -Useful diffmahnet functions -See https://diffmahnet.readthedocs.io/en/latest/installation.html -""" - -import numpy as np -import os -import pathlib -import glob - -import jax -import jax.numpy as jnp - -from diffmah import mah_halopop -from diffmah import DEFAULT_MAH_PARAMS -from diffmah.diffmah_kernels import DEFAULT_MAH_U_PARAMS -from diffmah.diffmah_kernels import ( - get_unbounded_mah_params, - get_bounded_mah_params, -) - -import diffmahnet - -from .utils import rescale_mah_parameters - -DEFAULT_MAH_UPARAMS = get_unbounded_mah_params(DEFAULT_MAH_PARAMS) - -T_GRID_MIN = 0.5 -T_GRID_MAX = jnp.log10(13.8) -N_T_GRID = 100 - -__all__ = ( - "mc_mah_cenpop", - "get_mean_and_std_of_mah", - "get_mah_from_unbounded_params", - "load_diffmahnet_training_data", - "get_available_models", -) - - -def mc_mah_cenpop( - m_obs, - t_obs, - randkey, - n_sample=1, - centrals_model_key="cenflow_v2_0.eqx", - t_min=T_GRID_MIN, - t_max=T_GRID_MAX, - n_t=N_T_GRID, - return_mah_params=False, -): - """ - Generate MC realiations of central halo populations - using the ``diffmahnet`` code for ``diffmahpop``. - This function takes in a grid of halo mass and - cosmic time at observation and for each it generates - MC samples given random keys - - Parameters - ---------- - m_obs: ndarray of shape (n_halo, ) - grid of base-10 log of mass of the halos at observation, in Msun - - t_obs: ndarray of shape (n_halo, ) - grid of base-10 log of cosmic time at observation of each halo, in Gyr - - randkey: key - JAX random key - - n_sample: int - number of MC samples per (m_obs,t_obs) pair - - centrals_model_key: str - model name for centrals - - t_min: float - base-10 log of minimum value for time grid - at which to compute mah, in Gyr - - t_max: float - base-10 log of maximum value for time grid, - at which to compute mah, in Gyr - - n_t: int - number of points in time grid - - return_mah_params: bool - if True the MAH parameters from the normalizing flow - will also be returned - - Returns - ------- - cen_mah: ndarray of shape (n_sample*n_m_obs*n_t_obs, n_t) - base-10 log of halo mass assembly histories, - for all MC realizations, in Msun - - t_grid: ndarray of shape (n_sample*n_m_obs*n_t_obs, n_t) - cosmic time grid on which to compute MAHs, - for all MC realizations, in Gyr - - cenflow_diffmahparams: namedtuple - diffmah parameters from normalizing flow, - each parameter is a ndarray of shape(n_sample*n_m_obs*n_t_obs, ) - """ - # create diffmahnet model for centrals - centrals_model = diffmahnet.load_pretrained_model(centrals_model_key) - mc_diffmahnet_cenpop = centrals_model.make_mc_diffmahnet() - - # get a list of (m_obs, t_obs) for each MC realization - m_vals, t_vals = [ - jnp.repeat(x.flatten(), n_sample) - for x in jnp.meshgrid( - m_obs, - t_obs, - ) - ] - - # get diffmah parameters from the normalizing flow - keys = jax.random.split(randkey, 2) - cenflow_diffmahparams = mc_diffmahnet_cenpop( - centrals_model.get_params(), m_vals, t_vals, keys[0] - ) - - # construct time grids for each halo, given observation time - t_grid = jnp.linspace(t_min, t_vals, n_t).T - - # compute the uncorrected predicted observed halo masses - logm_obs_uncorrected = diffmahnet.log_mah_kern( - cenflow_diffmahparams, - t_grid, - t_max, - )[:, -1] - - # rescale the mah parameters to the correct logm0 - cenflow_diffmahparams = rescale_mah_parameters( - cenflow_diffmahparams, - m_vals, - logm_obs_uncorrected, - ) - - # compute mah with corrected parameters - cen_mah = diffmahnet.log_mah_kern( - cenflow_diffmahparams, - t_grid, - t_max, - ) - - if return_mah_params: - return cen_mah, t_grid, cenflow_diffmahparams - - return cen_mah, t_grid - - -def get_mean_and_std_of_mah(mah): - """ - Helper function to get the mean and 1-sigma - standard deviation of a sample of mah realizations - - Parameters - ---------- - mah: ndarray of shape (n_halo, n_t) - MAH of the population of halos - - Returns - ------- - mah_mean: ndarray of shape (n_halo, ) - mean of mah at each time - - mah_max: ndarray of shape (n_halo, ) - upper bound for 1-sigma band around mean - - mah_min: ndarray of shape (n_halo, ) - lower bound for 1-sigma band around mean - """ - n_t = mah.shape[1] - - mah_mean = np.zeros(n_t) - mah_max = np.zeros(n_t) - mah_min = np.zeros(n_t) - for t in range(n_t): - _mah = mah[:, t] - mah_mean[t] = np.mean(_mah) - _std = np.std(_mah) - mah_max[t] = mah_mean[t] + _std - mah_min[t] = mah_mean[t] - _std - - return mah_mean, mah_max, mah_min - - -def get_mah_from_unbounded_params( - mah_params_unbound, - logt0, - t_grid, - logm_obs, -): - """ - Helper function to generate the MAH from - a set of diffmah unbounded parameters, - for a population of halos - - Parameters - ---------- - mah_params_unbound: ndarray of shape (n_halo, n_mah_param) - unbounded ``diffmah`` parameters - (logm0, logtc, early_index, late_index, t_peak) - - logt0: float - base-10 log of the age of the Universe at z=0, in Gyr - - t_grid: ndarray of shape (n_t, ) - cosmic time grid at which to compute the MAH - - logm_obs: float - base-10 log of observed halo mass, in Msun - - Returns - ------- - log_mah: ndarray of shape (n_halo, n_t) - base-10 log of MAH, in Msun - """ - mah_params_bound = jnp.array( - [ - *get_bounded_mah_params( - DEFAULT_MAH_U_PARAMS._make(mah_params_unbound.T), - ) - ] - ) - - mah_params_uncorrected = DEFAULT_MAH_PARAMS._make(mah_params_bound) - - _, logm_obs_uncorrected = mah_halopop( - mah_params_uncorrected, - t_grid, - logt0, - ) - - # rescale the mah parameters to the correct logm0 - mah_params = rescale_mah_parameters( - mah_params_uncorrected, - logm_obs, - logm_obs_uncorrected[:, -1], - ) - - _, log_mah = mah_halopop(mah_params, t_grid, logt0) - - return log_mah - - -def load_diffmahnet_training_data( - path=None, - is_test: bool | str = False, - is_cens=True, -): - """ - Convenient function to load the data - used to train ``diffmahnet`` - - Parameters - ---------- - path: str - path to the training data folder; - is not provided directly, the environment variable - ``DIFFMAHNET_TRAINING_DATA`` will be used instead - - is_test: bool or str - slices the training data into smaller test data - - is_cens: bool - if True, data for centrals will be loaded, - if False, data for satellites will be loaded - - Returns - ------- - x_unbound: ndarray of shape (n_pdf_var, ) - PDF variables - - u: ndarray of shape (n_cond_var, ) - conditional variables - """ - if path is None: - try: - path = os.environ["DIFFMAHNET_TRAINING_DATA"] - except KeyError: - msg = ( - "Since you did not pass the 'filename' argument\n" - "then you must have the 'DIFFMAHNET_TRAINING_DATA' environment variable set.\n" - "Run first 'export ``DIFFMAHNET_TRAINING_DATA=path_to_data_folder``'" - ) - raise ValueError(msg) - - # Parse available training data files - tdata_files = glob.glob(str(pathlib.Path(path) / "*")) - filenames = [x.split("/")[-1] for x in tdata_files] - lgm_vals = np.array([float(x.split("_")[1]) for x in filenames]) - t_vals = np.array([float(x.split("_")[3]) for x in filenames]) - is_cens_vals = np.array([x.split(".")[-2] == "cens" for x in filenames]) - fileinfo = list( - zip( - tdata_files, - lgm_vals.tolist(), - t_vals.tolist(), - is_cens_vals.tolist(), - ) - ) - cen_file_inds = np.where(is_cens_vals)[0] - sat_file_inds = np.where(~is_cens_vals)[0] - - # Load data - test_train_file_split = 80 # about 25:75 test-train split ratio - if is_test == "both": - test_train_file_split = None - inds = cen_file_inds if is_cens else sat_file_inds - test_train_slice = slice(None, test_train_file_split) - if is_test: - test_train_slice = slice(test_train_file_split, None) - inds = inds[test_train_slice] - - x = [] # PDF variables - u = [] # conditional variables - for i in inds: - filename, lgm, t, is_cens_val = fileinfo[i] - assert is_cens == is_cens_val - x.append(np.load(filename)) - u.append(np.tile(np.array([[lgm, t]]), (x[-1].shape[0], 1))) - - x = jnp.concatenate(x, axis=0) - u = jnp.concatenate(u, axis=0) - - # Transfrorm x parameters from bounded to unbounded space - x_unbound = jnp.array( - [ - *get_unbounded_mah_params( - DEFAULT_MAH_PARAMS._make(x.T), - ) - ] - ).T - - isfinite = np.all((jnp.isfinite(x_unbound)), axis=1) - return x_unbound[isfinite], u[isfinite] - - -def get_available_models(): - available_names = diffmahnet.pretrained_model_names - print(available_names) diff --git a/diffhalos/diffmahpop/diffmahpop_utils.py b/diffhalos/diffmahpop/diffmahpop_utils.py deleted file mode 100644 index 6fa320f..0000000 --- a/diffhalos/diffmahpop/diffmahpop_utils.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -Useful diffmahnet functions -See https://github.com/ArgonneCPAC/diffmah/tree/main/diffmah/diffmahpop_kernels -""" - -import jax.numpy as jnp -from jax import jit as jjit -from jax import vmap - -from diffmah.diffmahpop_kernels.mc_bimod_cens import mc_cenpop -from diffmah.diffmah_kernels import _log_mah_kern -from diffmah.diffmahpop_kernels.bimod_censat_params import ( - DEFAULT_DIFFMAHPOP_PARAMS, -) - -from .utils import rescale_mah_parameters - -T_GRID_MIN = 0.5 -T_GRID_MAX = 13.8 -N_T_GRID = 100 - -__all__ = ("mc_mah_cenpop",) - -log_mah_kern_vmap = jjit(vmap(_log_mah_kern, in_axes=(0, None, None))) - - -def mc_mah_cenpop( - m_obs, - t_obs, - randkey, - logt0, - n_sample=1, - params=DEFAULT_DIFFMAHPOP_PARAMS, - t_min=T_GRID_MIN, - t_max=T_GRID_MAX, - n_t=N_T_GRID, - return_mah_params_and_der=False, -): - """ - Diffmahpop predictions for populations of halo MAHs - - Parameters - ---------- - m_obs: ndarray of shape (n_halo, ) - grid of base-10 log of mass of the halos at observation, in Msun - - t_obs: ndarray of shape (n_halo, ) - grid of base-10 log of cosmic time at observation of each halo, in Gyr - - randkey: key - JAX random key - - logt0: float - base-10 log of the age of the Universe at z=0, in Gyr - - n_sample: int - number of MC samples per (m_obs,t_obs) pair - - params: namedtuple - diffmah parameters - - t_min: float - base-10 log of minimum value for time grid - at which to compute mah, in Gyr - - t_max: float - base-10 log of maximum value for time grid, - at which to compute mah, in Gyr - - n_t: int - number of points in time grid - - return_mah_params_and_der: bool - if True the MAH parameters and MAH gradients - will also be returned - - Returns - ------- - log_mah: ndarray of shape (n_halo, n_t) - base-10 log of mah for each halo, in Msun - - t_grid: ndarray of shape (n_t, ) - cosmic time grid on which to compute MAHs, in Gyr - - mah_params: namedtuple of ndarrays of shape (n_halo,) - mah parameters for all halos in the population - """ - # get a list of (m_obs, t_obs) for each MC realization - m_vals, t_vals = [ - jnp.repeat(x.flatten(), n_sample) - for x in jnp.meshgrid( - m_obs, - t_obs, - ) - ] - - # construct time grids for each halo, given observation time - t_grid = jnp.linspace(t_min, t_max, n_t) - - # predict uncorrected MAHs - mah_params_uncorrected, _, log_mah_uncorrected = mc_cenpop( - params, - t_grid, - m_vals, - t_vals, - randkey, - logt0, - ) - - # rescale the mah parameters to the correct logm0 - mah_params = rescale_mah_parameters( - mah_params_uncorrected, - m_vals, - log_mah_uncorrected[:, -1], - ) - - # get the corrected MAHs - log_mah = log_mah_kern_vmap(mah_params, t_grid, logt0) - - if return_mah_params_and_der: - return log_mah, t_grid, mah_params - - return log_mah, t_grid diff --git a/diffhalos/diffmahpop/scripts/.gitignore b/diffhalos/diffmahpop/scripts/.gitignore deleted file mode 100644 index 643cb18..0000000 --- a/diffhalos/diffmahpop/scripts/.gitignore +++ /dev/null @@ -1 +0,0 @@ -runs/ \ No newline at end of file diff --git a/diffhalos/diffmahpop/scripts/__init__.py b/diffhalos/diffmahpop/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/diffhalos/diffmahpop/scripts/flowjax_train.py b/diffhalos/diffmahpop/scripts/flowjax_train.py deleted file mode 100644 index 5a0feef..0000000 --- a/diffhalos/diffmahpop/scripts/flowjax_train.py +++ /dev/null @@ -1,126 +0,0 @@ -import pathlib -import argparse - -import jax -import diffmahnet -from diffmahnet import datatools - -SAVE_DIR = pathlib.Path("./data/") -TRAIN_DATA_DIR = pathlib.Path( - "/lcrc/project/halotools/diffmahpop_data/NM_12_NT_9_ISTART_0_IEND_576/" -) - -NN_DEPTH = 2 -NN_WIDTH = 50 -FLOW_LAYERS = 8 -SAMPLE_FRAC = 1.0 - - -parser = argparse.ArgumentParser( - description="Train a DiffMahNet normalizing flow model." -) -parser.add_argument("SAVE_FILENAME", help="Filename to save the trained model.") -parser.add_argument( - "--save-dir", - type=str, - default=SAVE_DIR, - help="Directory to save the trained model.", -) -parser.add_argument( - "--train-data-dir", - type=str, - default=TRAIN_DATA_DIR, - help="Directory containing the training data.", -) -parser.add_argument( - "--initial-model", - type=str, - default=None, - help="Optional filename of an initial model to load.", -) -parser.add_argument("--sats", action="store_true") -parser.add_argument( - "--nn-depth", type=int, default=NN_DEPTH, help="Depth of the hidden neural network." -) -parser.add_argument( - "--nn-width", type=int, default=NN_WIDTH, help="Width of the hidden neural network." -) -parser.add_argument( - "--flow-layers", type=int, default=FLOW_LAYERS, help="Number of flow layers." -) -parser.add_argument( - "--include-test", action="store_true", help="Include test data in the training set." -) -parser.add_argument( - "--max-epochs", type=int, default=50, help="Number of training epochs." -) -parser.add_argument( - "--learning-rate", - type=float, - default=5e-4, - help="Learning rate for the built-in flowjax optimizer.", -) -parser.add_argument("--max-patience", type=float, default=10.0) -parser.add_argument( - "--sample-frac", - type=float, - default=SAMPLE_FRAC, - help="Fraction of training data to load.", -) -parser.add_argument( - "--seed", type=int, default=0, help="Random seed for reproducibility." -) - -if __name__ == "__main__": - # Parse arguments - args = parser.parse_args() - save_dir = pathlib.Path(args.save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - save_filename = args.SAVE_FILENAME - train_data_dir = pathlib.Path(args.train_data_dir) - is_cens = not args.sats - nn_depth = args.nn_depth - nn_width = args.nn_width - flow_layers = args.flow_layers - is_test = "both" if args.include_test else False - initial_model = args.initial_model - max_epochs = args.max_epochs - sample_frac = args.sample_frac - - key = jax.random.key(args.seed) - key1, key2 = jax.random.split(key) - - # Load training data and flow model - train_data = datatools.DataHolder( - train_data_dir, - is_cens=is_cens, - is_test=is_test, - sample_frac=sample_frac, - randkey=key1, - ) - if initial_model is not None: - initial_model = save_dir / initial_model - flow = diffmahnet.DiffMahFlow.load(initial_model) - else: - flow = diffmahnet.DiffMahFlow( - scaler=train_data.scaler, - nn_depth=nn_depth, - nn_width=nn_width, - flow_layers=flow_layers, - ) - print("Number of parameters =", flow.get_params().size) - - # Train the flow model - if max_epochs > 0: - print("Training data shapes:", train_data.x.shape, train_data.u.shape) - flow.init_fit( - train_data.x, - train_data.u, - randkey=key2, - max_epochs=max_epochs, - learning_rate=args.learning_rate, - max_patience=args.max_patience, - ) - - # Save the trained model - flow.save(save_dir / save_filename) diff --git a/diffhalos/diffmahpop/scripts/flowjax_train_float64.py b/diffhalos/diffmahpop/scripts/flowjax_train_float64.py deleted file mode 100644 index e92e252..0000000 --- a/diffhalos/diffmahpop/scripts/flowjax_train_float64.py +++ /dev/null @@ -1,131 +0,0 @@ -# flake8: noqa: E402 -from jax import config - -config.update("jax_enable_x64", True) - -import pathlib -import argparse - -import jax -import diffmahnet -from diffmahnet import datatools - -SAVE_DIR = pathlib.Path("./data/") -TRAIN_DATA_DIR = pathlib.Path( - "/lcrc/project/halotools/diffmahpop_data/NM_12_NT_9_ISTART_0_IEND_576/" -) - -NN_DEPTH = 2 -NN_WIDTH = 50 -FLOW_LAYERS = 8 -SAMPLE_FRAC = 1.0 - - -parser = argparse.ArgumentParser( - description="Train a DiffMahNet normalizing flow model." -) -parser.add_argument("SAVE_FILENAME", help="Filename to save the trained model.") -parser.add_argument( - "--save-dir", - type=str, - default=SAVE_DIR, - help="Directory to save the trained model.", -) -parser.add_argument( - "--train-data-dir", - type=str, - default=TRAIN_DATA_DIR, - help="Directory containing the training data.", -) -parser.add_argument( - "--initial-model", - type=str, - default=None, - help="Optional filename of an initial model to load.", -) -parser.add_argument("--sats", action="store_true") -parser.add_argument( - "--nn-depth", type=int, default=NN_DEPTH, help="Depth of the hidden neural network." -) -parser.add_argument( - "--nn-width", type=int, default=NN_WIDTH, help="Width of the hidden neural network." -) -parser.add_argument( - "--flow-layers", type=int, default=FLOW_LAYERS, help="Number of flow layers." -) -parser.add_argument( - "--include-test", action="store_true", help="Include test data in the training set." -) -parser.add_argument( - "--max-epochs", type=int, default=50, help="Number of training epochs." -) -parser.add_argument( - "--learning-rate", - type=float, - default=5e-4, - help="Learning rate for the built-in flowjax optimizer.", -) -parser.add_argument("--max-patience", type=float, default=10.0) -parser.add_argument( - "--sample-frac", - type=float, - default=SAMPLE_FRAC, - help="Fraction of training data to load.", -) -parser.add_argument( - "--seed", type=int, default=0, help="Random seed for reproducibility." -) - -if __name__ == "__main__": - # Parse arguments - args = parser.parse_args() - save_dir = pathlib.Path(args.save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - save_filename = args.SAVE_FILENAME - train_data_dir = pathlib.Path(args.train_data_dir) - is_cens = not args.sats - nn_depth = args.nn_depth - nn_width = args.nn_width - flow_layers = args.flow_layers - is_test = "both" if args.include_test else False - initial_model = args.initial_model - max_epochs = args.max_epochs - sample_frac = args.sample_frac - - key = jax.random.key(args.seed) - key1, key2 = jax.random.split(key) - - # Load training data and flow model - train_data = datatools.DataHolder( - train_data_dir, - is_cens=is_cens, - is_test=is_test, - sample_frac=sample_frac, - randkey=key1, - ) - if initial_model is not None: - initial_model = save_dir / initial_model - flow = diffmahnet.DiffMahFlow.load(initial_model) - else: - flow = diffmahnet.DiffMahFlow( - scaler=train_data.scaler, - nn_depth=nn_depth, - nn_width=nn_width, - flow_layers=flow_layers, - ) - print("Number of parameters =", flow.get_params().size) - - # Train the flow model - if max_epochs > 0: - print("Training data shapes:", train_data.x.shape, train_data.u.shape) - flow.init_fit( - train_data.x, - train_data.u, - randkey=key2, - max_epochs=max_epochs, - learning_rate=args.learning_rate, - max_patience=args.max_patience, - ) - - # Save the trained model - flow.save(save_dir / save_filename) diff --git a/diffhalos/diffmahpop/scripts/kdecent_train.py b/diffhalos/diffmahpop/scripts/kdecent_train.py deleted file mode 100644 index b810806..0000000 --- a/diffhalos/diffmahpop/scripts/kdecent_train.py +++ /dev/null @@ -1,264 +0,0 @@ -import pathlib -import argparse - -import jax -import jax.numpy as jnp -import equinox as eqx - -from diffopt import kdescent -import diffmahnet -from diffmahnet import datatools - -SAVE_DIR = pathlib.Path("./data/") -TRAIN_DATA_DIR = pathlib.Path( - "/lcrc/project/halotools/diffmahpop_data/NM_12_NT_9_ISTART_0_IEND_576/" -) - -NN_DEPTH = 2 -NN_WIDTH = 50 -FLOW_LAYERS = 8 -SAMPLE_FRAC = 1.0 -NUM_KERNELS = 20 -NUM_FOURIER_KERNELS = 0 -LEARNING_RATE = 1e-4 - - -class KDescentLoss: - """ - Custom loss function to fit flowjax model - """ - - def __init__( - self, - train_data, - sample_size=None, - randkey=None, - num_kernels=20, - num_fourier_kernels=0, - ): - randkey = jax.random.key(0) if randkey is None else randkey - # t0 = 13.8 - # self.logt0 = np.log10(t0) - self.logt0 = train_data.logt0 - self.sample_size = sample_size - self.xscaler = train_data.x_scaler - self.uscaler = train_data.u_scaler - - self.tgrids, self.log_mah = train_data.get_tgrid_and_log_mah(randkey) - self.m_obs = train_data.m_obs - self.t_obs = train_data.t_obs - assert self.log_mah.ndim == self.tgrids.ndim == 2 - assert self.m_obs.ndim == self.t_obs.ndim == 1 - assert ( - self.log_mah.shape[0] - == self.m_obs.shape[0] - == self.t_obs.shape[0] - == self.tgrids.shape[0] - ) - self.condition = jnp.array([self.m_obs, self.t_obs]).T - - # Combine m and t with condition (m_obs, t_obs), since we always have - # an equivalent sampling of the conditional variables - # and this saves us from having to generate many separate - # KCalc instances at different conditional value bins - self.training_combined = ( - jnp.array( - [ - self.log_mah, - self.tgrids, - self.tile(self.m_obs), - self.tile(self.t_obs), - ] - ) - .reshape((4, -1)) - .T - ) - self.kde = kdescent.KCalc( - self.training_combined, - num_kernels=num_kernels, - num_fourier_kernels=num_fourier_kernels, - ) - - def tile(self, arr): - return jnp.tile(arr[..., None], (1, self.tgrids.shape[1])) - - @eqx.filter_jit - def __call__(self, diffmahflow, randkey): - """Compute the loss using kdescent""" - key0, key1, key2, key3 = jax.random.split(randkey, 4) - if self.sample_size is None: - tsamp = slice(None) - else: - tsamp = jax.random.choice( - key0, - self.training_combined.shape[0], - (self.sample_size,), - replace=False, - ) - mah_params = diffmahflow.sample( - self.condition[tsamp], randkey=key1, asparams=True - ) - log_mah = diffmahnet.log_mah_kern(mah_params, self.tgrids[tsamp], self.logt0) - model_combined = ( - jnp.array( - [ - log_mah, - self.tgrids[tsamp], - self.tile(self.m_obs[tsamp]), - self.tile(self.t_obs[tsamp]), - ] - ) - .reshape((4, -1)) - .T - ) - - if self.kde.num_fourier_kernels: - counts_model, counts_truth = self.kde.compare_fourier_counts( - key2, model_combined - ) - ecf_model = counts_model / model_combined.shape[0] - ecf_truth = counts_truth / self.training_combined.shape[0] - loss = jnp.sum(jnp.abs(ecf_model - ecf_truth) ** 2) - else: - loss = 0.0 - - counts_model, counts_truth = self.kde.compare_kde_counts(key3, model_combined) - pdf_model = counts_model / model_combined.shape[0] - pdf_truth = counts_truth / self.training_combined.shape[0] - loss += jnp.sum((pdf_model - pdf_truth) ** 2) - - # Optionally divide by total number of kernels to get MSE loss - # loss /= (self.kde.num_kernels + self.kde.num_fourier_kernels) - jax.debug.print("loss = {loss}", loss=loss) - - return loss - - -parser = argparse.ArgumentParser( - description="Train a DiffMahNet normalizing flow model." -) -parser.add_argument("SAVE_FILENAME", help="Filename to save the trained model.") -parser.add_argument( - "--save-dir", - type=str, - default=SAVE_DIR, - help="Directory to save the trained model.", -) -parser.add_argument( - "--train-data-dir", - type=str, - default=TRAIN_DATA_DIR, - help="Directory containing the training data.", -) -parser.add_argument( - "--initial-model", - type=str, - default=None, - help="Optional filename of an initial model to load.", -) -parser.add_argument("--sats", action="store_true") -parser.add_argument( - "--nn-depth", type=int, default=NN_DEPTH, help="Depth of the hidden neural network." -) -parser.add_argument( - "--nn-width", type=int, default=NN_WIDTH, help="Width of the hidden neural network." -) -parser.add_argument( - "--flow-layers", type=int, default=FLOW_LAYERS, help="Number of flow layers." -) -parser.add_argument( - "--include-test", action="store_true", help="Include test data in the training set." -) -parser.add_argument("--steps", type=int, default=100, help="Number of adam iterations.") -parser.add_argument( - "--learning-rate", - type=float, - default=LEARNING_RATE, - help="Initial adam learning rate.", -) -parser.add_argument( - "--num-kernels", type=int, default=NUM_KERNELS, help="Number of kdescent kernels." -) -parser.add_argument( - "--num-fourier-kernels", - type=int, - default=NUM_FOURIER_KERNELS, - help="Number of kdescent fourier kernels.", -) -parser.add_argument( - "--sample-frac", - type=float, - default=SAMPLE_FRAC, - help="Fraction of training data to load.", -) -parser.add_argument( - "--seed", type=int, default=0, help="Random seed for reproducibility." -) -parser.add_argument( - "--plot-loss-curve", - action="store_true", - help="Plot the loss curve during training.", -) - - -if __name__ == "__main__": - # Parse arguments - args = parser.parse_args() - save_dir = pathlib.Path(args.save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - save_filename = args.SAVE_FILENAME - train_data_dir = pathlib.Path(args.train_data_dir) - is_cens = not args.sats - nn_depth = args.nn_depth - nn_width = args.nn_width - flow_layers = args.flow_layers - is_test = "both" if args.include_test else False - initial_model = args.initial_model - steps = args.steps - sample_frac = args.sample_frac - - key = jax.random.key(args.seed) - key1, key2 = jax.random.split(key) - - # Load training data and flow model - train_data = datatools.DataHolder( - train_data_dir, - is_cens=is_cens, - is_test=is_test, - sample_frac=sample_frac, - randkey=key1, - ) - if initial_model is not None: - initial_model = save_dir / initial_model - flow = diffmahnet.DiffMahFlow.load(initial_model) - else: - flow = diffmahnet.DiffMahFlow( - scaler=train_data.scaler, - nn_depth=nn_depth, - nn_width=nn_width, - flow_layers=flow_layers, - ) - print("Number of parameters =", flow.get_params().size) - - # Train the flow model - if steps > 0: - loss_func = KDescentLoss( - train_data, - num_kernels=args.num_kernels, - num_fourier_kernels=args.num_fourier_kernels, - ) - params, losses = flow.adam_fit( - loss_func, randkey=key2, nsteps=steps, learning_rate=args.learning_rate - ) - if args.plot_loss_curve: - import matplotlib.pyplot as plt - - plt.semilogy(losses) - plt.xlabel("Iteration") - plt.ylabel("Loss") - plot_filename = save_filename.removesuffix(".eqx") + ".png" - plt.savefig(save_dir / plot_filename) - plt.close() - - # Save the trained model - flow.save(save_dir / save_filename) diff --git a/diffhalos/diffmahpop/scripts/kdecent_train_float64.py b/diffhalos/diffmahpop/scripts/kdecent_train_float64.py deleted file mode 100644 index a1edfa8..0000000 --- a/diffhalos/diffmahpop/scripts/kdecent_train_float64.py +++ /dev/null @@ -1,269 +0,0 @@ -# flake8: noqa: E402 -from jax import config - -config.update("jax_enable_x64", True) - -import pathlib -import argparse - -import jax -import jax.numpy as jnp -import equinox as eqx - -from diffopt import kdescent -import diffmahnet -from diffmahnet import datatools - -SAVE_DIR = pathlib.Path("./data/") -TRAIN_DATA_DIR = pathlib.Path( - "/lcrc/project/halotools/diffmahpop_data/NM_12_NT_9_ISTART_0_IEND_576/" -) - -NN_DEPTH = 2 -NN_WIDTH = 50 -FLOW_LAYERS = 8 -SAMPLE_FRAC = 1.0 -NUM_KERNELS = 20 -NUM_FOURIER_KERNELS = 0 -LEARNING_RATE = 1e-4 - - -class KDescentLoss: - """ - Custom loss function to fit flowjax model - """ - - def __init__( - self, - train_data, - sample_size=None, - randkey=None, - num_kernels=20, - num_fourier_kernels=0, - ): - randkey = jax.random.key(0) if randkey is None else randkey - # t0 = 13.8 - # self.logt0 = np.log10(t0) - self.logt0 = train_data.logt0 - self.sample_size = sample_size - self.xscaler = train_data.x_scaler - self.uscaler = train_data.u_scaler - - self.tgrids, self.log_mah = train_data.get_tgrid_and_log_mah(randkey) - self.m_obs = train_data.m_obs - self.t_obs = train_data.t_obs - assert self.log_mah.ndim == self.tgrids.ndim == 2 - assert self.m_obs.ndim == self.t_obs.ndim == 1 - assert ( - self.log_mah.shape[0] - == self.m_obs.shape[0] - == self.t_obs.shape[0] - == self.tgrids.shape[0] - ) - self.condition = jnp.array([self.m_obs, self.t_obs]).T - - # Combine m and t with condition (m_obs, t_obs), since we always have - # an equivalent sampling of the conditional variables - # and this saves us from having to generate many separate - # KCalc instances at different conditional value bins - self.training_combined = ( - jnp.array( - [ - self.log_mah, - self.tgrids, - self.tile(self.m_obs), - self.tile(self.t_obs), - ] - ) - .reshape((4, -1)) - .T - ) - self.kde = kdescent.KCalc( - self.training_combined, - num_kernels=num_kernels, - num_fourier_kernels=num_fourier_kernels, - ) - - def tile(self, arr): - return jnp.tile(arr[..., None], (1, self.tgrids.shape[1])) - - @eqx.filter_jit - def __call__(self, diffmahflow, randkey): - """Compute the loss using kdescent""" - key0, key1, key2, key3 = jax.random.split(randkey, 4) - if self.sample_size is None: - tsamp = slice(None) - else: - tsamp = jax.random.choice( - key0, - self.training_combined.shape[0], - (self.sample_size,), - replace=False, - ) - mah_params = diffmahflow.sample( - self.condition[tsamp], randkey=key1, asparams=True - ) - log_mah = diffmahnet.log_mah_kern(mah_params, self.tgrids[tsamp], self.logt0) - model_combined = ( - jnp.array( - [ - log_mah, - self.tgrids[tsamp], - self.tile(self.m_obs[tsamp]), - self.tile(self.t_obs[tsamp]), - ] - ) - .reshape((4, -1)) - .T - ) - - if self.kde.num_fourier_kernels: - counts_model, counts_truth = self.kde.compare_fourier_counts( - key2, model_combined - ) - ecf_model = counts_model / model_combined.shape[0] - ecf_truth = counts_truth / self.training_combined.shape[0] - loss = jnp.sum(jnp.abs(ecf_model - ecf_truth) ** 2) - else: - loss = 0.0 - - counts_model, counts_truth = self.kde.compare_kde_counts(key3, model_combined) - pdf_model = counts_model / model_combined.shape[0] - pdf_truth = counts_truth / self.training_combined.shape[0] - loss += jnp.sum((pdf_model - pdf_truth) ** 2) - - # Optionally divide by total number of kernels to get MSE loss - # loss /= (self.kde.num_kernels + self.kde.num_fourier_kernels) - jax.debug.print("loss = {loss}", loss=loss) - - return loss - - -parser = argparse.ArgumentParser( - description="Train a DiffMahNet normalizing flow model." -) -parser.add_argument("SAVE_FILENAME", help="Filename to save the trained model.") -parser.add_argument( - "--save-dir", - type=str, - default=SAVE_DIR, - help="Directory to save the trained model.", -) -parser.add_argument( - "--train-data-dir", - type=str, - default=TRAIN_DATA_DIR, - help="Directory containing the training data.", -) -parser.add_argument( - "--initial-model", - type=str, - default=None, - help="Optional filename of an initial model to load.", -) -parser.add_argument("--sats", action="store_true") -parser.add_argument( - "--nn-depth", type=int, default=NN_DEPTH, help="Depth of the hidden neural network." -) -parser.add_argument( - "--nn-width", type=int, default=NN_WIDTH, help="Width of the hidden neural network." -) -parser.add_argument( - "--flow-layers", type=int, default=FLOW_LAYERS, help="Number of flow layers." -) -parser.add_argument( - "--include-test", action="store_true", help="Include test data in the training set." -) -parser.add_argument("--steps", type=int, default=100, help="Number of adam iterations.") -parser.add_argument( - "--learning-rate", - type=float, - default=LEARNING_RATE, - help="Initial adam learning rate.", -) -parser.add_argument( - "--num-kernels", type=int, default=NUM_KERNELS, help="Number of kdescent kernels." -) -parser.add_argument( - "--num-fourier-kernels", - type=int, - default=NUM_FOURIER_KERNELS, - help="Number of kdescent fourier kernels.", -) -parser.add_argument( - "--sample-frac", - type=float, - default=SAMPLE_FRAC, - help="Fraction of training data to load.", -) -parser.add_argument( - "--seed", type=int, default=0, help="Random seed for reproducibility." -) -parser.add_argument( - "--plot-loss-curve", - action="store_true", - help="Plot the loss curve during training.", -) - - -if __name__ == "__main__": - # Parse arguments - args = parser.parse_args() - save_dir = pathlib.Path(args.save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - save_filename = args.SAVE_FILENAME - train_data_dir = pathlib.Path(args.train_data_dir) - is_cens = not args.sats - nn_depth = args.nn_depth - nn_width = args.nn_width - flow_layers = args.flow_layers - is_test = "both" if args.include_test else False - initial_model = args.initial_model - steps = args.steps - sample_frac = args.sample_frac - - key = jax.random.key(args.seed) - key1, key2 = jax.random.split(key) - - # Load training data and flow model - train_data = datatools.DataHolder( - train_data_dir, - is_cens=is_cens, - is_test=is_test, - sample_frac=sample_frac, - randkey=key1, - ) - if initial_model is not None: - initial_model = save_dir / initial_model - flow = diffmahnet.DiffMahFlow.load(initial_model) - else: - flow = diffmahnet.DiffMahFlow( - scaler=train_data.scaler, - nn_depth=nn_depth, - nn_width=nn_width, - flow_layers=flow_layers, - ) - print("Number of parameters =", flow.get_params().size) - - # Train the flow model - if steps > 0: - loss_func = KDescentLoss( - train_data, - num_kernels=args.num_kernels, - num_fourier_kernels=args.num_fourier_kernels, - ) - params, losses = flow.adam_fit( - loss_func, randkey=key2, nsteps=steps, learning_rate=args.learning_rate - ) - if args.plot_loss_curve: - import matplotlib.pyplot as plt - - plt.semilogy(losses) - plt.xlabel("Iteration") - plt.ylabel("Loss") - plot_filename = save_filename.removesuffix(".eqx") + ".png" - plt.savefig(save_dir / plot_filename) - plt.close() - - # Save the trained model - flow.save(save_dir / save_filename) diff --git a/diffhalos/diffmahpop/scripts/train_diffmahne.txt b/diffhalos/diffmahpop/scripts/train_diffmahne.txt deleted file mode 100644 index d49d62a..0000000 --- a/diffhalos/diffmahpop/scripts/train_diffmahne.txt +++ /dev/null @@ -1,32 +0,0 @@ -"Scripts to train diffmahnet using float64, modifying Alan's original scripts" - -- 2025/11/17: Modified to use float64. - - Model names: - cenflow_v2_0_float64.eqx - satflow_v2_0_float64.eqx - - Respective creation scripts: - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffhalos/diffmahpop/scripts/kdescent_train_float64.py cenflow_v2_0_float64.eqx --num-kernels 20 --num-fourier-kernels 20 --plot-loss-curve --seed 2 --include-test --initial-model cenflow_v1_0-kde_20_20train.eqx --steps 1200 --learning-rate 3e-5 - - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffhalos/diffmahpop/scripts/kdescent_train_float64.py satflow_v2_0_float64.eqx --num-kernels 20 --num-fourier-kernels 20 --plot-loss-curve --seed 2 --include-test --initial-model satflow_v1_0-kde_20_20train.eqx --steps 1200 --learning-rate 3e-5 --sats - -- 2025/05/15: Finished v2 models. Trained using v1 as the initial model, and -performing kdescent fitting to the 4D space {log(M(t)), t, log(M_obs), t_obs}. -- 2025/04/27: Finished v1 models, using full training data. Trained using the -flowjax fitter only. - - Model names: - cenflow_v1_0_float64.eqx - satflow_v1_0_float64.eqx - - Respective creation scripts: - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffmahnet/scripts/flowjax_train_float64.py cenflow_v1_0_float64.eqx --max-epochs 200 --max-patience 25 --learning-rate 8e-5 --nn-depth 8 --nn-width 48 --seed 111 --include-test --initial-model cenflow_v1_0train_float64.eqx - - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffmahnet/scripts/flowjax_train_float64.py satflow_v1_0_float64.eqx --max-epochs 200 --max-patience 25 --learning-rate 8e-5 --nn-depth 8 --nn-width 48 --seed 111 --include-test --initial-model cenflow_v1_0train_float64.eqx --sats - -- 2025/04/21: Preliminary versions of v1 models, from the "train split" (~70% -of the full training data) only. Trained using the flowjax fitter only. - - Model names: - cenflow_v1_0train_float64.eqx - satflow_v1_0train_float64.eqx - - Respective creation scripts: - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffmahnet/scripts/flowjax_train_float64.py cenflow_v1_0train_float64.eqx --max-epochs 400 --max-patience 30 --learning-rate 8e-5 --nn-depth 8 --nn-width 48 --seed 110 - - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffmahnet/scripts/flowjax_train_float64.py satflow_v1_0train_float64.eqx --max-epochs 400 --max-patience 30 --learning-rate 8e-5 --nn-depth 8 --nn-width 48 --seed 110 --sats \ No newline at end of file diff --git a/diffhalos/diffmahpop/scripts/validation_plots.py b/diffhalos/diffmahpop/scripts/validation_plots.py deleted file mode 100644 index a172c28..0000000 --- a/diffhalos/diffmahpop/scripts/validation_plots.py +++ /dev/null @@ -1,317 +0,0 @@ -import pathlib -import argparse - -import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -import tqdm -import corner - -import diffmahnet -from diffmahnet import datatools - -TRAIN_DATA_DIR = pathlib.Path( - "/lcrc/project/halotools/diffmahpop_data/NM_12_NT_9_ISTART_0_IEND_576/" -) -SAVE_DIR = pathlib.Path("./data/") -PLOT_DIR = pathlib.Path("./plots/") -SAMPLE_FRAC = 1.0 - - -def plot_mah_hists( - flow, - data, - data_desc=None, - flow_desc=None, - tfrac=1.0, - tobs_ranges=None, - mobs_ranges=None, - title="", - randkey=jax.random.key(0), -): - data_desc = data_desc + " data" if data_desc is not None else "Data" - flow_desc = flow_desc + " flow samples" if flow_desc is not None else "" - if mobs_ranges is None: - mobs_ranges = [(11.45, 11.75), (12.3, 12.7), (13.15, 13.55), (14.2, 14.57)] - if tobs_ranges is None: - tobs_ranges = [(3, 4), (6, 7), (9, 10), (13.7, 14.1)] - key1, key2 = jax.random.split(randkey, 2) - - mah_params = data.diffmahparams - flow_mah_params = flow.sample(data.u, randkey=key1, asparams=True) - - tgrid = diffmahnet.gen_time_grids(key2, data.u[:, 1]) - tfrac_grid = tgrid / data.u[:, 1, None] - tgrid_ind = jnp.minimum( - jax.vmap(jnp.searchsorted, in_axes=(0, None))(tfrac_grid, tfrac), - tfrac_grid.shape[1] - 1, - ) - log_mah = diffmahnet.log_mah_kern(mah_params, tgrid, data.logt0)[ - np.arange(tgrid.shape[0]), tgrid_ind - ] - flow_log_mah = diffmahnet.log_mah_kern(flow_mah_params, tgrid, data.logt0)[ - np.arange(tgrid.shape[0]), tgrid_ind - ] - - fig, axes = plt.subplots(2, 2, figsize=(8, 6)) - for i_tobs, ax in tqdm.tqdm(enumerate(axes.ravel()), leave=False): - tobs_min, tobs_max = tobs_ranges[i_tobs] - tobs_cut = (tobs_min < data.u[:, 1]) & (data.u[:, 1] < tobs_max) - for i_mobs in tqdm.trange(len(mobs_ranges), leave=False): - mobs_min, mobs_max = mobs_ranges[i_mobs] - mobs_cut = (mobs_min < data.u[:, 0]) & (data.u[:, 0] < mobs_max) - - cut = tobs_cut & mobs_cut - flow_hist_dat = flow_log_mah[cut] - hist_dat = log_mah[cut] - all_dat = np.concatenate([flow_hist_dat, hist_dat]) - if len(all_dat): - mean = all_dat.mean() - bins = np.linspace(mean - 3, mean + 3, 70) - else: - bins = 70 - color = f"C{i_mobs}" - ax.hist( - flow_hist_dat, - bins=bins, - linewidth=2, - color=color, - histtype="step", - density=True, - ) - ax.hist( - hist_dat, - bins=bins, - linestyle="--", - color=color, - histtype="step", - density=True, - ) - ax.hist([], alpha=0, label=f"t_obs~{np.mean(tobs_ranges[i_tobs]):.1f}") - ax = axes.ravel()[0] - - for i_mobs in range(len(mobs_ranges)): - color = f"C{i_mobs}" - m = np.mean(mobs_ranges[i_mobs]) - if flow_desc: - label = flow_desc + f" (M_obs~{m:.1f})" - else: - label = f"M_obs~{m:.1f}" - ax.hist([], linewidth=2, color=color, histtype="step", label=label) - ax.hist([], linestyle="--", color="k", histtype="step", label=data_desc) - for ax in axes.ravel(): - ax.set_xlim(left=9) - if ax in axes[-1, :]: - ax.set_xlabel("$\\rm M_h (t)$") - ax.legend(frameon=False, fontsize=10) - if title: - fig.suptitle(title) - plt.show() - - -def plot_mah_residual( - flow, data, tobs_ranges=None, mobs_ranges=None, title="", randkey=jax.random.key(0) -): - if mobs_ranges is None: - mobs_ranges = [(11.45, 11.75), (12.3, 12.7), (13.15, 13.55), (14.2, 14.57)] - if tobs_ranges is None: - tobs_ranges = [(3, 4), (6, 7), (9, 10), (13.7, 14.1)] - key1, key2 = jax.random.split(randkey, 2) - - mah_params = data.diffmahparams - flow_mah_params = flow.sample(data.u, randkey=key1, asparams=True) - - tgrid = diffmahnet.gen_time_grids(key2, data.u[:, 1], n_tgrid=100) - log_mah = diffmahnet.log_mah_kern(mah_params, tgrid, data.logt0) - flow_log_mah = diffmahnet.log_mah_kern(flow_mah_params, tgrid, data.logt0) - - fig, axes = plt.subplots(2, 2, figsize=(8, 6), sharey=True) - cmap = plt.matplotlib.colormaps["rainbow"] - colors = cmap(np.linspace(0, 1, len(mobs_ranges))) - for i_tobs, ax in enumerate(axes.ravel()): - tobs_min, tobs_max = tobs_ranges[i_tobs] - tobs_cut = (tobs_min < data.u[:, 1]) & (data.u[:, 1] < tobs_max) - ax.axhline(0, color="k", ls="--") - for i_mobs in range(len(mobs_ranges)): - mobs_min, mobs_max = mobs_ranges[i_mobs] - mobs_cut = (mobs_min < data.u[:, 0]) & (data.u[:, 0] < mobs_max) - - cut = tobs_cut & mobs_cut - if np.any(cut): - flow_mean = np.mean(flow_log_mah[cut], axis=0) - dat_mean = np.mean(log_mah[cut], axis=0) - tgrid_mean = np.mean(tgrid[cut], axis=0) - color = colors[i_mobs] - ax.plot(tgrid_mean, flow_mean - dat_mean, linewidth=2, color=color) - ax.plot([], [], alpha=0, label=f"t_obs~{np.mean(tobs_ranges[i_tobs]):.1f}") - - for i_mobs in range(len(mobs_ranges)): - color = colors[i_mobs] - m = np.mean(mobs_ranges[i_mobs]) - ax.plot([], [], linewidth=2, color=color, label=f"M_obs~{m:.1f}") - for ax in axes.ravel(): - # ax.set_xlim(left=9) - if ax in axes[:, 0]: - ax.set_ylabel("$\\rm\\Delta\\langle\\log M_h (t)\\rangle$") - if ax in axes[-1, :]: - ax.set_xlabel("t") - ax.legend(frameon=False, fontsize=10) - if title: - fig.suptitle(title) - plt.show() - - -def plot_mah_corner( - flow, data, data_desc=None, flow_desc=None, randkey=jax.random.key(0) -): - data_desc = data_desc + " data" if data_desc is not None else "Data" - flow_desc = flow_desc + " flow samples" if flow_desc is not None else "Flow samples" - key1, key2 = jax.random.split(randkey, 2) - - mah_params = data.diffmahparams - flow_mah_params = flow.sample(data.u, randkey=key1, asparams=True) - - tgrid = diffmahnet.gen_time_grids(key2, data.u[:, 1]) - log_mah = diffmahnet.log_mah_kern(mah_params, tgrid, data.logt0) - flow_log_mah = diffmahnet.log_mah_kern(flow_mah_params, tgrid, data.logt0) - - # Dimension indices correspond to: (object, variable, time snapshot) - broadcasted_u = np.tile(data.u[:, :, None], (1, 1, tgrid.shape[-1])) - combined_data = np.array( - [ - broadcasted_u[:, 0, :].flatten(), - broadcasted_u[:, 1, :].flatten(), - tgrid.flatten(), - log_mah.flatten(), - ] - ).T - combined_flow = np.array( - [ - broadcasted_u[:, 0, :].flatten(), - broadcasted_u[:, 1, :].flatten(), - tgrid.flatten(), - flow_log_mah.flatten(), - ] - ).T - - labels = ["logM_obs", "t_obs", "t", "logMAH"] - hist_kwargs = {"density": True} - ranges = [1.0, 1.0, 1.0, (5, 15)] - fig = corner.corner( - combined_flow, - color="C2", - plot_datapoints=False, - hist_kwargs={**hist_kwargs, "color": "C2"}, - labels=labels, - range=ranges, - plot_density=False, - ) - corner.corner( - combined_data, - fig=fig, - color="C1", - plot_datapoints=False, - hist_kwargs={**hist_kwargs, "color": "C1"}, - labels=labels, - range=ranges, - plot_density=False, - ) - fig.axes[1].text(0, 0.1, data_desc, color="C1") - fig.axes[1].text(0, 0, flow_desc, color="C2") - plt.show() - - -parser = argparse.ArgumentParser(description="Plot DiffMahNet validation plots.") -parser.add_argument("SAVE_FILENAME", help="Filename the trained model is saved as.") -parser.add_argument( - "PLOT_DIR", help=f"Directory to save the plots, relative to {PLOT_DIR}" -) -parser.add_argument( - "--save_dir", - type=str, - default=SAVE_DIR, - help="Directory the trained model is saved in.", -) -parser.add_argument( - "--train_data_dir", - type=str, - default=TRAIN_DATA_DIR, - help="Directory containing the training data.", -) -parser.add_argument("--sats", action="store_true") -parser.add_argument("--not-test", action="store_true") -parser.add_argument( - "--sample-frac", - type=float, - default=SAMPLE_FRAC, - help="Fraction of training data to load.", -) -parser.add_argument( - "--seed", type=int, default=0, help="Random seed for reproducibility." -) - -if __name__ == "__main__": - args = parser.parse_args() - save_dir = pathlib.Path(args.save_dir) - save_filename = args.SAVE_FILENAME - train_data_dir = pathlib.Path(args.train_data_dir) - is_cens = not args.sats - is_test = not args.not_test - plot_dir = PLOT_DIR / args.PLOT_DIR - plot_dir.mkdir(parents=True, exist_ok=True) - sample_frac = args.sample_frac - - key = jax.random.key(args.seed) - key1, *keys = jax.random.split(key, num=5) - - flow = diffmahnet.DiffMahFlow.load(save_dir / save_filename) - data = datatools.DataHolder( - train_data_dir, - is_test=is_test, - is_cens=is_cens, - sample_frac=sample_frac, - randkey=key1, - ) - - if is_test: - tobs_ranges = [(1.5, 4), (4, 7), (7.5, 8.5), (10, 12)] - mobs_ranges = [(10.8, 11.33), (11.6, 12.5), (12.5, 13.0), (14.0, 14.8)] - else: - tobs_ranges = None - mobs_ranges = None - - plot_mah_hists( - flow, - data, - title="t = t_obs", - tobs_ranges=tobs_ranges, - mobs_ranges=mobs_ranges, - randkey=keys[0], - ) - plt.savefig(plot_dir / "mah_hist_t1p0.png", bbox_inches="tight") - plot_mah_hists( - flow, - data, - title="t = 0.6 * t_obs", - tfrac=0.6, - tobs_ranges=tobs_ranges, - mobs_ranges=mobs_ranges, - randkey=keys[1], - ) - plt.savefig(plot_dir / "mah_hist_t0p6.png", bbox_inches="tight") - plot_mah_hists( - flow, - data, - title="t = 0.3 * t_obs", - tfrac=0.3, - tobs_ranges=tobs_ranges, - mobs_ranges=mobs_ranges, - randkey=keys[2], - ) - plt.savefig(plot_dir / "mah_hist_t0p3.png", bbox_inches="tight") - - plot_mah_residual( - flow, data, tobs_ranges=tobs_ranges, mobs_ranges=mobs_ranges, randkey=keys[3] - ) - plt.savefig(plot_dir / "mah_residual.png", bbox_inches="tight") diff --git a/diffhalos/diffmahpop/utils.py b/diffhalos/diffmahpop/utils.py deleted file mode 100644 index 2d8519d..0000000 --- a/diffhalos/diffmahpop/utils.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Useful utilities for diffmahpop""" - -__all__ = ("rescale_mah_parameters",) - - -def rescale_mah_parameters( - mah_params_uncorrected, - logm_obs, - logm_obs_uncorrected, -): - """ - Corrects the mah model parameters, so that - logm0 is rescaled to the value that results in - mah's that agree with the observed halo mass - - Parameters - ---------- - mah_params_uncorrected: namedtuple - mah parameters (logm0, logtc, early_index, late_index, t_peak) - where each parameters is a ndarray of shape (n_halo, ) - - logm_obs: ndarray of shape (n_halo, ) - base-10 log of true observed halo masses, in Msun - - logm_obs: ndarray of shape (n_halo, ) - base-10 log of uncorrected observed halo masses, in Msun - - Returns - ------- - mah_params: namedtuple - mah parameters after rescaling, of same shape as ``mah_uncorrected`` - """ - delta_logm_obs = logm_obs_uncorrected - logm_obs - logm0_rescaled = mah_params_uncorrected.logm0 - delta_logm_obs - mah_params = mah_params_uncorrected._replace(logm0=logm0_rescaled) - - return mah_params diff --git a/diffhalos/mah/scripts/flowjax_train_float64.py b/diffhalos/mah/scripts/flowjax_train_float64.py deleted file mode 100644 index e92e252..0000000 --- a/diffhalos/mah/scripts/flowjax_train_float64.py +++ /dev/null @@ -1,131 +0,0 @@ -# flake8: noqa: E402 -from jax import config - -config.update("jax_enable_x64", True) - -import pathlib -import argparse - -import jax -import diffmahnet -from diffmahnet import datatools - -SAVE_DIR = pathlib.Path("./data/") -TRAIN_DATA_DIR = pathlib.Path( - "/lcrc/project/halotools/diffmahpop_data/NM_12_NT_9_ISTART_0_IEND_576/" -) - -NN_DEPTH = 2 -NN_WIDTH = 50 -FLOW_LAYERS = 8 -SAMPLE_FRAC = 1.0 - - -parser = argparse.ArgumentParser( - description="Train a DiffMahNet normalizing flow model." -) -parser.add_argument("SAVE_FILENAME", help="Filename to save the trained model.") -parser.add_argument( - "--save-dir", - type=str, - default=SAVE_DIR, - help="Directory to save the trained model.", -) -parser.add_argument( - "--train-data-dir", - type=str, - default=TRAIN_DATA_DIR, - help="Directory containing the training data.", -) -parser.add_argument( - "--initial-model", - type=str, - default=None, - help="Optional filename of an initial model to load.", -) -parser.add_argument("--sats", action="store_true") -parser.add_argument( - "--nn-depth", type=int, default=NN_DEPTH, help="Depth of the hidden neural network." -) -parser.add_argument( - "--nn-width", type=int, default=NN_WIDTH, help="Width of the hidden neural network." -) -parser.add_argument( - "--flow-layers", type=int, default=FLOW_LAYERS, help="Number of flow layers." -) -parser.add_argument( - "--include-test", action="store_true", help="Include test data in the training set." -) -parser.add_argument( - "--max-epochs", type=int, default=50, help="Number of training epochs." -) -parser.add_argument( - "--learning-rate", - type=float, - default=5e-4, - help="Learning rate for the built-in flowjax optimizer.", -) -parser.add_argument("--max-patience", type=float, default=10.0) -parser.add_argument( - "--sample-frac", - type=float, - default=SAMPLE_FRAC, - help="Fraction of training data to load.", -) -parser.add_argument( - "--seed", type=int, default=0, help="Random seed for reproducibility." -) - -if __name__ == "__main__": - # Parse arguments - args = parser.parse_args() - save_dir = pathlib.Path(args.save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - save_filename = args.SAVE_FILENAME - train_data_dir = pathlib.Path(args.train_data_dir) - is_cens = not args.sats - nn_depth = args.nn_depth - nn_width = args.nn_width - flow_layers = args.flow_layers - is_test = "both" if args.include_test else False - initial_model = args.initial_model - max_epochs = args.max_epochs - sample_frac = args.sample_frac - - key = jax.random.key(args.seed) - key1, key2 = jax.random.split(key) - - # Load training data and flow model - train_data = datatools.DataHolder( - train_data_dir, - is_cens=is_cens, - is_test=is_test, - sample_frac=sample_frac, - randkey=key1, - ) - if initial_model is not None: - initial_model = save_dir / initial_model - flow = diffmahnet.DiffMahFlow.load(initial_model) - else: - flow = diffmahnet.DiffMahFlow( - scaler=train_data.scaler, - nn_depth=nn_depth, - nn_width=nn_width, - flow_layers=flow_layers, - ) - print("Number of parameters =", flow.get_params().size) - - # Train the flow model - if max_epochs > 0: - print("Training data shapes:", train_data.x.shape, train_data.u.shape) - flow.init_fit( - train_data.x, - train_data.u, - randkey=key2, - max_epochs=max_epochs, - learning_rate=args.learning_rate, - max_patience=args.max_patience, - ) - - # Save the trained model - flow.save(save_dir / save_filename) diff --git a/diffhalos/mah/scripts/kdecent_train_float64.py b/diffhalos/mah/scripts/kdecent_train_float64.py deleted file mode 100644 index a1edfa8..0000000 --- a/diffhalos/mah/scripts/kdecent_train_float64.py +++ /dev/null @@ -1,269 +0,0 @@ -# flake8: noqa: E402 -from jax import config - -config.update("jax_enable_x64", True) - -import pathlib -import argparse - -import jax -import jax.numpy as jnp -import equinox as eqx - -from diffopt import kdescent -import diffmahnet -from diffmahnet import datatools - -SAVE_DIR = pathlib.Path("./data/") -TRAIN_DATA_DIR = pathlib.Path( - "/lcrc/project/halotools/diffmahpop_data/NM_12_NT_9_ISTART_0_IEND_576/" -) - -NN_DEPTH = 2 -NN_WIDTH = 50 -FLOW_LAYERS = 8 -SAMPLE_FRAC = 1.0 -NUM_KERNELS = 20 -NUM_FOURIER_KERNELS = 0 -LEARNING_RATE = 1e-4 - - -class KDescentLoss: - """ - Custom loss function to fit flowjax model - """ - - def __init__( - self, - train_data, - sample_size=None, - randkey=None, - num_kernels=20, - num_fourier_kernels=0, - ): - randkey = jax.random.key(0) if randkey is None else randkey - # t0 = 13.8 - # self.logt0 = np.log10(t0) - self.logt0 = train_data.logt0 - self.sample_size = sample_size - self.xscaler = train_data.x_scaler - self.uscaler = train_data.u_scaler - - self.tgrids, self.log_mah = train_data.get_tgrid_and_log_mah(randkey) - self.m_obs = train_data.m_obs - self.t_obs = train_data.t_obs - assert self.log_mah.ndim == self.tgrids.ndim == 2 - assert self.m_obs.ndim == self.t_obs.ndim == 1 - assert ( - self.log_mah.shape[0] - == self.m_obs.shape[0] - == self.t_obs.shape[0] - == self.tgrids.shape[0] - ) - self.condition = jnp.array([self.m_obs, self.t_obs]).T - - # Combine m and t with condition (m_obs, t_obs), since we always have - # an equivalent sampling of the conditional variables - # and this saves us from having to generate many separate - # KCalc instances at different conditional value bins - self.training_combined = ( - jnp.array( - [ - self.log_mah, - self.tgrids, - self.tile(self.m_obs), - self.tile(self.t_obs), - ] - ) - .reshape((4, -1)) - .T - ) - self.kde = kdescent.KCalc( - self.training_combined, - num_kernels=num_kernels, - num_fourier_kernels=num_fourier_kernels, - ) - - def tile(self, arr): - return jnp.tile(arr[..., None], (1, self.tgrids.shape[1])) - - @eqx.filter_jit - def __call__(self, diffmahflow, randkey): - """Compute the loss using kdescent""" - key0, key1, key2, key3 = jax.random.split(randkey, 4) - if self.sample_size is None: - tsamp = slice(None) - else: - tsamp = jax.random.choice( - key0, - self.training_combined.shape[0], - (self.sample_size,), - replace=False, - ) - mah_params = diffmahflow.sample( - self.condition[tsamp], randkey=key1, asparams=True - ) - log_mah = diffmahnet.log_mah_kern(mah_params, self.tgrids[tsamp], self.logt0) - model_combined = ( - jnp.array( - [ - log_mah, - self.tgrids[tsamp], - self.tile(self.m_obs[tsamp]), - self.tile(self.t_obs[tsamp]), - ] - ) - .reshape((4, -1)) - .T - ) - - if self.kde.num_fourier_kernels: - counts_model, counts_truth = self.kde.compare_fourier_counts( - key2, model_combined - ) - ecf_model = counts_model / model_combined.shape[0] - ecf_truth = counts_truth / self.training_combined.shape[0] - loss = jnp.sum(jnp.abs(ecf_model - ecf_truth) ** 2) - else: - loss = 0.0 - - counts_model, counts_truth = self.kde.compare_kde_counts(key3, model_combined) - pdf_model = counts_model / model_combined.shape[0] - pdf_truth = counts_truth / self.training_combined.shape[0] - loss += jnp.sum((pdf_model - pdf_truth) ** 2) - - # Optionally divide by total number of kernels to get MSE loss - # loss /= (self.kde.num_kernels + self.kde.num_fourier_kernels) - jax.debug.print("loss = {loss}", loss=loss) - - return loss - - -parser = argparse.ArgumentParser( - description="Train a DiffMahNet normalizing flow model." -) -parser.add_argument("SAVE_FILENAME", help="Filename to save the trained model.") -parser.add_argument( - "--save-dir", - type=str, - default=SAVE_DIR, - help="Directory to save the trained model.", -) -parser.add_argument( - "--train-data-dir", - type=str, - default=TRAIN_DATA_DIR, - help="Directory containing the training data.", -) -parser.add_argument( - "--initial-model", - type=str, - default=None, - help="Optional filename of an initial model to load.", -) -parser.add_argument("--sats", action="store_true") -parser.add_argument( - "--nn-depth", type=int, default=NN_DEPTH, help="Depth of the hidden neural network." -) -parser.add_argument( - "--nn-width", type=int, default=NN_WIDTH, help="Width of the hidden neural network." -) -parser.add_argument( - "--flow-layers", type=int, default=FLOW_LAYERS, help="Number of flow layers." -) -parser.add_argument( - "--include-test", action="store_true", help="Include test data in the training set." -) -parser.add_argument("--steps", type=int, default=100, help="Number of adam iterations.") -parser.add_argument( - "--learning-rate", - type=float, - default=LEARNING_RATE, - help="Initial adam learning rate.", -) -parser.add_argument( - "--num-kernels", type=int, default=NUM_KERNELS, help="Number of kdescent kernels." -) -parser.add_argument( - "--num-fourier-kernels", - type=int, - default=NUM_FOURIER_KERNELS, - help="Number of kdescent fourier kernels.", -) -parser.add_argument( - "--sample-frac", - type=float, - default=SAMPLE_FRAC, - help="Fraction of training data to load.", -) -parser.add_argument( - "--seed", type=int, default=0, help="Random seed for reproducibility." -) -parser.add_argument( - "--plot-loss-curve", - action="store_true", - help="Plot the loss curve during training.", -) - - -if __name__ == "__main__": - # Parse arguments - args = parser.parse_args() - save_dir = pathlib.Path(args.save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - save_filename = args.SAVE_FILENAME - train_data_dir = pathlib.Path(args.train_data_dir) - is_cens = not args.sats - nn_depth = args.nn_depth - nn_width = args.nn_width - flow_layers = args.flow_layers - is_test = "both" if args.include_test else False - initial_model = args.initial_model - steps = args.steps - sample_frac = args.sample_frac - - key = jax.random.key(args.seed) - key1, key2 = jax.random.split(key) - - # Load training data and flow model - train_data = datatools.DataHolder( - train_data_dir, - is_cens=is_cens, - is_test=is_test, - sample_frac=sample_frac, - randkey=key1, - ) - if initial_model is not None: - initial_model = save_dir / initial_model - flow = diffmahnet.DiffMahFlow.load(initial_model) - else: - flow = diffmahnet.DiffMahFlow( - scaler=train_data.scaler, - nn_depth=nn_depth, - nn_width=nn_width, - flow_layers=flow_layers, - ) - print("Number of parameters =", flow.get_params().size) - - # Train the flow model - if steps > 0: - loss_func = KDescentLoss( - train_data, - num_kernels=args.num_kernels, - num_fourier_kernels=args.num_fourier_kernels, - ) - params, losses = flow.adam_fit( - loss_func, randkey=key2, nsteps=steps, learning_rate=args.learning_rate - ) - if args.plot_loss_curve: - import matplotlib.pyplot as plt - - plt.semilogy(losses) - plt.xlabel("Iteration") - plt.ylabel("Loss") - plot_filename = save_filename.removesuffix(".eqx") + ".png" - plt.savefig(save_dir / plot_filename) - plt.close() - - # Save the trained model - flow.save(save_dir / save_filename) diff --git a/diffhalos/mah/scripts/train_diffmahne.txt b/diffhalos/mah/scripts/train_diffmahne.txt deleted file mode 100644 index 22a0d4d..0000000 --- a/diffhalos/mah/scripts/train_diffmahne.txt +++ /dev/null @@ -1,31 +0,0 @@ -"Scripts to train diffmahnet using float64, modifying Alan's original scripts" - -- 2025/11/17: Modified to use float64. - - Model names: - cenflow_v2_0_float64.eqx - satflow_v2_0_float64.eqx - - Respective creation scripts: - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffhalos/diffmahpop/scripts/kdescent_train_float64.py cenflow_v2_0_float64.eqx --num-kernels 20 --num-fourier-kernels 20 --plot-loss-curve --seed 2 --include-test --initial-model cenflow_v1_0-kde_20_20train.eqx --steps 1200 --learning-rate 3e-5 - - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffhalos/diffmahpop/scripts/kdescent_train_float64.py satflow_v2_0_float64.eqx --num-kernels 20 --num-fourier-kernels 20 --plot-loss-curve --seed 2 --include-test --initial-model satflow_v1_0-kde_20_20train.eqx --steps 1200 --learning-rate 3e-5 --sats - -- 2025/05/15: Finished v2 models. Trained using v1 as the initial model, and -performing kdescent fitting to the 4D space {log(M(t)), t, log(M_obs), t_obs}. -- 2025/04/27: Finished v1 models, using full training data. Trained using the -flowjax fitter only. - - Model names: - cenflow_v1_0_float64.eqx - satflow_v1_0_float64.eqx - - Respective creation scripts: - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffmahnet/scripts/flowjax_train_float64.py cenflow_v1_0_float64.eqx --max-epochs 200 --max-patience 25 --learning-rate 8e-5 --nn-depth 8 --nn-width 48 --seed 111 --include-test --initial-model cenflow_v1_0train_float64.eqx - - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffmahnet/scripts/flowjax_train_float64.py satflow_v1_0_float64.eqx --max-epochs 200 --max-patience 25 --learning-rate 8e-5 --nn-depth 8 --nn-width 48 --seed 111 --include-test --initial-model cenflow_v1_0train_float64.eqx --sats - -- Preliminary versions of v1 models, from the "train split" (~70% of the full training data) only. Trained using the flowjax fitter only. - - Model names: - cenflow_v1_0train_float64.eqx - satflow_v1_0train_float64.eqx - - Respective creation scripts: - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffmahnet/scripts/flowjax_train_float64.py cenflow_v1_0train_float64.eqx --max-epochs 400 --max-patience 30 --learning-rate 8e-5 --nn-depth 8 --nn-width 48 --seed 110 - - mpiexec -n 1 $watchmemory python -W ignore ~/local/diffmahnet/scripts/flowjax_train_float64.py satflow_v1_0train_float64.eqx --max-epochs 400 --max-patience 30 --learning-rate 8e-5 --nn-depth 8 --nn-width 48 --seed 110 --sats