diff --git a/R/sufficient_stats_methods.R b/R/sufficient_stats_methods.R index c5078193..e2d364b1 100644 --- a/R/sufficient_stats_methods.R +++ b/R/sufficient_stats_methods.R @@ -45,7 +45,7 @@ initialize_susie_model.ss <- function(data, params, var_y, ...) { if (params$unmappable_effects %in% c("inf", "ash")) { # Initialize omega quantities for unmappable effects - omega_res <- compute_omega_quantities(data, tau2 = 0, sigma2 = 1) + omega_res <- compute_omega_quantities(data, tau2 = 0, sigma2 = var_y) model$omega_var <- omega_res$omega_var model$predictor_weights <- omega_res$diagXtOmegaX model$XtOmegay <- data$eigen_vectors %*% (data$VtXty / omega_res$omega_var) @@ -291,10 +291,8 @@ neg_loglik.ss <- function(data, params, model, V_param, ser_stats, ...) { #' @keywords internal update_fitted_values.ss <- function(data, params, model, l) { if (params$unmappable_effects != "none") { - model$XtXr <- compute_Xb(data$XtX, colSums(model$alpha * model$mu) + model$theta) + model$XtXr <- as.vector(data$XtX %*% (colSums(model$alpha * model$mu) + model$theta)) } else { - # Fix: Use direct matrix multiplication to match original implementation - # Original: s$XtXr = s$XtXr + XtX %*% (s$alpha[l,] * s$mu[l,]) model$XtXr <- model$fitted_without_l + as.vector(data$XtX %*% (model$alpha[l, ] * model$mu[l, ])) } return(model) @@ -344,7 +342,7 @@ update_variance_components.ss <- function(data, params, model, ...) { )) } - # Remove the sparse effects + # Remove the sparse effects to compute residuals for mr.ash b <- colSums(model$alpha * model$mu) residuals <- data$y - data$X %*% b diff --git a/R/susie_constructors.R b/R/susie_constructors.R index 355caa13..8681d3c2 100644 --- a/R/susie_constructors.R +++ b/R/susie_constructors.R @@ -67,6 +67,18 @@ individual_data_constructor <- function(X, y, L = min(10, ncol(X)), } mean_y <- mean(y) + # Force required preprocessing for unmappable effects methods + if (unmappable_effects != "none") { + if (!intercept) { + warning_message("Unmappable effects methods require centered data. Setting intercept=TRUE.") + intercept <- TRUE + } + if (!standardize) { + warning_message("Unmappable effects methods require scaled data. Setting standardize=TRUE.") + standardize <- TRUE + } + } + # Handle null weights if (is.numeric(null_weight) && null_weight == 0) { null_weight <- NULL diff --git a/R/susie_utils.R b/R/susie_utils.R index ad6b48aa..144e00ab 100644 --- a/R/susie_utils.R +++ b/R/susie_utils.R @@ -596,8 +596,8 @@ add_null_effect <- function(model_init, V) { # and log Bayes factor calculations. # # Functions: compute_eigen_decomposition, add_eigen_decomposition, -# compute_omega_quantities, compute_theta_blup, lbf_stabilization, -# compute_posterior_weights, compute_lbf_gradient +# compute_omega_quantities, scale_design_matrix, compute_theta_blup, +# lbf_stabilization, compute_posterior_weights, compute_lbf_gradient # ============================================================================= # Compute eigenvalue decomposition for unmappable methods @@ -617,19 +617,6 @@ compute_eigen_decomposition <- function(XtX, n) { # Add eigen decomposition to ss data objects for unmappable methods #' @keywords internal add_eigen_decomposition <- function(data, params, individual_data = NULL) { - # Standardize y to unit variance for all unmappable effects methods - y_scale_factor <- 1 - - if (params$unmappable_effects != "none") { - var_y <- data$yty / (data$n - 1) - if (abs(var_y - 1) > 1e-10) { - sd_y <- sqrt(var_y) - data$yty <- data$yty / var_y - data$Xty <- data$Xty / sd_y - y_scale_factor <- sd_y - } - } - # Compute eigen decomposition eigen_decomp <- compute_eigen_decomposition(data$XtX, data$n) @@ -638,19 +625,47 @@ add_eigen_decomposition <- function(data, params, individual_data = NULL) { data$eigen_values <- eigen_decomp$Dsq data$VtXty <- t(eigen_decomp$V) %*% data$Xty - # SuSiE.ash requires the X matrix and standardized y vector if (params$unmappable_effects == "ash") { if (is.null(individual_data)) { stop("Adaptive shrinkage (ash) requires individual-level data") } - data$X <- individual_data$X - data$y <- individual_data$y / y_scale_factor - data$VtXt <- t(data$eigen_vectors) %*% t(individual_data$X) + + X_scaled <- scale_design_matrix( + individual_data$X, + center = attr(individual_data$X, "scaled:center"), + scale = attr(individual_data$X, "scaled:scale") + ) + + data$X <- X_scaled + data$y <- individual_data$y + data$VtXt <- t(data$eigen_vectors) %*% t(X_scaled) } return(data) } +#' Scale design matrix using centering and scaling parameters +#' +#' Applies column-wise centering and scaling to match the space used by +#' compute_XtX() and compute_Xty() for unmappable effects methods. +#' +#' @param X Matrix to scale (n × p) +#' @param center Vector of column means to subtract (length p), or NULL +#' @param scale Vector of column SDs to divide by (length p), or NULL +#' +#' @return Scaled matrix with centered and scaled columns +#' +#' @keywords internal +scale_design_matrix <- function(X, center = NULL, scale = NULL) { + if (is.null(center)) center <- rep(0, ncol(X)) + if (is.null(scale)) scale <- rep(1, ncol(X)) + + X_centered <- sweep(X, 2, center, "-") + X_scaled <- sweep(X_centered, 2, scale, "/") + + return(X_scaled) +} + # Compute Omega-weighted quantities for unmappable effects methods #' @keywords internal compute_omega_quantities <- function(data, tau2, sigma2) {