diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 6a82c25..bab5de8 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -5,7 +5,33 @@ on:
branches: [ main ]
jobs:
+ Check-MDN-Changes:
+ runs-on: ubuntu-latest
+ outputs:
+ mdn_changed: ${{ steps.check.outputs.mdn_changed }}
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: Check for MDN-related file changes
+ id: check
+ run: |
+ # Get list of changed files in this push
+ CHANGED_FILES=$(git diff --name-only HEAD~1 HEAD)
+ echo "Changed files:"
+ echo "$CHANGED_FILES"
+
+ # Check if any MDN-related files were changed
+ if echo "$CHANGED_FILES" | grep -qE "(mdn|MDN)"; then
+ echo "mdn_changed=true" >> $GITHUB_OUTPUT
+ echo "MDN-related files were changed"
+ else
+ echo "mdn_changed=false" >> $GITHUB_OUTPUT
+ echo "No MDN-related files were changed"
+ fi
+
Test:
+ needs: Check-MDN-Changes
runs-on: ubuntu-latest
strategy:
matrix:
@@ -30,10 +56,14 @@ jobs:
run: |
sudo Rscript -e 'install.packages("StatMatch", repos="https://cloud.r-project.org")'
sudo Rscript -e 'install.packages("clue", repos="https://cloud.r-project.org")'
- - name: Install full dependencies (Python 3.13)
- if: matrix.python-version == '3.13'
+ - name: Install full dependencies without MDN (Python 3.13)
+ if: matrix.python-version == '3.13' && needs.Check-MDN-Changes.outputs.mdn_changed != 'true'
run: |
uv pip install -e ".[dev,docs,matching,images]" --system
+ - name: Install full dependencies with MDN (Python 3.13)
+ if: matrix.python-version == '3.13' && needs.Check-MDN-Changes.outputs.mdn_changed == 'true'
+ run: |
+ uv pip install -e ".[dev,docs,matching,mdn,images]" --system
- name: Install minimal dependencies (Python 3.12)
if: matrix.python-version == '3.12'
run: |
diff --git a/.github/workflows/pr_code_changes.yaml b/.github/workflows/pr_code_changes.yaml
index c8e2ce8..6a23e32 100644
--- a/.github/workflows/pr_code_changes.yaml
+++ b/.github/workflows/pr_code_changes.yaml
@@ -20,7 +20,33 @@ jobs:
- name: Check code formatting
run: make check-format
+ Check-MDN-Changes:
+ runs-on: ubuntu-latest
+ outputs:
+ mdn_changed: ${{ steps.check.outputs.mdn_changed }}
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+ - name: Check for MDN-related file changes
+ id: check
+ run: |
+ # Get list of changed files in this PR
+ CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}...HEAD)
+ echo "Changed files:"
+ echo "$CHANGED_FILES"
+
+ # Check if any MDN-related files were changed
+ if echo "$CHANGED_FILES" | grep -qE "(mdn|MDN)"; then
+ echo "mdn_changed=true" >> $GITHUB_OUTPUT
+ echo "MDN-related files were changed"
+ else
+ echo "mdn_changed=false" >> $GITHUB_OUTPUT
+ echo "No MDN-related files were changed"
+ fi
+
Test:
+ needs: Check-MDN-Changes
strategy:
matrix:
os: [ ubuntu-latest ]
@@ -49,10 +75,14 @@ jobs:
run: |
sudo Rscript -e 'install.packages("StatMatch", repos="https://cloud.r-project.org")'
sudo Rscript -e 'install.packages("clue", repos="https://cloud.r-project.org")'
- - name: Install full test dependencies (Python 3.13)
- if: matrix.python-version == '3.13'
+ - name: Install full test dependencies without MDN (Python 3.13)
+ if: matrix.python-version == '3.13' && needs.Check-MDN-Changes.outputs.mdn_changed != 'true'
run: |
uv pip install -e ".[dev,matching]" --system
+ - name: Install full test dependencies with MDN (Python 3.13)
+ if: matrix.python-version == '3.13' && needs.Check-MDN-Changes.outputs.mdn_changed == 'true'
+ run: |
+ uv pip install -e ".[dev,matching,mdn]" --system
- name: Install minimal test dependencies (Python 3.12)
if: matrix.python-version == '3.12'
run: |
diff --git a/.github/workflows/pr_docs_changes.yaml b/.github/workflows/pr_docs_changes.yaml
index 4481634..81de3a7 100644
--- a/.github/workflows/pr_docs_changes.yaml
+++ b/.github/workflows/pr_docs_changes.yaml
@@ -36,7 +36,7 @@ jobs:
sudo Rscript -e 'install.packages("clue", repos="https://cloud.r-project.org")'
- name: Install dependencies
run: |
- uv pip install -e ".[dev,docs,matching,images]" --system
+ uv pip install -e ".[dev,docs,matching,mdn,images]" --system
- name: Install JB
run: uv pip install "jupyter-book>=2.0.0a0" --system
- name: Test documentation builds
diff --git a/changelog_entry.yaml b/changelog_entry.yaml
index e69de29..1dffb00 100644
--- a/changelog_entry.yaml
+++ b/changelog_entry.yaml
@@ -0,0 +1,5 @@
+- bump: minor
+ changes:
+ added:
+ - Asinh transformation preprocessing option for numeric variables.
+ - Documentation for MDN model and preprocessing options.
diff --git a/docs/_config.yml b/docs/_config.yml
index bce2361..678d015 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -4,7 +4,7 @@ logo: logo.png
execute:
execute_notebooks: force
- timeout: 120
+ timeout: 360
repository:
url: https://github.com/policyengine/microimpute
diff --git a/docs/_toc.yml b/docs/_toc.yml
index a8b558d..dc6d677 100644
--- a/docs/_toc.yml
+++ b/docs/_toc.yml
@@ -1,6 +1,6 @@
format: jb-book
root: index
-parts:
+parts:
- caption: Models
chapters:
- file: models/imputer/index
@@ -18,19 +18,25 @@ parts:
- file: models/quantreg/index
sections:
- file: models/quantreg/quantreg-imputation
+ - file: models/mdn/index
+ sections:
+ - file: models/mdn/mdn-imputation
- caption: Imputation and benchmarking
chapters:
- file: imputation-benchmarking/index
sections:
+ - file: imputation-benchmarking/preprocessing
+ - file: imputation-benchmarking/cross-validation
+ - file: imputation-benchmarking/metrics
+ - file: imputation-benchmarking/visualizations
- file: imputation-benchmarking/benchmarking-methods
- - file: imputation-benchmarking/imputing-across-surveys
- caption: AutoImpute
chapters:
- file: autoimpute/index
sections:
- file: autoimpute/autoimpute
- - caption: SCF to CPS example
+ - caption: Use cases
chapters:
- - file: examples/scf_to_cps/index
+ - file: use_cases/index
sections:
- - file: examples/scf_to_cps/imputing-from-scf-to-cps
+ - file: use_cases/scf_to_cps/imputing-from-scf-to-cps
diff --git a/docs/autoimpute/autoimpute.ipynb b/docs/autoimpute/autoimpute.ipynb
index 904b855..988a9aa 100644
--- a/docs/autoimpute/autoimpute.ipynb
+++ b/docs/autoimpute/autoimpute.ipynb
@@ -6,27 +6,91 @@
"source": [
"# Autoimputation\n",
"\n",
- "This notebook demonstrates the functionality of the `autoimpute` module, which provides an automated approach to selecting and applying optimal imputation methods for missing data. Rather than manually testing different approaches, `autoimpute` evaluates multiple methods (tuning their hyperparameters to the specific dataset), identifies which performs best for your specific data, and applies it to generate high-quality imputations."
+ "This notebook demonstrates the functionality of the `autoimpute` module, which provides an automated approach to selecting and applying optimal imputation methods for missing data. Rather than manually testing different approaches, `autoimpute` evaluates multiple methods (tuning their hyperparameters to the specific dataset), identifies which performs best for your specific data, and applies it to generate high-quality imputations.\n",
+ "\n",
+ "### autoimpute function\n",
+ "\n",
+ "```python\n",
+ "def autoimpute(\n",
+ " donor_data: pd.DataFrame,\n",
+ " receiver_data: pd.DataFrame,\n",
+ " predictors: List[str],\n",
+ " imputed_variables: List[str],\n",
+ " weight_col: Optional[str] = None,\n",
+ " models: Optional[List[Type]] = None,\n",
+ " imputation_quantiles: Optional[List[float]] = None,\n",
+ " hyperparameters: Optional[Dict[str, Dict[str, Any]]] = None,\n",
+ " tune_hyperparameters: Optional[bool] = False,\n",
+ " preprocessing: Optional[Dict[str, str]] = None,\n",
+ " impute_all: Optional[bool] = False,\n",
+ " metric_priority: Optional[str] = \"auto\",\n",
+ " random_state: Optional[int] = RANDOM_STATE,\n",
+ " train_size: Optional[float] = TRAIN_SIZE,\n",
+ " k_folds: Optional[int] = 5,\n",
+ " force_retrain: Optional[bool] = False,\n",
+ " log_level: Optional[str] = \"WARNING\",\n",
+ ") -> AutoImputeResult\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| *donor_data* | pd.DataFrame | - | DataFrame with predictor and target variables for training |\n",
+ "| *receiver_data* | pd.DataFrame | - | DataFrame where imputed values will be generated |\n",
+ "| *predictors* | List[str] | - | Column names of predictor variables |\n",
+ "| *imputed_variables* | List[str] | - | Column names of variables to impute |\n",
+ "| *weight_col* | str | None | Column name for sampling weights |\n",
+ "| *models* | List[Type] | [QRF, OLS, QuantReg, Matching, MDN] | List of imputer classes to compare. |\n",
+ "| *imputation_quantiles* | List[float] | [0.05 to 0.95 in steps of 0.05] | Quantiles at which to predict |\n",
+ "| *hyperparameters* | Dict | None | Model-specific hyperparameters (e.g., {\"QRF\": {\"n_estimators\": 200}}) |\n",
+ "| *tune_hyperparameters* | bool | False | Enable automatic hyperparameter tuning |\n",
+ "| *preprocessing* | Dict[str, str] | None | Variable transformations: {\"var\": \"normalize\"/\"log\"/\"asinh\"} |\n",
+ "| *impute_all* | bool | False | Return imputations for all models, not just the best |\n",
+ "| *metric_priority* | str | \"auto\" | Model selection strategy: \"auto\", \"numerical\", \"categorical\", \"combined\" |\n",
+ "| *random_state* | int | 42 | Random seed for reproducibility |\n",
+ "| *train_size* | float | 0.8 | Proportion of donor data for training in cross-validation |\n",
+ "| *k_folds* | int | 5 | Number of cross-validation folds |\n",
+ "| *force_retrain* | bool | False | Force MDN model retraining (bypass cache) |\n",
+ "| *log_level* | str | \"WARNING\" | Logging verbosity level |\n",
+ "\n",
+ "### AutoImputeResult\n",
+ "\n",
+ "The function returns an `AutoImputeResult` object with the following attributes:\n",
+ "\n",
+ "| Attribute | Type | Description |\n",
+ "|-----------|------|-------------|\n",
+ "| *imputations* | Dict | Maps model names to quantile → DataFrame of imputed values |\n",
+ "| *receiver_data* | pd.DataFrame | Receiver data with imputed values integrated |\n",
+ "| *fitted_models* | Dict | Maps model names to fitted ImputerResults objects (if `impute_all=True` also includes all other fitted models)|\n",
+ "| *cv_results* | Dict | Cross-validation metrics per model (quantile_loss, log_loss) |\n",
+ "\n",
+ "Access the best model's imputations using `AutoImputeResult.imputations[\"best_method\"]`."
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "import logging\n",
+ "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
+ "logging.getLogger(\"pytorch_tabular\").setLevel(logging.ERROR)\n",
+ "logging.getLogger(\"joblib\").setLevel(logging.ERROR)\n",
+ "\n",
"import pandas as pd\n",
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"from sklearn.datasets import load_diabetes\n",
- "import warnings\n",
"\n",
- "# Set pandas display options to limit table width\n",
"pd.set_option(\"display.width\", 600)\n",
"pd.set_option(\"display.max_columns\", 10)\n",
"pd.set_option(\"display.expand_frame_repr\", False)\n",
"\n",
"from microimpute.comparisons.autoimpute import autoimpute\n",
+ "from microimpute.models import OLS, QuantReg, QRF, Matching\n",
"from microimpute.visualizations.comparison_plots import method_comparison_results"
]
},
@@ -41,7 +105,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -156,7 +220,7 @@
"4 0.005383 -0.044642 -0.036385 0.021872 0.003935 0.015596 0.008142 -0.002592 -0.031988 -0.046641"
]
},
- "execution_count": 4,
+ "execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@@ -179,7 +243,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -229,13 +293,13 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "52000f07ebe4466d941a4e2bdbc5b49a",
+ "model_id": "0fb0cd8da8ad4d65bf84cb943ebaf1c8",
"version_major": 2,
"version_minor": 0
},
@@ -251,12 +315,18 @@
"output_type": "stream",
"text": [
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 4.5s finished\n",
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.5s finished\n",
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.7s finished\n",
- "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 4.8s\n"
+ "R callback write-console: Warning: \n",
+ "R callback write-console: failed to download mirrors file (cannot open URL 'https://cran.r-project.org/CRAN_mirrors.csv'); using local file '/opt/homebrew/Cellar/r/4.5.0/lib/R/doc/CRAN_mirrors.csv'\n",
+ " \n",
+ "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 1.6min\n",
+ "R callback write-console: Warning: \n",
+ "R callback write-console: failed to download mirrors file (cannot open URL 'https://cran.r-project.org/CRAN_mirrors.csv'); using local file '/opt/homebrew/Cellar/r/4.5.0/lib/R/doc/CRAN_mirrors.csv'\n",
+ " \n",
+ "R callback write-console: Warning: \n",
+ "R callback write-console: failed to download mirrors file (cannot open URL 'https://cran.r-project.org/CRAN_mirrors.csv'); using local file '/opt/homebrew/Cellar/r/4.5.0/lib/R/doc/CRAN_mirrors.csv'\n",
+ " \n"
]
},
{
@@ -271,7 +341,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 10.7s finished\n"
+ "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 5.2min finished\n"
]
}
],
@@ -284,8 +354,9 @@
" receiver_data=receiver_data_without_targets,\n",
" predictors=predictors,\n",
" imputed_variables=imputed_variables,\n",
- " tune_hyperparameters=False, # enable automated hyperparameter tuning if desired\n",
- " k_folds=3, # Number of cross-validation folds\n",
+ " models=[OLS, QuantReg, QRF, Matching], # MDN model excluded for efficiency\n",
+ " tune_hyperparameters=False,\n",
+ " k_folds=3,\n",
")\n",
"\n",
"print(\n",
@@ -299,17 +370,12 @@
"source": [
"## Understanding the results\n",
"\n",
- "The `autoimpute` function returns four key objects that provide comprehensive information about the imputation process:\n",
- "\n",
- "- `imputations`: A dictionary where keys are quantiles used for imputation and values are DataFrames containing the imputed values at each quantile\n",
- "- `receiver_data`: The receiver dataset with imputed values integrated into it\n",
- "- `fitted_models`: The best-performing imputation model, already fitted on the donor data (if `impute_all=True` this also includes all other models fitted on the data)\n",
- "- `cv_results`: A dictionary with detailed performance metrics for all evaluated imputation methods"
+ "The `autoimpute` function returns an AutoimputeResults objects that provide comprehensive information about the imputation process:"
]
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -349,7 +415,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -376,7 +442,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -1740,7 +1806,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -1830,7 +1896,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -1974,7 +2040,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -4853,7 +4919,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -5027,7 +5093,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -5076,6 +5142,7 @@
" receiver_data=receiver_data_without_targets,\n",
" predictors=predictors,\n",
" imputed_variables=imputed_variables,\n",
+ " models=[OLS, QuantReg, QRF, Matching], # MDN model excluded for efficiency\n",
" tune_hyperparameters=False,\n",
" impute_all=True,\n",
" k_folds=3,\n",
diff --git a/docs/autoimpute/index.md b/docs/autoimpute/index.md
index d1e07c9..1600464 100644
--- a/docs/autoimpute/index.md
+++ b/docs/autoimpute/index.md
@@ -1,5 +1,5 @@
# Autoimpute
-This documentation describes how the `autoimpute` function works to automize the entire method comparison, selection, and imputation pipeline in a single function.
+This chapter describes how the `autoimpute` function works to automize the entire method comparison, selection, and imputation pipeline in a single function.
The pipeline begins with input validation to ensure all necessary columns exist and quantiles are properly specified. It then preprocesses the donor and receiver datasets to prepare them for model training and evaluation. The function supports imputing numerical, categorical and boolean variable types, internally selecting the method corresponding to each variable type. At its core, `autoimpute` employs cross-validation on the donor data to evaluate multiple imputation methods. Each model is assessed on its ability to accurately predict known values using two different metrics: quantile loss for numerical imputation and log loss for categorical imputation. The method with the lowest average loss (with different metrics combined with a weighted-rank approach) across target variables is automatically selected as the optimal approach for the specific dataset and imputation task. The chosen model is then trained on the complete donor dataset and applied to generate imputations for the missing values in the receiver data. Finally, the pipeline reintegrates these imputed values back into the original receiver dataset, producing a complete dataset ready for downstream analysis.
diff --git a/docs/examples/scf_to_cps/index.md b/docs/examples/scf_to_cps/index.md
deleted file mode 100644
index 29da8c7..0000000
--- a/docs/examples/scf_to_cps/index.md
+++ /dev/null
@@ -1,5 +0,0 @@
-# Example: imputing wealth on the CPS data set
-
-This example demonstrates how to use the microimpute package to impute wealth variables from the Survey of Consumer Finances (SCF) onto the Current Population Survey (CPS).
-
-The Survey of Consumer Finances (SCF) contains detailed information about household wealth, but the Current Population Survey (CPS) does not. By using the microimpute package's `autoimpute` function, we can leverage the relationships between common variables in both surveys to impute wealth variables onto the CPS dataset.
diff --git a/docs/imputation-benchmarking/benchmarking-methods.ipynb b/docs/imputation-benchmarking/benchmarking-methods.ipynb
index 3ff1e54..9f9c3aa 100644
--- a/docs/imputation-benchmarking/benchmarking-methods.ipynb
+++ b/docs/imputation-benchmarking/benchmarking-methods.ipynb
@@ -7,12 +7,16 @@
"source": [
"# Benchmarking methods\n",
"\n",
- "This document provides a comprehensive guide to benchmarking different imputation methods using Microimpute. The examples below illustrate the workflow for comparing various imputation approaches and evaluating their performance."
+ "This notebook provides a guide to benchmarking different imputation methods using Microimpute through two example data sets.\n",
+ "\n",
+ "### With sklearn's Diabetes data set\n",
+ "\n",
+ "For data set details refer [here](https://scikit-learn.org/stable/datasets/toy_dataset.html)."
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 2,
"id": "45bc35df",
"metadata": {},
"outputs": [
@@ -1512,7 +1516,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "e96cf34166f04ac780929e9dd307fb37",
+ "model_id": "757ba887bc594faba40db29be449fc8a",
"version_major": 2,
"version_minor": 0
},
@@ -1526,7 +1530,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "3eb2a3940398412d9a5d1d8522ef70e1",
+ "model_id": "c4ac70f725d34a4ab138535bfefbf164",
"version_major": 2,
"version_minor": 0
},
@@ -1541,21 +1545,23 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Inclusion order: ['bp', 'age', 'bmi', 'sex']\n",
- "Predictor: bp \n",
- "Loss reduction: inf\n",
- "Predictor: age \n",
- "Loss reduction: 0.0005895300868971366\n",
- "Predictor: bmi \n",
- "Loss reduction: 0.0006466259122890027\n",
- "Predictor: sex \n",
- "Loss reduction: 0.0003358259619754521\n"
+ "Optimal subset: ['bp', 'age', 'bmi', 'sex']\n",
+ "Optimal loss: 0.016242569699135692\n",
+ " step predictor_added predictors_included avg_quantile_loss avg_log_loss \\\n",
+ "0 1 bp [bp] 0.017815 0 \n",
+ "1 2 age [bp, age] 0.017225 0 \n",
+ "2 3 bmi [bp, age, bmi] 0.016578 0 \n",
+ "3 4 sex [bp, age, bmi, sex] 0.016243 0 \n",
+ "\n",
+ " cumulative_improvement marginal_improvement \n",
+ "0 0.000000 0.000000 \n",
+ "1 0.000590 0.000590 \n",
+ "2 0.001236 0.000647 \n",
+ "3 0.001572 0.000336 \n"
]
}
],
"source": [
- "# On the Diabetes Dataset (Numerical Variables)\n",
- "\n",
"from typing import List, Type\n",
"\n",
"import pandas as pd\n",
@@ -1628,9 +1634,21 @@
" QRF\n",
")\n",
"\n",
- "print(f\"Inclusion order: {predictor_inclusion_results['inclusion_order']}\")\n",
- "for element in predictor_inclusion_results[\"predictor_impacts\"]:\n",
- " print(f\"Predictor: {element['predictor']} \\nLoss reduction: {element['loss_reduction']}\")"
+ "print(f\"Optimal subset: {predictor_inclusion_results['optimal_subset']}\")\n",
+ "print(f\"Optimal loss: {predictor_inclusion_results['optimal_loss']}\")\n",
+ "\n",
+ "# For step-by-step details:\n",
+ "print(predictor_inclusion_results['results_df'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c85837d6",
+ "metadata": {},
+ "source": [
+ "### With the US Federal Reserve Board's Survey of Consumer Finances\n",
+ "\n",
+ "For data set details refer [here](https://www.federalreserve.gov/econres/scfindex.htm)."
]
},
{
@@ -3210,197 +3228,6 @@
" imputed_variables=IMPUTED_VARIABLES,\n",
")}\")"
]
- },
- {
- "cell_type": "markdown",
- "id": "2a024ddb",
- "metadata": {},
- "source": [
- "## Data preparation\n",
- "\n",
- "The data preparation phase establishes the foundation for meaningful benchmarking comparisons. The `load_scf()` function downloads data from user-specified survey years, carefully selecting relevant predictor and target variables that capture the essential relationships for imputation.\n",
- "\n",
- "The `preprocess_data()` applies normalization techniques to the features when normalize=True, ensuring that variables with different scales don't unduly influence the imputation models. This preprocessing step is crucial, particularly if introducing additional for methods like nearest neighbor matching that rely on distance calculations. Finally, the function splits the data into training and testing sets, maintaining the statistical properties of both sets while creating an appropriate evaluation framework. If you would like to normalizing the data set without splitting it (for example in the event of performing cross-validation), set the full_data parameter to False.\n",
- "\n",
- "```python\n",
- "# Normalizing\n",
- "processed_data = preprocess_data(dataset, full_data=True)\n",
- "\n",
- "# Normalizing and splitting\n",
- "X_train, X_test = preprocess_data(dataset)\n",
- "```\n",
- "\n",
- "## Imputation generation\n",
- "\n",
- "The imputation generation process serves as the core operational phase of the benchmarking framework. The `get_imputations()` function orchestrates this process with remarkable efficiency, handling all aspects of model training and prediction generation. It systematically trains each specified model on identical training data, ensuring a fair comparison across different imputation approaches.\n",
- "\n",
- "After training, the function generates predictions at user-specified quantiles, allowing for evaluation across different parts of the conditional distribution. The quantile-based approach provides insights not just into central tendency (as with mean-based methods) but into the entire shape of the imputed distributions. This comprehensive prediction generation creates a rich dataset for subsequent evaluation.\n",
- "\n",
- "The function organizes all results into a consistent, structured format designed for straightforward comparison. The returned nested dictionary architecture provides intuitive access to predictions from different models at different quantiles:\n",
- "\n",
- "```\n",
- "{\n",
- " \"ModelName1\": {\n",
- " 0.1: DataFrame of predictions at 10th percentile,\n",
- " 0.5: DataFrame of predictions at 50th percentile,\n",
- " 0.9: DataFrame of predictions at 90th percentile.\n",
- " },\n",
- " \"ModelName2\": {\n",
- " 0.1: DataFrame of predictions at 10th percentile,\n",
- " ...\n",
- " },\n",
- " ...\n",
- "}\n",
- "```\n",
- "\n",
- "This well-designed data structure simplifies downstream analysis and visualization, allowing researchers to focus on interpreting results rather than managing data formats.\n",
- "\n",
- "At this stage, a model object can only handle the imputation of on variable at a time, meaning that to impute multiple variables from a data set, a new model object must be created for each of them. \n",
- "\n",
- "## Evaluation metrics for different variable types\n",
- "\n",
- "Microimpute employs evaluation metrics tailored to the type of variable being imputed. The framework automatically selects the appropriate metric based on whether the imputed variable is numerical or categorical, ensuring meaningful performance assessment across different data types.\n",
- "\n",
- "### Quantile loss for numerical imputation\n",
- "\n",
- "The evaluation of numerical imputation employs quantile loss to assess imputation quality. This approach provides a more nuanced evaluation than traditional metrics like mean squared error, particularly for capturing performance across different parts of the distribution.\n",
- "\n",
- "At the foundation of this evaluation lies the `quantile_loss()` function, which implements the standard quantile loss formulation: \n",
- "\n",
- "$$L(y, f, q) = \\text{max}(q \\cdot (y - f), (q - 1)(y - f))$$\n",
- "\n",
- "where $q$ is the quantile to be evaluated, $y$ represents the true value and $f$ is the imputed value. \n",
- "\n",
- "This mathematical formulation creates an asymmetric loss function that penalizes under-prediction more heavily for higher quantiles and over-prediction more heavily for lower quantiles. This asymmetry aligns perfectly with the interpretation of quantiles—a 90th percentile prediction should rarely be below the true value, while a 10th percentile prediction should rarely exceed it.\n",
- "\n",
- "### Log loss for categorical imputation\n",
- "\n",
- "When imputing categorical variables, the framework switches to log loss (also known as cross-entropy loss), which is specifically designed for evaluating probabilistic predictions of categorical outcomes. Log loss measures the performance of a classification model where the prediction output is a probability value between 0 and 1.\n",
- "\n",
- "The log loss metric is calculated using the formula:\n",
- "\n",
- "$$\\text{LogLoss} = -\\frac{1}{N}\\sum_{i=1}^{N}\\sum_{j=1}^{M} y_{ij} \\log(p_{ij})$$\n",
- "\n",
- "where:\n",
- "- $N$ is the number of samples\n",
- "- $M$ is the number of classes\n",
- "- $y_{ij}$ is 1 if sample $i$ belongs to class $j$, and 0 otherwise\n",
- "- $p_{ij}$ is the predicted probability of sample $i$ belonging to class $j$\n",
- "\n",
- "Unlike quantile loss which evaluates numerical predictions at different percentiles, log loss evaluates the quality of probability estimates for categorical predictions. A perfect classifier would have a log loss of 0, while worse predictions yield increasingly higher values. The metric heavily penalizes confident misclassifications, predicting a class with high probability when it's incorrect results in a large loss value.\n",
- "\n",
- "This distinction is crucial for proper model evaluation:\n",
- "- **Quantile loss** is used for continuous numerical variables where we care about the distribution of predicted values\n",
- "- **Log loss** is used for categorical variables where we care about the accuracy of class probability predictions\n",
- "\n",
- "The framework automatically detects the variable type and applies the appropriate metric. For models that handle both types of variables (like OLS and QRF), the evaluation will produce separate quantile loss results for numerical variables and log loss results for categorical variables.\n",
- "\n",
- "### Unified evaluation framework\n",
- "\n",
- "The integration of these complementary metrics culminates in the `compute_loss()` and `compare_loss()` functions, which systematically evaluate multiple methods using the appropriate metric for each variable type. When dealing with mixed datasets containing both numerical and categorical variables, the framework produces separate evaluation results for each metric type, allowing researchers to assess model performance comprehensively across all variable types.\n",
- "\n",
- "### Distribution similarity metrics\n",
- "\n",
- "The `compare_distributions()` function evaluates how well the imputed values preserve the distributional characteristics of the original data. It automatically selects the appropriate metric based on the variable type: Wasserstein distance for continuous numerical variables and Kullback-Leibler (KL) divergence for discrete categorical and boolean variables.\n",
- "\n",
- "#### Wasserstein distance for numerical variables\n",
- "\n",
- "For continuous numerical variables, the framework uses the Wasserstein distance (also known as Earth Mover's Distance) to quantify the difference between distributions. The Wasserstein distance between two probability distributions $P$ and $Q$ is defined as:\n",
- "\n",
- "$$W_p(P, Q) = \\left(\\inf_{\\gamma \\in \\Pi(P, Q)} \\int_{X \\times Y} d(x, y)^p d\\gamma(x, y)\\right)^{1/p}$$\n",
- "\n",
- "where $\\Pi(P, Q)$ denotes the set of all joint distributions whose marginals are $P$ and $Q$ respectively.\n",
- "\n",
- "The Wasserstein distance measures the minimum \"work\" required to transform one distribution into another, where work is defined as the amount of distribution mass moved times the distance it's moved. Lower values indicate better preservation of the original distribution's shape. In the SCF example, QRF shows the lowest Wasserstein distance (1.2e7), indicating it best preserves the distribution of net worth values, while QuantReg shows the highest distance (2.8e7), suggesting greater distributional distortion.\n",
- "\n",
- "#### Kullback-Leibler divergence for categorical and boolean variables\n",
- "\n",
- "For discrete distributions (categorical and boolean variables), the framework employs KL divergence, an information-theoretic measure that quantifies how one probability distribution diverges from a reference distribution. The KL divergence from distribution $Q$ to distribution $P$ is defined as:\n",
- "\n",
- "$$D_{KL}(P||Q) = \\sum_{x \\in \\mathcal{X}} P(x) \\log\\left(\\frac{P(x)}{Q(x)}\\right)$$\n",
- "\n",
- "where:\n",
- "- $P$ is the reference distribution (original data)\n",
- "- $Q$ is the approximation (imputed data)\n",
- "- $\\mathcal{X}$ is the set of all possible categorical values\n",
- "\n",
- "In the context of imputation evaluation, KL divergence measures how much information is lost when using the imputed distribution $Q$ to approximate the true distribution $P$. Lower KL divergence values indicate better preservation of the original categorical distribution.\n",
- "\n",
- "## Predictor analysis and sensitivity evaluation\n",
- "\n",
- "Beyond comparing imputation methods, understanding the relationship between predictors and target variables, as well as the sensitivity of imputation quality to predictor selection, provides crucial insights for model optimization and feature engineering.\n",
- "\n",
- "### Predictor-target mutual information\n",
- "\n",
- "The `compute_predictor_correlations()` function with the `imputed_variables` parameter computes normalized mutual information between each predictor and the target variables. Mutual information measures the reduction in uncertainty about one variable given knowledge of another, making it particularly valuable for mixed data types (numeric and categorical). Unlike correlation coefficients that capture only linear relationships, mutual information detects any statistical dependency.\n",
- "\n",
- "For discrete random variables $X$ and $Y$, the mutual information is defined as:\n",
- "\n",
- "$$I(X; Y) = \\sum_{x \\in X} \\sum_{y \\in Y} p(x, y) \\log\\left(\\frac{p(x, y)}{p(x)p(y)}\\right)$$\n",
- "\n",
- "where:\n",
- "- $p(x, y)$ is the joint probability distribution of $X$ and $Y$\n",
- "- $p(x)$ and $p(y)$ are the marginal probability distributions\n",
- "\n",
- "For continuous variables, the summations are replaced by integrals:\n",
- "\n",
- "$$I(X; Y) = \\int_{X} \\int_{Y} p(x, y) \\log\\left(\\frac{p(x, y)}{p(x)p(y)}\\right) dx dy$$\n",
- "\n",
- "The normalized mutual information (NMI) used in the implementation is:\n",
- "\n",
- "$$\\text{NMI}(X; Y) = \\frac{I(X; Y)}{\\sqrt{H(X) \\cdot H(Y)}}$$\n",
- "\n",
- "where $H(X)$ and $H(Y)$ are the entropies of $X$ and $Y$ respectively.\n",
- "\n",
- "The normalized values range from 0 (no relationship) to 1 (perfect dependency), allowing direct comparison of predictor importance across different variable types. In the SCF example, we see that `income` (MI=0.156) and `wageinc` (MI=0.122) have the strongest relationships with `networth`, while `lf` (labor force status, MI=0.011) has minimal information content.\n",
- "\n",
- "### Leave-one-out predictor analysis\n",
- "\n",
- "The `leave_one_out_analysis()` function systematically evaluates the importance of each predictor by removing it from the model and measuring the resulting performance degradation. This approach provides a direct measure of each predictor's contribution to imputation quality.\n",
- "\n",
- "The analysis returns several key metrics:\n",
- "- **Loss increase**: The absolute increase in loss when the predictor is removed\n",
- "- **Relative impact**: The percentage increase in loss, indicating the predictor's relative importance\n",
- "\n",
- "Predictors with high relative impact are essential for accurate imputation and should not be removed from the model. Conversely, predictors with minimal impact might be candidates for removal to simplify the model and reduce computational costs.\n",
- "\n",
- "### Progressive predictor inclusion\n",
- "\n",
- "The `progressive_predictor_inclusion()` function takes a complementary approach by starting with no predictors and iteratively adding the one that provides the greatest performance improvement. This greedy forward selection reveals:\n",
- "- **Optimal inclusion order**: The sequence in which predictors should be added for maximum benefit\n",
- "- **Marginal contribution**: The performance improvement from adding each predictor\n",
- "- **Optimal subset**: The minimal set of predictors achieving near-optimal performance\n",
- "\n",
- "In the diabetes dataset example, the inclusion order [`bp`, `age`, `bmi`, `sex`] shows that blood pressure provides the most information, followed by age and BMI, with sex providing minimal additional benefit. The diminishing returns in loss reduction (from 0.0006 to 0.0003) illustrate when additional predictors provide negligible improvement.\n",
- "\n",
- "## Visualization\n",
- "\n",
- "The `method_comparison_results.plot()` function generates bar charts that present benchmarking results grouping results by model, allowing quickly identifying patterns and trends in performance across different methods and different parts of the distribution. The `metric` parameter allows specifying \"quantile_loss\", \"log_loss\" or \"combined\" as the metric of choice for the visualziation.\n",
- "\n",
- "The function employs color coding to visually distinguish between different imputation models, making it easy to track the performance of a single method. Along the horizontal axis, when visualizing quantile loss the chart displays different quantiles (such as the 10th, 25th, 50th percentiles), allowing assessment across the entire distribution of interest. The vertical axis represents average loss, with lower values indicating better performance (both for quantile loss and log loss), giving an immediate visual indication of which models are performing well.\n",
- "\n",
- "## Extending the benchmarking framework\n",
- "\n",
- "The Microimpute benchmarking framework was designed with extensibility as a core principle, allowing researchers to easily integrate and evaluate new imputation approaches. To incorporate your own custom imputation model into this evaluation framework, you can follow a straightforward process.\n",
- "\n",
- "First, implement your custom model by extending the `Imputer` abstract base class, following the design patterns and interface requirements documented in the [implement-new-model.md](../models/imputer/implement-new-model.md) file. This structured approach ensures your model will interact correctly with the rest of the benchmarking system. Once your model implementation is complete, simply include your model class in the `model_classes` list alongside the built-in models you wish to compare against. Finally, execute the benchmarking process as described previously, and your custom model will be evaluated using the same rigorous methodology applied to the built-in models.\n",
- "\n",
- "This seamless integration is possible because all models that implement the `Imputer` interface share a common API, allowing the benchmarking framework to interact with them in a consistent manner regardless of their internal implementation details. This architectural decision makes the framework inherently extensible while maintaining a clean separation between the benchmarking logic and the specific imputation methods being evaluated.\n",
- "\n",
- "## Best practices\n",
- "\n",
- "Robust evaluation and benchmarking requires testing models across multiple diverse datasets rather than relying on a single test case. This approach helps identify which models perform consistently well across different data scenarios and which may be sensitive to particular data characteristics. By examining performance across varied contexts, you can make more confident generalizations about a method's effectiveness.\n",
- "\n",
- "A comprehensive evaluation should assess performance across different quantiles rather than focusing solely on central measures like the median. Many applications care about the tails of distributions, and models that perform well at the median might struggle with extreme quantiles. Evaluating across the full spectrum of quantiles provides a more complete picture of each method's strengths and limitations.\n",
- "\n",
- "While statistical performance is critical, practical considerations should not be overlooked. Different imputation methods can vary dramatically in their computational requirements, including training time, memory usage, and prediction speed. In many applications, a slightly less accurate method that runs orders of magnitude faster may be preferable. Consider these trade-offs explicitly in your evaluation framework.\n",
- "\n",
- "For particularly important decisions, enhance the reliability of your performance estimates through cross-validation techniques. Cross-validation provides a more stable estimate of model performance by averaging results across multiple train-test splits, reducing the impact of any particular data division. This approach is especially valuable when working with smaller datasets where a single train-test split might not be representative.\n",
- "\n",
- "The package also supports detailed assessment of model behavior through train-test performance comparisons via the `model_performance_results()` function. This visualization tool helps identify potential overfitting or underfitting issues by contrasting a model's performance on training data with its performance on held-out test data. Significant disparities between training and testing performance can reveal important limitations in a model's generalization capabilities.\n",
- "\n",
- "For specialized applications with particular interest in certain parts of the distribution, the framework accommodates custom quantile sets for targeted evaluation. Rather than using the default (random) quantiles, researchers can specify exactly which quantiles to evaluate, allowing focused assessment of performance in regions of particular interest. This flexibility enables tailored evaluations that align precisely with application-specific requirements and priorities."
- ]
}
],
"metadata": {
diff --git a/docs/imputation-benchmarking/cross-validation.md b/docs/imputation-benchmarking/cross-validation.md
new file mode 100644
index 0000000..134772d
--- /dev/null
+++ b/docs/imputation-benchmarking/cross-validation.md
@@ -0,0 +1,150 @@
+# Cross-validation and model imputation comparison
+
+This page documents the cross-validation utilities for evaluating imputation model performance. Cross-validation provides robust
+estimates of how well a model will generalize to unseen data by training and testing on multiple data splits. Functions like `get_imputations`, will then build upon it, to standardize evaluation for all models, ensuring possible through a consistent experimental setup.
+
+Microimpute's cross-validation automatically selects the appropriate metric based on variable type. Numerical variables are evaluated using quantile loss, which measures prediction accuracy across the conditional distribution. Categorical variables are evaluated using log loss (cross-entropy), which penalizes confident but incorrect predictions, see the [Metrics page](./metrics.md) for more details.
+
+## Cross-validation
+
+Cross-validation provides robust estimates of how well a model will generalize to unseen data by training and testing on multiple data splits. Microimpute's cross-validation automatically selects the appropriate metric based on variable type: quantile loss for numerical variables and log loss for categorical variables.
+
+### cross_validate_model
+
+```python
+def cross_validate_model(
+ model_class: Type,
+ data: pd.DataFrame,
+ predictors: List[str],
+ imputed_variables: List[str],
+ weight_col: Optional[str] = None,
+ quantiles: Optional[List[float]] = QUANTILES,
+ n_splits: Optional[int] = 5,
+ random_state: Optional[int] = RANDOM_STATE,
+ model_hyperparams: Optional[dict] = None,
+ tune_hyperparameters: Optional[bool] = False,
+) -> Union[Dict[str, Any], Tuple[Dict[str, Any], Dict]]
+```
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| model_class | Type | - | Imputer class to evaluate (e.g., QRF, OLS, MDN) |
+| data | pd.DataFrame | - | Complete dataset for cross-validation |
+| predictors | List[str] | - | Column names of predictor variables |
+| imputed_variables | List[str] | - | Column names of variables to impute |
+| weight_col | str | None | Column name for sampling weights |
+| quantiles | List[float] | [0.05 to 0.95 in steps of 0.05] | Quantiles to evaluate |
+| n_splits | int | 5 | Number of cross-validation folds |
+| random_state | int | 42 | Random seed for reproducibility |
+| model_hyperparams | dict | None | Hyperparameters to pass to the model |
+| tune_hyperparameters | bool | False | Enable hyperparameter tuning |
+
+Returns a dictionary containing separate results for each metric type:
+
+```python
+{
+ "quantile_loss": {
+ "results": pd.DataFrame, # rows: ["train", "test"], cols: quantiles
+ "mean_train": float,
+ "mean_test": float,
+ "variables": List[str] # numerical variables evaluated
+ },
+ "log_loss": {
+ "results": pd.DataFrame, # rows: ["train", "test"], cols: quantiles
+ "mean_train": float,
+ "mean_test": float,
+ "variables": List[str] # categorical variables evaluated
+ }
+}
+```
+
+If `tune_hyperparameters=True`, returns a tuple of `(results_dict, best_hyperparameters)`.
+
+## Example usage
+
+```python
+from microimpute.evaluations import cross_validate_model
+from microimpute.models import QRF
+
+# Run 5-fold cross-validation
+results = cross_validate_model(
+ model_class=QRF,
+ data=diabetes_df,
+ predictors=["age", "sex", "bmi", "bp"],
+ imputed_variables=["s1", "s4"],
+ n_splits=5
+)
+
+# Check performance for numerical variables
+print(f"Mean test quantile loss: {results['quantile_loss']['mean_test']:.4f}")
+
+# View detailed results by quantile
+print(results["quantile_loss"]["results"])
+```
+
+## Interpreting results
+
+The results DataFrame shows loss values for each quantile, with rows for train and test splits. Lower values indicate better performance. Comparing train and test loss helps identify overfitting: a large gap suggests the model may not generalize well.
+
+For model selection, focus on the test loss (`mean_test`). When comparing multiple models, the `autoimpute()` function automates this comparison and selects the best-performing model using a rank-based approach that handles mixed variable types.
+
+## Imputation generation for model comparison
+
+The `get_imputations` function generates imputations using cross-validation for multiple model classes in a single call, organizing results in a consistent format for downstream comparison and evaluation.
+
+### get_imputations
+
+```python
+def get_imputations(
+ model_classes: List[Type],
+ X_train: pd.DataFrame,
+ X_test: pd.DataFrame,
+ predictors: List[str],
+ imputed_variables: List[str],
+ quantiles: Optional[List[float]] = QUANTILES,
+) -> Dict[str, Dict[float, pd.DataFrame]]
+```
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| model_classes | List[Type] | - | List of model classes to use (e.g., QRF, OLS, QuantReg, Matching) |
+| X_train | pd.DataFrame | - | Training data containing predictors and variables to impute |
+| X_test | pd.DataFrame | - | Test data on which to generate imputations |
+| predictors | List[str] | - | Column names of predictor variables |
+| imputed_variables | List[str] | - | Column names of variables to impute |
+| quantiles | List[float] | [0.05 to 0.95 in steps of 0.05] | List of quantiles to predict |
+
+Returns a nested dictionary mapping method names to dictionaries of quantile-indexed DataFrames:
+
+```python
+{
+ "QRF": {
+ 0.1: pd.DataFrame, # predictions at 10th percentile
+ 0.5: pd.DataFrame, # predictions at 50th percentile
+ 0.9: pd.DataFrame, # predictions at 90th percentile
+ },
+ "OLS": {
+ 0.1: pd.DataFrame,
+ ...
+ },
+}
+```
+
+### Example usage
+
+```python
+from microimpute.comparisons import get_imputations
+from microimpute.models import QRF, OLS, QuantReg, Matching
+
+# Generate imputations from multiple models
+method_imputations = get_imputations(
+ model_classes=[QRF, OLS, QuantReg, Matching],
+ X_train=train_data,
+ X_test=test_data,
+ predictors=["age", "sex", "bmi"],
+ imputed_variables=["income", "wealth"],
+)
+
+# Access predictions for a specific model and quantile
+qrf_median = method_imputations["QRF"][0.5]
+```
\ No newline at end of file
diff --git a/docs/imputation-benchmarking/index.md b/docs/imputation-benchmarking/index.md
index bf04ded..337e1c5 100644
--- a/docs/imputation-benchmarking/index.md
+++ b/docs/imputation-benchmarking/index.md
@@ -1,5 +1,5 @@
# Benchmarking different imputation methods
-This documentation describes how the Microimpute package allows you to compare different imputation methods using quantile loss metrics.
+This chapter describes how the Microimpute package allows you to compare different imputation methods using preprocessing, cross-validation, metric comparison, and evaluation tools.
-The benchmarking functionality enables systematically comparing multiple imputation models using a common dataset, allowing for robust evaluation of their performance. By assessing accuracy across various quantiles, you gain a comprehensive understanding of how each method performs across different levels of the distribution. This process is further supported by visualizations that highlight differences between approaches, making it easy to identify which imputation methods perform best under specific conditions. Ultimately, this empowers you to make data-driven decisions regarding the most suitable imputation approach for your analysis or application.
+The benchmarking functionality enables systematically comparing multiple imputation models using a common dataset, allowing for robust evaluation of their performance. It supports cross-validation to diagnose overfitting and measure performance on training data leveraging the availability of ground truth. By assessing accuracy of numeric imputation across various quantiles, it is possible to gain a more comprehensive understanding of how each method performs across different levels of the distribution. Categorical imputation is assessed with log loss. This process is further supported by visualizations that highlight differences between approaches, making it easy to identify which imputation methods perform best under specific conditions. Predictor evaluation tools are also available to inform decision-making when setting up the imputation task.
diff --git a/docs/imputation-benchmarking/metrics.md b/docs/imputation-benchmarking/metrics.md
new file mode 100644
index 0000000..fce0c40
--- /dev/null
+++ b/docs/imputation-benchmarking/metrics.md
@@ -0,0 +1,265 @@
+# Metrics and evaluation
+
+This page documents the evaluation metrics and predictor analysis tools available for assessing imputation quality. These utilities help understand model performance, compare methods, and analyze the contribution of individual predictors.
+
+## Loss metrics
+
+Microimpute employs evaluation metrics tailored to the type of variable being imputed. The framework automatically selects the appropriate metric based on whether the imputed variable is numerical or categorical, ensuring meaningful performance assessment across different data types.
+
+### Quantile loss
+
+Quantile loss assesses imputation quality for numerical variables. This approach provides a more nuanced evaluation than traditional metrics like mean squared error, particularly for capturing performance across different parts of the distribution.
+
+The quantile loss implements the standard pinball loss formulation:
+
+$$L_q(y, f) = \max(q(y-f), (q-1)(y-f))$$
+
+where $q$ is the quantile being evaluated, $y$ represents the true value, and $f$ is the imputed value. This asymmetric loss function penalizes under-prediction more heavily for higher quantiles and over-prediction more heavily for lower quantiles. The asymmetry aligns with the interpretation of quantiles: a 90th percentile prediction should rarely fall below the true value, while a 10th percentile prediction should rarely exceed it.
+
+```python
+def quantile_loss(q: float, y: np.ndarray, f: np.ndarray) -> np.ndarray
+```
+
+| Parameter | Type | Description |
+|-----------|------|-------------|
+| q | float | Quantile to evaluate (e.g., 0.5 for median) |
+| y | np.ndarray | True values |
+| f | np.ndarray | Predicted values |
+
+Returns an array of element-wise quantile losses.
+
+### Log loss
+
+Log loss (cross-entropy) evaluates probabilistic predictions of categorical outcomes. It measures the performance of a classification model where the prediction output is a probability value between 0 and 1.
+
+The log loss metric is calculated as:
+
+$$\text{LogLoss} = -\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{M} y_{ij} \log(p_{ij})$$
+
+where $N$ is the number of samples, $M$ is the number of classes, $y_{ij}$ is 1 if sample $i$ belongs to class $j$ and 0 otherwise, and $p_{ij}$ is the predicted probability of sample $i$ belonging to class $j$.
+
+A perfect classifier achieves a log loss of 0, while worse predictions yield increasingly higher values. The metric heavily penalizes confident misclassifications: predicting a class with high probability when incorrect results in a large loss value.
+
+```python
+def log_loss(
+ y_true: np.ndarray,
+ y_pred: np.ndarray,
+ normalize: bool = True,
+ labels: Optional[np.ndarray] = None,
+) -> float
+```
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| y_true | np.ndarray | - | True class labels |
+| y_pred | np.ndarray | - | Predicted probabilities or class labels |
+| normalize | bool | True | If True, return mean loss; if False, return sum |
+| labels | np.ndarray | None | List of possible label values |
+
+Returns the Log loss value (float).
+
+When predictions are class labels rather than probabilities, the function converts them to high-confidence probabilities (0.99/0.01) with a warning. For more accurate evaluation, use probability predictions when available.
+
+### compute_loss
+
+A unified function that selects the appropriate loss metric based on the specified type, providing a consistent interface for both numerical and categorical evaluation.
+
+```python
+def compute_loss(
+ test_y: np.ndarray,
+ imputations: np.ndarray,
+ metric: Literal["quantile_loss", "log_loss"],
+ q: float = 0.5,
+ labels: Optional[np.ndarray] = None,
+) -> Tuple[np.ndarray, float]
+```
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| test_y | np.ndarray | - | True values |
+| imputations | np.ndarray | - | Predicted/imputed values |
+| metric | str | - | "quantile_loss" or "log_loss" |
+| q | float | 0.5 | Quantile (for quantile_loss only) |
+| labels | np.ndarray | None | Class labels (for log_loss only) |
+
+Returns a tuple of `(element_wise_losses, mean_loss)`.
+
+### compare_metrics
+
+Compares metrics across multiple imputation methods, automatically detecting variable types and applying the appropriate metric. For models that handle both numerical and categorical variables, the evaluation produces separate results for each metric type.
+
+```python
+def compare_metrics(
+ test_y: pd.DataFrame,
+ method_imputations: Dict[str, Dict[float, pd.DataFrame]],
+ imputed_variables: List[str],
+) -> pd.DataFrame
+```
+
+| Parameter | Type | Description |
+|-----------|------|-------------|
+| test_y | pd.DataFrame | DataFrame containing true values |
+| method_imputations | Dict | Nested dict: method → quantile → DataFrame |
+| imputed_variables | List[str] | Variables to evaluate |
+
+Returns a DataFrame with columns `Method`, `Imputed Variable`, `Percentile`, `Loss`, and `Metric`.
+
+## Distribution comparison
+
+Beyond point-wise loss metrics, evaluating how well imputed values preserve distributional characteristics provides insight into whether the imputation maintains the statistical properties of the original data.
+
+### Wasserstein distance
+
+For continuous numerical variables, the Wasserstein distance (Earth Mover's Distance) quantifies the difference between distributions:
+
+$$W_p(P, Q) = \left(\inf_{\gamma \in \Pi(P, Q)} \int_{X \times Y} d(x, y)^p d\gamma(x, y)\right)^{1/p}$$
+
+where $\Pi(P, Q)$ denotes the set of all joint distributions whose marginals are $P$ and $Q$ respectively. The Wasserstein distance measures the minimum "work" required to transform one distribution into another, where work is the amount of distribution mass moved times the distance moved. Lower values indicate better preservation of the original distribution's shape.
+
+### Kullback-Leibler divergence
+
+For discrete distributions (categorical and boolean variables), KL divergence quantifies how one probability distribution diverges from a reference:
+
+$$D_{KL}(P||Q) = \sum_{x \in \mathcal{X}} P(x) \log\left(\frac{P(x)}{Q(x)}\right)$$
+
+where $P$ is the reference distribution (original data), $Q$ is the approximation (imputed data), and $\mathcal{X}$ is the set of all possible categorical values. KL divergence measures how much information is lost when using the imputed distribution to approximate the true distribution. Lower values indicate better preservation of the original categorical distribution.
+
+### compare_distributions
+
+```python
+def compare_distributions(
+ donor_data: pd.DataFrame,
+ receiver_data: pd.DataFrame,
+ imputed_variables: List[str],
+) -> pd.DataFrame
+```
+
+| Parameter | Type | Description |
+|-----------|------|-------------|
+| donor_data | pd.DataFrame | Original donor data |
+| receiver_data | pd.DataFrame | Receiver data with imputations |
+| imputed_variables | List[str] | Variables to compare |
+
+Returns a DataFrame with columns `Variable`, `Metric`, and `Distance`. The function automatically selects Wasserstein distance for numerical variables and KL divergence for categorical variables.
+
+## Predictor analysis
+
+Understanding which predictors contribute most to imputation quality helps with feature selection and model interpretation. These tools analyze predictor-target relationships and evaluate sensitivity to predictor selection.
+
+### Mutual information
+
+Mutual information measures the reduction in uncertainty about one variable given knowledge of another. Unlike correlation coefficients that capture only linear relationships, mutual information detects any statistical dependency, making it valuable for mixed data types.
+
+For discrete random variables $X$ and $Y$:
+
+$$I(X; Y) = \sum_{x \in X} \sum_{y \in Y} p(x, y) \log\left(\frac{p(x, y)}{p(x)p(y)}\right)$$
+
+For continuous variables, the summations are replaced by integrals. The normalized mutual information (NMI) used in the implementation is:
+
+$$\text{NMI}(X; Y) = \frac{I(X; Y)}{\sqrt{H(X) \cdot H(Y)}}$$
+
+where $H(X)$ and $H(Y)$ are the entropies of $X$ and $Y$ respectively. Normalized values range from 0 (no relationship) to 1 (perfect dependency), allowing direct comparison of predictor importance across different variable types.
+
+### compute_predictor_correlations
+
+```python
+def compute_predictor_correlations(
+ data: pd.DataFrame,
+ predictors: List[str],
+ imputed_variables: List[str],
+) -> Dict[str, pd.DataFrame]
+```
+
+| Parameter | Type | Description |
+|-----------|------|-------------|
+| data | pd.DataFrame | Dataset containing predictors and target variables |
+| predictors | List[str] | Column names of predictor variables |
+| imputed_variables | List[str] | Column names of target variables |
+
+Returns a dictionary containing `predictor_target_mi` DataFrame with mutual information scores.
+
+### Leave-one-out analysis
+
+Leave-one-out predictor analysis evaluates model performance when each predictor is excluded. By comparing loss with and without each predictor, you can assess its contribution to imputation quality. Predictors whose removal causes large increases in loss are most important, while those with minimal impact might be candidates for removal to simplify the model.
+
+### leave_one_out_analysis
+
+```python
+def leave_one_out_analysis(
+ data: pd.DataFrame,
+ predictors: List[str],
+ imputed_variables: List[str],
+ model_class: Type,
+ quantiles: Optional[List[float]] = QUANTILES,
+) -> Dict[str, Any]
+```
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| data | pd.DataFrame | - | Complete dataset |
+| predictors | List[str] | - | Column names of predictor variables |
+| imputed_variables | List[str] | - | Column names of variables to impute |
+| model_class | Type | - | Imputer class to evaluate |
+| quantiles | List[float] | [0.05 to 0.95 in steps of 0.05] | Quantiles to evaluate |
+
+Returns a dictionary containing loss increase and relative impact for each predictor.
+
+### Progressive predictor inclusion
+
+Progressive inclusion analysis adds predictors one at a time in order of their mutual information with the target. This greedy forward selection reveals the optimal inclusion order, marginal contribution of each predictor, and the minimal set of predictors achieving near-optimal performance. Diminishing returns in loss reduction indicate when additional predictors provide negligible improvement.
+
+### progressive_predictor_inclusion
+
+```python
+def progressive_predictor_inclusion(
+ data: pd.DataFrame,
+ predictors: List[str],
+ imputed_variables: List[str],
+ model_class: Type,
+ quantiles: Optional[List[float]] = QUANTILES,
+) -> Dict[str, Any]
+```
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| data | pd.DataFrame | - | Complete dataset |
+| predictors | List[str] | - | Column names of predictor variables |
+| imputed_variables | List[str] | - | Column names of variables to impute |
+| model_class | Type | - | Imputer class to evaluate |
+| quantiles | List[float] | [0.05 to 0.95 in steps of 0.05] | Quantiles to evaluate |
+
+Returns a dictionary containing `inclusion_order` (list of predictors in optimal order) and `predictor_impacts` (list of dicts with predictor name and loss reduction).
+
+## Example usage
+
+```python
+from microimpute.comparisons.metrics import compare_metrics, compare_distributions
+from microimpute.evaluations import (
+ compute_predictor_correlations,
+ leave_one_out_analysis,
+ progressive_predictor_inclusion,
+)
+from microimpute.models import QRF
+
+# Compare methods
+metrics_df = compare_metrics(
+ test_y=test_data[imputed_variables],
+ method_imputations={
+ "QRF": qrf_imputations,
+ "OLS": ols_imputations,
+ },
+ imputed_variables=imputed_variables
+)
+
+# Evaluate distributional match
+dist_df = compare_distributions(
+ donor_data=donor,
+ receiver_data=receiver_with_imputations,
+ imputed_variables=imputed_variables
+)
+
+# Analyze predictor importance
+mi_scores = compute_predictor_correlations(data, predictors, imputed_variables)
+loo_results = leave_one_out_analysis(data, predictors, imputed_variables, QRF)
+inclusion_results = progressive_predictor_inclusion(data, predictors, imputed_variables, QRF)
+```
diff --git a/docs/imputation-benchmarking/preprocessing.md b/docs/imputation-benchmarking/preprocessing.md
new file mode 100644
index 0000000..caa352c
--- /dev/null
+++ b/docs/imputation-benchmarking/preprocessing.md
@@ -0,0 +1,131 @@
+# Data preprocessing
+
+Preprocessing transformations can improve model performance by normalizing scale differences or handling skewed distributions. These are supported by `preprocess_data` and transformation-specific functions.
+
+## Transformation options
+
+Microimpute supports three transformation types that can be applied to numeric columns before training. Each transformation automatically excludes categorical and boolean columns to prevent encoding issues.
+
+**Normalization (z-score)** standardizes data to have mean 0 and standard deviation 1. This transformation is useful when predictors have different scales, ensuring that all features contribute equally to distance-based or gradient-based models.
+
+**Log transformation** applies the natural logarithm to values. This is effective for right-skewed distributions common in financial data like income or wealth. The transformation requires all values to be strictly positive.
+
+**Asinh transformation** applies the inverse hyperbolic sine function, which behaves like $\log(2x)$ for large positive values and $-\log(-2x)$ for large negative values, while remaining approximately linear near zero. Unlike log transformation, asinh handles zero and negative values, making it suitable for variables like net worth that can take any real value.
+
+### preprocess_data
+
+The main entry point for data preparation, combining splitting and transformation.
+
+```python
+def preprocess_data(
+ data: pd.DataFrame,
+ full_data: Optional[bool] = False,
+ train_size: Optional[float] = TRAIN_SIZE,
+ test_size: Optional[float] = TEST_SIZE,
+ random_state: Optional[int] = RANDOM_STATE,
+ normalize: Optional[Union[bool, List[str]]] = False,
+ log_transform: Optional[Union[bool, List[str]]] = False,
+ asinh_transform: Optional[Union[bool, List[str]]] = False,
+) -> Union[Tuple[pd.DataFrame, dict], Tuple[pd.DataFrame, pd.DataFrame, dict]]
+```
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| data | pd.DataFrame | - | DataFrame containing the data to preprocess |
+| full_data | bool | False | If True, return complete dataset without splitting |
+| train_size | float | 0.8 | Proportion of data for training split |
+| test_size | float | 0.2 | Proportion of data for test split |
+| random_state | int | 42 | Random seed for reproducibility |
+| normalize | bool or List[str] | False | True for all numeric columns, or list of specific columns |
+| log_transform | bool or List[str] | False | True for all numeric columns, or list of specific columns |
+| asinh_transform | bool or List[str] | False | True for all numeric columns, or list of specific columns |
+
+The return type depends on parameters. If `full_data=True` and transformations are applied, returns `(data, transform_params)`. If `full_data=False` with transformations, returns `(X_train, X_test, transform_params)`. Without transformations, the transform_params dict is omitted.
+
+### normalize_data
+
+```python
+def normalize_data(
+ data: pd.DataFrame,
+ columns_to_normalize: Optional[List[str]] = None,
+) -> Tuple[pd.DataFrame, dict]
+```
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| data | pd.DataFrame | - | DataFrame to normalize |
+| columns_to_normalize | List[str] | None | Specific columns to normalize. If None, all numeric columns |
+
+It returns uple of `(normalized_data, normalization_params)` where `normalization_params` maps column names to `{"mean": float, "std": float}`.
+
+### log_transform_data
+
+```python
+def log_transform_data(
+ data: pd.DataFrame,
+ columns_to_transform: Optional[List[str]] = None,
+) -> Tuple[pd.DataFrame, dict]
+```
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| data | pd.DataFrame | - | DataFrame to transform |
+| columns_to_transform | List[str] | None | Specific columns to transform. If None, all numeric columns |
+
+Returns a tuple of `(log_transformed_data, log_transform_params)`.
+
+Note: Raises `ValueError` if any values are non-positive.
+
+### asinh_transform_data
+
+```python
+def asinh_transform_data(
+ data: pd.DataFrame,
+ columns_to_transform: Optional[List[str]] = None,
+) -> Tuple[pd.DataFrame, dict]
+```
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| data | pd.DataFrame | - | DataFrame to transform |
+| columns_to_transform | List[str] | None | Specific columns to transform. If None, all numeric columns |
+
+Returns a tuple of `(asinh_transformed_data, asinh_transform_params)`.
+
+## Reversal functions
+
+After imputation, predictions can be converted back to the original scale using the stored parameters.
+
+```python
+def unnormalize_predictions(imputations: dict, normalization_params: dict) -> dict
+def unlog_transform_predictions(imputations: dict, log_transform_params: dict) -> dict
+def un_asinh_transform_predictions(imputations: dict, asinh_transform_params: dict) -> dict
+```
+
+Each function takes the imputation dictionary (mapping quantiles to DataFrames) and the parameter dictionary returned by the corresponding transform function, returning imputations in the original scale.
+
+## Usage with autoimpute
+
+The `autoimpute()` function accepts a `preprocessing` parameter that specifies transformations per variable:
+
+```python
+from microimpute.comparisons.autoimpute import autoimpute
+
+result = autoimpute(
+ donor_data=donor,
+ receiver_data=receiver,
+ predictors=["age", "education"],
+ imputed_variables=["income", "wealth"],
+ preprocessing={
+ "income": "log", # Log transform (positive values only)
+ "wealth": "asinh", # Asinh transform (handles zeros/negatives)
+ "age": "normalize" # Z-score normalization
+ }
+)
+```
+
+The transformations are applied automatically before model training and reversed after prediction, so the returned imputations are in the original scale.
+
+## Constraints
+
+Each column can only have one transformation applied. Attempting to apply multiple transformations to the same column raises a `ValueError`. When specifying transformations as `True` (apply to all), only one transformation type can be used. For different transformations on different columns, use the list format to specify columns explicitly.
diff --git a/docs/imputation-benchmarking/visualizations.md b/docs/imputation-benchmarking/visualizations.md
new file mode 100644
index 0000000..86867ac
--- /dev/null
+++ b/docs/imputation-benchmarking/visualizations.md
@@ -0,0 +1,174 @@
+# Visualizations
+
+The following visualization utilities create informative plots for comparing methods and understanding model performance.
+
+## Method comparison
+
+The `MethodComparisonResults` class and `method_comparison_results()` factory function provide visualization for comparing multiple imputation methods.
+
+### method_comparison_results
+
+```python
+def method_comparison_results(
+ data: Union[pd.DataFrame, Dict[str, Dict[str, Dict]]],
+ metric_name: Optional[str] = None,
+ metric: str = "quantile_loss",
+ data_format: str = "wide",
+) -> MethodComparisonResults
+```
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| data | DataFrame or Dict | - | Comparison data from cross-validation results |
+| metric_name | str | None | Deprecated, use `metric` instead |
+| metric | str | "quantile_loss" | Metric to visualize: "quantile_loss", "log_loss", or "combined" |
+| data_format | str | "wide" | Input format: "wide", "long", or "dual_metrics" |
+
+Returns: `MethodComparisonResults` object for visualization.
+
+### MethodComparisonResults class
+
+```python
+class MethodComparisonResults:
+ def plot(
+ self,
+ title: Optional[str] = None,
+ save_path: Optional[str] = None,
+ show_mean: bool = True,
+ figsize: Tuple[int, int] = (PLOT_CONFIG["width"], PLOT_CONFIG["height"]),
+ plot_type: str = "bar",
+ ) -> go.Figure
+
+ def summary(self, format: str = "wide") -> pd.DataFrame
+
+ def get_best_method(self, criterion: str = "mean") -> str
+```
+
+`MethodComparisonResults.plot()` parameters:
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| title | str | None | Custom plot title |
+| save_path | str | None | Path to save the plot |
+| show_mean | bool | True | Show horizontal lines for mean loss |
+| figsize | tuple | (width, height) | Figure dimensions in pixels |
+| plot_type | str | "bar" | Plot type: "bar" for grouped bars, "stacked" for contribution analysis |
+
+The `"stacked"` plot type shows rank-based contribution scores, useful for understanding how each variable contributes to overall model performance.
+
+`MethodComparisonResults.summary()` parameters:
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| format | str | "wide" | Output format: "wide" for methods as columns, "long" for stacked |
+
+`MethodComparisonResults.get_best_method()` parameters:
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| criterion | str | "mean" | Selection criterion: "mean" or "median" |
+
+### Example usage
+
+```python
+from microimpute.visualizations import method_comparison_results
+
+# Create comparison visualization from cross-validation results
+comparison_viz = method_comparison_results(
+ data=cv_results,
+ metric="quantile_loss",
+ data_format="wide",
+)
+
+# Generate plot
+fig = comparison_viz.plot(
+ title="Method comparison",
+ show_mean=True,
+ plot_type="bar"
+)
+fig.show()
+
+# Get summary statistics
+summary_df = comparison_viz.summary(format="wide")
+
+# Identify best method
+best = comparison_viz.get_best_method(criterion="mean")
+```
+
+## Individual model performance
+
+The `PerformanceResults` class and `model_performance_results()` factory function visualize single model performance.
+
+### model_performance_results
+
+```python
+def model_performance_results(
+ results: Union[pd.DataFrame, Dict[str, Dict[str, any]]],
+ model_name: Optional[str] = None,
+ method_name: Optional[str] = None,
+ metric: str = "quantile_loss",
+ class_probabilities: Optional[Dict[str, pd.DataFrame]] = None,
+ y_true: Optional[Dict[str, np.ndarray]] = None,
+ y_pred: Optional[Dict[str, np.ndarray]] = None,
+) -> PerformanceResults
+```
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| results | DataFrame or Dict | - | Performance results from cross-validation |
+| model_name | str | None | Name of the model |
+| method_name | str | None | Name of the imputation method |
+| metric | str | "quantile_loss" | Metric to visualize: "quantile_loss", "log_loss", or "combined" |
+| class_probabilities | Dict | None | Class probability DataFrames for categorical |
+| y_true | Dict | None | True values for confusion matrix |
+| y_pred | Dict | None | Predicted values for confusion matrix |
+
+Returns: `PerformanceResults` object for visualization.
+
+### PerformanceResults class
+
+```python
+class PerformanceResults:
+ def plot(
+ self,
+ title: Optional[str] = None,
+ save_path: Optional[str] = None,
+ figsize: Tuple[int, int] = (PLOT_CONFIG["width"], PLOT_CONFIG["height"]),
+ ) -> go.Figure
+
+ def summary(self) -> pd.DataFrame
+```
+
+`PerformanceResults.plot()` parameters:
+
+| Parameter | Type | Default used | Description |
+|-----------|------|---------|-------------|
+| title | str | None | Custom plot title |
+| save_path | str | None | Path to save the plot |
+| figsize | tuple | (width, height) | Figure dimensions in pixels |
+
+For quantile loss, the plot shows train and test loss across quantiles as grouped bars. For log loss, the plot includes the loss bars and optionally confusion matrix and class probability distribution subplots. For combined metrics, both are shown in subplots.
+
+### Example usage
+
+```python
+from microimpute.visualizations import model_performance_results
+
+# Visualize cross-validation results for a single model
+perf_viz = model_performance_results(
+ results=cv_results["quantile_loss"]["results"],
+ model_name="QRF",
+ method_name="Cross-validation",
+ metric="quantile_loss"
+)
+
+fig = perf_viz.plot(title="QRF performance")
+fig.show()
+
+# Get summary statistics
+summary = perf_viz.summary()
+```
+
+## Plot customization
+
+All plots are created using Plotly and return `go.Figure` objects that can be further customized using the standard Plotly API. Plots use a consistent light gray background (`#F0F0F0`) and the Plotly qualitative color palette for consistency across the documentation.
diff --git a/docs/index.md b/docs/index.md
index 391c20c..354005d 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -7,16 +7,17 @@ The framework currently supports the following imputation methods:
- Ordinary Least Squares Linear Regression
- Quantile Random Forests
- Quantile Regression
+- Mixture Density Networks
This is a work in progress that may evolve over time, including new statistical imputation methods and features.
-## Microimputation dashboard
+## Microimputation dashboard
Users can visualize imputation and benchmarking results at https://microimpute-dashboard.vercel.app/.
To use the dashboard for visualization, CSV files must contain the following columns in this exact order:
- `type`: Type of metric (e.g., "benchmark_loss", "distribution_distance", "predictor_correlation")
-- `method`: Imputation method name (e.g., "QRF", "OLS", "QuantReg", "Matching")
+- `method`: Imputation method name (e.g., "QRF", "OLS", "QuantReg", "Matching", "MDN")
- `variable`: Variable being imputed or analyzed
- `quantile`: Quantile level (numeric value, "mean", or "N/A")
- `metric_name`: Name of the metric (e.g., "quantile_loss", "log_loss")
diff --git a/docs/models/imputer/index.md b/docs/models/imputer/index.md
index 95d8db9..3cf0dc0 100644
--- a/docs/models/imputer/index.md
+++ b/docs/models/imputer/index.md
@@ -1,6 +1,6 @@
# The Imputer class
-The `Imputer` class serves as an abstract base class that defines the common interface for all imputation models within the Microimpute framework. It establishes a foundational structure with essential methods for data validation, model fitting, and prediction. Every specialized imputation model in the system inherits from this class and implements the required abstract methods to provide its unique functionality.
+The `Imputer` class serves as an abstract base class that defines the common interface for all imputation models within the Microimpute framework. It establishes a structure with essential methods for data validation, model fitting, and prediction. Every specialized imputation model in the system inherits from this class and implements the required abstract methods to provide its unique functionality.
## Key features
diff --git a/docs/models/matching/matching-imputation.ipynb b/docs/models/matching/matching-imputation.ipynb
index 07f73f4..1cdace2 100644
--- a/docs/models/matching/matching-imputation.ipynb
+++ b/docs/models/matching/matching-imputation.ipynb
@@ -12,7 +12,67 @@
"\n",
"The matching model can handle any variable type—numerical, categorical, boolean, or mixed. Since it transfers actual observed values from similar donor records rather than generating model-based predictions, it naturally preserves the original data type and distribution of each variable.\n",
"\n",
- "The matching model supports iterative imputation with a single object and workflow. Pass a list of `imputed_variables` with all variables that you hope to impute for and the model will do so without needing to fit and predict for each separately."
+ "The matching model supports iterative imputation with a single object and workflow. Pass a list of `imputed_variables` with all variables that you hope to impute for and the model will do so without needing to fit and predict for each separately.\n",
+ "\n",
+ "### Matching class\n",
+ "\n",
+ "```python\n",
+ "class Matching(\n",
+ " matching_hotdeck: MatchingHotdeckFn = nnd_hotdeck_using_rpy2,\n",
+ " log_level: Optional[str] = \"WARNING\"\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| matching_hotdeck | MatchingHotdeckFn | nnd_hotdeck_using_rpy2 | Function that performs hot deck matching |\n",
+ "| log_level | str | \"WARNING\" | Logging verbosity level |\n",
+ "\n",
+ "### fit() method\n",
+ "\n",
+ "```python\n",
+ "def fit(\n",
+ " X_train: pd.DataFrame,\n",
+ " predictors: List[str],\n",
+ " imputed_variables: List[str],\n",
+ " weight_col: Optional[str] = None,\n",
+ " tune_hyperparameters: bool = False,\n",
+ " constrained: bool = False,\n",
+ " constr_alg: str = \"lpSolve\",\n",
+ " dist_fun: str = \"Gower\",\n",
+ " k: int = 5,\n",
+ ") -> Union[MatchingResults, Tuple[MatchingResults, Dict]]\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| X_train | pd.DataFrame | - | Donor data with predictors and target variables |\n",
+ "| predictors | List[str] | - | Column names to use as matching variables |\n",
+ "| imputed_variables | List[str] | - | Column names to impute |\n",
+ "| weight_col | str | None | Column name for sampling weights |\n",
+ "| tune_hyperparameters | bool | False | Enable automatic hyperparameter tuning |\n",
+ "| constrained | bool | False | Use constrained matching (each donor used at most once) |\n",
+ "| constr_alg | str | \"lpSolve\" | Algorithm for constrained matching: \"lpSolve\" or \"hungarian\" |\n",
+ "| dist_fun | str | \"Gower\" | Distance function: \"Gower\", \"Manhattan\", \"Euclidean\", etc. |\n",
+ "| k | int | 5 | Number of nearest neighbors to consider |\n",
+ "\n",
+ "It returns a `MatchingResults` object (or tuple with best hyperparameters if tuning enabled).\n",
+ "\n",
+ "### MatchingResults.predict() method\n",
+ "\n",
+ "```python\n",
+ "def predict(\n",
+ " X_test: pd.DataFrame,\n",
+ " quantiles: Optional[List[float]] = None\n",
+ ") -> Dict[float, pd.DataFrame]\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| X_test | pd.DataFrame | - | Data to impute (with predictors) |\n",
+ "| quantiles | List[float] | None | Quantiles for output format (values are identical across quantiles) |\n",
+ "\n",
+ "It returns a dictionary mapping quantiles to DataFrames of imputed values. Note that matching returns the same values for all quantiles since it transfers actual donor values rather than generating distributional predictions."
]
},
{
@@ -22,6 +82,12 @@
"outputs": [],
"source": [
"# Import needed libraries and setup R environment\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "import logging\n",
+ "logging.getLogger(\"rpy2\").setLevel(logging.ERROR)\n",
+ "\n",
"import sys\n",
"import os\n",
"import pandas as pd\n",
@@ -31,11 +97,7 @@
"from rpy2.robjects import pandas2ri\n",
"from sklearn.datasets import load_diabetes\n",
"from sklearn.model_selection import train_test_split\n",
- "import warnings\n",
- "\n",
- "warnings.filterwarnings(\"ignore\")\n",
"\n",
- "# Set pandas display options to limit table width\n",
"pd.set_option(\"display.width\", 600)\n",
"pd.set_option(\"display.max_columns\", 10)\n",
"pd.set_option(\"display.expand_frame_repr\", False)\n",
@@ -3900,7 +3962,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "This plot displays the prediction intervals produced by a Matching model for total serum cholesterol values across ten data records. Each red dot indicates the imputed median value (Q50) for a missing observation, while black dots represent the corresponding true values. Light blue and dark blue vertical bars would represent the 80% (Q10–Q90) and 40% (Q30–Q70) prediction intervals, respectively. Unlike model-based methods, all records lack visible interval bars entirely. This reflects the limited variability inherent in matching methods, where each imputed value is drawn from a single matched donor or a small set of similar units. As a result, the model cannot capture the full uncertainty of the imputed values, as all quantile estimates collapse to the same value. Additionally, many of the imputed medians lie far from the actual values. This highlights a key limitation of matching-based imputation: while simple and interpretable, it may lack the flexibility to accurately quantify uncertainty or represent the underlying distribution, especially in complex or high-variance data."
+ "This plot displays the prediction intervals produced by a Matching model for total serum cholesterol values across ten data records. Each red dot indicates the imputed median value (Q50) for a missing observation, while black dots represent the corresponding true values. Light gray and dark gray vertical bars would represent the 80% (Q10–Q90) and 40% (Q30–Q70) prediction intervals, respectively. Unlike model-based methods, all records lack visible interval bars entirely. This reflects the limited variability inherent in matching methods, where each imputed value is drawn from a single matched donor or a small set of similar units. As a result, the model cannot capture the full uncertainty of the imputed values, as all quantile estimates collapse to the same value. Additionally, many of the imputed medians lie far from the actual values. This highlights a key limitation of matching-based imputation: while simple and interpretable, it may lack the flexibility to accurately quantify uncertainty or represent the underlying distribution, especially in complex or high-variance data."
]
},
{
@@ -4981,7 +5043,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Tuning the Matching model\n",
+ "## Tuning the Matching model\n",
"\n",
"The Matching imputer supports various parameters that can be adjusted to improve performance. To set specific values you know increase performance for your specific dataset see below. Additionally, automatic hyperparameter tunning specific to the target dataset is enabled by setting the parameter `tune_hyperparameters` to True. "
]
diff --git a/docs/models/mdn/index.md b/docs/models/mdn/index.md
new file mode 100644
index 0000000..c48797c
--- /dev/null
+++ b/docs/models/mdn/index.md
@@ -0,0 +1,29 @@
+# Mixture Density Network
+
+The `MDN` model uses deep neural networks with mixture density outputs to predict missing values by learning complex, potentially multi-modal conditional distributions. Built on PyTorch Tabular, this approach combines the flexibility of neural networks with the probabilistic richness of Gaussian mixture models.
+
+## Variable type support
+
+MDN automatically adapts to your target variable types. For numerical variables, it uses a mixture density network head that models the full conditional distribution as a mixture of Gaussians, enabling it to capture complex, multi-modal relationships. For categorical and boolean variables, it switches to a neural classifier with appropriate output layers. This automatic detection means you can pass mixed variable types without manual configuration.
+
+## How it works
+
+The MDN imputer uses a shared backbone architecture (configurable dense layers with dropout and batch normalization options) that feeds into specialized output heads. For numerical targets, the mixture density head outputs parameters for a mixture of Gaussian distributions: mixing coefficients, means, and variances for each component. Predictions are generated by stochastically sampling from this learned distribution rather than returning point estimates.
+
+For categorical and boolean targets, the model uses a standard classification head with softmax outputs. Predictions are made by sampling from the predicted probability distribution, preserving the stochastic nature of imputation.
+
+The model supports automatic caching based on data hashes, avoiding redundant retraining when the same data is encountered again. Hyperparameter tuning via Optuna is available for optimizing the number of Gaussian components and learning rate.
+
+## Key features
+
+MDN offers several advantages for complex imputation tasks. The mixture density approach can model multi-modal distributions that simpler methods cannot capture, making it suitable for variables with complex conditional distributions. The neural network backbone can learn non-linear relationships without requiring explicit feature engineering.
+
+Training leverages GPU acceleration when available and includes early stopping to prevent overfitting. The automatic model caching system speeds up repeated analyses on the same dataset.
+
+## Installation note
+
+MDN requires the `pytorch-tabular` package, which is an optional dependency. Install it with:
+
+```bash
+pip install pytorch_tabular
+```
diff --git a/docs/models/mdn/mdn-imputation.ipynb b/docs/models/mdn/mdn-imputation.ipynb
new file mode 100644
index 0000000..674f716
--- /dev/null
+++ b/docs/models/mdn/mdn-imputation.ipynb
@@ -0,0 +1,5630 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "cell-0",
+ "metadata": {},
+ "source": [
+ "# Mixture Density Network (MDN) imputation\n",
+ "\n",
+ "This notebook demonstrates how to use Microimpute's `MDN` imputer to impute values using mixture density networks. MDN models the full conditional distribution of a target variable as a mixture of Gaussians, enabling it to capture complex, multi-modal relationships.\n",
+ "\n",
+ "## Variable type support\n",
+ "\n",
+ "The MDN model automatically handles both numerical and categorical variables. For numerical targets, it applies a mixture density network that learns the parameters of a Gaussian mixture model. For categorical targets (strings, booleans, or numerically-encoded categorical variables), it switches to using a neural classifier. This automatic adaptation happens internally without requiring any manual configuration."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-api",
+ "metadata": {},
+ "source": [
+ "### MDN class\n",
+ "\n",
+ "```python\n",
+ "class MDN(\n",
+ " layers: str = \"128-64-32\",\n",
+ " activation: str = \"ReLU\",\n",
+ " dropout: float = 0.0,\n",
+ " use_batch_norm: bool = False,\n",
+ " num_gaussian: int = 5,\n",
+ " softmax_temperature: float = 1.0,\n",
+ " n_samples: int = 100,\n",
+ " learning_rate: float = 1e-3,\n",
+ " max_epochs: int = 100,\n",
+ " early_stopping_patience: int = 10,\n",
+ " batch_size: int = 256,\n",
+ " model_dir: str = \"./microimpute_models\",\n",
+ " force_retrain: bool = False,\n",
+ " seed: Optional[int] = RANDOM_STATE,\n",
+ " log_level: Optional[str] = \"WARNING\"\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used| Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| layers | str | \"128-64-32\" | Network architecture as hyphen-separated layer sizes |\n",
+ "| activation | str | \"ReLU\" | Activation function (ReLU, LeakyReLU, SELU, etc.) |\n",
+ "| dropout | float | 0.0 | Dropout probability for regularization |\n",
+ "| use_batch_norm | bool | False | Whether to use batch normalization |\n",
+ "| num_gaussian | int | 5 | Number of Gaussian components in the mixture |\n",
+ "| softmax_temperature | float | 1.0 | Temperature for mixture weight softmax |\n",
+ "| n_samples | int | 100 | Number of samples for MDN prediction |\n",
+ "| learning_rate | float | 1e-3 | Learning rate for Adam optimizer |\n",
+ "| max_epochs | int | 100 | Maximum training epochs |\n",
+ "| early_stopping_patience | int | 10 | Epochs to wait before early stopping |\n",
+ "| batch_size | int | 256 | Training batch size |\n",
+ "| model_dir | str | \"./microimpute_models\" | Directory for caching trained models |\n",
+ "| force_retrain | bool | False | If True, skip cache and always retrain |\n",
+ "| seed | int | 42 | Random seed for reproducibility |\n",
+ "| log_level | str | \"WARNING\" | Logging level (DEBUG, INFO, WARNING, ERROR) |\n",
+ "\n",
+ "### fit() method\n",
+ "\n",
+ "```python\n",
+ "def fit(\n",
+ " X_train: pd.DataFrame,\n",
+ " predictors: List[str],\n",
+ " imputed_variables: List[str],\n",
+ " weight_col: Optional[str] = None,\n",
+ " tune_hyperparameters: bool = False,\n",
+ " n_trials: int = 10,\n",
+ " cv_folds: int = 3\n",
+ ") -> Union[MDNResults, Tuple[MDNResults, Dict]]\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| X_train | pd.DataFrame | - | Training data with predictors and target variables |\n",
+ "| predictors | List[str] | - | Column names to use as predictors |\n",
+ "| imputed_variables | List[str] | - | Column names of variables to impute |\n",
+ "| weight_col | str | None | Column name for sampling weights |\n",
+ "| tune_hyperparameters | bool | False | Enable Optuna-based hyperparameter tuning |\n",
+ "| n_trials | int | 10 | Number of Optuna trials for tuning |\n",
+ "| cv_folds | int | 3 | Number of cross-validation folds for tuning |\n",
+ "\n",
+ "It returns a `MDNResults` object (or tuple with best hyperparameters if tuning enabled).\n",
+ "\n",
+ "### MDNResults.predict() method\n",
+ "\n",
+ "```python\n",
+ "def predict(\n",
+ " X_test: pd.DataFrame,\n",
+ " quantiles: Optional[List[float]] = None\n",
+ ") -> Dict[float, pd.DataFrame]\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| X_test | pd.DataFrame | - | Data to impute (with predictors) |\n",
+ "| quantiles | List[float] | QUANTILES | Quantiles at which to return predictions |\n",
+ "\n",
+ "It returns a dictionary mapping quantiles to DataFrames of imputed values."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-1",
+ "metadata": {},
+ "source": [
+ "## Setup and data preparation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "cell-2",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <21BE8260-F4D5-3597-9DD0-6953BC4DDF3D> /Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
+ "Trying to import in ABI mode.\n"
+ ]
+ }
+ ],
+ "source": [
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "import logging\n",
+ "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
+ "logging.getLogger(\"pytorch_tabular\").setLevel(logging.ERROR)\n",
+ "logging.getLogger(\"lightning\").setLevel(logging.ERROR)\n",
+ "\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import plotly.express as px\n",
+ "import plotly.graph_objects as go\n",
+ "from sklearn.datasets import load_diabetes\n",
+ "\n",
+ "pd.set_option(\"display.width\", 600)\n",
+ "pd.set_option(\"display.max_columns\", 10)\n",
+ "pd.set_option(\"display.expand_frame_repr\", False)\n",
+ "\n",
+ "from microimpute.utils.data import preprocess_data\n",
+ "from microimpute.evaluations import cross_validate_model\n",
+ "from microimpute.models import MDN\n",
+ "from microimpute.config import QUANTILES\n",
+ "from microimpute.visualizations import model_performance_results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "cell-3",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " age | \n",
+ " sex | \n",
+ " bmi | \n",
+ " bp | \n",
+ " s1 | \n",
+ " s2 | \n",
+ " s3 | \n",
+ " s4 | \n",
+ " s5 | \n",
+ " s6 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 0.038076 | \n",
+ " 0.050680 | \n",
+ " 0.061696 | \n",
+ " 0.021872 | \n",
+ " -0.044223 | \n",
+ " -0.034821 | \n",
+ " -0.043401 | \n",
+ " -0.002592 | \n",
+ " 0.019907 | \n",
+ " -0.017646 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " -0.001882 | \n",
+ " -0.044642 | \n",
+ " -0.051474 | \n",
+ " -0.026328 | \n",
+ " -0.008449 | \n",
+ " -0.019163 | \n",
+ " 0.074412 | \n",
+ " -0.039493 | \n",
+ " -0.068332 | \n",
+ " -0.092204 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 0.085299 | \n",
+ " 0.050680 | \n",
+ " 0.044451 | \n",
+ " -0.005670 | \n",
+ " -0.045599 | \n",
+ " -0.034194 | \n",
+ " -0.032356 | \n",
+ " -0.002592 | \n",
+ " 0.002861 | \n",
+ " -0.025930 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " -0.089063 | \n",
+ " -0.044642 | \n",
+ " -0.011595 | \n",
+ " -0.036656 | \n",
+ " 0.012191 | \n",
+ " 0.024991 | \n",
+ " -0.036038 | \n",
+ " 0.034309 | \n",
+ " 0.022688 | \n",
+ " -0.009362 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 0.005383 | \n",
+ " -0.044642 | \n",
+ " -0.036385 | \n",
+ " 0.021872 | \n",
+ " 0.003935 | \n",
+ " 0.015596 | \n",
+ " 0.008142 | \n",
+ " -0.002592 | \n",
+ " -0.031988 | \n",
+ " -0.046641 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age sex bmi bp s1 s2 s3 s4 s5 s6\n",
+ "0 0.038076 0.050680 0.061696 0.021872 -0.044223 -0.034821 -0.043401 -0.002592 0.019907 -0.017646\n",
+ "1 -0.001882 -0.044642 -0.051474 -0.026328 -0.008449 -0.019163 0.074412 -0.039493 -0.068332 -0.092204\n",
+ "2 0.085299 0.050680 0.044451 -0.005670 -0.045599 -0.034194 -0.032356 -0.002592 0.002861 -0.025930\n",
+ "3 -0.089063 -0.044642 -0.011595 -0.036656 0.012191 0.024991 -0.036038 0.034309 0.022688 -0.009362\n",
+ "4 0.005383 -0.044642 -0.036385 0.021872 0.003935 0.015596 0.008142 -0.002592 -0.031988 -0.046641"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Load the diabetes dataset\n",
+ "diabetes = load_diabetes()\n",
+ "df = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)\n",
+ "\n",
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "cell-4",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training set size: 353 records\n",
+ "Testing set size: 89 records\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Define variables for the model\n",
+ "predictors = [\"age\", \"sex\", \"bmi\", \"bp\"]\n",
+ "imputed_variables = [\"s1\", \"s4\"]\n",
+ "\n",
+ "# Create a subset with only needed columns\n",
+ "diabetes_df = df[predictors + imputed_variables]\n",
+ "\n",
+ "# Split data into training and testing sets\n",
+ "X_train, X_test = preprocess_data(diabetes_df)\n",
+ "\n",
+ "print(f\"Training set size: {X_train.shape[0]} records\")\n",
+ "print(f\"Testing set size: {X_test.shape[0]} records\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-5",
+ "metadata": {},
+ "source": [
+ "## Simulating missing data\n",
+ "\n",
+ "For this example, we simulate missing data in our test set by removing the values we want to impute."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "cell-6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " age | \n",
+ " sex | \n",
+ " bmi | \n",
+ " bp | \n",
+ " s1 | \n",
+ " s4 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 287 | \n",
+ " 0.045341 | \n",
+ " -0.044642 | \n",
+ " -0.006206 | \n",
+ " -0.015999 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " | 211 | \n",
+ " 0.092564 | \n",
+ " -0.044642 | \n",
+ " 0.036907 | \n",
+ " 0.021872 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " | 72 | \n",
+ " 0.063504 | \n",
+ " 0.050680 | \n",
+ " -0.004050 | \n",
+ " -0.012556 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " | 321 | \n",
+ " 0.096197 | \n",
+ " -0.044642 | \n",
+ " 0.051996 | \n",
+ " 0.079265 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " | 73 | \n",
+ " 0.012648 | \n",
+ " 0.050680 | \n",
+ " -0.020218 | \n",
+ " -0.002228 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age sex bmi bp s1 s4\n",
+ "287 0.045341 -0.044642 -0.006206 -0.015999 NaN NaN\n",
+ "211 0.092564 -0.044642 0.036907 0.021872 NaN NaN\n",
+ "72 0.063504 0.050680 -0.004050 -0.012556 NaN NaN\n",
+ "321 0.096197 -0.044642 0.051996 0.079265 NaN NaN\n",
+ "73 0.012648 0.050680 -0.020218 -0.002228 NaN NaN"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Create a copy of the test set with missing values\n",
+ "X_test_missing = X_test.copy()\n",
+ "\n",
+ "# Store the actual values for later comparison\n",
+ "actual_values = X_test_missing[imputed_variables].copy()\n",
+ "\n",
+ "# Remove the values to be imputed\n",
+ "X_test_missing[imputed_variables] = np.nan\n",
+ "\n",
+ "X_test_missing.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-7",
+ "metadata": {},
+ "source": [
+ "## Training and using the MDN imputer\n",
+ "\n",
+ "Now we train the MDN imputer and use it to impute the missing values in our test set. The MDN model automatically caches trained models based on a hash of the input data, speeding up repeated analyses."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "cell-8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize the MDN imputer\n",
+ "mdn_imputer = MDN(\n",
+ " layers=\"64-32\",\n",
+ " num_gaussian=3,\n",
+ " max_epochs=30,\n",
+ " early_stopping_patience=5,\n",
+ ")\n",
+ "\n",
+ "# Fit the model\n",
+ "fitted_mdn_imputer = mdn_imputer.fit(\n",
+ " X_train,\n",
+ " predictors,\n",
+ " imputed_variables,\n",
+ " tune_hyperparameters=False, # Hyperparameter tuning may be enabled, note it will have high computational cost\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "cell-9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " s1 | \n",
+ " s4 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 287 | \n",
+ " -0.239280 | \n",
+ " -0.307822 | \n",
+ "
\n",
+ " \n",
+ " | 211 | \n",
+ " 0.339311 | \n",
+ " 0.358265 | \n",
+ "
\n",
+ " \n",
+ " | 72 | \n",
+ " -0.276636 | \n",
+ " -0.277714 | \n",
+ "
\n",
+ " \n",
+ " | 321 | \n",
+ " 0.540901 | \n",
+ " 0.551054 | \n",
+ "
\n",
+ " \n",
+ " | 73 | \n",
+ " 0.460601 | \n",
+ " 0.446176 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " s1 s4\n",
+ "287 -0.239280 -0.307822\n",
+ "211 0.339311 0.358265\n",
+ "72 -0.276636 -0.277714\n",
+ "321 0.540901 0.551054\n",
+ "73 0.460601 0.446176"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Impute values in the test set\n",
+ "imputed_values = fitted_mdn_imputer.predict(X_test_missing, QUANTILES)\n",
+ "\n",
+ "# Display the first few imputed values at the median\n",
+ "imputed_values[0.5].head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-10",
+ "metadata": {},
+ "source": [
+ "## Evaluating the imputation results\n",
+ "\n",
+ "Now let's compare the imputed values with the actual values to evaluate the performance of our imputer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "cell-11",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "hovertemplate": "Actual=%{x}
ClosestPrediction=%{y}",
+ "legendgroup": "",
+ "marker": {
+ "color": "#636efa",
+ "opacity": 0.7,
+ "symbol": "circle"
+ },
+ "mode": "markers",
+ "name": "",
+ "orientation": "v",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 0.12501870313429186,
+ 0.03430885887772673,
+ -0.0249601584096303,
+ -0.03949338287409329,
+ 0.10300345740307394,
+ -0.002592261998183278,
+ 0.05484510736603471,
+ 0.14132210941786577,
+ 0.03833367306762126,
+ 0.03430885887772673,
+ 0.09887559882847057,
+ -0.002592261998183278,
+ 0.030077955918414535,
+ 0.03430885887772673,
+ 0.10988322169407955,
+ 0.03430885887772673,
+ -0.0249601584096303,
+ -0.002592261998183278,
+ 0.03695772020942014,
+ -0.002592261998183278,
+ 0.04658939021682799,
+ -0.024732934523729287,
+ -0.06761469701386505,
+ -0.002592261998183278,
+ -0.05523112129005496,
+ -0.0763945037500033,
+ 0.014942474478202204,
+ 0.03430885887772673,
+ -0.027712064126032544,
+ -0.03949338287409329,
+ -0.07311850844666953,
+ -0.03949338287409329,
+ 0.03833367306762126,
+ 0.03430885887772673,
+ 0.020446285911006685,
+ 0.07120997975363674,
+ -0.12678066991651324,
+ -0.047980640675552584,
+ 0.013566521620001083,
+ 0.03430885887772673,
+ -0.012576582685820214,
+ -0.002592261998183278,
+ 0.045213437358626866,
+ -0.002592261998183278,
+ -0.007072771253015731,
+ -0.03949338287409329,
+ 0.016318427336403322,
+ -0.002592261998183278,
+ -0.009824676969417972,
+ -0.03949338287409329,
+ -0.030463969842434782,
+ -0.002592261998183278,
+ -0.02220825269322806,
+ -0.002592261998183278,
+ -0.051103262715451604,
+ 0.03430885887772673,
+ -0.0249601584096303,
+ -0.0763945037500033,
+ 0.0342058144930179,
+ -0.002592261998183278,
+ 0.0080627101871966,
+ 0.021024455362399115,
+ 0.0025588987543921156,
+ -0.002592261998183278,
+ 0.0025588987543921156,
+ -0.002592261998183278,
+ -0.016704441260423575,
+ 0.03430885887772673,
+ 0.045213437358626866,
+ 0.03615391492152222,
+ 0.07823630595545376,
+ -0.002592261998183278,
+ -0.011200629827619093,
+ -0.002592261998183278,
+ 0.03145390877661565,
+ 0.019917421736121838,
+ 0.024574144485610048,
+ 0.03430885887772673,
+ -0.001568959820211247,
+ -0.03949338287409329,
+ -0.00019300696201012598,
+ -0.03949338287409329,
+ -0.00019300696201012598,
+ -0.05056371913686628,
+ -0.06623874415566393,
+ -0.002592261998183278,
+ -0.004320865536613489,
+ 0.07120997975363674,
+ 0.04383748450042574,
+ -0.014400620678474476,
+ 0.03282986163481677,
+ -0.03949338287409329,
+ -0.038719686991641515,
+ -0.03949338287409329,
+ -0.04422349842444599,
+ -0.0763945037500033,
+ -0.035967781275239266,
+ -0.05167075276314359,
+ -0.007072771253015731,
+ -0.002592261998183278,
+ -0.07311850844666953,
+ -0.0763945037500033,
+ -0.019456346976825818,
+ -0.03949338287409329,
+ -0.007072771253015731,
+ 0.07120997975363674,
+ -0.008448724111216851,
+ -0.03949338287409329,
+ 0.08924392882106273,
+ 0.10811110062954676,
+ -0.0249601584096303,
+ -0.03949338287409329,
+ 0.03282986163481677,
+ -0.002592261998183278,
+ -0.04422349842444599,
+ -0.002592261998183278,
+ -0.0029449126784123676,
+ -0.03949338287409329,
+ -0.033215875558837024,
+ -0.0763945037500033,
+ 0.08236416453005713,
+ 0.07120997975363674,
+ -0.0318399227006359,
+ 0.0029429061332032365,
+ -0.04972730985725048,
+ -0.03949338287409329,
+ 0.010814615903598841,
+ -0.03949338287409329,
+ -0.005696818394814609,
+ 0.03430885887772673,
+ 0.06172487165704031,
+ -0.002592261998183278,
+ 0.05622106022423583,
+ 0.07120997975363674,
+ -0.007072771253015731,
+ -0.002592261998183278,
+ -0.06348683843926169,
+ -0.03949338287409329,
+ -0.015328488402222454,
+ 0.03430885887772673,
+ -0.019456346976825818,
+ 0.03430885887772673,
+ 0.039709625925822375,
+ 0.07120997975363674,
+ 0.045213437358626866,
+ 0.07120997975363674,
+ -0.04972730985725048,
+ 0.01585829843977173,
+ -0.026336111267831423,
+ -0.03949338287409329,
+ 0.03833367306762126,
+ 0.10811110062954676,
+ 0.08511607024645937,
+ 0.03430885887772673,
+ 0.016318427336403322,
+ 0.02655962349378563,
+ 0.020446285911006685,
+ -0.002592261998183278,
+ 0.01219056876179996,
+ 0.10811110062954676,
+ -0.0029449126784123676,
+ -0.03949338287409329,
+ -0.046975404140848234,
+ -0.03949338287409329,
+ -0.0029449126784123676,
+ -0.047242618258034386,
+ 0.04658939021682799,
+ -0.03949338287409329,
+ -0.007072771253015731,
+ -0.03949338287409329,
+ -0.030463969842434782,
+ -0.0763945037500033,
+ -0.08962994274508297,
+ -0.0763945037500033,
+ -0.05935897986465832,
+ -0.03949338287409329,
+ -0.030463969842434782,
+ -0.002592261998183278
+ ],
+ "xaxis": "x",
+ "y": [
+ 0.15996141731739044,
+ 0.032140739262104034,
+ 0.023270580917596817,
+ -0.07243411242961884,
+ 0.06383831799030304,
+ 0.0471138134598732,
+ 0.08792641758918762,
+ 0.183455690741539,
+ 0.034883130341768265,
+ -0.020984657108783722,
+ 0.1328243613243103,
+ 0.002175786066800356,
+ 0.8124674558639526,
+ 0.9326153993606567,
+ 0.1008240208029747,
+ 0.06365294754505157,
+ 0.3842753469944,
+ 0.46691128611564636,
+ 0.017988024279475212,
+ -0.016335798427462578,
+ 0.06642946600914001,
+ -0.008764325641095638,
+ -0.0118333101272583,
+ 0.010635918006300926,
+ -0.012584746815264225,
+ -0.0853499099612236,
+ -0.002956878859549761,
+ -0.024943694472312927,
+ -0.033975280821323395,
+ -0.07795815169811249,
+ -0.07504265010356903,
+ -0.03126005083322525,
+ 0.03940160199999809,
+ 0.053814880549907684,
+ -0.04383951798081398,
+ 0.019542427733540535,
+ -0.09670525044202805,
+ -0.0822581872344017,
+ 0.06511250138282776,
+ 0.09403425455093384,
+ 0.00344745977781713,
+ -0.00915258377790451,
+ 0.0950278714299202,
+ -0.048670198768377304,
+ 0.039458371698856354,
+ -0.07556705176830292,
+ -0.006443846970796585,
+ 0.020815376192331314,
+ -0.036735083907842636,
+ -0.044321704655885696,
+ -0.08580169081687927,
+ -0.059006646275520325,
+ -0.019963562488555908,
+ 0.009087007492780685,
+ -0.130485400557518,
+ -0.0006790757179260254,
+ -0.037002887576818466,
+ -0.09501223266124725,
+ 0.012499166652560234,
+ 0.02030245214700699,
+ 0.014375532045960426,
+ -0.02604287676513195,
+ 0.057647883892059326,
+ 0.05878327041864395,
+ -0.04572222754359245,
+ -0.03820172697305679,
+ -0.06333865225315094,
+ 0.04185344651341438,
+ 0.05389529839158058,
+ 0.01455413457006216,
+ 0.09297429770231247,
+ -0.0037352601066231728,
+ -0.014067748561501503,
+ 0.026063010096549988,
+ 0.007550985552370548,
+ -0.025209808722138405,
+ -0.04939274489879608,
+ -0.030361073091626167,
+ 0.03330788016319275,
+ -0.014227045699954033,
+ 0.0016633199993520975,
+ -0.08225220441818237,
+ 0.010647851973772049,
+ -0.08228402584791183,
+ -0.07383066415786743,
+ 0.061680346727371216,
+ -0.5339921116828918,
+ -0.6321333646774292,
+ 0.025442680343985558,
+ -0.022219836711883545,
+ -0.026420146226882935,
+ -0.10970157384872437,
+ -0.041401173919439316,
+ -0.0358879491686821,
+ -0.003495529294013977,
+ -0.04557208716869354,
+ 0.013635829091072083,
+ -0.043398745357990265,
+ 0.026523813605308533,
+ 0.00031686751754023135,
+ -0.09354585409164429,
+ -0.05077875778079033,
+ -0.03905845806002617,
+ -0.02609078772366047,
+ -0.006901499815285206,
+ 0.05072978138923645,
+ -0.0048773703165352345,
+ -0.01315099187195301,
+ 0.083067886531353,
+ 0.11018361896276474,
+ -0.042783450335264206,
+ -0.042475901544094086,
+ 0.007468030788004398,
+ 0.03531891107559204,
+ -0.09278491139411926,
+ -0.008651846088469028,
+ 0.03183288499712944,
+ -0.08680566400289536,
+ -0.01793188974261284,
+ -0.08094043284654617,
+ 0.08602085709571838,
+ 0.06315568834543228,
+ -0.020230960100889206,
+ 0.003904549404978752,
+ -0.03103073686361313,
+ -0.06506631523370743,
+ 0.036732763051986694,
+ -0.07117997109889984,
+ -0.0006138663738965988,
+ 0.03311414271593094,
+ 0.06245873123407364,
+ 0.002597265876829624,
+ 0.01770462840795517,
+ 0.058917708694934845,
+ -0.06402486562728882,
+ 0.048038020730018616,
+ -0.03813735395669937,
+ -0.028543436899781227,
+ -0.0075743068009614944,
+ 0.03675296530127525,
+ -0.014315500855445862,
+ 0.040746308863162994,
+ 0.07579956203699112,
+ 0.1126079186797142,
+ 0.08204692602157593,
+ 0.07413098216056824,
+ -0.0642552450299263,
+ 0.002614614088088274,
+ -0.033612947911024094,
+ -0.015575357712805271,
+ 0.08901567757129669,
+ 0.12694990634918213,
+ 0.12089148908853531,
+ 0.0254050400108099,
+ -0.005140587687492371,
+ 0.06322228908538818,
+ 0.017512474209070206,
+ 0.013693263754248619,
+ 0.059920668601989746,
+ 0.04606965184211731,
+ -0.03190714493393898,
+ -0.01763644628226757,
+ -0.03793055936694145,
+ -0.0012233610032126307,
+ 0.012116450816392899,
+ -0.026453981176018715,
+ 0.0048493873327970505,
+ -0.008992618881165981,
+ 0.05157541483640671,
+ -0.028200700879096985,
+ -0.05151493847370148,
+ -0.05033167079091072,
+ -0.11441226303577423,
+ -0.11985128372907639,
+ -0.07205710560083389,
+ 0.0047546266578137875,
+ -0.018449068069458008,
+ -0.006497040390968323
+ ],
+ "yaxis": "y"
+ },
+ {
+ "line": {
+ "color": "red",
+ "dash": "dash"
+ },
+ "mode": "lines",
+ "name": "Perfect prediction",
+ "type": "scatter",
+ "x": [
+ -1.969916820526123,
+ 1.6161797046661377
+ ],
+ "y": [
+ -1.969916820526123,
+ 1.6161797046661377
+ ]
+ }
+ ],
+ "layout": {
+ "height": 600,
+ "legend": {
+ "tracegroupgap": 0
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "white",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "white",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "#C8D4E3",
+ "linecolor": "#C8D4E3",
+ "minorgridcolor": "#C8D4E3",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "#C8D4E3",
+ "linecolor": "#C8D4E3",
+ "minorgridcolor": "#C8D4E3",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "white",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "#C8D4E3"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "white",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "#EBF0F8",
+ "linecolor": "#EBF0F8",
+ "ticks": ""
+ },
+ "bgcolor": "white",
+ "radialaxis": {
+ "gridcolor": "#EBF0F8",
+ "linecolor": "#EBF0F8",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "white",
+ "gridcolor": "#DFE8F3",
+ "gridwidth": 2,
+ "linecolor": "#EBF0F8",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "#EBF0F8"
+ },
+ "yaxis": {
+ "backgroundcolor": "white",
+ "gridcolor": "#DFE8F3",
+ "gridwidth": 2,
+ "linecolor": "#EBF0F8",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "#EBF0F8"
+ },
+ "zaxis": {
+ "backgroundcolor": "white",
+ "gridcolor": "#DFE8F3",
+ "gridwidth": 2,
+ "linecolor": "#EBF0F8",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "#EBF0F8"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "#DFE8F3",
+ "linecolor": "#A2B1C6",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "#DFE8F3",
+ "linecolor": "#A2B1C6",
+ "ticks": ""
+ },
+ "bgcolor": "white",
+ "caxis": {
+ "gridcolor": "#DFE8F3",
+ "linecolor": "#A2B1C6",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "#EBF0F8",
+ "linecolor": "#EBF0F8",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "#EBF0F8",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "#EBF0F8",
+ "linecolor": "#EBF0F8",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "#EBF0F8",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Comparison of actual vs. imputed values using MDN"
+ },
+ "width": 750,
+ "xaxis": {
+ "anchor": "y",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "Actual values"
+ }
+ },
+ "yaxis": {
+ "anchor": "x",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "Imputed values"
+ }
+ }
+ }
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Get quantiles and create prediction matrix\n",
+ "quantiles = list(imputed_values.keys())\n",
+ "\n",
+ "# Convert imputed_values dict to array: (n_samples, n_quantiles)\n",
+ "pred_matrix = np.stack(\n",
+ " [imputed_values[q].values.flatten() for q in quantiles], axis=1\n",
+ ")\n",
+ "\n",
+ "# Actual values flattened\n",
+ "actual = actual_values.values.flatten()\n",
+ "\n",
+ "# Compute absolute error matrix\n",
+ "abs_error = np.abs(pred_matrix - actual[:, None])\n",
+ "\n",
+ "# Find closest prediction for each sample\n",
+ "closest_indices = abs_error.argmin(axis=1)\n",
+ "closest_predictions = np.array(\n",
+ " [pred_matrix[i, idx] for i, idx in enumerate(closest_indices)]\n",
+ ")\n",
+ "\n",
+ "# Create DataFrame for plotting\n",
+ "closest_df = pd.DataFrame({\n",
+ " \"Actual\": actual,\n",
+ " \"ClosestPrediction\": closest_predictions,\n",
+ "})\n",
+ "\n",
+ "# Extract median predictions\n",
+ "median_predictions = imputed_values[0.5]\n",
+ "\n",
+ "# Calculate plot bounds\n",
+ "min_val = min(actual_values.min().min(), median_predictions.min().min())\n",
+ "max_val = max(actual_values.max().max(), median_predictions.max().max())\n",
+ "\n",
+ "# Create scatter plot\n",
+ "fig = px.scatter(\n",
+ " closest_df,\n",
+ " x=\"Actual\",\n",
+ " y=\"ClosestPrediction\",\n",
+ " opacity=0.7,\n",
+ " title=\"Comparison of actual vs. imputed values using MDN\",\n",
+ ")\n",
+ "\n",
+ "# Add diagonal line (perfect prediction)\n",
+ "fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x=[min_val, max_val],\n",
+ " y=[min_val, max_val],\n",
+ " mode=\"lines\",\n",
+ " line=dict(color=\"red\", dash=\"dash\"),\n",
+ " name=\"Perfect prediction\",\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "fig.update_layout(\n",
+ " xaxis_title=\"Actual values\",\n",
+ " yaxis_title=\"Imputed values\",\n",
+ " width=750,\n",
+ " height=600,\n",
+ " template=\"plotly_white\",\n",
+ ")\n",
+ "\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-12",
+ "metadata": {},
+ "source": [
+ "This scatter plot compares actual observed values with those imputed by the MDN model. Each point represents a data record, with the x-axis showing the true value and the y-axis showing the model's prediction closest to the true value across all quantiles. The red dashed line represents the ideal 1:1 relationship. Points clustering around this line indicate that the MDN model effectively captures the underlying structure of the data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-13",
+ "metadata": {},
+ "source": [
+ "## Examining quantile predictions\n",
+ "\n",
+ "MDN provides predictions at different quantiles by sampling from the learned mixture distribution, allowing us to capture the entire conditional distribution of the missing values."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "cell-14",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Actual | \n",
+ " Q10 | \n",
+ " Q25 | \n",
+ " Q50 | \n",
+ " Q75 | \n",
+ " Q90 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 0.125019 | \n",
+ " -1.208731 | \n",
+ " -0.793742 | \n",
+ " -0.239280 | \n",
+ " 0.260678 | \n",
+ " 0.750109 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " -0.024960 | \n",
+ " -0.430616 | \n",
+ " -0.080356 | \n",
+ " 0.339311 | \n",
+ " 0.740931 | \n",
+ " 1.117335 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 0.103003 | \n",
+ " -1.449445 | \n",
+ " -0.907212 | \n",
+ " -0.276636 | \n",
+ " 0.273373 | \n",
+ " 0.844860 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 0.054845 | \n",
+ " -0.039490 | \n",
+ " 0.249726 | \n",
+ " 0.540901 | \n",
+ " 0.896468 | \n",
+ " 1.195148 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 0.038334 | \n",
+ " -0.671013 | \n",
+ " -0.092434 | \n",
+ " 0.460601 | \n",
+ " 1.070943 | \n",
+ " 1.577428 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Actual Q10 Q25 Q50 Q75 Q90\n",
+ "0 0.125019 -1.208731 -0.793742 -0.239280 0.260678 0.750109\n",
+ "1 -0.024960 -0.430616 -0.080356 0.339311 0.740931 1.117335\n",
+ "2 0.103003 -1.449445 -0.907212 -0.276636 0.273373 0.844860\n",
+ "3 0.054845 -0.039490 0.249726 0.540901 0.896468 1.195148\n",
+ "4 0.038334 -0.671013 -0.092434 0.460601 1.070943 1.577428"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Compare predictions at different quantiles for the first 5 records\n",
+ "quantiles_to_show = [0.1, 0.25, 0.5, 0.75, 0.9]\n",
+ "comparison_df = pd.DataFrame(index=range(5))\n",
+ "\n",
+ "# Add actual values\n",
+ "comparison_df[\"Actual\"] = actual_values.iloc[:5, 0].values\n",
+ "\n",
+ "# Add quantile predictions\n",
+ "for q in quantiles_to_show:\n",
+ " comparison_df[f\"Q{int(q*100)}\"] = imputed_values[q].iloc[:5, 0].values\n",
+ "\n",
+ "comparison_df"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-15",
+ "metadata": {},
+ "source": [
+ "## Visualizing prediction intervals\n",
+ "\n",
+ "By visualizing the prediction intervals we can better understand the uncertainty in our imputed values."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "cell-16",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(173, 216, 230, 0.3)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 0,
+ 0
+ ],
+ "y": [
+ -1.208730936050415,
+ 0.7501086592674255
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(173, 216, 230, 0.3)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 1,
+ 1
+ ],
+ "y": [
+ -0.4306159019470215,
+ 1.1173346042633057
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(173, 216, 230, 0.3)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 2,
+ 2
+ ],
+ "y": [
+ -1.4494445323944092,
+ 0.8448596596717834
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(173, 216, 230, 0.3)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 3,
+ 3
+ ],
+ "y": [
+ -0.03949018567800522,
+ 1.1951476335525513
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(173, 216, 230, 0.3)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 4,
+ 4
+ ],
+ "y": [
+ -0.6710134148597717,
+ 1.577427625656128
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(173, 216, 230, 0.3)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 5,
+ 5
+ ],
+ "y": [
+ -1.388054609298706,
+ 0.7144660353660583
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(173, 216, 230, 0.3)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 6,
+ 6
+ ],
+ "y": [
+ 0.9335368275642395,
+ 1.994225263595581
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(173, 216, 230, 0.3)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 7,
+ 7
+ ],
+ "y": [
+ -1.2012914419174194,
+ 0.6220993399620056
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(173, 216, 230, 0.3)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 8,
+ 8
+ ],
+ "y": [
+ 0.4870220720767975,
+ 1.1506199836730957
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(173, 216, 230, 0.3)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 9,
+ 9
+ ],
+ "y": [
+ -1.1105694770812988,
+ 0.017988024279475212
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(70, 130, 180, 0.5)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 0,
+ 0
+ ],
+ "y": [
+ -0.6731794476509094,
+ 0.15996141731739044
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(70, 130, 180, 0.5)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 1,
+ 1
+ ],
+ "y": [
+ 0.023270580917596817,
+ 0.658467710018158
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(70, 130, 180, 0.5)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 2,
+ 2
+ ],
+ "y": [
+ -0.769422709941864,
+ 0.17357832193374634
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(70, 130, 180, 0.5)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 3,
+ 3
+ ],
+ "y": [
+ 0.3068358898162842,
+ 0.8011546730995178
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(70, 130, 180, 0.5)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 4,
+ 4
+ ],
+ "y": [
+ 0.034883130341768265,
+ 0.9577965140342712
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(70, 130, 180, 0.5)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 5,
+ 5
+ ],
+ "y": [
+ -0.7857669591903687,
+ 0.1328243613243103
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(70, 130, 180, 0.5)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 6,
+ 6
+ ],
+ "y": [
+ 1.253185749053955,
+ 1.6908645629882812
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(70, 130, 180, 0.5)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 7,
+ 7
+ ],
+ "y": [
+ -0.629646360874176,
+ 0.1008240208029747
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(70, 130, 180, 0.5)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 8,
+ 8
+ ],
+ "y": [
+ 0.6771109104156494,
+ 0.9602589011192322
+ ]
+ },
+ {
+ "hoverinfo": "none",
+ "line": {
+ "color": "rgba(70, 130, 180, 0.5)",
+ "width": 10
+ },
+ "mode": "lines",
+ "showlegend": false,
+ "type": "scatter",
+ "x": [
+ 9,
+ 9
+ ],
+ "y": [
+ -0.7771245241165161,
+ -0.29959893226623535
+ ]
+ },
+ {
+ "marker": {
+ "color": "black",
+ "size": 8
+ },
+ "mode": "markers",
+ "name": "Actual",
+ "type": "scatter",
+ "x": [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9
+ ],
+ "y": [
+ 0.12501870313429186,
+ -0.0249601584096303,
+ 0.10300345740307394,
+ 0.05484510736603471,
+ 0.03833367306762126,
+ 0.09887559882847057,
+ 0.030077955918414535,
+ 0.10988322169407955,
+ -0.0249601584096303,
+ 0.03695772020942014
+ ]
+ },
+ {
+ "marker": {
+ "color": "red",
+ "size": 8
+ },
+ "mode": "markers",
+ "name": "Median (Q50)",
+ "type": "scatter",
+ "x": [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9
+ ],
+ "y": [
+ -0.23928046226501465,
+ 0.339310884475708,
+ -0.276636004447937,
+ 0.540900707244873,
+ 0.4606007933616638,
+ -0.34471726417541504,
+ 1.4770541191101074,
+ -0.30302998423576355,
+ 0.8187563419342041,
+ -0.5293239951133728
+ ]
+ },
+ {
+ "line": {
+ "color": "rgba(173, 216, 230, 0.3)",
+ "width": 10
+ },
+ "mode": "lines",
+ "name": "80% PI (Q10-Q90)",
+ "type": "scatter",
+ "x": [
+ -1,
+ -1
+ ],
+ "y": [
+ 0,
+ 0
+ ]
+ },
+ {
+ "line": {
+ "color": "rgba(70, 130, 180, 0.5)",
+ "width": 10
+ },
+ "mode": "lines",
+ "name": "40% PI (Q30-Q70)",
+ "type": "scatter",
+ "x": [
+ -1,
+ -1
+ ],
+ "y": [
+ 0,
+ 0
+ ]
+ }
+ ],
+ "layout": {
+ "height": 600,
+ "legend": {
+ "x": 0.99,
+ "xanchor": "right",
+ "y": 0.99,
+ "yanchor": "top"
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "white",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "white",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "#C8D4E3",
+ "linecolor": "#C8D4E3",
+ "minorgridcolor": "#C8D4E3",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "#C8D4E3",
+ "linecolor": "#C8D4E3",
+ "minorgridcolor": "#C8D4E3",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "white",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "#C8D4E3"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "white",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "#EBF0F8",
+ "linecolor": "#EBF0F8",
+ "ticks": ""
+ },
+ "bgcolor": "white",
+ "radialaxis": {
+ "gridcolor": "#EBF0F8",
+ "linecolor": "#EBF0F8",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "white",
+ "gridcolor": "#DFE8F3",
+ "gridwidth": 2,
+ "linecolor": "#EBF0F8",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "#EBF0F8"
+ },
+ "yaxis": {
+ "backgroundcolor": "white",
+ "gridcolor": "#DFE8F3",
+ "gridwidth": 2,
+ "linecolor": "#EBF0F8",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "#EBF0F8"
+ },
+ "zaxis": {
+ "backgroundcolor": "white",
+ "gridcolor": "#DFE8F3",
+ "gridwidth": 2,
+ "linecolor": "#EBF0F8",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "#EBF0F8"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "#DFE8F3",
+ "linecolor": "#A2B1C6",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "#DFE8F3",
+ "linecolor": "#A2B1C6",
+ "ticks": ""
+ },
+ "bgcolor": "white",
+ "caxis": {
+ "gridcolor": "#DFE8F3",
+ "linecolor": "#A2B1C6",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "#EBF0F8",
+ "linecolor": "#EBF0F8",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "#EBF0F8",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "#EBF0F8",
+ "linecolor": "#EBF0F8",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "#EBF0F8",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "MDN imputation prediction intervals"
+ },
+ "width": 750,
+ "xaxis": {
+ "showgrid": true,
+ "title": {
+ "text": "Data record index"
+ }
+ },
+ "yaxis": {
+ "showgrid": true,
+ "title": {
+ "text": "Value (s1)"
+ }
+ }
+ }
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Create prediction interval plot for first 10 records\n",
+ "n_records = 10\n",
+ "\n",
+ "# Prepare data\n",
+ "records = list(range(n_records))\n",
+ "actuals = actual_values.iloc[:n_records, 0].values\n",
+ "medians = imputed_values[0.5].iloc[:n_records, 0].values\n",
+ "q10 = imputed_values[0.1].iloc[:n_records, 0].values\n",
+ "q90 = imputed_values[0.9].iloc[:n_records, 0].values\n",
+ "q30 = imputed_values[0.3].iloc[:n_records, 0].values\n",
+ "q70 = imputed_values[0.7].iloc[:n_records, 0].values\n",
+ "\n",
+ "# Create figure\n",
+ "fig = go.Figure()\n",
+ "\n",
+ "# Add 80% prediction interval (Q10-Q90)\n",
+ "for i in range(n_records):\n",
+ " fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x=[i, i],\n",
+ " y=[q10[i], q90[i]],\n",
+ " mode=\"lines\",\n",
+ " line=dict(width=10, color=\"rgba(173, 216, 230, 0.3)\"),\n",
+ " hoverinfo=\"none\",\n",
+ " showlegend=False,\n",
+ " )\n",
+ " )\n",
+ "\n",
+ "# Add 40% prediction interval (Q30-Q70)\n",
+ "for i in range(n_records):\n",
+ " fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x=[i, i],\n",
+ " y=[q30[i], q70[i]],\n",
+ " mode=\"lines\",\n",
+ " line=dict(width=10, color=\"rgba(70, 130, 180, 0.5)\"),\n",
+ " hoverinfo=\"none\",\n",
+ " showlegend=False,\n",
+ " )\n",
+ " )\n",
+ "\n",
+ "# Add actual values\n",
+ "fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x=records,\n",
+ " y=actuals,\n",
+ " mode=\"markers\",\n",
+ " marker=dict(color=\"black\", size=8),\n",
+ " name=\"Actual\",\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "# Add median predictions\n",
+ "fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x=records,\n",
+ " y=medians,\n",
+ " mode=\"markers\",\n",
+ " marker=dict(color=\"red\", size=8),\n",
+ " name=\"Median (Q50)\",\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "# Legend entries for intervals\n",
+ "fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x=[-1, -1], y=[0, 0],\n",
+ " mode=\"lines\",\n",
+ " line=dict(color=\"rgba(173, 216, 230, 0.3)\", width=10),\n",
+ " name=\"80% PI (Q10-Q90)\",\n",
+ " )\n",
+ ")\n",
+ "fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x=[-1, -1], y=[0, 0],\n",
+ " mode=\"lines\",\n",
+ " line=dict(color=\"rgba(70, 130, 180, 0.5)\", width=10),\n",
+ " name=\"40% PI (Q30-Q70)\",\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "fig.update_layout(\n",
+ " title=\"MDN imputation prediction intervals\",\n",
+ " xaxis=dict(title=\"Data record index\", showgrid=True),\n",
+ " yaxis=dict(title=\"Value (s1)\", showgrid=True),\n",
+ " width=750,\n",
+ " height=600,\n",
+ " template=\"plotly_white\",\n",
+ " legend=dict(yanchor=\"top\", y=0.99, xanchor=\"right\", x=0.99),\n",
+ ")\n",
+ "\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-17",
+ "metadata": {},
+ "source": [
+ "This plot visualizes the prediction intervals produced by the MDN model. Each vertical bar represents an 80% (light gray) or 40% (dark gray) prediction interval. Red dots mark the model's median predictions (Q50), while black dots show the actual observed values. In most cases, the true values fall within the prediction intervals, indicating that the MDN model appropriately captures uncertainty in its imputation."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-18",
+ "metadata": {},
+ "source": [
+ "## Assessing the method's performance\n",
+ "\n",
+ "To verify our model doesn't overfit and ensure robust results, we perform cross-validation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "cell-19",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
+ "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 7.6s remaining: 11.5s\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 7.6s remaining: 5.1s\n",
+ "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 7.7s finished\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Quantile loss results:\n",
+ " 0.05 0.10 0.15 0.20 0.25 ... 0.75 0.80 0.85 0.90 0.95\n",
+ "train 0.052342 0.080491 0.100151 0.114462 0.124877 ... 0.116854 0.105562 0.090379 0.070140 0.043119\n",
+ "test 0.060333 0.088338 0.109102 0.123207 0.134303 ... 0.127896 0.114887 0.099006 0.077486 0.047284\n",
+ "\n",
+ "[2 rows x 19 columns]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Run cross-validation\n",
+ "mdn_results = cross_validate_model(\n",
+ " MDN, diabetes_df, predictors, imputed_variables,\n",
+ " model_hyperparams={\"layers\": \"64-32\", \"num_gaussian\": 3, \"max_epochs\": 30}\n",
+ ")\n",
+ "\n",
+ "if \"quantile_loss\" in mdn_results:\n",
+ " print(\"Quantile loss results:\")\n",
+ " print(mdn_results[\"quantile_loss\"][\"results\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "cell-20",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "marker": {
+ "color": "#00CC96"
+ },
+ "name": "Train",
+ "type": "bar",
+ "x": [
+ "0.05",
+ "0.1",
+ "0.15",
+ "0.2",
+ "0.25",
+ "0.3",
+ "0.35",
+ "0.4",
+ "0.45",
+ "0.5",
+ "0.55",
+ "0.6",
+ "0.65",
+ "0.7",
+ "0.75",
+ "0.8",
+ "0.85",
+ "0.9",
+ "0.95"
+ ],
+ "y": [
+ 0.05234219330354317,
+ 0.08049099191203442,
+ 0.10015070706030556,
+ 0.11446162331927275,
+ 0.12487722321825685,
+ 0.1324132843109902,
+ 0.13768463404454853,
+ 0.14056006658781792,
+ 0.14185292803744393,
+ 0.1415992977069213,
+ 0.14005457428331253,
+ 0.13682163676619796,
+ 0.13220755848574667,
+ 0.12584912155335878,
+ 0.11685420327991669,
+ 0.10556158423365573,
+ 0.09037880944328178,
+ 0.07014017488138428,
+ 0.04311920349617806
+ ]
+ },
+ {
+ "marker": {
+ "color": "#AB63FA"
+ },
+ "name": "Test",
+ "type": "bar",
+ "x": [
+ "0.05",
+ "0.1",
+ "0.15",
+ "0.2",
+ "0.25",
+ "0.3",
+ "0.35",
+ "0.4",
+ "0.45",
+ "0.5",
+ "0.55",
+ "0.6",
+ "0.65",
+ "0.7",
+ "0.75",
+ "0.8",
+ "0.85",
+ "0.9",
+ "0.95"
+ ],
+ "y": [
+ 0.060333009487974085,
+ 0.08833805422856528,
+ 0.10910236343419596,
+ 0.12320670350109023,
+ 0.13430269442274195,
+ 0.14292452496806163,
+ 0.14938276914149362,
+ 0.15449110540720765,
+ 0.1571806630200133,
+ 0.1578150543631282,
+ 0.15689612315245363,
+ 0.15328965501428699,
+ 0.14702280821997657,
+ 0.13845547628705684,
+ 0.12789622672114154,
+ 0.11488700831441748,
+ 0.099006403363506,
+ 0.07748605801852056,
+ 0.047284249954029064
+ ]
+ }
+ ],
+ "layout": {
+ "barmode": "group",
+ "height": 600,
+ "legend": {
+ "x": 0.99,
+ "xanchor": "right",
+ "y": 0.99,
+ "yanchor": "top"
+ },
+ "margin": {
+ "b": 50,
+ "l": 50,
+ "r": 50,
+ "t": 80
+ },
+ "paper_bgcolor": "#F0F0F0",
+ "plot_bgcolor": "#F0F0F0",
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "MDN cross-validation performance"
+ },
+ "width": 750,
+ "xaxis": {
+ "showgrid": false,
+ "title": {
+ "text": "Quantile"
+ },
+ "zeroline": false
+ },
+ "yaxis": {
+ "showgrid": false,
+ "title": {
+ "text": "Average Quantile Loss"
+ },
+ "zeroline": false
+ }
+ }
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Plot the results\n",
+ "if \"quantile_loss\" in mdn_results:\n",
+ " perf_results_viz = model_performance_results(\n",
+ " results=mdn_results[\"quantile_loss\"][\"results\"],\n",
+ " model_name=\"MDN\",\n",
+ " method_name=\"Cross-validation quantile loss average\",\n",
+ " )\n",
+ " fig = perf_results_viz.plot(\n",
+ " title=\"MDN cross-validation performance\",\n",
+ " )\n",
+ " fig.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-21",
+ "metadata": {},
+ "source": [
+ "## Hyperparameter tuning\n",
+ "\n",
+ "MDN supports automatic hyperparameter tuning using Optuna. This tunes the number of Gaussian components and learning rate to optimize performance on your specific dataset. To tune hyperparameters set `tune_hyperparameters=True`, and access the best configuration found in the dictionary returned second when fitting the model.\n",
+ "\n",
+ "```\n",
+ "fitted_tuned, best_parameters = mdn_tuned.fit(\n",
+ " X_train,\n",
+ " predictors,\n",
+ " imputed_variables,\n",
+ " tune_hyperparameters=True,\n",
+ " n_trials=10,\n",
+ " cv_folds=3,\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "Note this will require significant time due to multiple networks being trained in identifying the best hyperparameters."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-23",
+ "metadata": {},
+ "source": [
+ "## Categorical variable imputation\n",
+ "\n",
+ "MDN automatically handles categorical variables through neural classification. Let's evaluate its performance on categorical imputation tasks."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "cell-24",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Categorical variable distribution:\n",
+ "risk_level\n",
+ "low 148\n",
+ "high 148\n",
+ "medium 146\n",
+ "Name: count, dtype: int64\n",
+ "\n",
+ "Training set size: 353 records\n",
+ "Testing set size: 89 records\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Create a dataset with categorical variables\n",
+ "np.random.seed(42)\n",
+ "\n",
+ "df_categorical = pd.DataFrame()\n",
+ "df_categorical['age'] = df['age']\n",
+ "df_categorical['sex'] = df['sex']\n",
+ "df_categorical['bmi'] = df['bmi']\n",
+ "df_categorical['bp'] = df['bp']\n",
+ "df_categorical['risk_level'] = pd.qcut(\n",
+ " df['s1'], q=3, labels=['low', 'medium', 'high']\n",
+ ").astype(str)\n",
+ "\n",
+ "print(\"Categorical variable distribution:\")\n",
+ "print(df_categorical['risk_level'].value_counts())\n",
+ "\n",
+ "# Split the data\n",
+ "X_train_cat, X_test_cat = preprocess_data(df_categorical)\n",
+ "\n",
+ "print(f\"\\nTraining set size: {X_train_cat.shape[0]} records\")\n",
+ "print(f\"Testing set size: {X_test_cat.shape[0]} records\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "cell-25",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MDN model fitted for categorical variable imputation\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Fit MDN model for categorical imputation\n",
+ "predictors_cat = [\"age\", \"sex\", \"bmi\", \"bp\"]\n",
+ "imputed_variables_cat = [\"risk_level\"]\n",
+ "\n",
+ "mdn_cat_imputer = MDN(layers=\"64-32\", max_epochs=50)\n",
+ "fitted_mdn_cat = mdn_cat_imputer.fit(X_train_cat, predictors_cat, imputed_variables_cat)\n",
+ "\n",
+ "print(\"MDN model fitted for categorical variable imputation\")\n",
+ "\n",
+ "# Create test set with missing categorical values\n",
+ "X_test_cat_missing = X_test_cat.copy()\n",
+ "actual_cat_values = X_test_cat_missing[imputed_variables_cat].copy()\n",
+ "X_test_cat_missing[imputed_variables_cat] = np.nan\n",
+ "\n",
+ "# Impute the categorical values\n",
+ "imputed_cat_values = fitted_mdn_cat.predict(X_test_cat_missing, [0.5])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "cell-26",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Categorical imputation accuracy: 31.46%\n",
+ "\n",
+ "Confusion matrix:\n",
+ " Predicted: low Predicted: medium Predicted: high\n",
+ "Actual: low 10 8 12\n",
+ "Actual: medium 14 11 4\n",
+ "Actual: high 7 16 7\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Evaluate categorical imputation accuracy\n",
+ "from sklearn.metrics import accuracy_score, confusion_matrix\n",
+ "\n",
+ "predicted = imputed_cat_values[0.5]['risk_level'].values\n",
+ "actual = actual_cat_values['risk_level'].values\n",
+ "\n",
+ "accuracy = accuracy_score(actual, predicted)\n",
+ "print(f\"Categorical imputation accuracy: {accuracy:.2%}\")\n",
+ "\n",
+ "conf_matrix = pd.DataFrame(\n",
+ " confusion_matrix(actual, predicted),\n",
+ " index=['Actual: low', 'Actual: medium', 'Actual: high'],\n",
+ " columns=['Predicted: low', 'Predicted: medium', 'Predicted: high']\n",
+ ")\n",
+ "print(\"\\nConfusion matrix:\")\n",
+ "print(conf_matrix)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "cell-27",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
+ "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 0.8s remaining: 1.2s\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 0.9s remaining: 0.6s\n",
+ "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 0.9s finished\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Categorical imputation cross-validation results (log loss):\n",
+ "Mean train log loss: 0.9622\n",
+ "Mean test log loss: 1.0936\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Run cross-validation for categorical variables\n",
+ "mdn_categorical_results = cross_validate_model(\n",
+ " MDN, df_categorical, predictors_cat, imputed_variables_cat,\n",
+ " model_hyperparams={\"layers\": \"64-32\", \"max_epochs\": 30}\n",
+ ")\n",
+ "\n",
+ "print(\"Categorical imputation cross-validation results (log loss):\")\n",
+ "print(f\"Mean train log loss: {mdn_categorical_results['log_loss']['mean_train']:.4f}\")\n",
+ "print(f\"Mean test log loss: {mdn_categorical_results['log_loss']['mean_test']:.4f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "cell-28",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "marker": {
+ "color": "#00CC96"
+ },
+ "name": "Train",
+ "showlegend": true,
+ "type": "bar",
+ "x": [
+ "Train"
+ ],
+ "xaxis": "x",
+ "y": [
+ 0.9622434565816542
+ ],
+ "yaxis": "y"
+ },
+ {
+ "marker": {
+ "color": "#AB63FA"
+ },
+ "name": "Test",
+ "showlegend": true,
+ "type": "bar",
+ "x": [
+ "Test"
+ ],
+ "xaxis": "x",
+ "y": [
+ 1.093641015261423
+ ],
+ "yaxis": "y"
+ }
+ ],
+ "layout": {
+ "annotations": [
+ {
+ "font": {
+ "size": 16
+ },
+ "showarrow": false,
+ "text": "Log loss performance",
+ "x": 0.5,
+ "xanchor": "center",
+ "xref": "paper",
+ "y": 1,
+ "yanchor": "bottom",
+ "yref": "paper"
+ }
+ ],
+ "height": 420,
+ "paper_bgcolor": "#F0F0F0",
+ "plot_bgcolor": "#F0F0F0",
+ "showlegend": true,
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "MDN categorical imputation cross-validation performance"
+ },
+ "width": 750,
+ "xaxis": {
+ "anchor": "y",
+ "domain": [
+ 0,
+ 1
+ ]
+ },
+ "yaxis": {
+ "anchor": "x",
+ "domain": [
+ 0,
+ 1
+ ]
+ }
+ }
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Plot categorical performance\n",
+ "cat_perf_results_viz = model_performance_results(\n",
+ " results=mdn_categorical_results,\n",
+ " model_name=\"MDN\",\n",
+ " method_name=\"Cross-validation log loss average\",\n",
+ " metric=\"log_loss\",\n",
+ ")\n",
+ "fig = cat_perf_results_viz.plot(\n",
+ " title=\"MDN categorical imputation cross-validation performance\",\n",
+ ")\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cell-29",
+ "metadata": {},
+ "source": [
+ "## Model caching\n",
+ "\n",
+ "MDN automatically caches trained models based on a hash of the input data. When you fit the model on the same data again, it loads from cache rather than retraining, significantly speeding up repeated analyses. Use `force_retrain=True` to bypass the cache if needed."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pe3.13",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/models/ols/ols-imputation.ipynb b/docs/models/ols/ols-imputation.ipynb
index 6f192e0..0502f68 100644
--- a/docs/models/ols/ols-imputation.ipynb
+++ b/docs/models/ols/ols-imputation.ipynb
@@ -12,7 +12,54 @@
"\n",
"The OLS model intelligently handles both numerical and categorical variables. When imputing numerical targets, it uses standard linear regression. For categorical targets (strings, booleans, or numerically-encoded categorical variables), it automatically switches to logistic regression classification internally. You don't need to specify variable types—the model detects and adapts automatically.\n",
"\n",
- "The OLS model supports iterative imputation with a single object and workflow. Pass a list of `imputed_variables` with all variables that you hope to impute for and the model will do so without needing to fit and predict for each separately."
+ "The OLS model supports iterative imputation with a single object and workflow. Pass a list of `imputed_variables` with all variables that you hope to impute for and the model will do so without needing to fit and predict for each separately.\n",
+ "### OLS class\n",
+ "\n",
+ "```python\n",
+ "class OLS(\n",
+ " log_level: Optional[str] = \"WARNING\"\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| log_level | str | \"WARNING\" | Logging verbosity level |\n",
+ "\n",
+ "### fit() method\n",
+ "\n",
+ "```python\n",
+ "def fit(\n",
+ " X_train: pd.DataFrame,\n",
+ " predictors: List[str],\n",
+ " imputed_variables: List[str],\n",
+ " weight_col: Optional[str] = None,\n",
+ ") -> OLSResults\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| X_train | pd.DataFrame | - | Training data with predictors and target variables |\n",
+ "| predictors | List[str] | - | Column names to use as predictors |\n",
+ "| imputed_variables | List[str] | - | Column names to impute |\n",
+ "| weight_col | str | None | Column name for sampling weights |\n",
+ "\n",
+ "It returns a `OLSResults` object for prediction. Internally uses linear regression for numerical targets and logistic regression for categorical/boolean targets.\n",
+ "\n",
+ "### OLSResults.predict() method\n",
+ "\n",
+ "```python\n",
+ "def predict(\n",
+ " X_test: pd.DataFrame,\n",
+ " quantiles: Optional[List[float]] = None\n",
+ ") -> Dict[float, pd.DataFrame]\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| X_test | pd.DataFrame | - | Data to impute (with predictors) |\n",
+ "| quantiles | List[float] | None | Quantiles at which to return predictions |\n",
+ "\n",
+ "It returns a dictionary mapping quantiles to DataFrames of imputed values. For numerical variables, quantile predictions are based on the normal distribution assumption. For categorical variables, predictions are sampled from the predicted probability distribution."
]
},
{
@@ -29,16 +76,18 @@
"outputs": [],
"source": [
"# Import necessary libraries\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "import logging\n",
+ "logging.getLogger(\"joblib\").setLevel(logging.ERROR)\n",
+ "\n",
"import pandas as pd\n",
"import numpy as np\n",
"import plotly.express as px\n",
"import plotly.graph_objects as go\n",
"from sklearn.datasets import load_diabetes\n",
- "import warnings\n",
- "\n",
- "warnings.filterwarnings(\"ignore\")\n",
"\n",
- "# Set pandas display options to limit table width\n",
"pd.set_option(\"display.width\", 600)\n",
"pd.set_option(\"display.max_columns\", 10)\n",
"pd.set_option(\"display.expand_frame_repr\", False)\n",
@@ -3554,7 +3603,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "This plot illustrates the prediction intervals generated by an OLS Linear Regression model for imputing total serum cholesterol values across ten records. For each observation, the red dot represents the median prediction (Q50), while the black dot indicates the true observed value. Vertical bars depict the model’s 40% prediction interval (Q30–Q70) in dark blue and the 80% prediction interval (Q10–Q90) in light blue. The intervals convey the model’s estimation of uncertainty, with wider intervals indicating less certainty about the imputed value. In some cases, the actual value falls within the 80% interval, suggesting that the OLS model is reasonably well-calibrated. However, the intervals tend to be vertically symmetric and relatively wide, sometimes missing the real values altogether. This reflects the linear nature of OLS: less responsive to local heteroskedasticity or skewness, and possibly limited in imputing power. Compared to Quantile Regression Forests, which can produce more adaptive and asymmetric intervals, the intervals here are more uniform in shape and spread. Overall, this plot shows that OLS is capable of performing fairly well on homocesdastic and simple linear datasets, though the fit may be quite limited in highly nonlinear settings."
+ "This plot illustrates the prediction intervals generated by an OLS Linear Regression model for imputing total serum cholesterol values across ten records. For each observation, the red dot represents the median prediction (Q50), while the black dot indicates the true observed value. Vertical bars depict the model’s 40% prediction interval (Q30–Q70) in dark gray and the 80% prediction interval (Q10–Q90) in light gray. The intervals convey the model’s estimation of uncertainty, with wider intervals indicating less certainty about the imputed value. In some cases, the actual value falls within the 80% interval, suggesting that the OLS model is reasonably well-calibrated. However, the intervals tend to be vertically symmetric and relatively wide, sometimes missing the real values altogether. This reflects the linear nature of OLS: less responsive to local heteroskedasticity or skewness, and possibly limited in imputing power. Compared to Quantile Regression Forests, which can produce more adaptive and asymmetric intervals, the intervals here are more uniform in shape and spread. Overall, this plot shows that OLS is capable of performing fairly well on homocesdastic and simple linear datasets, though the fit may be quite limited in highly nonlinear settings."
]
},
{
diff --git a/docs/models/qrf/index.md b/docs/models/qrf/index.md
index 43c093e..7e15838 100644
--- a/docs/models/qrf/index.md
+++ b/docs/models/qrf/index.md
@@ -4,7 +4,7 @@ The `QRF` model harnesses the power of ensemble learning by utilizing multiple d
## Variable type support
-QRF seamlessly handles both numerical and categorical variables. For numerical targets, it uses quantile regression forests. For categorical targets (strings, booleans, numerically-encoded categorical variables), it automatically employs a Random Forest Classifier. The model detects variable types automatically and applies the appropriate method internally, requiring no manual specification from users.
+QRF handles both numerical and categorical variables. For numerical targets, it uses quantile regression forests. For categorical targets (strings, booleans, numerically-encoded categorical variables), it automatically employs a Random Forest Classifier. The model detects variable types automatically and applies the appropriate method internally, requiring no manual specification from users.
## How it works
diff --git a/docs/models/qrf/qrf-imputation.ipynb b/docs/models/qrf/qrf-imputation.ipynb
index 9915680..c59a999 100644
--- a/docs/models/qrf/qrf-imputation.ipynb
+++ b/docs/models/qrf/qrf-imputation.ipynb
@@ -19,7 +19,73 @@
"1. **Variable 1**: Uses only the original predictors\n",
"2. **Variable 2**: Uses original predictors + Variable 1's imputed values \n",
"3. **Variable 3**: Uses original predictors + Variables 1 & 2's imputed values\n",
- "4. And so on..."
+ "4. And so on...\n",
+ "\n",
+ "### QRF class\n",
+ "\n",
+ "```python\n",
+ "class QRF(\n",
+ " log_level: Optional[str] = \"WARNING\",\n",
+ " memory_efficient: bool = False,\n",
+ " batch_size: Optional[int] = None,\n",
+ " cleanup_interval: int = 10\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| log_level | str | \"WARNING\" | Logging verbosity level |\n",
+ "| memory_efficient | bool | False | Enable memory-efficient mode for large datasets |\n",
+ "| batch_size | int | None | Batch size for memory-efficient processing |\n",
+ "| cleanup_interval | int | 10 | Interval for garbage collection in memory-efficient mode |\n",
+ "\n",
+ "### fit() method\n",
+ "\n",
+ "```python\n",
+ "def fit(\n",
+ " X_train: pd.DataFrame,\n",
+ " predictors: List[str],\n",
+ " imputed_variables: List[str],\n",
+ " weight_col: Optional[str] = None,\n",
+ " tune_hyperparameters: bool = False,\n",
+ " n_estimators: int = 100,\n",
+ " max_depth: Optional[int] = None,\n",
+ " min_samples_split: int = 2,\n",
+ " min_samples_leaf: int = 1,\n",
+ " max_features: Union[str, float] = 1.0,\n",
+ ") -> Union[QRFResults, Tuple[QRFResults, Dict]]\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| X_train | pd.DataFrame | - | Training data with predictors and target variables |\n",
+ "| predictors | List[str] | - | Column names to use as predictors |\n",
+ "| imputed_variables | List[str] | - | Column names to impute |\n",
+ "| weight_col | str | None | Column name for sampling weights |\n",
+ "| tune_hyperparameters | bool | False | Enable automatic hyperparameter tuning |\n",
+ "| n_estimators | int | 100 | Number of trees in the forest |\n",
+ "| max_depth | int | None | Maximum depth of trees (None = unlimited) |\n",
+ "| min_samples_split | int | 2 | Minimum samples to split a node |\n",
+ "| min_samples_leaf | int | 1 | Minimum samples at a leaf node |\n",
+ "| max_features | str or float | 1.0 | Features to consider for best split |\n",
+ "\n",
+ "It returns a `QRFResults` object (or tuple with best hyperparameters if tuning enabled).\n",
+ "\n",
+ "### QRFResults.predict() method\n",
+ "\n",
+ "```python\n",
+ "def predict(\n",
+ " X_test: pd.DataFrame,\n",
+ " quantiles: Optional[List[float]] = None\n",
+ ") -> Dict[float, pd.DataFrame]\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| X_test | pd.DataFrame | - | Data to impute (with predictors) |\n",
+ "| quantiles | List[float] | None | Quantiles at which to return predictions |\n",
+ "\n",
+ "It returns a dictionary mapping quantiles to DataFrames of imputed values. QRF learns the full conditional distribution empirically, providing flexible quantile predictions that can capture skewness and heteroskedasticity."
]
},
{
@@ -33,28 +99,21 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <21BE8260-F4D5-3597-9DD0-6953BC4DDF3D> /Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
- "Trying to import in ABI mode.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"# Import necessary libraries\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "import logging\n",
+ "logging.getLogger(\"joblib\").setLevel(logging.ERROR)\n",
+ "\n",
"import pandas as pd\n",
"import numpy as np\n",
"import plotly.express as px\n",
"import plotly.graph_objects as go\n",
"from sklearn.datasets import load_diabetes\n",
- "import warnings\n",
- "\n",
- "warnings.filterwarnings(\"ignore\")\n",
"\n",
- "# Set pandas display options to limit table width\n",
"pd.set_option(\"display.width\", 600)\n",
"pd.set_option(\"display.max_columns\", 10)\n",
"pd.set_option(\"display.expand_frame_repr\", False)\n",
@@ -3577,7 +3636,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "This plot visualizes the prediction intervals produced by the Quantile Random Forest (QRF) model for imputing total serum cholesterol values across ten data records. Each vertical bar represents an 80% (light blue) or 40% (dark blue) prediction interval, capturing the model's estimated range of plausible values based on the Q10–Q90 and Q30–Q70 quantiles, respectively. Red dots mark the model's median predictions (Q50), while black dots show the actual observed values. In most cases, the true values fall within the wider intervals, indicating that the QRF model is appropriately capturing uncertainty in its imputation. The fact that the intervals are sometimes asymmetrical around the median reflects the model’s flexibility in estimating skewed or heteroskedastic distributions. Overall, the plot demonstrates that the QRF model not only provides accurate point estimates but also yields informative prediction intervals that account for uncertainty in the imputed values."
+ "This plot visualizes the prediction intervals produced by the Quantile Random Forest (QRF) model for imputing total serum cholesterol values across ten data records. Each vertical bar represents an 80% (light gray) or 40% (dark gray) prediction interval, capturing the model's estimated range of plausible values based on the Q10–Q90 and Q30–Q70 quantiles, respectively. Red dots mark the model's median predictions (Q50), while black dots show the actual observed values. In most cases, the true values fall within the wider intervals, indicating that the QRF model is appropriately capturing uncertainty in its imputation. The fact that the intervals are sometimes asymmetrical around the median reflects the model’s flexibility in estimating skewed or heteroskedastic distributions. Overall, the plot demonstrates that the QRF model not only provides accurate point estimates but also yields informative prediction intervals that account for uncertainty in the imputed values."
]
},
{
diff --git a/docs/models/quantreg/quantreg-imputation.ipynb b/docs/models/quantreg/quantreg-imputation.ipynb
index 044e202..79308a9 100644
--- a/docs/models/quantreg/quantreg-imputation.ipynb
+++ b/docs/models/quantreg/quantreg-imputation.ipynb
@@ -12,7 +12,57 @@
"\n",
"QuantReg supports numerical variables only. The model will raise an error if you attempt to impute categorical variables (including strings, and numerically-encoded categorical variables). For datasets with categorical targets, use OLS or QRF models which automatically handle both types through internal classification.\n",
"\n",
- "The QuantReg model supports iterative imputation with a single object and workflow. Pass a list of `imputed_variables` with all numerical variables that you hope to impute for and the model will do so without needing to fit and predict for each separately."
+ "The QuantReg model supports iterative imputation with a single object and workflow. Pass a list of `imputed_variables` with all numerical variables that you hope to impute for and the model will do so without needing to fit and predict for each separately.\n",
+ "\n",
+ "### QuantReg class\n",
+ "\n",
+ "```python\n",
+ "class QuantReg(\n",
+ " log_level: Optional[str] = \"WARNING\"\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| log_level | str | \"WARNING\" | Logging verbosity level |\n",
+ "\n",
+ "### fit() method\n",
+ "\n",
+ "```python\n",
+ "def fit(\n",
+ " X_train: pd.DataFrame,\n",
+ " predictors: List[str],\n",
+ " imputed_variables: List[str],\n",
+ " weight_col: Optional[str] = None,\n",
+ " quantiles: Optional[List[float]] = QUANTILES,\n",
+ ") -> QuantRegResults\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| X_train | pd.DataFrame | - | Training data with predictors and target variables |\n",
+ "| predictors | List[str] | - | Column names to use as predictors |\n",
+ "| imputed_variables | List[str] | - | Column names to impute (numerical only) |\n",
+ "| weight_col | str | None | Column name for sampling weights |\n",
+ "| quantiles | List[float] | [0.05 to 0.95 in steps of 0.05] | Quantiles to fit separate regression models for |\n",
+ "\n",
+ "It returns a `QuantRegResults` object for prediction. A separate linear regression model is fit for each specified quantile.\n",
+ "\n",
+ "### QuantRegResults.predict() method\n",
+ "\n",
+ "```python\n",
+ "def predict(\n",
+ " X_test: pd.DataFrame,\n",
+ " quantiles: Optional[List[float]] = None\n",
+ ") -> Dict[float, pd.DataFrame]\n",
+ "```\n",
+ "\n",
+ "| Parameter | Type | Default used | Description |\n",
+ "|-----------|------|---------|-------------|\n",
+ "| X_test | pd.DataFrame | - | Data to impute (with predictors) |\n",
+ "| quantiles | List[float] | None | Quantiles at which to return predictions |\n",
+ "\n",
+ "It returns a dictionary mapping quantiles to DataFrames of imputed values. QuantReg directly models each conditional quantile, making no distributional assumptions beyond linearity at each quantile."
]
},
{
@@ -26,28 +76,22 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <21BE8260-F4D5-3597-9DD0-6953BC4DDF3D> /Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
- "Trying to import in ABI mode.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"# Import necessary libraries\n",
+ "\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "import logging\n",
+ "logging.getLogger(\"joblib\").setLevel(logging.ERROR)\n",
+ "\n",
"import pandas as pd\n",
"import numpy as np\n",
"import plotly.express as px\n",
"import plotly.graph_objects as go\n",
"from sklearn.datasets import load_diabetes\n",
- "import warnings\n",
- "\n",
- "warnings.filterwarnings(\"ignore\")\n",
"\n",
- "# Set pandas display options to limit table width\n",
"pd.set_option(\"display.width\", 600)\n",
"pd.set_option(\"display.max_columns\", 10)\n",
"pd.set_option(\"display.expand_frame_repr\", False)\n",
@@ -3567,7 +3611,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "This plot illustrates the prediction intervals generated by a Quantile Regression (QuantReg) model for imputing total serum cholesterol values across ten records. For each observation, the red dot represents the median prediction (Q50), while the black dot indicates the true observed value. Vertical bars depict the model’s 40% prediction interval (Q30–Q70) in dark blue and the 80% prediction interval (Q10–Q90) in light blue. The intervals convey the model’s estimation of uncertainty, with wider intervals indicating less certainty about the imputed value. In many cases, the actual value falls within the 80% interval, suggesting that the QuantReg model is reasonably well-calibrated. However, the intervals tend to be vertically symmetric and relatively wide, which reflects the linear nature of Quantile Regression: less responsive to local heteroskedasticity or skewness. Compared to Quantile Regression Forests, which can produce more adaptive and asymmetric intervals, the intervals here are more uniform in shape and spread. Overall, this plot shows that QuantReg is capable of capturing uncertainty around its median predictions, though the fit may be somewhat conservative or limited in highly nonlinear settings."
+ "This plot illustrates the prediction intervals generated by a Quantile Regression (QuantReg) model for imputing total serum cholesterol values across ten records. For each observation, the red dot represents the median prediction (Q50), while the black dot indicates the true observed value. Vertical bars depict the model’s 40% prediction interval (Q30–Q70) in dark gray and the 80% prediction interval (Q10–Q90) in light gray. The intervals convey the model’s estimation of uncertainty, with wider intervals indicating less certainty about the imputed value. In many cases, the actual value falls within the 80% interval, suggesting that the QuantReg model is reasonably well-calibrated. However, the intervals tend to be vertically symmetric and relatively wide, which reflects the linear nature of Quantile Regression: less responsive to local heteroskedasticity or skewness. Compared to Quantile Regression Forests, which can produce more adaptive and asymmetric intervals, the intervals here are more uniform in shape and spread. Overall, this plot shows that QuantReg is capable of capturing uncertainty around its median predictions, though the fit may be somewhat conservative or limited in highly nonlinear settings."
]
},
{
diff --git a/docs/imputation-benchmarking/imputing-across-surveys.md b/docs/use_cases/index.md
similarity index 89%
rename from docs/imputation-benchmarking/imputing-across-surveys.md
rename to docs/use_cases/index.md
index 3f8b6c4..b7ea6d4 100644
--- a/docs/imputation-benchmarking/imputing-across-surveys.md
+++ b/docs/use_cases/index.md
@@ -1,8 +1,8 @@
-# Imputing across surveys
+# Imputing full variables across surveys
This document explains what the workflow for imputing variables from one survey to another using Microimpute may look like. We'll use the example of imputing wealth data from the Survey of Consumer Finances (SCF) into the Current Population Survey (CPS).
- ## Identifying receiver and donor datasets
+## Identifying receiver and donor datasets
The first step is to identify your donor and receiver datasets. The Donor dataset is that containing the variable you want to impute (e.g., SCF contains wealth data). The Receiver dataset will receive the imputed variable (e.g., CPS which originally did not contain wealth data but will after our imputation is completed). It is important for these two datasets to have predictor variables in common for the imputation to be succcesful. For example, both the SCF and CPS surveys contain demographic and financial data that may help us understand how wealth values may be distributed.
@@ -20,7 +20,7 @@ cps_data = pd.read_csv("cps_data.csv")
## Cleaning and aligning variables
Before imputation, you need to ensure both datasets have compatible variables. Identify common variables present in both datasets
-and standardize their variable formats, units, and categories so that Python can recognize they indeed represent the same the same data. Remember to also handle missing values in common variables. Lastly, identify the target variables in the donor dataset which will directly inform the values of the imputed variables in the receiver dataset.
+and standardize their variable formats, units, and categories so that Python can recognize they indeed represent the same the same data. Remember to also handle missing values in common variables. Lastly, identify the target variables in the donor dataset which will directly inform the values of the imputed variables in the receiver dataset. For details on data preprocessing options refer to the [Data preprocessing page](./preprocessing.md).
```python
# Identify common variables
@@ -49,7 +49,7 @@ target_variable = ['networth']
## Performing imputation
-Microimpute offers several methods for imputation across surveys. The approach under the hood will differ based on the method chosen, although the workflow will remain constant. Let us see this for two different example methods.
+Microimpute offers several methods for imputation across surveys, which are described in the [Models Chapter](../models). The approach under the hood will differ based on the method chosen, although the workflow will remain constant. Let us see this for two different example methods.
### Matching imputation
@@ -89,7 +89,7 @@ cps_data_with_wealth_ols = ols_imputer.impute(cps_data)
## Evaluating imputation quality
-Evaluating imputation quality across surveys can be challenging since the true values aren't known in the receiver dataset. Comparing the distribution of the target variables in the donor dataset to the distribution of the variables we imputed in the receiver dataset may give us an understanding of the imputation quality for different sections of the distribution. We may want to pay particular attention to obtaining accurate prediction not only for mean or median values but also look at the performance at the distribution tails. This can be achieved computing the quantile loss supported by Microimpute. Additionally, if we have performed imputation accross multiple methods we may want to compare across them. Microimpute supports this through the easy workflow described in the [benchmarking-methods.ipynb](./benchmarking-methods.ipynb) file.
+Evaluating imputation quality across surveys can be challenging since the true values aren't known in the receiver dataset. Comparing the distribution of the target variables in the donor dataset to the distribution of the variables we imputed in the receiver dataset may give us an understanding of the imputation quality for different sections of the distribution. We may want to pay particular attention to obtaining accurate prediction not only for mean or median values but also look at the performance at the distribution tails. This can be achieved computing the quantile loss supported by Microimpute. Additionally, if we have performed imputation accross multiple methods we may want to compare across them. Microimpute supports this through multiple easy-to-use metrics described in the [Metrics page](./metrics.md) file.
```python
# Ensure all imputations are in a dictionary mapping quantiles to dataframes containing imputed values
@@ -117,4 +117,4 @@ final_imputed_dataset.to_csv("cps_with_imputed_wealth.csv", index=False)
## Key considerations
-Model selection plays a critical role in this workflow because different imputation methods have unique strengths. For example, a Quantile Regression Forest (QRF) often performs better when capturing complex relationships between variables, while a Matching approach may be more effective at preserving the original distributional properties of the data. Additionally, not all models can impute categorical data. For example, atching is able to match any value regardless of its data type, but QuantReg does not support categorical imputation. OLS and QRF will use logistic regression and random forest classification methods under the hood, respectively. Variable selection is equally important, since the common predictors across datasets should have strong power for explaining the target variable to ensure a reliable imputation. Because the ground truth is typically unknown in the receiver dataset, validation can involve simulation studies or comparing imputed values against known aggregate statistics. Finally, it is crucial to maintain documentation of the imputation process, from the choice of model to any adjustments made, so that the analysis remains transparent and reproducible for others.
+Model selection plays a critical role in this workflow because different imputation methods have unique strengths. For example, a Quantile Regression Forest (QRF) often performs better when capturing complex relationships between variables, while a Matching approach may be more effective at preserving the original distributional properties of the data. Additionally, not all models can impute categorical data. For example, atching is able to match any value regardless of its data type, but QuantReg does not support categorical imputation. OLS and QRF will use logistic regression and random forest classification methods under the hood, respectively. Variable selection is equally important, since the common predictors across datasets should have strong power for explaining the target variable to ensure a reliable imputation. Because the ground truth is typically unknown in the receiver dataset, validation can involve simulation studies or comparing imputed values against known aggregate statistics. Finally, it is crucial to maintain documentation of the imputation process, from the choice of model to any adjustments made, so that the analysis remains transparent and reproducible for others. For the full pipeline details on SCF-to-CPS net worth imputation refer to the following [notebook](./scf_to_cps/imputing-from-scf-to-cps.md).
diff --git a/docs/examples/scf_to_cps/autoimpute_best_model_imputations.png b/docs/use_cases/scf_to_cps/autoimpute_best_model_imputations.png
similarity index 100%
rename from docs/examples/scf_to_cps/autoimpute_best_model_imputations.png
rename to docs/use_cases/scf_to_cps/autoimpute_best_model_imputations.png
diff --git a/docs/examples/scf_to_cps/autoimpute_model_comparison.png b/docs/use_cases/scf_to_cps/autoimpute_model_comparison.png
similarity index 100%
rename from docs/examples/scf_to_cps/autoimpute_model_comparison.png
rename to docs/use_cases/scf_to_cps/autoimpute_model_comparison.png
diff --git a/docs/examples/scf_to_cps/by_income_decile_comparisons.png b/docs/use_cases/scf_to_cps/by_income_decile_comparisons.png
similarity index 100%
rename from docs/examples/scf_to_cps/by_income_decile_comparisons.png
rename to docs/use_cases/scf_to_cps/by_income_decile_comparisons.png
diff --git a/docs/examples/scf_to_cps/imputations_model_comparison.png b/docs/use_cases/scf_to_cps/imputations_model_comparison.png
similarity index 100%
rename from docs/examples/scf_to_cps/imputations_model_comparison.png
rename to docs/use_cases/scf_to_cps/imputations_model_comparison.png
diff --git a/docs/examples/scf_to_cps/imputing-from-scf-to-cps.md b/docs/use_cases/scf_to_cps/imputing-from-scf-to-cps.md
similarity index 99%
rename from docs/examples/scf_to_cps/imputing-from-scf-to-cps.md
rename to docs/use_cases/scf_to_cps/imputing-from-scf-to-cps.md
index eb24d54..18e09c7 100644
--- a/docs/examples/scf_to_cps/imputing-from-scf-to-cps.md
+++ b/docs/use_cases/scf_to_cps/imputing-from-scf-to-cps.md
@@ -1,6 +1,6 @@
-# Imputing wealth from the SCF to the CPS
+# Example: imputing wealth from the SCF to the CPS
-This notebook demonstrates how to use the `microimpute` package and specifically the `autoimpute` function to impute wealth variables from the Survey of Consumer Finances to the Current Population Survey.
+This notebook demonstrates a full pipeline powered by the `microimpute` package and specifically the `autoimpute` function used to impute wealth variables from the Survey of Consumer Finances to the Current Population Survey.
The Survey of Consumer Finances (SCF) is a triennial survey conducted by the Federal Reserve that collects detailed information on U.S. families' balance sheets, income, and demographic characteristics, with a special focus on wealth measures. The Current Population Survey (CPS) is a monthly survey conducted by the Census Bureau that provides comprehensive data on the labor force, employment, unemployment, and demographic characteristics, but lacks detailed wealth information.
diff --git a/microimpute/comparisons/autoimpute.py b/microimpute/comparisons/autoimpute.py
index c561542..e5ca952 100644
--- a/microimpute/comparisons/autoimpute.py
+++ b/microimpute/comparisons/autoimpute.py
@@ -27,7 +27,11 @@
VALIDATE_CONFIG,
)
from microimpute.models import OLS, QRF, Imputer, QuantReg
-from microimpute.utils.data import unnormalize_predictions
+from microimpute.utils.data import (
+ un_asinh_transform_predictions,
+ unlog_transform_predictions,
+ unnormalize_predictions,
+)
from microimpute.utils.type_handling import VariableTypeDetector
try:
@@ -47,6 +51,57 @@
log = logging.getLogger(__name__)
+
+def _reverse_transformations(
+ imputations: Dict[float, pd.DataFrame],
+ transform_params: Optional[Dict[str, Any]],
+) -> Dict[float, pd.DataFrame]:
+ """Reverse preprocessing transformations on imputed predictions.
+
+ Args:
+ imputations: Dict mapping quantiles to DataFrames of predictions.
+ transform_params: Dict with 'type' and 'params' from prepare_data_for_imputation.
+
+ Returns:
+ Dict with same structure but with reversed transformations.
+ """
+ if not transform_params:
+ return imputations
+
+ transform_type = transform_params.get("type")
+ params = transform_params.get("params", {})
+
+ if transform_type == "normalize":
+ # Legacy normalize_data=True format
+ return unnormalize_predictions(imputations, params)
+
+ elif transform_type == "preprocessing":
+ # New preprocessing format with multiple transformation types
+ result = imputations
+
+ # Reverse normalization if any
+ if params.get("normalization"):
+ result = unnormalize_predictions(result, params["normalization"])
+
+ # Reverse log transform if any
+ if params.get("log_transform"):
+ result = unlog_transform_predictions(
+ result, params["log_transform"]
+ )
+
+ # Reverse asinh transform if any
+ if params.get("asinh_transform"):
+ result = un_asinh_transform_predictions(
+ result, params["asinh_transform"]
+ )
+
+ return result
+
+ else:
+ log.warning(f"Unknown transform type: {transform_type}")
+ return imputations
+
+
# Internal constants for model compatibility with variable types
_NUMERICAL_MODELS = {"OLS", "QRF", "QuantReg", "Matching", "MDN"}
_CATEGORICAL_MODELS = {
@@ -238,11 +293,11 @@ def _generate_imputations_for_all_models(
imputed_variables: List[str],
weight_col: Optional[str],
imputation_q: float,
- normalize_data: bool,
train_size: float,
tune_hyperparameters: bool,
hyperparams: Optional[Dict[str, Any]],
log_level: str,
+ preprocessing: Optional[Dict[str, str]] = None,
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, Any]]:
"""Generate imputations for all models when impute_all=True.
@@ -274,16 +329,16 @@ def _generate_imputations_for_all_models(
log.info(f"Generating imputations with {model_name}.")
# Preprocess data fresh for this model
- training_data, imputing_data, normalizing_params = (
+ training_data, imputing_data, transform_params = (
prepare_data_for_imputation(
donor_data,
receiver_data,
predictors,
imputed_variables,
weight_col,
- normalize_data,
train_size,
1 - train_size,
+ preprocessing=preprocessing,
)
)
@@ -305,13 +360,10 @@ def _generate_imputations_for_all_models(
log_level,
)
- # Unnormalize if needed
- if normalize_data and normalizing_params:
- final_imputations = unnormalize_predictions(
- imputations, normalizing_params
- )
- else:
- final_imputations = imputations
+ # Reverse transformations if needed
+ final_imputations = _reverse_transformations(
+ imputations, transform_params
+ )
final_imputations_dict[model_name] = final_imputations[imputation_q]
fitted_models_dict[model_name] = fitted_model
@@ -330,7 +382,7 @@ def autoimpute(
imputation_quantiles: Optional[List[float]] = None,
hyperparameters: Optional[Dict[str, Dict[str, Any]]] = None,
tune_hyperparameters: Optional[bool] = False,
- normalize_data: Optional[bool] = False,
+ preprocessing: Optional[Dict[str, str]] = None,
impute_all: Optional[bool] = False,
metric_priority: Optional[str] = "auto",
random_state: Optional[int] = RANDOM_STATE,
@@ -363,7 +415,13 @@ def autoimpute(
with model names as keys. Defaults to None and uses default model hyperparameters then.
tune_hyperparameters : Whether to tune hyperparameters for the models.
Defaults to False.
- normalize_data : If True, will normalize the data before imputation.
+ preprocessing : Dictionary mapping variable names (predictors or imputed_variables)
+ to transformation type. Supported transformations:
+ - "normalize": z-score normalization (mean=0, std=1)
+ - "log": natural log transformation (requires positive values)
+ - "asinh": inverse hyperbolic sine transformation (handles zero/negative values)
+ Example: {"income": "asinh", "age": "normalize"}
+ If a variable is not in this dict, no transformation is applied.
impute_all : If True, will return final imputations for all models not
just the best one.
metric_priority : Strategy for model selection when both metrics are present:
@@ -429,16 +487,16 @@ def autoimpute(
# Keep track of original imputed variable names
original_imputed_variables = imputed_variables.copy()
- training_data, imputing_data, normalizing_params = (
+ training_data, imputing_data, transform_params = (
prepare_data_for_imputation(
donor_data,
receiver_data,
predictors,
imputed_variables,
weight_col,
- normalize_data,
train_size,
1 - train_size,
+ preprocessing=preprocessing,
)
)
@@ -556,13 +614,10 @@ def autoimpute(
log_level,
)
- # Unnormalize if needed
- if normalize_data and normalizing_params:
- final_imputations = unnormalize_predictions(
- imputations, normalizing_params
- )
- else:
- final_imputations = imputations
+ # Reverse transformations if needed
+ final_imputations = _reverse_transformations(
+ imputations, transform_params
+ )
log.info(
f"Imputation generation completed for {len(receiver_data)} samples "
@@ -601,11 +656,11 @@ def autoimpute(
original_imputed_variables,
weight_col,
imputation_q,
- normalize_data,
train_size,
tune_hyperparameters,
best_hyperparams,
log_level,
+ preprocessing=preprocessing,
)
)
final_imputations_dict.update(other_imputations)
diff --git a/microimpute/comparisons/autoimpute_helpers.py b/microimpute/comparisons/autoimpute_helpers.py
index 1df9307..9723817 100644
--- a/microimpute/comparisons/autoimpute_helpers.py
+++ b/microimpute/comparisons/autoimpute_helpers.py
@@ -85,11 +85,11 @@ def prepare_data_for_imputation(
predictors: List[str],
imputed_variables: List[str],
weight_col: Optional[str],
- normalize_data: bool,
train_size: float,
test_size: float,
+ preprocessing: Optional[Dict[str, str]] = None,
) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[dict]]:
- """Prepare training and imputing data, optionally with normalization.
+ """Prepare training and imputing data, optionally with transformations.
Args:
donor_data: Original donor data.
@@ -97,12 +97,15 @@ def prepare_data_for_imputation(
predictors: Predictor columns.
imputed_variables: Variables to impute.
weight_col: Optional weight column.
- normalize_data: Whether to normalize.
train_size: Training data proportion.
test_size: Test data proportion.
+ preprocessing: Dict mapping variable names to transformation type.
+ Supported: "normalize", "log", "asinh". If None, no transformation.
Returns:
- Tuple of (training_data, imputing_data, normalization_params or None)
+ Tuple of (training_data, imputing_data, transform_params or None)
+ transform_params contains info needed to reverse transformations on
+ imputed variables.
"""
# Remove imputed variables from receiver if present
receiver_data = receiver_data.drop(
@@ -112,43 +115,106 @@ def prepare_data_for_imputation(
training_data = donor_data.copy()
imputing_data = receiver_data.copy()
- if normalize_data:
- # Normalize predictors and imputed variables together for consistency
+ if preprocessing:
all_training_cols = predictors + imputed_variables
- normalized_training, norm_params = preprocess_data(
+ all_cols = set(all_training_cols)
+
+ # Validate preprocessing keys
+ invalid_cols = set(preprocessing.keys()) - all_cols
+ if invalid_cols:
+ error_msg = (
+ f"Preprocessing specified for unknown columns: {invalid_cols}. "
+ f"Valid columns are: {all_cols}"
+ )
+ log.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Validate transformation types
+ valid_transforms = {"normalize", "log", "asinh"}
+ for col, transform in preprocessing.items():
+ if transform not in valid_transforms:
+ error_msg = (
+ f"Invalid transformation '{transform}' for column '{col}'. "
+ f"Valid transformations are: {valid_transforms}"
+ )
+ log.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Group columns by transformation type
+ normalize_cols = [
+ col for col, t in preprocessing.items() if t == "normalize"
+ ]
+ log_cols = [col for col, t in preprocessing.items() if t == "log"]
+ asinh_cols = [col for col, t in preprocessing.items() if t == "asinh"]
+
+ # Apply transformations to training data
+ transformed_training, transform_result = preprocess_data(
training_data[all_training_cols],
full_data=True,
train_size=train_size,
test_size=test_size,
- normalize=True,
+ normalize=normalize_cols if normalize_cols else False,
+ log_transform=log_cols if log_cols else False,
+ asinh_transform=asinh_cols if asinh_cols else False,
)
- # Normalize imputing data predictors using same parameters
- imputing_predictors, _ = preprocess_data(
+ # Apply same transformations to predictors in imputing data
+ predictor_normalize = [c for c in normalize_cols if c in predictors]
+ predictor_log = [c for c in log_cols if c in predictors]
+ predictor_asinh = [c for c in asinh_cols if c in predictors]
+
+ transformed_imputing, _ = preprocess_data(
imputing_data[predictors],
full_data=True,
train_size=train_size,
test_size=test_size,
- normalize=True,
+ normalize=predictor_normalize if predictor_normalize else False,
+ log_transform=predictor_log if predictor_log else False,
+ asinh_transform=predictor_asinh if predictor_asinh else False,
)
- # Reconstruct training data with normalized values
- training_data = normalized_training
+ training_data = transformed_training
if weight_col:
training_data[weight_col] = donor_data[weight_col]
- imputing_data = imputing_predictors
-
- # Extract normalization params only for imputed variables
- imputed_norm_params = {
- col: norm_params[col]
- for col in imputed_variables
- if col in norm_params
+ imputing_data = transformed_imputing
+
+ # Extract transform params only for imputed variables
+ imputed_transform_params = {
+ "normalization": {
+ col: transform_result["normalization"].get(col, {})
+ for col in imputed_variables
+ if col in transform_result.get("normalization", {})
+ },
+ "log_transform": {
+ col: transform_result["log_transform"].get(col, {})
+ for col in imputed_variables
+ if col in transform_result.get("log_transform", {})
+ },
+ "asinh_transform": {
+ col: transform_result["asinh_transform"].get(col, {})
+ for col in imputed_variables
+ if col in transform_result.get("asinh_transform", {})
+ },
}
- return training_data, imputing_data, imputed_norm_params
+ # Only return params if there are transformations to reverse
+ has_transforms = any(
+ imputed_transform_params[key]
+ for key in ["normalization", "log_transform", "asinh_transform"]
+ )
+
+ if has_transforms:
+ transform_params = {
+ "type": "preprocessing",
+ "params": imputed_transform_params,
+ }
+ return training_data, imputing_data, transform_params
+ else:
+ return training_data, imputing_data, None
+
else:
- # No normalization needed
+ # No transformation needed
training_data = preprocess_data(
training_data[predictors + imputed_variables],
full_data=True,
diff --git a/microimpute/models/mdn.py b/microimpute/models/mdn.py
index d2fa6d4..09d4e00 100644
--- a/microimpute/models/mdn.py
+++ b/microimpute/models/mdn.py
@@ -1,6 +1,7 @@
"""Mixture Density Network (MDN) imputation model using PyTorch Tabular."""
import hashlib
+import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -19,6 +20,20 @@
# PyTorch Tabular imports
try:
+ # Set environment variables to suppress logging BEFORE imports
+ # pytorch_tabular uses PT_LOGLEVEL to set its log level
+ os.environ["PT_LOGLEVEL"] = "ERROR"
+
+ # Suppress lightning rank_zero logging
+ for _logger_name in [
+ "lightning_utilities.core.rank_zero",
+ "pytorch_lightning",
+ "pytorch_lightning.utilities.rank_zero",
+ ]:
+ logging.getLogger(_logger_name).setLevel(logging.ERROR)
+
+ # After import, also update the rank_zero_module logger
+ from lightning_fabric.utilities.rank_zero import rank_zero_module
from pytorch_tabular import TabularModel
from pytorch_tabular.config import (
DataConfig,
@@ -27,11 +42,39 @@
)
from pytorch_tabular.models import CategoryEmbeddingModelConfig, MDNConfig
+ rank_zero_module.log.setLevel(logging.ERROR)
+
PYTORCH_TABULAR_AVAILABLE = True
except ImportError:
PYTORCH_TABULAR_AVAILABLE = False
+def _suppress_pytorch_logging() -> None:
+ """Suppress verbose logging from PyTorch-related libraries.
+
+ This only suppresses pytorch_tabular and lightning logging,
+ leaving microimpute's own logging intact.
+ """
+ for logger_name in [
+ "pytorch_tabular",
+ "pytorch_tabular.tabular_model",
+ "pytorch_tabular.config",
+ "pytorch_tabular.config.config",
+ "pytorch_tabular.tabular_datamodule",
+ "pytorch_lightning",
+ "pytorch_lightning.utilities.rank_zero",
+ "lightning",
+ "lightning.pytorch",
+ "lightning.pytorch.utilities.rank_zero",
+ "lightning_fabric",
+ ]:
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.CRITICAL)
+ logger.propagate = False
+ logger.handlers = []
+ logger.addHandler(logging.NullHandler())
+
+
def _generate_data_hash(X: pd.DataFrame, y: pd.Series) -> str:
"""Generate a hash from the training data for cache identification.
@@ -142,6 +185,7 @@ def fit(
X: Feature DataFrame (predictors are already dummy-encoded).
y: Target Series.
"""
+ _suppress_pytorch_logging()
self.output_column = y.name
# Combine X and y for PyTorch Tabular
@@ -209,6 +253,8 @@ def fit(
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
+ verbose=False,
+ suppress_lightning_logger=True,
)
self.model.fit(train=train_data)
@@ -311,6 +357,7 @@ def fit(
var_type: Type of variable ("boolean" or "categorical").
categories: List of categories for categorical variables.
"""
+ _suppress_pytorch_logging()
self.output_column = y.name
self.var_type = var_type
@@ -385,6 +432,8 @@ def fit(
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
+ verbose=False,
+ suppress_lightning_logger=True,
)
self.model.fit(train=train_data)
@@ -539,19 +588,23 @@ def _predict(
X_test: pd.DataFrame,
quantiles: Optional[List[float]] = None,
return_probs: bool = False,
+ n_samples: int = 1000,
) -> Dict[float, pd.DataFrame]:
"""Predict imputed values using stochastic sampling.
- For MDN models, samples are drawn from the learned mixture distribution.
- For classifier models, samples are drawn from the predicted probability
- distribution.
+ For MDN models, many samples are drawn from the learned mixture
+ distribution and empirical quantiles are computed. For classifier
+ models, samples are drawn from the predicted probability distribution.
Args:
X_test: DataFrame containing the test data.
- quantiles: List of quantiles (used to determine number of
- independent samples to draw).
+ quantiles: List of quantiles to compute from the sampled
+ distribution.
return_probs: If True, return probability distributions for
categorical variables.
+ n_samples: Number of samples to draw for computing quantiles
+ (default 1000). More samples give more accurate quantile
+ estimates but increase computation time.
Returns:
Dictionary mapping quantiles to imputed DataFrames.
@@ -561,12 +614,24 @@ def _predict(
imputations: Dict[float, pd.DataFrame] = {}
prob_results = {} if return_probs else None
- # Determine how many independent samples to draw
if quantiles:
quantiles_to_use = quantiles
else:
- quantiles_to_use = [0.5] # Default single sample
+ quantiles_to_use = [0.5]
+
+ # Pre-compute samples for MDN models (draw once, compute all
+ # quantiles)
+ mdn_samples: Dict[str, np.ndarray] = {}
+ for variable in self.imputed_variables:
+ model = self.models[variable]
+ if isinstance(model, _MDNModel):
+ # Draw n_samples for each observation
+ samples = model.predict(
+ X_test[self.predictors], n_samples=n_samples
+ )
+ mdn_samples[variable] = samples
+ # Compute quantiles from the samples
for q in quantiles_to_use:
imputed_df = pd.DataFrame(index=X_test.index)
@@ -591,11 +656,10 @@ def _predict(
)
elif isinstance(model, _MDNModel):
- # Stochastic sampling from MDN mixture distribution
- samples = model.predict(
- X_test[self.predictors], n_samples=1
- )
- imputed_df[variable] = samples.flatten()
+ # Compute empirical quantile from samples
+ samples = mdn_samples[variable]
+ quantile_values = np.quantile(samples, q, axis=1)
+ imputed_df[variable] = quantile_values
else:
raise ValueError(
diff --git a/microimpute/utils/data.py b/microimpute/utils/data.py
index 3aa5e4e..f0c404a 100644
--- a/microimpute/utils/data.py
+++ b/microimpute/utils/data.py
@@ -1,13 +1,15 @@
"""Data preparation and transformation utilities
This module provides comprehensive data preparation functions for imputation workflows,
-including data splitting, normalization, log transformation, and categorical variable handling.
+including data splitting, normalization, log transformation, asinh transformation,
+and categorical variable handling.
These utilities ensure consistent data preprocessing across different imputation methods.
Key functions:
- - preprocess_data: split and optionally normalize or log-transform data for training/testing
+ - preprocess_data: split and optionally transform data for training/testing
- unnormalize_predictions: convert normalized predictions back to original scale
- unlog_transform_predictions: convert log-transformed predictions back to original scale
+ - un_asinh_transform_predictions: convert asinh-transformed predictions back to original scale
- Handle categorical variables through one-hot encoding
"""
@@ -255,6 +257,113 @@ def log_transform_data(
raise RuntimeError("Failed to apply log transformation") from e
+@validate_call(config=VALIDATE_CONFIG)
+def asinh_transform_data(
+ data: pd.DataFrame,
+ columns_to_transform: Optional[List[str]] = None,
+) -> Tuple[pd.DataFrame, dict]:
+ """Apply inverse hyperbolic sine (asinh) transformation to numeric columns.
+
+ The asinh transformation is similar to log transformation but handles
+ zero and negative values. It behaves like log(2x) for large positive x,
+ like -log(-2x) for large negative x, and like x near zero.
+
+ Categorical and boolean columns are excluded from transformation
+ to prevent issues when they are later encoded as dummy variables.
+
+ Args:
+ data: DataFrame to transform.
+ columns_to_transform: Optional list of specific columns to transform.
+ If None, all numeric columns will be transformed.
+
+ Returns:
+ Tuple of (asinh_transformed_data, asinh_transform_params)
+ where asinh_transform_params is a dict mapping column names to {}.
+
+ Raises:
+ ValueError: If specified columns don't exist in data.
+ RuntimeError: If asinh transformation fails.
+ """
+ logger.debug("Applying asinh transformation to data")
+ try:
+ from microimpute.utils.type_handling import VariableTypeDetector
+
+ # Identify categorical columns to exclude from asinh transformation
+ detector = VariableTypeDetector()
+ categorical_cols = []
+ for col in data.columns:
+ var_type, _ = detector.categorize_variable(data[col], col, logger)
+ if var_type in ["categorical", "numeric_categorical", "bool"]:
+ categorical_cols.append(col)
+
+ if categorical_cols:
+ logger.info(
+ f"Excluding categorical columns from asinh transformation: "
+ f"{categorical_cols}"
+ )
+
+ # Determine which columns to transform
+ if columns_to_transform is not None:
+ # Validate that specified columns exist
+ missing_cols = set(columns_to_transform) - set(data.columns)
+ if missing_cols:
+ error_msg = (
+ f"Columns specified for asinh transformation not found "
+ f"in data: {missing_cols}"
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Only transform specified columns that are not categorical
+ numeric_cols = [
+ col
+ for col in columns_to_transform
+ if col not in categorical_cols
+ ]
+
+ # Warn if user specified categorical columns
+ specified_categorical = [
+ col for col in columns_to_transform if col in categorical_cols
+ ]
+ if specified_categorical:
+ logger.warning(
+ f"Skipping asinh transformation for categorical "
+ f"columns: {specified_categorical}"
+ )
+ else:
+ # Get all numeric columns for asinh transformation
+ numeric_cols = [
+ col for col in data.columns if col not in categorical_cols
+ ]
+
+ if not numeric_cols:
+ logger.warning("No numeric columns found for asinh transformation")
+ return data.copy(), {}
+
+ # Apply asinh transformation only to numeric columns
+ data_copy = data.copy()
+ asinh_transform_params = {}
+ for col in numeric_cols:
+ data_copy[col] = np.arcsinh(data_copy[col])
+ asinh_transform_params[col] = {}
+
+ logger.debug(
+ f"Asinh transformed {len(numeric_cols)} numeric columns successfully"
+ )
+ logger.debug(
+ f"Asinh transformation parameters: {asinh_transform_params}"
+ )
+
+ return data_copy, asinh_transform_params
+
+ except ValueError:
+ # Re-raise ValueError as-is
+ raise
+ except (TypeError, AttributeError) as e:
+ logger.error(f"Error during asinh transformation: {str(e)}")
+ raise RuntimeError("Failed to apply asinh transformation") from e
+
+
@validate_call(config=VALIDATE_CONFIG)
def preprocess_data(
data: pd.DataFrame,
@@ -264,6 +373,7 @@ def preprocess_data(
random_state: Optional[int] = RANDOM_STATE,
normalize: Optional[Union[bool, List[str]]] = False,
log_transform: Optional[Union[bool, List[str]]] = False,
+ asinh_transform: Optional[Union[bool, List[str]]] = False,
) -> Union[
Tuple[pd.DataFrame, dict], # when full_data=True
Tuple[pd.DataFrame, pd.DataFrame, dict], # when full_data=False
@@ -284,6 +394,12 @@ def preprocess_data(
- True: transform all numeric columns
- List of column names: transform only those columns
- False: no transformation (default)
+ asinh_transform: Whether to apply asinh transformation to the data.
+ The asinh transformation handles zero and negative values unlike
+ log. Can be:
+ - True: transform all numeric columns
+ - List of column names: transform only those columns
+ - False: no transformation (default)
Returns:
Different tuple formats depending on parameters:
@@ -299,12 +415,13 @@ def preprocess_data(
Where transform_params is a dict with keys:
- "normalization": dict of normalization parameters (or empty dict)
- "log_transform": dict of log transform parameters (or empty dict)
+ - "asinh_transform": dict of asinh transform parameters (or empty dict)
Raises:
- ValueError: If data is empty or invalid, or if both normalize and
- log_transform would apply to the same columns, or if log_transform
- is applied to data with non-positive values, or if specified
- columns don't exist in data.
+ ValueError: If data is empty or invalid, or if multiple transformations
+ would apply to the same columns, or if log_transform is applied
+ to data with non-positive values, or if specified columns don't
+ exist in data.
RuntimeError: If data preprocessing fails
"""
@@ -315,51 +432,56 @@ def preprocess_data(
if data.empty:
raise ValueError("Data must not be None or empty")
- # Check if both normalize and log_transform are requested
+ # Check which transformations are requested
normalize_requested = normalize is not False and normalize != []
log_transform_requested = (
log_transform is not False and log_transform != []
)
+ asinh_transform_requested = (
+ asinh_transform is not False and asinh_transform != []
+ )
- # Validate that normalize and log_transform don't conflict
- if normalize_requested and log_transform_requested:
- # If both are True, they would apply to all numeric columns - conflict
- if normalize is True and log_transform is True:
- error_msg = (
- "Cannot apply both normalization and log transformation to "
- "all columns. Please specify which columns to transform with "
- "each approach using lists."
- )
- logger.error(error_msg)
- raise ValueError(error_msg)
-
- # If one is True and the other is a list, conflict
- if normalize is True or log_transform is True:
+ # Collect transformation settings for conflict checking
+ transforms = []
+ if normalize_requested:
+ transforms.append(("normalize", normalize))
+ if log_transform_requested:
+ transforms.append(("log_transform", log_transform))
+ if asinh_transform_requested:
+ transforms.append(("asinh_transform", asinh_transform))
+
+ # Validate that multiple transformations don't conflict
+ if len(transforms) > 1:
+ # Check if any are True (apply to all columns)
+ any_true = any(t[1] is True for t in transforms)
+ if any_true:
+ names = [t[0] for t in transforms]
error_msg = (
- "Cannot apply both normalization and log transformation. "
- "When using both, specify column lists for each to ensure "
- "they apply to different variables."
+ f"Cannot apply multiple transformations ({', '.join(names)}) "
+ f"to all columns. Please specify column lists for each to "
+ f"ensure they apply to different variables."
)
logger.error(error_msg)
raise ValueError(error_msg)
- # Both are lists - check for overlap
- normalize_cols = (
- set(normalize) if isinstance(normalize, list) else set()
- )
- log_transform_cols = (
- set(log_transform) if isinstance(log_transform, list) else set()
- )
- overlap = normalize_cols & log_transform_cols
-
- if overlap:
- error_msg = (
- f"Cannot apply both normalization and log transformation to "
- f"the same columns: {overlap}. Each column can only have one "
- f"transformation applied."
- )
- logger.error(error_msg)
- raise ValueError(error_msg)
+ # All are lists - check for overlaps between any pair
+ col_sets = {}
+ for name, cols in transforms:
+ col_sets[name] = set(cols) if isinstance(cols, list) else set()
+
+ # Check all pairs for overlap
+ transform_names = list(col_sets.keys())
+ for i, name1 in enumerate(transform_names):
+ for name2 in transform_names[i + 1 :]:
+ overlap = col_sets[name1] & col_sets[name2]
+ if overlap:
+ error_msg = (
+ f"Cannot apply both {name1} and {name2} to the same "
+ f"columns: {overlap}. Each column can only have one "
+ f"transformation applied."
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
# Check for missing values
missing_count = data.isna().sum().sum()
@@ -390,13 +512,30 @@ def preprocess_data(
data, columns_to_transform=log_transform
)
+ # Apply asinh transformation if requested
+ asinh_transform_params = {}
+ if asinh_transform_requested:
+ if isinstance(asinh_transform, bool):
+ # asinh_transform=True means transform all numeric columns
+ data, asinh_transform_params = asinh_transform_data(data)
+ else:
+ # asinh_transform is a list of specific columns
+ data, asinh_transform_params = asinh_transform_data(
+ data, columns_to_transform=asinh_transform
+ )
+
# Prepare transformation parameters to return
- has_transformations = normalize_requested or log_transform_requested
+ has_transformations = (
+ normalize_requested
+ or log_transform_requested
+ or asinh_transform_requested
+ )
if has_transformations:
- # Merge both parameter dicts, with a key to distinguish them
+ # Merge parameter dicts, with a key to distinguish them
transform_params = {
"normalization": normalization_params,
"log_transform": log_transform_params,
+ "asinh_transform": asinh_transform_params,
}
if full_data:
@@ -527,3 +666,56 @@ def unlog_transform_predictions(
)
return untransformed
+
+
+@validate_call(config=VALIDATE_CONFIG)
+def un_asinh_transform_predictions(
+ imputations: dict, asinh_transform_params: dict
+) -> dict:
+ """Reverse asinh transformation on predictions using stored parameters.
+
+ Args:
+ imputations: Dictionary mapping quantiles to DataFrames of predictions.
+ asinh_transform_params: Dictionary with column names that were
+ asinh-transformed.
+
+ Returns:
+ Dictionary with same structure as imputations but with
+ un-asinh-transformed values (sinh applied).
+
+ Raises:
+ ValueError: If columns in imputations don't match asinh transformation
+ parameters.
+ """
+ logger.debug(
+ f"Reversing asinh transformation for {len(imputations)} quantiles"
+ )
+
+ untransformed = {}
+ for q, df in imputations.items():
+ cols = df.columns
+
+ # Check that all columns have asinh transformation parameters
+ missing_params = [
+ col for col in cols if col not in asinh_transform_params
+ ]
+ if missing_params:
+ error_msg = (
+ f"Missing asinh transformation parameters for columns: "
+ f"{missing_params}"
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Reverse asinh transformation: x_original = sinh(x_asinh)
+ df_untransformed = df.copy()
+ for col in cols:
+ df_untransformed[col] = np.sinh(df[col])
+ untransformed[q] = df_untransformed
+
+ logger.debug(
+ f"Reversed asinh transformation for quantile {q} with shape "
+ f"{df_untransformed.shape}"
+ )
+
+ return untransformed
diff --git a/microimpute/utils/statmatch_hotdeck.py b/microimpute/utils/statmatch_hotdeck.py
index 1f66ff9..fb67bd4 100644
--- a/microimpute/utils/statmatch_hotdeck.py
+++ b/microimpute/utils/statmatch_hotdeck.py
@@ -25,37 +25,16 @@
z.vars: A character vector with the names of the variables available only in data.don that should be "donated" to data.rec.
"""
-import contextlib
import os
-import sys
-
-# Set the env var early, before rpy2 does anything
-os.environ["RPY2_USE_PYTHON_API"] = "0"
-
-
-@contextlib.contextmanager
-def suppress_rpy2_stderr():
- """Temporarily suppress rpy2's stderr output (e.g., dlopen warnings)."""
- stderr_fd = sys.stderr.fileno()
- with os.fdopen(os.dup(stderr_fd), "w") as old_stderr:
- with open(os.devnull, "w") as devnull:
- os.dup2(devnull.fileno(), stderr_fd)
- try:
- yield
- finally:
- os.dup2(old_stderr.fileno(), stderr_fd)
-
-
-with suppress_rpy2_stderr():
- import rpy2.robjects as ro
- from rpy2.robjects import (
- conversion,
- default_converter,
- numpy2ri,
- pandas2ri,
- )
- from rpy2.robjects.conversion import localconverter
- from rpy2.robjects.packages import importr
+
+# Set env vars early, before rpy2 does anything
+# RPY2_CFFI_MODE=ABI skips API mode and avoids the dlopen warning
+os.environ["RPY2_CFFI_MODE"] = "ABI"
+
+import rpy2.robjects as ro
+from rpy2.robjects import conversion, default_converter, numpy2ri, pandas2ri
+from rpy2.robjects.conversion import localconverter
+from rpy2.robjects.packages import importr
@validate_call(config=VALIDATE_CONFIG)
diff --git a/tests/test_autoimpute.py b/tests/test_autoimpute.py
index 5a9777d..017c77c 100644
--- a/tests/test_autoimpute.py
+++ b/tests/test_autoimpute.py
@@ -7,6 +7,7 @@
from microimpute.comparisons.autoimpute import autoimpute, AutoImputeResult
from microimpute.visualizations import *
+from microimpute.models import QRF, QuantReg, OLS
# Check if Matching is available
try:
@@ -16,6 +17,13 @@
except ImportError:
HAS_MATCHING = False
+# Check if MDN is available
+try:
+ from microimpute.models import MDN
+
+ HAS_MDN = True
+except ImportError:
+ HAS_MDN = False
# === Fixtures ===
@@ -81,6 +89,7 @@ def test_autoimpute_basic_structure(
receiver_data=diabetes_receiver,
predictors=predictors,
imputed_variables=imputed_variables,
+ models=[QRF, Matching, QuantReg, OLS] if not HAS_MDN else None,
hyperparameters={
"QRF": {"n_estimators": 50},
"Matching": {"constrained": True},
@@ -132,7 +141,7 @@ def test_autoimpute_all_models(
receiver_data=diabetes_receiver,
predictors=predictors,
imputed_variables=imputed_variables,
- models=None, # Use all available models
+ models=[QRF, Matching, QuantReg, OLS] if not HAS_MDN else None,
impute_all=True, # Return results for all models
log_level="WARNING",
)
@@ -201,6 +210,7 @@ def test_autoimpute_with_hyperparameters(simple_data: tuple) -> None:
receiver_data=receiver,
predictors=["x1", "x2"],
imputed_variables=["y1"],
+ models=[QRF, Matching, QuantReg, OLS] if not HAS_MDN else None,
hyperparameters=hyperparameters,
log_level="WARNING",
)
@@ -222,6 +232,7 @@ def test_autoimpute_multiple_imputed_variables(simple_data: tuple) -> None:
receiver_data=receiver,
predictors=["x1", "x2"],
imputed_variables=["y1", "y2"], # Multiple variables
+ models=[QRF, Matching, QuantReg, OLS] if not HAS_MDN else None,
log_level="WARNING",
)
@@ -243,6 +254,7 @@ def test_autoimpute_large_receiver() -> None:
receiver_data=receiver,
predictors=["x"],
imputed_variables=["y"],
+ models=[QRF, Matching, QuantReg, OLS] if not HAS_MDN else None,
log_level="WARNING",
)
@@ -263,6 +275,7 @@ def test_autoimpute_best_method_selection(simple_data: tuple) -> None:
receiver_data=receiver,
predictors=["x1", "x2"],
imputed_variables=["y1"],
+ models=[QRF, Matching, QuantReg, OLS] if not HAS_MDN else None,
log_level="WARNING",
)
@@ -321,6 +334,7 @@ def test_autoimpute_cv_results_structure(simple_data: tuple) -> None:
receiver_data=receiver,
predictors=["x1", "x2"],
imputed_variables=["y1"],
+ models=[QRF, Matching, QuantReg, OLS] if not HAS_MDN else None,
log_level="WARNING",
)
@@ -344,22 +358,6 @@ def test_autoimpute_cv_results_structure(simple_data: tuple) -> None:
assert "test" in ql_results["results"].index
-# === Visualization Compatibility ===
-
-
-def test_autoimpute_visualization_compatibility(simple_data: tuple) -> None:
- """Test that autoimpute results work with visualization functions."""
- donor, receiver = simple_data
-
- results = autoimpute(
- donor_data=donor,
- receiver_data=receiver,
- predictors=["x1", "x2"],
- imputed_variables=["y1"],
- log_level="WARNING",
- )
-
-
# === Error Handling ===
@@ -388,6 +386,7 @@ def test_autoimpute_missing_predictors() -> None:
receiver_data=receiver,
predictors=["x1", "x2"], # x2 not in receiver
imputed_variables=["y"],
+ models=[QRF, Matching, QuantReg, OLS] if not HAS_MDN else None,
log_level="WARNING",
)
@@ -425,6 +424,7 @@ def test_autoimpute_consistency(simple_data: tuple) -> None:
receiver_data=receiver,
predictors=["x1", "x2"],
imputed_variables=["y1"],
+ models=[QRF, Matching, QuantReg, OLS] if not HAS_MDN else None,
log_level="WARNING",
)
@@ -433,6 +433,7 @@ def test_autoimpute_consistency(simple_data: tuple) -> None:
receiver_data=receiver,
predictors=["x1", "x2"],
imputed_variables=["y1"],
+ models=[QRF, Matching, QuantReg, OLS] if not HAS_MDN else None,
log_level="WARNING",
)
diff --git a/tests/test_data_preprocessing.py b/tests/test_data_preprocessing.py
index ee9ea83..47b3242 100644
--- a/tests/test_data_preprocessing.py
+++ b/tests/test_data_preprocessing.py
@@ -5,9 +5,11 @@
import pytest
from microimpute.utils.data import (
+ asinh_transform_data,
log_transform_data,
normalize_data,
preprocess_data,
+ un_asinh_transform_predictions,
unlog_transform_predictions,
)
@@ -525,7 +527,7 @@ def test_preprocess_data_raises_error_for_both_normalize_and_log(
with pytest.raises(
ValueError,
- match="Cannot apply both normalization and log transformation",
+ match="Cannot apply multiple transformations",
):
preprocess_data(
data, full_data=True, normalize=True, log_transform=True
@@ -822,7 +824,7 @@ def test_error_on_overlapping_columns(self):
# Error when same column is in both lists
with pytest.raises(
- ValueError, match="Cannot apply both normalization and log"
+ ValueError, match="Cannot apply both .* to the same columns"
):
preprocess_data(
data,
@@ -900,3 +902,348 @@ def test_both_transformations_on_different_columns(self):
pd.testing.assert_series_equal(result["wealth"], data["wealth"])
assert "wealth" not in norm_params
assert "wealth" not in log_params
+
+
+class TestAsinhTransform:
+ """Test the asinh_transform_data function."""
+
+ def test_asinh_transform_excludes_categorical_columns(self):
+ """Test that categorical columns are not asinh transformed."""
+ data = pd.DataFrame(
+ {
+ "numeric_col": [1.0, 2.5, 3.7, 4.2, 5.9],
+ "categorical_col": [1, 2, 3, 1, 2],
+ "boolean_col": [0, 1, 0, 1, 0],
+ }
+ )
+
+ asinh_data, asinh_params = asinh_transform_data(data)
+
+ # Categorical and boolean columns should be unchanged
+ pd.testing.assert_series_equal(
+ asinh_data["categorical_col"], data["categorical_col"]
+ )
+ pd.testing.assert_series_equal(
+ asinh_data["boolean_col"], data["boolean_col"]
+ )
+
+ # Numeric column should be asinh transformed
+ assert not np.allclose(
+ asinh_data["numeric_col"].values, data["numeric_col"].values
+ )
+
+ # Only numeric column should have asinh transform params
+ assert "numeric_col" in asinh_params
+ assert "categorical_col" not in asinh_params
+ assert "boolean_col" not in asinh_params
+
+ def test_asinh_transform_correctly_transforms_numeric_columns(self):
+ """Test that numeric columns are correctly asinh transformed."""
+ data = pd.DataFrame(
+ {
+ "value1": [
+ -10.5,
+ -2.7,
+ 0.0,
+ 2.8,
+ 10.1,
+ 100.3,
+ 1000.9,
+ 10000.4,
+ 100000.6,
+ 1000000.2,
+ ],
+ "value2": [
+ -1000.5,
+ -100.3,
+ -10.1,
+ 0.0,
+ 10.9,
+ 100.2,
+ 1000.8,
+ 10000.5,
+ 100000.1,
+ 1000000.7,
+ ],
+ "category": [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
+ }
+ )
+
+ asinh_data, asinh_params = asinh_transform_data(data)
+
+ # Check that numeric columns are asinh transformed
+ expected_value1 = np.arcsinh(data["value1"].values)
+ expected_value2 = np.arcsinh(data["value2"].values)
+
+ np.testing.assert_array_almost_equal(
+ asinh_data["value1"].values, expected_value1
+ )
+ np.testing.assert_array_almost_equal(
+ asinh_data["value2"].values, expected_value2
+ )
+
+ # Check asinh transform params are stored
+ assert "value1" in asinh_params
+ assert "value2" in asinh_params
+
+ def test_asinh_transform_handles_negative_and_zero_values(self):
+ """Test that asinh transform correctly handles negative and zero values."""
+ data = pd.DataFrame(
+ {
+ "value": [-100.0, -10.0, -1.0, 0.0, 1.0, 10.0, 100.0],
+ }
+ )
+
+ # Should NOT raise an error (unlike log transform)
+ asinh_data, asinh_params = asinh_transform_data(data)
+
+ # Check that transformation is symmetric
+ assert np.isclose(
+ asinh_data["value"].iloc[0], -asinh_data["value"].iloc[6]
+ )
+ assert np.isclose(
+ asinh_data["value"].iloc[1], -asinh_data["value"].iloc[5]
+ )
+ assert np.isclose(
+ asinh_data["value"].iloc[2], -asinh_data["value"].iloc[4]
+ )
+ assert np.isclose(asinh_data["value"].iloc[3], 0.0)
+
+ def test_asinh_transform_returns_copy(self):
+ """Test that asinh transform returns a copy."""
+ data = pd.DataFrame(
+ {
+ "value": [-10.5, -2.7, 0.0, 2.8, 10.1, 100.3, 1000.9, 10000.4],
+ "category": [1, 2, 1, 2, 1, 2, 1, 2],
+ }
+ )
+ original_data = data.copy()
+
+ asinh_data, _ = asinh_transform_data(data)
+
+ # Original data should be unchanged
+ pd.testing.assert_frame_equal(data, original_data)
+
+ # Asinh transformed data should be different
+ assert not asinh_data["value"].equals(data["value"])
+
+
+class TestUnAsinhTransformPredictions:
+ """Test the un_asinh_transform_predictions function."""
+
+ def test_un_asinh_transform_reverses_asinh_transform(self):
+ """Test that un_asinh transform correctly reverses asinh transform."""
+ original = pd.DataFrame(
+ {
+ "value1": [
+ -100.5,
+ -10.7,
+ 0.0,
+ 10.8,
+ 100.1,
+ 1000.3,
+ ],
+ "value2": [
+ -1000.5,
+ -100.3,
+ -10.1,
+ 0.0,
+ 100.9,
+ 1000.2,
+ ],
+ }
+ )
+
+ # Apply asinh transform
+ asinh_data, asinh_params = asinh_transform_data(original)
+
+ # Create imputations dict (simulating prediction output)
+ imputations = {0.5: asinh_data}
+
+ # Reverse asinh transform
+ reversed_data = un_asinh_transform_predictions(
+ imputations, asinh_params
+ )
+
+ # Should match original data
+ pd.testing.assert_frame_equal(
+ reversed_data[0.5], original, check_exact=False, atol=1e-10
+ )
+
+ def test_un_asinh_transform_raises_error_for_missing_params(self):
+ """Test that un_asinh transform raises error when params are missing."""
+ imputations = {
+ 0.5: pd.DataFrame(
+ {
+ "value1": [0.0, 0.69, 1.10],
+ "value2": [2.3, 3.0, 3.9],
+ }
+ )
+ }
+
+ # Only have params for value1, not value2
+ asinh_params = {"value1": {}}
+
+ with pytest.raises(
+ ValueError, match="Missing asinh transformation parameters"
+ ):
+ un_asinh_transform_predictions(imputations, asinh_params)
+
+
+class TestPreprocessDataWithAsinhTransform:
+ """Test that preprocess_data correctly uses asinh transformation."""
+
+ def test_preprocess_data_excludes_categoricals_from_asinh_transform(self):
+ """Test that preprocess_data doesn't asinh transform categorical cols."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ -5.3,
+ 0.0,
+ 35.2,
+ 40.9,
+ 45.1,
+ 50.6,
+ 55.8,
+ 60.3,
+ 65.7,
+ 70.2,
+ ],
+ "race": [1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
+ "is_female": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
+ "income": [
+ -10000.0,
+ 0.0,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ }
+ )
+
+ result, transform_params = preprocess_data(
+ data, full_data=True, asinh_transform=True
+ )
+
+ # Extract asinh transform params from nested dict
+ asinh_params = transform_params["asinh_transform"]
+
+ # Categorical columns should be unchanged
+ pd.testing.assert_series_equal(result["race"], data["race"])
+ pd.testing.assert_series_equal(result["is_female"], data["is_female"])
+
+ # Numeric columns should be asinh transformed
+ assert not np.allclose(result["age"].values, data["age"].values)
+ assert not np.allclose(result["income"].values, data["income"].values)
+
+ # Only numeric columns in asinh_params
+ assert "age" in asinh_params
+ assert "income" in asinh_params
+ assert "race" not in asinh_params
+ assert "is_female" not in asinh_params
+
+ def test_preprocess_data_with_asinh_transform_on_specific_columns(self):
+ """Test asinh transform on specific columns only."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ 23,
+ 30,
+ 35,
+ 46,
+ 45,
+ 52,
+ 55,
+ 61,
+ 68,
+ 72,
+ ],
+ "income": [
+ -10000.0,
+ 0.0,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ "wealth": [
+ -50000.5,
+ 0.0,
+ 350000.7,
+ 450000.2,
+ 550000.9,
+ 650000.1,
+ 750000.4,
+ 850000.8,
+ 950000.6,
+ 1050000.3,
+ ],
+ }
+ )
+
+ # Only asinh transform income column
+ result, transform_params = preprocess_data(
+ data, full_data=True, asinh_transform=["income"]
+ )
+
+ # Extract asinh transform params from nested dict
+ asinh_params = transform_params["asinh_transform"]
+
+ # Income should be asinh transformed
+ assert not np.allclose(result["income"].values, data["income"].values)
+ assert "income" in asinh_params
+
+ # Age and wealth should NOT be transformed
+ pd.testing.assert_series_equal(result["age"], data["age"])
+ pd.testing.assert_series_equal(result["wealth"], data["wealth"])
+ assert "age" not in asinh_params
+ assert "wealth" not in asinh_params
+
+ def test_preprocess_data_raises_error_for_asinh_and_log_overlap(self):
+ """Test error when both asinh and log target same columns."""
+ data = pd.DataFrame(
+ {
+ "age": [25.3, 30.7, 35.2, 40.9, 45.1],
+ "income": [50123.45, 60987.23, 70456.78, 80234.56, 90876.12],
+ }
+ )
+
+ with pytest.raises(
+ ValueError, match="Cannot apply both .* to the same columns"
+ ):
+ preprocess_data(
+ data,
+ full_data=True,
+ asinh_transform=["income"],
+ log_transform=["income"],
+ )
+
+ def test_preprocess_data_raises_error_for_asinh_and_normalize_overlap(
+ self,
+ ):
+ """Test error when both asinh and normalize target same columns."""
+ data = pd.DataFrame(
+ {
+ "age": [25.3, 30.7, 35.2, 40.9, 45.1],
+ "income": [50123.45, 60987.23, 70456.78, 80234.56, 90876.12],
+ }
+ )
+
+ with pytest.raises(
+ ValueError, match="Cannot apply both .* to the same columns"
+ ):
+ preprocess_data(
+ data,
+ full_data=True,
+ asinh_transform=["income"],
+ normalize=["income"],
+ )
diff --git a/tests/test_models/test_imputers.py b/tests/test_models/test_imputers.py
index e9f87f7..18548b8 100644
--- a/tests/test_models/test_imputers.py
+++ b/tests/test_models/test_imputers.py
@@ -658,11 +658,15 @@ def test_missing_predictors_in_test(model_class: Type[Imputer]) -> None:
@pytest.mark.parametrize(
- "model_class", ALL_IMPUTER_MODELS, ids=lambda cls: cls.__name__
+ "model_class",
+ [OLS, QuantReg, QRF, Matching],
+ ids=lambda cls: cls.__name__,
)
def test_reproducibility(
model_class: Type[Imputer], simple_data: pd.DataFrame
) -> None:
+ # Note: MDN is excluded because PyTorch MPS (Apple Silicon) doesn't support
+ # deterministic operations, making reproducibility tests unreliable.
"""Test that models produce reproducible results."""
X_train, X_test = preprocess_data(simple_data)