diff --git a/src/data_processing/transformations.py b/src/data_processing/transformations.py index 2a643e8..0cae6b5 100644 --- a/src/data_processing/transformations.py +++ b/src/data_processing/transformations.py @@ -11,27 +11,30 @@ def pivot_table( def agg_func(values): return sum(values) / len(values) + elif aggfunc == "sum": def agg_func(values): return sum(values) + elif aggfunc == "count": def agg_func(values): return len(values) + else: raise ValueError(f"Unsupported aggregation function: {aggfunc}") grouped_data = {} - for i in range(len(df)): - row = df.iloc[i] - index_val = row[index] - column_val = row[columns] - value = row[values] - if index_val not in grouped_data: - grouped_data[index_val] = {} - if column_val not in grouped_data[index_val]: - grouped_data[index_val][column_val] = [] - grouped_data[index_val][column_val].append(value) + + # Extract data as numpy arrays for fast iteration, avoiding .iloc row lookup + index_data = df[index].values + column_data = df[columns].values + value_data = df[values].values + + for index_val, column_val, value in zip(index_data, column_data, value_data): + inner = grouped_data.setdefault(index_val, {}) + inner.setdefault(column_val, []).append(value) + for index_val in grouped_data: result[index_val] = {} for column_val in grouped_data[index_val]: