-
Notifications
You must be signed in to change notification settings - Fork 395
[Refactor] refactor packing in RL train controller and train worker #1393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors the packing logic in the RL training controller and worker components to improve token balancing and code organization. The key changes introduce a Karmarkar-Karp algorithm for balanced partitioning, extract helper methods for better code maintainability, and restructure how data batches are distributed across workers.
Key Changes
- Introduces sequence-length balanced partitioning using the Karmarkar-Karp differencing algorithm to better distribute workload across devices
- Refactors worker's
fitmethod to accept nested list structurelist[list[WorkerInputItem]]instead of flat list, aligning with the new per-step packing approach - Extracts reusable helper methods (
_resolve_ray_data,_apply_rollout_is_correction,_create_padding_sample,_pack,_balance_split_batch) to reduce code duplication and improve maintainability
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 12 comments.
| File | Description |
|---|---|
| xtuner/v1/rl/utils.py | Adds Karmarkar-Karp algorithm implementation with get_seqlen_balanced_partitions function for balanced workload distribution across partitions |
| xtuner/v1/rl/base/worker.py | Refactors fit method to handle nested batch structure, extracts ray data resolution and importance sampling logic into separate methods, adds get_worker_cfg accessor method |
| xtuner/v1/rl/base/controller.py | Major refactoring of packing logic with new balanced splitting, padding creation, and improved data distribution across workers with per-step gradient accumulation support |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Adapted from https://github.com/volcengine/verl/blob/main/verl/utils/seqlen_balancing.py | ||
| def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool): | ||
| # see: https://en.wikipedia.org/wiki/Largest_differencing_method | ||
| class Set: |
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class implements lt, but does not implement le or ge.
| return len(self.items) < len(other.items) | ||
| return self.items < other.items | ||
|
|
||
| class State: |
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class implements lt, but does not implement le or ge.
ce11425 to
62ae9fc
Compare
| rollout_logprobs: torch.Tensor | None | ||
|
|
||
|
|
||
| class RawTrainingController: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要给核心函数添加单测。
在添加单测的过程中,也会为了方便单测调整函数接口,这样接口设计也会变得更合理
xtuner/v1/rl/base/controller.py
Outdated
| get_logger().info(f"default split into {dp_size} partitions with tokens: {tokens_in_partition}") | ||
|
|
||
| packed_data_batches: list[list[list[dict]]] = [[[] for _ in range(optimizer_steps)] for _ in range(dp_size)] | ||
| max_packs_per_card = [0] * optimizer_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to max_packed_batch_num_per_step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_packs_per_step 更加准确一些:每步最大的packs数
|
|
||
| # old logprobs are inplaced updated in compute_actor_logprobs | ||
| loss_ctx_input_list = self.compute_actor_logprobs(seq_ctx_list, loss_ctx_input_list) | ||
| loss_ctx_input_list, metrics = self._apply_rollout_is_correction( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!原来很长的fit函数变得有层次更易读了
xtuner/v1/rl/base/controller.py
Outdated
| n_routed_experts=n_routed_experts, | ||
| ) | ||
| padding_samples = [padding_sample for _ in range(num_padding_packs)] | ||
| packed_data_batches[dp_rank][step_idx].extend(padding_samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以加一个数据pack等处理的总体耗时记录。原因是:之前数据处理一部分放在TrainController单节点上,另外一部分放到Worker多节点上。现在全部放在Controller单节点上,有可能变慢,但是数据处理一般比较简单应该不会慢太多,加个监控以后方便观察来及时调整。
Great PR description! Then I can review easily as the same order above : ) Additionally, when you write the core function calling chain responding to your original design in the "Key Changes", you will find that there are some high-level functions missing in your implementation, just like the The unit test can play the same role sometimes. For example, if you want to write unit test to test the core padding function, then you need to abstract the related code pieces into the function |
62ae9fc to
2dfbb8b
Compare
Motivation
Current xtuner data distribution mechanism has a pack allocation issue that leads to unstable training steps and affects training effectiveness.
The data distribution pipeline consists of three stages:
data_batchby token count, creating one pack per 32K tokens, resulting in N packsoptimizer_stepparameterWhen
N/Mis not divisible byoptimizer_step, the actual training steps fail to match the expected value.For example:
Key Changes
1. Token-aware Pre-allocation
In
RawTrainingController.fit()(controller.py), samples are evenly distributed intoMworkers and further split intooptimizer_stepbuckets for each worker, based on token count. This ensures balanced token distribution across all workers and steps:batches_per_dp_group = self._balance_split_batch(data_batches, dp_size)mini_batch_for_steps = self._balance_split_batch(dp_worker_data_batches, optimizer_steps)2. Pack & Pad per Bucket
Within each pre-allocated bucket, data is packed and padded so that each pack does not exceed
pack_max_length. Padding is applied where necessary, and the number of packs per step is aligned across all workers:batch4pack_list = self._rearrange_batch_for_pack(step_mini_batch, pack_max_length)step_pack = self._pad_and_pack_batches(batch4pack, pack_max_length)self._pad_to_max_packs_across_workes(packed_data_batches, step_idx, max_packs, pack_max_length)3. Worker-side Training
In
TrainingWorker.fit()(worker.py), each worker processes its assigned data, including sequence context resolution, logprobs computation, importance sampling correction, and the actual training step:seq_ctx = self._resolve_ray_data(data["seq_ctx"], language_cfg)self.compute_actor_logprobs()self._apply_rollout_is_correction()train_step()