From 5f10f733cf5befba924eb8784f6cd1ff2b9ca346 Mon Sep 17 00:00:00 2001 From: juaristi22 Date: Mon, 3 Nov 2025 16:38:04 +0800 Subject: [PATCH 1/2] add imputed variable histograms to dashboard --- changelog_entry.yaml | 4 + examples/demo.ipynb | 1778 +++++++++++++++++ examples/pipeline.py | 3 + .../components/DistributionOverlay.tsx | 313 +++ .../components/ImputationResults.tsx | 6 +- .../components/VisualizationDashboard.tsx | 89 +- .../public/microimputation_results.csv | 63 + microimpute/utils/dashboard_formatter.py | 335 ++++ tests/test_dashboard_formatter.py | 309 +++ 9 files changed, 2895 insertions(+), 5 deletions(-) create mode 100644 examples/demo.ipynb create mode 100644 microimputation-dashboard/components/DistributionOverlay.tsx diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..c68d72d 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - Distribution comparison histogram of donor and receiver datasets for imputed variables (to dashboard). diff --git a/examples/demo.ipynb b/examples/demo.ipynb new file mode 100644 index 0000000..c749976 --- /dev/null +++ b/examples/demo.ipynb @@ -0,0 +1,1778 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MicroImpute Pipeline Demo\n", + "\n", + "Demonstrates autoimpute, distribution evaluation, predictor analysis, and dashboard formatting." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup: Import libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import warnings\n", + "from sklearn.datasets import load_diabetes\n", + "\n", + "from microimpute.comparisons.autoimpute import autoimpute\n", + "from microimpute.comparisons.metrics import compare_distributions\n", + "from microimpute.visualizations import method_comparison_results\n", + "from microimpute.evaluations.predictor_analysis import (\n", + " compute_predictor_correlations,\n", + " leave_one_out_analysis,\n", + " progressive_predictor_inclusion,\n", + ")\n", + "from microimpute.utils.dashboard_formatter import format_csv\n", + "from microimpute.models import OLS\n", + "\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Load and prepare data" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the diabetes dataset\n", + "diabetes = load_diabetes()\n", + "diabetes_data = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)\n", + "\n", + "# Split into donor and receiver portions (70/30 split)\n", + "np.random.seed(42)\n", + "donor_indices = np.random.choice(\n", + " len(diabetes_data),\n", + " size=int(0.7 * len(diabetes_data)),\n", + " replace=False,\n", + ")\n", + "receiver_indices = np.array(\n", + " [i for i in range(len(diabetes_data)) if i not in donor_indices]\n", + ")\n", + "\n", + "donor_data = diabetes_data.iloc[donor_indices].reset_index(drop=True)\n", + "receiver_data = diabetes_data.iloc[receiver_indices].reset_index(\n", + " drop=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a categorical risk_factor variable based on cholesterol levels (s4)\n", + "def categorize_risk(s4_value):\n", + " if s4_value < -0.02:\n", + " return \"low\"\n", + " elif s4_value < 0.02:\n", + " return \"medium\"\n", + " else:\n", + " return \"high\"\n", + "\n", + "\n", + "donor_data[\"risk_factor\"] = donor_data[\"s4\"].apply(categorize_risk)\n", + "receiver_data[\"risk_factor\"] = receiver_data[\"s4\"].apply(categorize_risk)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Donor data shape: (309, 11)\n", + "Receiver data shape: (133, 8)\n", + "Predictors: ['age', 'sex', 'bmi', 'bp']\n", + "Variables to impute: ['s1', 's4', 'risk_factor']\n", + "\n", + "Risk factor distribution in donor data:\n", + "risk_factor\n", + "low 112\n", + "high 108\n", + "medium 89\n", + "Name: count, dtype: int64\n" + ] + } + ], + "source": [ + "# Define predictors and variables to impute\n", + "predictors = [\"age\", \"sex\", \"bmi\", \"bp\"]\n", + "imputed_variables = [\"s1\", \"s4\", \"risk_factor\"]\n", + "\n", + "# Remove imputed variables from receiver data\n", + "receiver_data = receiver_data.drop(\n", + " columns=imputed_variables\n", + ")\n", + "\n", + "print(f\"Donor data shape: {donor_data.shape}\")\n", + "print(f\"Receiver data shape: {receiver_data.shape}\")\n", + "print(f\"Predictors: {predictors}\")\n", + "print(f\"Variables to impute: {imputed_variables}\")\n", + "print(f\"\\nRisk factor distribution in donor data:\")\n", + "print(donor_data[\"risk_factor\"].value_counts())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Run autoimpute to find the best imputation method" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "70f7282b79674dbb9ce2522e37a117fb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Evaluating models: 0%| | 0/4 [00:00Quantiles=%{x}
Test Quantile loss=%{y}", + "legendgroup": "QRF", + "marker": { + "color": "#636EFA", + "pattern": { + "shape": "" + } + }, + "name": "QRF", + "offsetgroup": "QRF", + "orientation": "v", + "showlegend": true, + "textposition": "auto", + "type": "bar", + "x": [ + "0.05", + "0.05", + "0.1", + "0.1", + "0.15", + "0.15", + "0.2", + "0.2", + "0.25", + "0.25", + "0.3", + "0.3", + "0.35", + "0.35", + "0.4", + "0.4", + "0.45", + "0.45", + "0.5", + "0.5", + "0.55", + "0.55", + "0.6", + "0.6", + "0.65", + "0.65", + "0.7", + "0.7", + "0.75", + "0.75", + "0.8", + "0.8", + "0.85", + "0.85", + "0.9", + "0.9", + "0.95", + "0.95" + ], + "xaxis": "x", + "y": [ + 0.004976762851613436, + 0.004976762851613436, + 0.007958901707037789, + 0.007958901707037789, + 0.011348256884255009, + 0.011348256884255009, + 0.014479045607066829, + 0.014479045607066829, + 0.01804729212868635, + 0.01804729212868635, + 0.020713966247658862, + 0.020713966247658862, + 0.022742931417508282, + 0.022742931417508282, + 0.02258721277917371, + 0.02258721277917371, + 0.023351581275292608, + 0.023351581275292608, + 0.023625607730980425, + 0.023625607730980425, + 0.023186470004863385, + 0.023186470004863385, + 0.022905432991342744, + 0.022905432991342744, + 0.022230731438156376, + 0.022230731438156376, + 0.02057637886132135, + 0.02057637886132135, + 0.01912773338390523, + 0.01912773338390523, + 0.01739101077609373, + 0.01739101077609373, + 0.013723503511700213, + 0.013723503511700213, + 0.010967503180152355, + 0.010967503180152355, + 0.0071899999204732045, + 0.0071899999204732045 + ], + "yaxis": "y" + }, + { + "alignmentgroup": "True", + "hovertemplate": "Method=OLS
Quantiles=%{x}
Test Quantile loss=%{y}", + "legendgroup": "OLS", + "marker": { + "color": "#EF553B", + "pattern": { + "shape": "" + } + }, + "name": "OLS", + "offsetgroup": "OLS", + "orientation": "v", + "showlegend": true, + "textposition": "auto", + "type": "bar", + "x": [ + "0.05", + "0.05", + "0.1", + "0.1", + "0.15", + "0.15", + "0.2", + "0.2", + "0.25", + "0.25", + "0.3", + "0.3", + "0.35", + "0.35", + "0.4", + "0.4", + "0.45", + "0.45", + "0.5", + "0.5", + "0.55", + "0.55", + "0.6", + "0.6", + "0.65", + "0.65", + "0.7", + "0.7", + "0.75", + "0.75", + "0.8", + "0.8", + "0.85", + "0.85", + "0.9", + "0.9", + "0.95", + "0.95" + ], + "xaxis": "x", + "y": [ + 0.003875743024333408, + 0.003875743024333408, + 0.006636892853346667, + 0.006636892853346667, + 0.008964045456920146, + 0.008964045456920146, + 0.010925268697485854, + 0.010925268697485854, + 0.012586171018988979, + 0.012586171018988979, + 0.013971078532560128, + 0.013971078532560128, + 0.015078371048174203, + 0.015078371048174203, + 0.015960673093506517, + 0.015960673093506517, + 0.016611314804344936, + 0.016611314804344936, + 0.017023427694261576, + 0.017023427694261576, + 0.017162262823655853, + 0.017162262823655853, + 0.016994331164241493, + 0.016994331164241493, + 0.016563877768320707, + 0.016563877768320707, + 0.01584349745574873, + 0.01584349745574873, + 0.014830479499612956, + 0.014830479499612956, + 0.013414418494530157, + 0.013414418494530157, + 0.011604173888267462, + 0.011604173888267462, + 0.009116693282835488, + 0.009116693282835488, + 0.005653910575626517, + 0.005653910575626517 + ], + "yaxis": "y" + }, + { + "alignmentgroup": "True", + "hovertemplate": "Method=Matching
Quantiles=%{x}
Test Quantile loss=%{y}", + "legendgroup": "Matching", + "marker": { + "color": "#00CC96", + "pattern": { + "shape": "" + } + }, + "name": "Matching", + "offsetgroup": "Matching", + "orientation": "v", + "showlegend": true, + "textposition": "auto", + "type": "bar", + "x": [ + "0.05", + "0.05", + "0.1", + "0.1", + "0.15", + "0.15", + "0.2", + "0.2", + "0.25", + "0.25", + "0.3", + "0.3", + "0.35", + "0.35", + "0.4", + "0.4", + "0.45", + "0.45", + "0.5", + "0.5", + "0.55", + "0.55", + "0.6", + "0.6", + "0.65", + "0.65", + "0.7", + "0.7", + "0.75", + "0.75", + "0.8", + "0.8", + "0.85", + "0.85", + "0.9", + "0.9", + "0.95", + "0.95" + ], + "xaxis": "x", + "y": [ + 0.022039382576829306, + 0.022039382576829306, + 0.022147598558993443, + 0.022147598558993443, + 0.022255814541157576, + 0.022255814541157576, + 0.022364030523321716, + 0.022364030523321716, + 0.022472246505485852, + 0.022472246505485852, + 0.02258046248764999, + 0.02258046248764999, + 0.022688678469814125, + 0.022688678469814125, + 0.02279689445197826, + 0.02279689445197826, + 0.0229051104341424, + 0.0229051104341424, + 0.02301332641630653, + 0.02301332641630653, + 0.02312154239847067, + 0.02312154239847067, + 0.023229758380634808, + 0.023229758380634808, + 0.023337974362798945, + 0.023337974362798945, + 0.023446190344963078, + 0.023446190344963078, + 0.023554406327127214, + 0.023554406327127214, + 0.02366262230929135, + 0.02366262230929135, + 0.023770838291455487, + 0.023770838291455487, + 0.023879054273619627, + 0.023879054273619627, + 0.02398727025578376, + 0.02398727025578376 + ], + "yaxis": "y" + } + ], + "layout": { + "barmode": "group", + "height": 600, + "legend": { + "title": { + "text": "Method" + }, + "tracegroupgap": 0 + }, + "paper_bgcolor": "#F0F0F0", + "plot_bgcolor": "#F0F0F0", + "shapes": [ + { + "line": { + "color": "#636EFA", + "dash": "dot", + "width": 2 + }, + "name": "QRF Mean", + "type": "line", + "x0": -0.5, + "x1": 18.5, + "y0": 0.017217385405120106, + "y1": 0.017217385405120106 + }, + { + "line": { + "color": "#EF553B", + "dash": "dot", + "width": 2 + }, + "name": "OLS Mean", + "type": "line", + "x0": -0.5, + "x1": 18.5, + "y0": 0.012779822693513777, + "y1": 0.012779822693513777 + }, + { + "line": { + "color": "#00CC96", + "dash": "dot", + "width": 2 + }, + "name": "Matching Mean", + "type": "line", + "x0": -0.5, + "x1": 18.5, + "y0": 0.023013326416306525, + "y1": 0.023013326416306525 + } + ], + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "font": { + "size": 14 + }, + "text": "Autoimpute method comparison on diabetes dataset" + }, + "width": 750, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "showgrid": false, + "title": { + "font": { + "size": 12 + }, + "text": "Quantiles" + }, + "zeroline": false + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "showgrid": false, + "title": { + "font": { + "size": 12 + }, + "text": "Test Quantile loss" + }, + "zeroline": false + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Extract the model cross-validation results from the autoimpute results object\n", + "comparison_viz = method_comparison_results(\n", + " data=autoimpute_results.cv_results,\n", + " metric=\"quantile_loss\",\n", + " data_format=\"wide\",\n", + ")\n", + "fig = comparison_viz.plot(\n", + " title=\"Autoimpute method comparison on diabetes dataset\",\n", + " show_mean=True,\n", + ")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Evaluate distribution preservation" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Distribution comparison results:\n", + " Variable Metric Distance\n", + "0 s1 wasserstein_distance 0.024660\n", + "1 s4 wasserstein_distance 0.020422\n", + "2 risk_factor kl_divergence 6.033155\n" + ] + } + ], + "source": [ + "distribution_comparison_df = compare_distributions(\n", + " donor_data=donor_data,\n", + " receiver_data=autoimpute_results.receiver_data,\n", + " imputed_variables=imputed_variables,\n", + ")\n", + "\n", + "print(\"Distribution comparison results:\")\n", + "print(distribution_comparison_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Analyze predictor correlations and mutual information" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Correlation analysis completed:\n", + " - Pearson correlation matrix: (4, 4)\n", + " - Spearman correlation matrix: (4, 4)\n", + " - Mutual information matrix: (4, 4)\n", + " - Predictor-target MI: (4, 3)\n" + ] + } + ], + "source": [ + "predictor_correlations = compute_predictor_correlations(\n", + " data=donor_data,\n", + " predictors=predictors,\n", + " imputed_variables=imputed_variables,\n", + " method=\"all\",\n", + ")\n", + "\n", + "print(\"Correlation analysis completed:\")\n", + "print(\n", + " f\" - Pearson correlation matrix: {predictor_correlations['pearson'].shape}\"\n", + ")\n", + "print(\n", + " f\" - Spearman correlation matrix: {predictor_correlations['spearman'].shape}\"\n", + ")\n", + "print(\n", + " f\" - Mutual information matrix: {predictor_correlations['mutual_info'].shape}\"\n", + ")\n", + "print(\n", + " f\" - Predictor-target MI: {predictor_correlations['predictor_target_mi'].shape}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Assess predictor importance via leave-one-out analysis" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "90a63259512d414aa2a4f65e47c446ef", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Leave-one-out analysis: 0%| | 0/4 [00:00 { + return data.filter((d) => d.type === 'distribution_bins'); + }, [data]); + + // Parse and group distribution data by variable + const variableDistributions = useMemo(() => { + const distributions: Record = {}; + + distributionBins.forEach((d) => { + const variable = d.variable; + + if (!distributions[variable]) { + distributions[variable] = { + variable, + type: + d.metric_name === 'histogram_distribution' + ? 'numerical' + : 'categorical', + data: [], + nSamplesDonor: 0, + nSamplesReceiver: 0, + }; + } + + try { + const info = JSON.parse(d.additional_info); + + if (d.metric_name === 'histogram_distribution') { + // Numerical variable + (distributions[variable].data as BinData[]).push({ + binIndex: info.bin_index, + binStart: info.bin_start, + binEnd: info.bin_end, + donorHeight: info.donor_height, + receiverHeight: info.receiver_height, + binLabel: `${info.bin_start.toFixed(2)}-${info.bin_end.toFixed(2)}`, + }); + distributions[variable].nSamplesDonor = info.n_samples_donor; + distributions[variable].nSamplesReceiver = info.n_samples_receiver; + } else if (d.metric_name === 'categorical_distribution') { + // Categorical variable + (distributions[variable].data as CategoryData[]).push({ + category: info.category, + donorProportion: info.donor_proportion, + receiverProportion: info.receiver_proportion, + }); + distributions[variable].nSamplesDonor = info.n_samples_donor; + distributions[variable].nSamplesReceiver = info.n_samples_receiver; + } + } catch (e) { + console.error('Error parsing distribution bin data:', e); + } + }); + + // Sort numerical bins by bin index + Object.values(distributions).forEach((dist) => { + if (dist.type === 'numerical') { + (dist.data as BinData[]).sort((a, b) => a.binIndex - b.binIndex); + } + }); + + return distributions; + }, [distributionBins]); + + const variables = Object.keys(variableDistributions); + const [selectedVariable, setSelectedVariable] = useState( + variables[0] || '' + ); + + if (variables.length === 0) { + return null; + } + + const selectedDistribution = variableDistributions[selectedVariable]; + + const renderNumericalDistribution = (dist: VariableDistribution) => { + const chartData = (dist.data as BinData[]).map((bin) => ({ + name: bin.binLabel, + Donor: bin.donorHeight, + Receiver: bin.receiverHeight, + binStart: bin.binStart, + binEnd: bin.binEnd, + })); + + return ( +
+ + + + + + [`${value.toFixed(2)}%`, '']} + labelFormatter={(label) => `Bin: ${label}`} + contentStyle={{ color: '#000000' }} + labelStyle={{ color: '#000000' }} + /> + + + + + +

