From 5f10f733cf5befba924eb8784f6cd1ff2b9ca346 Mon Sep 17 00:00:00 2001
From: juaristi22
Date: Mon, 3 Nov 2025 16:38:04 +0800
Subject: [PATCH 1/2] add imputed variable histograms to dashboard
---
changelog_entry.yaml | 4 +
examples/demo.ipynb | 1778 +++++++++++++++++
examples/pipeline.py | 3 +
.../components/DistributionOverlay.tsx | 313 +++
.../components/ImputationResults.tsx | 6 +-
.../components/VisualizationDashboard.tsx | 89 +-
.../public/microimputation_results.csv | 63 +
microimpute/utils/dashboard_formatter.py | 335 ++++
tests/test_dashboard_formatter.py | 309 +++
9 files changed, 2895 insertions(+), 5 deletions(-)
create mode 100644 examples/demo.ipynb
create mode 100644 microimputation-dashboard/components/DistributionOverlay.tsx
diff --git a/changelog_entry.yaml b/changelog_entry.yaml
index e69de29..c68d72d 100644
--- a/changelog_entry.yaml
+++ b/changelog_entry.yaml
@@ -0,0 +1,4 @@
+- bump: minor
+ changes:
+ added:
+ - Distribution comparison histogram of donor and receiver datasets for imputed variables (to dashboard).
diff --git a/examples/demo.ipynb b/examples/demo.ipynb
new file mode 100644
index 0000000..c749976
--- /dev/null
+++ b/examples/demo.ipynb
@@ -0,0 +1,1778 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MicroImpute Pipeline Demo\n",
+ "\n",
+ "Demonstrates autoimpute, distribution evaluation, predictor analysis, and dashboard formatting."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup: Import libraries"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import warnings\n",
+ "from sklearn.datasets import load_diabetes\n",
+ "\n",
+ "from microimpute.comparisons.autoimpute import autoimpute\n",
+ "from microimpute.comparisons.metrics import compare_distributions\n",
+ "from microimpute.visualizations import method_comparison_results\n",
+ "from microimpute.evaluations.predictor_analysis import (\n",
+ " compute_predictor_correlations,\n",
+ " leave_one_out_analysis,\n",
+ " progressive_predictor_inclusion,\n",
+ ")\n",
+ "from microimpute.utils.dashboard_formatter import format_csv\n",
+ "from microimpute.models import OLS\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 1: Load and prepare data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load the diabetes dataset\n",
+ "diabetes = load_diabetes()\n",
+ "diabetes_data = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)\n",
+ "\n",
+ "# Split into donor and receiver portions (70/30 split)\n",
+ "np.random.seed(42)\n",
+ "donor_indices = np.random.choice(\n",
+ " len(diabetes_data),\n",
+ " size=int(0.7 * len(diabetes_data)),\n",
+ " replace=False,\n",
+ ")\n",
+ "receiver_indices = np.array(\n",
+ " [i for i in range(len(diabetes_data)) if i not in donor_indices]\n",
+ ")\n",
+ "\n",
+ "donor_data = diabetes_data.iloc[donor_indices].reset_index(drop=True)\n",
+ "receiver_data = diabetes_data.iloc[receiver_indices].reset_index(\n",
+ " drop=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create a categorical risk_factor variable based on cholesterol levels (s4)\n",
+ "def categorize_risk(s4_value):\n",
+ " if s4_value < -0.02:\n",
+ " return \"low\"\n",
+ " elif s4_value < 0.02:\n",
+ " return \"medium\"\n",
+ " else:\n",
+ " return \"high\"\n",
+ "\n",
+ "\n",
+ "donor_data[\"risk_factor\"] = donor_data[\"s4\"].apply(categorize_risk)\n",
+ "receiver_data[\"risk_factor\"] = receiver_data[\"s4\"].apply(categorize_risk)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Donor data shape: (309, 11)\n",
+ "Receiver data shape: (133, 8)\n",
+ "Predictors: ['age', 'sex', 'bmi', 'bp']\n",
+ "Variables to impute: ['s1', 's4', 'risk_factor']\n",
+ "\n",
+ "Risk factor distribution in donor data:\n",
+ "risk_factor\n",
+ "low 112\n",
+ "high 108\n",
+ "medium 89\n",
+ "Name: count, dtype: int64\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Define predictors and variables to impute\n",
+ "predictors = [\"age\", \"sex\", \"bmi\", \"bp\"]\n",
+ "imputed_variables = [\"s1\", \"s4\", \"risk_factor\"]\n",
+ "\n",
+ "# Remove imputed variables from receiver data\n",
+ "receiver_data = receiver_data.drop(\n",
+ " columns=imputed_variables\n",
+ ")\n",
+ "\n",
+ "print(f\"Donor data shape: {donor_data.shape}\")\n",
+ "print(f\"Receiver data shape: {receiver_data.shape}\")\n",
+ "print(f\"Predictors: {predictors}\")\n",
+ "print(f\"Variables to impute: {imputed_variables}\")\n",
+ "print(f\"\\nRisk factor distribution in donor data:\")\n",
+ "print(donor_data[\"risk_factor\"].value_counts())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 2: Run autoimpute to find the best imputation method"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "70f7282b79674dbb9ce2522e37a117fb",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Evaluating models: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 8.3s finished\n",
+ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.7s finished\n",
+ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
+ "QuantReg does not support categorical variable 'risk_factor'. Skipping QuantReg for this fold.\n",
+ "[Parallel(n_jobs=-1)]: Batch computation too fast (0.012035846710205078s.) Setting batch_size=2.\n",
+ "QuantReg does not support categorical variable 'risk_factor'. Skipping QuantReg for this fold.\n",
+ "QuantReg does not support categorical variable 'risk_factor'. Skipping QuantReg for this fold.\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.0s finished\n",
+ "QuantReg cannot handle the provided variable types. Returning NaN results.\n",
+ "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 3.3s\n",
+ "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 8.3s finished\n",
+ "Model QuantReg cannot handle categorical variable 'risk_factor' (type: categorical). Skipping.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "['_abc_impl',\n",
+ " '_calculate_keys',\n",
+ " '_copy_and_set_values',\n",
+ " '_get_value',\n",
+ " '_iter',\n",
+ " '_setattr_handler',\n",
+ " 'construct',\n",
+ " 'copy',\n",
+ " 'cv_results',\n",
+ " 'dict',\n",
+ " 'fitted_models',\n",
+ " 'from_orm',\n",
+ " 'imputations',\n",
+ " 'json',\n",
+ " 'model_computed_fields',\n",
+ " 'model_config',\n",
+ " 'model_construct',\n",
+ " 'model_copy',\n",
+ " 'model_dump',\n",
+ " 'model_dump_json',\n",
+ " 'model_extra',\n",
+ " 'model_fields',\n",
+ " 'model_fields_set',\n",
+ " 'model_json_schema',\n",
+ " 'model_parametrized_name',\n",
+ " 'model_post_init',\n",
+ " 'model_rebuild',\n",
+ " 'model_validate',\n",
+ " 'model_validate_json',\n",
+ " 'model_validate_strings',\n",
+ " 'parse_file',\n",
+ " 'parse_obj',\n",
+ " 'parse_raw',\n",
+ " 'receiver_data',\n",
+ " 'schema',\n",
+ " 'schema_json',\n",
+ " 'update_forward_refs',\n",
+ " 'validate']"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "autoimpute_results = autoimpute(\n",
+ " donor_data=donor_data,\n",
+ " receiver_data=receiver_data,\n",
+ " predictors=predictors,\n",
+ " imputed_variables=imputed_variables,\n",
+ " tune_hyperparameters=False,\n",
+ " impute_all=True,\n",
+ " k_folds=3,\n",
+ ")\n",
+ "\n",
+ "all_attributes = dir(autoimpute_results)\n",
+ "non_dunder_attributes = [attr for attr in all_attributes if (not (attr.startswith('__') and attr.endswith('__')))]\n",
+ "non_dunder_attributes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Best performing method: OLSResults\n"
+ ]
+ }
+ ],
+ "source": [
+ "best_method_name = autoimpute_results.fitted_models[\n",
+ " \"best_method\"\n",
+ "].__class__.__name__\n",
+ "print(f\"Best performing method: {best_method_name}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 3: Compare model performance"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "alignmentgroup": "True",
+ "hovertemplate": "Method=QRF Quantiles=%{x} Test Quantile loss=%{y} ",
+ "legendgroup": "QRF",
+ "marker": {
+ "color": "#636EFA",
+ "pattern": {
+ "shape": ""
+ }
+ },
+ "name": "QRF",
+ "offsetgroup": "QRF",
+ "orientation": "v",
+ "showlegend": true,
+ "textposition": "auto",
+ "type": "bar",
+ "x": [
+ "0.05",
+ "0.05",
+ "0.1",
+ "0.1",
+ "0.15",
+ "0.15",
+ "0.2",
+ "0.2",
+ "0.25",
+ "0.25",
+ "0.3",
+ "0.3",
+ "0.35",
+ "0.35",
+ "0.4",
+ "0.4",
+ "0.45",
+ "0.45",
+ "0.5",
+ "0.5",
+ "0.55",
+ "0.55",
+ "0.6",
+ "0.6",
+ "0.65",
+ "0.65",
+ "0.7",
+ "0.7",
+ "0.75",
+ "0.75",
+ "0.8",
+ "0.8",
+ "0.85",
+ "0.85",
+ "0.9",
+ "0.9",
+ "0.95",
+ "0.95"
+ ],
+ "xaxis": "x",
+ "y": [
+ 0.004976762851613436,
+ 0.004976762851613436,
+ 0.007958901707037789,
+ 0.007958901707037789,
+ 0.011348256884255009,
+ 0.011348256884255009,
+ 0.014479045607066829,
+ 0.014479045607066829,
+ 0.01804729212868635,
+ 0.01804729212868635,
+ 0.020713966247658862,
+ 0.020713966247658862,
+ 0.022742931417508282,
+ 0.022742931417508282,
+ 0.02258721277917371,
+ 0.02258721277917371,
+ 0.023351581275292608,
+ 0.023351581275292608,
+ 0.023625607730980425,
+ 0.023625607730980425,
+ 0.023186470004863385,
+ 0.023186470004863385,
+ 0.022905432991342744,
+ 0.022905432991342744,
+ 0.022230731438156376,
+ 0.022230731438156376,
+ 0.02057637886132135,
+ 0.02057637886132135,
+ 0.01912773338390523,
+ 0.01912773338390523,
+ 0.01739101077609373,
+ 0.01739101077609373,
+ 0.013723503511700213,
+ 0.013723503511700213,
+ 0.010967503180152355,
+ 0.010967503180152355,
+ 0.0071899999204732045,
+ 0.0071899999204732045
+ ],
+ "yaxis": "y"
+ },
+ {
+ "alignmentgroup": "True",
+ "hovertemplate": "Method=OLS Quantiles=%{x} Test Quantile loss=%{y} ",
+ "legendgroup": "OLS",
+ "marker": {
+ "color": "#EF553B",
+ "pattern": {
+ "shape": ""
+ }
+ },
+ "name": "OLS",
+ "offsetgroup": "OLS",
+ "orientation": "v",
+ "showlegend": true,
+ "textposition": "auto",
+ "type": "bar",
+ "x": [
+ "0.05",
+ "0.05",
+ "0.1",
+ "0.1",
+ "0.15",
+ "0.15",
+ "0.2",
+ "0.2",
+ "0.25",
+ "0.25",
+ "0.3",
+ "0.3",
+ "0.35",
+ "0.35",
+ "0.4",
+ "0.4",
+ "0.45",
+ "0.45",
+ "0.5",
+ "0.5",
+ "0.55",
+ "0.55",
+ "0.6",
+ "0.6",
+ "0.65",
+ "0.65",
+ "0.7",
+ "0.7",
+ "0.75",
+ "0.75",
+ "0.8",
+ "0.8",
+ "0.85",
+ "0.85",
+ "0.9",
+ "0.9",
+ "0.95",
+ "0.95"
+ ],
+ "xaxis": "x",
+ "y": [
+ 0.003875743024333408,
+ 0.003875743024333408,
+ 0.006636892853346667,
+ 0.006636892853346667,
+ 0.008964045456920146,
+ 0.008964045456920146,
+ 0.010925268697485854,
+ 0.010925268697485854,
+ 0.012586171018988979,
+ 0.012586171018988979,
+ 0.013971078532560128,
+ 0.013971078532560128,
+ 0.015078371048174203,
+ 0.015078371048174203,
+ 0.015960673093506517,
+ 0.015960673093506517,
+ 0.016611314804344936,
+ 0.016611314804344936,
+ 0.017023427694261576,
+ 0.017023427694261576,
+ 0.017162262823655853,
+ 0.017162262823655853,
+ 0.016994331164241493,
+ 0.016994331164241493,
+ 0.016563877768320707,
+ 0.016563877768320707,
+ 0.01584349745574873,
+ 0.01584349745574873,
+ 0.014830479499612956,
+ 0.014830479499612956,
+ 0.013414418494530157,
+ 0.013414418494530157,
+ 0.011604173888267462,
+ 0.011604173888267462,
+ 0.009116693282835488,
+ 0.009116693282835488,
+ 0.005653910575626517,
+ 0.005653910575626517
+ ],
+ "yaxis": "y"
+ },
+ {
+ "alignmentgroup": "True",
+ "hovertemplate": "Method=Matching Quantiles=%{x} Test Quantile loss=%{y} ",
+ "legendgroup": "Matching",
+ "marker": {
+ "color": "#00CC96",
+ "pattern": {
+ "shape": ""
+ }
+ },
+ "name": "Matching",
+ "offsetgroup": "Matching",
+ "orientation": "v",
+ "showlegend": true,
+ "textposition": "auto",
+ "type": "bar",
+ "x": [
+ "0.05",
+ "0.05",
+ "0.1",
+ "0.1",
+ "0.15",
+ "0.15",
+ "0.2",
+ "0.2",
+ "0.25",
+ "0.25",
+ "0.3",
+ "0.3",
+ "0.35",
+ "0.35",
+ "0.4",
+ "0.4",
+ "0.45",
+ "0.45",
+ "0.5",
+ "0.5",
+ "0.55",
+ "0.55",
+ "0.6",
+ "0.6",
+ "0.65",
+ "0.65",
+ "0.7",
+ "0.7",
+ "0.75",
+ "0.75",
+ "0.8",
+ "0.8",
+ "0.85",
+ "0.85",
+ "0.9",
+ "0.9",
+ "0.95",
+ "0.95"
+ ],
+ "xaxis": "x",
+ "y": [
+ 0.022039382576829306,
+ 0.022039382576829306,
+ 0.022147598558993443,
+ 0.022147598558993443,
+ 0.022255814541157576,
+ 0.022255814541157576,
+ 0.022364030523321716,
+ 0.022364030523321716,
+ 0.022472246505485852,
+ 0.022472246505485852,
+ 0.02258046248764999,
+ 0.02258046248764999,
+ 0.022688678469814125,
+ 0.022688678469814125,
+ 0.02279689445197826,
+ 0.02279689445197826,
+ 0.0229051104341424,
+ 0.0229051104341424,
+ 0.02301332641630653,
+ 0.02301332641630653,
+ 0.02312154239847067,
+ 0.02312154239847067,
+ 0.023229758380634808,
+ 0.023229758380634808,
+ 0.023337974362798945,
+ 0.023337974362798945,
+ 0.023446190344963078,
+ 0.023446190344963078,
+ 0.023554406327127214,
+ 0.023554406327127214,
+ 0.02366262230929135,
+ 0.02366262230929135,
+ 0.023770838291455487,
+ 0.023770838291455487,
+ 0.023879054273619627,
+ 0.023879054273619627,
+ 0.02398727025578376,
+ 0.02398727025578376
+ ],
+ "yaxis": "y"
+ }
+ ],
+ "layout": {
+ "barmode": "group",
+ "height": 600,
+ "legend": {
+ "title": {
+ "text": "Method"
+ },
+ "tracegroupgap": 0
+ },
+ "paper_bgcolor": "#F0F0F0",
+ "plot_bgcolor": "#F0F0F0",
+ "shapes": [
+ {
+ "line": {
+ "color": "#636EFA",
+ "dash": "dot",
+ "width": 2
+ },
+ "name": "QRF Mean",
+ "type": "line",
+ "x0": -0.5,
+ "x1": 18.5,
+ "y0": 0.017217385405120106,
+ "y1": 0.017217385405120106
+ },
+ {
+ "line": {
+ "color": "#EF553B",
+ "dash": "dot",
+ "width": 2
+ },
+ "name": "OLS Mean",
+ "type": "line",
+ "x0": -0.5,
+ "x1": 18.5,
+ "y0": 0.012779822693513777,
+ "y1": 0.012779822693513777
+ },
+ {
+ "line": {
+ "color": "#00CC96",
+ "dash": "dot",
+ "width": 2
+ },
+ "name": "Matching Mean",
+ "type": "line",
+ "x0": -0.5,
+ "x1": 18.5,
+ "y0": 0.023013326416306525,
+ "y1": 0.023013326416306525
+ }
+ ],
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "font": {
+ "size": 14
+ },
+ "text": "Autoimpute method comparison on diabetes dataset"
+ },
+ "width": 750,
+ "xaxis": {
+ "anchor": "y",
+ "domain": [
+ 0,
+ 1
+ ],
+ "showgrid": false,
+ "title": {
+ "font": {
+ "size": 12
+ },
+ "text": "Quantiles"
+ },
+ "zeroline": false
+ },
+ "yaxis": {
+ "anchor": "x",
+ "domain": [
+ 0,
+ 1
+ ],
+ "showgrid": false,
+ "title": {
+ "font": {
+ "size": 12
+ },
+ "text": "Test Quantile loss"
+ },
+ "zeroline": false
+ }
+ }
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Extract the model cross-validation results from the autoimpute results object\n",
+ "comparison_viz = method_comparison_results(\n",
+ " data=autoimpute_results.cv_results,\n",
+ " metric=\"quantile_loss\",\n",
+ " data_format=\"wide\",\n",
+ ")\n",
+ "fig = comparison_viz.plot(\n",
+ " title=\"Autoimpute method comparison on diabetes dataset\",\n",
+ " show_mean=True,\n",
+ ")\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 4: Evaluate distribution preservation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Distribution comparison results:\n",
+ " Variable Metric Distance\n",
+ "0 s1 wasserstein_distance 0.024660\n",
+ "1 s4 wasserstein_distance 0.020422\n",
+ "2 risk_factor kl_divergence 6.033155\n"
+ ]
+ }
+ ],
+ "source": [
+ "distribution_comparison_df = compare_distributions(\n",
+ " donor_data=donor_data,\n",
+ " receiver_data=autoimpute_results.receiver_data,\n",
+ " imputed_variables=imputed_variables,\n",
+ ")\n",
+ "\n",
+ "print(\"Distribution comparison results:\")\n",
+ "print(distribution_comparison_df)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 5: Analyze predictor correlations and mutual information"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Correlation analysis completed:\n",
+ " - Pearson correlation matrix: (4, 4)\n",
+ " - Spearman correlation matrix: (4, 4)\n",
+ " - Mutual information matrix: (4, 4)\n",
+ " - Predictor-target MI: (4, 3)\n"
+ ]
+ }
+ ],
+ "source": [
+ "predictor_correlations = compute_predictor_correlations(\n",
+ " data=donor_data,\n",
+ " predictors=predictors,\n",
+ " imputed_variables=imputed_variables,\n",
+ " method=\"all\",\n",
+ ")\n",
+ "\n",
+ "print(\"Correlation analysis completed:\")\n",
+ "print(\n",
+ " f\" - Pearson correlation matrix: {predictor_correlations['pearson'].shape}\"\n",
+ ")\n",
+ "print(\n",
+ " f\" - Spearman correlation matrix: {predictor_correlations['spearman'].shape}\"\n",
+ ")\n",
+ "print(\n",
+ " f\" - Mutual information matrix: {predictor_correlations['mutual_info'].shape}\"\n",
+ ")\n",
+ "print(\n",
+ " f\" - Predictor-target MI: {predictor_correlations['predictor_target_mi'].shape}\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 6: Assess predictor importance via leave-one-out analysis"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "90a63259512d414aa2a4f65e47c446ef",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Leave-one-out analysis: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Predictor importance results:\n",
+ " predictor_removed relative_impact\n",
+ "1 sex 30.680328\n",
+ "2 bmi 0.032479\n",
+ "0 age 0.004546\n",
+ "3 bp -0.004827\n"
+ ]
+ }
+ ],
+ "source": [
+ "predictor_importance_df = leave_one_out_analysis(\n",
+ " data=donor_data,\n",
+ " predictors=predictors,\n",
+ " imputed_variables=imputed_variables,\n",
+ " model_class=OLS,\n",
+ " quantiles=[0.1, 0.5, 0.9],\n",
+ " train_size=0.7,\n",
+ " n_jobs=1,\n",
+ " random_state=42,\n",
+ ")\n",
+ "\n",
+ "print(\"Predictor importance results:\")\n",
+ "print(predictor_importance_df[[\"predictor_removed\", \"relative_impact\"]])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 7: Analyze impact of variable ordering via progressive inclusion"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c2d73cd76e034edfa3a640b4e719fd6b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Progressive inclusion: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Optimal predictor order: ['sex', 'bmi', 'age', 'bp']\n",
+ "Optimal subset: ['sex', 'bmi', 'age']\n",
+ "Optimal loss: 2.409709\n"
+ ]
+ }
+ ],
+ "source": [
+ "progressive_results = progressive_predictor_inclusion(\n",
+ " data=donor_data,\n",
+ " predictors=predictors,\n",
+ " imputed_variables=imputed_variables,\n",
+ " model_class=OLS,\n",
+ " quantiles=[0.1, 0.5, 0.9],\n",
+ " train_size=0.7,\n",
+ " random_state=42,\n",
+ ")\n",
+ "\n",
+ "progressive_inclusion_df = progressive_results[\"results_df\"]\n",
+ "optimal_subset = progressive_results[\"optimal_subset\"]\n",
+ "optimal_loss = progressive_results[\"optimal_loss\"]\n",
+ "\n",
+ "print(\n",
+ " f\"Optimal predictor order: {progressive_inclusion_df['predictor_added'].tolist()}\"\n",
+ ")\n",
+ "print(f\"Optimal subset: {optimal_subset}\")\n",
+ "print(f\"Optimal loss: {optimal_loss:.6f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 8: Format all results into unified CSV for dashboard visualization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Formatted DataFrame shape: (293, 8)\n",
+ "Result types included: ['benchmark_loss' 'distribution_distance' 'predictor_correlation'\n",
+ " 'predictor_target_mi' 'predictor_importance' 'progressive_inclusion']\n"
+ ]
+ }
+ ],
+ "source": [
+ "output_path = \"microimputation_results.csv\"\n",
+ "\n",
+ "autoimpute_dict = {\"cv_results\": autoimpute_results.cv_results}\n",
+ "\n",
+ "formatted_df = format_csv(\n",
+ " output_path=output_path,\n",
+ " autoimpute_result=autoimpute_dict,\n",
+ " comparison_metrics_df=None,\n",
+ " distribution_comparison_df=distribution_comparison_df,\n",
+ " predictor_correlations=predictor_correlations,\n",
+ " predictor_importance_df=predictor_importance_df,\n",
+ " progressive_inclusion_df=progressive_inclusion_df,\n",
+ " best_method_name=best_method_name,\n",
+ ")\n",
+ "\n",
+ "print(f\"Formatted DataFrame shape: {formatted_df.shape}\")\n",
+ "print(f\"Result types included: {formatted_df['type'].unique()}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 9: Remember to explore your results in the Microimputation Dashboard!"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pe3.13",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/examples/pipeline.py b/examples/pipeline.py
index da848d8..5907a86 100644
--- a/examples/pipeline.py
+++ b/examples/pipeline.py
@@ -242,6 +242,9 @@ def categorize_risk(s4_value):
predictor_importance_df=predictor_importance_df,
progressive_inclusion_df=progressive_inclusion_df,
best_method_name=best_method_name,
+ donor_data=donor_data,
+ receiver_data=autoimpute_results.receiver_data,
+ imputed_variables=imputed_variables,
)
print(f"Formatted DataFrame shape: {formatted_df.shape}")
diff --git a/microimputation-dashboard/components/DistributionOverlay.tsx b/microimputation-dashboard/components/DistributionOverlay.tsx
new file mode 100644
index 0000000..6e621b9
--- /dev/null
+++ b/microimputation-dashboard/components/DistributionOverlay.tsx
@@ -0,0 +1,313 @@
+'use client';
+
+import { useMemo, useState } from 'react';
+import { ImputationDataPoint } from '@/types/imputation';
+import {
+ BarChart,
+ Bar,
+ XAxis,
+ YAxis,
+ CartesianGrid,
+ Tooltip,
+ Legend,
+ ResponsiveContainer,
+} from 'recharts';
+
+interface DistributionOverlayProps {
+ data: ImputationDataPoint[];
+}
+
+interface BinData {
+ binIndex: number;
+ binStart: number;
+ binEnd: number;
+ donorHeight: number;
+ receiverHeight: number;
+ binLabel?: string;
+}
+
+interface CategoryData {
+ category: string;
+ donorProportion: number;
+ receiverProportion: number;
+}
+
+interface VariableDistribution {
+ variable: string;
+ type: 'numerical' | 'categorical';
+ data: BinData[] | CategoryData[];
+ nSamplesDonor: number;
+ nSamplesReceiver: number;
+}
+
+export default function DistributionOverlay({
+ data,
+}: DistributionOverlayProps) {
+ // Extract distribution bins data
+ const distributionBins = useMemo(() => {
+ return data.filter((d) => d.type === 'distribution_bins');
+ }, [data]);
+
+ // Parse and group distribution data by variable
+ const variableDistributions = useMemo(() => {
+ const distributions: Record = {};
+
+ distributionBins.forEach((d) => {
+ const variable = d.variable;
+
+ if (!distributions[variable]) {
+ distributions[variable] = {
+ variable,
+ type:
+ d.metric_name === 'histogram_distribution'
+ ? 'numerical'
+ : 'categorical',
+ data: [],
+ nSamplesDonor: 0,
+ nSamplesReceiver: 0,
+ };
+ }
+
+ try {
+ const info = JSON.parse(d.additional_info);
+
+ if (d.metric_name === 'histogram_distribution') {
+ // Numerical variable
+ (distributions[variable].data as BinData[]).push({
+ binIndex: info.bin_index,
+ binStart: info.bin_start,
+ binEnd: info.bin_end,
+ donorHeight: info.donor_height,
+ receiverHeight: info.receiver_height,
+ binLabel: `${info.bin_start.toFixed(2)}-${info.bin_end.toFixed(2)}`,
+ });
+ distributions[variable].nSamplesDonor = info.n_samples_donor;
+ distributions[variable].nSamplesReceiver = info.n_samples_receiver;
+ } else if (d.metric_name === 'categorical_distribution') {
+ // Categorical variable
+ (distributions[variable].data as CategoryData[]).push({
+ category: info.category,
+ donorProportion: info.donor_proportion,
+ receiverProportion: info.receiver_proportion,
+ });
+ distributions[variable].nSamplesDonor = info.n_samples_donor;
+ distributions[variable].nSamplesReceiver = info.n_samples_receiver;
+ }
+ } catch (e) {
+ console.error('Error parsing distribution bin data:', e);
+ }
+ });
+
+ // Sort numerical bins by bin index
+ Object.values(distributions).forEach((dist) => {
+ if (dist.type === 'numerical') {
+ (dist.data as BinData[]).sort((a, b) => a.binIndex - b.binIndex);
+ }
+ });
+
+ return distributions;
+ }, [distributionBins]);
+
+ const variables = Object.keys(variableDistributions);
+ const [selectedVariable, setSelectedVariable] = useState(
+ variables[0] || ''
+ );
+
+ if (variables.length === 0) {
+ return null;
+ }
+
+ const selectedDistribution = variableDistributions[selectedVariable];
+
+ const renderNumericalDistribution = (dist: VariableDistribution) => {
+ const chartData = (dist.data as BinData[]).map((bin) => ({
+ name: bin.binLabel,
+ Donor: bin.donorHeight,
+ Receiver: bin.receiverHeight,
+ binStart: bin.binStart,
+ binEnd: bin.binEnd,
+ }));
+
+ return (
+
+
+
+
+
+
+ [`${value.toFixed(2)}%`, '']}
+ labelFormatter={(label) => `Bin: ${label}`}
+ contentStyle={{ color: '#000000' }}
+ labelStyle={{ color: '#000000' }}
+ />
+
+
+
+
+
+
+ Histogram with {(dist.data as BinData[]).length} bins. Each bin shows the percentage of values falling within that range.
+ Overlapping bars indicate similar distributions.
+
+
+ );
+ };
+
+ const renderCategoricalDistribution = (dist: VariableDistribution) => {
+ const chartData = (dist.data as CategoryData[]).map((cat) => ({
+ category: cat.category,
+ Donor: cat.donorProportion,
+ Receiver: cat.receiverProportion,
+ }));
+
+ return (
+
+
+
+
+
+
+ [`${value.toFixed(2)}%`, '']}
+ contentStyle={{ color: '#000000' }}
+ labelStyle={{ color: '#000000' }}
+ />
+
+
+
+
+
+
+ Side-by-side bars compare the proportion of each category in donor vs receiver data.
+
+
+ );
+ };
+
+ return (
+
+
+
+ Distribution comparison
+
+
+ Visual comparison of donor and receiver distributions for imputed
+ variables. Overlapping distributions indicate successful imputation.
+
+
+ {/* Variable selector */}
+ {variables.length > 1 && (
+
+
+ Select variable:
+
+ setSelectedVariable(e.target.value)}
+ className="px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:border-indigo-500 bg-white text-gray-900"
+ >
+ {variables.map((variable) => (
+
+ {variable} (
+ {variableDistributions[variable].type === 'numerical'
+ ? 'numerical'
+ : 'categorical'}
+ )
+
+ ))}
+
+
+ )}
+
+ {variables.length === 1 && (
+
+ Variable: {' '}
+ {selectedVariable} {' '}
+
+ ({selectedDistribution?.type})
+
+
+ )}
+
+
+ {/* Render appropriate chart */}
+ {selectedDistribution && (
+
+ {selectedDistribution.type === 'numerical'
+ ? renderNumericalDistribution(selectedDistribution)
+ : renderCategoricalDistribution(selectedDistribution)}
+
+ )}
+
+ );
+}
diff --git a/microimputation-dashboard/components/ImputationResults.tsx b/microimputation-dashboard/components/ImputationResults.tsx
index 598b507..57f5387 100644
--- a/microimputation-dashboard/components/ImputationResults.tsx
+++ b/microimputation-dashboard/components/ImputationResults.tsx
@@ -3,6 +3,7 @@
import { useMemo } from 'react';
import { ImputationDataPoint } from '@/types/imputation';
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, Legend, ResponsiveContainer, Cell } from 'recharts';
+import DistributionOverlay from './DistributionOverlay';
interface ImputationResultsProps {
data: ImputationDataPoint[];
@@ -86,6 +87,9 @@ export default function ImputationResults({ data }: ImputationResultsProps) {
+ {/* Distribution Overlay Visualization */}
+
+
{/* Wasserstein Distance Section */}
{hasWasserstein && (
@@ -207,7 +211,7 @@ export default function ImputationResults({ data }: ImputationResultsProps) {
{/* Explanation */}
-
+
What is KL-divergence? Kullback-Leibler divergence measures how much
one probability distribution differs from another. It quantifies the "information lost"
diff --git a/microimputation-dashboard/components/VisualizationDashboard.tsx b/microimputation-dashboard/components/VisualizationDashboard.tsx
index 67a1d88..6904329 100644
--- a/microimputation-dashboard/components/VisualizationDashboard.tsx
+++ b/microimputation-dashboard/components/VisualizationDashboard.tsx
@@ -466,8 +466,8 @@ export default function VisualizationDashboard({
Assessment of the quality of the imputations produced by the best-performing (or the only selected) model
-
- {/* Imputed Variables Section */}
+
+ {/* Imputed Variables Section - 1/4 width */}
Imputed Variables
@@ -477,7 +477,7 @@ export default function VisualizationDashboard({
{dataAnalysis.imputedVars.length} variable{dataAnalysis.imputedVars.length !== 1 ? 's' : ''} imputed
-
+ 3 ? 'max-h-32 overflow-y-auto' : ''}`}>
{dataAnalysis.imputedVars.map((variable) => (
{variable}
@@ -492,7 +492,7 @@ export default function VisualizationDashboard({
)}
- {/* Best Model Section */}
+ {/* Best Model Section - 1/4 width */}
{dataAnalysis.allMethods.length === 1 ? 'Imputation Model' : 'Best Performing Model'}
@@ -526,6 +526,87 @@ export default function VisualizationDashboard({
)}
+
+ {/* Metrics Section - 1/2 width */}
+
+
+ Performance Metrics
+
+
+ {/* Average Test Losses */}
+ {(() => {
+ const benchmarkData = data.filter(d => d.type === 'benchmark_loss' && d.method === dataAnalysis.bestModel && d.split === 'test');
+
+ // Calculate avg quantile loss
+ const quantileLossData = benchmarkData.filter(
+ d => d.metric_name === 'quantile_loss' &&
+ typeof d.quantile === 'number' &&
+ d.metric_value !== null
+ );
+ const avgQuantileLoss = quantileLossData.length > 0
+ ? quantileLossData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / quantileLossData.length
+ : null;
+
+ // Calculate avg log loss
+ const logLossData = benchmarkData.filter(
+ d => d.metric_name === 'log_loss' &&
+ d.metric_value !== null
+ );
+ const avgLogLoss = logLossData.length > 0
+ ? logLossData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / logLossData.length
+ : null;
+
+ // Calculate avg Wasserstein distance
+ const wassersteinData = data.filter(
+ d => d.type === 'distribution_distance' &&
+ d.metric_name === 'wasserstein_distance' &&
+ d.metric_value !== null
+ );
+ const avgWasserstein = wassersteinData.length > 0
+ ? wassersteinData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / wassersteinData.length
+ : null;
+
+ // Calculate avg KL divergence
+ const klData = data.filter(
+ d => d.type === 'distribution_distance' &&
+ d.metric_name === 'kl_divergence' &&
+ d.metric_value !== null
+ );
+ const avgKL = klData.length > 0
+ ? klData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / klData.length
+ : null;
+
+ return (
+ <>
+ {avgQuantileLoss !== null && (
+
+
Avg. test quantile loss
+
{avgQuantileLoss.toFixed(4)}
+
+ )}
+ {avgLogLoss !== null && (
+
+
Avg. test log loss
+
{avgLogLoss.toFixed(4)}
+
+ )}
+ {avgWasserstein !== null && (
+
+
Avg. wasserstein distance
+
{avgWasserstein.toFixed(4)}
+
+ )}
+ {avgKL !== null && (
+
+
Avg. KL divergence
+
{avgKL.toFixed(4)}
+
+ )}
+ >
+ );
+ })()}
+
+
diff --git a/microimputation-dashboard/public/microimputation_results.csv b/microimputation-dashboard/public/microimputation_results.csv
index 9ef58a5..8029b03 100644
--- a/microimputation-dashboard/public/microimputation_results.csv
+++ b/microimputation-dashboard/public/microimputation_results.csv
@@ -292,3 +292,66 @@ progressive_inclusion,OLSResults,N/A,N/A,cumulative_improvement,0.00121758583691
progressive_inclusion,OLSResults,N/A,N/A,marginal_improvement,0.00012356933219281885,test,"{""step"": 3, ""predictor_added"": ""age""}"
progressive_inclusion,OLSResults,N/A,N/A,cumulative_improvement,0.0011012642793990501,test,"{""step"": 4, ""predictor_added"": ""bp"", ""predictors"": [""sex"", ""bmi"", ""age"", ""bp""]}"
progressive_inclusion,OLSResults,N/A,N/A,marginal_improvement,-0.0001163215575132881,test,"{""step"": 4, ""predictor_added"": ""bp""}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 0, ""bin_start"": -0.12678066991651324, ""bin_end"": -0.11747005557601899, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 1, ""bin_start"": -0.11747005557601899, ""bin_end"": -0.10815944123552473, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 2, ""bin_start"": -0.10815944123552473, ""bin_end"": -0.0988488268950305, ""donor_height"": 0.6472491909385114, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 3, ""bin_start"": -0.0988488268950305, ""bin_end"": -0.08953821255453624, ""donor_height"": 1.2944983818770226, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 4, ""bin_start"": -0.08953821255453624, ""bin_end"": -0.08022759821404199, ""donor_height"": 0.9708737864077669, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 5, ""bin_start"": -0.08022759821404199, ""bin_end"": -0.07091698387354775, ""donor_height"": 2.2653721682847894, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 6, ""bin_start"": -0.07091698387354775, ""bin_end"": -0.0616063695330535, ""donor_height"": 3.2362459546925564, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 7, ""bin_start"": -0.0616063695330535, ""bin_end"": -0.052295755192559246, ""donor_height"": 3.2362459546925564, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 8, ""bin_start"": -0.052295755192559246, ""bin_end"": -0.04298514085206499, ""donor_height"": 5.17799352750809, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 9, ""bin_start"": -0.04298514085206499, ""bin_end"": -0.03367452651157074, ""donor_height"": 8.090614886731393, ""receiver_height"": 2.2556390977443606, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 10, ""bin_start"": -0.03367452651157074, ""bin_end"": -0.0243639121710765, ""donor_height"": 6.796116504854369, ""receiver_height"": 6.7669172932330826, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 11, ""bin_start"": -0.0243639121710765, ""bin_end"": -0.015053297830582249, ""donor_height"": 5.17799352750809, ""receiver_height"": 3.759398496240602, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 12, ""bin_start"": -0.015053297830582249, ""bin_end"": -0.005742683490087996, ""donor_height"": 9.385113268608416, ""receiver_height"": 18.045112781954884, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 13, ""bin_start"": -0.005742683490087996, ""bin_end"": 0.0035679308504062424, ""donor_height"": 10.679611650485437, ""receiver_height"": 24.060150375939852, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 14, ""bin_start"": 0.0035679308504062424, ""bin_end"": 0.012878545190900509, ""donor_height"": 5.501618122977347, ""receiver_height"": 18.796992481203006, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 15, ""bin_start"": 0.012878545190900509, ""bin_end"": 0.022189159531394748, ""donor_height"": 8.09061488673139, ""receiver_height"": 20.30075187969925, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 16, ""bin_start"": 0.022189159531394748, ""bin_end"": 0.031499773871888986, ""donor_height"": 6.4724919093851145, ""receiver_height"": 2.2556390977443606, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 17, ""bin_start"": 0.031499773871888986, ""bin_end"": 0.04081038821238325, ""donor_height"": 5.501618122977347, ""receiver_height"": 3.007518796992481, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 18, ""bin_start"": 0.04081038821238325, ""bin_end"": 0.05012100255287749, ""donor_height"": 3.8834951456310676, ""receiver_height"": 0.7518796992481204, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 19, ""bin_start"": 0.05012100255287749, ""bin_end"": 0.05943161689337176, ""donor_height"": 2.588996763754045, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 20, ""bin_start"": 0.05943161689337176, ""bin_end"": 0.068742231233866, ""donor_height"": 2.2653721682847894, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 21, ""bin_start"": 0.068742231233866, ""bin_end"": 0.07805284557436024, ""donor_height"": 0.9708737864077669, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 22, ""bin_start"": 0.07805284557436024, ""bin_end"": 0.0873634599148545, ""donor_height"": 1.9417475728155338, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 23, ""bin_start"": 0.0873634599148545, ""bin_end"": 0.09667407425534874, ""donor_height"": 2.2653721682847894, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 24, ""bin_start"": 0.09667407425534874, ""bin_end"": 0.10598468859584298, ""donor_height"": 0.6472491909385114, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 25, ""bin_start"": 0.10598468859584298, ""bin_end"": 0.11529530293633725, ""donor_height"": 0.6472491909385113, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 26, ""bin_start"": 0.11529530293633725, ""bin_end"": 0.12460591727683148, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 27, ""bin_start"": 0.12460591727683148, ""bin_end"": 0.13391653161732572, ""donor_height"": 1.2944983818770228, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 28, ""bin_start"": 0.13391653161732572, ""bin_end"": 0.14322714595781996, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 29, ""bin_start"": 0.14322714595781996, ""bin_end"": 0.15253776029831428, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 0, ""bin_start"": -0.0763945037500033, ""bin_end"": -0.06767353884966323, ""donor_height"": 6.796116504854369, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 1, ""bin_start"": -0.06767353884966323, ""bin_end"": -0.05895257394932317, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 2, ""bin_start"": -0.05895257394932317, ""bin_end"": -0.0502316090489831, ""donor_height"": 0.9708737864077671, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 3, ""bin_start"": -0.0502316090489831, ""bin_end"": -0.041510644148643035, ""donor_height"": 0.9708737864077671, ""receiver_height"": 2.255639097744361, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 4, ""bin_start"": -0.041510644148643035, ""bin_end"": -0.03278967924830297, ""donor_height"": 26.537216828478964, ""receiver_height"": 5.263157894736842, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 5, ""bin_start"": -0.03278967924830297, ""bin_end"": -0.0240687143479629, ""donor_height"": 0.6472491909385113, ""receiver_height"": 6.015037593984962, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 6, ""bin_start"": -0.0240687143479629, ""bin_end"": -0.015347749447622835, ""donor_height"": 0.6472491909385113, ""receiver_height"": 11.278195488721805, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 7, ""bin_start"": -0.015347749447622835, ""bin_end"": -0.006626784547282771, ""donor_height"": 1.6181229773462782, ""receiver_height"": 10.526315789473683, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 8, ""bin_start"": -0.006626784547282771, ""bin_end"": 0.0020941803530572928, ""donor_height"": 24.59546925566343, ""receiver_height"": 15.037593984962406, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 9, ""bin_start"": 0.0020941803530572928, ""bin_end"": 0.010815145253397357, ""donor_height"": 0.9708737864077671, ""receiver_height"": 9.774436090225564, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 10, ""bin_start"": 0.010815145253397357, ""bin_end"": 0.01953611015373742, ""donor_height"": 0.9708737864077671, ""receiver_height"": 19.548872180451127, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 11, ""bin_start"": 0.01953611015373742, ""bin_end"": 0.0282570750540775, ""donor_height"": 2.5889967637540456, ""receiver_height"": 6.015037593984963, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 12, ""bin_start"": 0.0282570750540775, ""bin_end"": 0.03697803995441756, ""donor_height"": 17.475728155339805, ""receiver_height"": 3.7593984962406015, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 13, ""bin_start"": 0.03697803995441756, ""bin_end"": 0.04569900485475763, ""donor_height"": 0.32362459546925565, ""receiver_height"": 5.263157894736842, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 14, ""bin_start"": 0.04569900485475763, ""bin_end"": 0.05441996975509768, ""donor_height"": 0.6472491909385113, ""receiver_height"": 1.5037593984962407, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 15, ""bin_start"": 0.05441996975509768, ""bin_end"": 0.06314093465543776, ""donor_height"": 0.3236245954692557, ""receiver_height"": 1.5037593984962407, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 16, ""bin_start"": 0.06314093465543776, ""bin_end"": 0.07186189955577783, ""donor_height"": 7.443365695792881, ""receiver_height"": 2.2556390977443606, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 17, ""bin_start"": 0.07186189955577783, ""bin_end"": 0.08058286445611788, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 18, ""bin_start"": 0.08058286445611788, ""bin_end"": 0.08930382935645796, ""donor_height"": 0.9708737864077669, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 19, ""bin_start"": 0.08930382935645796, ""bin_end"": 0.09802479425679801, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 20, ""bin_start"": 0.09802479425679801, ""bin_end"": 0.10674575915713809, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 21, ""bin_start"": 0.10674575915713809, ""bin_end"": 0.11546672405747814, ""donor_height"": 3.2362459546925564, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 22, ""bin_start"": 0.11546672405747814, ""bin_end"": 0.12418768895781822, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 23, ""bin_start"": 0.12418768895781822, ""bin_end"": 0.1329086538581583, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 24, ""bin_start"": 0.1329086538581583, ""bin_end"": 0.14162961875849833, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 25, ""bin_start"": 0.14162961875849833, ""bin_end"": 0.1503505836588384, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 26, ""bin_start"": 0.1503505836588384, ""bin_end"": 0.1590715485591785, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 27, ""bin_start"": 0.1590715485591785, ""bin_end"": 0.16779251345951857, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 28, ""bin_start"": 0.16779251345951857, ""bin_end"": 0.17651347835985864, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 29, ""bin_start"": 0.17651347835985864, ""bin_end"": 0.18523444326019867, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,risk_factor,N/A,categorical_distribution,,full,"{""category"": ""high"", ""donor_proportion"": 34.95145631067961, ""receiver_proportion"": 45.86466165413533, ""n_samples_donor"": 309, ""n_samples_receiver"": 133}"
+distribution_bins,OLSResults,risk_factor,N/A,categorical_distribution,,full,"{""category"": ""low"", ""donor_proportion"": 36.24595469255664, ""receiver_proportion"": 54.13533834586466, ""n_samples_donor"": 309, ""n_samples_receiver"": 133}"
+distribution_bins,OLSResults,risk_factor,N/A,categorical_distribution,,full,"{""category"": ""medium"", ""donor_proportion"": 28.802588996763756, ""receiver_proportion"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133}"
diff --git a/microimpute/utils/dashboard_formatter.py b/microimpute/utils/dashboard_formatter.py
index a5602fa..613f461 100644
--- a/microimpute/utils/dashboard_formatter.py
+++ b/microimpute/utils/dashboard_formatter.py
@@ -5,9 +5,281 @@
import json
from typing import Any, Dict, List, Optional, Union
+import numpy as np
import pandas as pd
+def _compute_histogram_data(
+ donor_values: np.ndarray,
+ receiver_values: np.ndarray,
+ variable_name: str,
+ n_bins: int = 30,
+) -> Dict[str, Union[List[float], str, int]]:
+ """
+ Compute histogram bin data for numerical variables.
+
+ Parameters
+ ----------
+ donor_values : np.ndarray
+ Original donor dataset values
+ receiver_values : np.ndarray
+ Imputed receiver dataset values
+ variable_name : str
+ Name of the variable being analyzed
+ n_bins : int
+ Number of histogram bins (default: 30)
+
+ Returns
+ -------
+ Dict containing bin edges and heights for both distributions
+ """
+ # Remove NaN values
+ donor_clean = donor_values[~np.isnan(donor_values)]
+ receiver_clean = receiver_values[~np.isnan(receiver_values)]
+
+ # Determine bin edges based on combined data range using numpy's auto algorithm
+ combined = np.concatenate([donor_clean, receiver_clean])
+ _, bin_edges = np.histogram(combined, bins=n_bins)
+
+ # Compute histogram heights (normalized as densities)
+ donor_heights, _ = np.histogram(donor_clean, bins=bin_edges, density=True)
+ receiver_heights, _ = np.histogram(
+ receiver_clean, bins=bin_edges, density=True
+ )
+
+ # Convert to percentages for easier interpretation
+ # Multiply by bin width to get probability mass per bin
+ bin_widths = np.diff(bin_edges)
+ donor_heights = (donor_heights * bin_widths * 100).tolist()
+ receiver_heights = (receiver_heights * bin_widths * 100).tolist()
+
+ return {
+ "variable": variable_name,
+ "bin_edges": bin_edges.tolist(),
+ "donor_heights": donor_heights,
+ "receiver_heights": receiver_heights,
+ "n_samples_donor": len(donor_clean),
+ "n_samples_receiver": len(receiver_clean),
+ "n_bins": n_bins,
+ }
+
+
+def _compute_categorical_distribution(
+ donor_values: pd.Series,
+ receiver_values: pd.Series,
+ variable_name: str,
+ max_categories: int = 20,
+) -> Dict[str, Union[List, str, bool]]:
+ """
+ Compute distribution data for categorical variables.
+
+ Parameters
+ ----------
+ donor_values : pd.Series
+ Original donor dataset values
+ receiver_values : pd.Series
+ Imputed receiver dataset values
+ variable_name : str
+ Name of the variable
+ max_categories : int
+ Maximum number of categories to include (others grouped as "Other")
+
+ Returns
+ -------
+ Dict containing category labels and proportions
+ """
+ # Get value counts
+ donor_counts = donor_values.value_counts()
+ receiver_counts = receiver_values.value_counts()
+
+ # Get all unique categories
+ all_categories = list(set(donor_counts.index) | set(receiver_counts.index))
+
+ # If too many categories, keep top ones and group rest as "Other"
+ if len(all_categories) > max_categories:
+ # Get top categories by combined frequency
+ combined_counts = donor_counts.add(receiver_counts, fill_value=0)
+ top_categories = combined_counts.nlargest(
+ max_categories - 1
+ ).index.tolist()
+
+ # Calculate "Other" category
+ donor_other = donor_counts[
+ ~donor_counts.index.isin(top_categories)
+ ].sum()
+ receiver_other = receiver_counts[
+ ~receiver_counts.index.isin(top_categories)
+ ].sum()
+
+ categories = top_categories + ["Other"]
+
+ # Get proportions
+ donor_props = [donor_counts.get(cat, 0) for cat in top_categories]
+ donor_props.append(donor_other)
+ donor_props = (
+ pd.Series(donor_props) / donor_values.count() * 100
+ ).tolist()
+
+ receiver_props = [
+ receiver_counts.get(cat, 0) for cat in top_categories
+ ]
+ receiver_props.append(receiver_other)
+ receiver_props = (
+ pd.Series(receiver_props) / receiver_values.count() * 100
+ ).tolist()
+ else:
+ categories = sorted(all_categories)
+ donor_props = [
+ (donor_counts.get(cat, 0) / donor_values.count() * 100)
+ for cat in categories
+ ]
+ receiver_props = [
+ (receiver_counts.get(cat, 0) / receiver_values.count() * 100)
+ for cat in categories
+ ]
+
+ return {
+ "variable": variable_name,
+ "categories": categories,
+ "donor_proportions": donor_props,
+ "receiver_proportions": receiver_props,
+ "n_samples_donor": int(donor_values.count()),
+ "n_samples_receiver": int(receiver_values.count()),
+ "is_categorical": True,
+ }
+
+
+def _format_histogram_rows(
+ histogram_data: Dict[str, Union[List, str, int, bool]], method: str
+) -> List[Dict]:
+ """
+ Convert histogram data to CSV row format.
+
+ Parameters
+ ----------
+ histogram_data : Dict
+ Output from _compute_histogram_data or _compute_categorical_distribution
+ method : str
+ Imputation method name
+
+ Returns
+ -------
+ List of dictionaries ready for CSV formatting
+ """
+ rows = []
+
+ if histogram_data.get("is_categorical", False):
+ # Categorical variable - store as distribution_bins type
+ for i, category in enumerate(histogram_data["categories"]):
+ rows.append(
+ {
+ "type": "distribution_bins",
+ "method": method,
+ "variable": histogram_data["variable"],
+ "quantile": "N/A",
+ "metric_name": "categorical_distribution",
+ "metric_value": None, # Not used for histograms
+ "split": "full",
+ "additional_info": json.dumps(
+ {
+ "category": str(category),
+ "donor_proportion": float(
+ histogram_data["donor_proportions"][i]
+ ),
+ "receiver_proportion": float(
+ histogram_data["receiver_proportions"][i]
+ ),
+ "n_samples_donor": int(
+ histogram_data["n_samples_donor"]
+ ),
+ "n_samples_receiver": int(
+ histogram_data["n_samples_receiver"]
+ ),
+ }
+ ),
+ }
+ )
+ else:
+ # Numerical variable - store bin data
+ n_bins = len(histogram_data["donor_heights"])
+ for i in range(n_bins):
+ rows.append(
+ {
+ "type": "distribution_bins",
+ "method": method,
+ "variable": histogram_data["variable"],
+ "quantile": "N/A",
+ "metric_name": "histogram_distribution",
+ "metric_value": None, # Not used for histograms
+ "split": "full",
+ "additional_info": json.dumps(
+ {
+ "bin_index": int(i),
+ "bin_start": float(histogram_data["bin_edges"][i]),
+ "bin_end": float(
+ histogram_data["bin_edges"][i + 1]
+ ),
+ "donor_height": float(
+ histogram_data["donor_heights"][i]
+ ),
+ "receiver_height": float(
+ histogram_data["receiver_heights"][i]
+ ),
+ "n_samples_donor": int(
+ histogram_data["n_samples_donor"]
+ ),
+ "n_samples_receiver": int(
+ histogram_data["n_samples_receiver"]
+ ),
+ "total_bins": int(n_bins),
+ }
+ ),
+ }
+ )
+
+ return rows
+
+
+def _validate_imputed_variables(
+ donor_data: pd.DataFrame,
+ receiver_data: pd.DataFrame,
+ imputed_variables: List[str],
+) -> None:
+ """
+ Validate that all imputed variables exist in both datasets.
+
+ Parameters
+ ----------
+ donor_data : pd.DataFrame
+ Original donor dataset
+ receiver_data : pd.DataFrame
+ Imputed receiver dataset
+ imputed_variables : List[str]
+ List of variable names that were imputed
+
+ Raises
+ ------
+ ValueError
+ If any imputed variable is missing from either dataset
+ """
+ missing_in_donor = [
+ var for var in imputed_variables if var not in donor_data.columns
+ ]
+ missing_in_receiver = [
+ var for var in imputed_variables if var not in receiver_data.columns
+ ]
+
+ if missing_in_donor:
+ raise ValueError(
+ f"The following imputed variables are missing from donor_data: {missing_in_donor}"
+ )
+
+ if missing_in_receiver:
+ raise ValueError(
+ f"The following imputed variables are missing from receiver_data: {missing_in_receiver}"
+ )
+
+
def format_csv(
output_path: Optional[str] = None,
autoimpute_result: Optional[Dict] = None,
@@ -17,6 +289,10 @@ def format_csv(
predictor_importance_df: Optional[pd.DataFrame] = None,
progressive_inclusion_df: Optional[pd.DataFrame] = None,
best_method_name: Optional[str] = None,
+ donor_data: Optional[pd.DataFrame] = None,
+ receiver_data: Optional[pd.DataFrame] = None,
+ imputed_variables: Optional[List[str]] = None,
+ n_histogram_bins: int = 30,
) -> pd.DataFrame:
"""
Format various imputation outputs into a unified long-format CSV for dashboard visualization.
@@ -54,11 +330,30 @@ def format_csv(
best_method_name : str, optional
Name of the best method to append "_best_method" suffix to.
+ donor_data : pd.DataFrame, optional
+ Original donor dataset for histogram generation. Required if imputed_variables is provided.
+
+ receiver_data : pd.DataFrame, optional
+ Imputed receiver dataset for histogram generation. Required if imputed_variables is provided.
+
+ imputed_variables : List[str], optional
+ List of variable names that were imputed. When provided with donor_data and receiver_data,
+ histogram bin data will be included in the CSV for distribution visualization.
+
+ n_histogram_bins : int, default 30
+ Number of bins to use for numerical variable histograms.
+
Returns
-------
pd.DataFrame
Unified long-format DataFrame with columns:
['type', 'method', 'variable', 'quantile', 'metric_name', 'metric_value', 'split', 'additional_info']
+
+ Raises
+ ------
+ ValueError
+ If imputed_variables is provided but donor_data or receiver_data is missing.
+ If any imputed variable is not present in both donor_data and receiver_data.
"""
rows = []
@@ -339,6 +634,46 @@ def format_csv(
}
)
+ # 7. Process histogram distribution data for imputed variables
+ if imputed_variables is not None:
+ # Validate inputs
+ if donor_data is None or receiver_data is None:
+ raise ValueError(
+ "donor_data and receiver_data are required when imputed_variables is provided"
+ )
+
+ # Validate that all imputed variables exist in both datasets
+ _validate_imputed_variables(
+ donor_data, receiver_data, imputed_variables
+ )
+
+ # Generate histogram data for each imputed variable
+ for var in imputed_variables:
+ # Check if variable is categorical or numerical
+ if donor_data[
+ var
+ ].dtype == "object" or pd.api.types.is_categorical_dtype(
+ donor_data[var]
+ ):
+ # Categorical variable
+ hist_data = _compute_categorical_distribution(
+ donor_data[var], receiver_data[var], var
+ )
+ else:
+ # Numerical variable
+ hist_data = _compute_histogram_data(
+ donor_data[var].values,
+ receiver_data[var].values,
+ var,
+ n_bins=n_histogram_bins,
+ )
+
+ # Format histogram rows and add to main rows list
+ histogram_rows = _format_histogram_rows(
+ hist_data, best_method_name if best_method_name else "N/A"
+ )
+ rows.extend(histogram_rows)
+
# Create DataFrame from rows
if not rows:
# Return empty DataFrame with correct columns if no data
diff --git a/tests/test_dashboard_formatter.py b/tests/test_dashboard_formatter.py
index 41bf210..1992069 100644
--- a/tests/test_dashboard_formatter.py
+++ b/tests/test_dashboard_formatter.py
@@ -19,6 +19,7 @@
VALID_TYPES = {
"benchmark_loss",
"distribution_distance",
+ "distribution_bins",
"predictor_correlation",
"predictor_target_mi",
"predictor_importance",
@@ -625,6 +626,314 @@ def test_saved_csv_preserves_data(self, sample_autoimpute_result):
assert list(df["method"]) == list(loaded_df["method"])
+class TestDistributionBins:
+ """Test distribution_bins type formatting for histogram data."""
+
+ @pytest.fixture
+ def sample_donor_receiver_data(self):
+ """Create sample donor and receiver datasets with imputed variables."""
+ np.random.seed(42)
+
+ # Create donor data
+ donor_data = pd.DataFrame(
+ {
+ "numerical_var1": np.random.normal(100, 15, 200),
+ "numerical_var2": np.random.exponential(2, 200),
+ "categorical_var": np.random.choice(
+ ["A", "B", "C"], 200, p=[0.5, 0.3, 0.2]
+ ),
+ "predictor1": np.random.randn(200),
+ "predictor2": np.random.randn(200),
+ }
+ )
+
+ # Create receiver data (slightly different distributions)
+ receiver_data = pd.DataFrame(
+ {
+ "numerical_var1": np.random.normal(
+ 102, 14, 150
+ ), # Shifted mean
+ "numerical_var2": np.random.exponential(
+ 2.1, 150
+ ), # Different rate
+ "categorical_var": np.random.choice(
+ ["A", "B", "C"], 150, p=[0.4, 0.4, 0.2]
+ ),
+ "predictor1": np.random.randn(150),
+ "predictor2": np.random.randn(150),
+ }
+ )
+
+ return donor_data, receiver_data
+
+ def test_distribution_bins_created(self, sample_donor_receiver_data):
+ """Test that distribution_bins rows are created when histogram data is provided."""
+ donor_data, receiver_data = sample_donor_receiver_data
+ imputed_variables = [
+ "numerical_var1",
+ "numerical_var2",
+ "categorical_var",
+ ]
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", delete=False, suffix=".csv"
+ ) as f:
+ output_path = f.name
+
+ try:
+ result = format_csv(
+ output_path=output_path,
+ donor_data=donor_data,
+ receiver_data=receiver_data,
+ imputed_variables=imputed_variables,
+ best_method_name="TestMethod",
+ n_histogram_bins=20,
+ )
+
+ # Check that distribution_bins type exists
+ dist_bins_rows = result[result["type"] == "distribution_bins"]
+ assert len(dist_bins_rows) > 0
+
+ # Check that all imputed variables have bins
+ variables_with_bins = dist_bins_rows["variable"].unique()
+ assert set(variables_with_bins) == set(imputed_variables)
+
+ # Check numerical variables have histogram_distribution metric
+ numerical_bins = dist_bins_rows[
+ dist_bins_rows["metric_name"] == "histogram_distribution"
+ ]
+ assert "numerical_var1" in numerical_bins["variable"].values
+ assert "numerical_var2" in numerical_bins["variable"].values
+
+ # Check categorical variable has categorical_distribution metric
+ categorical_bins = dist_bins_rows[
+ dist_bins_rows["metric_name"] == "categorical_distribution"
+ ]
+ assert "categorical_var" in categorical_bins["variable"].values
+
+ finally:
+ Path(output_path).unlink()
+
+ def test_numerical_histogram_heights_match_numpy(
+ self, sample_donor_receiver_data
+ ):
+ """Test that histogram heights match numpy's histogram output."""
+ donor_data, receiver_data = sample_donor_receiver_data
+ var_name = "numerical_var1"
+ n_bins = 15
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", delete=False, suffix=".csv"
+ ) as f:
+ output_path = f.name
+
+ try:
+ result = format_csv(
+ output_path=output_path,
+ donor_data=donor_data,
+ receiver_data=receiver_data,
+ imputed_variables=[var_name],
+ best_method_name="TestMethod",
+ n_histogram_bins=n_bins,
+ )
+
+ # Get the distribution bins for our variable
+ dist_bins = result[
+ (result["type"] == "distribution_bins")
+ & (result["variable"] == var_name)
+ & (result["metric_name"] == "histogram_distribution")
+ ]
+
+ # Should have n_bins rows for this variable
+ assert len(dist_bins) == n_bins
+
+ # Extract bin data from additional_info
+ bin_data = []
+ for _, row in dist_bins.iterrows():
+ info = json.loads(row["additional_info"])
+ bin_data.append(info)
+
+ # Sort by bin index
+ bin_data = sorted(bin_data, key=lambda x: x["bin_index"])
+
+ # Manually compute histogram with numpy for comparison
+ donor_values = donor_data[var_name].values
+ receiver_values = receiver_data[var_name].values
+
+ # Remove NaN values
+ donor_clean = donor_values[~np.isnan(donor_values)]
+ receiver_clean = receiver_values[~np.isnan(receiver_values)]
+
+ # Compute bin edges from combined data (same as in the function)
+ combined = np.concatenate([donor_clean, receiver_clean])
+ _, bin_edges = np.histogram(combined, bins=n_bins)
+
+ # Compute histograms
+ donor_heights_np, _ = np.histogram(
+ donor_clean, bins=bin_edges, density=True
+ )
+ receiver_heights_np, _ = np.histogram(
+ receiver_clean, bins=bin_edges, density=True
+ )
+
+ # Convert to percentages (same as in the function)
+ bin_widths = np.diff(bin_edges)
+ donor_heights_expected = donor_heights_np * bin_widths * 100
+ receiver_heights_expected = receiver_heights_np * bin_widths * 100
+
+ # Compare heights
+ for i, data in enumerate(bin_data):
+ assert data["bin_index"] == i
+ # Check bin edges
+ assert np.isclose(data["bin_start"], bin_edges[i], rtol=1e-10)
+ assert np.isclose(
+ data["bin_end"], bin_edges[i + 1], rtol=1e-10
+ )
+ # Check heights match numpy's output
+ assert np.isclose(
+ data["donor_height"], donor_heights_expected[i], rtol=1e-10
+ )
+ assert np.isclose(
+ data["receiver_height"],
+ receiver_heights_expected[i],
+ rtol=1e-10,
+ )
+ # Check sample counts
+ assert data["n_samples_donor"] == len(donor_clean)
+ assert data["n_samples_receiver"] == len(receiver_clean)
+ assert data["total_bins"] == n_bins
+
+ finally:
+ Path(output_path).unlink()
+
+ def test_categorical_distribution_proportions(
+ self, sample_donor_receiver_data
+ ):
+ """Test that categorical distribution proportions are computed correctly."""
+ donor_data, receiver_data = sample_donor_receiver_data
+ var_name = "categorical_var"
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", delete=False, suffix=".csv"
+ ) as f:
+ output_path = f.name
+
+ try:
+ result = format_csv(
+ output_path=output_path,
+ donor_data=donor_data,
+ receiver_data=receiver_data,
+ imputed_variables=[var_name],
+ best_method_name="TestMethod",
+ )
+
+ # Get the distribution bins for categorical variable
+ cat_bins = result[
+ (result["type"] == "distribution_bins")
+ & (result["variable"] == var_name)
+ & (result["metric_name"] == "categorical_distribution")
+ ]
+
+ # Should have one row per category
+ assert len(cat_bins) == 3 # A, B, C
+
+ # Extract category data
+ category_data = {}
+ for _, row in cat_bins.iterrows():
+ info = json.loads(row["additional_info"])
+ category = info["category"]
+ category_data[category] = info
+
+ # Manually compute proportions
+ donor_counts = donor_data[var_name].value_counts()
+ receiver_counts = receiver_data[var_name].value_counts()
+
+ donor_total = donor_data[var_name].count()
+ receiver_total = receiver_data[var_name].count()
+
+ # Check each category
+ for category in ["A", "B", "C"]:
+ assert category in category_data
+ data = category_data[category]
+
+ # Expected proportions
+ expected_donor_prop = (
+ donor_counts.get(category, 0) / donor_total
+ ) * 100
+ expected_receiver_prop = (
+ receiver_counts.get(category, 0) / receiver_total
+ ) * 100
+
+ # Check proportions match
+ assert np.isclose(
+ data["donor_proportion"], expected_donor_prop, rtol=1e-10
+ )
+ assert np.isclose(
+ data["receiver_proportion"],
+ expected_receiver_prop,
+ rtol=1e-10,
+ )
+
+ # Check sample counts
+ assert data["n_samples_donor"] == donor_total
+ assert data["n_samples_receiver"] == receiver_total
+
+ finally:
+ Path(output_path).unlink()
+
+ def test_error_when_missing_data_for_imputed_variables(self):
+ """Test that error is raised when donor/receiver data is missing for histogram generation."""
+ imputed_variables = ["var1", "var2"]
+
+ with pytest.raises(
+ ValueError, match="donor_data and receiver_data are required"
+ ):
+ format_csv(
+ imputed_variables=imputed_variables,
+ donor_data=None,
+ receiver_data=None,
+ )
+
+ # Test with missing donor data
+ receiver_data = pd.DataFrame({"var1": [1, 2, 3], "var2": [4, 5, 6]})
+ with pytest.raises(
+ ValueError, match="donor_data and receiver_data are required"
+ ):
+ format_csv(
+ imputed_variables=imputed_variables,
+ donor_data=None,
+ receiver_data=receiver_data,
+ )
+
+ def test_error_when_variable_missing_from_datasets(self):
+ """Test that error is raised when imputed variable is not in datasets."""
+ donor_data = pd.DataFrame({"var1": [1, 2, 3], "var2": [4, 5, 6]})
+ receiver_data = pd.DataFrame({"var1": [7, 8, 9], "var2": [10, 11, 12]})
+ imputed_variables = ["var1", "var3"] # var3 doesn't exist
+
+ with pytest.raises(
+ ValueError, match="missing from donor_data: \\['var3'\\]"
+ ):
+ format_csv(
+ donor_data=donor_data,
+ receiver_data=receiver_data,
+ imputed_variables=imputed_variables,
+ )
+
+ # Test with variable missing from receiver
+ receiver_data = pd.DataFrame({"var1": [7, 8, 9]}) # Missing var2
+ imputed_variables = ["var1", "var2"]
+
+ with pytest.raises(
+ ValueError, match="missing from receiver_data: \\['var2'\\]"
+ ):
+ format_csv(
+ donor_data=donor_data,
+ receiver_data=receiver_data,
+ imputed_variables=imputed_variables,
+ )
+
+
class TestEdgeCases:
"""Test edge cases and error handling."""
From c8d722e00f17ad06b6927114cb20d936e9e9ff8d Mon Sep 17 00:00:00 2001
From: juaristi22
Date: Thu, 6 Nov 2025 12:38:07 +0800
Subject: [PATCH 2/2] enable log transform in data preprocessing
---
changelog_entry.yaml | 1 +
microimpute/utils/data.py | 363 +++++++++++++++--
tests/test_data_preprocessing.py | 671 ++++++++++++++++++++++++++++++-
3 files changed, 997 insertions(+), 38 deletions(-)
diff --git a/changelog_entry.yaml b/changelog_entry.yaml
index c68d72d..f9c460a 100644
--- a/changelog_entry.yaml
+++ b/changelog_entry.yaml
@@ -2,3 +2,4 @@
changes:
added:
- Distribution comparison histogram of donor and receiver datasets for imputed variables (to dashboard).
+ - Log transformation option for numerical variables in data preprocessing.
diff --git a/microimpute/utils/data.py b/microimpute/utils/data.py
index 938d70e..3aa5e4e 100644
--- a/microimpute/utils/data.py
+++ b/microimpute/utils/data.py
@@ -1,18 +1,20 @@
"""Data preparation and transformation utilities
This module provides comprehensive data preparation functions for imputation workflows,
-including data splitting, normalization, unnormalization, and categorical variable handling.
+including data splitting, normalization, log transformation, and categorical variable handling.
These utilities ensure consistent data preprocessing across different imputation methods.
Key functions:
- - preprocess_data: split and optionally normalize data for training/testing
+ - preprocess_data: split and optionally normalize or log-transform data for training/testing
- unnormalize_predictions: convert normalized predictions back to original scale
+ - unlog_transform_predictions: convert log-transformed predictions back to original scale
- Handle categorical variables through one-hot encoding
"""
import logging
-from typing import Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
+import numpy as np
import pandas as pd
from pydantic import validate_call
from sklearn.model_selection import train_test_split
@@ -30,6 +32,7 @@
@validate_call(config=VALIDATE_CONFIG)
def normalize_data(
data: pd.DataFrame,
+ columns_to_normalize: Optional[List[str]] = None,
) -> Tuple[pd.DataFrame, dict]:
"""Normalize numeric columns in a DataFrame.
@@ -38,6 +41,8 @@ def normalize_data(
Args:
data: DataFrame to normalize.
+ columns_to_normalize: Optional list of specific columns to normalize.
+ If None, all numeric columns will be normalized.
Returns:
Tuple of (normalized_data, normalization_params)
@@ -45,6 +50,7 @@ def normalize_data(
to {"mean": float, "std": float}.
Raises:
+ ValueError: If specified columns don't exist in data.
RuntimeError: If normalization fails.
"""
logger.debug("Normalizing data")
@@ -64,10 +70,39 @@ def normalize_data(
f"Excluding categorical columns from normalization: {categorical_cols}"
)
- # Get only numeric columns for normalization
- numeric_cols = [
- col for col in data.columns if col not in categorical_cols
- ]
+ # Determine which columns to normalize
+ if columns_to_normalize is not None:
+ # Validate that specified columns exist
+ missing_cols = set(columns_to_normalize) - set(data.columns)
+ if missing_cols:
+ error_msg = (
+ f"Columns specified for normalization not found in "
+ f"data: {missing_cols}"
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Only normalize specified columns that are not categorical
+ numeric_cols = [
+ col
+ for col in columns_to_normalize
+ if col not in categorical_cols
+ ]
+
+ # Warn if user specified categorical columns
+ specified_categorical = [
+ col for col in columns_to_normalize if col in categorical_cols
+ ]
+ if specified_categorical:
+ logger.warning(
+ f"Skipping normalization for categorical columns: "
+ f"{specified_categorical}"
+ )
+ else:
+ # Get all numeric columns for normalization
+ numeric_cols = [
+ col for col in data.columns if col not in categorical_cols
+ ]
if not numeric_cols:
logger.warning("No numeric columns found for normalization")
@@ -106,6 +141,120 @@ def normalize_data(
raise RuntimeError("Failed to normalize data") from e
+@validate_call(config=VALIDATE_CONFIG)
+def log_transform_data(
+ data: pd.DataFrame,
+ columns_to_transform: Optional[List[str]] = None,
+) -> Tuple[pd.DataFrame, dict]:
+ """Apply log transformation to numeric columns in a DataFrame.
+
+ Categorical and boolean columns are excluded from transformation
+ to prevent issues when they are later encoded as dummy variables.
+
+ Args:
+ data: DataFrame to log transform.
+ columns_to_transform: Optional list of specific columns to
+ log transform. If None, all numeric columns will be transformed.
+
+ Returns:
+ Tuple of (log_transformed_data, log_transform_params)
+ where log_transform_params is a dict mapping column names
+ to {} for reversing the transformation.
+
+ Raises:
+ ValueError: If data contains non-positive values in numeric columns
+ or if specified columns don't exist in data.
+ RuntimeError: If log transformation fails.
+ """
+ logger.debug("Applying log transformation to data")
+ try:
+ from microimpute.utils.type_handling import VariableTypeDetector
+
+ # Identify categorical columns to exclude from log transformation
+ detector = VariableTypeDetector()
+ categorical_cols = []
+ for col in data.columns:
+ var_type, _ = detector.categorize_variable(data[col], col, logger)
+ if var_type in ["categorical", "numeric_categorical", "bool"]:
+ categorical_cols.append(col)
+
+ if categorical_cols:
+ logger.info(
+ f"Excluding categorical columns from log transformation: {categorical_cols}"
+ )
+
+ # Determine which columns to transform
+ if columns_to_transform is not None:
+ # Validate that specified columns exist
+ missing_cols = set(columns_to_transform) - set(data.columns)
+ if missing_cols:
+ error_msg = (
+ f"Columns specified for log transformation not found "
+ f"in data: {missing_cols}"
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Only transform specified columns that are not categorical
+ numeric_cols = [
+ col
+ for col in columns_to_transform
+ if col not in categorical_cols
+ ]
+
+ # Warn if user specified categorical columns
+ specified_categorical = [
+ col for col in columns_to_transform if col in categorical_cols
+ ]
+ if specified_categorical:
+ logger.warning(
+ f"Skipping log transformation for categorical "
+ f"columns: {specified_categorical}"
+ )
+ else:
+ # Get all numeric columns for log transformation
+ numeric_cols = [
+ col for col in data.columns if col not in categorical_cols
+ ]
+
+ if not numeric_cols:
+ logger.warning("No numeric columns found for log transformation")
+ return data.copy(), {}
+
+ # Check for non-positive values in numeric columns
+ data_copy = data.copy()
+ for col in numeric_cols:
+ min_val = data_copy[col].min()
+ if min_val <= 0:
+ error_msg = (
+ f"Column '{col}' contains non-positive values "
+ f"(min={min_val}). Log transformation requires all "
+ f"positive values."
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Apply log transformation only to numeric columns
+ log_transform_params = {}
+ for col in numeric_cols:
+ data_copy[col] = np.log(data_copy[col])
+ log_transform_params[col] = {}
+
+ logger.debug(
+ f"Log transformed {len(numeric_cols)} numeric columns successfully"
+ )
+ logger.debug(f"Log transformation parameters: {log_transform_params}")
+
+ return data_copy, log_transform_params
+
+ except ValueError:
+ # Re-raise ValueError as-is (for non-positive values)
+ raise
+ except (TypeError, AttributeError) as e:
+ logger.error(f"Error during log transformation: {str(e)}")
+ raise RuntimeError("Failed to apply log transformation") from e
+
+
@validate_call(config=VALIDATE_CONFIG)
def preprocess_data(
data: pd.DataFrame,
@@ -113,7 +262,8 @@ def preprocess_data(
train_size: Optional[float] = TRAIN_SIZE,
test_size: Optional[float] = TEST_SIZE,
random_state: Optional[int] = RANDOM_STATE,
- normalize: Optional[bool] = False,
+ normalize: Optional[Union[bool, List[str]]] = False,
+ log_transform: Optional[Union[bool, List[str]]] = False,
) -> Union[
Tuple[pd.DataFrame, dict], # when full_data=True
Tuple[pd.DataFrame, pd.DataFrame, dict], # when full_data=False
@@ -126,17 +276,35 @@ def preprocess_data(
train_size: Proportion of the dataset to include in the train split.
test_size: Proportion of the dataset to include in the test split.
random_state: Random seed for reproducibility.
- normalize: Whether to normalize the data.
+ normalize: Whether to normalize the data. Can be:
+ - True: normalize all numeric columns
+ - List of column names: normalize only those columns
+ - False: no normalization (default)
+ log_transform: Whether to apply log transformation to the data. Can be:
+ - True: transform all numeric columns
+ - List of column names: transform only those columns
+ - False: no transformation (default)
Returns:
- Different tuple formats depending on the value of full_data:
- - If full_data=True: (data, dummy_info)
- - If full_data=False: (X_train, X_test, dummy_info)
-
- Where dummy_info is a dictionary mapping original columns to their resulting dummy columns
+ Different tuple formats depending on parameters:
+ - If full_data=True and transformations applied:
+ (data, transform_params)
+ - If full_data=True and no transformations:
+ data
+ - If full_data=False and transformations applied:
+ (X_train, X_test, transform_params)
+ - If full_data=False and no transformations:
+ (X_train, X_test)
+
+ Where transform_params is a dict with keys:
+ - "normalization": dict of normalization parameters (or empty dict)
+ - "log_transform": dict of log transform parameters (or empty dict)
Raises:
- ValueError: If data is empty or invalid
+ ValueError: If data is empty or invalid, or if both normalize and
+ log_transform would apply to the same columns, or if log_transform
+ is applied to data with non-positive values, or if specified
+ columns don't exist in data.
RuntimeError: If data preprocessing fails
"""
@@ -146,23 +314,100 @@ def preprocess_data(
if data.empty:
raise ValueError("Data must not be None or empty")
+
+ # Check if both normalize and log_transform are requested
+ normalize_requested = normalize is not False and normalize != []
+ log_transform_requested = (
+ log_transform is not False and log_transform != []
+ )
+
+ # Validate that normalize and log_transform don't conflict
+ if normalize_requested and log_transform_requested:
+ # If both are True, they would apply to all numeric columns - conflict
+ if normalize is True and log_transform is True:
+ error_msg = (
+ "Cannot apply both normalization and log transformation to "
+ "all columns. Please specify which columns to transform with "
+ "each approach using lists."
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # If one is True and the other is a list, conflict
+ if normalize is True or log_transform is True:
+ error_msg = (
+ "Cannot apply both normalization and log transformation. "
+ "When using both, specify column lists for each to ensure "
+ "they apply to different variables."
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Both are lists - check for overlap
+ normalize_cols = (
+ set(normalize) if isinstance(normalize, list) else set()
+ )
+ log_transform_cols = (
+ set(log_transform) if isinstance(log_transform, list) else set()
+ )
+ overlap = normalize_cols & log_transform_cols
+
+ if overlap:
+ error_msg = (
+ f"Cannot apply both normalization and log transformation to "
+ f"the same columns: {overlap}. Each column can only have one "
+ f"transformation applied."
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
# Check for missing values
missing_count = data.isna().sum().sum()
if missing_count > 0:
logger.warning(f"Data contains {missing_count} missing values")
- if normalize:
- data, normalization_params = normalize_data(data)
+ # Apply normalization if requested
+ normalization_params = {}
+ if normalize_requested:
+ if isinstance(normalize, bool):
+ # normalize=True means normalize all numeric columns
+ data, normalization_params = normalize_data(data)
+ else:
+ # normalize is a list of specific columns
+ data, normalization_params = normalize_data(
+ data, columns_to_normalize=normalize
+ )
- if full_data and normalize:
- logger.info("Returning full preprocessed dataset")
- return (
- data,
- normalization_params,
- )
- elif full_data:
- logger.info("Returning full preprocessed dataset")
- return data
+ # Apply log transformation if requested
+ log_transform_params = {}
+ if log_transform_requested:
+ if isinstance(log_transform, bool):
+ # log_transform=True means transform all numeric columns
+ data, log_transform_params = log_transform_data(data)
+ else:
+ # log_transform is a list of specific columns
+ data, log_transform_params = log_transform_data(
+ data, columns_to_transform=log_transform
+ )
+
+ # Prepare transformation parameters to return
+ has_transformations = normalize_requested or log_transform_requested
+ if has_transformations:
+ # Merge both parameter dicts, with a key to distinguish them
+ transform_params = {
+ "normalization": normalization_params,
+ "log_transform": log_transform_params,
+ }
+
+ if full_data:
+ if has_transformations:
+ logger.info(
+ "Returning full preprocessed dataset with transformations"
+ )
+ return (data, transform_params)
+ else:
+ logger.info("Returning full preprocessed dataset")
+ return data
else:
logger.debug(
f"Splitting data with train_size={train_size}, test_size={test_size}"
@@ -177,17 +422,10 @@ def preprocess_data(
logger.info(
f"Data split into train ({X_train.shape}) and test ({X_test.shape}) sets"
)
- if normalize:
- return (
- X_train,
- X_test,
- normalization_params,
- )
+ if has_transformations:
+ return (X_train, X_test, transform_params)
else:
- return (
- X_train,
- X_test,
- )
+ return (X_train, X_test)
except (ValueError, TypeError) as e:
logger.error(f"Error in processing data: {str(e)}")
@@ -236,3 +474,56 @@ def unnormalize_predictions(
logger.debug(f"Unnormalized quantile {q} with shape {df_unnorm.shape}")
return unnormalized
+
+
+@validate_call(config=VALIDATE_CONFIG)
+def unlog_transform_predictions(
+ imputations: dict, log_transform_params: dict
+) -> dict:
+ """Reverse log transformation on predictions using stored parameters.
+
+ Args:
+ imputations: Dictionary mapping quantiles to DataFrames of predictions.
+ log_transform_params: Dictionary with column names that were
+ log-transformed.
+
+ Returns:
+ Dictionary with same structure as imputations but with
+ un-log-transformed values.
+
+ Raises:
+ ValueError: If columns in imputations don't match log transformation
+ parameters.
+ """
+ logger.debug(
+ f"Reversing log transformation for {len(imputations)} quantiles"
+ )
+
+ untransformed = {}
+ for q, df in imputations.items():
+ cols = df.columns
+
+ # Check that all columns have log transformation parameters
+ missing_params = [
+ col for col in cols if col not in log_transform_params
+ ]
+ if missing_params:
+ error_msg = (
+ f"Missing log transformation parameters for columns: "
+ f"{missing_params}"
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Reverse log transformation: x_original = exp(x_log)
+ df_untransformed = df.copy()
+ for col in cols:
+ df_untransformed[col] = np.exp(df[col])
+ untransformed[q] = df_untransformed
+
+ logger.debug(
+ f"Reversed log transformation for quantile {q} with shape "
+ f"{df_untransformed.shape}"
+ )
+
+ return untransformed
diff --git a/tests/test_data_preprocessing.py b/tests/test_data_preprocessing.py
index 6110e43..ee9ea83 100644
--- a/tests/test_data_preprocessing.py
+++ b/tests/test_data_preprocessing.py
@@ -4,7 +4,12 @@
import pandas as pd
import pytest
-from microimpute.utils.data import normalize_data, preprocess_data
+from microimpute.utils.data import (
+ log_transform_data,
+ normalize_data,
+ preprocess_data,
+ unlog_transform_predictions,
+)
class TestNormalize:
@@ -170,10 +175,13 @@ def test_preprocess_data_excludes_categoricals_from_normalization(self):
}
)
- result, norm_params = preprocess_data(
+ result, transform_params = preprocess_data(
data, full_data=True, normalize=True
)
+ # Extract normalization params from nested dict
+ norm_params = transform_params["normalization"]
+
# Categorical columns should be unchanged
pd.testing.assert_series_equal(result["race"], data["race"])
pd.testing.assert_series_equal(result["is_female"], data["is_female"])
@@ -233,3 +241,662 @@ def test_categorical_columns_dont_get_weird_suffixes_when_dummified(
# Column name should not contain decimal points
assert "." not in col, f"Column {col} has decimal point in name"
+
+
+class TestLogTransform:
+ """Test the log_transform_data function."""
+
+ def test_log_transform_excludes_categorical_columns(self):
+ """Test that categorical columns are not log transformed."""
+ data = pd.DataFrame(
+ {
+ "numeric_col": [1.0, 2.5, 3.7, 4.2, 5.9],
+ "categorical_col": [1, 2, 3, 1, 2],
+ "boolean_col": [0, 1, 0, 1, 0],
+ }
+ )
+
+ log_data, log_params = log_transform_data(data)
+
+ # Categorical and boolean columns should be unchanged
+ pd.testing.assert_series_equal(
+ log_data["categorical_col"], data["categorical_col"]
+ )
+ pd.testing.assert_series_equal(
+ log_data["boolean_col"], data["boolean_col"]
+ )
+
+ # Numeric column should be log transformed
+ assert not np.allclose(
+ log_data["numeric_col"].values, data["numeric_col"].values
+ )
+
+ # Only numeric column should have log transform params
+ assert "numeric_col" in log_params
+ assert "categorical_col" not in log_params
+ assert "boolean_col" not in log_params
+
+ def test_log_transform_correctly_transforms_numeric_columns(self):
+ """Test that numeric columns are correctly log transformed."""
+ data = pd.DataFrame(
+ {
+ "value1": [
+ 1.5,
+ 2.7,
+ 3.2,
+ 4.8,
+ 5.1,
+ 6.3,
+ 7.9,
+ 8.4,
+ 9.6,
+ 10.2,
+ ],
+ "value2": [
+ 15.5,
+ 27.3,
+ 32.1,
+ 48.7,
+ 51.9,
+ 63.2,
+ 79.8,
+ 84.5,
+ 96.1,
+ 102.4,
+ ],
+ "category": [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
+ }
+ )
+
+ log_data, log_params = log_transform_data(data)
+
+ # Check that numeric columns are log transformed
+ expected_value1 = np.log(data["value1"].values)
+ expected_value2 = np.log(data["value2"].values)
+
+ np.testing.assert_array_almost_equal(
+ log_data["value1"].values, expected_value1
+ )
+ np.testing.assert_array_almost_equal(
+ log_data["value2"].values, expected_value2
+ )
+
+ # Check log transform params are stored
+ assert "value1" in log_params
+ assert "value2" in log_params
+
+ def test_log_transform_rejects_non_positive_values(self):
+ """Test that log transform raises error for non-positive values."""
+ data = pd.DataFrame(
+ {
+ "value": [1.0, 2.0, 0.0, 4.0, 5.0], # Contains zero
+ }
+ )
+
+ with pytest.raises(ValueError, match="non-positive values"):
+ log_transform_data(data)
+
+ data_negative = pd.DataFrame(
+ {
+ "value": [1.0, 2.0, -1.0, 4.0, 5.0], # Contains negative
+ }
+ )
+
+ with pytest.raises(ValueError, match="non-positive values"):
+ log_transform_data(data_negative)
+
+ def test_log_transform_returns_copy(self):
+ """Test that log transform returns a copy."""
+ data = pd.DataFrame(
+ {
+ "value": [1.5, 2.7, 3.2, 4.8, 5.1, 6.3, 7.9, 8.4, 9.6, 10.2],
+ "category": [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
+ }
+ )
+ original_data = data.copy()
+
+ log_data, _ = log_transform_data(data)
+
+ # Original data should be unchanged
+ pd.testing.assert_frame_equal(data, original_data)
+
+ # Log transformed data should be different
+ assert not log_data["value"].equals(data["value"])
+
+ def test_log_transform_with_no_numeric_columns(self):
+ """Test log transform with only categorical columns."""
+ data = pd.DataFrame({"cat1": [1, 2, 3, 1, 2], "cat2": [0, 1, 0, 1, 0]})
+
+ log_data, log_params = log_transform_data(data)
+
+ # Data should be unchanged
+ pd.testing.assert_frame_equal(log_data, data)
+
+ # No log transform params should be returned
+ assert log_params == {}
+
+
+class TestUnlogTransformPredictions:
+ """Test the unlog_transform_predictions function."""
+
+ def test_unlog_transform_reverses_log_transform(self):
+ """Test that unlog transform correctly reverses log transform."""
+ original = pd.DataFrame(
+ {
+ "value1": [
+ 1.5,
+ 2.7,
+ 3.2,
+ 4.8,
+ 5.1,
+ 6.3,
+ 7.9,
+ 8.4,
+ 9.6,
+ 10.2,
+ ],
+ "value2": [
+ 15.5,
+ 27.3,
+ 32.1,
+ 48.7,
+ 51.9,
+ 63.2,
+ 79.8,
+ 84.5,
+ 96.1,
+ 102.4,
+ ],
+ }
+ )
+
+ # Apply log transform
+ log_data, log_params = log_transform_data(original)
+
+ # Create imputations dict (simulating prediction output)
+ imputations = {0.5: log_data}
+
+ # Reverse log transform
+ reversed_data = unlog_transform_predictions(imputations, log_params)
+
+ # Should match original data
+ pd.testing.assert_frame_equal(
+ reversed_data[0.5], original, check_exact=False, atol=1e-10
+ )
+
+ def test_unlog_transform_raises_error_for_missing_params(self):
+ """Test that unlog transform raises error when params are missing."""
+ imputations = {
+ 0.5: pd.DataFrame(
+ {
+ "value1": [0.0, 0.69, 1.10],
+ "value2": [2.3, 3.0, 3.9],
+ }
+ )
+ }
+
+ # Only have params for value1, not value2
+ log_params = {"value1": {}}
+
+ with pytest.raises(
+ ValueError, match="Missing log transformation parameters"
+ ):
+ unlog_transform_predictions(imputations, log_params)
+
+
+class TestPreprocessDataWithLogTransform:
+ """Test that preprocess_data correctly uses log transformation."""
+
+ def test_preprocess_data_excludes_categoricals_from_log_transform(self):
+ """Test that preprocess_data doesn't log transform categorical columns."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ 25.3,
+ 30.7,
+ 35.2,
+ 40.9,
+ 45.1,
+ 50.6,
+ 55.8,
+ 60.3,
+ 65.7,
+ 70.2,
+ ],
+ "race": [1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
+ "is_female": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
+ "income": [
+ 50123.45,
+ 60987.23,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ }
+ )
+
+ result, transform_params = preprocess_data(
+ data, full_data=True, log_transform=True
+ )
+
+ # Extract log transform params from nested dict
+ log_params = transform_params["log_transform"]
+
+ # Categorical columns should be unchanged
+ pd.testing.assert_series_equal(result["race"], data["race"])
+ pd.testing.assert_series_equal(result["is_female"], data["is_female"])
+
+ # Numeric columns should be log transformed
+ assert not np.allclose(result["age"].values, data["age"].values)
+ assert not np.allclose(result["income"].values, data["income"].values)
+
+ # Only numeric columns in log_params
+ assert "age" in log_params
+ assert "income" in log_params
+ assert "race" not in log_params
+ assert "is_female" not in log_params
+
+ def test_preprocess_data_raises_error_for_both_normalize_and_log(
+ self,
+ ):
+ """Test that preprocess_data raises error if both normalize and log_transform are True."""
+ data = pd.DataFrame(
+ {
+ "value1": [1.5, 2.7, 3.2, 4.8, 5.1, 6.3, 7.9, 8.4, 9.6, 10.2],
+ "value2": [
+ 15.5,
+ 27.3,
+ 32.1,
+ 48.7,
+ 51.9,
+ 63.2,
+ 79.8,
+ 84.5,
+ 96.1,
+ 102.4,
+ ],
+ }
+ )
+
+ with pytest.raises(
+ ValueError,
+ match="Cannot apply both normalization and log transformation",
+ ):
+ preprocess_data(
+ data, full_data=True, normalize=True, log_transform=True
+ )
+
+ def test_preprocess_data_with_log_transform_and_split(self):
+ """Test that preprocess_data correctly splits and log transforms data."""
+ data = pd.DataFrame(
+ {
+ "value1": [
+ 1.5,
+ 2.7,
+ 3.2,
+ 4.8,
+ 5.1,
+ 6.3,
+ 7.9,
+ 8.4,
+ 9.6,
+ 10.2,
+ 11.5,
+ 12.8,
+ 13.3,
+ 14.9,
+ 15.2,
+ ],
+ "value2": [
+ 15.5,
+ 27.3,
+ 32.1,
+ 48.7,
+ 51.9,
+ 63.2,
+ 79.8,
+ 84.5,
+ 96.1,
+ 102.4,
+ 115.7,
+ 128.9,
+ 133.6,
+ 149.2,
+ 152.8,
+ ],
+ }
+ )
+
+ X_train, X_test, transform_params = preprocess_data(
+ data,
+ full_data=False,
+ test_size=0.2,
+ train_size=None,
+ random_state=42,
+ log_transform=True,
+ )
+
+ # Extract log transform params from nested dict
+ log_params = transform_params["log_transform"]
+
+ # Check that data is split
+ assert len(X_train) == 12
+ assert len(X_test) == 3
+
+ # Check that log params are returned
+ assert "value1" in log_params
+ assert "value2" in log_params
+
+ # Check that values are log transformed (compare to original)
+ assert not any(X_train["value1"].isin(data["value1"]))
+ assert not any(X_test["value1"].isin(data["value1"]))
+
+
+class TestPreprocessDataWithSelectiveTransformation:
+ """Test preprocess_data with selective column transformation."""
+
+ def test_normalize_only_specified_columns(self):
+ """Test that only specified columns are normalized."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ 23,
+ 30,
+ 35,
+ 46,
+ 45,
+ 52,
+ 55,
+ 61,
+ 68,
+ 72,
+ ],
+ "income": [
+ 50123.45,
+ 60987.23,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ "wealth": [
+ 150000.5,
+ 250000.3,
+ 350000.7,
+ 450000.2,
+ 550000.9,
+ 650000.1,
+ 750000.4,
+ 850000.8,
+ 950000.6,
+ 1050000.3,
+ ],
+ }
+ )
+
+ # Only normalize income column
+ result, transform_params = preprocess_data(
+ data, full_data=True, normalize=["income"]
+ )
+
+ # Extract normalization params from nested dict
+ norm_params = transform_params["normalization"]
+
+ # Income should be normalized
+ assert not np.allclose(result["income"].values, data["income"].values)
+ assert "income" in norm_params
+
+ # Age and wealth should NOT be normalized
+ pd.testing.assert_series_equal(result["age"], data["age"])
+ pd.testing.assert_series_equal(result["wealth"], data["wealth"])
+ assert "age" not in norm_params
+ assert "wealth" not in norm_params
+
+ def test_log_transform_only_specified_columns(self):
+ """Test that only specified columns are log transformed."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ 23,
+ 30,
+ 35,
+ 46,
+ 45,
+ 52,
+ 55,
+ 61,
+ 68,
+ 72,
+ ],
+ "income": [
+ 50123.45,
+ 60987.23,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ "wealth": [
+ 150000.5,
+ 250000.3,
+ 350000.7,
+ 450000.2,
+ 550000.9,
+ 650000.1,
+ 750000.4,
+ 850000.8,
+ 950000.6,
+ 1050000.3,
+ ],
+ }
+ )
+
+ # Only log transform income column
+ result, transform_params = preprocess_data(
+ data, full_data=True, log_transform=["income"]
+ )
+
+ # Extract log transform params from nested dict
+ log_params = transform_params["log_transform"]
+
+ # Income should be log transformed
+ assert not np.allclose(result["income"].values, data["income"].values)
+ assert "income" in log_params
+
+ # Age and wealth should NOT be transformed
+ pd.testing.assert_series_equal(result["age"], data["age"])
+ pd.testing.assert_series_equal(result["wealth"], data["wealth"])
+ assert "age" not in log_params
+ assert "wealth" not in log_params
+
+ def test_normalize_multiple_specified_columns(self):
+ """Test normalizing multiple specified columns."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ 23,
+ 30,
+ 35,
+ 46,
+ 45,
+ 52,
+ 55,
+ 61,
+ 68,
+ 72,
+ ],
+ "income": [
+ 50123.45,
+ 60987.23,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ "wealth": [
+ 150000.5,
+ 250000.3,
+ 350000.7,
+ 450000.2,
+ 550000.9,
+ 650000.1,
+ 750000.4,
+ 850000.8,
+ 950000.6,
+ 1050000.3,
+ ],
+ }
+ )
+
+ # Normalize income and wealth, but not age
+ result, transform_params = preprocess_data(
+ data, full_data=True, normalize=["income", "wealth"]
+ )
+
+ # Extract normalization params from nested dict
+ norm_params = transform_params["normalization"]
+
+ # Income and wealth should be normalized
+ assert not np.allclose(result["income"].values, data["income"].values)
+ assert not np.allclose(result["wealth"].values, data["wealth"].values)
+ assert "income" in norm_params
+ assert "wealth" in norm_params
+
+ # Age should NOT be normalized
+ pd.testing.assert_series_equal(result["age"], data["age"])
+ assert "age" not in norm_params
+
+ def test_error_on_nonexistent_column_normalize(self):
+ """Test that error is raised when specifying non-existent column."""
+ data = pd.DataFrame(
+ {
+ "age": [25.3, 30.7, 35.2, 40.9, 45.1],
+ "income": [50123.45, 60987.23, 70456.78, 80234.56, 90876.12],
+ }
+ )
+
+ with pytest.raises(ValueError, match="not found in data"):
+ preprocess_data(
+ data, full_data=True, normalize=["income", "nonexistent"]
+ )
+
+ def test_error_on_nonexistent_column_log_transform(self):
+ """Test that error is raised when specifying non-existent column."""
+ data = pd.DataFrame(
+ {
+ "age": [25.3, 30.7, 35.2, 40.9, 45.1],
+ "income": [50123.45, 60987.23, 70456.78, 80234.56, 90876.12],
+ }
+ )
+
+ with pytest.raises(ValueError, match="not found in data"):
+ preprocess_data(
+ data, full_data=True, log_transform=["income", "nonexistent"]
+ )
+
+ def test_error_on_overlapping_columns(self):
+ """Test error when both normalize and log_transform target same columns."""
+ data = pd.DataFrame(
+ {
+ "age": [25.3, 30.7, 35.2, 40.9, 45.1],
+ "income": [50123.45, 60987.23, 70456.78, 80234.56, 90876.12],
+ }
+ )
+
+ # Error when same column is in both lists
+ with pytest.raises(
+ ValueError, match="Cannot apply both normalization and log"
+ ):
+ preprocess_data(
+ data,
+ full_data=True,
+ normalize=["income", "age"],
+ log_transform=["age"],
+ )
+
+ def test_both_transformations_on_different_columns(self):
+ """Test that both transformations work when applied to different columns."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ 23,
+ 30,
+ 35,
+ 46,
+ 45,
+ 52,
+ 55,
+ 61,
+ 68,
+ 72,
+ ],
+ "income": [
+ 50123.45,
+ 60987.23,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ "wealth": [
+ 150000.5,
+ 250000.3,
+ 350000.7,
+ 450000.2,
+ 550000.9,
+ 650000.1,
+ 750000.4,
+ 850000.8,
+ 950000.6,
+ 1050000.3,
+ ],
+ }
+ )
+
+ # Normalize age, log transform income, leave wealth unchanged
+ result, transform_params = preprocess_data(
+ data,
+ full_data=True,
+ normalize=["age"],
+ log_transform=["income"],
+ )
+
+ # Extract both parameter dicts
+ norm_params = transform_params["normalization"]
+ log_params = transform_params["log_transform"]
+
+ # Age should be normalized
+ assert not np.allclose(result["age"].values, data["age"].values)
+ assert "age" in norm_params
+ assert "age" not in log_params
+
+ # Income should be log transformed
+ assert not np.allclose(result["income"].values, data["income"].values)
+ assert "income" in log_params
+ assert "income" not in norm_params
+
+ # Wealth should be unchanged
+ pd.testing.assert_series_equal(result["wealth"], data["wealth"])
+ assert "wealth" not in norm_params
+ assert "wealth" not in log_params