-
Notifications
You must be signed in to change notification settings - Fork 402
Open
Labels
Description
Describe the bug
In SampleCache, get_cache_path doesn't consider the rank or apply a lock when saving/loading parquet files. This causes parquet caches to be potentially corrupted when multiple ranks are launched with accelerate backend in calls to cache_samples.
This causes the get_samples_from_cache function calls to fail, because the parquet load call in this block fails (cache_management.py: line 283):
for task_id in task_ids:
if task_id.sampling_method != sampling_method:
continue
cache_file = self.get_cache_path(task_id)
try:
dataset = load_dataset("parquet", data_files=str(cache_file), split="train")
dataset_df = dataset.to_pandas().set_index("sample_id")
task_datasets[task_id] = dataset_df
except Exception as e:
logger.warning(f"Error loading prediction cache for {str(task_id)}: {e}")To Reproduce
I have a bit of a complicated setup so I can't provide a simple repro directly here. I can try to produce one if it's absolute necessary.
Expected behavior
Only rank 0 should cache the samples and cache with multiple ranks should work.
Version
I'm using this branch: #1083, but the cache portion of code is identical.