Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,13 @@ celerybeat.pid
*.h5
# MyST build outputs
_build

# PyTorch Lightning logs and checkpoints
lightning_logs/
saved_models/
.pt_tmp/
*.ckpt
*.yaml

# MDN model cache
microimpute_models/
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
added:
- Mixture Density Network (MDN) model for numeric variable imputation and Neural Classifier model for categorical variable imputation.
26 changes: 24 additions & 2 deletions microimpute/comparisons/autoimpute.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,23 @@
except ImportError:
HAS_MATCHING = False


try:
from microimpute.models import MDN

HAS_MDN = True
except ImportError:
HAS_MDN = False

log = logging.getLogger(__name__)

# Internal constants for model compatibility with variable types
_NUMERICAL_MODELS = {"OLS", "QRF", "QuantReg", "Matching"}
_NUMERICAL_MODELS = {"OLS", "QRF", "QuantReg", "Matching", "MDN"}
_CATEGORICAL_MODELS = {
"OLS",
"QRF",
"Matching",
"MDN",
} # QuantReg doesn't support categorical


Expand Down Expand Up @@ -208,7 +217,7 @@ def _evaluate_models_parallel(
if len(result) == 3:
model_name, cv_result, best_params = result
method_results[model_name] = cv_result
if model_name in ["QRF", "Matching"]:
if model_name in ["QRF", "Matching", "MDN"]:
best_hyperparams[model_name] = best_params
else:
model_name, cv_result = result
Expand Down Expand Up @@ -327,6 +336,7 @@ def autoimpute(
random_state: Optional[int] = RANDOM_STATE,
train_size: Optional[float] = TRAIN_SIZE,
k_folds: Optional[int] = 5,
force_retrain: Optional[bool] = False,
log_level: Optional[str] = "WARNING",
) -> AutoImputeResult:
"""Automatically select and apply the best imputation model.
Expand Down Expand Up @@ -364,6 +374,8 @@ def autoimpute(
random_state : Random seed for reproducibility
train_size : Proportion of data to use for training in preprocessing
k_folds : Number of folds for cross-validation. Defaults to 5.
force_retrain : If True, forces MDN models to retrain instead of using
cached models. Defaults to False.
log_level : Logging level for the function. Defaults to "WARNING".

Returns:
Expand Down Expand Up @@ -440,9 +452,19 @@ def autoimpute(
model_classes: List[Type[Imputer]] = [QRF, OLS, QuantReg]
if HAS_MATCHING:
model_classes.append(Matching)
if HAS_MDN:
model_classes.append(MDN)
else:
model_classes = models

# Inject force_retrain for MDN if it's in the model list
if force_retrain and any(m.__name__ == "MDN" for m in model_classes):
if hyperparameters is None:
hyperparameters = {}
if "MDN" not in hyperparameters:
hyperparameters["MDN"] = {}
hyperparameters["MDN"]["force_retrain"] = True

# Log hyperparameter usage
if hyperparameters:
model_names = [
Expand Down
2 changes: 1 addition & 1 deletion microimpute/comparisons/autoimpute_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def fit_and_predict_model(
weight_col=weight_col,
quantiles=[quantile],
)
elif hyperparams and model_name in ["Matching", "QRF"]:
elif hyperparams and model_name in ["Matching", "QRF", "MDN"]:
# Apply hyperparameters for specific models
fitted_model = model.fit(
training_data,
Expand Down
43 changes: 38 additions & 5 deletions microimpute/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,45 @@
# Random state for reproducibility
RANDOM_STATE: int = 42

# Model parameters
# Model parameters (passed via **kwargs to fit() or as __init__ params)
DEFAULT_MODEL_PARAMS: Dict[str, Dict[str, Any]] = {
"qrf": {},
"quantreg": {},
"ols": {},
"matching": {},
"qrf": {
# RandomForestQuantileRegressor parameters
"n_estimators": 100,
"max_depth": None,
"min_samples_split": 2,
"min_samples_leaf": 1,
"max_features": 1.0,
},
"quantreg": {
# statsmodels QuantReg uses default parameters
},
"ols": {
# statsmodels OLS uses default parameters
# LogisticRegression params for categorical targets:
"penalty": "l2",
"C": 1.0,
"max_iter": 1000,
},
"matching": {
# StatMatch NND hotdeck default parameters
},
"mdn": {
# Backbone network parameters
"layers": "128-64-32",
"activation": "ReLU",
"dropout": 0.0,
"use_batch_norm": False,
# MDN head parameters
"num_gaussian": 5,
"softmax_temperature": 1.0,
"n_samples": 100,
# Training parameters
"learning_rate": 1e-3,
"max_epochs": 100,
"early_stopping_patience": 10,
"batch_size": 256,
},
}

# Plotting configuration
Expand Down
4 changes: 2 additions & 2 deletions microimpute/evaluations/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def _fit_model_for_fold(
return None, None
raise e

# Handle hyperparameter tuning for QRF and Matching
elif tune_hyperparameters and model_name in ["QRF", "Matching"]:
# Handle hyperparameter tuning for QRF, Matching, and MDN
elif tune_hyperparameters and model_name in ["QRF", "Matching", "MDN"]:
log.info(f"Tuning {model_name} hyperparameters during fitting")
fitted_model, fold_tuned_params = model.fit(
train_data,
Expand Down
7 changes: 7 additions & 0 deletions microimpute/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
- QRF: quantile random forest for non-parametric quantile regression
- QuantReg: linear quantile regression model
- Matching: statistical matching/hot-deck imputation (optional, requires rpy2)
- MDN: Mixture Density Network for probabilistic imputation
(optional, requires pytorch-tabular)

Base classes:
- Imputer: abstract base class for all imputation models
Expand All @@ -23,6 +25,11 @@
except ImportError:
pass

try:
from microimpute.models.mdn import MDN
except ImportError:
pass

# Import specific model implementations
from microimpute.models.ols import OLS
from microimpute.models.qrf import QRF
Expand Down
Loading