Skip to content
This repository was archived by the owner on Nov 20, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions R/susie_constructors.R
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,6 @@ sufficient_stats_constructor <- function(XtX, Xty, yty, n,
if (any(is.infinite(Xty))) {
stop("Input Xty contains infinite values.")
}
if (!(is.double(XtX) & is.matrix(XtX)) & !inherits(XtX, "sparseMatrix")) {
stop("Input XtX must be a double-precision matrix, or a sparse matrix.")
}
if (anyNA(XtX)) {
stop("Input XtX matrix contains NAs.")
}
Expand Down
3 changes: 2 additions & 1 deletion R/susie_get_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,9 @@ get_cs_correlation <- function(model, X = NULL, Xcorr = NULL, max = FALSE) {
}
if (max) {
cs_corr <- max(abs(cs_corr[upper.tri(cs_corr)]))
} else {
rownames(cs_corr) <- colnames(cs_corr) <- names(model$sets$cs)
}
rownames(cs_corr) <- colnames(cs_corr) <- names(model$sets$cs)
return(cs_corr)
}

Expand Down
181 changes: 181 additions & 0 deletions tests/testthat/test_generic_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,184 @@ test_that("cleanup_model.default removes temporary fields", {
expect_false("residuals" %in% names(result))
expect_false("fitted_without_l" %in% names(result))
})

# =============================================================================
# DEFAULT METHOD ERROR MESSAGES
# =============================================================================

test_that("get_var_y.default throws error for unimplemented class", {
data <- structure(list(y = rnorm(50)), class = "unsupported_class")

expect_error(
get_var_y.default(data),
"get_var_y: no method for class 'unsupported_class'"
)
})

test_that("initialize_susie_model.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
params <- list(L = 5)

expect_error(
initialize_susie_model.default(data, params),
"initialize_susie_model: no method for class 'unsupported_class'"
)
})

test_that("initialize_fitted.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
mat_init <- matrix(0, 5, 10)

expect_error(
initialize_fitted.default(data, mat_init),
"initialize_fitted: no method for class 'unsupported_class'"
)
})

test_that("compute_residuals.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
params <- list()
model <- list(alpha = matrix(1/10, 5, 10), V = rep(1, 5))
l <- 1

expect_error(
compute_residuals.default(data, params, model, l),
"compute_residuals: no method for class 'unsupported_class'"
)
})

test_that("compute_ser_statistics.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
params <- list()
model <- list(alpha = matrix(1/10, 5, 10), residuals = rnorm(50))
l <- 1

expect_error(
compute_ser_statistics.default(data, params, model, l),
"compute_ser_statistics: no method for class 'unsupported_class'"
)
})

test_that("SER_posterior_e_loglik.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
params <- list()
model <- list(alpha = matrix(1/10, 5, 10), lbf_variable = matrix(0, 5, 10))
l <- 1

expect_error(
SER_posterior_e_loglik.default(data, params, model, l),
"SER_posterior_e_loglik: no method for class 'unsupported_class'"
)
})

test_that("calculate_posterior_moments.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
params <- list()
model <- list(alpha = matrix(1/10, 5, 10))
V <- 1.0

expect_error(
calculate_posterior_moments.default(data, params, model, V),
"calculate_posterior_moments: no method for class 'unsupported_class'"
)
})

test_that("get_ER2.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
model <- list(alpha = matrix(1/10, 5, 10), sigma2 = 1)

expect_error(
get_ER2.default(data, model),
"get_ER2: no method for class 'unsupported_class'"
)
})

test_that("Eloglik.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
model <- list(alpha = matrix(1/10, 5, 10), sigma2 = 1)

expect_error(
Eloglik.default(data, model),
"Eloglik: no method for class 'unsupported_class'"
)
})

test_that("loglik.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
params <- list()
model <- list(alpha = matrix(1/10, 5, 10), sigma2 = 1)
V <- 1.0
ser_stats <- list(betahat = rnorm(10), shat2 = rep(1, 10))

expect_error(
loglik.default(data, params, model, V, ser_stats),
"loglik: no method for class 'unsupported_class'"
)
})

test_that("neg_loglik.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
params <- list()
model <- list(alpha = matrix(1/10, 5, 10), sigma2 = 1)
V_param <- 0.0 # log scale
ser_stats <- list(betahat = rnorm(10), shat2 = rep(1, 10))

expect_error(
neg_loglik.default(data, params, model, V_param, ser_stats),
"neg_loglik: no method for class 'unsupported_class'"
)
})

test_that("update_fitted_values.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
params <- list()
model <- list(alpha = matrix(1/10, 5, 10), mu = matrix(0, 5, 10))
l <- 1

expect_error(
update_fitted_values.default(data, params, model, l),
"update_fitted_values: no method for class 'unsupported_class'"
)
})

test_that("get_scale_factors.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
params <- list()

expect_error(
get_scale_factors.default(data, params),
"get_scale_factors: no method for class 'unsupported_class'"
)
})

test_that("get_intercept.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
params <- list()
model <- list(alpha = matrix(1/10, 5, 10), mu = matrix(0, 5, 10))

expect_error(
get_intercept.default(data, params, model),
"get_intercept: no method for class 'unsupported_class'"
)
})

test_that("get_cs.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
params <- list(coverage = 0.95, min_abs_corr = 0.5)
model <- list(alpha = matrix(1/10, 5, 10))

expect_error(
get_cs.default(data, params, model),
"get_cs: no method for class 'unsupported_class'"
)
})

