From 1e562b3bdd974fe01931ba59f413684340a825de Mon Sep 17 00:00:00 2001 From: juaristi22 Date: Fri, 5 Dec 2025 10:44:48 +0800 Subject: [PATCH] some nits --- .github/workflows/main.yml | 34 +- README.md | 1 + changelog_entry.yaml | 4 + docs/imputation-benchmarking/metrics.md | 52 ++- docs/myst.yml | 14 +- microimpute/comparisons/autoimpute_helpers.py | 23 +- microimpute/comparisons/metrics.py | 95 ++++- microimpute/models/mdn.py | 35 +- myst.yml | 14 +- tests/test_metrics.py | 370 +++++++++++++++++- 10 files changed, 552 insertions(+), 90 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index bab5de8..90e08ff 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,33 +5,7 @@ on: branches: [ main ] jobs: - Check-MDN-Changes: - runs-on: ubuntu-latest - outputs: - mdn_changed: ${{ steps.check.outputs.mdn_changed }} - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - name: Check for MDN-related file changes - id: check - run: | - # Get list of changed files in this push - CHANGED_FILES=$(git diff --name-only HEAD~1 HEAD) - echo "Changed files:" - echo "$CHANGED_FILES" - - # Check if any MDN-related files were changed - if echo "$CHANGED_FILES" | grep -qE "(mdn|MDN)"; then - echo "mdn_changed=true" >> $GITHUB_OUTPUT - echo "MDN-related files were changed" - else - echo "mdn_changed=false" >> $GITHUB_OUTPUT - echo "No MDN-related files were changed" - fi - Test: - needs: Check-MDN-Changes runs-on: ubuntu-latest strategy: matrix: @@ -56,12 +30,8 @@ jobs: run: | sudo Rscript -e 'install.packages("StatMatch", repos="https://cloud.r-project.org")' sudo Rscript -e 'install.packages("clue", repos="https://cloud.r-project.org")' - - name: Install full dependencies without MDN (Python 3.13) - if: matrix.python-version == '3.13' && needs.Check-MDN-Changes.outputs.mdn_changed != 'true' - run: | - uv pip install -e ".[dev,docs,matching,images]" --system - - name: Install full dependencies with MDN (Python 3.13) - if: matrix.python-version == '3.13' && needs.Check-MDN-Changes.outputs.mdn_changed == 'true' + - name: Install full dependencies (Python 3.13) + if: matrix.python-version == '3.13' run: | uv pip install -e ".[dev,docs,matching,mdn,images]" --system - name: Install minimal dependencies (Python 3.12) diff --git a/README.md b/README.md index e9b83bb..8b055cf 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ Microimpute enables variable imputation through a variety of statistical methods - **Ordinary Least Squares (OLS)**: Linear regression-based imputation - **Quantile Regression**: Distribution-aware regression imputation - **Quantile Random Forests (QRF)**: Non-parametric forest-based approach +- **Mixture Density Networks (MDN)**: Neural network with Gaussian mixture approximation head ### Automated method selection - **AutoImpute**: Automatically compares and selects the best imputation method for your data diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..7eb6399 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - Updates to documentation and Myst deployment. diff --git a/docs/imputation-benchmarking/metrics.md b/docs/imputation-benchmarking/metrics.md index fce0c40..cf2f6a3 100644 --- a/docs/imputation-benchmarking/metrics.md +++ b/docs/imputation-benchmarking/metrics.md @@ -116,6 +116,8 @@ $$W_p(P, Q) = \left(\inf_{\gamma \in \Pi(P, Q)} \int_{X \times Y} d(x, y)^p d\ga where $\Pi(P, Q)$ denotes the set of all joint distributions whose marginals are $P$ and $Q$ respectively. The Wasserstein distance measures the minimum "work" required to transform one distribution into another, where work is the amount of distribution mass moved times the distance moved. Lower values indicate better preservation of the original distribution's shape. +When sample weights are provided, the weighted Wasserstein distance accounts for varying observation importance, which is essential when comparing survey data with different sampling designs. We use scipy's `wasserstein_distance` implementation, which supports sample weights via the `u_weights` and `v_weights` parameters. + ### Kullback-Leibler divergence For discrete distributions (categorical and boolean variables), KL divergence quantifies how one probability distribution diverges from a reference: @@ -124,24 +126,56 @@ $$D_{KL}(P||Q) = \sum_{x \in \mathcal{X}} P(x) \log\left(\frac{P(x)}{Q(x)}\right where $P$ is the reference distribution (original data), $Q$ is the approximation (imputed data), and $\mathcal{X}$ is the set of all possible categorical values. KL divergence measures how much information is lost when using the imputed distribution to approximate the true distribution. Lower values indicate better preservation of the original categorical distribution. +When sample weights are provided, the probability distributions are computed as weighted proportions rather than simple counts, ensuring proper comparison of weighted survey data. + +### kl_divergence + +Computes the Kullback-Leibler divergence between two categorical distributions, with optional sample weights. + +```python +def kl_divergence( + donor_values: np.ndarray, + receiver_values: np.ndarray, + donor_weights: Optional[np.ndarray] = None, + receiver_weights: Optional[np.ndarray] = None, +) -> float +``` + +| Parameter | Type | Default used | Description | +|-----------|------|---------|-------------| +| donor_values | np.ndarray | - | Categorical values from donor data (reference distribution) | +| receiver_values | np.ndarray | - | Categorical values from receiver data (approximation) | +| donor_weights | np.ndarray | None | Optional sample weights for donor values | +| receiver_weights | np.ndarray | None | Optional sample weights for receiver values | + +Returns KL divergence value (float >= 0), where 0 indicates identical distributions. + ### compare_distributions +Compares distributions between donor and receiver data, automatically selecting the appropriate metric based on variable type and supporting sample weights for survey data. + ```python def compare_distributions( donor_data: pd.DataFrame, receiver_data: pd.DataFrame, imputed_variables: List[str], + donor_weights: Optional[Union[pd.Series, np.ndarray]] = None, + receiver_weights: Optional[Union[pd.Series, np.ndarray]] = None, ) -> pd.DataFrame ``` -| Parameter | Type | Description | -|-----------|------|-------------| -| donor_data | pd.DataFrame | Original donor data | -| receiver_data | pd.DataFrame | Receiver data with imputations | -| imputed_variables | List[str] | Variables to compare | +| Parameter | Type | Default used | Description | +|-----------|------|---------|-------------| +| donor_data | pd.DataFrame | - | Original donor data | +| receiver_data | pd.DataFrame | - | Receiver data with imputations | +| imputed_variables | List[str] | - | Variables to compare | +| donor_weights | pd.Series or np.ndarray | None | Sample weights for donor data (must match donor_data length) | +| receiver_weights | pd.Series or np.ndarray | None | Sample weights for receiver data (must match receiver_data length) | Returns a DataFrame with columns `Variable`, `Metric`, and `Distance`. The function automatically selects Wasserstein distance for numerical variables and KL divergence for categorical variables. +Note that data must not contain null or infinite values. If your data contains such values, filter them before calling this function. + ## Predictor analysis Understanding which predictors contribute most to imputation quality helps with feature selection and model interpretation. These tools analyze predictor-target relationships and evaluate sensitivity to predictor selection. @@ -251,11 +285,13 @@ metrics_df = compare_metrics( imputed_variables=imputed_variables ) -# Evaluate distributional match -dist_df = compare_distributions( +# Evaluate distributional match with survey weights +dist_df_weighted = compare_distributions( donor_data=donor, receiver_data=receiver_with_imputations, - imputed_variables=imputed_variables + imputed_variables=imputed_variables, + donor_weights=donor["sample_weight"], + receiver_weights=receiver["sample_weight"], ) # Analyze predictor importance diff --git a/docs/myst.yml b/docs/myst.yml index c0e500a..e7f484c 100644 --- a/docs/myst.yml +++ b/docs/myst.yml @@ -28,22 +28,28 @@ project: - file: models/quantreg/index children: - file: models/quantreg/quantreg-imputation + - file: models/mdn/index + children: + - file: models/mdn/mdn-imputation - title: Imputation and benchmarking children: - file: imputation-benchmarking/index children: + - file: imputation-benchmarking/preprocessing + - file: imputation-benchmarking/cross-validation + - file: imputation-benchmarking/metrics + - file: imputation-benchmarking/visualizations - file: imputation-benchmarking/benchmarking-methods - - file: imputation-benchmarking/imputing-across-surveys - title: AutoImpute children: - file: autoimpute/index children: - file: autoimpute/autoimpute - - title: SCF to CPS example + - title: Use cases children: - - file: examples/scf_to_cps/index + - file: use_cases/index children: - - file: examples/scf_to_cps/imputing-from-scf-to-cps + - file: use_cases/scf_to_cps/imputing-from-scf-to-cps site: options: logo: logo.png diff --git a/microimpute/comparisons/autoimpute_helpers.py b/microimpute/comparisons/autoimpute_helpers.py index 9723817..fe4fe6a 100644 --- a/microimpute/comparisons/autoimpute_helpers.py +++ b/microimpute/comparisons/autoimpute_helpers.py @@ -163,15 +163,20 @@ def prepare_data_for_imputation( predictor_log = [c for c in log_cols if c in predictors] predictor_asinh = [c for c in asinh_cols if c in predictors] - transformed_imputing, _ = preprocess_data( - imputing_data[predictors], - full_data=True, - train_size=train_size, - test_size=test_size, - normalize=predictor_normalize if predictor_normalize else False, - log_transform=predictor_log if predictor_log else False, - asinh_transform=predictor_asinh if predictor_asinh else False, - ) + if predictor_normalize or predictor_log or predictor_asinh: + transformed_imputing, _ = preprocess_data( + imputing_data[predictors], + full_data=True, + train_size=train_size, + test_size=test_size, + normalize=( + predictor_normalize if predictor_normalize else False + ), + log_transform=predictor_log if predictor_log else False, + asinh_transform=predictor_asinh if predictor_asinh else False, + ) + else: + transformed_imputing = imputing_data[predictors].copy() training_data = transformed_training if weight_col: diff --git a/microimpute/comparisons/metrics.py b/microimpute/comparisons/metrics.py index 15d1227..c963962 100644 --- a/microimpute/comparisons/metrics.py +++ b/microimpute/comparisons/metrics.py @@ -8,7 +8,7 @@ """ import logging -from typing import Dict, List, Literal, Optional, Tuple +from typing import Dict, List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd @@ -497,7 +497,10 @@ def compare_metrics( def kl_divergence( - donor_values: np.ndarray, receiver_values: np.ndarray + donor_values: np.ndarray, + receiver_values: np.ndarray, + donor_weights: Optional[np.ndarray] = None, + receiver_weights: Optional[np.ndarray] = None, ) -> float: """Calculate Kullback-Leibler (KL) Divergence between two categorical distributions. @@ -512,6 +515,10 @@ def kl_divergence( Args: donor_values: Array of categorical values from donor data (reference distribution P). receiver_values: Array of categorical values from receiver data (approximation Q). + donor_weights: Optional weights for donor values. If provided, computes + weighted probability distribution. + receiver_weights: Optional weights for receiver values. If provided, + computes weighted probability distribution. Returns: KL divergence value >= 0, where 0 indicates identical distributions @@ -536,9 +543,30 @@ def kl_divergence( np.unique(donor_values), np.unique(receiver_values) ) - # Calculate probability distributions - donor_counts = pd.Series(donor_values).value_counts(normalize=True) - receiver_counts = pd.Series(receiver_values).value_counts(normalize=True) + # Calculate probability distributions (weighted if weights provided) + if donor_weights is not None: + # Compute weighted probabilities + donor_df = pd.DataFrame( + {"value": donor_values, "weight": donor_weights} + ) + donor_grouped = donor_df.groupby("value")["weight"].sum() + donor_total = donor_grouped.sum() + donor_counts = donor_grouped / donor_total + else: + donor_counts = pd.Series(donor_values).value_counts(normalize=True) + + if receiver_weights is not None: + # Compute weighted probabilities + receiver_df = pd.DataFrame( + {"value": receiver_values, "weight": receiver_weights} + ) + receiver_grouped = receiver_df.groupby("value")["weight"].sum() + receiver_total = receiver_grouped.sum() + receiver_counts = receiver_grouped / receiver_total + else: + receiver_counts = pd.Series(receiver_values).value_counts( + normalize=True + ) # Create probability arrays for all categories p_donor = np.array([donor_counts.get(cat, 0.0) for cat in all_categories]) @@ -563,6 +591,8 @@ def compare_distributions( donor_data: pd.DataFrame, receiver_data: pd.DataFrame, imputed_variables: List[str], + donor_weights: Optional[Union[pd.Series, np.ndarray]] = None, + receiver_weights: Optional[Union[pd.Series, np.ndarray]] = None, ) -> pd.DataFrame: """Compare distributions between donor and receiver data for imputed variables. @@ -574,6 +604,10 @@ def compare_distributions( donor_data: DataFrame containing original donor data. receiver_data: DataFrame containing receiver data with imputations. imputed_variables: List of variable names to compare. + donor_weights: Optional array or Series of sample weights for donor data. + Must have same length as donor_data. + receiver_weights: Optional array or Series of sample weights for receiver + data. Must have same length as receiver_data. Returns: DataFrame with columns 'Variable', 'Metric', and 'Distance' containing @@ -608,14 +642,45 @@ def compare_distributions( receiver_data, imputed_variables, "receiver_data" ) + # Convert weights to numpy arrays if provided + donor_weights_arr = None + receiver_weights_arr = None + if donor_weights is not None: + donor_weights_arr = np.asarray(donor_weights) + if len(donor_weights_arr) != len(donor_data): + raise ValueError( + f"donor_weights length ({len(donor_weights_arr)}) must match " + f"donor_data length ({len(donor_data)})" + ) + if receiver_weights is not None: + receiver_weights_arr = np.asarray(receiver_weights) + if len(receiver_weights_arr) != len(receiver_data): + raise ValueError( + f"receiver_weights length ({len(receiver_weights_arr)}) must " + f"match receiver_data length ({len(receiver_data)})" + ) + results = [] # Detect metric type and compute distance for each variable detector = VariableTypeDetector() for var in imputed_variables: - # Get values from both datasets - donor_values = donor_data[var].dropna().values - receiver_values = receiver_data[var].dropna().values + donor_values = donor_data[var].values + receiver_values = receiver_data[var].values + + # Check for null values - these are not allowed when comparing + if np.any(pd.isna(donor_values)): + raise ValueError( + f"Variable '{var}' in donor_data contains null values. " + "Please remove or impute null values before comparing " + "distributions." + ) + if np.any(pd.isna(receiver_values)): + raise ValueError( + f"Variable '{var}' in receiver_data contains null values. " + "Please remove or impute null values before comparing " + "distributions." + ) if len(donor_values) == 0 or len(receiver_values) == 0: log.warning( @@ -633,14 +698,24 @@ def compare_distributions( if var_type in ["bool", "categorical", "numeric_categorical"]: # Use KL Divergence for categorical metric_name = "kl_divergence" - distance = kl_divergence(donor_values, receiver_values) + distance = kl_divergence( + donor_values, + receiver_values, + donor_weights=donor_weights_arr, + receiver_weights=receiver_weights_arr, + ) log.debug( f"KL divergence for categorical variable '{var}': {distance:.6f}" ) else: # Use Wasserstein Distance for numerical metric_name = "wasserstein_distance" - distance = wasserstein_distance(donor_values, receiver_values) + distance = wasserstein_distance( + donor_values, + receiver_values, + u_weights=donor_weights_arr, + v_weights=receiver_weights_arr, + ) log.debug( f"Wasserstein distance for numerical variable '{var}': {distance:.6f}" ) diff --git a/microimpute/models/mdn.py b/microimpute/models/mdn.py index 09d4e00..72001ee 100644 --- a/microimpute/models/mdn.py +++ b/microimpute/models/mdn.py @@ -32,6 +32,33 @@ ]: logging.getLogger(_logger_name).setLevel(logging.ERROR) + # Suppress pytorch_tabular warnings + import warnings + + warnings.filterwarnings( + "ignore", + message=".*does not have many workers.*", + module="pytorch_lightning", + ) + warnings.filterwarnings( + "ignore", + message=".*pin_memory.*argument is set as true but not supported on MPS.*", + module="torch.utils.data.dataloader", + ) + + warnings.filterwarnings( + "ignore", + message=".*Setting an item of incompatible dtype is deprecated.*", + category=FutureWarning, + module="pytorch_tabular.tabular_datamodule", + ) + + warnings.filterwarnings( + "ignore", + message=".*training batches.*smaller than the logging interval.*", + module="pytorch_lightning.loops.fit_loop", + ) + # After import, also update the rank_zero_module logger from lightning_fabric.utilities.rank_zero import rank_zero_module from pytorch_tabular import TabularModel @@ -145,7 +172,7 @@ def __init__( self, seed: int, logger, - layers: str = "128-64", + layers: str = "128-64-32", activation: str = "ReLU", dropout: float = 0.1, use_batch_norm: bool = False, @@ -155,7 +182,7 @@ def __init__( learning_rate: float = 1e-3, max_epochs: int = 100, early_stopping_patience: int = 10, - batch_size: int = 64, + batch_size: int = 256, ): self.seed = seed self.logger = logger @@ -316,14 +343,14 @@ def __init__( self, seed: int, logger, - layers: str = "128-64", + layers: str = "128-64-32", activation: str = "ReLU", dropout: float = 0.1, use_batch_norm: bool = False, learning_rate: float = 1e-3, max_epochs: int = 100, early_stopping_patience: int = 10, - batch_size: int = 64, + batch_size: int = 256, ): self.seed = seed self.logger = logger diff --git a/myst.yml b/myst.yml index 160145e..911eb8a 100644 --- a/myst.yml +++ b/myst.yml @@ -28,22 +28,28 @@ project: - file: docs/models/quantreg/index children: - file: docs/models/quantreg/quantreg-imputation + - file: docs/models/mdn/index + children: + - file: docs/models/mdn/mdn-imputation - title: Imputation and benchmarking children: - file: docs/imputation-benchmarking/index children: + - file: docs/imputation-benchmarking/preprocessing + - file: docs/imputation-benchmarking/cross-validation + - file: docs/imputation-benchmarking/metrics + - file: docs/imputation-benchmarking/visualizations - file: docs/imputation-benchmarking/benchmarking-methods - - file: docs/imputation-benchmarking/imputing-across-surveys - title: AutoImpute children: - file: docs/autoimpute/index children: - file: docs/autoimpute/autoimpute - - title: SCF to CPS example + - title: Use cases children: - - file: docs/examples/scf_to_cps/index + - file: docs/use_cases/index children: - - file: docs/examples/scf_to_cps/imputing-from-scf-to-cps + - file: docs/use_cases/scf_to_cps/imputing-from-scf-to-cps site: options: logo: docs/logo.png diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 8b1ee62..302df50 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -4,10 +4,7 @@ import pandas as pd import pytest -from microimpute.comparisons import ( - compare_metrics, - get_imputations, -) +from microimpute.comparisons import compare_metrics, get_imputations from microimpute.comparisons.autoimpute import autoimpute from microimpute.comparisons.autoimpute_helpers import ( select_best_model_dual_metrics, @@ -1222,37 +1219,35 @@ def test_compare_distributions_missing_columns() -> None: compare_distributions(donor, receiver, ["c"]) -def test_compare_distributions_with_nulls() -> None: - """Test that compare_distributions handles null values correctly.""" +def test_compare_distributions_rejects_nulls_in_data() -> None: + """Test that compare_distributions raises error when data contains nulls.""" np.random.seed(42) - donor = pd.DataFrame( + donor_with_nulls = pd.DataFrame( { "x": [1, 2, 3, 4, 5, np.nan, np.nan], "y": ["A", "B", "A", "B", "A", None, "B"], } ) - receiver = pd.DataFrame( + receiver_ok = pd.DataFrame( { - "x": [1.5, 2.5, 3.5, 4.5, 5.5, np.nan], - "y": ["A", "A", "B", "B", "B", None], + "x": [1.5, 2.5, 3.5, 4.5, 5.5, 6.5], + "y": ["A", "A", "B", "B", "B", "A"], } ) - # Should work - nulls are dropped - results = compare_distributions(donor, receiver, ["x", "y"]) + # Should raise error for nulls in donor + with pytest.raises(ValueError, match="contains null values"): + compare_distributions(donor_with_nulls, receiver_ok, ["x", "y"]) - assert len(results) == 2 - assert all(results["Distance"] >= 0) - -def test_compare_distributions_insufficient_data() -> None: - """Test error when variables have insufficient non-null data.""" +def test_compare_distributions_empty_data() -> None: + """Test error when variables have no data.""" donor = pd.DataFrame( { - "x": [np.nan, np.nan, np.nan], - "y": [None, None, None], + "x": pd.Series([], dtype=float), + "y": pd.Series([], dtype=str), } ) @@ -1297,3 +1292,340 @@ def test_compare_distributions_return_format() -> None: assert results["Variable"].dtype == "object" assert results["Metric"].dtype == "object" assert results["Distance"].dtype in ["float64", "float32"] + + +# === Weighted Distribution Comparison Tests === + + +def test_kl_divergence_with_weights() -> None: + """Test KL divergence with sample weights.""" + # Create data where weights matter + donor = np.array(["A", "A", "B", "B"]) + receiver = np.array(["A", "A", "B", "B"]) + + # Without weights: both distributions are 50% A, 50% B + kl_unweighted = kl_divergence(donor, receiver) + assert np.isclose( + kl_unweighted, 0.0, atol=1e-10 + ), "Unweighted identical distributions should have KL=0" + + # With weights: donor becomes 80% A, 20% B; receiver stays 50% A, 50% B + donor_weights = np.array( + [4.0, 4.0, 1.0, 1.0] + ) # A weighted 8, B weighted 2 + kl_weighted_donor = kl_divergence( + donor, receiver, donor_weights=donor_weights + ) + assert ( + kl_weighted_donor > 0 + ), "Weighted donor with different distribution should have KL > 0" + + # With receiver weights: donor 50/50, receiver becomes 80% A, 20% B + receiver_weights = np.array([4.0, 4.0, 1.0, 1.0]) + kl_weighted_receiver = kl_divergence( + donor, receiver, receiver_weights=receiver_weights + ) + assert ( + kl_weighted_receiver > 0 + ), "Weighted receiver with different distribution should have KL > 0" + + +def test_kl_divergence_weights_symmetry() -> None: + """Test that KL divergence with swapped weighted distributions gives same result.""" + donor = np.array(["A", "A", "B", "B"]) + receiver = np.array(["A", "A", "B", "B"]) + + # Make donor 80% A via weights + donor_weights = np.array([4.0, 4.0, 1.0, 1.0]) + + # Make receiver 80% A via weights + receiver_weights = np.array([4.0, 4.0, 1.0, 1.0]) + + # Both weighted the same way should give KL = 0 + kl = kl_divergence( + donor, + receiver, + donor_weights=donor_weights, + receiver_weights=receiver_weights, + ) + assert np.isclose( + kl, 0.0, atol=1e-10 + ), "Identically weighted distributions should have KL=0" + + +def test_compare_distributions_with_weights() -> None: + """Test compare_distributions with weight arrays.""" + np.random.seed(42) + + donor = pd.DataFrame( + { + "income": np.random.normal(50000, 10000, 100), + "region": np.random.choice(["A", "B", "C"], 100), + } + ) + donor_weights = np.random.uniform(0.5, 2.0, 100) + + receiver = pd.DataFrame( + { + "income": np.random.normal(52000, 10000, 100), + "region": np.random.choice(["A", "B", "C"], 100), + } + ) + receiver_weights = np.random.uniform(0.5, 2.0, 100) + + # Unweighted comparison + results_unweighted = compare_distributions( + donor, receiver, ["income", "region"] + ) + + # Weighted comparison + results_weighted = compare_distributions( + donor, + receiver, + ["income", "region"], + donor_weights=donor_weights, + receiver_weights=receiver_weights, + ) + + # Both should return valid results + assert len(results_unweighted) == 2 + assert len(results_weighted) == 2 + + # Results should be different (weights should affect the computation) + # Get income distances + income_unweighted = results_unweighted[ + results_unweighted["Variable"] == "income" + ]["Distance"].values[0] + income_weighted = results_weighted[ + results_weighted["Variable"] == "income" + ]["Distance"].values[0] + + # With random weights, results should typically differ + # (though not guaranteed, so we just check they're both valid) + assert income_unweighted >= 0 + assert income_weighted >= 0 + + +def test_compare_distributions_donor_weight_only() -> None: + """Test compare_distributions with only donor weights.""" + np.random.seed(42) + + donor = pd.DataFrame( + { + "x": np.random.normal(0, 1, 50), + } + ) + donor_weights = np.random.uniform(1, 3, 50) + + receiver = pd.DataFrame( + { + "x": np.random.normal(0.5, 1, 50), + } + ) + + # Should work with only donor weights + results = compare_distributions( + donor, receiver, ["x"], donor_weights=donor_weights + ) + + assert len(results) == 1 + assert results["Variable"].values[0] == "x" + assert results["Distance"].values[0] >= 0 + + +def test_compare_distributions_receiver_weight_only() -> None: + """Test compare_distributions with only receiver weights.""" + np.random.seed(42) + + donor = pd.DataFrame( + { + "x": np.random.normal(0, 1, 50), + } + ) + + receiver = pd.DataFrame( + { + "x": np.random.normal(0.5, 1, 50), + } + ) + receiver_weights = np.random.uniform(1, 3, 50) + + # Should work with only receiver weights + results = compare_distributions( + donor, receiver, ["x"], receiver_weights=receiver_weights + ) + + assert len(results) == 1 + assert results["Variable"].values[0] == "x" + assert results["Distance"].values[0] >= 0 + + +def test_compare_distributions_rejects_nulls() -> None: + """Test that compare_distributions raises error when data contains nulls.""" + donor_with_null = pd.DataFrame( + { + "x": [1.0, 2.0, np.nan, 4.0, 5.0], + } + ) + donor_ok = pd.DataFrame( + { + "x": [1.0, 2.0, 3.0, 4.0, 5.0], + } + ) + receiver_with_null = pd.DataFrame( + { + "x": [1.5, np.nan, 3.5, 4.5, 5.5], + } + ) + receiver_ok = pd.DataFrame( + { + "x": [1.5, 2.5, 3.5, 4.5, 5.5], + } + ) + + # Should raise error for null in donor + with pytest.raises(ValueError, match="donor_data contains null values"): + compare_distributions(donor_with_null, receiver_ok, ["x"]) + + # Should raise error for null in receiver + with pytest.raises(ValueError, match="receiver_data contains null values"): + compare_distributions(donor_ok, receiver_with_null, ["x"]) + + +def test_compare_distributions_weights_affect_wasserstein() -> None: + """Test that weights actually affect Wasserstein distance calculation.""" + # Create two identical value arrays + donor = pd.DataFrame( + { + "x": [1.0, 2.0, 3.0, 4.0], + } + ) + donor_weights = np.array([1.0, 1.0, 1.0, 1.0]) + + receiver = pd.DataFrame( + { + "x": [1.0, 2.0, 3.0, 4.0], + } + ) + receiver_weights = np.array([1.0, 1.0, 1.0, 1.0]) + + # Identical data and weights should give distance = 0 + results_identical = compare_distributions( + donor, + receiver, + ["x"], + donor_weights=donor_weights, + receiver_weights=receiver_weights, + ) + assert np.isclose( + results_identical["Distance"].values[0], 0.0, atol=1e-10 + ), "Identical weighted distributions should have distance=0" + + # Now change receiver weights to shift distribution toward higher values + receiver_shifted_weights = np.array( + [0.1, 0.1, 1.0, 1.0] + ) # More weight on higher values + + results_shifted = compare_distributions( + donor, + receiver, + ["x"], + donor_weights=donor_weights, + receiver_weights=receiver_shifted_weights, + ) + assert ( + results_shifted["Distance"].values[0] > 0 + ), "Different weighted distributions should have distance > 0" + + +def test_compare_distributions_weights_affect_kl() -> None: + """Test that weights actually affect KL divergence calculation.""" + # Create identical categorical arrays + donor = pd.DataFrame( + { + "cat": ["A", "A", "B", "B"], + } + ) + donor_weights = np.array([1.0, 1.0, 1.0, 1.0]) + + receiver = pd.DataFrame( + { + "cat": ["A", "A", "B", "B"], + } + ) + receiver_weights = np.array([1.0, 1.0, 1.0, 1.0]) + + # Identical data and weights should give KL = 0 + results_identical = compare_distributions( + donor, + receiver, + ["cat"], + donor_weights=donor_weights, + receiver_weights=receiver_weights, + ) + assert np.isclose( + results_identical["Distance"].values[0], 0.0, atol=1e-10 + ), "Identical weighted distributions should have KL=0" + + # Now change weights to create different distributions + # Donor: 50% A, 50% B (equal weights) + # Receiver: 90% A, 10% B (by weights) + receiver_shifted_weights = np.array( + [4.5, 4.5, 0.5, 0.5] + ) # 90% weight on A + + results_shifted = compare_distributions( + donor, + receiver, + ["cat"], + donor_weights=donor_weights, + receiver_weights=receiver_shifted_weights, + ) + assert ( + results_shifted["Distance"].values[0] > 0 + ), "Different weighted distributions should have KL > 0" + + +def test_compare_distributions_weight_length_mismatch() -> None: + """Test error handling for weight length mismatch.""" + donor = pd.DataFrame({"x": [1, 2, 3]}) + receiver = pd.DataFrame({"x": [1, 2, 3]}) + + # Should raise error for mismatched donor weights length + with pytest.raises(ValueError, match="donor_weights length"): + compare_distributions( + donor, + receiver, + ["x"], + donor_weights=np.array([1.0, 2.0]), # Wrong length + ) + + # Should raise error for mismatched receiver weights length + with pytest.raises(ValueError, match="receiver_weights length"): + compare_distributions( + donor, + receiver, + ["x"], + receiver_weights=np.array([1.0, 2.0]), # Wrong length + ) + + +def test_compare_distributions_with_series_weights() -> None: + """Test that compare_distributions works with pandas Series as weights.""" + donor = pd.DataFrame({"x": [1.0, 2.0, 3.0, 4.0]}) + donor_weights = pd.Series([1.0, 2.0, 3.0, 4.0]) + + receiver = pd.DataFrame({"x": [1.5, 2.5, 3.5, 4.5]}) + receiver_weights = pd.Series([1.0, 1.0, 1.0, 1.0]) + + # Should work with Series weights + results = compare_distributions( + donor, + receiver, + ["x"], + donor_weights=donor_weights, + receiver_weights=receiver_weights, + ) + + assert len(results) == 1 + assert results["Distance"].values[0] >= 0