Skip to content
This repository was archived by the owner on Nov 20, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions R/sufficient_stats_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ initialize_susie_model.ss <- function(data, params, var_y, ...) {
if (params$unmappable_effects %in% c("inf", "ash")) {

# Initialize omega quantities for unmappable effects
omega_res <- compute_omega_quantities(data, tau2 = 0, sigma2 = 1)
omega_res <- compute_omega_quantities(data, tau2 = 0, sigma2 = var_y)
model$omega_var <- omega_res$omega_var
model$predictor_weights <- omega_res$diagXtOmegaX
model$XtOmegay <- data$eigen_vectors %*% (data$VtXty / omega_res$omega_var)
Expand Down Expand Up @@ -291,10 +291,8 @@ neg_loglik.ss <- function(data, params, model, V_param, ser_stats, ...) {
#' @keywords internal
update_fitted_values.ss <- function(data, params, model, l) {
if (params$unmappable_effects != "none") {
model$XtXr <- compute_Xb(data$XtX, colSums(model$alpha * model$mu) + model$theta)
model$XtXr <- as.vector(data$XtX %*% (colSums(model$alpha * model$mu) + model$theta))
} else {
# Fix: Use direct matrix multiplication to match original implementation
# Original: s$XtXr = s$XtXr + XtX %*% (s$alpha[l,] * s$mu[l,])
model$XtXr <- model$fitted_without_l + as.vector(data$XtX %*% (model$alpha[l, ] * model$mu[l, ]))
}
return(model)
Expand Down Expand Up @@ -344,7 +342,7 @@ update_variance_components.ss <- function(data, params, model, ...) {
))
}

# Remove the sparse effects
# Remove the sparse effects to compute residuals for mr.ash
b <- colSums(model$alpha * model$mu)
residuals <- data$y - data$X %*% b

Expand Down
12 changes: 12 additions & 0 deletions R/susie_constructors.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ individual_data_constructor <- function(X, y, L = min(10, ncol(X)),
}
mean_y <- mean(y)

# Force required preprocessing for unmappable effects methods
if (unmappable_effects != "none") {
if (!intercept) {
warning_message("Unmappable effects methods require centered data. Setting intercept=TRUE.")
intercept <- TRUE
}
if (!standardize) {
warning_message("Unmappable effects methods require scaled data. Setting standardize=TRUE.")
standardize <- TRUE
}
}

# Handle null weights
if (is.numeric(null_weight) && null_weight == 0) {
null_weight <- NULL
Expand Down
53 changes: 34 additions & 19 deletions R/susie_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,8 @@ add_null_effect <- function(model_init, V) {
# and log Bayes factor calculations.
#
# Functions: compute_eigen_decomposition, add_eigen_decomposition,
# compute_omega_quantities, compute_theta_blup, lbf_stabilization,
# compute_posterior_weights, compute_lbf_gradient
# compute_omega_quantities, scale_design_matrix, compute_theta_blup,
# lbf_stabilization, compute_posterior_weights, compute_lbf_gradient
# =============================================================================

# Compute eigenvalue decomposition for unmappable methods
Expand All @@ -617,19 +617,6 @@ compute_eigen_decomposition <- function(XtX, n) {
# Add eigen decomposition to ss data objects for unmappable methods
#' @keywords internal
add_eigen_decomposition <- function(data, params, individual_data = NULL) {
# Standardize y to unit variance for all unmappable effects methods
y_scale_factor <- 1

if (params$unmappable_effects != "none") {
var_y <- data$yty / (data$n - 1)
if (abs(var_y - 1) > 1e-10) {
sd_y <- sqrt(var_y)
data$yty <- data$yty / var_y
data$Xty <- data$Xty / sd_y
y_scale_factor <- sd_y
}
}

# Compute eigen decomposition
eigen_decomp <- compute_eigen_decomposition(data$XtX, data$n)

Expand All @@ -638,19 +625,47 @@ add_eigen_decomposition <- function(data, params, individual_data = NULL) {
data$eigen_values <- eigen_decomp$Dsq
data$VtXty <- t(eigen_decomp$V) %*% data$Xty

# SuSiE.ash requires the X matrix and standardized y vector
if (params$unmappable_effects == "ash") {
if (is.null(individual_data)) {
stop("Adaptive shrinkage (ash) requires individual-level data")
}
data$X <- individual_data$X
data$y <- individual_data$y / y_scale_factor
data$VtXt <- t(data$eigen_vectors) %*% t(individual_data$X)

X_scaled <- scale_design_matrix(
individual_data$X,
center = attr(individual_data$X, "scaled:center"),
scale = attr(individual_data$X, "scaled:scale")
)

data$X <- X_scaled
data$y <- individual_data$y
data$VtXt <- t(data$eigen_vectors) %*% t(X_scaled)
}

return(data)
}

#' Scale design matrix using centering and scaling parameters
#'
#' Applies column-wise centering and scaling to match the space used by
#' compute_XtX() and compute_Xty() for unmappable effects methods.
#'
#' @param X Matrix to scale (n × p)
#' @param center Vector of column means to subtract (length p), or NULL
#' @param scale Vector of column SDs to divide by (length p), or NULL
#'
#' @return Scaled matrix with centered and scaled columns
#'
#' @keywords internal
scale_design_matrix <- function(X, center = NULL, scale = NULL) {
if (is.null(center)) center <- rep(0, ncol(X))
if (is.null(scale)) scale <- rep(1, ncol(X))

X_centered <- sweep(X, 2, center, "-")
X_scaled <- sweep(X_centered, 2, scale, "/")

return(X_scaled)
}

# Compute Omega-weighted quantities for unmappable effects methods
#' @keywords internal
compute_omega_quantities <- function(data, tau2, sigma2) {
Expand Down