Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions PERFORMANCE_OPTIMIZATION_REPORT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# C3PO Performance Optimization Report

## Executive Summary

This report documents performance inefficiencies identified in the C3PO (Contextualized Critiques with Constrained Preference Optimization) codebase. The analysis covers the main pipeline components including sampling, training, evaluation, and data processing modules.

## Key Findings

### 1. **HIGH IMPACT**: Dataset Caching Disabled (FIXED)
**Location**: `src/dataset/format.py` - All dataset transformation functions
**Issue**: All `dataset.map()` calls explicitly disable caching with `load_from_cache_file=False`
**Impact**: Forces recomputation of dataset transformations on every run, significantly slowing down repeated operations
**Lines Affected**: 65, 72, 80, 88, 115, 121, 127, 135, 141, 149, 172, 177, 183
**Fix Applied**: Removed `load_from_cache_file=False` parameter to enable default caching behavior

### 2. **MEDIUM IMPACT**: Inefficient API Call Batching
**Location**: `src/sample.py` - `sample_completions()` function (lines 96-163)
**Issue**: Sequential API calls for different completion types instead of batching
**Impact**:
- Makes 4 separate API calls per prompt batch (baseline, revised, in-context, CoT)
- Each call waits for completion before starting the next
- Increases total latency by ~4x for completion generation
**Recommendation**: Batch all completion requests into a single API call with different prompts

### 3. **MEDIUM IMPACT**: Redundant Dataset Operations
**Location**: `src/dataset/format.py` - `to_lcdpo()` and `to_sft_weighted()` functions
**Issue**: Unnecessary dataset length calculations and truncation warnings
**Impact**:
- `min(len(dataset), min(len(negative_dataset), len(general_dataset)))` computed multiple times
- Redundant dataset selections and warnings for each dataset
**Lines**: 94-105, 155-166
**Recommendation**: Compute minimum length once and reuse

### 4. **MEDIUM IMPACT**: String Processing Inefficiencies
**Location**: `src/sample.py` - Response parsing (lines 45, 70, 127, 138)
**Issue**: Multiple string operations on API responses
**Impact**:
- `r.split("REVISED_CATEGORIES:")[-1].strip()` pattern repeated
- String operations not optimized for batch processing
**Recommendation**: Create utility functions for common parsing patterns

### 5. **LOW-MEDIUM IMPACT**: File I/O Inefficiencies
**Location**: `src/dataset/feedback_utils.py` - Dataset serialization (lines 125-139)
**Issue**: JSON serialization/deserialization for large datasets
**Impact**:
- Converting entire datasets to dictionaries before JSON serialization
- Memory overhead for large datasets
**Recommendation**: Use more efficient serialization formats (e.g., Arrow, Parquet)

### 6. **LOW IMPACT**: Repeated Model Loading
**Location**: `src/sample.py` - Multiple `get_model()` calls
**Issue**: Models loaded separately for categories, prompts, and completions
**Impact**:
- Potential redundant model initialization if same model used for multiple tasks
- Memory overhead from multiple model instances
**Lines**: 36, 57, 99
**Recommendation**: Implement model caching/reuse strategy

### 7. **LOW IMPACT**: Inefficient Linear Layer Discovery
**Location**: `src/utils.py` - `find_all_linear_names()` function (lines 147-157)
**Issue**: Iterates through all model modules to find linear layers
**Impact**:
- O(n) search through model architecture
- Called during LoRA setup for every training run
**Recommendation**: Cache results or use more efficient module filtering

## Performance Impact Assessment

| Optimization | Impact Level | Effort | Risk | Priority |
|--------------|-------------|--------|------|----------|
| Dataset Caching | HIGH | LOW | LOW | 1 (FIXED) |
| API Call Batching | MEDIUM | MEDIUM | MEDIUM | 2 |
| Dataset Operations | MEDIUM | LOW | LOW | 3 |
| String Processing | MEDIUM | LOW | LOW | 4 |
| File I/O | LOW-MEDIUM | MEDIUM | MEDIUM | 5 |
| Model Loading | LOW | MEDIUM | MEDIUM | 6 |
| Linear Layer Discovery | LOW | LOW | LOW | 7 |

## Implementation Notes

### Fixed: Dataset Caching Optimization
- **Change**: Removed `load_from_cache_file=False` from all `dataset.map()` calls
- **Benefit**: Enables HuggingFace datasets' built-in caching mechanism
- **Risk**: Minimal - only enables existing functionality
- **Testing**: Verify dataset operations produce identical results

### Recommended Next Steps
1. Implement API call batching in `sample_completions()`
2. Optimize dataset length calculations in format functions
3. Create utility functions for common string parsing operations
4. Evaluate alternative serialization formats for large datasets

## Conclusion

