diff --git a/ci/scripts/CI_ENV.sh b/ci/scripts/CI_ENV.sh index 1555cb10d..1d20540e8 100644 --- a/ci/scripts/CI_ENV.sh +++ b/ci/scripts/CI_ENV.sh @@ -5,6 +5,8 @@ export INTERN_VL_1B_PATH=${CI_SHARE_MODEL}/InternVL3_5-1B-HF export VIDEO_ROOT=${CI_SHARE_DATA}/images export QWEN3_4B_PATH=${CI_SHARE_MODEL}/Qwen3-4B-Instruct-2507 export ROLLOUT_DATA_PATH=${CI_SHARE_DATA}/gsm8k/train.jsonl +export GEO3K_TRAIN_DATA_PATH=${CI_SHARE_DATA}/geometry3k/train.jsonl +export GEO3K_MEDIA_ROOT=${CI_SHARE_DATA}/geometry3k/ export DEEPSEEK_V3_PATH=${CI_SHARE_MODEL}/DeepSeek-V3.1 export GPT_OSS_MINI_PATH=${CI_SHARE_MODEL}/gpt-oss-20b-bf16 export ROLLOUT_TEST_DATA_PATH=${CI_SHARE_DATA}/gsm8k/test.jsonl diff --git a/tests/ray/test_rollout.py b/tests/ray/test_rollout.py index 3c8436079..2bb348e09 100644 --- a/tests/ray/test_rollout.py +++ b/tests/ray/test_rollout.py @@ -52,9 +52,9 @@ def init_config(self): self.max_response_length = 1024 self.rollout_cfg = RolloutConfig( env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, rollout_cross_node_comm=False, tensor_parallel_size=1, expert_parallel_size=1, @@ -91,7 +91,7 @@ def init_config(self): pack_level='none', group_by_length=False, ) - self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) self.replay_buffer_cfg = ReplayBufferConfig( dataset_cfg=self.train_dataset_cfg, dataloader_cfg=self.dataloader_cfg, @@ -102,6 +102,7 @@ def init_config(self): def setUp(self): ray.init(num_cpus=80, ignore_reinit_error=True) self.data_path = TRAIN_DATA_PATH + self.model_path = MODEL_PATH self.temp_dir = tempfile.TemporaryDirectory() self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") @@ -147,7 +148,7 @@ def test_lmdeploy_dataflow(self): deep=True, update=dict( expert_parallel_size=2, - model_path=MOE_MODEL_PATH, + model_path=self.model_path, model_name=os.path.basename(MOE_MODEL_PATH).lower(), tokenizer_path=MOE_MODEL_PATH, ), @@ -170,83 +171,125 @@ def test_lmdeploy_dataflow(self): self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) ray.get(self.test_env.shutdown.remote(), timeout=300) - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_dataflow_save_resume(self): - rollout_cfg = self.rollout_cfg - - self.dataflow_cfg.enable_partial_rollout = 0 + def _get_sorted_input_ids(self, responses): + """Helper to extract and sort input_ids from responses.""" + all_ids = [] + for data_items in responses[0]: + for data_item in data_items: + all_ids.extend(data_item.data.input_ids) + all_ids.sort() + return all_ids + + def _run_dataflow_save_resume_test(self, rollout_cfg, dataflow_cfg): + """ + Generic driver for dataflow save/resume tests. + """ + # 1. Initialize Environment and DataFlow + is_partial_rollout = dataflow_cfg.enable_partial_rollout == 1 self.test_env = SingleTurnEnvironment.remote( "test_env", self.pg, rollout_cfg=rollout_cfg, ) - self.test_flow = DataFlow.remote("test_env", - self.dataflow_cfg, - self.replay_buffer_cfg, - self.test_env - ) + self.test_flow = DataFlow.remote( + "test_env", + dataflow_cfg, + self.replay_buffer_cfg, + self.test_env + ) + + # 2. Initial Run ray.get(self.test_flow.run.remote(), timeout=300) - save_dir = Path(self.temp_dir.name) / 'checkpoints' / 'ckpt-step-2' - save_dir.mkdir(parents=True, exist_ok=True) + + # Capture status before saving (critical for partial rollout consistency check) + rl_status_before_save = ray.get(self.test_flow.get_replaybuffer_status.remote()) + # 3. Save + save_dir = Path(self.temp_dir.name) / 'checkpoints' / f'ckpt-step-2' + save_dir.mkdir(parents=True, exist_ok=True) ray.get(self.test_flow.save.remote(save_dir)) - responses_old = ray.get(self.test_flow.run.remote(), timeout=300) - ray.get(self.test_flow.resume.remote(save_dir)) - responses_new = ray.get(self.test_flow.run.remote(), timeout=300) + # Define run logic based on mode + def run_continuation(status_ref): + if is_partial_rollout: + remain = status_ref["rollout_paused_count"] + # Finish the remaining paused samples + return ray.get(self.test_flow.run.remote(num=remain, enable_partial_rollout=0), timeout=300) + else: + # Normal run + return ray.get(self.test_flow.run.remote(), timeout=300) - all_train_prompt_ids_old = [] - for data_items in responses_old[0]: - for data_item in data_items: - all_train_prompt_ids_old.extend(data_item.data.input_ids) + # continue running after save + responses_old = run_continuation(rl_status_before_save) + rb_status_old = ray.get(self.test_flow.get_replaybuffer_status.remote()) - all_train_prompt_ids_new = [] - for data_items in responses_new[0]: - for data_item in data_items: - all_train_prompt_ids_new.extend(data_item.data.input_ids) - all_train_prompt_ids_old.sort() - all_train_prompt_ids_new.sort() - self.assertEqual(all_train_prompt_ids_old, all_train_prompt_ids_new) + # resume from saved checkpoint + ray.get(self.test_flow.resume.remote(save_dir)) + rl_status_resume = ray.get(self.test_flow.get_replaybuffer_status.remote()) + responses_new = run_continuation(rl_status_resume) + rb_status_new = ray.get(self.test_flow.get_replaybuffer_status.remote()) + # 6. Cleanup ray.get(self.test_env.shutdown.remote(), timeout=300) - @unittest.skip("skip lmdeploy async dataflow after lmdeploy support abort_request") - def test_lmdeploy_async_dataflow(self): - self.dataflow_cfg.enable_partial_rollout = 1 - self.test_env = SingleTurnEnvironment.remote( - "test_env", - self.pg, - rollout_cfg=self.rollout_cfg, - ) - self.test_flow = DataFlow.remote("test_env", - self.dataflow_cfg, - self.replay_buffer_cfg, - self.test_env - ) - extra_params = {"stream": False, "return_token_ids": True, "return_logprobs": True} - dump_path = os.path.join(self.temp_dir.name, "unfinished_buffer.pickle") - responses = ray.get(self.test_flow.run.remote(extra_params=extra_params, dump=True, dump_path=dump_path)) - finished_samples_count = sum(1 for data in responses[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") - self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) - status = ray.get(self.test_flow.get_replaybuffer_status.remote()) - finished_count = status["rollout_finished_count"] # 已经去掉了data_flow返回的数量 - paused_count = status["rollout_paused_count"] - sample_count = status["action_count"] - self.assertEqual(len(responses) + finished_count + paused_count, sample_count) - self.assertEqual(len(responses), self.dataflow_cfg.global_batch_size) + # 7. Assertions + # Compare Data + ids_old = self._get_sorted_input_ids(responses_old) + ids_new = self._get_sorted_input_ids(responses_new) + self.assertEqual(ids_old, ids_new) + + # Compare ReplayBuffer Status (Old run vs New run) + for key in rb_status_old: + self.assertEqual(rb_status_old[key], rb_status_new[key]) + + # For partial rollout, verify the resumed state matches the saved state + if is_partial_rollout: + for key in rl_status_before_save: + self.assertEqual(rl_status_before_save[key], rl_status_resume[key]) + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_lmdeploy_dataflow_save_resume(self): + rollout_cfg = self.rollout_cfg + dataflow_cfg = self.dataflow_cfg + dataflow_cfg.enable_partial_rollout = 0 + self._run_dataflow_save_resume_test(rollout_cfg, dataflow_cfg) - ray.get(self.test_flow.clear_replaybuffer.remote()) - response_resume = ray.get(self.test_flow.run.remote(extra_params=extra_params, resume=True, resume_path=dump_path)) - finished_resume_samples_count = sum(1 for data in response_resume[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") - self.assertEqual(finished_resume_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) - status = ray.get(self.test_flow.get_replaybuffer_status.remote()) - finished_count = status["rollout_finished_count"] - paused_count = status["rollout_paused_count"] - sample_count = status["action_count"] - self.assertEqual(len(response_resume) + finished_count + paused_count, sample_count) - self.assertEqual(len(response_resume), self.dataflow_cfg.global_batch_size) - ray.get(self.test_env.shutdown.remote()) + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_lmdeploy_dataflow_save_resume_with_partial_rollout(self): + rollout_cfg = self.rollout_cfg + dataflow_cfg = self.dataflow_cfg + dataflow_cfg.max_concurrent = 4 + dataflow_cfg.enable_partial_rollout = 1 + self._run_dataflow_save_resume_test(rollout_cfg, dataflow_cfg) + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_lmdeploy_dataflow_save_resume_with_partial_rollout_r3(self): + model_path = MOE_MODEL_PATH + rollout_cfg = RolloutConfig( + env="test_rollout", + model_path=model_path, + model_name=os.path.basename(model_path).lower(), + tokenizer_path=model_path, + rollout_cross_node_comm=False, + tensor_parallel_size=1, + expert_parallel_size=1, + gpus_per_node=8, + dtype="bfloat16", + launch_server_method="ray", + context_length=self.max_prompt_length + self.max_response_length, + worker_log_dir=self.worker_log_dir, + enable_return_routed_experts=True, + ) + dataflow_cfg = DataFlowConfig( + env="test", + prompt_repeat_k=2, + global_batch_size=2, + enable_partial_rollout=1, + max_concurrent=4, + worker_log_dir=self.worker_log_dir, + ) + self._run_dataflow_save_resume_test(rollout_cfg, dataflow_cfg) @unittest.skip("skip lmdeploy turbomind generate test due to ci environment issue") def test_lmdeploy_turbomind_generate(self): @@ -289,44 +332,6 @@ def test_sglang_dataflow(self): self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) ray.get(self.test_env.shutdown.remote(), timeout=300) print("responses: ", responses) - - @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "lmdeploy backend is not enabled") - def test_sglang_async_dataflow(self): - self.dataflow_cfg.enable_partial_rollout = 1 - self.rollout_cfg.launch_server_method="multiprocessing" - self.test_env = SingleTurnEnvironment.remote( - "test_env", - self.pg, - rollout_cfg=self.rollout_cfg, - ) - self.test_flow = DataFlow.remote("test_env", - self.dataflow_cfg, - self.replay_buffer_cfg, - self.test_env - ) - extra_params = {"stream": True, "return_token_ids": True, "return_logprobs": True} - dump_path = os.path.join(self.temp_dir.name, "unfinished_buffer.pickle") - responses = ray.get(self.test_flow.run.remote(extra_params=extra_params, dump=False, dump_path=dump_path)) - finished_samples_count = sum(1 for data in responses[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") - self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) - status = ray.get(self.test_flow.get_replaybuffer_status.remote()) - finished_count = status["rollout_finished_count"] # 已经去掉了data_flow返回的数量 - paused_count = status["rollout_paused_count"] - sample_count = status["action_count"] - self.assertEqual(len(responses) + finished_count + paused_count, sample_count) - self.assertEqual(len(responses), self.dataflow_cfg.global_batch_size) - - ray.get(self.test_flow.clear_replaybuffer.remote()) - response_resume = ray.get(self.test_flow.run.remote(extra_params=extra_params, resume=True, resume_path=dump_path)) - finished_resume_samples_count = sum(1 for data in response_resume[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") - self.assertEqual(finished_resume_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) - status = ray.get(self.test_flow.get_replaybuffer_status.remote()) - finished_count = status["rollout_finished_count"] - paused_count = status["rollout_paused_count"] - sample_count = status["action_count"] - self.assertEqual(len(response_resume) + finished_count + paused_count, sample_count) - self.assertEqual(len(response_resume), self.dataflow_cfg.global_batch_size) - ray.get(self.test_env.shutdown.remote()) if __name__ == "__main__": unittest.main() diff --git a/tests/ray/test_vl_rollout.py b/tests/ray/test_vl_rollout.py new file mode 100644 index 000000000..7aa273b5a --- /dev/null +++ b/tests/ray/test_vl_rollout.py @@ -0,0 +1,204 @@ +import os +import subprocess +from functools import wraps +import unittest +import tempfile +import ray +import torch +from pathlib import Path +from transformers import AutoTokenizer +import tempfile +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.judger.controller import JudgerConfig +from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.ray.environment import SingleTurnEnvironment +from xtuner.v1.ray.rollout import RolloutController +from xtuner.v1.ray.judger import JudgerController +from xtuner.v1.datasets import RLTokenizeFnConfig, build_datasets, build_dataloader +from xtuner.v1.datasets.config import ( + DataloaderConfig, + DatasetConfig, +) + +MODEL_PATH=os.getenv("QWEN3_VL_DENSE_PATH") +TRAIN_DATA_PATH=os.getenv("GEO3K_TRAIN_DATA_PATH") +MEDIA_ROOT=os.getenv("GEO3K_MEDIA_ROOT") + +resource_map = { + "npu": "NPU", + "cuda": "GPU", +} +class TestRollout(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + + def init_config(self): + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=8, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, # 16 GB + ) + self.max_prompt_length = 2048 + self.max_response_length = 2048 + self.rollout_cfg = RolloutConfig( + env="test_rollout", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + rollout_cross_node_comm=False, + tensor_parallel_size=2, + expert_parallel_size=1, + gpus_per_node=8, # gpu: 8, npu: 16 + dtype="bfloat16", + launch_server_method="ray", + context_length=self.max_prompt_length + self.max_response_length, + worker_log_dir=self.worker_log_dir, + ) + from xtuner.v1.ray.judger.geo3k import GEO3KJudgerConfig + geo3k_judger_config = GEO3KJudgerConfig() + self.judger_cfg = JudgerConfig(reward_judger_configs=[geo3k_judger_config]) + + self.dataflow_cfg = DataFlowConfig( + env="test", + prompt_repeat_k=2, + global_batch_size=2, + enable_partial_rollout=0, + max_retry_times=1, + worker_log_dir=self.worker_log_dir, + ) + self.training_sample_params = SampleParams( + max_tokens=self.max_response_length, + ) + self.evaluation_sample_params = SampleParams( + max_tokens=self.max_response_length, + top_p=1.0, + temperature=0.0, + top_k=1, + ) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig + tokenize_fn_cfg = Qwen3VLTokenizeFnConfig(processor_path=self.model_path) + train_dataset_cfg = [ + { + "dataset": DatasetConfig(name="geo3k", + anno_path=self.data_path, + class_name='VLMJsonlDataset', + media_root=self.media_root, + sample_ratio=1.0), + "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length, + tokenize_fn_cfg=tokenize_fn_cfg), + } + ] + dataloader_config = DataloaderConfig(num_workers=8, + collator="fake_collator", + pack_level="none") + + self.replay_buffer_cfg = ReplayBufferConfig( + dataset_cfg=train_dataset_cfg, + dataloader_cfg=dataloader_config, + tokenizer=self.tokenizer, + ) + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.data_path = TRAIN_DATA_PATH + self.model_path = MODEL_PATH + self.media_root = MEDIA_ROOT + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.init_config() + self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + + def tearDown(self): + ray.shutdown() + # When lmdeploy enable ep>1, it uses deep_ep. Buffer implicit destroy would cause some ray actor stucked. + # Use pkill cleen up ray::WorkerWrapper process after close ray cluster connection as workaround. + # TODO(chenchiyu): add excplicit deep_ep destroy in lmdeploy. + self._cleanup_lmdeploy_ray_worker_wrapper() + self.temp_dir.cleanup() + + def _cleanup_lmdeploy_ray_worker_wrapper(self): + try: + result = subprocess.run(["pkill", "-f", "ray::RayWorkerWrapper*"], capture_output=True, text=True, timeout=10) + if result.returncode != 0: + print(f"pkill command failed with return code {result.returncode}: {result.stderr}." + " Maybe no lmdeploy ray::RayWorkerWrapper processes found.") + except Exception as e: + print(f"Error stopping ray::RayWorkerWrapper cluster: {e}") + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_vl_resume_with_partial_rollout(self): + rollout_cfg = self.rollout_cfg + # rollout_cfg.enable_return_routed_experts = True + self.test_env = SingleTurnEnvironment.remote( + "test_env", + self.pg, + rollout_cfg=rollout_cfg, + ) + dataflow_cfg = self.dataflow_cfg + dataflow_cfg.max_concurrent = 4 + dataflow_cfg.enable_partial_rollout = 0 + self.test_flow = DataFlow.remote("test_env", + dataflow_cfg, + self.replay_buffer_cfg, + self.test_env + ) + ray.get(self.test_flow.run.remote(), timeout=300) + rl_status_save = ray.get(self.test_flow.get_replaybuffer_status.remote()) + save_dir = Path(self.temp_dir.name) / 'checkpoints' / 'ckpt-step-2' + save_dir.mkdir(parents=True, exist_ok=True) + + ray.get(self.test_flow.save.remote(save_dir)) + remain_paused_samples_old = rl_status_save["rollout_paused_count"] + responses_old = ray.get(self.test_flow.run.remote(num=remain_paused_samples_old, enable_partial_rollout=0), timeout=300) + rb_status_old = ray.get(self.test_flow.get_replaybuffer_status.remote()) + + mm_info_old = [] + for multimodal_train_infos in responses_old[1]: + image_grid_thw = multimodal_train_infos["image_grid_thw"].numpy().flatten() + mm_info_old.extend(image_grid_thw) + + ray.get(self.test_flow.resume.remote(save_dir)) + rl_status_resume = ray.get(self.test_flow.get_replaybuffer_status.remote()) + remain_paused_samples_new = rl_status_resume["rollout_paused_count"] + responses_new = ray.get(self.test_flow.run.remote(num=remain_paused_samples_new, enable_partial_rollout=0), timeout=300) + rb_status_new = ray.get(self.test_flow.get_replaybuffer_status.remote()) + + mm_info_new = [] + for multimodal_train_infos in responses_new[1]: + image_grid_thw = multimodal_train_infos["image_grid_thw"].numpy().flatten() + mm_info_new.extend(image_grid_thw) + + all_train_prompt_ids_old = [] + for data_items in responses_old[0]: + for data_item in data_items: + all_train_prompt_ids_old.extend(data_item.data.input_ids) + + all_train_prompt_ids_new = [] + for data_items in responses_new[0]: + for data_item in data_items: + all_train_prompt_ids_new.extend(data_item.data.input_ids) + + all_train_prompt_ids_old.sort() + all_train_prompt_ids_new.sort() + mm_info_old.sort() + mm_info_new.sort() + self.assertEqual(all_train_prompt_ids_old, all_train_prompt_ids_new) + self.assertEqual(mm_info_old, mm_info_new) + for key in rb_status_old: + self.assertEqual(rb_status_old[key], rb_status_new[key]) + for key in rl_status_save: + self.assertEqual(rl_status_save[key], rl_status_resume[key]) + ray.get(self.test_env.shutdown.remote(), timeout=300) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index c5cfe1f57..cf40abf6c 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -1,8 +1,11 @@ -from typing import Any, Dict, List, Optional, TypedDict +from __future__ import annotations +from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict + +import torch from cyclopts import Parameter from pydantic import BaseModel, ConfigDict, Field -from typing_extensions import Annotated +from typing_extensions import Annotated, NotRequired from xtuner.v1.utils import StrEnum @@ -12,6 +15,13 @@ from xtuner.v1.utils.logger import get_logger +if TYPE_CHECKING: + import ray + + RayObjectRef = ray.ObjectRef +else: + RayObjectRef: TypeAlias = Any + logger = get_logger() @@ -46,7 +56,7 @@ class RolloutState(StrEnum): SKIPPED = "skipped" @staticmethod - def from_str(state_str: str) -> "RolloutState": + def from_str(state_str: str) -> RolloutState: for state in RolloutState: if state.value == state_str: return state @@ -72,6 +82,12 @@ class RLUIDItem(BaseModel): version: int = -1 +class MultimodalTrainInfo(TypedDict): + pixel_values: NotRequired[torch.Tensor | RayObjectRef | None] # type: ignore[valid-type] + image_grid_thw: NotRequired[torch.Tensor] + position_ids: NotRequired[torch.Tensor] + + class RLDatasetItem(BaseModel): """Represents the data structure output from the dataset. @@ -85,15 +101,19 @@ class RLDatasetItem(BaseModel): extra_info (Dict[str, Any]): Additional user-defined information. """ - model_config = ConfigDict(extra="forbid") - messages: Optional[List[Dict[str, Any]]] = None - input_ids: Optional[List[int]] = None - num_tokens: Optional[int] = None - ability: Optional[str] = None - reward_model: Optional[Dict[str, Any]] = None - data_source: Optional[Dict[str, Any]] = None - extra_info: Dict[str, Any] = dict() - multimodal_train_info: Optional[Dict[str, Any]] = None + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + messages: list[dict[str, Any]] | None = None + input_ids: list[int] | None = None + num_tokens: int | None = None + ability: str | None = None + reward_model: dict[str, Any] | None = None + data_source: dict[str, Any] | None = None + extra_info: dict[str, Any] = dict() + multimodal_train_info: MultimodalTrainInfo | None = None + + +class RolloutExtraInfo(TypedDict): + routed_experts: NotRequired[list[int] | RayObjectRef] # type: ignore[valid-type] class RLRolloutResponseItem(BaseModel): @@ -109,12 +129,12 @@ class RLRolloutResponseItem(BaseModel): """ model_config = ConfigDict(extra="forbid") - response: Optional[str] = None - response_ids: Optional[List[int]] = None - num_return_tokens: Optional[int] = None - finish_reason: Optional[str] = None # "stop", "length", "abort", "failed", "skipped" - logprobs: Optional[List[float]] = None - extra_info: Dict[str, Any] = dict() + response: str | None = None + response_ids: list[int] | None = None + num_return_tokens: int | None = None + finish_reason: str | None = None # "stop", "length", "abort", "failed", "skipped" + logprobs: list[float] | None = None + extra_info: RolloutExtraInfo = Field(default_factory=dict) state: RolloutState = RolloutState.INIT @@ -128,15 +148,15 @@ class RLJudgerResponseItem(BaseModel): """ model_config = ConfigDict(extra="forbid") - uid: Optional[int] = None - reward: Dict[str, Any] = Field(default_factory=lambda: {"score": 0.0, "val": 0.0}) - extra_info: Dict[str, Any] = dict() + uid: int | None = None + reward: dict[str, Any] = Field(default_factory=lambda: {"score": 0.0, "val": 0.0}) + extra_info: dict[str, Any] = dict() class RLAgentDataItem(BaseModel): # todo: define agent output data structure model_config = ConfigDict(extra="forbid") - extra_info: Dict[str, Any] = dict() + extra_info: dict[str, Any] = dict() class RLEnvDataItem(BaseModel): @@ -154,7 +174,7 @@ class RLEnvDataItem(BaseModel): rollout: RLRolloutResponseItem = RLRolloutResponseItem() judger: RLJudgerResponseItem = RLJudgerResponseItem() agent: RLAgentDataItem = RLAgentDataItem() - extra_info: Dict[str, Any] = dict() + extra_info: dict[str, Any] = dict() class RLExtraDataItem(BaseModel): @@ -168,7 +188,7 @@ class RLExtraDataItem(BaseModel): model_config = ConfigDict(extra="forbid") retry_times: int = 0 - extra_info: Dict[str, Any] = dict() + extra_info: dict[str, Any] = dict() class RLDataFlowItem(BaseModel): @@ -191,7 +211,7 @@ class RLDataFlowItem(BaseModel): extra_info: RLExtraDataItem = RLExtraDataItem() -def is_valid_for_replaybuffer(group_data_items: List[RLDataFlowItem]) -> bool: +def is_valid_for_replaybuffer(group_data_items: list[RLDataFlowItem]) -> bool: """Checks if a group of data items is valid for insertion into the replay buffer. @@ -220,7 +240,7 @@ def is_valid_for_replaybuffer(group_data_items: List[RLDataFlowItem]) -> bool: return True -def is_valid_for_training(group_data_items: List[RLDataFlowItem]) -> bool: +def is_valid_for_training(group_data_items: list[RLDataFlowItem]) -> bool: """Checks if a group of data items is valid for a training step. Args: @@ -318,8 +338,8 @@ class SampleParams(BaseModel): frequency_penalty: Annotated[float, Parameter(help="The parameter for frequency penalty.")] = 0.0 min_tokens: Annotated[int, Parameter(help="Minimum number of tokens to generate.")] = 0 max_tokens: Annotated[int, Parameter(help="Maximum number of tokens to generate.")] = 2048 - stops: Annotated[List[str], Parameter(help="List of stop sequences.")] = [] - stop_token_ids: Annotated[List[int], Parameter(help="List of stop token IDs.")] = [] + stops: Annotated[list[str], Parameter(help="List of stop sequences.")] = [] + stop_token_ids: Annotated[list[int], Parameter(help="List of stop token IDs.")] = [] skip_special_tokens: Annotated[bool, Parameter(help="Whether to skip special tokens.")] = True @@ -337,8 +357,8 @@ class RolloutExtraParams(TypedDict): # 说明: 这里没定义API server情况数据格式,因为直接使用openai server的格式 class RLRolloutRequestItem(BaseModel): model_config = ConfigDict(extra="forbid") - messages: List[Dict[str, Any]] - tools: List = Field(default_factory=list) + messages: list[dict[str, Any]] + tools: list = Field(default_factory=list) tool_choice: str = "auto" sample_params: SampleParams = Field(default_factory=SampleParams) - extra_params: Dict[str, Any] = Field(default_factory=dict) + extra_params: dict[str, Any] = Field(default_factory=dict) diff --git a/xtuner/v1/ray/dataflow/flow.py b/xtuner/v1/ray/dataflow/flow.py index aa14c5122..358625714 100644 --- a/xtuner/v1/ray/dataflow/flow.py +++ b/xtuner/v1/ray/dataflow/flow.py @@ -146,6 +146,7 @@ def _reset_internal_states( global_batch_size: Optional[int] = None, sample_params: Optional[SampleParams] = None, extra_params: Optional[Dict] = None, + enable_partial_rollout: Optional[bool] = None, ): """Resets all internal state variables of DataFlow.""" self.finished_samples_count = 0 @@ -156,6 +157,10 @@ def _reset_internal_states( else: self.target_batch_size = self.config.global_batch_size + if enable_partial_rollout is not None: + self.enable_partial_rollout = enable_partial_rollout + else: + self.enable_partial_rollout = self.config.enable_partial_rollout self.sample_params = sample_params if sample_params else self.config.sample_params self.extra_params = extra_params if extra_params else self.config.extra_params logger_msg = ( @@ -256,10 +261,13 @@ async def concurrent_task_runner(self): pbar.n = self.finished_samples_count pbar.refresh() next_update_threshold += update_step + self.logger.info( + f"waiting_tasks: {len(waiting_tasks)}, finished_samples_count: {self.finished_samples_count}" + ) while len(waiting_tasks) < self.config.max_concurrent: # In async mode, we keep spawning. In sync mode, we stop if we have enough tasks in flight. if ( - not self.config.enable_partial_rollout + not self.enable_partial_rollout and self.finished_samples_count + len(waiting_tasks) >= self.target_batch_size ): break @@ -331,10 +339,7 @@ async def run( num: Optional[int] = None, sample_params: Optional[SampleParams] = None, extra_params: Optional[Dict] = None, - dump: bool = False, - dump_path: Optional[str] = None, - resume: bool = False, - resume_path: Optional[str] = None, + enable_partial_rollout: Optional[bool] = None, ): """Starts the data generation process. @@ -344,20 +349,13 @@ async def run( Returns: List[RLDataFlowItem]: A list of collected training samples. """ - self._reset_internal_states(global_batch_size=num, sample_params=sample_params, extra_params=extra_params) - - if resume: - assert resume_path, "Resuming is enabled but no resume path is provided." - self.logger.info(f"Resuming replay buffer from {resume_path}") - await self.replay_buffer.resume_storage.remote(resume_path) - + self._reset_internal_states( + global_batch_size=num, + sample_params=sample_params, + extra_params=extra_params, + enable_partial_rollout=enable_partial_rollout, + ) await self.concurrent_task_runner() - - if dump: - assert dump_path, "Dumping is enabled but no dump path is provided." - self.logger.info(f"Dump replay buffer from {dump_path}") - await self.replay_buffer.dump_storage.remote(dump_path) - return await self.replay_buffer.get_samples.remote(self.target_batch_size) # type: ignore[attr-defined] def logging_replaybuffer_state(self): diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 98a8e7524..c2cc7287b 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -1,10 +1,10 @@ -import itertools from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union from uuid import uuid4 +import numpy import ray import torch from cyclopts import Parameter @@ -14,6 +14,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from xtuner.v1.data_proto.rl_data import ( + MultimodalTrainInfo, RLDataFlowItem, RLDatasetItem, RLExtraDataItem, @@ -122,7 +123,9 @@ def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta) -> List[RLDataFlowIt replay_meta.observation_ids, replay_meta.observation_refs, replay_meta.observation_versions ): env_data = ray.get(obs_ref) - ray._private.internal_api.free(obs_ref) + # NOTE: This mapping function used by both dump and get. ObjectRefs are kept during dump (for training continuity) + # but released during get (via del replaymeta) when no longer needed. So we do not free them manually here. + # ray._private.internal_api.free(obs_ref) item = RLDataFlowItem( uid=RLUIDItem(env=env_str, root_id=root_id, action_id=action_id, observation_id=obs_id, version=version), @@ -295,9 +298,6 @@ def sample(self, env: str, enable_partial_rollout: int, prompt_repeat_k: int) -> # note: Sample grouped sample at once. They share the same action_id return self.sample_from_datasets(env, prompt_repeat_k) - def resume(self, num: int) -> None: - self.train_dataloader_iter = itertools.islice(self.train_dataloader, num, None) - class ReplayBufferStorage: """Handles the storage of experiences for the replay buffer.""" @@ -317,7 +317,6 @@ def __init__(self, worker_log_dir): list ) # action_id: [observation_id, observation_id, ...] self.logger = get_logger(log_dir=worker_log_dir, tag="ReplayBuffer") - self._multimodal_train_infos: Dict[int, Dict[str, Any]] = {} def add(self, grouped_dataitem: List[RLDataFlowItem]): """Adds a group of data items to the storage. @@ -350,10 +349,12 @@ def add(self, grouped_dataitem: List[RLDataFlowItem]): # observation for observation_id in replay_meta.observation_ids: - self._action2observations[action_id].append(observation_id) self._observations[observation_id] = replay_meta self._observations2states[observation_id] = replay_meta.state - self._states[replay_meta.state].append(observation_id) + if observation_id not in self._action2observations[action_id]: + self._action2observations[action_id].append(observation_id) + if observation_id not in self._states[replay_meta.state]: + self._states[replay_meta.state].append(observation_id) def clear(self): attrs_to_clear = [ @@ -369,7 +370,7 @@ def clear(self): for attr in attrs_to_clear: getattr(self, attr).clear() - def get(self, global_batch_size: int) -> Tuple[List[List[RLDataFlowItem]], List[Dict[str, Any] | None]]: + def get(self, global_batch_size: int) -> Tuple[List[List[RLDataFlowItem]], List[MultimodalTrainInfo | None]]: """Retrieves a batch of finished sample groups from the buffer. Args: @@ -397,6 +398,11 @@ def get(self, global_batch_size: int) -> Tuple[List[List[RLDataFlowItem]], List[ remain_finished_list = self._returned[global_batch_size:] for action_id in target_finished_list: replay_meta = self._actions.pop(action_id) + observation_ids = self._action2observations.pop(action_id) + for obs_id in observation_ids: + self._observations.pop(obs_id) + self._observations2states.pop(obs_id) + # todo: add an unified state management replay_meta.state = RolloutState.ARCHIVED group_samples = mapping_replaymeta_to_dataitem(replay_meta) @@ -446,7 +452,73 @@ def print(self): ) self.logger.info(log_message) - def dump(self, file_path: str): + def resolve_ray_objects(self, data_item: RLDataFlowItem): + """Resolves ray.ObjectRefs in a RLDataFlowItem to their actual values. + + Args: + data_item (RLDataFlowItem): The data item containing ray.ObjectRefs. + Returns: + RLDataFlowItem: The data item with ray.ObjectRefs resolved. + """ + + # Resolve data.multimodal_train_info + if hasattr(data_item.data, "multimodal_train_info"): + multimodal_info = data_item.data.multimodal_train_info + if multimodal_info and "pixel_values" in multimodal_info: + pixel_values_ref = multimodal_info["pixel_values"] + if isinstance(pixel_values_ref, ObjectRef): + multimodal_info["pixel_values"] = ray.get(pixel_values_ref) + data_item.data.multimodal_train_info = multimodal_info + # Resolve rollout.extra_info.router_experts + if "routed_experts" in data_item.env.rollout.extra_info: + if isinstance(data_item.env.rollout.extra_info["routed_experts"], ObjectRef): + data_item.env.rollout.extra_info["routed_experts"] = ray.get( + data_item.env.rollout.extra_info["routed_experts"] + ) + self.logger.info("Resolved routed_experts ObjectRef in rollout.extra_info") + + def convert_to_ray_objref(self, data_item: RLDataFlowItem): + """Converts large tensors in RLDataFlowItem to ray.ObjectRefs. + + Args: + data_item (RLDataFlowItem): The data item containing large tensors. + Returns: + RLDataFlowItem: The data item with large tensors converted to ray.ObjectRefs. + """ + # convert data.multimodal_train_info to ray.ObjectRef + if hasattr(data_item.data, "multimodal_train_info"): + multimodal_info = data_item.data.multimodal_train_info + if multimodal_info and "pixel_values" in multimodal_info: + pixel_values_ref = ray.put(multimodal_info["pixel_values"]) + del multimodal_info["pixel_values"] + data_item.data.multimodal_train_info = pixel_values_ref + # convert rollout.extra_info.router_experts to ray.ObjectRef + if "routed_experts" in data_item.env.rollout.extra_info: + routed_experts_ref = ray.put(data_item.env.rollout.extra_info["routed_experts"]) + del data_item.env.rollout.extra_info["routed_experts"] + data_item.env.rollout.extra_info["routed_experts"] = routed_experts_ref + + def has_objectref(self, item: RLDataFlowItem) -> bool: + def check(obj): + if isinstance(obj, ray.ObjectRef): + return True + if isinstance(obj, BaseModel): + return any(check(getattr(obj, f)) for f in obj.model_fields) + if isinstance(obj, (list, tuple, set)): + return any(check(x) for x in obj) + if isinstance(obj, dict): + return any(check(v) for v in obj.values()) + if isinstance(obj, (str, int, float, bool, type(None), torch.Tensor, numpy.ndarray)): + return False + # 如果不满足以上类型,抛出错误,防止意想不到的问题 + raise TypeError( + f"Unsupported type: {type(obj)} in {obj} " + f"Expected ray.ObjectRef, BaseModel, list/tuple/set, dict, or primitive types." + ) + + return check(item) + + def dump(self, file_path: Path): """Dumps the entire state of the replay buffer storage to a single file, resolving all ray.ObjectRefs to their actual values. @@ -454,60 +526,60 @@ def dump(self, file_path: str): file_path (str): The path to the file where the state will be saved. """ - import os - import pickle - - self.logger.info(f"Starting to dump ReplayBufferStorage state to {file_path}...") - os.makedirs(os.path.dirname(file_path), exist_ok=True) - - all_data_items = [] - for replay_meta in self._actions.values(): - group_data_items = mapping_replaymeta_to_dataitem(replay_meta) - all_data_items.append(group_data_items) + all_data_items = [mapping_replaymeta_to_dataitem(replay_meta) for replay_meta in self._actions.values()] + + for data_items in all_data_items: + for item in data_items: + self.resolve_ray_objects(item) + res = self.has_objectref(item) + assert not res, "ReplayBufferStorage.dump found unresolved ray.ObjectRef in RLDataFlowItem" + + state = { + "_paused": self._paused, + "_returned": self._returned, + "_actions": all_data_items, + "_root2actions": dict(self._root2actions), + "_observations2states": self._observations2states, + "_states": dict(self._states), + "_action2observations": dict(self._action2observations), + } - with open(file_path, "wb") as f: - pickle.dump(all_data_items, f) + torch.save(state, file_path) self.logger.info(f"ReplayBufferStorage state dumped to {file_path}") - def resume(self, file_path: Path | str): + def resume(self, file_path: Path): """Resumes the replay buffer storage from a single file. Args: file_path (str): The path to the file from which to restore the state. """ - import os - import pickle - - self.logger.info(f"Starting to resume ReplayBufferStorage state from {file_path}...") - if not os.path.exists(file_path): - self.logger.error(f"State file not found: {file_path}. Cannot resume.") - return - - with open(file_path, "rb") as f: - all_data_items = pickle.load(f) - - for group_data_items in all_data_items: - replay_meta = mapping_dataitem_to_replaymeta(group_data_items) - root_id = replay_meta.root_id + if len(self._actions) > 0: + self.logger.warning("ReplayBufferStorage is not empty. Resuming will overwrite the existing state.") + self.clear() + + state = torch.load(file_path, map_location="cpu", weights_only=False) + + self._paused = state["_paused"] + self._returned = state["_returned"] + self._root2actions = defaultdict(list, state["_root2actions"]) + self._observations2states = state["_observations2states"] + self._states = defaultdict(list, state["_states"]) + self._action2observations = defaultdict(list, state["_action2observations"]) + + dump_actions = state["_actions"] + # 重建 _actions 和 _observations: 与replaymeta相关 + for group_dataitem in dump_actions: + for data_item in group_dataitem: + self.convert_to_ray_objref(data_item) + replay_meta = mapping_dataitem_to_replaymeta(group_dataitem) action_id = replay_meta.action_id - state_str = replay_meta.state - if state_str == "abort": - self._paused.append(action_id) - elif state_str == "returned": - self._returned.append(action_id) - self._root2actions[root_id].append(action_id) self._actions[action_id] = replay_meta - for observation_id in replay_meta.observation_ids: - self._action2observations[action_id].append(observation_id) + for observation_id in self._action2observations[action_id]: self._observations[observation_id] = replay_meta - self._observations2states[observation_id] = replay_meta.state - self._states[replay_meta.state].append(observation_id) self.logger.info(f"ReplayBufferStorage state successfully resumed from {file_path}") - self.print() - @ray.remote class ReplayBuffer: @@ -547,6 +619,7 @@ def __init__( self.storage, ) self.post_processor_func = config.postprocessor_func + self.logger = get_logger(log_dir=config.worker_log_dir, tag="ReplayBuffer") def get_train_dataset_length(self): """Returns the length of the training dataloader.""" @@ -607,25 +680,6 @@ def print(self): """Prints the current state of the replay buffer storage.""" self.storage.print() - def dump_storage(self, file_path: str): - """Dumps the replay buffer's storage to a file. - - Args: - file_path (str): The path to the file for saving the data. - """ - self.storage.dump(file_path) - - def resume_storage(self, file_path: Path | str): - """Resumes the replay buffer's storage from a file. - - Args: - file_path (str): The path to the file from which to restore the - state. - """ - self.storage.resume(file_path) - num = self.storage.get_prompt_num() - self.sampler.resume(num) - def status(self): return self.storage.status() @@ -637,10 +691,16 @@ def save(self, file_path: Path | str): """ if isinstance(file_path, str): file_path = Path(file_path) + + # save dataloader dataloader_path = file_path / "dataloader" dataloader_state = self._dataloader.get_state_dict(self.sampler.reduced_consumed_samples) torch.save(dataloader_state, dataloader_path) + # save storage + rb_storage_path = file_path / "replay_buffer_storage.pth" + self.storage.dump(rb_storage_path) + def resume(self, file_path: Path | str): """Resumes the replay buffer's storage from a file. @@ -651,16 +711,28 @@ def resume(self, file_path: Path | str): if isinstance(file_path, str): file_path = Path(file_path) dataloader_path = file_path / "dataloader" - dataloader_state = torch.load(dataloader_path, map_location=DEVICE) - self._dataloader.load_state_dict(dataloader_state) - - self.sampler = Sampler( - self._dataloader, - self.tokenizer, - self.storage, - ) - self.sampler.reduced_consumed_samples = dataloader_state["sampler"]["step"] - self.sampler.cur_epoch = dataloader_state["sampler"]["epoch"] + if dataloader_path.exists(): + dataloader_state = torch.load(dataloader_path, map_location=DEVICE) + self._dataloader.load_state_dict(dataloader_state) + + # resume dataloader + self.sampler = Sampler( + self._dataloader, + self.tokenizer, + self.storage, + ) + self.sampler.reduced_consumed_samples = dataloader_state["sampler"]["step"] + self.sampler.cur_epoch = dataloader_state["sampler"]["epoch"] + else: + self.logger.warning(f"Dataloader state file {dataloader_path} does not exist. Skipping dataloader resume.") + # resume storage + rb_storage_path = file_path / "replay_buffer_storage.pth" + if rb_storage_path.exists(): + self.storage.resume(rb_storage_path) + else: + self.logger.warning( + f"ReplayBufferStorage state file {rb_storage_path} does not exist. Skipping storage resume." + ) def get_finished_samples(self): """Returns the number of finished sample groups in the storage."""