diff --git a/examples/v1/config/rl_qwen3_8B_grpo.py b/examples/v1/config/rl_qwen3_8B_grpo.py index fd9227b85..b51aaa0ec 100644 --- a/examples/v1/config/rl_qwen3_8B_grpo.py +++ b/examples/v1/config/rl_qwen3_8B_grpo.py @@ -32,9 +32,9 @@ # basic settings experimental_name = "grpo_gsm8k" total_epochs = 15 -global_batch_size = 1024 +global_batch_size = 128 prompt_repeat_k = 5 -rollout_tp_size = 2 +rollout_tp_size = 1 rollout_ep_size = 1 max_prompt_length = 512 max_response_length = 1024 @@ -83,7 +83,6 @@ expert_parallel_size=rollout_ep_size, gpu_memory_utilization=0.75, context_length = max_response_length + max_prompt_length, - prompt_repeat_k=prompt_repeat_k, # rollout_max_batch_size_per_instance=rollout_max_batch_size_per_instance, # optional, will be determined automatically if not set ) diff --git a/xtuner/v1/config/fsdp.py b/xtuner/v1/config/fsdp.py index 6d855fa71..6dc1d1afe 100644 --- a/xtuner/v1/config/fsdp.py +++ b/xtuner/v1/config/fsdp.py @@ -37,6 +37,7 @@ class FSDPConfig(BaseModel): hsdp_sharding_size: Annotated[ Optional[int], Parameter(help="Sharding size for HSDP (Hybrid Sharding Data Parallel)") ] = None + enable_autocast: Annotated[bool, Parameter(help="Enable autocast for mixed precision training")] = False def model_post_init(self, __context: Any) -> None: if self.hsdp_sharding_size is not None: diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index ca3318ea0..572af5ca6 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -225,6 +225,7 @@ class SampleParams(BaseModel): stops: Annotated[List[str], Parameter(help="List of stop sequences.")] = [] stop_token_ids: Annotated[List[int], Parameter(help="List of stop token IDs.")] = [] skip_special_tokens: Annotated[bool, Parameter(help="Whether to skip special tokens.")] = True + sampling_seed: Annotated[int | None, Parameter(help="Random seed for sampling.")] = None class RolloutExtraParams(TypedDict): diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 4b3e99873..864be660b 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -262,7 +262,11 @@ def train_step(self, data_batches: list[ModelItem]): total_forward_tokens += (num_tokens.sum()) ** 2 if self.intra_layer_micro_batch == 1: - output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0]) + if self.fsdp_cfg.enable_autocast: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0]) + else: + output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0]) else: # For intra_layer_micro_batch > 1, we need to handle the data batches differently. # Here we assume that the model can handle a list of seq_ctx and loss_ctx. diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index cdb8235d7..ee4ed54ce 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -185,9 +185,13 @@ def fully_shard( self.rotary_emb = self.build_rotary_embedding(self.config) self._maybe_compile_layers() - mp_policy = MixedPrecisionPolicy( - param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype - ) + # TODO: 这个代码没有经过严格精度验证 + if self.fsdp_config.enable_autocast: + mp_policy = MixedPrecisionPolicy() + else: + mp_policy = MixedPrecisionPolicy( + param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype + ) num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio) generator = torch.Generator() @@ -201,8 +205,9 @@ def fully_shard( layer = checkpoint_wrapper( layer, preserve_rng_state=checkpoint_preserve_rng_state, checkpoint_impl=CheckpointImpl.REENTRANT ) - # __class__ without self attribute - layer.__class__.forward = maybe_compile(layer.__class__.forward, fullgraph=True) + if not self.fsdp_config.enable_autocast: + # __class__ without self attribute + layer.__class__.forward = maybe_compile(layer.__class__.forward, fullgraph=True) self.layers[str(layer_idx)] = layer fully_shard( @@ -252,11 +257,11 @@ def fully_shard( ) self.set_modules_to_forward_prefetch([self.embed_tokens, self.layers["0"]]) # type: ignore - for _, module in self.named_modules(): - if isinstance(module, nn.Embedding): - module.forward = types.MethodType(self.patched_emb_forward, module) # type: ignore - elif isinstance(module, RMSNorm): - module.forward = types.MethodType(self.patched_rms_norm_forward, module) # type: ignore + # for _, module in self.named_modules(): + # if isinstance(module, nn.Embedding): + # module.forward = types.MethodType(self.patched_emb_forward, module) # type: ignore + # elif isinstance(module, RMSNorm): + # module.forward = types.MethodType(self.patched_rms_norm_forward, module) # type: ignore self._to_empty_meta() # Make sure it works properly when using fsdp diff --git a/xtuner/v1/module/attention/mha.py b/xtuner/v1/module/attention/mha.py index fd4d18565..edce68c06 100644 --- a/xtuner/v1/module/attention/mha.py +++ b/xtuner/v1/module/attention/mha.py @@ -379,9 +379,9 @@ def forward( kwargs["s_aux"] = sinks # [b, n_head, seq, head_dim] attn_output: torch.Tensor = self.attn_impl_func( # type: ignore - query_states, - key_states, - value_states, + query_states if query_states.dtype == torch.bfloat16 else query_states.to(torch.bfloat16), + key_states if key_states.dtype == torch.bfloat16 else key_states.to(torch.bfloat16), + value_states if value_states.dtype == torch.bfloat16 else value_states.to(torch.bfloat16), cu_seqlens_q=seq_ctx.cu_seq_lens_q, cu_seqlens_k=seq_ctx.cu_seq_lens_k, max_seqlen_q=seq_ctx.max_length_q, diff --git a/xtuner/v1/module/lm_head/lm_head.py b/xtuner/v1/module/lm_head/lm_head.py index 417a890aa..e1f741853 100644 --- a/xtuner/v1/module/lm_head/lm_head.py +++ b/xtuner/v1/module/lm_head/lm_head.py @@ -33,7 +33,7 @@ def forward( # type: ignore[override] b = self.bias if loss_ctx is None: logits = F.linear(hidden_states, w, b) - return None, (logits.float(), {}) + return None, (logits, {}) # 为了对齐,暂时删除 else: return loss_ctx.forward(hidden_states, w, b) diff --git a/xtuner/v1/module/rms_norm/rms_norm.py b/xtuner/v1/module/rms_norm/rms_norm.py index d2238ad24..5b591f145 100644 --- a/xtuner/v1/module/rms_norm/rms_norm.py +++ b/xtuner/v1/module/rms_norm/rms_norm.py @@ -22,13 +22,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: weight = self.weight # just for align - # input_dtype = hidden_states.dtype - # hidden_states = hidden_states.to(torch.float32) - # variance = hidden_states.pow(2).mean(-1, keepdim=True) - # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # return (weight * hidden_states).to(input_dtype) # gpt_oss - # return weight * hidden_states.to(input_dtype) # Llama - return rms_norm(hidden_states, weight, epsilon=self.variance_epsilon) # xtuner + return weight * hidden_states.to(input_dtype) # Llama + # return rms_norm(hidden_states, weight, epsilon=self.variance_epsilon) # xtuner def init_weights(self): self.weight.data.fill_(1.0) diff --git a/xtuner/v1/ops/rotary_emb.py b/xtuner/v1/ops/rotary_emb.py index bcd975dec..41ffb85bc 100644 --- a/xtuner/v1/ops/rotary_emb.py +++ b/xtuner/v1/ops/rotary_emb.py @@ -10,6 +10,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +# 为了对齐推理引擎 +@torch.compile(dynamic=True) def apply_rotary_pos_emb_cuda( q: torch.Tensor, k: torch.Tensor, diff --git a/xtuner/v1/ray/config/worker.py b/xtuner/v1/ray/config/worker.py index ba91d5ca5..6a3e6ce78 100644 --- a/xtuner/v1/ray/config/worker.py +++ b/xtuner/v1/ray/config/worker.py @@ -104,7 +104,7 @@ class RolloutConfig(BaseModel): gpu_memory_utilization: Annotated[ float, Parameter(group=infer_group, help="GPU memory utilization for the rollout worker.") ] = 0.85 - random_seed: Annotated[int, Parameter(group=infer_group, help="Random seed for the rollout worker.")] = 1024 + random_seed: Annotated[int | None, Parameter(group=infer_group, help="Random seed for the rollout worker.")] = None # distributed config rollout_cross_node_comm: Annotated[ bool, diff --git a/xtuner/v1/ray/environment/base_env.py b/xtuner/v1/ray/environment/base_env.py index c23607549..d99a23664 100644 --- a/xtuner/v1/ray/environment/base_env.py +++ b/xtuner/v1/ray/environment/base_env.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Any, List +import os from xtuner.v1.data_proto.rl_data import RLDataFlowItem @@ -30,6 +31,12 @@ def __init__( self.rollout_controller = self.init_rollout_controller(rollout_cfg, rollout_pg) self.judger_controller = self.init_judger_controller(judger_cfg, judger_pg) + self.random_seed = rollout_cfg.random_seed + self.enable_logprob_zero_diff = os.environ.get("XTUNER_ENABLE_LOGPROB_ZERO_DIFF", "0") == "1" + if self.random_seed is None and self.enable_logprob_zero_diff: + print(f'XTUNER_ENABLE_LOGPROB_ZERO_DIFF is enabled, set random_seed to 42 due to random_seed is None') + self.random_seed = 42 + def init_rollout_controller(self, rollout_cfg: Any, placement_group: Any): """Initializes the rollout controller with the appropriate worker backend. diff --git a/xtuner/v1/ray/environment/single_turn_env.py b/xtuner/v1/ray/environment/single_turn_env.py index 3024fbfd4..9e5cf1575 100644 --- a/xtuner/v1/ray/environment/single_turn_env.py +++ b/xtuner/v1/ray/environment/single_turn_env.py @@ -1,7 +1,7 @@ import asyncio import os from typing import List - +import copy import ray from xtuner.v1.data_proto.rl_data import ( @@ -62,18 +62,26 @@ async def generate( and state from the rollout controller. """ if self.rollout_controller: + if self.enable_logprob_zero_diff: + assert self.random_seed is not None, "XTUNER_ENABLE_LOGPROB_ZERO_DIFF is enabled, random_seed must be set" + group_sampling_seeds = [self.random_seed + i for i in range(len(group_data_items))] + else: + group_sampling_seeds = [self.random_seed for _ in range(len(group_data_items))] + # 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦 # 每个模块返回独立的data item, 在env中进行更新 - response_future = [ - self.rollout_controller.rollout.remote( - prompt=sample.data.messages, - input_ids=sample.data.input_ids, - sample_params=sample_params, - extra_params=extra_params, - extra_info=sample.data.extra_info, - ) - for sample in group_data_items - ] + response_future = [] + for i, sample in enumerate(group_data_items): + _sample_params = copy.deepcopy(sample_params) + _sample_params.sampling_seed = group_sampling_seeds[i] + future = self.rollout_controller.rollout.remote( + prompt=sample.data.messages, + input_ids=sample.data.input_ids, + sample_params=_sample_params, + extra_params=extra_params, + extra_info=sample.data.extra_info, + ) + response_future.append(future) try: rollout_responses = await asyncio.wait_for( asyncio.gather(*response_future), timeout=self.rollout_timeout diff --git a/xtuner/v1/ray/rollout/sglang.py b/xtuner/v1/ray/rollout/sglang.py index 611a2cb54..2620458e4 100644 --- a/xtuner/v1/ray/rollout/sglang.py +++ b/xtuner/v1/ray/rollout/sglang.py @@ -179,9 +179,14 @@ def _transform_rollout_config_to_server_configs(self): sglang_server_args.max_running_requests = self.config.rollout_max_batch_size_per_instance sglang_server_args.log_level = log_level sglang_server_args.log_level_http = log_level_http - sglang_server_args.enable_deterministic_inference = enable_deterministic_inference sglang_server_args.tp_size = num_gpus_per_engine sglang_server_args.ep_size = num_gpus_per_engine + sglang_server_args.enable_deterministic_inference = enable_deterministic_inference + + if self.enable_logprob_zero_diff: + sglang_server_args.enable_deterministic_inference = True + sglang_server_args.rl_on_policy_target = 'fsdp' + sglang_server_args.attention_backend = 'fa3' if grammar_backend is not None: sglang_server_args.grammar_backend = grammar_backend @@ -212,6 +217,7 @@ def _transform_sample_params(self, sample_params: Dict): "stop": sample_params["stops"], "stop_token_ids": sample_params["stop_token_ids"], "skip_special_tokens": sample_params["skip_special_tokens"], + "sampling_seed": sample_params["sampling_seed"], } return sglang_sample_params diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py index 61b1dcd26..638f68758 100644 --- a/xtuner/v1/ray/rollout/worker.py +++ b/xtuner/v1/ray/rollout/worker.py @@ -76,6 +76,7 @@ def __init__( self.enable_return_routed_experts = self.config.enable_return_routed_experts if self.rank == 0: self.logger.info(f"RolloutConfig:\n{self.config.model_dump_json(indent=2)}") + self.enable_logprob_zero_diff = os.environ.get("XTUNER_ENABLE_LOGPROB_ZERO_DIFF", "0") == "1" def init_dist_port(self): """Initialize distributed communication ports. diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index f22e9ecf2..305a38ce1 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -147,6 +147,15 @@ def __init__( super().__init__(worker_cfg, rank, master_addr, master_port, world_size, accelerator) self.config = cast(WorkerConfig, self.config) torch.accelerator.set_device_index(int(os.environ["LOCAL_RANK"])) + + if self.config.fsdp_cfg.enable_autocast: + from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode, disable_batch_invariant_mode + print("FSDPTrainRayActor call enable_batch_invariant_mode for true-on-policy") + enable_batch_invariant_mode( + # In Qwen3, rope `inv_freq_expanded.float() @ position_ids_expanded.float()` uses bmm + # and disabling it will make it aligned + enable_bmm=False, + ) self._engine = self._build_engine(worker_cfg) self._has_ref = False @@ -213,7 +222,8 @@ def _build_ref_model( if isinstance(ref_model_cfg, VisionComposeConfigProtocol): assert ref_model_cfg.text_config.float8_cfg is None, "VisionComposeConfigProtocol does not support float8" if ref_model_fsdp_cfg is None: - ref_model_fsdp_cfg = FSDPConfig(recompute_ratio=0, cpu_offload=False, requires_grad=False) + ref_model_fsdp_cfg = FSDPConfig(recompute_ratio=0, cpu_offload=False, requires_grad=False, + enable_autocast=self.config.fsdp_cfg.enable_autocast) model.language_model.fully_shard(ref_model_fsdp_cfg) # type: ignore model.vision_tower.fully_shard(ref_model_fsdp_cfg) # type: ignore model.multi_modal_projector.fully_shard(ref_model_fsdp_cfg) # type: ignore @@ -230,7 +240,8 @@ def _build_ref_model( else: float8_handler = None if ref_model_fsdp_cfg is None: - ref_model_fsdp_cfg = FSDPConfig(recompute_ratio=0, cpu_offload=False, requires_grad=False) + ref_model_fsdp_cfg = FSDPConfig(recompute_ratio=0, cpu_offload=False, requires_grad=False, + enable_autocast=self.config.fsdp_cfg.enable_autocast) model = model.fully_shard(ref_model_fsdp_cfg, float8_handler) # type: ignore model.from_hf(hf_path=load_from) @@ -268,7 +279,11 @@ def compute_actor_logprobs( self, seq_ctx_list: list[SequenceContext], loss_ctx_input_list: list[RLLossContextInputItem] ) -> list[RLLossContextInputItem]: for seq_ctx, loss_ctx_input in zip(seq_ctx_list, loss_ctx_input_list): - output = self._engine.forward_only(seq_ctx=seq_ctx) + if self.config.fsdp_cfg.enable_autocast: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + output = self._engine.forward_only(seq_ctx=seq_ctx) + else: + output = self._engine.forward_only(seq_ctx=seq_ctx) loss_ctx_input.old_logprobs = gather_logprobs(output["logits"], loss_ctx_input.shifted_labels) return loss_ctx_input_list @@ -279,7 +294,11 @@ def compute_ref_logprobs( self._ref_model.to_device(DEVICE) for seq_ctx, loss_ctx_input in zip(seq_ctx_list, loss_ctx_input_list): with torch.no_grad(): - ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx=None) + if self.config.fsdp_cfg.enable_autocast: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx=None) + else: + ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx=None) ref_logprobs = gather_logprobs(ref_output["logits"], loss_ctx_input.shifted_labels) loss_ctx_input.ref_logprobs = ref_logprobs self._ref_model.to_device("cpu") @@ -402,7 +421,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): else: min_diff = torch.min(rollout_logprobs - old_logprobs) max_diff = torch.max(rollout_logprobs - old_logprobs) - mean_diff = torch.mean(rollout_logprobs - old_logprobs) + mean_diff = torch.sum(rollout_logprobs - old_logprobs) / len(old_logprobs) if rollout_logprobs.numel() == 1: std_diff = torch.tensor(0.0) else: @@ -600,7 +619,7 @@ def get_params(tensor_list, name_list, save_dtype): else: saved_list.append(f"layers.{i}.{sub_name}") local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() + # local_tensor = local_tensor.bfloat16() load_spec = language_model.load_spec_mapping.get(f"layers.{i}.{sub_name}") if isinstance(model.config, VisionComposeConfigProtocol): @@ -628,7 +647,7 @@ def get_params(tensor_list, name_list, save_dtype): if name in saved_list: continue local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() + # local_tensor = local_tensor.bfloat16() load_spec = model.load_spec_mapping.get(name) if isinstance(model.config, VisionComposeConfigProtocol): diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 3f7525153..aac09f8d0 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -268,6 +268,15 @@ def __init__( train_worker_cfg.log_dir = log_dir dataflow_config.worker_log_dir = log_dir rollout_config.worker_log_dir = log_dir + + if rollout_config.random_seed is None: + print(f'rollout_config.random_seed is None, set to {seed}') + rollout_config.random_seed = seed + + if os.environ.get("XTUNER_ENABLE_LOGPROB_ZERO_DIFF", "0") == "1": + print(f'XTUNER_ENABLE_LOGPROB_ZERO_DIFF is enabled, make sure the logprob diff is zero !!!') + train_worker_cfg.fsdp_cfg.enable_autocast = True + self._enable_evaluate = False self._enable_initial_evaluate = False if evaluator_config: @@ -385,7 +394,7 @@ def fit(self): """ self.logger.info("Start RL training") if self._enable_initial_evaluate and self._enable_evaluate and self._evaluator: - ray.get(self._rollout_env_controller.check_active_workers.remote()) + # ray.get(self._rollout_env_controller.check_active_workers.remote()) scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) trajectory_save_path = self.exp_dir / "eval_0_trajectory.jsonl" self._save_trajectories(eval_data_groups, trajectory_save_path) @@ -395,7 +404,7 @@ def fit(self): step_timer_dict = {} # 1. Rollout with timer("generation", step_timer_dict): - ray.get(self._rollout_env_controller.check_active_workers.remote()) + # ray.get(self._rollout_env_controller.check_active_workers.remote()) data_groups, multimodal_train_infos = ray.get(self._rollout_dataflow.run.remote()) # 2. Offload rollout models and save trajectories with timer("offload_and_dump", step_timer_dict):