Skip to content

Conversation

@YanhuiDua
Copy link
Collaborator

@YanhuiDua YanhuiDua commented Dec 15, 2025

This PR introduces asynchronous RL support to Xtuner, enabling partial rollouts and version-based sample management for more efficient training data generation.

1. Key Concepts:

  • staleness_threshold: The maximum allowed threshold of stale (expired) samples in a training batch.
  • enable_partial_rollout: Whether to enable partial rollout for asynchronous data generation.
  • tail_batch_candidate_steps: Number of rollout steps after which a sample becomes a candidate for the tail batch. Set to 0 to disable. 0 means no tail batch.
  • tail_batch_trigger_size: Number of candidate samples needed in the queue to trigger a tail batch operation. It will be set to global_batch_size when not provided by user or set to 0

2. Async logic:

Strategy Type Settings Core Features
Synchronous Strategy staleness_threshold=0.0
enable_partial_rollout=0
tail_batch_candidate_steps=0
1. No data oversending
Asynchronous 1 staleness_threshold=0.2
enable_partial_rollout=0
tail_batch_candidate_steps=0
1. 20% data oversending
2. Responses not retained when paused rollout
3. Prioritize sampling data from the abort queue
Asynchronous 2 staleness_threshold=0.2
enable_partial_rollout=0
tail_batch_candidate_steps=1
tail_batch_trigger_size=0
1. 20% data oversending
2. Responses not retained when paused
3. Prioritize sampling data from the abort queue
4. Put it into the candidate pool when sample abort num reaches tail_batch_candidate_steps+1
Asynchronous 3 staleness_threshold=0.2
enable_partial_rollout=1
tail_batch_candidate_steps=0
tail_batch_trigger_size=0
1. 20% data oversending
2. Responses retained & concatenated when paused
3. Prioritize sampling data from the abort queue
Asynchronous 4 staleness_threshold=0.2
enable_partial_rollout=1
tail_batch_candidate_steps=1
tail_batch_trigger_size=0
1. 20% data oversending
2. Responses retained & concatenated when paused
3. Prioritize sampling data from the abort queue
4. Put it into the candidate pool when sample abort num reaches tail_batch_candidate_steps+1. the tail_batch_candidate_steps means off policy step

3. BenchMark

4. Relative PR

  • Added async-related configuration parameters including partial_rollout, tail_batch_candidate_steps, tail_batch_trigger_size and staleness_threshold;
  • Refactored replay buffer storage to support versioned samples with bucketed tracking of completed, aborted, and expired states
  • Renamed Sampler to DatasetSampler and separated dataset sampling logic from replay buffer sampling
  • Apply sample_from_expired_storage in dataflow. When sample_from_expired_storage is set to True, the dataflow will not oversend data and will return data only after all tasks of the current batch are completed.
  • Add task time log info.
  • Added partial rollout functionality with versioned response tracking to accumulate tokens across multiple generation steps
  • Implemented automatic worker restart mechanism when all rollout workers become inactive
  • Fixed state handling for aborted rollouts and improved error logging
  • Add tensorboard for training and rollout metrics.
  • Refactored the training loop in fit() to conditionally execute rollout, training, and weight synchronization based on debug mode
  • Fix async running bugs

@YanhuiDua YanhuiDua force-pushed the support_async_rl_4 branch 2 times, most recently from 5e3f135 to aaa4860 Compare December 19, 2025 04:20
waiting_tasks = set()
dataflow_start_time = time.perf_counter()
task_completion_times = []
with tqdm(total=self.target_batch_size, desc="rollout_controller for training samples") as pbar:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用 tqdm(miniters=10) (Minimum progress display update interval in iters)并在循环中使用 pbar.update(finished_samples) 来代替 manual pbar.fresh。最小化pbar在loop中的操作。

data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx
)
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice hierarchical code!

collator="fake_collator",
pack_level="none",
expired_threshold = (
min(remain_size, self.config.tail_batch_trigger_size)
Copy link
Collaborator

@jayhenry jayhenry Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use cast(int, xxx) instead

self.finished_samples_count = await self.replay_buffer.get_completed_samples_count.remote()
waiting_tasks = pending_tasks

while len(waiting_tasks) + self.finished_samples_count < max(data_concurrency, self.target_batch_size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(waiting_tasks) + self.finished_samples_count < data_concurrency + init_finished_samples_count

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants