-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Labels
Description
The following code attempts to fit a marginal density using both pooled and unpooled condensier estimates by way of sl3. The true density is standard normal. The unpooled estimates (red) approximate the true density(blue), but the unpooled estimates (green) do not.
options(sl3.verbose = FALSE)
library("condensier")
library("sl3")
library("simcausal")
library(ggplot2)
D <- DAG.empty()
D <- D + node("I", distr = "rbern", prob = 1) +
node("B", distr = "rnorm", mean = 0, sd = 1)
D <- set.DAG(D, n.test = 10)
datO <- sim(D, n = 10000, rndseed = 12345)
# ================================================================================
task <- sl3_Task$new(datO, covariates=c("I"),outcome="B")
lrn_unpooled <- Lrnr_condensier$new(nbins = 25, bin_method = "equal.len", pool = FALSE,
bin_estimator = Lrnr_glm_fast$new(family = binomial()))
lrn_unpooled_fit <- lrn_unpooled$train(task)
lrn_pooled <- Lrnr_condensier$new(nbins = 25, bin_method = "equal.len", pool = TRUE,
bin_estimator = Lrnr_glm_fast$new(family = binomial()))
lrn_pooled_fit <- lrn_pooled$train(task)
# ================================================================================
# evaluate fit on training data
datO$pred_fB_unpooled <- lrn_unpooled_fit$predict(task)
datO$pred_fB_pooled <- lrn_pooled_fit$predict(task)
datO$fB <- dnorm(datO$B)
long <- melt(datO, id=c("B"), measure=c("pred_fB_unpooled","pred_fB_pooled","fB"))
ggplot(long, aes(x=B, y=value, color=variable)) + geom_point() + theme_bw()
