-
Notifications
You must be signed in to change notification settings - Fork 83
Fix ParallelStreamingDataset with resume=True not resuming after loading a state dict when breaking early
#771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #771 +/- ##
===================================
Coverage 80% 80%
===================================
Files 52 52
Lines 7362 7364 +2
===================================
+ Hits 5906 5908 +2
Misses 1456 1456 🚀 New features to boost your workflow:
|
ParallelStreamingDataset with resume=True not resuming after loading a state dict when breaking early
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes a bug where ParallelStreamingDataset with resume=True would not properly resume after manually loading a state dict following a training crash when iterations were incomplete (breaking early from an epoch). The fix restructures the control flow in StreamingDataLoader.__iter__() to ensure that the ParallelStreamingDataset cycling logic executes even when self.restore=True, enabling proper resume behavior.
Key changes:
- Restructured the
__iter__method inStreamingDataLoaderto handle ParallelStreamingDataset cycling logic before checking the restore flag - Enhanced test coverage with new scenarios that simulate training crashes and manual state dict loading
- Updated test parameters and expected values to better validate cycling and resume behavior
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| src/litdata/streaming/dataloader.py | Fixed the __iter__ method to properly handle ParallelStreamingDataset cycling when resuming from a manually loaded state dict by restructuring conditional logic |
| tests/streaming/test_parallel.py | Added comprehensive test scenarios for crash simulation and manual resume, updated helper function to support tmpdir reuse, and adjusted test parameters and expected values for better coverage |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # for some reason the workers are swapping their samples compared to the previous epoch when not resuming | ||
| # so we update expected_2 and batches_2 accordingly | ||
| expected_2 = [expected_2[i + 1] if i % 2 == 0 else expected_2[i - 1] for i in range(len(expected_2))] | ||
| batches_2 = [batches_2[i + 1] if i % 2 == 0 else batches_2[i - 1] for i in range(len(batches_2))] |
Copilot
AI
Dec 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment suggests uncertainty about the behavior ("for some reason the workers are swapping their samples"). The test then compensates by swapping elements in expected_2 and batches_2 lists. This workaround indicates either: (1) the underlying behavior is not well understood, or (2) there's a non-deterministic or undocumented aspect of worker sample distribution. Consider investigating the root cause and documenting why this swapping occurs, or fixing the underlying issue if it's a bug.
| # for some reason the workers are swapping their samples compared to the previous epoch when not resuming | |
| # so we update expected_2 and batches_2 accordingly | |
| expected_2 = [expected_2[i + 1] if i % 2 == 0 else expected_2[i - 1] for i in range(len(expected_2))] | |
| batches_2 = [batches_2[i + 1] if i % 2 == 0 else batches_2[i - 1] for i in range(len(batches_2))] | |
| # The order of samples delivered by the workers should be deterministic when shuffle=False. | |
| # If this test fails, investigate the worker sample assignment logic in the DataLoader. | |
| # Remove the workaround that swaps expected_2 and batches_2; the test should reflect the true, documented order. |
| # For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not | ||
| # want to restart at index 0 at every epoch. So we set them in restore state. |
Copilot
AI
Dec 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment states "we do not want to restart at index 0 at every epoch" but could be more explicit about what this fix addresses. Consider expanding the comment to mention that this handles both automatic cycling (between epochs in the same session) and manual resume (after loading a state dict from a previous session/crash), as this is the key bug being fixed.
| # For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not | |
| # want to restart at index 0 at every epoch. So we set them in restore state. | |
| # For ParallelStreamingDataset with _length != None, we want to cycle the wrapped datasets and avoid | |
| # restarting at index 0 at every epoch. This logic ensures that we correctly handle both automatic cycling | |
| # between epochs in the same session and manual resume after loading a state dict from a previous session | |
| # or crash. We set the datasets in restore state to maintain the correct position across both scenarios. |
Before submitting
What does this PR do?
Fixes #770.
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