From 7e653af91018df7b377e7650b2f7c6761fbd09dd Mon Sep 17 00:00:00 2001 From: alexmccreight <57416850+alexmccreight@users.noreply.github.com> Date: Sat, 1 Nov 2025 19:52:43 -0500 Subject: [PATCH] ash grid scaling factor --- R/sufficient_stats_methods.R | 20 ++++++++------------ R/susie.R | 4 ++-- R/susie_constructors.R | 4 ++-- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/R/sufficient_stats_methods.R b/R/sufficient_stats_methods.R index 922955c5..f200404c 100644 --- a/R/sufficient_stats_methods.R +++ b/R/sufficient_stats_methods.R @@ -334,22 +334,18 @@ update_variance_components.ss <- function(data, params, model, ...) { # Update the sparse effect variance sparse_var <- mean(colSums(model$alpha * model$V)) - - # Update sigma2 and tau2 via MoM - mom_result <- mom_unmappable(data, params, model, omega, sparse_var, est_tau2 = TRUE, est_sigma2 = TRUE) - + # Remove the sparse effects b <- colSums(model$alpha * model$mu) residuals <- data$y - data$X %*% b - # Specify ash grid - if (mom_result$tau2 > 0) { - grid_factors <- exp(seq(log(0.1), log(100), length.out = 20 - 1)) - est_sa2 <- c(0, mom_result$tau2 * grid_factors) + # Build ASH grid based on sparse variance and scaling factor + if (sparse_var * params$ash_scaling_factor > 1e-8) { + grid_factors <- exp(seq(log(0.01), log(10), length.out = 20 - 1)) + est_sa2 <- c(0, sparse_var * params$ash_scaling_factor * grid_factors) } else { - # Fallback if MoM gives tau2 = 0 est_sa2 <- c(0, (2^(0.05*(1:20-1)) - 1)^4) - est_sa2 <- est_sa2 * (0.1 / max(est_sa2)) + est_sa2 <- est_sa2 * (0.01 / max(est_sa2)) } # Call mr.ash with residuals @@ -359,8 +355,8 @@ update_variance_components.ss <- function(data, params, model, ...) { sa2 = est_sa2, intercept = FALSE, standardize = FALSE, - sigma2 = mom_result$sigma2, - update.sigma2 = params$update_ash_sigma2, + sigma2 = model$sigma2, + update.sigma2 = params$estimate_residual_variance, max.iter = 3000 ) diff --git a/R/susie.R b/R/susie.R index e41bc199..1abef5e5 100644 --- a/R/susie.R +++ b/R/susie.R @@ -279,7 +279,7 @@ susie <- function(X, y, L = min(10, ncol(X)), estimate_prior_variance = TRUE, estimate_prior_method = c("optim", "EM", "simple"), unmappable_effects = c("none", "inf", "ash"), - update_ash_sigma2 = FALSE, + ash_scaling_factor = 0.15, check_null_threshold = 0, prior_tol = 1e-9, residual_variance_upperbound = Inf, @@ -311,7 +311,7 @@ susie <- function(X, y, L = min(10, ncol(X)), prior_weights, null_weight, standardize, intercept, estimate_residual_variance, estimate_residual_method, estimate_prior_variance, estimate_prior_method, - unmappable_effects, update_ash_sigma2, check_null_threshold, prior_tol, + unmappable_effects, ash_scaling_factor, check_null_threshold, prior_tol, residual_variance_upperbound, model_init, coverage, min_abs_corr, compute_univariate_zscore, na.rm, max_iter, tol, convergence_method, verbose, track_fit, diff --git a/R/susie_constructors.R b/R/susie_constructors.R index 80c9d8a4..78ae0f22 100644 --- a/R/susie_constructors.R +++ b/R/susie_constructors.R @@ -24,7 +24,7 @@ individual_data_constructor <- function(X, y, L = min(10, ncol(X)), estimate_prior_variance = TRUE, estimate_prior_method = "optim", unmappable_effects = "none", - update_ash_sigma2 = FALSE, + ash_scaling_factor = 0.15, check_null_threshold = 0, prior_tol = 1e-9, residual_variance_upperbound = Inf, @@ -141,7 +141,7 @@ individual_data_constructor <- function(X, y, L = min(10, ncol(X)), estimate_prior_variance = estimate_prior_variance, estimate_prior_method = estimate_prior_method, unmappable_effects = unmappable_effects, - update_ash_sigma2 = update_ash_sigma2, + ash_scaling_factor = ash_scaling_factor, check_null_threshold = check_null_threshold, prior_tol = prior_tol, residual_variance_upperbound = residual_variance_upperbound,