-
Notifications
You must be signed in to change notification settings - Fork 395
[Feature] support async rl #1360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
efb3109 to
1601d51
Compare
5e3f135 to
aaa4860
Compare
| waiting_tasks = set() | ||
| dataflow_start_time = time.perf_counter() | ||
| task_completion_times = [] | ||
| with tqdm(total=self.target_batch_size, desc="rollout_controller for training samples") as pbar: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用 tqdm(miniters=10) (Minimum progress display update interval in iters)并在循环中使用 pbar.update(finished_samples) 来代替 manual pbar.fresh。最小化pbar在loop中的操作。
| data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx | ||
| ) | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice hierarchical code!
| collator="fake_collator", | ||
| pack_level="none", | ||
| expired_threshold = ( | ||
| min(remain_size, self.config.tail_batch_trigger_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use cast(int, xxx) instead
| self.finished_samples_count = await self.replay_buffer.get_completed_samples_count.remote() | ||
| waiting_tasks = pending_tasks | ||
|
|
||
| while len(waiting_tasks) + self.finished_samples_count < max(data_concurrency, self.target_batch_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
len(waiting_tasks) + self.finished_samples_count < data_concurrency + init_finished_samples_count
31b3535 to
953a613
Compare
f6fa0fd to
4bd4c4f
Compare
This PR introduces asynchronous RL support to Xtuner, enabling partial rollouts and version-based sample management for more efficient training data generation.
1. Key Concepts:
2. Async logic:
staleness_threshold=0.0enable_partial_rollout=0tail_batch_candidate_steps=0staleness_threshold=0.2enable_partial_rollout=0tail_batch_candidate_steps=02. Responses not retained when paused rollout
3. Prioritize sampling data from the abort queue
staleness_threshold=0.2enable_partial_rollout=0tail_batch_candidate_steps=1tail_batch_trigger_size=02. Responses not retained when paused
3. Prioritize sampling data from the abort queue
4. Put it into the candidate pool when sample abort num reaches
tail_batch_candidate_steps+1staleness_threshold=0.2enable_partial_rollout=1tail_batch_candidate_steps=0tail_batch_trigger_size=02. Responses retained & concatenated when paused
3. Prioritize sampling data from the abort queue
staleness_threshold=0.2enable_partial_rollout=1tail_batch_candidate_steps=1tail_batch_trigger_size=02. Responses retained & concatenated when paused
3. Prioritize sampling data from the abort queue
4. Put it into the candidate pool when sample abort num reaches
tail_batch_candidate_steps+1. thetail_batch_candidate_stepsmeans off policy step3. BenchMark
4. Relative PR
sample_from_expired_storagein dataflow. Whensample_from_expired_storageis set to True, the dataflow will not oversend data and will return data only after all tasks of the current batch are completed.