What is KL-divergence? Kullback-Leibler divergence measures how much
one probability distribution differs from another. It quantifies the "information lost"
diff --git a/microimputation-dashboard/components/VisualizationDashboard.tsx b/microimputation-dashboard/components/VisualizationDashboard.tsx
index 67a1d88..6904329 100644
--- a/microimputation-dashboard/components/VisualizationDashboard.tsx
+++ b/microimputation-dashboard/components/VisualizationDashboard.tsx
@@ -466,8 +466,8 @@ export default function VisualizationDashboard({
Assessment of the quality of the imputations produced by the best-performing (or the only selected) model
-
- {/* Imputed Variables Section */}
+
+ {/* Imputed Variables Section - 1/4 width */}
Imputed Variables
@@ -477,7 +477,7 @@ export default function VisualizationDashboard({
{dataAnalysis.imputedVars.length} variable{dataAnalysis.imputedVars.length !== 1 ? 's' : ''} imputed
-
+ 3 ? 'max-h-32 overflow-y-auto' : ''}`}>
{dataAnalysis.imputedVars.map((variable) => (
-
{variable}
@@ -492,7 +492,7 @@ export default function VisualizationDashboard({
)}
- {/* Best Model Section */}
+ {/* Best Model Section - 1/4 width */}
{dataAnalysis.allMethods.length === 1 ? 'Imputation Model' : 'Best Performing Model'}
@@ -526,6 +526,87 @@ export default function VisualizationDashboard({
)}
+
+ {/* Metrics Section - 1/2 width */}
+
+
+ Performance Metrics
+
+
+ {/* Average Test Losses */}
+ {(() => {
+ const benchmarkData = data.filter(d => d.type === 'benchmark_loss' && d.method === dataAnalysis.bestModel && d.split === 'test');
+
+ // Calculate avg quantile loss
+ const quantileLossData = benchmarkData.filter(
+ d => d.metric_name === 'quantile_loss' &&
+ typeof d.quantile === 'number' &&
+ d.metric_value !== null
+ );
+ const avgQuantileLoss = quantileLossData.length > 0
+ ? quantileLossData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / quantileLossData.length
+ : null;
+
+ // Calculate avg log loss
+ const logLossData = benchmarkData.filter(
+ d => d.metric_name === 'log_loss' &&
+ d.metric_value !== null
+ );
+ const avgLogLoss = logLossData.length > 0
+ ? logLossData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / logLossData.length
+ : null;
+
+ // Calculate avg Wasserstein distance
+ const wassersteinData = data.filter(
+ d => d.type === 'distribution_distance' &&
+ d.metric_name === 'wasserstein_distance' &&
+ d.metric_value !== null
+ );
+ const avgWasserstein = wassersteinData.length > 0
+ ? wassersteinData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / wassersteinData.length
+ : null;
+
+ // Calculate avg KL divergence
+ const klData = data.filter(
+ d => d.type === 'distribution_distance' &&
+ d.metric_name === 'kl_divergence' &&
+ d.metric_value !== null
+ );
+ const avgKL = klData.length > 0
+ ? klData.reduce((sum, d) => sum + (d.metric_value ?? 0), 0) / klData.length
+ : null;
+
+ return (
+ <>
+ {avgQuantileLoss !== null && (
+
+
Avg. test quantile loss
+
{avgQuantileLoss.toFixed(4)}
+
+ )}
+ {avgLogLoss !== null && (
+
+
Avg. test log loss
+
{avgLogLoss.toFixed(4)}
+
+ )}
+ {avgWasserstein !== null && (
+
+
Avg. wasserstein distance
+
{avgWasserstein.toFixed(4)}
+
+ )}
+ {avgKL !== null && (
+
+
Avg. KL divergence
+
{avgKL.toFixed(4)}
+
+ )}
+ >
+ );
+ })()}
+
+
diff --git a/microimputation-dashboard/public/microimputation_results.csv b/microimputation-dashboard/public/microimputation_results.csv
index 9ef58a5..8029b03 100644
--- a/microimputation-dashboard/public/microimputation_results.csv
+++ b/microimputation-dashboard/public/microimputation_results.csv
@@ -292,3 +292,66 @@ progressive_inclusion,OLSResults,N/A,N/A,cumulative_improvement,0.00121758583691
progressive_inclusion,OLSResults,N/A,N/A,marginal_improvement,0.00012356933219281885,test,"{""step"": 3, ""predictor_added"": ""age""}"
progressive_inclusion,OLSResults,N/A,N/A,cumulative_improvement,0.0011012642793990501,test,"{""step"": 4, ""predictor_added"": ""bp"", ""predictors"": [""sex"", ""bmi"", ""age"", ""bp""]}"
progressive_inclusion,OLSResults,N/A,N/A,marginal_improvement,-0.0001163215575132881,test,"{""step"": 4, ""predictor_added"": ""bp""}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 0, ""bin_start"": -0.12678066991651324, ""bin_end"": -0.11747005557601899, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 1, ""bin_start"": -0.11747005557601899, ""bin_end"": -0.10815944123552473, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 2, ""bin_start"": -0.10815944123552473, ""bin_end"": -0.0988488268950305, ""donor_height"": 0.6472491909385114, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 3, ""bin_start"": -0.0988488268950305, ""bin_end"": -0.08953821255453624, ""donor_height"": 1.2944983818770226, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 4, ""bin_start"": -0.08953821255453624, ""bin_end"": -0.08022759821404199, ""donor_height"": 0.9708737864077669, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 5, ""bin_start"": -0.08022759821404199, ""bin_end"": -0.07091698387354775, ""donor_height"": 2.2653721682847894, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 6, ""bin_start"": -0.07091698387354775, ""bin_end"": -0.0616063695330535, ""donor_height"": 3.2362459546925564, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 7, ""bin_start"": -0.0616063695330535, ""bin_end"": -0.052295755192559246, ""donor_height"": 3.2362459546925564, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 8, ""bin_start"": -0.052295755192559246, ""bin_end"": -0.04298514085206499, ""donor_height"": 5.17799352750809, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 9, ""bin_start"": -0.04298514085206499, ""bin_end"": -0.03367452651157074, ""donor_height"": 8.090614886731393, ""receiver_height"": 2.2556390977443606, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 10, ""bin_start"": -0.03367452651157074, ""bin_end"": -0.0243639121710765, ""donor_height"": 6.796116504854369, ""receiver_height"": 6.7669172932330826, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 11, ""bin_start"": -0.0243639121710765, ""bin_end"": -0.015053297830582249, ""donor_height"": 5.17799352750809, ""receiver_height"": 3.759398496240602, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 12, ""bin_start"": -0.015053297830582249, ""bin_end"": -0.005742683490087996, ""donor_height"": 9.385113268608416, ""receiver_height"": 18.045112781954884, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 13, ""bin_start"": -0.005742683490087996, ""bin_end"": 0.0035679308504062424, ""donor_height"": 10.679611650485437, ""receiver_height"": 24.060150375939852, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 14, ""bin_start"": 0.0035679308504062424, ""bin_end"": 0.012878545190900509, ""donor_height"": 5.501618122977347, ""receiver_height"": 18.796992481203006, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 15, ""bin_start"": 0.012878545190900509, ""bin_end"": 0.022189159531394748, ""donor_height"": 8.09061488673139, ""receiver_height"": 20.30075187969925, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 16, ""bin_start"": 0.022189159531394748, ""bin_end"": 0.031499773871888986, ""donor_height"": 6.4724919093851145, ""receiver_height"": 2.2556390977443606, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 17, ""bin_start"": 0.031499773871888986, ""bin_end"": 0.04081038821238325, ""donor_height"": 5.501618122977347, ""receiver_height"": 3.007518796992481, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 18, ""bin_start"": 0.04081038821238325, ""bin_end"": 0.05012100255287749, ""donor_height"": 3.8834951456310676, ""receiver_height"": 0.7518796992481204, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 19, ""bin_start"": 0.05012100255287749, ""bin_end"": 0.05943161689337176, ""donor_height"": 2.588996763754045, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 20, ""bin_start"": 0.05943161689337176, ""bin_end"": 0.068742231233866, ""donor_height"": 2.2653721682847894, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 21, ""bin_start"": 0.068742231233866, ""bin_end"": 0.07805284557436024, ""donor_height"": 0.9708737864077669, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 22, ""bin_start"": 0.07805284557436024, ""bin_end"": 0.0873634599148545, ""donor_height"": 1.9417475728155338, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 23, ""bin_start"": 0.0873634599148545, ""bin_end"": 0.09667407425534874, ""donor_height"": 2.2653721682847894, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 24, ""bin_start"": 0.09667407425534874, ""bin_end"": 0.10598468859584298, ""donor_height"": 0.6472491909385114, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 25, ""bin_start"": 0.10598468859584298, ""bin_end"": 0.11529530293633725, ""donor_height"": 0.6472491909385113, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 26, ""bin_start"": 0.11529530293633725, ""bin_end"": 0.12460591727683148, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 27, ""bin_start"": 0.12460591727683148, ""bin_end"": 0.13391653161732572, ""donor_height"": 1.2944983818770228, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 28, ""bin_start"": 0.13391653161732572, ""bin_end"": 0.14322714595781996, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s1,N/A,histogram_distribution,,full,"{""bin_index"": 29, ""bin_start"": 0.14322714595781996, ""bin_end"": 0.15253776029831428, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 0, ""bin_start"": -0.0763945037500033, ""bin_end"": -0.06767353884966323, ""donor_height"": 6.796116504854369, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 1, ""bin_start"": -0.06767353884966323, ""bin_end"": -0.05895257394932317, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 2, ""bin_start"": -0.05895257394932317, ""bin_end"": -0.0502316090489831, ""donor_height"": 0.9708737864077671, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 3, ""bin_start"": -0.0502316090489831, ""bin_end"": -0.041510644148643035, ""donor_height"": 0.9708737864077671, ""receiver_height"": 2.255639097744361, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 4, ""bin_start"": -0.041510644148643035, ""bin_end"": -0.03278967924830297, ""donor_height"": 26.537216828478964, ""receiver_height"": 5.263157894736842, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 5, ""bin_start"": -0.03278967924830297, ""bin_end"": -0.0240687143479629, ""donor_height"": 0.6472491909385113, ""receiver_height"": 6.015037593984962, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 6, ""bin_start"": -0.0240687143479629, ""bin_end"": -0.015347749447622835, ""donor_height"": 0.6472491909385113, ""receiver_height"": 11.278195488721805, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 7, ""bin_start"": -0.015347749447622835, ""bin_end"": -0.006626784547282771, ""donor_height"": 1.6181229773462782, ""receiver_height"": 10.526315789473683, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 8, ""bin_start"": -0.006626784547282771, ""bin_end"": 0.0020941803530572928, ""donor_height"": 24.59546925566343, ""receiver_height"": 15.037593984962406, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 9, ""bin_start"": 0.0020941803530572928, ""bin_end"": 0.010815145253397357, ""donor_height"": 0.9708737864077671, ""receiver_height"": 9.774436090225564, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 10, ""bin_start"": 0.010815145253397357, ""bin_end"": 0.01953611015373742, ""donor_height"": 0.9708737864077671, ""receiver_height"": 19.548872180451127, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 11, ""bin_start"": 0.01953611015373742, ""bin_end"": 0.0282570750540775, ""donor_height"": 2.5889967637540456, ""receiver_height"": 6.015037593984963, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 12, ""bin_start"": 0.0282570750540775, ""bin_end"": 0.03697803995441756, ""donor_height"": 17.475728155339805, ""receiver_height"": 3.7593984962406015, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 13, ""bin_start"": 0.03697803995441756, ""bin_end"": 0.04569900485475763, ""donor_height"": 0.32362459546925565, ""receiver_height"": 5.263157894736842, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 14, ""bin_start"": 0.04569900485475763, ""bin_end"": 0.05441996975509768, ""donor_height"": 0.6472491909385113, ""receiver_height"": 1.5037593984962407, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 15, ""bin_start"": 0.05441996975509768, ""bin_end"": 0.06314093465543776, ""donor_height"": 0.3236245954692557, ""receiver_height"": 1.5037593984962407, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 16, ""bin_start"": 0.06314093465543776, ""bin_end"": 0.07186189955577783, ""donor_height"": 7.443365695792881, ""receiver_height"": 2.2556390977443606, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 17, ""bin_start"": 0.07186189955577783, ""bin_end"": 0.08058286445611788, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 18, ""bin_start"": 0.08058286445611788, ""bin_end"": 0.08930382935645796, ""donor_height"": 0.9708737864077669, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 19, ""bin_start"": 0.08930382935645796, ""bin_end"": 0.09802479425679801, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 20, ""bin_start"": 0.09802479425679801, ""bin_end"": 0.10674575915713809, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 21, ""bin_start"": 0.10674575915713809, ""bin_end"": 0.11546672405747814, ""donor_height"": 3.2362459546925564, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 22, ""bin_start"": 0.11546672405747814, ""bin_end"": 0.12418768895781822, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 23, ""bin_start"": 0.12418768895781822, ""bin_end"": 0.1329086538581583, ""donor_height"": 0.32362459546925565, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 24, ""bin_start"": 0.1329086538581583, ""bin_end"": 0.14162961875849833, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 25, ""bin_start"": 0.14162961875849833, ""bin_end"": 0.1503505836588384, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 26, ""bin_start"": 0.1503505836588384, ""bin_end"": 0.1590715485591785, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 27, ""bin_start"": 0.1590715485591785, ""bin_end"": 0.16779251345951857, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 28, ""bin_start"": 0.16779251345951857, ""bin_end"": 0.17651347835985864, ""donor_height"": 0.0, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,s4,N/A,histogram_distribution,,full,"{""bin_index"": 29, ""bin_start"": 0.17651347835985864, ""bin_end"": 0.18523444326019867, ""donor_height"": 0.3236245954692557, ""receiver_height"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133, ""total_bins"": 30}"
+distribution_bins,OLSResults,risk_factor,N/A,categorical_distribution,,full,"{""category"": ""high"", ""donor_proportion"": 34.95145631067961, ""receiver_proportion"": 45.86466165413533, ""n_samples_donor"": 309, ""n_samples_receiver"": 133}"
+distribution_bins,OLSResults,risk_factor,N/A,categorical_distribution,,full,"{""category"": ""low"", ""donor_proportion"": 36.24595469255664, ""receiver_proportion"": 54.13533834586466, ""n_samples_donor"": 309, ""n_samples_receiver"": 133}"
+distribution_bins,OLSResults,risk_factor,N/A,categorical_distribution,,full,"{""category"": ""medium"", ""donor_proportion"": 28.802588996763756, ""receiver_proportion"": 0.0, ""n_samples_donor"": 309, ""n_samples_receiver"": 133}"
diff --git a/microimpute/utils/dashboard_formatter.py b/microimpute/utils/dashboard_formatter.py
index a5602fa..613f461 100644
--- a/microimpute/utils/dashboard_formatter.py
+++ b/microimpute/utils/dashboard_formatter.py
@@ -5,9 +5,281 @@
import json
from typing import Any, Dict, List, Optional, Union
+import numpy as np
import pandas as pd
+def _compute_histogram_data(
+ donor_values: np.ndarray,
+ receiver_values: np.ndarray,
+ variable_name: str,
+ n_bins: int = 30,
+) -> Dict[str, Union[List[float], str, int]]:
+ """
+ Compute histogram bin data for numerical variables.
+
+ Parameters
+ ----------
+ donor_values : np.ndarray
+ Original donor dataset values
+ receiver_values : np.ndarray
+ Imputed receiver dataset values
+ variable_name : str
+ Name of the variable being analyzed
+ n_bins : int
+ Number of histogram bins (default: 30)
+
+ Returns
+ -------
+ Dict containing bin edges and heights for both distributions
+ """
+ # Remove NaN values
+ donor_clean = donor_values[~np.isnan(donor_values)]
+ receiver_clean = receiver_values[~np.isnan(receiver_values)]
+
+ # Determine bin edges based on combined data range using numpy's auto algorithm
+ combined = np.concatenate([donor_clean, receiver_clean])
+ _, bin_edges = np.histogram(combined, bins=n_bins)
+
+ # Compute histogram heights (normalized as densities)
+ donor_heights, _ = np.histogram(donor_clean, bins=bin_edges, density=True)
+ receiver_heights, _ = np.histogram(
+ receiver_clean, bins=bin_edges, density=True
+ )
+
+ # Convert to percentages for easier interpretation
+ # Multiply by bin width to get probability mass per bin
+ bin_widths = np.diff(bin_edges)
+ donor_heights = (donor_heights * bin_widths * 100).tolist()
+ receiver_heights = (receiver_heights * bin_widths * 100).tolist()
+
+ return {
+ "variable": variable_name,
+ "bin_edges": bin_edges.tolist(),
+ "donor_heights": donor_heights,
+ "receiver_heights": receiver_heights,
+ "n_samples_donor": len(donor_clean),
+ "n_samples_receiver": len(receiver_clean),
+ "n_bins": n_bins,
+ }
+
+
+def _compute_categorical_distribution(
+ donor_values: pd.Series,
+ receiver_values: pd.Series,
+ variable_name: str,
+ max_categories: int = 20,
+) -> Dict[str, Union[List, str, bool]]:
+ """
+ Compute distribution data for categorical variables.
+
+ Parameters
+ ----------
+ donor_values : pd.Series
+ Original donor dataset values
+ receiver_values : pd.Series
+ Imputed receiver dataset values
+ variable_name : str
+ Name of the variable
+ max_categories : int
+ Maximum number of categories to include (others grouped as "Other")
+
+ Returns
+ -------
+ Dict containing category labels and proportions
+ """
+ # Get value counts
+ donor_counts = donor_values.value_counts()
+ receiver_counts = receiver_values.value_counts()
+
+ # Get all unique categories
+ all_categories = list(set(donor_counts.index) | set(receiver_counts.index))
+
+ # If too many categories, keep top ones and group rest as "Other"
+ if len(all_categories) > max_categories:
+ # Get top categories by combined frequency
+ combined_counts = donor_counts.add(receiver_counts, fill_value=0)
+ top_categories = combined_counts.nlargest(
+ max_categories - 1
+ ).index.tolist()
+
+ # Calculate "Other" category
+ donor_other = donor_counts[
+ ~donor_counts.index.isin(top_categories)
+ ].sum()
+ receiver_other = receiver_counts[
+ ~receiver_counts.index.isin(top_categories)
+ ].sum()
+
+ categories = top_categories + ["Other"]
+
+ # Get proportions
+ donor_props = [donor_counts.get(cat, 0) for cat in top_categories]
+ donor_props.append(donor_other)
+ donor_props = (
+ pd.Series(donor_props) / donor_values.count() * 100
+ ).tolist()
+
+ receiver_props = [
+ receiver_counts.get(cat, 0) for cat in top_categories
+ ]
+ receiver_props.append(receiver_other)
+ receiver_props = (
+ pd.Series(receiver_props) / receiver_values.count() * 100
+ ).tolist()
+ else:
+ categories = sorted(all_categories)
+ donor_props = [
+ (donor_counts.get(cat, 0) / donor_values.count() * 100)
+ for cat in categories
+ ]
+ receiver_props = [
+ (receiver_counts.get(cat, 0) / receiver_values.count() * 100)
+ for cat in categories
+ ]
+
+ return {
+ "variable": variable_name,
+ "categories": categories,
+ "donor_proportions": donor_props,
+ "receiver_proportions": receiver_props,
+ "n_samples_donor": int(donor_values.count()),
+ "n_samples_receiver": int(receiver_values.count()),
+ "is_categorical": True,
+ }
+
+
+def _format_histogram_rows(
+ histogram_data: Dict[str, Union[List, str, int, bool]], method: str
+) -> List[Dict]:
+ """
+ Convert histogram data to CSV row format.
+
+ Parameters
+ ----------
+ histogram_data : Dict
+ Output from _compute_histogram_data or _compute_categorical_distribution
+ method : str
+ Imputation method name
+
+ Returns
+ -------
+ List of dictionaries ready for CSV formatting
+ """
+ rows = []
+
+ if histogram_data.get("is_categorical", False):
+ # Categorical variable - store as distribution_bins type
+ for i, category in enumerate(histogram_data["categories"]):
+ rows.append(
+ {
+ "type": "distribution_bins",
+ "method": method,
+ "variable": histogram_data["variable"],
+ "quantile": "N/A",
+ "metric_name": "categorical_distribution",
+ "metric_value": None, # Not used for histograms
+ "split": "full",
+ "additional_info": json.dumps(
+ {
+ "category": str(category),
+ "donor_proportion": float(
+ histogram_data["donor_proportions"][i]
+ ),
+ "receiver_proportion": float(
+ histogram_data["receiver_proportions"][i]
+ ),
+ "n_samples_donor": int(
+ histogram_data["n_samples_donor"]
+ ),
+ "n_samples_receiver": int(
+ histogram_data["n_samples_receiver"]
+ ),
+ }
+ ),
+ }
+ )
+ else:
+ # Numerical variable - store bin data
+ n_bins = len(histogram_data["donor_heights"])
+ for i in range(n_bins):
+ rows.append(
+ {
+ "type": "distribution_bins",
+ "method": method,
+ "variable": histogram_data["variable"],
+ "quantile": "N/A",
+ "metric_name": "histogram_distribution",
+ "metric_value": None, # Not used for histograms
+ "split": "full",
+ "additional_info": json.dumps(
+ {
+ "bin_index": int(i),
+ "bin_start": float(histogram_data["bin_edges"][i]),
+ "bin_end": float(
+ histogram_data["bin_edges"][i + 1]
+ ),
+ "donor_height": float(
+ histogram_data["donor_heights"][i]
+ ),
+ "receiver_height": float(
+ histogram_data["receiver_heights"][i]
+ ),
+ "n_samples_donor": int(
+ histogram_data["n_samples_donor"]
+ ),
+ "n_samples_receiver": int(
+ histogram_data["n_samples_receiver"]
+ ),
+ "total_bins": int(n_bins),
+ }
+ ),
+ }
+ )
+
+ return rows
+
+
+def _validate_imputed_variables(
+ donor_data: pd.DataFrame,
+ receiver_data: pd.DataFrame,
+ imputed_variables: List[str],
+) -> None:
+ """
+ Validate that all imputed variables exist in both datasets.
+
+ Parameters
+ ----------
+ donor_data : pd.DataFrame
+ Original donor dataset
+ receiver_data : pd.DataFrame
+ Imputed receiver dataset
+ imputed_variables : List[str]
+ List of variable names that were imputed
+
+ Raises
+ ------
+ ValueError
+ If any imputed variable is missing from either dataset
+ """
+ missing_in_donor = [
+ var for var in imputed_variables if var not in donor_data.columns
+ ]
+ missing_in_receiver = [
+ var for var in imputed_variables if var not in receiver_data.columns
+ ]
+
+ if missing_in_donor:
+ raise ValueError(
+ f"The following imputed variables are missing from donor_data: {missing_in_donor}"
+ )
+
+ if missing_in_receiver:
+ raise ValueError(
+ f"The following imputed variables are missing from receiver_data: {missing_in_receiver}"
+ )
+
+
def format_csv(
output_path: Optional[str] = None,
autoimpute_result: Optional[Dict] = None,
@@ -17,6 +289,10 @@ def format_csv(
predictor_importance_df: Optional[pd.DataFrame] = None,
progressive_inclusion_df: Optional[pd.DataFrame] = None,
best_method_name: Optional[str] = None,
+ donor_data: Optional[pd.DataFrame] = None,
+ receiver_data: Optional[pd.DataFrame] = None,
+ imputed_variables: Optional[List[str]] = None,
+ n_histogram_bins: int = 30,
) -> pd.DataFrame:
"""
Format various imputation outputs into a unified long-format CSV for dashboard visualization.
@@ -54,11 +330,30 @@ def format_csv(
best_method_name : str, optional
Name of the best method to append "_best_method" suffix to.
+ donor_data : pd.DataFrame, optional
+ Original donor dataset for histogram generation. Required if imputed_variables is provided.
+
+ receiver_data : pd.DataFrame, optional
+ Imputed receiver dataset for histogram generation. Required if imputed_variables is provided.
+
+ imputed_variables : List[str], optional
+ List of variable names that were imputed. When provided with donor_data and receiver_data,
+ histogram bin data will be included in the CSV for distribution visualization.
+
+ n_histogram_bins : int, default 30
+ Number of bins to use for numerical variable histograms.
+
Returns
-------
pd.DataFrame
Unified long-format DataFrame with columns:
['type', 'method', 'variable', 'quantile', 'metric_name', 'metric_value', 'split', 'additional_info']
+
+ Raises
+ ------
+ ValueError
+ If imputed_variables is provided but donor_data or receiver_data is missing.
+ If any imputed variable is not present in both donor_data and receiver_data.
"""
rows = []
@@ -339,6 +634,46 @@ def format_csv(
}
)
+ # 7. Process histogram distribution data for imputed variables
+ if imputed_variables is not None:
+ # Validate inputs
+ if donor_data is None or receiver_data is None:
+ raise ValueError(
+ "donor_data and receiver_data are required when imputed_variables is provided"
+ )
+
+ # Validate that all imputed variables exist in both datasets
+ _validate_imputed_variables(
+ donor_data, receiver_data, imputed_variables
+ )
+
+ # Generate histogram data for each imputed variable
+ for var in imputed_variables:
+ # Check if variable is categorical or numerical
+ if donor_data[
+ var
+ ].dtype == "object" or pd.api.types.is_categorical_dtype(
+ donor_data[var]
+ ):
+ # Categorical variable
+ hist_data = _compute_categorical_distribution(
+ donor_data[var], receiver_data[var], var
+ )
+ else:
+ # Numerical variable
+ hist_data = _compute_histogram_data(
+ donor_data[var].values,
+ receiver_data[var].values,
+ var,
+ n_bins=n_histogram_bins,
+ )
+
+ # Format histogram rows and add to main rows list
+ histogram_rows = _format_histogram_rows(
+ hist_data, best_method_name if best_method_name else "N/A"
+ )
+ rows.extend(histogram_rows)
+
# Create DataFrame from rows
if not rows:
# Return empty DataFrame with correct columns if no data
diff --git a/microimpute/utils/data.py b/microimpute/utils/data.py
index 938d70e..3aa5e4e 100644
--- a/microimpute/utils/data.py
+++ b/microimpute/utils/data.py
@@ -1,18 +1,20 @@
"""Data preparation and transformation utilities
This module provides comprehensive data preparation functions for imputation workflows,
-including data splitting, normalization, unnormalization, and categorical variable handling.
+including data splitting, normalization, log transformation, and categorical variable handling.
These utilities ensure consistent data preprocessing across different imputation methods.
Key functions:
- - preprocess_data: split and optionally normalize data for training/testing
+ - preprocess_data: split and optionally normalize or log-transform data for training/testing
- unnormalize_predictions: convert normalized predictions back to original scale
+ - unlog_transform_predictions: convert log-transformed predictions back to original scale
- Handle categorical variables through one-hot encoding
"""
import logging
-from typing import Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
+import numpy as np
import pandas as pd
from pydantic import validate_call
from sklearn.model_selection import train_test_split
@@ -30,6 +32,7 @@
@validate_call(config=VALIDATE_CONFIG)
def normalize_data(
data: pd.DataFrame,
+ columns_to_normalize: Optional[List[str]] = None,
) -> Tuple[pd.DataFrame, dict]:
"""Normalize numeric columns in a DataFrame.
@@ -38,6 +41,8 @@ def normalize_data(
Args:
data: DataFrame to normalize.
+ columns_to_normalize: Optional list of specific columns to normalize.
+ If None, all numeric columns will be normalized.
Returns:
Tuple of (normalized_data, normalization_params)
@@ -45,6 +50,7 @@ def normalize_data(
to {"mean": float, "std": float}.
Raises:
+ ValueError: If specified columns don't exist in data.
RuntimeError: If normalization fails.
"""
logger.debug("Normalizing data")
@@ -64,10 +70,39 @@ def normalize_data(
f"Excluding categorical columns from normalization: {categorical_cols}"
)
- # Get only numeric columns for normalization
- numeric_cols = [
- col for col in data.columns if col not in categorical_cols
- ]
+ # Determine which columns to normalize
+ if columns_to_normalize is not None:
+ # Validate that specified columns exist
+ missing_cols = set(columns_to_normalize) - set(data.columns)
+ if missing_cols:
+ error_msg = (
+ f"Columns specified for normalization not found in "
+ f"data: {missing_cols}"
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Only normalize specified columns that are not categorical
+ numeric_cols = [
+ col
+ for col in columns_to_normalize
+ if col not in categorical_cols
+ ]
+
+ # Warn if user specified categorical columns
+ specified_categorical = [
+ col for col in columns_to_normalize if col in categorical_cols
+ ]
+ if specified_categorical:
+ logger.warning(
+ f"Skipping normalization for categorical columns: "
+ f"{specified_categorical}"
+ )
+ else:
+ # Get all numeric columns for normalization
+ numeric_cols = [
+ col for col in data.columns if col not in categorical_cols
+ ]
if not numeric_cols:
logger.warning("No numeric columns found for normalization")
@@ -106,6 +141,120 @@ def normalize_data(
raise RuntimeError("Failed to normalize data") from e
+@validate_call(config=VALIDATE_CONFIG)
+def log_transform_data(
+ data: pd.DataFrame,
+ columns_to_transform: Optional[List[str]] = None,
+) -> Tuple[pd.DataFrame, dict]:
+ """Apply log transformation to numeric columns in a DataFrame.
+
+ Categorical and boolean columns are excluded from transformation
+ to prevent issues when they are later encoded as dummy variables.
+
+ Args:
+ data: DataFrame to log transform.
+ columns_to_transform: Optional list of specific columns to
+ log transform. If None, all numeric columns will be transformed.
+
+ Returns:
+ Tuple of (log_transformed_data, log_transform_params)
+ where log_transform_params is a dict mapping column names
+ to {} for reversing the transformation.
+
+ Raises:
+ ValueError: If data contains non-positive values in numeric columns
+ or if specified columns don't exist in data.
+ RuntimeError: If log transformation fails.
+ """
+ logger.debug("Applying log transformation to data")
+ try:
+ from microimpute.utils.type_handling import VariableTypeDetector
+
+ # Identify categorical columns to exclude from log transformation
+ detector = VariableTypeDetector()
+ categorical_cols = []
+ for col in data.columns:
+ var_type, _ = detector.categorize_variable(data[col], col, logger)
+ if var_type in ["categorical", "numeric_categorical", "bool"]:
+ categorical_cols.append(col)
+
+ if categorical_cols:
+ logger.info(
+ f"Excluding categorical columns from log transformation: {categorical_cols}"
+ )
+
+ # Determine which columns to transform
+ if columns_to_transform is not None:
+ # Validate that specified columns exist
+ missing_cols = set(columns_to_transform) - set(data.columns)
+ if missing_cols:
+ error_msg = (
+ f"Columns specified for log transformation not found "
+ f"in data: {missing_cols}"
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Only transform specified columns that are not categorical
+ numeric_cols = [
+ col
+ for col in columns_to_transform
+ if col not in categorical_cols
+ ]
+
+ # Warn if user specified categorical columns
+ specified_categorical = [
+ col for col in columns_to_transform if col in categorical_cols
+ ]
+ if specified_categorical:
+ logger.warning(
+ f"Skipping log transformation for categorical "
+ f"columns: {specified_categorical}"
+ )
+ else:
+ # Get all numeric columns for log transformation
+ numeric_cols = [
+ col for col in data.columns if col not in categorical_cols
+ ]
+
+ if not numeric_cols:
+ logger.warning("No numeric columns found for log transformation")
+ return data.copy(), {}
+
+ # Check for non-positive values in numeric columns
+ data_copy = data.copy()
+ for col in numeric_cols:
+ min_val = data_copy[col].min()
+ if min_val <= 0:
+ error_msg = (
+ f"Column '{col}' contains non-positive values "
+ f"(min={min_val}). Log transformation requires all "
+ f"positive values."
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Apply log transformation only to numeric columns
+ log_transform_params = {}
+ for col in numeric_cols:
+ data_copy[col] = np.log(data_copy[col])
+ log_transform_params[col] = {}
+
+ logger.debug(
+ f"Log transformed {len(numeric_cols)} numeric columns successfully"
+ )
+ logger.debug(f"Log transformation parameters: {log_transform_params}")
+
+ return data_copy, log_transform_params
+
+ except ValueError:
+ # Re-raise ValueError as-is (for non-positive values)
+ raise
+ except (TypeError, AttributeError) as e:
+ logger.error(f"Error during log transformation: {str(e)}")
+ raise RuntimeError("Failed to apply log transformation") from e
+
+
@validate_call(config=VALIDATE_CONFIG)
def preprocess_data(
data: pd.DataFrame,
@@ -113,7 +262,8 @@ def preprocess_data(
train_size: Optional[float] = TRAIN_SIZE,
test_size: Optional[float] = TEST_SIZE,
random_state: Optional[int] = RANDOM_STATE,
- normalize: Optional[bool] = False,
+ normalize: Optional[Union[bool, List[str]]] = False,
+ log_transform: Optional[Union[bool, List[str]]] = False,
) -> Union[
Tuple[pd.DataFrame, dict], # when full_data=True
Tuple[pd.DataFrame, pd.DataFrame, dict], # when full_data=False
@@ -126,17 +276,35 @@ def preprocess_data(
train_size: Proportion of the dataset to include in the train split.
test_size: Proportion of the dataset to include in the test split.
random_state: Random seed for reproducibility.
- normalize: Whether to normalize the data.
+ normalize: Whether to normalize the data. Can be:
+ - True: normalize all numeric columns
+ - List of column names: normalize only those columns
+ - False: no normalization (default)
+ log_transform: Whether to apply log transformation to the data. Can be:
+ - True: transform all numeric columns
+ - List of column names: transform only those columns
+ - False: no transformation (default)
Returns:
- Different tuple formats depending on the value of full_data:
- - If full_data=True: (data, dummy_info)
- - If full_data=False: (X_train, X_test, dummy_info)
-
- Where dummy_info is a dictionary mapping original columns to their resulting dummy columns
+ Different tuple formats depending on parameters:
+ - If full_data=True and transformations applied:
+ (data, transform_params)
+ - If full_data=True and no transformations:
+ data
+ - If full_data=False and transformations applied:
+ (X_train, X_test, transform_params)
+ - If full_data=False and no transformations:
+ (X_train, X_test)
+
+ Where transform_params is a dict with keys:
+ - "normalization": dict of normalization parameters (or empty dict)
+ - "log_transform": dict of log transform parameters (or empty dict)
Raises:
- ValueError: If data is empty or invalid
+ ValueError: If data is empty or invalid, or if both normalize and
+ log_transform would apply to the same columns, or if log_transform
+ is applied to data with non-positive values, or if specified
+ columns don't exist in data.
RuntimeError: If data preprocessing fails
"""
@@ -146,23 +314,100 @@ def preprocess_data(
if data.empty:
raise ValueError("Data must not be None or empty")
+
+ # Check if both normalize and log_transform are requested
+ normalize_requested = normalize is not False and normalize != []
+ log_transform_requested = (
+ log_transform is not False and log_transform != []
+ )
+
+ # Validate that normalize and log_transform don't conflict
+ if normalize_requested and log_transform_requested:
+ # If both are True, they would apply to all numeric columns - conflict
+ if normalize is True and log_transform is True:
+ error_msg = (
+ "Cannot apply both normalization and log transformation to "
+ "all columns. Please specify which columns to transform with "
+ "each approach using lists."
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # If one is True and the other is a list, conflict
+ if normalize is True or log_transform is True:
+ error_msg = (
+ "Cannot apply both normalization and log transformation. "
+ "When using both, specify column lists for each to ensure "
+ "they apply to different variables."
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Both are lists - check for overlap
+ normalize_cols = (
+ set(normalize) if isinstance(normalize, list) else set()
+ )
+ log_transform_cols = (
+ set(log_transform) if isinstance(log_transform, list) else set()
+ )
+ overlap = normalize_cols & log_transform_cols
+
+ if overlap:
+ error_msg = (
+ f"Cannot apply both normalization and log transformation to "
+ f"the same columns: {overlap}. Each column can only have one "
+ f"transformation applied."
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
# Check for missing values
missing_count = data.isna().sum().sum()
if missing_count > 0:
logger.warning(f"Data contains {missing_count} missing values")
- if normalize:
- data, normalization_params = normalize_data(data)
+ # Apply normalization if requested
+ normalization_params = {}
+ if normalize_requested:
+ if isinstance(normalize, bool):
+ # normalize=True means normalize all numeric columns
+ data, normalization_params = normalize_data(data)
+ else:
+ # normalize is a list of specific columns
+ data, normalization_params = normalize_data(
+ data, columns_to_normalize=normalize
+ )
- if full_data and normalize:
- logger.info("Returning full preprocessed dataset")
- return (
- data,
- normalization_params,
- )
- elif full_data:
- logger.info("Returning full preprocessed dataset")
- return data
+ # Apply log transformation if requested
+ log_transform_params = {}
+ if log_transform_requested:
+ if isinstance(log_transform, bool):
+ # log_transform=True means transform all numeric columns
+ data, log_transform_params = log_transform_data(data)
+ else:
+ # log_transform is a list of specific columns
+ data, log_transform_params = log_transform_data(
+ data, columns_to_transform=log_transform
+ )
+
+ # Prepare transformation parameters to return
+ has_transformations = normalize_requested or log_transform_requested
+ if has_transformations:
+ # Merge both parameter dicts, with a key to distinguish them
+ transform_params = {
+ "normalization": normalization_params,
+ "log_transform": log_transform_params,
+ }
+
+ if full_data:
+ if has_transformations:
+ logger.info(
+ "Returning full preprocessed dataset with transformations"
+ )
+ return (data, transform_params)
+ else:
+ logger.info("Returning full preprocessed dataset")
+ return data
else:
logger.debug(
f"Splitting data with train_size={train_size}, test_size={test_size}"
@@ -177,17 +422,10 @@ def preprocess_data(
logger.info(
f"Data split into train ({X_train.shape}) and test ({X_test.shape}) sets"
)
- if normalize:
- return (
- X_train,
- X_test,
- normalization_params,
- )
+ if has_transformations:
+ return (X_train, X_test, transform_params)
else:
- return (
- X_train,
- X_test,
- )
+ return (X_train, X_test)
except (ValueError, TypeError) as e:
logger.error(f"Error in processing data: {str(e)}")
@@ -236,3 +474,56 @@ def unnormalize_predictions(
logger.debug(f"Unnormalized quantile {q} with shape {df_unnorm.shape}")
return unnormalized
+
+
+@validate_call(config=VALIDATE_CONFIG)
+def unlog_transform_predictions(
+ imputations: dict, log_transform_params: dict
+) -> dict:
+ """Reverse log transformation on predictions using stored parameters.
+
+ Args:
+ imputations: Dictionary mapping quantiles to DataFrames of predictions.
+ log_transform_params: Dictionary with column names that were
+ log-transformed.
+
+ Returns:
+ Dictionary with same structure as imputations but with
+ un-log-transformed values.
+
+ Raises:
+ ValueError: If columns in imputations don't match log transformation
+ parameters.
+ """
+ logger.debug(
+ f"Reversing log transformation for {len(imputations)} quantiles"
+ )
+
+ untransformed = {}
+ for q, df in imputations.items():
+ cols = df.columns
+
+ # Check that all columns have log transformation parameters
+ missing_params = [
+ col for col in cols if col not in log_transform_params
+ ]
+ if missing_params:
+ error_msg = (
+ f"Missing log transformation parameters for columns: "
+ f"{missing_params}"
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ # Reverse log transformation: x_original = exp(x_log)
+ df_untransformed = df.copy()
+ for col in cols:
+ df_untransformed[col] = np.exp(df[col])
+ untransformed[q] = df_untransformed
+
+ logger.debug(
+ f"Reversed log transformation for quantile {q} with shape "
+ f"{df_untransformed.shape}"
+ )
+
+ return untransformed
diff --git a/tests/test_dashboard_formatter.py b/tests/test_dashboard_formatter.py
index 41bf210..1992069 100644
--- a/tests/test_dashboard_formatter.py
+++ b/tests/test_dashboard_formatter.py
@@ -19,6 +19,7 @@
VALID_TYPES = {
"benchmark_loss",
"distribution_distance",
+ "distribution_bins",
"predictor_correlation",
"predictor_target_mi",
"predictor_importance",
@@ -625,6 +626,314 @@ def test_saved_csv_preserves_data(self, sample_autoimpute_result):
assert list(df["method"]) == list(loaded_df["method"])
+class TestDistributionBins:
+ """Test distribution_bins type formatting for histogram data."""
+
+ @pytest.fixture
+ def sample_donor_receiver_data(self):
+ """Create sample donor and receiver datasets with imputed variables."""
+ np.random.seed(42)
+
+ # Create donor data
+ donor_data = pd.DataFrame(
+ {
+ "numerical_var1": np.random.normal(100, 15, 200),
+ "numerical_var2": np.random.exponential(2, 200),
+ "categorical_var": np.random.choice(
+ ["A", "B", "C"], 200, p=[0.5, 0.3, 0.2]
+ ),
+ "predictor1": np.random.randn(200),
+ "predictor2": np.random.randn(200),
+ }
+ )
+
+ # Create receiver data (slightly different distributions)
+ receiver_data = pd.DataFrame(
+ {
+ "numerical_var1": np.random.normal(
+ 102, 14, 150
+ ), # Shifted mean
+ "numerical_var2": np.random.exponential(
+ 2.1, 150
+ ), # Different rate
+ "categorical_var": np.random.choice(
+ ["A", "B", "C"], 150, p=[0.4, 0.4, 0.2]
+ ),
+ "predictor1": np.random.randn(150),
+ "predictor2": np.random.randn(150),
+ }
+ )
+
+ return donor_data, receiver_data
+
+ def test_distribution_bins_created(self, sample_donor_receiver_data):
+ """Test that distribution_bins rows are created when histogram data is provided."""
+ donor_data, receiver_data = sample_donor_receiver_data
+ imputed_variables = [
+ "numerical_var1",
+ "numerical_var2",
+ "categorical_var",
+ ]
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", delete=False, suffix=".csv"
+ ) as f:
+ output_path = f.name
+
+ try:
+ result = format_csv(
+ output_path=output_path,
+ donor_data=donor_data,
+ receiver_data=receiver_data,
+ imputed_variables=imputed_variables,
+ best_method_name="TestMethod",
+ n_histogram_bins=20,
+ )
+
+ # Check that distribution_bins type exists
+ dist_bins_rows = result[result["type"] == "distribution_bins"]
+ assert len(dist_bins_rows) > 0
+
+ # Check that all imputed variables have bins
+ variables_with_bins = dist_bins_rows["variable"].unique()
+ assert set(variables_with_bins) == set(imputed_variables)
+
+ # Check numerical variables have histogram_distribution metric
+ numerical_bins = dist_bins_rows[
+ dist_bins_rows["metric_name"] == "histogram_distribution"
+ ]
+ assert "numerical_var1" in numerical_bins["variable"].values
+ assert "numerical_var2" in numerical_bins["variable"].values
+
+ # Check categorical variable has categorical_distribution metric
+ categorical_bins = dist_bins_rows[
+ dist_bins_rows["metric_name"] == "categorical_distribution"
+ ]
+ assert "categorical_var" in categorical_bins["variable"].values
+
+ finally:
+ Path(output_path).unlink()
+
+ def test_numerical_histogram_heights_match_numpy(
+ self, sample_donor_receiver_data
+ ):
+ """Test that histogram heights match numpy's histogram output."""
+ donor_data, receiver_data = sample_donor_receiver_data
+ var_name = "numerical_var1"
+ n_bins = 15
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", delete=False, suffix=".csv"
+ ) as f:
+ output_path = f.name
+
+ try:
+ result = format_csv(
+ output_path=output_path,
+ donor_data=donor_data,
+ receiver_data=receiver_data,
+ imputed_variables=[var_name],
+ best_method_name="TestMethod",
+ n_histogram_bins=n_bins,
+ )
+
+ # Get the distribution bins for our variable
+ dist_bins = result[
+ (result["type"] == "distribution_bins")
+ & (result["variable"] == var_name)
+ & (result["metric_name"] == "histogram_distribution")
+ ]
+
+ # Should have n_bins rows for this variable
+ assert len(dist_bins) == n_bins
+
+ # Extract bin data from additional_info
+ bin_data = []
+ for _, row in dist_bins.iterrows():
+ info = json.loads(row["additional_info"])
+ bin_data.append(info)
+
+ # Sort by bin index
+ bin_data = sorted(bin_data, key=lambda x: x["bin_index"])
+
+ # Manually compute histogram with numpy for comparison
+ donor_values = donor_data[var_name].values
+ receiver_values = receiver_data[var_name].values
+
+ # Remove NaN values
+ donor_clean = donor_values[~np.isnan(donor_values)]
+ receiver_clean = receiver_values[~np.isnan(receiver_values)]
+
+ # Compute bin edges from combined data (same as in the function)
+ combined = np.concatenate([donor_clean, receiver_clean])
+ _, bin_edges = np.histogram(combined, bins=n_bins)
+
+ # Compute histograms
+ donor_heights_np, _ = np.histogram(
+ donor_clean, bins=bin_edges, density=True
+ )
+ receiver_heights_np, _ = np.histogram(
+ receiver_clean, bins=bin_edges, density=True
+ )
+
+ # Convert to percentages (same as in the function)
+ bin_widths = np.diff(bin_edges)
+ donor_heights_expected = donor_heights_np * bin_widths * 100
+ receiver_heights_expected = receiver_heights_np * bin_widths * 100
+
+ # Compare heights
+ for i, data in enumerate(bin_data):
+ assert data["bin_index"] == i
+ # Check bin edges
+ assert np.isclose(data["bin_start"], bin_edges[i], rtol=1e-10)
+ assert np.isclose(
+ data["bin_end"], bin_edges[i + 1], rtol=1e-10
+ )
+ # Check heights match numpy's output
+ assert np.isclose(
+ data["donor_height"], donor_heights_expected[i], rtol=1e-10
+ )
+ assert np.isclose(
+ data["receiver_height"],
+ receiver_heights_expected[i],
+ rtol=1e-10,
+ )
+ # Check sample counts
+ assert data["n_samples_donor"] == len(donor_clean)
+ assert data["n_samples_receiver"] == len(receiver_clean)
+ assert data["total_bins"] == n_bins
+
+ finally:
+ Path(output_path).unlink()
+
+ def test_categorical_distribution_proportions(
+ self, sample_donor_receiver_data
+ ):
+ """Test that categorical distribution proportions are computed correctly."""
+ donor_data, receiver_data = sample_donor_receiver_data
+ var_name = "categorical_var"
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", delete=False, suffix=".csv"
+ ) as f:
+ output_path = f.name
+
+ try:
+ result = format_csv(
+ output_path=output_path,
+ donor_data=donor_data,
+ receiver_data=receiver_data,
+ imputed_variables=[var_name],
+ best_method_name="TestMethod",
+ )
+
+ # Get the distribution bins for categorical variable
+ cat_bins = result[
+ (result["type"] == "distribution_bins")
+ & (result["variable"] == var_name)
+ & (result["metric_name"] == "categorical_distribution")
+ ]
+
+ # Should have one row per category
+ assert len(cat_bins) == 3 # A, B, C
+
+ # Extract category data
+ category_data = {}
+ for _, row in cat_bins.iterrows():
+ info = json.loads(row["additional_info"])
+ category = info["category"]
+ category_data[category] = info
+
+ # Manually compute proportions
+ donor_counts = donor_data[var_name].value_counts()
+ receiver_counts = receiver_data[var_name].value_counts()
+
+ donor_total = donor_data[var_name].count()
+ receiver_total = receiver_data[var_name].count()
+
+ # Check each category
+ for category in ["A", "B", "C"]:
+ assert category in category_data
+ data = category_data[category]
+
+ # Expected proportions
+ expected_donor_prop = (
+ donor_counts.get(category, 0) / donor_total
+ ) * 100
+ expected_receiver_prop = (
+ receiver_counts.get(category, 0) / receiver_total
+ ) * 100
+
+ # Check proportions match
+ assert np.isclose(
+ data["donor_proportion"], expected_donor_prop, rtol=1e-10
+ )
+ assert np.isclose(
+ data["receiver_proportion"],
+ expected_receiver_prop,
+ rtol=1e-10,
+ )
+
+ # Check sample counts
+ assert data["n_samples_donor"] == donor_total
+ assert data["n_samples_receiver"] == receiver_total
+
+ finally:
+ Path(output_path).unlink()
+
+ def test_error_when_missing_data_for_imputed_variables(self):
+ """Test that error is raised when donor/receiver data is missing for histogram generation."""
+ imputed_variables = ["var1", "var2"]
+
+ with pytest.raises(
+ ValueError, match="donor_data and receiver_data are required"
+ ):
+ format_csv(
+ imputed_variables=imputed_variables,
+ donor_data=None,
+ receiver_data=None,
+ )
+
+ # Test with missing donor data
+ receiver_data = pd.DataFrame({"var1": [1, 2, 3], "var2": [4, 5, 6]})
+ with pytest.raises(
+ ValueError, match="donor_data and receiver_data are required"
+ ):
+ format_csv(
+ imputed_variables=imputed_variables,
+ donor_data=None,
+ receiver_data=receiver_data,
+ )
+
+ def test_error_when_variable_missing_from_datasets(self):
+ """Test that error is raised when imputed variable is not in datasets."""
+ donor_data = pd.DataFrame({"var1": [1, 2, 3], "var2": [4, 5, 6]})
+ receiver_data = pd.DataFrame({"var1": [7, 8, 9], "var2": [10, 11, 12]})
+ imputed_variables = ["var1", "var3"] # var3 doesn't exist
+
+ with pytest.raises(
+ ValueError, match="missing from donor_data: \\['var3'\\]"
+ ):
+ format_csv(
+ donor_data=donor_data,
+ receiver_data=receiver_data,
+ imputed_variables=imputed_variables,
+ )
+
+ # Test with variable missing from receiver
+ receiver_data = pd.DataFrame({"var1": [7, 8, 9]}) # Missing var2
+ imputed_variables = ["var1", "var2"]
+
+ with pytest.raises(
+ ValueError, match="missing from receiver_data: \\['var2'\\]"
+ ):
+ format_csv(
+ donor_data=donor_data,
+ receiver_data=receiver_data,
+ imputed_variables=imputed_variables,
+ )
+
+
class TestEdgeCases:
"""Test edge cases and error handling."""
diff --git a/tests/test_data_preprocessing.py b/tests/test_data_preprocessing.py
index 6110e43..ee9ea83 100644
--- a/tests/test_data_preprocessing.py
+++ b/tests/test_data_preprocessing.py
@@ -4,7 +4,12 @@
import pandas as pd
import pytest
-from microimpute.utils.data import normalize_data, preprocess_data
+from microimpute.utils.data import (
+ log_transform_data,
+ normalize_data,
+ preprocess_data,
+ unlog_transform_predictions,
+)
class TestNormalize:
@@ -170,10 +175,13 @@ def test_preprocess_data_excludes_categoricals_from_normalization(self):
}
)
- result, norm_params = preprocess_data(
+ result, transform_params = preprocess_data(
data, full_data=True, normalize=True
)
+ # Extract normalization params from nested dict
+ norm_params = transform_params["normalization"]
+
# Categorical columns should be unchanged
pd.testing.assert_series_equal(result["race"], data["race"])
pd.testing.assert_series_equal(result["is_female"], data["is_female"])
@@ -233,3 +241,662 @@ def test_categorical_columns_dont_get_weird_suffixes_when_dummified(
# Column name should not contain decimal points
assert "." not in col, f"Column {col} has decimal point in name"
+
+
+class TestLogTransform:
+ """Test the log_transform_data function."""
+
+ def test_log_transform_excludes_categorical_columns(self):
+ """Test that categorical columns are not log transformed."""
+ data = pd.DataFrame(
+ {
+ "numeric_col": [1.0, 2.5, 3.7, 4.2, 5.9],
+ "categorical_col": [1, 2, 3, 1, 2],
+ "boolean_col": [0, 1, 0, 1, 0],
+ }
+ )
+
+ log_data, log_params = log_transform_data(data)
+
+ # Categorical and boolean columns should be unchanged
+ pd.testing.assert_series_equal(
+ log_data["categorical_col"], data["categorical_col"]
+ )
+ pd.testing.assert_series_equal(
+ log_data["boolean_col"], data["boolean_col"]
+ )
+
+ # Numeric column should be log transformed
+ assert not np.allclose(
+ log_data["numeric_col"].values, data["numeric_col"].values
+ )
+
+ # Only numeric column should have log transform params
+ assert "numeric_col" in log_params
+ assert "categorical_col" not in log_params
+ assert "boolean_col" not in log_params
+
+ def test_log_transform_correctly_transforms_numeric_columns(self):
+ """Test that numeric columns are correctly log transformed."""
+ data = pd.DataFrame(
+ {
+ "value1": [
+ 1.5,
+ 2.7,
+ 3.2,
+ 4.8,
+ 5.1,
+ 6.3,
+ 7.9,
+ 8.4,
+ 9.6,
+ 10.2,
+ ],
+ "value2": [
+ 15.5,
+ 27.3,
+ 32.1,
+ 48.7,
+ 51.9,
+ 63.2,
+ 79.8,
+ 84.5,
+ 96.1,
+ 102.4,
+ ],
+ "category": [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
+ }
+ )
+
+ log_data, log_params = log_transform_data(data)
+
+ # Check that numeric columns are log transformed
+ expected_value1 = np.log(data["value1"].values)
+ expected_value2 = np.log(data["value2"].values)
+
+ np.testing.assert_array_almost_equal(
+ log_data["value1"].values, expected_value1
+ )
+ np.testing.assert_array_almost_equal(
+ log_data["value2"].values, expected_value2
+ )
+
+ # Check log transform params are stored
+ assert "value1" in log_params
+ assert "value2" in log_params
+
+ def test_log_transform_rejects_non_positive_values(self):
+ """Test that log transform raises error for non-positive values."""
+ data = pd.DataFrame(
+ {
+ "value": [1.0, 2.0, 0.0, 4.0, 5.0], # Contains zero
+ }
+ )
+
+ with pytest.raises(ValueError, match="non-positive values"):
+ log_transform_data(data)
+
+ data_negative = pd.DataFrame(
+ {
+ "value": [1.0, 2.0, -1.0, 4.0, 5.0], # Contains negative
+ }
+ )
+
+ with pytest.raises(ValueError, match="non-positive values"):
+ log_transform_data(data_negative)
+
+ def test_log_transform_returns_copy(self):
+ """Test that log transform returns a copy."""
+ data = pd.DataFrame(
+ {
+ "value": [1.5, 2.7, 3.2, 4.8, 5.1, 6.3, 7.9, 8.4, 9.6, 10.2],
+ "category": [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
+ }
+ )
+ original_data = data.copy()
+
+ log_data, _ = log_transform_data(data)
+
+ # Original data should be unchanged
+ pd.testing.assert_frame_equal(data, original_data)
+
+ # Log transformed data should be different
+ assert not log_data["value"].equals(data["value"])
+
+ def test_log_transform_with_no_numeric_columns(self):
+ """Test log transform with only categorical columns."""
+ data = pd.DataFrame({"cat1": [1, 2, 3, 1, 2], "cat2": [0, 1, 0, 1, 0]})
+
+ log_data, log_params = log_transform_data(data)
+
+ # Data should be unchanged
+ pd.testing.assert_frame_equal(log_data, data)
+
+ # No log transform params should be returned
+ assert log_params == {}
+
+
+class TestUnlogTransformPredictions:
+ """Test the unlog_transform_predictions function."""
+
+ def test_unlog_transform_reverses_log_transform(self):
+ """Test that unlog transform correctly reverses log transform."""
+ original = pd.DataFrame(
+ {
+ "value1": [
+ 1.5,
+ 2.7,
+ 3.2,
+ 4.8,
+ 5.1,
+ 6.3,
+ 7.9,
+ 8.4,
+ 9.6,
+ 10.2,
+ ],
+ "value2": [
+ 15.5,
+ 27.3,
+ 32.1,
+ 48.7,
+ 51.9,
+ 63.2,
+ 79.8,
+ 84.5,
+ 96.1,
+ 102.4,
+ ],
+ }
+ )
+
+ # Apply log transform
+ log_data, log_params = log_transform_data(original)
+
+ # Create imputations dict (simulating prediction output)
+ imputations = {0.5: log_data}
+
+ # Reverse log transform
+ reversed_data = unlog_transform_predictions(imputations, log_params)
+
+ # Should match original data
+ pd.testing.assert_frame_equal(
+ reversed_data[0.5], original, check_exact=False, atol=1e-10
+ )
+
+ def test_unlog_transform_raises_error_for_missing_params(self):
+ """Test that unlog transform raises error when params are missing."""
+ imputations = {
+ 0.5: pd.DataFrame(
+ {
+ "value1": [0.0, 0.69, 1.10],
+ "value2": [2.3, 3.0, 3.9],
+ }
+ )
+ }
+
+ # Only have params for value1, not value2
+ log_params = {"value1": {}}
+
+ with pytest.raises(
+ ValueError, match="Missing log transformation parameters"
+ ):
+ unlog_transform_predictions(imputations, log_params)
+
+
+class TestPreprocessDataWithLogTransform:
+ """Test that preprocess_data correctly uses log transformation."""
+
+ def test_preprocess_data_excludes_categoricals_from_log_transform(self):
+ """Test that preprocess_data doesn't log transform categorical columns."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ 25.3,
+ 30.7,
+ 35.2,
+ 40.9,
+ 45.1,
+ 50.6,
+ 55.8,
+ 60.3,
+ 65.7,
+ 70.2,
+ ],
+ "race": [1, 2, 3, 1, 2, 3, 1, 2, 3, 1],
+ "is_female": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
+ "income": [
+ 50123.45,
+ 60987.23,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ }
+ )
+
+ result, transform_params = preprocess_data(
+ data, full_data=True, log_transform=True
+ )
+
+ # Extract log transform params from nested dict
+ log_params = transform_params["log_transform"]
+
+ # Categorical columns should be unchanged
+ pd.testing.assert_series_equal(result["race"], data["race"])
+ pd.testing.assert_series_equal(result["is_female"], data["is_female"])
+
+ # Numeric columns should be log transformed
+ assert not np.allclose(result["age"].values, data["age"].values)
+ assert not np.allclose(result["income"].values, data["income"].values)
+
+ # Only numeric columns in log_params
+ assert "age" in log_params
+ assert "income" in log_params
+ assert "race" not in log_params
+ assert "is_female" not in log_params
+
+ def test_preprocess_data_raises_error_for_both_normalize_and_log(
+ self,
+ ):
+ """Test that preprocess_data raises error if both normalize and log_transform are True."""
+ data = pd.DataFrame(
+ {
+ "value1": [1.5, 2.7, 3.2, 4.8, 5.1, 6.3, 7.9, 8.4, 9.6, 10.2],
+ "value2": [
+ 15.5,
+ 27.3,
+ 32.1,
+ 48.7,
+ 51.9,
+ 63.2,
+ 79.8,
+ 84.5,
+ 96.1,
+ 102.4,
+ ],
+ }
+ )
+
+ with pytest.raises(
+ ValueError,
+ match="Cannot apply both normalization and log transformation",
+ ):
+ preprocess_data(
+ data, full_data=True, normalize=True, log_transform=True
+ )
+
+ def test_preprocess_data_with_log_transform_and_split(self):
+ """Test that preprocess_data correctly splits and log transforms data."""
+ data = pd.DataFrame(
+ {
+ "value1": [
+ 1.5,
+ 2.7,
+ 3.2,
+ 4.8,
+ 5.1,
+ 6.3,
+ 7.9,
+ 8.4,
+ 9.6,
+ 10.2,
+ 11.5,
+ 12.8,
+ 13.3,
+ 14.9,
+ 15.2,
+ ],
+ "value2": [
+ 15.5,
+ 27.3,
+ 32.1,
+ 48.7,
+ 51.9,
+ 63.2,
+ 79.8,
+ 84.5,
+ 96.1,
+ 102.4,
+ 115.7,
+ 128.9,
+ 133.6,
+ 149.2,
+ 152.8,
+ ],
+ }
+ )
+
+ X_train, X_test, transform_params = preprocess_data(
+ data,
+ full_data=False,
+ test_size=0.2,
+ train_size=None,
+ random_state=42,
+ log_transform=True,
+ )
+
+ # Extract log transform params from nested dict
+ log_params = transform_params["log_transform"]
+
+ # Check that data is split
+ assert len(X_train) == 12
+ assert len(X_test) == 3
+
+ # Check that log params are returned
+ assert "value1" in log_params
+ assert "value2" in log_params
+
+ # Check that values are log transformed (compare to original)
+ assert not any(X_train["value1"].isin(data["value1"]))
+ assert not any(X_test["value1"].isin(data["value1"]))
+
+
+class TestPreprocessDataWithSelectiveTransformation:
+ """Test preprocess_data with selective column transformation."""
+
+ def test_normalize_only_specified_columns(self):
+ """Test that only specified columns are normalized."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ 23,
+ 30,
+ 35,
+ 46,
+ 45,
+ 52,
+ 55,
+ 61,
+ 68,
+ 72,
+ ],
+ "income": [
+ 50123.45,
+ 60987.23,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ "wealth": [
+ 150000.5,
+ 250000.3,
+ 350000.7,
+ 450000.2,
+ 550000.9,
+ 650000.1,
+ 750000.4,
+ 850000.8,
+ 950000.6,
+ 1050000.3,
+ ],
+ }
+ )
+
+ # Only normalize income column
+ result, transform_params = preprocess_data(
+ data, full_data=True, normalize=["income"]
+ )
+
+ # Extract normalization params from nested dict
+ norm_params = transform_params["normalization"]
+
+ # Income should be normalized
+ assert not np.allclose(result["income"].values, data["income"].values)
+ assert "income" in norm_params
+
+ # Age and wealth should NOT be normalized
+ pd.testing.assert_series_equal(result["age"], data["age"])
+ pd.testing.assert_series_equal(result["wealth"], data["wealth"])
+ assert "age" not in norm_params
+ assert "wealth" not in norm_params
+
+ def test_log_transform_only_specified_columns(self):
+ """Test that only specified columns are log transformed."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ 23,
+ 30,
+ 35,
+ 46,
+ 45,
+ 52,
+ 55,
+ 61,
+ 68,
+ 72,
+ ],
+ "income": [
+ 50123.45,
+ 60987.23,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ "wealth": [
+ 150000.5,
+ 250000.3,
+ 350000.7,
+ 450000.2,
+ 550000.9,
+ 650000.1,
+ 750000.4,
+ 850000.8,
+ 950000.6,
+ 1050000.3,
+ ],
+ }
+ )
+
+ # Only log transform income column
+ result, transform_params = preprocess_data(
+ data, full_data=True, log_transform=["income"]
+ )
+
+ # Extract log transform params from nested dict
+ log_params = transform_params["log_transform"]
+
+ # Income should be log transformed
+ assert not np.allclose(result["income"].values, data["income"].values)
+ assert "income" in log_params
+
+ # Age and wealth should NOT be transformed
+ pd.testing.assert_series_equal(result["age"], data["age"])
+ pd.testing.assert_series_equal(result["wealth"], data["wealth"])
+ assert "age" not in log_params
+ assert "wealth" not in log_params
+
+ def test_normalize_multiple_specified_columns(self):
+ """Test normalizing multiple specified columns."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ 23,
+ 30,
+ 35,
+ 46,
+ 45,
+ 52,
+ 55,
+ 61,
+ 68,
+ 72,
+ ],
+ "income": [
+ 50123.45,
+ 60987.23,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ "wealth": [
+ 150000.5,
+ 250000.3,
+ 350000.7,
+ 450000.2,
+ 550000.9,
+ 650000.1,
+ 750000.4,
+ 850000.8,
+ 950000.6,
+ 1050000.3,
+ ],
+ }
+ )
+
+ # Normalize income and wealth, but not age
+ result, transform_params = preprocess_data(
+ data, full_data=True, normalize=["income", "wealth"]
+ )
+
+ # Extract normalization params from nested dict
+ norm_params = transform_params["normalization"]
+
+ # Income and wealth should be normalized
+ assert not np.allclose(result["income"].values, data["income"].values)
+ assert not np.allclose(result["wealth"].values, data["wealth"].values)
+ assert "income" in norm_params
+ assert "wealth" in norm_params
+
+ # Age should NOT be normalized
+ pd.testing.assert_series_equal(result["age"], data["age"])
+ assert "age" not in norm_params
+
+ def test_error_on_nonexistent_column_normalize(self):
+ """Test that error is raised when specifying non-existent column."""
+ data = pd.DataFrame(
+ {
+ "age": [25.3, 30.7, 35.2, 40.9, 45.1],
+ "income": [50123.45, 60987.23, 70456.78, 80234.56, 90876.12],
+ }
+ )
+
+ with pytest.raises(ValueError, match="not found in data"):
+ preprocess_data(
+ data, full_data=True, normalize=["income", "nonexistent"]
+ )
+
+ def test_error_on_nonexistent_column_log_transform(self):
+ """Test that error is raised when specifying non-existent column."""
+ data = pd.DataFrame(
+ {
+ "age": [25.3, 30.7, 35.2, 40.9, 45.1],
+ "income": [50123.45, 60987.23, 70456.78, 80234.56, 90876.12],
+ }
+ )
+
+ with pytest.raises(ValueError, match="not found in data"):
+ preprocess_data(
+ data, full_data=True, log_transform=["income", "nonexistent"]
+ )
+
+ def test_error_on_overlapping_columns(self):
+ """Test error when both normalize and log_transform target same columns."""
+ data = pd.DataFrame(
+ {
+ "age": [25.3, 30.7, 35.2, 40.9, 45.1],
+ "income": [50123.45, 60987.23, 70456.78, 80234.56, 90876.12],
+ }
+ )
+
+ # Error when same column is in both lists
+ with pytest.raises(
+ ValueError, match="Cannot apply both normalization and log"
+ ):
+ preprocess_data(
+ data,
+ full_data=True,
+ normalize=["income", "age"],
+ log_transform=["age"],
+ )
+
+ def test_both_transformations_on_different_columns(self):
+ """Test that both transformations work when applied to different columns."""
+ data = pd.DataFrame(
+ {
+ "age": [
+ 23,
+ 30,
+ 35,
+ 46,
+ 45,
+ 52,
+ 55,
+ 61,
+ 68,
+ 72,
+ ],
+ "income": [
+ 50123.45,
+ 60987.23,
+ 70456.78,
+ 80234.56,
+ 90876.12,
+ 100543.89,
+ 110234.67,
+ 120789.34,
+ 130456.78,
+ 140987.23,
+ ],
+ "wealth": [
+ 150000.5,
+ 250000.3,
+ 350000.7,
+ 450000.2,
+ 550000.9,
+ 650000.1,
+ 750000.4,
+ 850000.8,
+ 950000.6,
+ 1050000.3,
+ ],
+ }
+ )
+
+ # Normalize age, log transform income, leave wealth unchanged
+ result, transform_params = preprocess_data(
+ data,
+ full_data=True,
+ normalize=["age"],
+ log_transform=["income"],
+ )
+
+ # Extract both parameter dicts
+ norm_params = transform_params["normalization"]
+ log_params = transform_params["log_transform"]
+
+ # Age should be normalized
+ assert not np.allclose(result["age"].values, data["age"].values)
+ assert "age" in norm_params
+ assert "age" not in log_params
+
+ # Income should be log transformed
+ assert not np.allclose(result["income"].values, data["income"].values)
+ assert "income" in log_params
+ assert "income" not in norm_params
+
+ # Wealth should be unchanged
+ pd.testing.assert_series_equal(result["wealth"], data["wealth"])
+ assert "wealth" not in norm_params
+ assert "wealth" not in log_params