Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/v1/config/rl_qwen3_8B_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
# basic settings
experimental_name = "grpo_gsm8k"
total_epochs = 15
global_batch_size = 1024
global_batch_size = 128
prompt_repeat_k = 5
rollout_tp_size = 2
rollout_tp_size = 1
rollout_ep_size = 1
max_prompt_length = 512
max_response_length = 1024
Expand Down Expand Up @@ -83,7 +83,6 @@
expert_parallel_size=rollout_ep_size,
gpu_memory_utilization=0.75,
context_length = max_response_length + max_prompt_length,
prompt_repeat_k=prompt_repeat_k,
# rollout_max_batch_size_per_instance=rollout_max_batch_size_per_instance, # optional, will be determined automatically if not set
)

Expand Down
1 change: 1 addition & 0 deletions xtuner/v1/config/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class FSDPConfig(BaseModel):
hsdp_sharding_size: Annotated[
Optional[int], Parameter(help="Sharding size for HSDP (Hybrid Sharding Data Parallel)")
] = None
enable_autocast: Annotated[bool, Parameter(help="Enable autocast for mixed precision training")] = False

def model_post_init(self, __context: Any) -> None:
if self.hsdp_sharding_size is not None:
Expand Down
1 change: 1 addition & 0 deletions xtuner/v1/data_proto/rl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class SampleParams(BaseModel):
stops: Annotated[List[str], Parameter(help="List of stop sequences.")] = []
stop_token_ids: Annotated[List[int], Parameter(help="List of stop token IDs.")] = []
skip_special_tokens: Annotated[bool, Parameter(help="Whether to skip special tokens.")] = True
sampling_seed: Annotated[int | None, Parameter(help="Random seed for sampling.")] = None


class RolloutExtraParams(TypedDict):
Expand Down
6 changes: 5 additions & 1 deletion xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,11 @@ def train_step(self, data_batches: list[ModelItem]):
total_forward_tokens += (num_tokens.sum()) ** 2

if self.intra_layer_micro_batch == 1:
output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0])
if self.fsdp_cfg.enable_autocast:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0])
else:
output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0])
else:
# For intra_layer_micro_batch > 1, we need to handle the data batches differently.
# Here we assume that the model can handle a list of seq_ctx and loss_ctx.
Expand Down
25 changes: 15 additions & 10 deletions xtuner/v1/model/dense/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,13 @@ def fully_shard(
self.rotary_emb = self.build_rotary_embedding(self.config)

self._maybe_compile_layers()
mp_policy = MixedPrecisionPolicy(
param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
)
# TODO: 这个代码没有经过严格精度验证
if self.fsdp_config.enable_autocast:
mp_policy = MixedPrecisionPolicy()
else:
mp_policy = MixedPrecisionPolicy(
param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
)
num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio)