The dataset caching fix provides immediate performance benefits with minimal risk. The other identified optimizations offer additional performance gains and should be prioritized based on the impact/effort matrix above. All optimizations maintain backward compatibility and existing functionality.
28 changes: 14 additions & 14 deletions src/dataset/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,30 @@ def to_dpo(dataset: Dataset, negative_dataset: Dataset = None, general_dataset:
"prompt": x["prompt"],
"rejected": x["baseline_response"],
"chosen": x["revised_response"]
}, remove_columns=dataset.features, load_from_cache_file=False)
}, remove_columns=dataset.features)

if negative_dataset is not None:
negative_dataset = negative_dataset.map(lambda x: {
"prompt": x["prompt"],
"rejected": x["revised_response"],
"chosen": x["baseline_response"]
}, remove_columns=negative_dataset.features, load_from_cache_file=False)
}, remove_columns=negative_dataset.features)
dataset = concatenate_datasets([dataset, negative_dataset])

if general_dataset is not None:
general_dataset = general_dataset.map(lambda x: {
"prompt": x["prompt"],
"rejected": x["revised_response"],
"chosen": x["baseline_response"]
}, remove_columns=general_dataset.features, load_from_cache_file=False)
}, remove_columns=general_dataset.features)
dataset = concatenate_datasets([dataset, general_dataset])

prompt_format = FORMAT_MAPPING[model_name_or_path]["prompt"]

dataset = dataset.map(lambda x: {
# The DPO trainer adds the eos/bos tokens itself so no need to do that here
"prompt": prompt_format(x['prompt'])
}, load_from_cache_file=False)
})
return dataset


Expand All @@ -112,18 +112,18 @@ def to_lcdpo(dataset: Dataset, negative_dataset: Dataset = None, general_dataset
"prompt": prompt_format(x['prompt']),
"rejected": x["baseline_response"],
"chosen": x["revised_response"]
}, remove_columns=dataset.features, load_from_cache_file=False)
}, remove_columns=dataset.features)

if negative_dataset is not None:
negative_dataset = negative_dataset.map(lambda x: {
"hard_negative": full_format(x["prompt"], x["baseline_response"])
}, remove_columns=negative_dataset.features, load_from_cache_file=False)
}, remove_columns=negative_dataset.features)
dataset = dataset.add_column("hard_negative", negative_dataset["hard_negative"])

if general_dataset is not None:
general_dataset = general_dataset.map(lambda x: {
"soft_negative": full_format(x["prompt"], x["baseline_response"])
}, remove_columns=general_dataset.features, load_from_cache_file=False)
}, remove_columns=general_dataset.features)
dataset = dataset.add_column("soft_negative", general_dataset["soft_negative"])
return dataset

Expand All @@ -132,20 +132,20 @@ def to_sft(dataset: Dataset, negative_dataset: Dataset = None, general_dataset:
dataset = dataset.map(lambda x: {
"prompt": x["prompt"],
"completion": f' {x["revised_response"]}' # TODO: hack to fix tokenization issue when there are to neighboring parentheses (e.g. '[/INST][...]' )
}, remove_columns=dataset.features, load_from_cache_file=False)
}, remove_columns=dataset.features)

if negative_dataset is not None:
negative_dataset = negative_dataset.map(lambda x: {
"prompt": x["prompt"],
"completion": f' {x["baseline_response"]}' # TODO: hack to fix tokenization issue when there are to neighboring parentheses (e.g. '[/INST][...]' )
}, remove_columns=negative_dataset.features, load_from_cache_file=False)
}, remove_columns=negative_dataset.features)
dataset = concatenate_datasets([dataset, negative_dataset])

if general_dataset is not None:
general_dataset = general_dataset.map(lambda x: {
"prompt": x["prompt"],
"completion": f' {x["baseline_response"]}' # TODO: hack to fix tokenization issue when there are to neighboring parentheses (e.g. '[/INST][...]' )
}, remove_columns=general_dataset.features, load_from_cache_file=False)
}, remove_columns=general_dataset.features)
dataset = concatenate_datasets([dataset, general_dataset])
return dataset

Expand All @@ -169,17 +169,17 @@ def to_sft_weighted(dataset: Dataset, negative_dataset: Dataset = None, general_

dataset = dataset.map(lambda x: {
"text": full_format(x["prompt"], x["revised_response"]) # TODO: hack to fix tokenization issue when there are to neighboring parentheses (e.g. '[/INST][...]' )
}, remove_columns=dataset.features, load_from_cache_file=False)
}, remove_columns=dataset.features)

if negative_dataset is not None:
negative_dataset = negative_dataset.map(lambda x: {
"hard_negative": full_format(x["prompt"], x["baseline_response"])
}, remove_columns=negative_dataset.features, load_from_cache_file=False)
}, remove_columns=negative_dataset.features)
dataset = dataset.add_column("hard_negative", negative_dataset["hard_negative"])

if general_dataset is not None:
general_dataset = general_dataset.map(lambda x: {
"soft_negative": full_format(x["prompt"], x["baseline_response"])
}, remove_columns=general_dataset.features, load_from_cache_file=False)
}, remove_columns=general_dataset.features)
dataset = dataset.add_column("soft_negative", general_dataset["soft_negative"])
return dataset
return dataset