+ Histogram with {(dist.data as BinData[]).length} bins. Each bin shows the percentage of values falling within that range. + Overlapping bars indicate similar distributions. +

+
+ ); + }; + + const renderCategoricalDistribution = (dist: VariableDistribution) => { + const chartData = (dist.data as CategoryData[]).map((cat) => ({ + category: cat.category, + Donor: cat.donorProportion, + Receiver: cat.receiverProportion, + })); + + return ( +
+ + + + + + [`${value.toFixed(2)}%`, '']} + contentStyle={{ color: '#000000' }} + labelStyle={{ color: '#000000' }} + /> + + + + + +

+ Side-by-side bars compare the proportion of each category in donor vs receiver data. +

+
+ ); + }; + + return ( +
+
+

+ Distribution comparison +

+

+ Visual comparison of donor and receiver distributions for imputed + variables. Overlapping distributions indicate successful imputation. +

+ + {/* Variable selector */} + {variables.length > 1 && ( +
+ + +
+ )} + + {variables.length === 1 && ( +
+ Variable:{' '} + {selectedVariable}{' '} + + ({selectedDistribution?.type}) + +
+ )} +
+ + {/* Render appropriate chart */} + {selectedDistribution && ( +
+ {selectedDistribution.type === 'numerical' + ? renderNumericalDistribution(selectedDistribution) + : renderCategoricalDistribution(selectedDistribution)} +
+ )} +
+ ); +} diff --git a/microimputation-dashboard/components/ImputationResults.tsx b/microimputation-dashboard/components/ImputationResults.tsx index 598b507..57f5387 100644 --- a/microimputation-dashboard/components/ImputationResults.tsx +++ b/microimputation-dashboard/components/ImputationResults.tsx @@ -3,6 +3,7 @@ import { useMemo } from 'react'; import { ImputationDataPoint } from '@/types/imputation'; import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, Legend, ResponsiveContainer, Cell } from 'recharts'; +import DistributionOverlay from './DistributionOverlay'; interface ImputationResultsProps { data: ImputationDataPoint[]; @@ -86,6 +87,9 @@ export default function ImputationResults({ data }: ImputationResultsProps) {

+ {/* Distribution Overlay Visualization */} + + {/* Wasserstein Distance Section */} {hasWasserstein && (
@@ -207,7 +211,7 @@ export default function ImputationResults({ data }: ImputationResultsProps) { {/* Explanation */} -
+

What is KL-divergence? Kullback-Leibler divergence measures how much one probability distribution differs from another. It quantifies the "information lost" diff --git a/microimputation-dashboard/components/VisualizationDashboard.tsx b/microimputation-dashboard/components/VisualizationDashboard.tsx index 67a1d88..6904329 100644 --- a/microimputation-dashboard/components/VisualizationDashboard.tsx +++ b/microimputation-dashboard/components/VisualizationDashboard.tsx @@ -466,8 +466,8 @@ export default function VisualizationDashboard({ Assessment of the quality of the imputations produced by the best-performing (or the only selected) model

-
- {/* Imputed Variables Section */} +
+ {/* Imputed Variables Section - 1/4 width */}

Imputed Variables @@ -477,7 +477,7 @@ export default function VisualizationDashboard({

{dataAnalysis.imputedVars.length} variable{dataAnalysis.imputedVars.length !== 1 ? 's' : ''} imputed

-
    +
      3 ? 'max-h-32 overflow-y-auto' : ''}`}> {dataAnalysis.imputedVars.map((variable) => (
    • {variable} @@ -492,7 +492,7 @@ export default function VisualizationDashboard({ )}

- {/* Best Model Section */} + {/* Best Model Section - 1/4 width */}

{dataAnalysis.allMethods.length === 1 ? 'Imputation Model' : 'Best Performing Model'} @@ -526,6 +526,87 @@ export default function VisualizationDashboard({

)}

+ + {/* Metrics Section - 1/2 width */} +
+

+ Performance Metrics +

+
+ {/* Average Test Losses */} + {(() => { + const benchmarkData = data.filter(d => d.type === 'benchmark_loss' && d.method === dataAnalysis.bestModel && d.split === 'test'); + + // Calculate avg quantile loss + const quantileLossData = benchmarkData.filter( + d => d.metric_name === 'quantile_loss' && + typeof d.quantile === 'number' && + d.metric_value !== null + ); + const avgQuantileLoss = quantileLossData.length > 0 + ? quantileLossData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / quantileLossData.length + : null; + + // Calculate avg log loss + const logLossData = benchmarkData.filter( + d => d.metric_name === 'log_loss' && + d.metric_value !== null + ); + const avgLogLoss = logLossData.length > 0 + ? logLossData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / logLossData.length + : null; + + // Calculate avg Wasserstein distance + const wassersteinData = data.filter( + d => d.type === 'distribution_distance' && + d.metric_name === 'wasserstein_distance' && + d.metric_value !== null + ); + const avgWasserstein = wassersteinData.length > 0 + ? wassersteinData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / wassersteinData.length + : null; + + // Calculate avg KL divergence + const klData = data.filter( + d => d.type === 'distribution_distance' && + d.metric_name === 'kl_divergence' && + d.metric_value !== null + ); + const avgKL = klData.length > 0 + ? klData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / klData.length + : null; + + return ( + <> + {avgQuantileLoss !== null && ( +
+

Avg. test quantile loss

+

{avgQuantileLoss.toFixed(4)}

+
+ )} + {avgLogLoss !== null && ( +
+

Avg. test log loss

+

{avgLogLoss.toFixed(4)}

+
+ )} + {avgWasserstein !== null && ( +
+

Avg. wasserstein distance

+

{avgWasserstein.toFixed(4)}

+
+ )} + {avgKL !== null && ( +
+

Avg. KL divergence

+

{avgKL.toFixed(4)}

+
+ )} + + ); + })()} +
+
diff --git a/microimputation-dashboard/public/microimputation_results.csv b/microimputation-dashboard/public/microimputation_results.csv index 9ef58a5..8029b03 100644 --- a/microimputation-dashboard/public/microimputation_results.csv +++ b/microimputation-dashboard/public/microimputation_results.csv @@ -292,3 +292,66 @@ progressive_inclusion,OLSResults,N/A,N/A,cumulative_improvement,0.00121758583691 progressive_inclusion,OLSResults,N/A,N/A,marginal_improvement,0.00012356933219281885,test,"{""step"": 3, ""predictor_added"": ""age""}" progressive_inclusion,OLSResults,N/A,N/A,cumulative_improvement,0.0011012642793990501,test,"{""step"": 4, ""predictor_added"": ""bp"", ""predictors"": [""sex"", ""bmi"", ""age"", ""bp""]}" progressive_inclusion,OLSResults,N/A,N/A,marginal_improvement,-0.0001163215575132881,test,"{""step"": 4, ""predictor_added"": ""bp""}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 0, ""bin_start"": -0.12678066991651324, ""bin_end"": -0.11747005557601899, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 1, ""bin_start"": -0.11747005557601899, ""bin_end"": -0.10815944123552473, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 2, ""bin_start"": -0.10815944123552473, ""bin_end"": -0.0988488268950305, ""donor_height"": 0.6472491909385114, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 3, ""bin_start"": -0.0988488268950305, ""bin_end"": -0.08953821255453624, ""donor_height"": 1.2944983818770226, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 4, ""bin_start"": -0.08953821255453624, ""bin_end"": -0.08022759821404199, ""donor_height"": 0.9708737864077669, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 5, ""bin_start"": -0.08022759821404199, ""bin_end"": -0.07091698387354775, ""donor_height"": 2.2653721682847894, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 6, ""bin_start"": -0.07091698387354775, ""bin_end"": -0.0616063695330535, ""donor_height"": 3.2362459546925564, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 7, ""bin_start"": -0.0616063695330535, ""bin_end"": -0.052295755192559246, ""donor_height"": 3.2362459546925564, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 8, ""bin_start"": -0.052295755192559246, ""bin_end"": -0.04298514085206499, ""donor_height"": 5.17799352750809, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 9, ""bin_start"": -0.04298514085206499, ""bin_end"": -0.03367452651157074, ""donor_height"": 8.090614886731393, ""receiver_height"": 2.2556390977443606, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 10, ""bin_start"": -0.03367452651157074, ""bin_end"": -0.0243639121710765, ""donor_height"": 6.796116504854369, ""receiver_height"": 6.7669172932330826, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 11, ""bin_start"": -0.0243639121710765, ""bin_end"": -0.015053297830582249, ""donor_height"": 5.17799352750809, ""receiver_height"": 3.759398496240602, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 12, ""bin_start"": -0.015053297830582249, ""bin_end"": -0.005742683490087996, ""donor_height"": 9.385113268608416, ""receiver_height"": 18.045112781954884, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 13, ""bin_start"": -0.005742683490087996, ""bin_end"": 0.0035679308504062424, ""donor_height"": 10.679611650485437, ""receiver_height"": 24.060150375939852, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 14, ""bin_start"": 0.0035679308504062424, ""bin_end"": 0.012878545190900509, ""donor_height"": 5.501618122977347, ""receiver_height"": 18.796992481203006, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 15, ""bin_start"": 0.012878545190900509, ""bin_end"": 0.022189159531394748, ""donor_height"": 8.09061488673139, ""receiver_height"": 20.30075187969925, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 16, ""bin_start"": 0.022189159531394748, ""bin_end"": 0.031499773871888986, ""donor_height"": 6.4724919093851145, ""receiver_height"": 2.2556390977443606, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 17, ""bin_start"": 0.031499773871888986, ""bin_end"": 0.04081038821238325, ""donor_height"": 5.501618122977347, ""receiver_height"": 3.007518796992481, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 18, ""bin_start"": 0.04081038821238325, ""bin_end"": 0.05012100255287749, ""donor_height"": 3.8834951456310676, ""receiver_height"": 0.7518796992481204, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 19, ""bin_start"": 0.05012100255287749, ""bin_end"": 0.05943161689337176, ""donor_height"": 2.588996763754045, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 20, ""bin_start"": 0.05943161689337176, ""bin_end"": 0.068742231233866, ""donor_height"": 2.2653721682847894, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 21, ""bin_start"": 0.068742231233866, ""bin_end"": 0.07805284557436024, ""donor_height"": 0.9708737864077669, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 22, ""bin_start"": 0.07805284557436024, ""bin_end"": 0.0873634599148545, ""donor_height"": 1.9417475728155338, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 23, ""bin_start"": 0.0873634599148545, ""bin_end"": 0.09667407425534874, ""donor_height"": 2.2653721682847894, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 24, ""bin_start"": 0.09667407425534874, ""bin_end"": 0.10598468859584298, ""donor_height"": 0.6472491909385114, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 25, ""bin_start"": 0.10598468859584298, ""bin_end"": 0.11529530293633725, ""donor_height"": 0.6472491909385113, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 26, ""bin_start"": 0.11529530293633725, ""bin_end"": 0.12460591727683148, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 27, ""bin_start"": 0.12460591727683148, ""bin_end"": 0.13391653161732572, ""donor_height"": 1.2944983818770228, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 28, ""bin_start"": 0.13391653161732572, ""bin_end"": 0.14322714595781996, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 29, ""bin_start"": 0.14322714595781996, ""bin_end"": 0.15253776029831428, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 0, ""bin_start"": -0.0763945037500033, ""bin_end"": -0.06767353884966323, ""donor_height"": 6.796116504854369, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 1, ""bin_start"": -0.06767353884966323, ""bin_end"": -0.05895257394932317, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 2, ""bin_start"": -0.05895257394932317, ""bin_end"": -0.0502316090489831, ""donor_height"": 0.9708737864077671, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 3, ""bin_start"": -0.0502316090489831, ""bin_end"": -0.041510644148643035, ""donor_height"": 0.9708737864077671, ""receiver_height"": 2.255639097744361, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 4, ""bin_start"": -0.041510644148643035, ""bin_end"": -0.03278967924830297, ""donor_height"": 26.537216828478964, ""receiver_height"": 5.263157894736842, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 5, ""bin_start"": -0.03278967924830297, ""bin_end"": -0.0240687143479629, ""donor_height"": 0.6472491909385113, ""receiver_height"": 6.015037593984962, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 6, ""bin_start"": -0.0240687143479629, ""bin_end"": -0.015347749447622835, ""donor_height"": 0.6472491909385113, ""receiver_height"": 11.278195488721805, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 7, ""bin_start"": -0.015347749447622835, ""bin_end"": -0.006626784547282771, ""donor_height"": 1.6181229773462782, ""receiver_height"": 10.526315789473683, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 8, ""bin_start"": -0.006626784547282771, ""bin_end"": 0.0020941803530572928, ""donor_height"": 24.59546925566343, ""receiver_height"": 15.037593984962406, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 9, ""bin_start"": 0.0020941803530572928, ""bin_end"": 0.010815145253397357, ""donor_height"": 0.9708737864077671, ""receiver_height"": 9.774436090225564, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 10, ""bin_start"": 0.010815145253397357, ""bin_end"": 0.01953611015373742, ""donor_height"": 0.9708737864077671, ""receiver_height"": 19.548872180451127, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 11, ""bin_start"": 0.01953611015373742, ""bin_end"": 0.0282570750540775, ""donor_height"": 2.5889967637540456, ""receiver_height"": 6.015037593984963, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 12, ""bin_start"": 0.0282570750540775, ""bin_end"": 0.03697803995441756, ""donor_height"": 17.475728155339805, ""receiver_height"": 3.7593984962406015, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 13, ""bin_start"": 0.03697803995441756, ""bin_end"": 0.04569900485475763, ""donor_height"": 0.32362459546925565, ""receiver_height"": 5.263157894736842, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 14, ""bin_start"": 0.04569900485475763, ""bin_end"": 0.05441996975509768, ""donor_height"": 0.6472491909385113, ""receiver_height"": 1.5037593984962407, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 15, ""bin_start"": 0.05441996975509768, ""bin_end"": 0.06314093465543776, ""donor_height"": 0.3236245954692557, ""receiver_height"": 1.5037593984962407, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 16, ""bin_start"": 0.06314093465543776, ""bin_end"": 0.07186189955577783, ""donor_height"": 7.443365695792881, ""receiver_height"": 2.2556390977443606, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 17, ""bin_start"": 0.07186189955577783, ""bin_end"": 0.08058286445611788, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 18, ""bin_start"": 0.08058286445611788, ""bin_end"": 0.08930382935645796, ""donor_height"": 0.9708737864077669, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 19, ""bin_start"": 0.08930382935645796, ""bin_end"": 0.09802479425679801, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 20, ""bin_start"": 0.09802479425679801, ""bin_end"": 0.10674575915713809, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 21, ""bin_start"": 0.10674575915713809, ""bin_end"": 0.11546672405747814, ""donor_height"": 3.2362459546925564, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 22, ""bin_start"": 0.11546672405747814, ""bin_end"": 0.12418768895781822, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 23, ""bin_start"": 0.12418768895781822, ""bin_end"": 0.1329086538581583, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 24, ""bin_start"": 0.1329086538581583, ""bin_end"": 0.14162961875849833, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 25, ""bin_start"": 0.14162961875849833, ""bin_end"": 0.1503505836588384, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 26, ""bin_start"": 0.1503505836588384, ""bin_end"": 0.1590715485591785, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 27, ""bin_start"": 0.1590715485591785, ""bin_end"": 0.16779251345951857, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 28, ""bin_start"": 0.16779251345951857, ""bin_end"": 0.17651347835985864, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 29, ""bin_start"": 0.17651347835985864, ""bin_end"": 0.18523444326019867, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}" +distribution_bins,OLSResults,risk_factor,N/A,categorical_distribution,,full,"{""category"": ""high"", ""donor_proportion"": 34.95145631067961, ""receiver_proportion"": 45.86466165413533, ""n_samples_donor"": 309, ""n_samples_receiver"": 133}" +distribution_bins,OLSResults,risk_factor,N/A,categorical_distribution,,full,"{""category"": ""low"", ""donor_proportion"": 36.24595469255664, ""receiver_proportion"": 54.13533834586466, ""n_samples_donor"": 309, ""n_samples_receiver"": 133}" +distribution_bins,OLSResults,risk_factor,N/A,categorical_distribution,,full,"{""category"": ""medium"", ""donor_proportion"": 28.802588996763756, ""receiver_proportion"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133}" diff --git a/microimpute/utils/dashboard_formatter.py b/microimpute/utils/dashboard_formatter.py index a5602fa..613f461 100644 --- a/microimpute/utils/dashboard_formatter.py +++ b/microimpute/utils/dashboard_formatter.py @@ -5,9 +5,281 @@ import json from typing import Any, Dict, List, Optional, Union +import numpy as np import pandas as pd +def _compute_histogram_data( + donor_values: np.ndarray, + receiver_values: np.ndarray, + variable_name: str, + n_bins: int = 30, +) -> Dict[str, Union[List[float], str, int]]: + """ + Compute histogram bin data for numerical variables. + + Parameters + ---------- + donor_values : np.ndarray + Original donor dataset values + receiver_values : np.ndarray + Imputed receiver dataset values + variable_name : str + Name of the variable being analyzed + n_bins : int + Number of histogram bins (default: 30) + + Returns + ------- + Dict containing bin edges and heights for both distributions + """ + # Remove NaN values + donor_clean = donor_values[~np.isnan(donor_values)] + receiver_clean = receiver_values[~np.isnan(receiver_values)] + + # Determine bin edges based on combined data range using numpy's auto algorithm + combined = np.concatenate([donor_clean, receiver_clean]) + _, bin_edges = np.histogram(combined, bins=n_bins) + + # Compute histogram heights (normalized as densities) + donor_heights, _ = np.histogram(donor_clean, bins=bin_edges, density=True) + receiver_heights, _ = np.histogram( + receiver_clean, bins=bin_edges, density=True + ) + + # Convert to percentages for easier interpretation + # Multiply by bin width to get probability mass per bin + bin_widths = np.diff(bin_edges) + donor_heights = (donor_heights * bin_widths * 100).tolist() + receiver_heights = (receiver_heights * bin_widths * 100).tolist() + + return { + "variable": variable_name, + "bin_edges": bin_edges.tolist(), + "donor_heights": donor_heights, + "receiver_heights": receiver_heights, + "n_samples_donor": len(donor_clean), + "n_samples_receiver": len(receiver_clean), + "n_bins": n_bins, + } + + +def _compute_categorical_distribution( + donor_values: pd.Series, + receiver_values: pd.Series, + variable_name: str, + max_categories: int = 20, +) -> Dict[str, Union[List, str, bool]]: + """ + Compute distribution data for categorical variables. + + Parameters + ---------- + donor_values : pd.Series + Original donor dataset values + receiver_values : pd.Series + Imputed receiver dataset values + variable_name : str + Name of the variable + max_categories : int + Maximum number of categories to include (others grouped as "Other") + + Returns + ------- + Dict containing category labels and proportions + """ + # Get value counts + donor_counts = donor_values.value_counts() + receiver_counts = receiver_values.value_counts() + + # Get all unique categories + all_categories = list(set(donor_counts.index) | set(receiver_counts.index)) + + # If too many categories, keep top ones and group rest as "Other" + if len(all_categories) > max_categories: + # Get top categories by combined frequency + combined_counts = donor_counts.add(receiver_counts, fill_value=0) + top_categories = combined_counts.nlargest( + max_categories - 1 + ).index.tolist() + + # Calculate "Other" category + donor_other = donor_counts[ + ~donor_counts.index.isin(top_categories) + ].sum() + receiver_other = receiver_counts[ + ~receiver_counts.index.isin(top_categories) + ].sum() + + categories = top_categories + ["Other"] + + # Get proportions + donor_props = [donor_counts.get(cat, 0) for cat in top_categories] + donor_props.append(donor_other) + donor_props = ( + pd.Series(donor_props) / donor_values.count() * 100 + ).tolist() + + receiver_props = [ + receiver_counts.get(cat, 0) for cat in top_categories + ] + receiver_props.append(receiver_other) + receiver_props = ( + pd.Series(receiver_props) / receiver_values.count() * 100 + ).tolist() + else: + categories = sorted(all_categories) + donor_props = [ + (donor_counts.get(cat, 0) / donor_values.count() * 100) + for cat in categories + ] + receiver_props = [ + (receiver_counts.get(cat, 0) / receiver_values.count() * 100) + for cat in categories + ] + + return { + "variable": variable_name, + "categories": categories, + "donor_proportions": donor_props, + "receiver_proportions": receiver_props, + "n_samples_donor": int(donor_values.count()), + "n_samples_receiver": int(receiver_values.count()), + "is_categorical": True, + } + + +def _format_histogram_rows( + histogram_data: Dict[str, Union[List, str, int, bool]], method: str +) -> List[Dict]: + """ + Convert histogram data to CSV row format. + + Parameters + ---------- + histogram_data : Dict + Output from _compute_histogram_data or _compute_categorical_distribution + method : str + Imputation method name + + Returns + ------- + List of dictionaries ready for CSV formatting + """ + rows = [] + + if histogram_data.get("is_categorical", False): + # Categorical variable - store as distribution_bins type + for i, category in enumerate(histogram_data["categories"]): + rows.append( + { + "type": "distribution_bins", + "method": method, + "variable": histogram_data["variable"], + "quantile": "N/A", + "metric_name": "categorical_distribution", + "metric_value": None, # Not used for histograms + "split": "full", + "additional_info": json.dumps( + { + "category": str(category), + "donor_proportion": float( + histogram_data["donor_proportions"][i] + ), + "receiver_proportion": float( + histogram_data["receiver_proportions"][i] + ), + "n_samples_donor": int( + histogram_data["n_samples_donor"] + ), + "n_samples_receiver": int( + histogram_data["n_samples_receiver"] + ), + } + ), + } + ) + else: + # Numerical variable - store bin data + n_bins = len(histogram_data["donor_heights"]) + for i in range(n_bins): + rows.append( + { + "type": "distribution_bins", + "method": method, + "variable": histogram_data["variable"], + "quantile": "N/A", + "metric_name": "histogram_distribution", + "metric_value": None, # Not used for histograms + "split": "full", + "additional_info": json.dumps( + { + "bin_index": int(i), + "bin_start": float(histogram_data["bin_edges"][i]), + "bin_end": float( + histogram_data["bin_edges"][i + 1] + ), + "donor_height": float( + histogram_data["donor_heights"][i] + ), + "receiver_height": float( + histogram_data["receiver_heights"][i] + ), + "n_samples_donor": int( + histogram_data["n_samples_donor"] + ), + "n_samples_receiver": int( + histogram_data["n_samples_receiver"] + ), + "total_bins": int(n_bins), + } + ), + } + ) + + return rows + + +def _validate_imputed_variables( + donor_data: pd.DataFrame, + receiver_data: pd.DataFrame, + imputed_variables: List[str], +) -> None: + """ + Validate that all imputed variables exist in both datasets. + + Parameters + ---------- + donor_data : pd.DataFrame + Original donor dataset + receiver_data : pd.DataFrame + Imputed receiver dataset + imputed_variables : List[str] + List of variable names that were imputed + + Raises + ------ + ValueError + If any imputed variable is missing from either dataset + """ + missing_in_donor = [ + var for var in imputed_variables if var not in donor_data.columns + ] + missing_in_receiver = [ + var for var in imputed_variables if var not in receiver_data.columns + ] + + if missing_in_donor: + raise ValueError( + f"The following imputed variables are missing from donor_data: {missing_in_donor}" + ) + + if missing_in_receiver: + raise ValueError( + f"The following imputed variables are missing from receiver_data: {missing_in_receiver}" + ) + + def format_csv( output_path: Optional[str] = None, autoimpute_result: Optional[Dict] = None, @@ -17,6 +289,10 @@ def format_csv( predictor_importance_df: Optional[pd.DataFrame] = None, progressive_inclusion_df: Optional[pd.DataFrame] = None, best_method_name: Optional[str] = None, + donor_data: Optional[pd.DataFrame] = None, + receiver_data: Optional[pd.DataFrame] = None, + imputed_variables: Optional[List[str]] = None, + n_histogram_bins: int = 30, ) -> pd.DataFrame: """ Format various imputation outputs into a unified long-format CSV for dashboard visualization. @@ -54,11 +330,30 @@ def format_csv( best_method_name : str, optional Name of the best method to append "_best_method" suffix to. + donor_data : pd.DataFrame, optional + Original donor dataset for histogram generation. Required if imputed_variables is provided. + + receiver_data : pd.DataFrame, optional + Imputed receiver dataset for histogram generation. Required if imputed_variables is provided. + + imputed_variables : List[str], optional + List of variable names that were imputed. When provided with donor_data and receiver_data, + histogram bin data will be included in the CSV for distribution visualization. + + n_histogram_bins : int, default 30 + Number of bins to use for numerical variable histograms. + Returns ------- pd.DataFrame Unified long-format DataFrame with columns: ['type', 'method', 'variable', 'quantile', 'metric_name', 'metric_value', 'split', 'additional_info'] + + Raises + ------ + ValueError + If imputed_variables is provided but donor_data or receiver_data is missing. + If any imputed variable is not present in both donor_data and receiver_data. """ rows = [] @@ -339,6 +634,46 @@ def format_csv( } ) + # 7. Process histogram distribution data for imputed variables + if imputed_variables is not None: + # Validate inputs + if donor_data is None or receiver_data is None: + raise ValueError( + "donor_data and receiver_data are required when imputed_variables is provided" + ) + + # Validate that all imputed variables exist in both datasets + _validate_imputed_variables( + donor_data, receiver_data, imputed_variables + ) + + # Generate histogram data for each imputed variable + for var in imputed_variables: + # Check if variable is categorical or numerical + if donor_data[ + var + ].dtype == "object" or pd.api.types.is_categorical_dtype( + donor_data[var] + ): + # Categorical variable + hist_data = _compute_categorical_distribution( + donor_data[var], receiver_data[var], var + ) + else: + # Numerical variable + hist_data = _compute_histogram_data( + donor_data[var].values, + receiver_data[var].values, + var, + n_bins=n_histogram_bins, + ) + + # Format histogram rows and add to main rows list + histogram_rows = _format_histogram_rows( + hist_data, best_method_name if best_method_name else "N/A" + ) + rows.extend(histogram_rows) + # Create DataFrame from rows if not rows: # Return empty DataFrame with correct columns if no data diff --git a/tests/test_dashboard_formatter.py b/tests/test_dashboard_formatter.py index 41bf210..1992069 100644 --- a/tests/test_dashboard_formatter.py +++ b/tests/test_dashboard_formatter.py @@ -19,6 +19,7 @@ VALID_TYPES = { "benchmark_loss", "distribution_distance", + "distribution_bins", "predictor_correlation", "predictor_target_mi", "predictor_importance", @@ -625,6 +626,314 @@ def test_saved_csv_preserves_data(self, sample_autoimpute_result): assert list(df["method"]) == list(loaded_df["method"]) +class TestDistributionBins: + """Test distribution_bins type formatting for histogram data.""" + + @pytest.fixture + def sample_donor_receiver_data(self): + """Create sample donor and receiver datasets with imputed variables.""" + np.random.seed(42) + + # Create donor data + donor_data = pd.DataFrame( + { + "numerical_var1": np.random.normal(100, 15, 200), + "numerical_var2": np.random.exponential(2, 200), + "categorical_var": np.random.choice( + ["A", "B", "C"], 200, p=[0.5, 0.3, 0.2] + ), + "predictor1": np.random.randn(200), + "predictor2": np.random.randn(200), + } + ) + + # Create receiver data (slightly different distributions) + receiver_data = pd.DataFrame( + { + "numerical_var1": np.random.normal( + 102, 14, 150 + ), # Shifted mean + "numerical_var2": np.random.exponential( + 2.1, 150 + ), # Different rate + "categorical_var": np.random.choice( + ["A", "B", "C"], 150, p=[0.4, 0.4, 0.2] + ), + "predictor1": np.random.randn(150), + "predictor2": np.random.randn(150), + } + ) + + return donor_data, receiver_data + + def test_distribution_bins_created(self, sample_donor_receiver_data): + """Test that distribution_bins rows are created when histogram data is provided.""" + donor_data, receiver_data = sample_donor_receiver_data + imputed_variables = [ + "numerical_var1", + "numerical_var2", + "categorical_var", + ] + + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".csv" + ) as f: + output_path = f.name + + try: + result = format_csv( + output_path=output_path, + donor_data=donor_data, + receiver_data=receiver_data, + imputed_variables=imputed_variables, + best_method_name="TestMethod", + n_histogram_bins=20, + ) + + # Check that distribution_bins type exists + dist_bins_rows = result[result["type"] == "distribution_bins"] + assert len(dist_bins_rows) > 0 + + # Check that all imputed variables have bins + variables_with_bins = dist_bins_rows["variable"].unique() + assert set(variables_with_bins) == set(imputed_variables) + + # Check numerical variables have histogram_distribution metric + numerical_bins = dist_bins_rows[ + dist_bins_rows["metric_name"] == "histogram_distribution" + ] + assert "numerical_var1" in numerical_bins["variable"].values + assert "numerical_var2" in numerical_bins["variable"].values + + # Check categorical variable has categorical_distribution metric + categorical_bins = dist_bins_rows[ + dist_bins_rows["metric_name"] == "categorical_distribution" + ] + assert "categorical_var" in categorical_bins["variable"].values + + finally: + Path(output_path).unlink() + + def test_numerical_histogram_heights_match_numpy( + self, sample_donor_receiver_data + ): + """Test that histogram heights match numpy's histogram output.""" + donor_data, receiver_data = sample_donor_receiver_data + var_name = "numerical_var1" + n_bins = 15 + + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".csv" + ) as f: + output_path = f.name + + try: + result = format_csv( + output_path=output_path, + donor_data=donor_data, + receiver_data=receiver_data, + imputed_variables=[var_name], + best_method_name="TestMethod", + n_histogram_bins=n_bins, + ) + + # Get the distribution bins for our variable + dist_bins = result[ + (result["type"] == "distribution_bins") + & (result["variable"] == var_name) + & (result["metric_name"] == "histogram_distribution") + ] + + # Should have n_bins rows for this variable + assert len(dist_bins) == n_bins + + # Extract bin data from additional_info + bin_data = [] + for _, row in dist_bins.iterrows(): + info = json.loads(row["additional_info"]) + bin_data.append(info) + + # Sort by bin index + bin_data = sorted(bin_data, key=lambda x: x["bin_index"]) + + # Manually compute histogram with numpy for comparison + donor_values = donor_data[var_name].values + receiver_values = receiver_data[var_name].values + + # Remove NaN values + donor_clean = donor_values[~np.isnan(donor_values)] + receiver_clean = receiver_values[~np.isnan(receiver_values)] + + # Compute bin edges from combined data (same as in the function) + combined = np.concatenate([donor_clean, receiver_clean]) + _, bin_edges = np.histogram(combined, bins=n_bins) + + # Compute histograms + donor_heights_np, _ = np.histogram( + donor_clean, bins=bin_edges, density=True + ) + receiver_heights_np, _ = np.histogram( + receiver_clean, bins=bin_edges, density=True + ) + + # Convert to percentages (same as in the function) + bin_widths = np.diff(bin_edges) + donor_heights_expected = donor_heights_np * bin_widths * 100 + receiver_heights_expected = receiver_heights_np * bin_widths * 100 + + # Compare heights + for i, data in enumerate(bin_data): + assert data["bin_index"] == i + # Check bin edges + assert np.isclose(data["bin_start"], bin_edges[i], rtol=1e-10) + assert np.isclose( + data["bin_end"], bin_edges[i + 1], rtol=1e-10 + ) + # Check heights match numpy's output + assert np.isclose( + data["donor_height"], donor_heights_expected[i], rtol=1e-10 + ) + assert np.isclose( + data["receiver_height"], + receiver_heights_expected[i], + rtol=1e-10, + ) + # Check sample counts + assert data["n_samples_donor"] == len(donor_clean) + assert data["n_samples_receiver"] == len(receiver_clean) + assert data["total_bins"] == n_bins + + finally: + Path(output_path).unlink() + + def test_categorical_distribution_proportions( + self, sample_donor_receiver_data + ): + """Test that categorical distribution proportions are computed correctly.""" + donor_data, receiver_data = sample_donor_receiver_data + var_name = "categorical_var" + + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".csv" + ) as f: + output_path = f.name + + try: + result = format_csv( + output_path=output_path, + donor_data=donor_data, + receiver_data=receiver_data, + imputed_variables=[var_name], + best_method_name="TestMethod", + ) + + # Get the distribution bins for categorical variable + cat_bins = result[ + (result["type"] == "distribution_bins") + & (result["variable"] == var_name) + & (result["metric_name"] == "categorical_distribution") + ] + + # Should have one row per category + assert len(cat_bins) == 3 # A, B, C + + # Extract category data + category_data = {} + for _, row in cat_bins.iterrows(): + info = json.loads(row["additional_info"]) + category = info["category"] + category_data[category] = info + + # Manually compute proportions + donor_counts = donor_data[var_name].value_counts() + receiver_counts = receiver_data[var_name].value_counts() + + donor_total = donor_data[var_name].count() + receiver_total = receiver_data[var_name].count() + + # Check each category + for category in ["A", "B", "C"]: + assert category in category_data + data = category_data[category] + + # Expected proportions + expected_donor_prop = ( + donor_counts.get(category, 0) / donor_total + ) * 100 + expected_receiver_prop = ( + receiver_counts.get(category, 0) / receiver_total + ) * 100 + + # Check proportions match + assert np.isclose( + data["donor_proportion"], expected_donor_prop, rtol=1e-10 + ) + assert np.isclose( + data["receiver_proportion"], + expected_receiver_prop, + rtol=1e-10, + ) + + # Check sample counts + assert data["n_samples_donor"] == donor_total + assert data["n_samples_receiver"] == receiver_total + + finally: + Path(output_path).unlink() + + def test_error_when_missing_data_for_imputed_variables(self): + """Test that error is raised when donor/receiver data is missing for histogram generation.""" + imputed_variables = ["var1", "var2"] + + with pytest.raises( + ValueError, match="donor_data and receiver_data are required" + ): + format_csv( + imputed_variables=imputed_variables, + donor_data=None, + receiver_data=None, + ) + + # Test with missing donor data + receiver_data = pd.DataFrame({"var1": [1, 2, 3], "var2": [4, 5, 6]}) + with pytest.raises( + ValueError, match="donor_data and receiver_data are required" + ): + format_csv( + imputed_variables=imputed_variables, + donor_data=None, + receiver_data=receiver_data, + ) + + def test_error_when_variable_missing_from_datasets(self): + """Test that error is raised when imputed variable is not in datasets.""" + donor_data = pd.DataFrame({"var1": [1, 2, 3], "var2": [4, 5, 6]}) + receiver_data = pd.DataFrame({"var1": [7, 8, 9], "var2": [10, 11, 12]}) + imputed_variables = ["var1", "var3"] # var3 doesn't exist + + with pytest.raises( + ValueError, match="missing from donor_data: \\['var3'\\]" + ): + format_csv( + donor_data=donor_data, + receiver_data=receiver_data, + imputed_variables=imputed_variables, + ) + + # Test with variable missing from receiver + receiver_data = pd.DataFrame({"var1": [7, 8, 9]}) # Missing var2 + imputed_variables = ["var1", "var2"] + + with pytest.raises( + ValueError, match="missing from receiver_data: \\['var2'\\]" + ): + format_csv( + donor_data=donor_data, + receiver_data=receiver_data, + imputed_variables=imputed_variables, + ) + + class TestEdgeCases: """Test edge cases and error handling.""" From c8d722e00f17ad06b6927114cb20d936e9e9ff8d Mon Sep 17 00:00:00 2001 From: juaristi22 Date: Thu, 6 Nov 2025 12:38:07 +0800 Subject: [PATCH 2/2] enable log transform in data preprocessing --- changelog_entry.yaml | 1 + microimpute/utils/data.py | 363 +++++++++++++++-- tests/test_data_preprocessing.py | 671 ++++++++++++++++++++++++++++++- 3 files changed, 997 insertions(+), 38 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index c68d72d..f9c460a 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -2,3 +2,4 @@ changes: added: - Distribution comparison histogram of donor and receiver datasets for imputed variables (to dashboard). + - Log transformation option for numerical variables in data preprocessing. diff --git a/microimpute/utils/data.py b/microimpute/utils/data.py index 938d70e..3aa5e4e 100644 --- a/microimpute/utils/data.py +++ b/microimpute/utils/data.py @@ -1,18 +1,20 @@ """Data preparation and transformation utilities This module provides comprehensive data preparation functions for imputation workflows, -including data splitting, normalization, unnormalization, and categorical variable handling. +including data splitting, normalization, log transformation, and categorical variable handling. These utilities ensure consistent data preprocessing across different imputation methods. Key functions: - - preprocess_data: split and optionally normalize data for training/testing + - preprocess_data: split and optionally normalize or log-transform data for training/testing - unnormalize_predictions: convert normalized predictions back to original scale + - unlog_transform_predictions: convert log-transformed predictions back to original scale - Handle categorical variables through one-hot encoding """ import logging -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union +import numpy as np import pandas as pd from pydantic import validate_call from sklearn.model_selection import train_test_split @@ -30,6 +32,7 @@ @validate_call(config=VALIDATE_CONFIG) def normalize_data( data: pd.DataFrame, + columns_to_normalize: Optional[List[str]] = None, ) -> Tuple[pd.DataFrame, dict]: """Normalize numeric columns in a DataFrame. @@ -38,6 +41,8 @@ def normalize_data( Args: data: DataFrame to normalize. + columns_to_normalize: Optional list of specific columns to normalize. + If None, all numeric columns will be normalized. Returns: Tuple of (normalized_data, normalization_params) @@ -45,6 +50,7 @@ def normalize_data( to {"mean": float, "std": float}. Raises: + ValueError: If specified columns don't exist in data. RuntimeError: If normalization fails. """ logger.debug("Normalizing data") @@ -64,10 +70,39 @@ def normalize_data( f"Excluding categorical columns from normalization: {categorical_cols}" ) - # Get only numeric columns for normalization - numeric_cols = [ - col for col in data.columns if col not in categorical_cols - ] + # Determine which columns to normalize + if columns_to_normalize is not None: + # Validate that specified columns exist + missing_cols = set(columns_to_normalize) - set(data.columns) + if missing_cols: + error_msg = ( + f"Columns specified for normalization not found in " + f"data: {missing_cols}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + + # Only normalize specified columns that are not categorical + numeric_cols = [ + col + for col in columns_to_normalize + if col not in categorical_cols + ] + + # Warn if user specified categorical columns + specified_categorical = [ + col for col in columns_to_normalize if col in categorical_cols + ] + if specified_categorical: + logger.warning( + f"Skipping normalization for categorical columns: " + f"{specified_categorical}" + ) + else: + # Get all numeric columns for normalization + numeric_cols = [ + col for col in data.columns if col not in categorical_cols + ] if not numeric_cols: logger.warning("No numeric columns found for normalization") @@ -106,6 +141,120 @@ def normalize_data( raise RuntimeError("Failed to normalize data") from e +@validate_call(config=VALIDATE_CONFIG) +def log_transform_data( + data: pd.DataFrame, + columns_to_transform: Optional[List[str]] = None, +) -> Tuple[pd.DataFrame, dict]: + """Apply log transformation to numeric columns in a DataFrame. + + Categorical and boolean columns are excluded from transformation + to prevent issues when they are later encoded as dummy variables. + + Args: + data: DataFrame to log transform. + columns_to_transform: Optional list of specific columns to + log transform. If None, all numeric columns will be transformed. + + Returns: + Tuple of (log_transformed_data, log_transform_params) + where log_transform_params is a dict mapping column names + to {} for reversing the transformation. + + Raises: + ValueError: If data contains non-positive values in numeric columns + or if specified columns don't exist in data. + RuntimeError: If log transformation fails. + """ + logger.debug("Applying log transformation to data") + try: + from microimpute.utils.type_handling import VariableTypeDetector + + # Identify categorical columns to exclude from log transformation + detector = VariableTypeDetector() + categorical_cols = [] + for col in data.columns: + var_type, _ = detector.categorize_variable(data[col], col, logger) + if var_type in ["categorical", "numeric_categorical", "bool"]: + categorical_cols.append(col) + + if categorical_cols: + logger.info( + f"Excluding categorical columns from log transformation: {categorical_cols}" + ) + + # Determine which columns to transform + if columns_to_transform is not None: + # Validate that specified columns exist + missing_cols = set(columns_to_transform) - set(data.columns) + if missing_cols: + error_msg = ( + f"Columns specified for log transformation not found " + f"in data: {missing_cols}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + + # Only transform specified columns that are not categorical + numeric_cols = [ + col + for col in columns_to_transform + if col not in categorical_cols + ] + + # Warn if user specified categorical columns + specified_categorical = [ + col for col in columns_to_transform if col in categorical_cols + ] + if specified_categorical: + logger.warning( + f"Skipping log transformation for categorical " + f"columns: {specified_categorical}" + ) + else: + # Get all numeric columns for log transformation + numeric_cols = [ + col for col in data.columns if col not in categorical_cols + ] + + if not numeric_cols: + logger.warning("No numeric columns found for log transformation") + return data.copy(), {} + + # Check for non-positive values in numeric columns + data_copy = data.copy() + for col in numeric_cols: + min_val = data_copy[col].min() + if min_val <= 0: + error_msg = ( + f"Column '{col}' contains non-positive values " + f"(min={min_val}). Log transformation requires all " + f"positive values." + ) + logger.error(error_msg) + raise ValueError(error_msg) + + # Apply log transformation only to numeric columns + log_transform_params = {} + for col in numeric_cols: + data_copy[col] = np.log(data_copy[col]) + log_transform_params[col] = {} + + logger.debug( + f"Log transformed {len(numeric_cols)} numeric columns successfully" + ) + logger.debug(f"Log transformation parameters: {log_transform_params}") + + return data_copy, log_transform_params + + except ValueError: + # Re-raise ValueError as-is (for non-positive values) + raise + except (TypeError, AttributeError) as e: + logger.error(f"Error during log transformation: {str(e)}") + raise RuntimeError("Failed to apply log transformation") from e + + @validate_call(config=VALIDATE_CONFIG) def preprocess_data( data: pd.DataFrame, @@ -113,7 +262,8 @@ def preprocess_data( train_size: Optional[float] = TRAIN_SIZE, test_size: Optional[float] = TEST_SIZE, random_state: Optional[int] = RANDOM_STATE, - normalize: Optional[bool] = False, + normalize: Optional[Union[bool, List[str]]] = False, + log_transform: Optional[Union[bool, List[str]]] = False, ) -> Union[ Tuple[pd.DataFrame, dict], # when full_data=True Tuple[pd.DataFrame, pd.DataFrame, dict], # when full_data=False @@ -126,17 +276,35 @@ def preprocess_data( train_size: Proportion of the dataset to include in the train split. test_size: Proportion of the dataset to include in the test split. random_state: Random seed for reproducibility. - normalize: Whether to normalize the data. + normalize: Whether to normalize the data. Can be: + - True: normalize all numeric columns + - List of column names: normalize only those columns + - False: no normalization (default) + log_transform: Whether to apply log transformation to the data. Can be: + - True: transform all numeric columns + - List of column names: transform only those columns + - False: no transformation (default) Returns: - Different tuple formats depending on the value of full_data: - - If full_data=True: (data, dummy_info) - - If full_data=False: (X_train, X_test, dummy_info) - - Where dummy_info is a dictionary mapping original columns to their resulting dummy columns + Different tuple formats depending on parameters: + - If full_data=True and transformations applied: + (data, transform_params) + - If full_data=True and no transformations: + data + - If full_data=False and transformations applied: + (X_train, X_test, transform_params) + - If full_data=False and no transformations: + (X_train, X_test) + + Where transform_params is a dict with keys: + - "normalization": dict of normalization parameters (or empty dict) + - "log_transform": dict of log transform parameters (or empty dict) Raises: - ValueError: If data is empty or invalid + ValueError: If data is empty or invalid, or if both normalize and + log_transform would apply to the same columns, or if log_transform + is applied to data with non-positive values, or if specified + columns don't exist in data. RuntimeError: If data preprocessing fails """ @@ -146,23 +314,100 @@ def preprocess_data( if data.empty: raise ValueError("Data must not be None or empty") + + # Check if both normalize and log_transform are requested + normalize_requested = normalize is not False and normalize != [] + log_transform_requested = ( + log_transform is not False and log_transform != [] + ) + + # Validate that normalize and log_transform don't conflict + if normalize_requested and log_transform_requested: + # If both are True, they would apply to all numeric columns - conflict + if normalize is True and log_transform is True: + error_msg = ( + "Cannot apply both normalization and log transformation to " + "all columns. Please specify which columns to transform with " + "each approach using lists." + ) + logger.error(error_msg) + raise ValueError(error_msg) + + # If one is True and the other is a list, conflict + if normalize is True or log_transform is True: + error_msg = ( + "Cannot apply both normalization and log transformation. " + "When using both, specify column lists for each to ensure " + "they apply to different variables." + ) + logger.error(error_msg) + raise ValueError(error_msg) + + # Both are lists - check for overlap + normalize_cols = ( + set(normalize) if isinstance(normalize, list) else set() + ) + log_transform_cols = ( + set(log_transform) if isinstance(log_transform, list) else set() + ) + overlap = normalize_cols & log_transform_cols + + if overlap: + error_msg = ( + f"Cannot apply both normalization and log transformation to " + f"the same columns: {overlap}. Each column can only have one " + f"transformation applied." + ) + logger.error(error_msg) + raise ValueError(error_msg) + # Check for missing values missing_count = data.isna().sum().sum() if missing_count > 0: logger.warning(f"Data contains {missing_count} missing values") - if normalize: - data, normalization_params = normalize_data(data) + # Apply normalization if requested + normalization_params = {} + if normalize_requested: + if isinstance(normalize, bool): + # normalize=True means normalize all numeric columns + data, normalization_params = normalize_data(data) + else: + # normalize is a list of specific columns + data, normalization_params = normalize_data( + data, columns_to_normalize=normalize + ) - if full_data and normalize: - logger.info("Returning full preprocessed dataset") - return ( - data, - normalization_params, - ) - elif full_data: - logger.info("Returning full preprocessed dataset") - return data + # Apply log transformation if requested + log_transform_params = {} + if log_transform_requested: + if isinstance(log_transform, bool): + # log_transform=True means transform all numeric columns + data, log_transform_params = log_transform_data(data) + else: + # log_transform is a list of specific columns + data, log_transform_params = log_transform_data( + data, columns_to_transform=log_transform + ) + + # Prepare transformation parameters to return + has_transformations = normalize_requested or log_transform_requested + if has_transformations: + # Merge both parameter dicts, with a key to distinguish them + transform_params = { + "normalization": normalization_params, + "log_transform": log_transform_params, + } + + if full_data: + if has_transformations: + logger.info( + "Returning full preprocessed dataset with transformations" + ) + return (data, transform_params) + else: + logger.info("Returning full preprocessed dataset") + return data else: logger.debug( f"Splitting data with train_size={train_size}, test_size={test_size}" @@ -177,17 +422,10 @@ def preprocess_data( logger.info( f"Data split into train ({X_train.shape}) and test ({X_test.shape}) sets" ) - if normalize: - return ( - X_train, - X_test, - normalization_params, - ) + if has_transformations: + return (X_train, X_test, transform_params) else: - return ( - X_train, - X_test, - ) + return (X_train, X_test) except (ValueError, TypeError) as e: logger.error(f"Error in processing data: {str(e)}") @@ -236,3 +474,56 @@ def unnormalize_predictions( logger.debug(f"Unnormalized quantile {q} with shape {df_unnorm.shape}") return unnormalized + + +@validate_call(config=VALIDATE_CONFIG) +def unlog_transform_predictions( + imputations: dict, log_transform_params: dict +) -> dict: + """Reverse log transformation on predictions using stored parameters. + + Args: + imputations: Dictionary mapping quantiles to DataFrames of predictions. + log_transform_params: Dictionary with column names that were + log-transformed. + + Returns: + Dictionary with same structure as imputations but with + un-log-transformed values. + + Raises: + ValueError: If columns in imputations don't match log transformation + parameters. + """ + logger.debug( + f"Reversing log transformation for {len(imputations)} quantiles" + ) + + untransformed = {} + for q, df in imputations.items(): + cols = df.columns + + # Check that all columns have log transformation parameters + missing_params = [ + col for col in cols if col not in log_transform_params + ] + if missing_params: + error_msg = ( + f"Missing log transformation parameters for columns: " + f"{missing_params}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + + # Reverse log transformation: x_original = exp(x_log) + df_untransformed = df.copy() + for col in cols: + df_untransformed[col] = np.exp(df[col]) + untransformed[q] = df_untransformed + + logger.debug( + f"Reversed log transformation for quantile {q} with shape " + f"{df_untransformed.shape}" + ) + + return untransformed diff --git a/tests/test_data_preprocessing.py b/tests/test_data_preprocessing.py index 6110e43..ee9ea83 100644 --- a/tests/test_data_preprocessing.py +++ b/tests/test_data_preprocessing.py @@ -4,7 +4,12 @@ import pandas as pd import pytest -from microimpute.utils.data import normalize_data, preprocess_data +from microimpute.utils.data import ( + log_transform_data, + normalize_data, + preprocess_data, + unlog_transform_predictions, +) class TestNormalize: @@ -170,10 +175,13 @@ def test_preprocess_data_excludes_categoricals_from_normalization(self): } ) - result, norm_params = preprocess_data( + result, transform_params = preprocess_data( data, full_data=True, normalize=True ) + # Extract normalization params from nested dict + norm_params = transform_params["normalization"] + # Categorical columns should be unchanged pd.testing.assert_series_equal(result["race"], data["race"]) pd.testing.assert_series_equal(result["is_female"], data["is_female"]) @@ -233,3 +241,662 @@ def test_categorical_columns_dont_get_weird_suffixes_when_dummified( # Column name should not contain decimal points assert "." not in col, f"Column {col} has decimal point in name" + + +class TestLogTransform: + """Test the log_transform_data function.""" + + def test_log_transform_excludes_categorical_columns(self): + """Test that categorical columns are not log transformed.""" + data = pd.DataFrame( + { + "numeric_col": [1.0, 2.5, 3.7, 4.2, 5.9], + "categorical_col": [1, 2, 3, 1, 2], + "boolean_col": [0, 1, 0, 1, 0], + } + ) + + log_data, log_params = log_transform_data(data) + + # Categorical and boolean columns should be unchanged + pd.testing.assert_series_equal( + log_data["categorical_col"], data["categorical_col"] + ) + pd.testing.assert_series_equal( + log_data["boolean_col"], data["boolean_col"] + ) + + # Numeric column should be log transformed + assert not np.allclose( + log_data["numeric_col"].values, data["numeric_col"].values + ) + + # Only numeric column should have log transform params + assert "numeric_col" in log_params + assert "categorical_col" not in log_params + assert "boolean_col" not in log_params + + def test_log_transform_correctly_transforms_numeric_columns(self): + """Test that numeric columns are correctly log transformed.""" + data = pd.DataFrame( + { + "value1": [ + 1.5, + 2.7, + 3.2, + 4.8, + 5.1, + 6.3, + 7.9, + 8.4, + 9.6, + 10.2, + ], + "value2": [ + 15.5, + 27.3, + 32.1, + 48.7, + 51.9, + 63.2, + 79.8, + 84.5, + 96.1, + 102.4, + ], + "category": [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + } + ) + + log_data, log_params = log_transform_data(data) + + # Check that numeric columns are log transformed + expected_value1 = np.log(data["value1"].values) + expected_value2 = np.log(data["value2"].values) + + np.testing.assert_array_almost_equal( + log_data["value1"].values, expected_value1 + ) + np.testing.assert_array_almost_equal( + log_data["value2"].values, expected_value2 + ) + + # Check log transform params are stored + assert "value1" in log_params + assert "value2" in log_params + + def test_log_transform_rejects_non_positive_values(self): + """Test that log transform raises error for non-positive values.""" + data = pd.DataFrame( + { + "value": [1.0, 2.0, 0.0, 4.0, 5.0], # Contains zero + } + ) + + with pytest.raises(ValueError, match="non-positive values"): + log_transform_data(data) + + data_negative = pd.DataFrame( + { + "value": [1.0, 2.0, -1.0, 4.0, 5.0], # Contains negative + } + ) + + with pytest.raises(ValueError, match="non-positive values"): + log_transform_data(data_negative) + + def test_log_transform_returns_copy(self): + """Test that log transform returns a copy.""" + data = pd.DataFrame( + { + "value": [1.5, 2.7, 3.2, 4.8, 5.1, 6.3, 7.9, 8.4, 9.6, 10.2], + "category": [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + } + ) + original_data = data.copy() + + log_data, _ = log_transform_data(data) + + # Original data should be unchanged + pd.testing.assert_frame_equal(data, original_data) + + # Log transformed data should be different + assert not log_data["value"].equals(data["value"]) + + def test_log_transform_with_no_numeric_columns(self): + """Test log transform with only categorical columns.""" + data = pd.DataFrame({"cat1": [1, 2, 3, 1, 2], "cat2": [0, 1, 0, 1, 0]}) + + log_data, log_params = log_transform_data(data) + + # Data should be unchanged + pd.testing.assert_frame_equal(log_data, data) + + # No log transform params should be returned + assert log_params == {} + + +class TestUnlogTransformPredictions: + """Test the unlog_transform_predictions function.""" + + def test_unlog_transform_reverses_log_transform(self): + """Test that unlog transform correctly reverses log transform.""" + original = pd.DataFrame( + { + "value1": [ + 1.5, + 2.7, + 3.2, + 4.8, + 5.1, + 6.3, + 7.9, + 8.4, + 9.6, + 10.2, + ], + "value2": [ + 15.5, + 27.3, + 32.1, + 48.7, + 51.9, + 63.2, + 79.8, + 84.5, + 96.1, + 102.4, + ], + } + ) + + # Apply log transform + log_data, log_params = log_transform_data(original) + + # Create imputations dict (simulating prediction output) + imputations = {0.5: log_data} + + # Reverse log transform + reversed_data = unlog_transform_predictions(imputations, log_params) + + # Should match original data + pd.testing.assert_frame_equal( + reversed_data[0.5], original, check_exact=False, atol=1e-10 + ) + + def test_unlog_transform_raises_error_for_missing_params(self): + """Test that unlog transform raises error when params are missing.""" + imputations = { + 0.5: pd.DataFrame( + { + "value1": [0.0, 0.69, 1.10], + "value2": [2.3, 3.0, 3.9], + } + ) + } + + # Only have params for value1, not value2 + log_params = {"value1": {}} + + with pytest.raises( + ValueError, match="Missing log transformation parameters" + ): + unlog_transform_predictions(imputations, log_params) + + +class TestPreprocessDataWithLogTransform: + """Test that preprocess_data correctly uses log transformation.""" + + def test_preprocess_data_excludes_categoricals_from_log_transform(self): + """Test that preprocess_data doesn't log transform categorical columns.""" + data = pd.DataFrame( + { + "age": [ + 25.3, + 30.7, + 35.2, + 40.9, + 45.1, + 50.6, + 55.8, + 60.3, + 65.7, + 70.2, + ], + "race": [1, 2, 3, 1, 2, 3, 1, 2, 3, 1], + "is_female": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + "income": [ + 50123.45, + 60987.23, + 70456.78, + 80234.56, + 90876.12, + 100543.89, + 110234.67, + 120789.34, + 130456.78, + 140987.23, + ], + } + ) + + result, transform_params = preprocess_data( + data, full_data=True, log_transform=True + ) + + # Extract log transform params from nested dict + log_params = transform_params["log_transform"] + + # Categorical columns should be unchanged + pd.testing.assert_series_equal(result["race"], data["race"]) + pd.testing.assert_series_equal(result["is_female"], data["is_female"]) + + # Numeric columns should be log transformed + assert not np.allclose(result["age"].values, data["age"].values) + assert not np.allclose(result["income"].values, data["income"].values) + + # Only numeric columns in log_params + assert "age" in log_params + assert "income" in log_params + assert "race" not in log_params + assert "is_female" not in log_params + + def test_preprocess_data_raises_error_for_both_normalize_and_log( + self, + ): + """Test that preprocess_data raises error if both normalize and log_transform are True.""" + data = pd.DataFrame( + { + "value1": [1.5, 2.7, 3.2, 4.8, 5.1, 6.3, 7.9, 8.4, 9.6, 10.2], + "value2": [ + 15.5, + 27.3, + 32.1, + 48.7, + 51.9, + 63.2, + 79.8, + 84.5, + 96.1, + 102.4, + ], + } + ) + + with pytest.raises( + ValueError, + match="Cannot apply both normalization and log transformation", + ): + preprocess_data( + data, full_data=True, normalize=True, log_transform=True + ) + + def test_preprocess_data_with_log_transform_and_split(self): + """Test that preprocess_data correctly splits and log transforms data.""" + data = pd.DataFrame( + { + "value1": [ + 1.5, + 2.7, + 3.2, + 4.8, + 5.1, + 6.3, + 7.9, + 8.4, + 9.6, + 10.2, + 11.5, + 12.8, + 13.3, + 14.9, + 15.2, + ], + "value2": [ + 15.5, + 27.3, + 32.1, + 48.7, + 51.9, + 63.2, + 79.8, + 84.5, + 96.1, + 102.4, + 115.7, + 128.9, + 133.6, + 149.2, + 152.8, + ], + } + ) + + X_train, X_test, transform_params = preprocess_data( + data, + full_data=False, + test_size=0.2, + train_size=None, + random_state=42, + log_transform=True, + ) + + # Extract log transform params from nested dict + log_params = transform_params["log_transform"] + + # Check that data is split + assert len(X_train) == 12 + assert len(X_test) == 3 + + # Check that log params are returned + assert "value1" in log_params + assert "value2" in log_params + + # Check that values are log transformed (compare to original) + assert not any(X_train["value1"].isin(data["value1"])) + assert not any(X_test["value1"].isin(data["value1"])) + + +class TestPreprocessDataWithSelectiveTransformation: + """Test preprocess_data with selective column transformation.""" + + def test_normalize_only_specified_columns(self): + """Test that only specified columns are normalized.""" + data = pd.DataFrame( + { + "age": [ + 23, + 30, + 35, + 46, + 45, + 52, + 55, + 61, + 68, + 72, + ], + "income": [ + 50123.45, + 60987.23, + 70456.78, + 80234.56, + 90876.12, + 100543.89, + 110234.67, + 120789.34, + 130456.78, + 140987.23, + ], + "wealth": [ + 150000.5, + 250000.3, + 350000.7, + 450000.2, + 550000.9, + 650000.1, + 750000.4, + 850000.8, + 950000.6, + 1050000.3, + ], + } + ) + + # Only normalize income column + result, transform_params = preprocess_data( + data, full_data=True, normalize=["income"] + ) + + # Extract normalization params from nested dict + norm_params = transform_params["normalization"] + + # Income should be normalized + assert not np.allclose(result["income"].values, data["income"].values) + assert "income" in norm_params + + # Age and wealth should NOT be normalized + pd.testing.assert_series_equal(result["age"], data["age"]) + pd.testing.assert_series_equal(result["wealth"], data["wealth"]) + assert "age" not in norm_params + assert "wealth" not in norm_params + + def test_log_transform_only_specified_columns(self): + """Test that only specified columns are log transformed.""" + data = pd.DataFrame( + { + "age": [ + 23, + 30, + 35, + 46, + 45, + 52, + 55, + 61, + 68, + 72, + ], + "income": [ + 50123.45, + 60987.23, + 70456.78, + 80234.56, + 90876.12, + 100543.89, + 110234.67, + 120789.34, + 130456.78, + 140987.23, + ], + "wealth": [ + 150000.5, + 250000.3, + 350000.7, + 450000.2, + 550000.9, + 650000.1, + 750000.4, + 850000.8, + 950000.6, + 1050000.3, + ], + } + ) + + # Only log transform income column + result, transform_params = preprocess_data( + data, full_data=True, log_transform=["income"] + ) + + # Extract log transform params from nested dict + log_params = transform_params["log_transform"] + + # Income should be log transformed + assert not np.allclose(result["income"].values, data["income"].values) + assert "income" in log_params + + # Age and wealth should NOT be transformed + pd.testing.assert_series_equal(result["age"], data["age"]) + pd.testing.assert_series_equal(result["wealth"], data["wealth"]) + assert "age" not in log_params + assert "wealth" not in log_params + + def test_normalize_multiple_specified_columns(self): + """Test normalizing multiple specified columns.""" + data = pd.DataFrame( + { + "age": [ + 23, + 30, + 35, + 46, + 45, + 52, + 55, + 61, + 68, + 72, + ], + "income": [ + 50123.45, + 60987.23, + 70456.78, + 80234.56, + 90876.12, + 100543.89, + 110234.67, + 120789.34, + 130456.78, + 140987.23, + ], + "wealth": [ + 150000.5, + 250000.3, + 350000.7, + 450000.2, + 550000.9, + 650000.1, + 750000.4, + 850000.8, + 950000.6, + 1050000.3, + ], + } + ) + + # Normalize income and wealth, but not age + result, transform_params = preprocess_data( + data, full_data=True, normalize=["income", "wealth"] + ) + + # Extract normalization params from nested dict + norm_params = transform_params["normalization"] + + # Income and wealth should be normalized + assert not np.allclose(result["income"].values, data["income"].values) + assert not np.allclose(result["wealth"].values, data["wealth"].values) + assert "income" in norm_params + assert "wealth" in norm_params + + # Age should NOT be normalized + pd.testing.assert_series_equal(result["age"], data["age"]) + assert "age" not in norm_params + + def test_error_on_nonexistent_column_normalize(self): + """Test that error is raised when specifying non-existent column.""" + data = pd.DataFrame( + { + "age": [25.3, 30.7, 35.2, 40.9, 45.1], + "income": [50123.45, 60987.23, 70456.78, 80234.56, 90876.12], + } + ) + + with pytest.raises(ValueError, match="not found in data"): + preprocess_data( + data, full_data=True, normalize=["income", "nonexistent"] + ) + + def test_error_on_nonexistent_column_log_transform(self): + """Test that error is raised when specifying non-existent column.""" + data = pd.DataFrame( + { + "age": [25.3, 30.7, 35.2, 40.9, 45.1], + "income": [50123.45, 60987.23, 70456.78, 80234.56, 90876.12], + } + ) + + with pytest.raises(ValueError, match="not found in data"): + preprocess_data( + data, full_data=True, log_transform=["income", "nonexistent"] + ) + + def test_error_on_overlapping_columns(self): + """Test error when both normalize and log_transform target same columns.""" + data = pd.DataFrame( + { + "age": [25.3, 30.7, 35.2, 40.9, 45.1], + "income": [50123.45, 60987.23, 70456.78, 80234.56, 90876.12], + } + ) + + # Error when same column is in both lists + with pytest.raises( + ValueError, match="Cannot apply both normalization and log" + ): + preprocess_data( + data, + full_data=True, + normalize=["income", "age"], + log_transform=["age"], + ) + + def test_both_transformations_on_different_columns(self): + """Test that both transformations work when applied to different columns.""" + data = pd.DataFrame( + { + "age": [ + 23, + 30, + 35, + 46, + 45, + 52, + 55, + 61, + 68, + 72, + ], + "income": [ + 50123.45, + 60987.23, + 70456.78, + 80234.56, + 90876.12, + 100543.89, + 110234.67, + 120789.34, + 130456.78, + 140987.23, + ], + "wealth": [ + 150000.5, + 250000.3, + 350000.7, + 450000.2, + 550000.9, + 650000.1, + 750000.4, + 850000.8, + 950000.6, + 1050000.3, + ], + } + ) + + # Normalize age, log transform income, leave wealth unchanged + result, transform_params = preprocess_data( + data, + full_data=True, + normalize=["age"], + log_transform=["income"], + ) + + # Extract both parameter dicts + norm_params = transform_params["normalization"] + log_params = transform_params["log_transform"] + + # Age should be normalized + assert not np.allclose(result["age"].values, data["age"].values) + assert "age" in norm_params + assert "age" not in log_params + + # Income should be log transformed + assert not np.allclose(result["income"].values, data["income"].values) + assert "income" in log_params + assert "income" not in norm_params + + # Wealth should be unchanged + pd.testing.assert_series_equal(result["wealth"], data["wealth"]) + assert "wealth" not in norm_params + assert "wealth" not in log_params