generator = torch.Generator()
Expand All @@ -201,8 +205,9 @@ def fully_shard(
layer = checkpoint_wrapper(
layer, preserve_rng_state=checkpoint_preserve_rng_state, checkpoint_impl=CheckpointImpl.REENTRANT
)
# __class__ without self attribute
layer.__class__.forward = maybe_compile(layer.__class__.forward, fullgraph=True)
if not self.fsdp_config.enable_autocast:
# __class__ without self attribute
layer.__class__.forward = maybe_compile(layer.__class__.forward, fullgraph=True)

self.layers[str(layer_idx)] = layer
fully_shard(
Expand Down Expand Up @@ -252,11 +257,11 @@ def fully_shard(
)
self.set_modules_to_forward_prefetch([self.embed_tokens, self.layers["0"]]) # type: ignore

for _, module in self.named_modules():
if isinstance(module, nn.Embedding):
module.forward = types.MethodType(self.patched_emb_forward, module) # type: ignore
elif isinstance(module, RMSNorm):
module.forward = types.MethodType(self.patched_rms_norm_forward, module) # type: ignore
# for _, module in self.named_modules():
# if isinstance(module, nn.Embedding):
# module.forward = types.MethodType(self.patched_emb_forward, module) # type: ignore
# elif isinstance(module, RMSNorm):
# module.forward = types.MethodType(self.patched_rms_norm_forward, module) # type: ignore
self._to_empty_meta()

# Make sure it works properly when using fsdp
Expand Down
6 changes: 3 additions & 3 deletions xtuner/v1/module/attention/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ def forward(
kwargs["s_aux"] = sinks
# [b, n_head, seq, head_dim]
attn_output: torch.Tensor = self.attn_impl_func( # type: ignore
query_states,
key_states,
value_states,
query_states if query_states.dtype == torch.bfloat16 else query_states.to(torch.bfloat16),
key_states if key_states.dtype == torch.bfloat16 else key_states.to(torch.bfloat16),
value_states if value_states.dtype == torch.bfloat16 else value_states.to(torch.bfloat16),
cu_seqlens_q=seq_ctx.cu_seq_lens_q,
cu_seqlens_k=seq_ctx.cu_seq_lens_k,
max_seqlen_q=seq_ctx.max_length_q,
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/module/lm_head/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def forward( # type: ignore[override]
b = self.bias
if loss_ctx is None:
logits = F.linear(hidden_states, w, b)
return None, (logits.float(), {})
return None, (logits, {}) # 为了对齐,暂时删除
else:
return loss_ctx.forward(hidden_states, w, b)

Expand Down
12 changes: 6 additions & 6 deletions xtuner/v1/module/rms_norm/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
weight = self.weight

# just for align
# input_dtype = hidden_states.dtype
# hidden_states = hidden_states.to(torch.float32)
# variance = hidden_states.pow(2).mean(-1, keepdim=True)
# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# return (weight * hidden_states).to(input_dtype) # gpt_oss
# return weight * hidden_states.to(input_dtype) # Llama
return rms_norm(hidden_states, weight, epsilon=self.variance_epsilon) # xtuner
return weight * hidden_states.to(input_dtype) # Llama
# return rms_norm(hidden_states, weight, epsilon=self.variance_epsilon) # xtuner

def init_weights(self):
self.weight.data.fill_(1.0)
Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/ops/rotary_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# 为了对齐推理引擎
@torch.compile(dynamic=True)
def apply_rotary_pos_emb_cuda(
q: torch.Tensor,
k: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/ray/config/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class RolloutConfig(BaseModel):
gpu_memory_utilization: Annotated[
float, Parameter(group=infer_group, help="GPU memory utilization for the rollout worker.")
] = 0.85
random_seed: Annotated[int, Parameter(group=infer_group, help="Random seed for the rollout worker.")] = 1024
random_seed: Annotated[int | None, Parameter(group=infer_group, help="Random seed for the rollout worker.")] = None
# distributed config
rollout_cross_node_comm: Annotated[
bool,
Expand Down
7 changes: 7 additions & 0 deletions xtuner/v1/ray/environment/base_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, List
import os

from xtuner.v1.data_proto.rl_data import RLDataFlowItem

Expand Down Expand Up @@ -30,6 +31,12 @@ def __init__(
self.rollout_controller = self.init_rollout_controller(rollout_cfg, rollout_pg)
self.judger_controller = self.init_judger_controller(judger_cfg, judger_pg)

self.random_seed = rollout_cfg.random_seed
self.enable_logprob_zero_diff = os.environ.get("XTUNER_ENABLE_LOGPROB_ZERO_DIFF", "0") == "1"
if self.random_seed is None and self.enable_logprob_zero_diff:
print(f'XTUNER_ENABLE_LOGPROB_ZERO_DIFF is enabled, set random_seed to 42 due to random_seed is None')
self.random_seed = 42

def init_rollout_controller(self, rollout_cfg: Any, placement_group: Any):
"""Initializes the rollout controller with the appropriate worker
backend.
Expand Down
30 changes: 19 additions & 11 deletions xtuner/v1/ray/environment/single_turn_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import os
from typing import List

import copy
import ray

from xtuner.v1.data_proto.rl_data import (
Expand Down Expand Up @@ -62,18 +62,26 @@ async def generate(
and state from the rollout controller.
"""
if self.rollout_controller:
if self.enable_logprob_zero_diff:
assert self.random_seed is not None, "XTUNER_ENABLE_LOGPROB_ZERO_DIFF is enabled, random_seed must be set"
group_sampling_seeds = [self.random_seed + i for i in range(len(group_data_items))]
else:
group_sampling_seeds = [self.random_seed for _ in range(len(group_data_items))]

# 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦
# 每个模块返回独立的data item, 在env中进行更新
response_future = [
self.rollout_controller.rollout.remote(
prompt=sample.data.messages,
input_ids=sample.data.input_ids,
sample_params=sample_params,
extra_params=extra_params,
extra_info=sample.data.extra_info,
)
for sample in group_data_items
]
response_future = []
for i, sample in enumerate(group_data_items):
_sample_params = copy.deepcopy(sample_params)
_sample_params.sampling_seed = group_sampling_seeds[i]
future = self.rollout_controller.rollout.remote(
prompt=sample.data.messages,
input_ids=sample.data.input_ids,
sample_params=_sample_params,
extra_params=extra_params,
extra_info=sample.data.extra_info,
)
response_future.append(future)
try:
rollout_responses = await asyncio.wait_for(
asyncio.gather(*response_future), timeout=self.rollout_timeout
Expand Down
8 changes: 7 additions & 1 deletion xtuner/v1/ray/rollout/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,14 @@ def _transform_rollout_config_to_server_configs(self):
sglang_server_args.max_running_requests = self.config.rollout_max_batch_size_per_instance
sglang_server_args.log_level = log_level
sglang_server_args.log_level_http = log_level_http
sglang_server_args.enable_deterministic_inference = enable_deterministic_inference
sglang_server_args.tp_size = num_gpus_per_engine
sglang_server_args.ep_size = num_gpus_per_engine
sglang_server_args.enable_deterministic_inference = enable_deterministic_inference

if self.enable_logprob_zero_diff:
sglang_server_args.enable_deterministic_inference = True
sglang_server_args.rl_on_policy_target = 'fsdp'
sglang_server_args.attention_backend = 'fa3'

if grammar_backend is not None:
sglang_server_args.grammar_backend = grammar_backend
Expand Down Expand Up @@ -212,6 +217,7 @@ def _transform_sample_params(self, sample_params: Dict):
"stop": sample_params["stops"],
"stop_token_ids": sample_params["stop_token_ids"],
"skip_special_tokens": sample_params["skip_special_tokens"],
"sampling_seed": sample_params["sampling_seed"],
}
return sglang_sample_params

Expand Down
1 change: 1 addition & 0 deletions xtuner/v1/ray/rollout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
self.enable_return_routed_experts = self.config.enable_return_routed_experts
if self.rank == 0:
self.logger.info(f"RolloutConfig:\n{self.config.model_dump_json(indent=2)}")
self.enable_logprob_zero_diff = os.environ.get("XTUNER_ENABLE_LOGPROB_ZERO_DIFF", "0") == "1"

def init_dist_port(self):
"""Initialize distributed communication ports.
Expand Down
33 changes: 26 additions & 7 deletions xtuner/v1/rl/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ def __init__(
super().__init__(worker_cfg, rank, master_addr, master_port, world_size, accelerator)
self.config = cast(WorkerConfig, self.config)
torch.accelerator.set_device_index(int(os.environ["LOCAL_RANK"]))

if self.config.fsdp_cfg.enable_autocast:
from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode, disable_batch_invariant_mode
print("FSDPTrainRayActor call enable_batch_invariant_mode for true-on-policy")
enable_batch_invariant_mode(
# In Qwen3, rope `inv_freq_expanded.float() @ position_ids_expanded.float()` uses bmm
# and disabling it will make it aligned
enable_bmm=False,
)
self._engine = self._build_engine(worker_cfg)

self._has_ref = False
Expand Down Expand Up @@ -213,7 +222,8 @@ def _build_ref_model(
if isinstance(ref_model_cfg, VisionComposeConfigProtocol):
assert ref_model_cfg.text_config.float8_cfg is None, "VisionComposeConfigProtocol does not support float8"
if ref_model_fsdp_cfg is None:
ref_model_fsdp_cfg = FSDPConfig(recompute_ratio=0, cpu_offload=False, requires_grad=False)
ref_model_fsdp_cfg = FSDPConfig(recompute_ratio=0, cpu_offload=False, requires_grad=False,
enable_autocast=self.config.fsdp_cfg.enable_autocast)
model.language_model.fully_shard(ref_model_fsdp_cfg) # type: ignore
model.vision_tower.fully_shard(ref_model_fsdp_cfg) # type: ignore
model.multi_modal_projector.fully_shard(ref_model_fsdp_cfg) # type: ignore
Expand All @@ -230,7 +240,8 @@ def _build_ref_model(
else:
float8_handler = None
if ref_model_fsdp_cfg is None:
ref_model_fsdp_cfg = FSDPConfig(recompute_ratio=0, cpu_offload=False, requires_grad=False)
ref_model_fsdp_cfg = FSDPConfig(recompute_ratio=0, cpu_offload=False, requires_grad=False,
enable_autocast=self.config.fsdp_cfg.enable_autocast)
model = model.fully_shard(ref_model_fsdp_cfg, float8_handler) # type: ignore

model.from_hf(hf_path=load_from)
Expand Down Expand Up @@ -268,7 +279,11 @@ def compute_actor_logprobs(
self, seq_ctx_list: list[SequenceContext], loss_ctx_input_list: list[RLLossContextInputItem]
) -> list[RLLossContextInputItem]:
for seq_ctx, loss_ctx_input in zip(seq_ctx_list, loss_ctx_input_list):
output = self._engine.forward_only(seq_ctx=seq_ctx)
if self.config.fsdp_cfg.enable_autocast:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = self._engine.forward_only(seq_ctx=seq_ctx)
else:
output = self._engine.forward_only(seq_ctx=seq_ctx)
loss_ctx_input.old_logprobs = gather_logprobs(output["logits"], loss_ctx_input.shifted_labels)
return loss_ctx_input_list

Expand All @@ -279,7 +294,11 @@ def compute_ref_logprobs(
self._ref_model.to_device(DEVICE)
for seq_ctx, loss_ctx_input in zip(seq_ctx_list, loss_ctx_input_list):
with torch.no_grad():
ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx=None)
if self.config.fsdp_cfg.enable_autocast:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx=None)
else:
ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx=None)
ref_logprobs = gather_logprobs(ref_output["logits"], loss_ctx_input.shifted_labels)
loss_ctx_input.ref_logprobs = ref_logprobs
self._ref_model.to_device("cpu")
Expand Down Expand Up @@ -402,7 +421,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int):
else:
min_diff = torch.min(rollout_logprobs - old_logprobs)
max_diff = torch.max(rollout_logprobs - old_logprobs)
mean_diff = torch.mean(rollout_logprobs - old_logprobs)
mean_diff = torch.sum(rollout_logprobs - old_logprobs) / len(old_logprobs)
if rollout_logprobs.numel() == 1:
std_diff = torch.tensor(0.0)
else:
Expand Down Expand Up @@ -600,7 +619,7 @@ def get_params(tensor_list, name_list, save_dtype):
else:
saved_list.append(f"layers.{i}.{sub_name}")
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
local_tensor = local_tensor.bfloat16()
# local_tensor = local_tensor.bfloat16()
load_spec = language_model.load_spec_mapping.get(f"layers.{i}.{sub_name}")

if isinstance(model.config, VisionComposeConfigProtocol):
Expand Down Expand Up @@ -628,7 +647,7 @@ def get_params(tensor_list, name_list, save_dtype):
if name in saved_list:
continue
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
local_tensor = local_tensor.bfloat16()
# local_tensor = local_tensor.bfloat16()
load_spec = model.load_spec_mapping.get(name)

if isinstance(model.config, VisionComposeConfigProtocol):
Expand Down
13 changes: 11 additions & 2 deletions xtuner/v1/train/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ def __init__(
train_worker_cfg.log_dir = log_dir
dataflow_config.worker_log_dir = log_dir
rollout_config.worker_log_dir = log_dir

if rollout_config.random_seed is None:
print(f'rollout_config.random_seed is None, set to {seed}')
rollout_config.random_seed = seed

if os.environ.get("XTUNER_ENABLE_LOGPROB_ZERO_DIFF", "0") == "1":
print(f'XTUNER_ENABLE_LOGPROB_ZERO_DIFF is enabled, make sure the logprob diff is zero !!!')
train_worker_cfg.fsdp_cfg.enable_autocast = True

self._enable_evaluate = False
self._enable_initial_evaluate = False
if evaluator_config:
Expand Down Expand Up @@ -385,7 +394,7 @@ def fit(self):
"""
self.logger.info("Start RL training")
if self._enable_initial_evaluate and self._enable_evaluate and self._evaluator:
ray.get(self._rollout_env_controller.check_active_workers.remote())
# ray.get(self._rollout_env_controller.check_active_workers.remote())
scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True))
trajectory_save_path = self.exp_dir / "eval_0_trajectory.jsonl"
self._save_trajectories(eval_data_groups, trajectory_save_path)
Expand All @@ -395,7 +404,7 @@ def fit(self):
step_timer_dict = {}
# 1. Rollout
with timer("generation", step_timer_dict):
ray.get(self._rollout_env_controller.check_active_workers.remote())
# ray.get(self._rollout_env_controller.check_active_workers.remote())
data_groups, multimodal_train_infos = ray.get(self._rollout_dataflow.run.remote())
# 2. Offload rollout models and save trajectories
with timer("offload_and_dump", step_timer_dict):
Expand Down