test_that("get_variable_names.default throws error for unimplemented class", {
data <- structure(list(n = 50, p = 10), class = "unsupported_class")
model <- list(alpha = matrix(1/10, 5, 10))

expect_error(
get_variable_names.default(data, model),
"get_variable_names: no method for class 'unsupported_class'"
)
})
18 changes: 18 additions & 0 deletions tests/testthat/test_individual_data_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,24 @@ test_that("get_zscore.individual returns default when compute_univariate_zscore=
expect_null(z)
})

test_that("get_zscore.individual warns when X is not a matrix (sparse/trend filtering)", {
setup <- setup_individual_data()
setup$params$compute_univariate_zscore <- TRUE

# Convert X to sparse matrix
setup$data$X <- Matrix::Matrix(setup$data$X, sparse = TRUE)

# Should produce warning about slow computation
expect_message(
z <- get_zscore.individual(setup$data, setup$params, setup$model),
"Calculation of univariate regression z-scores is not implemented specifically for sparse or trend filtering matrices"
)

# Should still compute z-scores
expect_length(z, setup$data$p)
expect_type(z, "double")
})

test_that("cleanup_model.individual removes temporary fields", {
setup <- setup_individual_data()

Expand Down
107 changes: 107 additions & 0 deletions tests/testthat/test_plotting.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ test_that("susie_plot with z-scores requires compute_univariate_zscore", {
)
})

test_that("susie_plot with z_original also requires z-scores", {
set.seed(51)
dat <- simulate_regression(n = 100, p = 50, k = 3)
fit <- susie(dat$X, dat$y, L = 5, compute_univariate_zscore = FALSE, verbose = FALSE)

# Should error when trying to plot z_original without z-scores
expect_error(
susie_plot(fit, "z_original"),
"z-scores are not available"
)
})

test_that("susie_plot with z-scores works when available", {
set.seed(3)
dat <- simulate_regression(n = 100, p = 50, k = 3)
Expand Down Expand Up @@ -411,10 +423,105 @@ test_that("susie_plot with b parameter highlights specific positions", {
)
})

test_that("susie_plot sets x0 and y1 to NULL when CS filtered by max_cs", {
set.seed(52)
dat <- simulate_regression(n = 200, p = 100, k = 3, signal_sd = 2)
fit <- susie(dat$X, dat$y, L = 10, verbose = FALSE)

# Get CS with purity info
fit$sets <- susie_get_cs(fit, X = dat$X, coverage = 0.95)

if (!is.null(fit$sets$cs) && length(fit$sets$cs) > 0) {
# Use very strict max_cs filter (size < 1) to exclude CS
# This should trigger the else branch: x0 <- NULL; y1 <- NULL
expect_error(
susie_plot(fit, "PIP", max_cs = 1, add_legend = TRUE), # Only CS with size < 1
NA
)

# Also test with very high purity threshold (max_cs as purity)
expect_error(
susie_plot(fit, "PIP", max_cs = 0.999, add_legend = TRUE), # Very high purity
NA
)
} else {
skip("No CS found for max_cs filter test")
}
})

test_that("susie_plot skips CS when x0 is NULL (next statement)", {
set.seed(53)
dat <- simulate_regression(n = 200, p = 100, k = 3, signal_sd = 2)
fit <- susie(dat$X, dat$y, L = 10, verbose = FALSE)

# Get CS
fit$sets <- susie_get_cs(fit, X = dat$X, coverage = 0.95)

if (!is.null(fit$sets$cs) && length(fit$sets$cs) > 0) {
# Use max_cs to filter out large CS, causing is.null(x0) to be TRUE
# This should trigger the next statement to skip those CS
expect_error(
susie_plot(fit, "PIP", max_cs = 2), # Skip CS with > 2 variables
NA
)
} else {
skip("No CS found for next statement test")
}
})

test_that("susie_plot uses cs_index when available (else uses cs_idx)", {
set.seed(54)
dat <- simulate_regression(n = 200, p = 100, k = 3, signal_sd = 2)
fit <- susie(dat$X, dat$y, L = 10, verbose = FALSE)

# Get CS which should populate cs_index
fit$sets <- susie_get_cs(fit, X = dat$X, coverage = 0.95)

if (!is.null(fit$sets$cs) && length(fit$sets$cs) > 0) {
# When cs_index exists, should use it
expect_true(!is.null(fit$sets$cs_index))

# Plot with legend to see cs_index values
expect_error(
susie_plot(fit, "PIP", add_legend = TRUE),
NA
)

# Test the else branch: remove cs_index to force use of cs_idx
fit_no_index <- fit
fit_no_index$sets$cs_index <- NULL

expect_error(
susie_plot(fit_no_index, "PIP", add_legend = TRUE),
NA
)
} else {
skip("No CS found for cs_index test")
}
})

# =============================================================================
# SUSIE_PLOT_ITERATION
# =============================================================================

test_that("susie_plot_iteration uses tempdir when file_prefix missing", {
set.seed(55)
dat <- simulate_regression(n = 100, p = 50, k = 3)
fit <- susie(dat$X, dat$y, L = 5, track_fit = FALSE, verbose = FALSE)

# Don't provide file_prefix - should use tempdir()
result <- invisible(capture.output({
suppressMessages(susie_plot_iteration(fit, L = 5))
}, type = "output"))

# Check that file was created in tempdir
expected_path <- file.path(tempdir(), "susie_plot.pdf")
expect_true(file.exists(expected_path))

# Clean up
if (file.exists(expected_path)) file.remove(expected_path)
})

test_that("susie_plot_iteration with track_fit=FALSE uses final fit only", {
set.seed(19)
dat <- simulate_regression(n = 100, p = 50, k = 3)
Expand Down
Loading