diff --git a/PERFORMANCE_OPTIMIZATION_REPORT.md b/PERFORMANCE_OPTIMIZATION_REPORT.md new file mode 100644 index 0000000..fae1c64 --- /dev/null +++ b/PERFORMANCE_OPTIMIZATION_REPORT.md @@ -0,0 +1,95 @@ +# C3PO Performance Optimization Report + +## Executive Summary + +This report documents performance inefficiencies identified in the C3PO (Contextualized Critiques with Constrained Preference Optimization) codebase. The analysis covers the main pipeline components including sampling, training, evaluation, and data processing modules. + +## Key Findings + +### 1. **HIGH IMPACT**: Dataset Caching Disabled (FIXED) +**Location**: `src/dataset/format.py` - All dataset transformation functions +**Issue**: All `dataset.map()` calls explicitly disable caching with `load_from_cache_file=False` +**Impact**: Forces recomputation of dataset transformations on every run, significantly slowing down repeated operations +**Lines Affected**: 65, 72, 80, 88, 115, 121, 127, 135, 141, 149, 172, 177, 183 +**Fix Applied**: Removed `load_from_cache_file=False` parameter to enable default caching behavior + +### 2. **MEDIUM IMPACT**: Inefficient API Call Batching +**Location**: `src/sample.py` - `sample_completions()` function (lines 96-163) +**Issue**: Sequential API calls for different completion types instead of batching +**Impact**: +- Makes 4 separate API calls per prompt batch (baseline, revised, in-context, CoT) +- Each call waits for completion before starting the next +- Increases total latency by ~4x for completion generation +**Recommendation**: Batch all completion requests into a single API call with different prompts + +### 3. **MEDIUM IMPACT**: Redundant Dataset Operations +**Location**: `src/dataset/format.py` - `to_lcdpo()` and `to_sft_weighted()` functions +**Issue**: Unnecessary dataset length calculations and truncation warnings +**Impact**: +- `min(len(dataset), min(len(negative_dataset), len(general_dataset)))` computed multiple times +- Redundant dataset selections and warnings for each dataset +**Lines**: 94-105, 155-166 +**Recommendation**: Compute minimum length once and reuse + +### 4. **MEDIUM IMPACT**: String Processing Inefficiencies +**Location**: `src/sample.py` - Response parsing (lines 45, 70, 127, 138) +**Issue**: Multiple string operations on API responses +**Impact**: +- `r.split("REVISED_CATEGORIES:")[-1].strip()` pattern repeated +- String operations not optimized for batch processing +**Recommendation**: Create utility functions for common parsing patterns + +### 5. **LOW-MEDIUM IMPACT**: File I/O Inefficiencies +**Location**: `src/dataset/feedback_utils.py` - Dataset serialization (lines 125-139) +**Issue**: JSON serialization/deserialization for large datasets +**Impact**: +- Converting entire datasets to dictionaries before JSON serialization +- Memory overhead for large datasets +**Recommendation**: Use more efficient serialization formats (e.g., Arrow, Parquet) + +### 6. **LOW IMPACT**: Repeated Model Loading +**Location**: `src/sample.py` - Multiple `get_model()` calls +**Issue**: Models loaded separately for categories, prompts, and completions +**Impact**: +- Potential redundant model initialization if same model used for multiple tasks +- Memory overhead from multiple model instances +**Lines**: 36, 57, 99 +**Recommendation**: Implement model caching/reuse strategy + +### 7. **LOW IMPACT**: Inefficient Linear Layer Discovery +**Location**: `src/utils.py` - `find_all_linear_names()` function (lines 147-157) +**Issue**: Iterates through all model modules to find linear layers +**Impact**: +- O(n) search through model architecture +- Called during LoRA setup for every training run +**Recommendation**: Cache results or use more efficient module filtering + +## Performance Impact Assessment + +| Optimization | Impact Level | Effort | Risk | Priority | +|--------------|-------------|--------|------|----------| +| Dataset Caching | HIGH | LOW | LOW | 1 (FIXED) | +| API Call Batching | MEDIUM | MEDIUM | MEDIUM | 2 | +| Dataset Operations | MEDIUM | LOW | LOW | 3 | +| String Processing | MEDIUM | LOW | LOW | 4 | +| File I/O | LOW-MEDIUM | MEDIUM | MEDIUM | 5 | +| Model Loading | LOW | MEDIUM | MEDIUM | 6 | +| Linear Layer Discovery | LOW | LOW | LOW | 7 | + +## Implementation Notes + +### Fixed: Dataset Caching Optimization +- **Change**: Removed `load_from_cache_file=False` from all `dataset.map()` calls +- **Benefit**: Enables HuggingFace datasets' built-in caching mechanism +- **Risk**: Minimal - only enables existing functionality +- **Testing**: Verify dataset operations produce identical results + +### Recommended Next Steps +1. Implement API call batching in `sample_completions()` +2. Optimize dataset length calculations in format functions +3. Create utility functions for common string parsing operations +4. Evaluate alternative serialization formats for large datasets + +## Conclusion + +The dataset caching fix provides immediate performance benefits with minimal risk. The other identified optimizations offer additional performance gains and should be prioritized based on the impact/effort matrix above. All optimizations maintain backward compatibility and existing functionality. diff --git a/src/dataset/format.py b/src/dataset/format.py index 90db781..69f67e5 100644 --- a/src/dataset/format.py +++ b/src/dataset/format.py @@ -62,14 +62,14 @@ def to_dpo(dataset: Dataset, negative_dataset: Dataset = None, general_dataset: "prompt": x["prompt"], "rejected": x["baseline_response"], "chosen": x["revised_response"] - }, remove_columns=dataset.features, load_from_cache_file=False) + }, remove_columns=dataset.features) if negative_dataset is not None: negative_dataset = negative_dataset.map(lambda x: { "prompt": x["prompt"], "rejected": x["revised_response"], "chosen": x["baseline_response"] - }, remove_columns=negative_dataset.features, load_from_cache_file=False) + }, remove_columns=negative_dataset.features) dataset = concatenate_datasets([dataset, negative_dataset]) if general_dataset is not None: @@ -77,7 +77,7 @@ def to_dpo(dataset: Dataset, negative_dataset: Dataset = None, general_dataset: "prompt": x["prompt"], "rejected": x["revised_response"], "chosen": x["baseline_response"] - }, remove_columns=general_dataset.features, load_from_cache_file=False) + }, remove_columns=general_dataset.features) dataset = concatenate_datasets([dataset, general_dataset]) prompt_format = FORMAT_MAPPING[model_name_or_path]["prompt"] @@ -85,7 +85,7 @@ def to_dpo(dataset: Dataset, negative_dataset: Dataset = None, general_dataset: dataset = dataset.map(lambda x: { # The DPO trainer adds the eos/bos tokens itself so no need to do that here "prompt": prompt_format(x['prompt']) - }, load_from_cache_file=False) + }) return dataset @@ -112,18 +112,18 @@ def to_lcdpo(dataset: Dataset, negative_dataset: Dataset = None, general_dataset "prompt": prompt_format(x['prompt']), "rejected": x["baseline_response"], "chosen": x["revised_response"] - }, remove_columns=dataset.features, load_from_cache_file=False) + }, remove_columns=dataset.features) if negative_dataset is not None: negative_dataset = negative_dataset.map(lambda x: { "hard_negative": full_format(x["prompt"], x["baseline_response"]) - }, remove_columns=negative_dataset.features, load_from_cache_file=False) + }, remove_columns=negative_dataset.features) dataset = dataset.add_column("hard_negative", negative_dataset["hard_negative"]) if general_dataset is not None: general_dataset = general_dataset.map(lambda x: { "soft_negative": full_format(x["prompt"], x["baseline_response"]) - }, remove_columns=general_dataset.features, load_from_cache_file=False) + }, remove_columns=general_dataset.features) dataset = dataset.add_column("soft_negative", general_dataset["soft_negative"]) return dataset @@ -132,20 +132,20 @@ def to_sft(dataset: Dataset, negative_dataset: Dataset = None, general_dataset: dataset = dataset.map(lambda x: { "prompt": x["prompt"], "completion": f' {x["revised_response"]}' # TODO: hack to fix tokenization issue when there are to neighboring parentheses (e.g. '[/INST][...]' ) - }, remove_columns=dataset.features, load_from_cache_file=False) + }, remove_columns=dataset.features) if negative_dataset is not None: negative_dataset = negative_dataset.map(lambda x: { "prompt": x["prompt"], "completion": f' {x["baseline_response"]}' # TODO: hack to fix tokenization issue when there are to neighboring parentheses (e.g. '[/INST][...]' ) - }, remove_columns=negative_dataset.features, load_from_cache_file=False) + }, remove_columns=negative_dataset.features) dataset = concatenate_datasets([dataset, negative_dataset]) if general_dataset is not None: general_dataset = general_dataset.map(lambda x: { "prompt": x["prompt"], "completion": f' {x["baseline_response"]}' # TODO: hack to fix tokenization issue when there are to neighboring parentheses (e.g. '[/INST][...]' ) - }, remove_columns=general_dataset.features, load_from_cache_file=False) + }, remove_columns=general_dataset.features) dataset = concatenate_datasets([dataset, general_dataset]) return dataset @@ -169,17 +169,17 @@ def to_sft_weighted(dataset: Dataset, negative_dataset: Dataset = None, general_ dataset = dataset.map(lambda x: { "text": full_format(x["prompt"], x["revised_response"]) # TODO: hack to fix tokenization issue when there are to neighboring parentheses (e.g. '[/INST][...]' ) - }, remove_columns=dataset.features, load_from_cache_file=False) + }, remove_columns=dataset.features) if negative_dataset is not None: negative_dataset = negative_dataset.map(lambda x: { "hard_negative": full_format(x["prompt"], x["baseline_response"]) - }, remove_columns=negative_dataset.features, load_from_cache_file=False) + }, remove_columns=negative_dataset.features) dataset = dataset.add_column("hard_negative", negative_dataset["hard_negative"]) if general_dataset is not None: general_dataset = general_dataset.map(lambda x: { "soft_negative": full_format(x["prompt"], x["baseline_response"]) - }, remove_columns=general_dataset.features, load_from_cache_file=False) + }, remove_columns=general_dataset.features) dataset = dataset.add_column("soft_negative", general_dataset["soft_negative"]) - return dataset \ No newline at end of file + return dataset