Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ci/scripts/CI_ENV.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ export INTERN_VL_1B_PATH=${CI_SHARE_MODEL}/InternVL3_5-1B-HF
export VIDEO_ROOT=${CI_SHARE_DATA}/images
export QWEN3_4B_PATH=${CI_SHARE_MODEL}/Qwen3-4B-Instruct-2507
export ROLLOUT_DATA_PATH=${CI_SHARE_DATA}/gsm8k/train.jsonl
export GEO3K_TRAIN_DATA_PATH=${CI_SHARE_DATA}/geometry3k/train.jsonl
export GEO3K_MEDIA_ROOT=${CI_SHARE_DATA}/geometry3k/
export DEEPSEEK_V3_PATH=${CI_SHARE_MODEL}/DeepSeek-V3.1
export GPT_OSS_MINI_PATH=${CI_SHARE_MODEL}/gpt-oss-20b-bf16
export ROLLOUT_TEST_DATA_PATH=${CI_SHARE_DATA}/gsm8k/test.jsonl
Expand Down
213 changes: 109 additions & 104 deletions tests/ray/test_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def init_config(self):
self.max_response_length = 1024
self.rollout_cfg = RolloutConfig(
env="test_rollout",
model_path=MODEL_PATH,
model_name=os.path.basename(MODEL_PATH).lower(),
tokenizer_path=MODEL_PATH,
model_path=self.model_path,
model_name=os.path.basename(self.model_path).lower(),
tokenizer_path=self.model_path,
rollout_cross_node_comm=False,
tensor_parallel_size=1,
expert_parallel_size=1,
Expand Down Expand Up @@ -91,7 +91,7 @@ def init_config(self):
pack_level='none',
group_by_length=False,
)
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.replay_buffer_cfg = ReplayBufferConfig(
dataset_cfg=self.train_dataset_cfg,
dataloader_cfg=self.dataloader_cfg,
Expand All @@ -102,6 +102,7 @@ def init_config(self):
def setUp(self):
ray.init(num_cpus=80, ignore_reinit_error=True)
self.data_path = TRAIN_DATA_PATH

self.model_path = MODEL_PATH
self.temp_dir = tempfile.TemporaryDirectory()
self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs")
Expand Down Expand Up @@ -147,7 +148,7 @@ def test_lmdeploy_dataflow(self):
deep=True,
update=dict(
expert_parallel_size=2,
model_path=MOE_MODEL_PATH,
model_path=self.model_path,
model_name=os.path.basename(MOE_MODEL_PATH).lower(),
tokenizer_path=MOE_MODEL_PATH,
),
Expand All @@ -170,83 +171,125 @@ def test_lmdeploy_dataflow(self):
self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
ray.get(self.test_env.shutdown.remote(), timeout=300)

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_dataflow_save_resume(self):
rollout_cfg = self.rollout_cfg

self.dataflow_cfg.enable_partial_rollout = 0
def _get_sorted_input_ids(self, responses):
"""Helper to extract and sort input_ids from responses."""
all_ids = []
for data_items in responses[0]:
for data_item in data_items:
all_ids.extend(data_item.data.input_ids)
all_ids.sort()
return all_ids

def _run_dataflow_save_resume_test(self, rollout_cfg, dataflow_cfg):
"""
Generic driver for dataflow save/resume tests.
"""
# 1. Initialize Environment and DataFlow
is_partial_rollout = dataflow_cfg.enable_partial_rollout == 1
self.test_env = SingleTurnEnvironment.remote(
"test_env",
self.pg,
rollout_cfg=rollout_cfg,
)
self.test_flow = DataFlow.remote("test_env",
self.dataflow_cfg,
self.replay_buffer_cfg,
self.test_env
)
self.test_flow = DataFlow.remote(
"test_env",
dataflow_cfg,
self.replay_buffer_cfg,
self.test_env
)

# 2. Initial Run
ray.get(self.test_flow.run.remote(), timeout=300)
save_dir = Path(self.temp_dir.name) / 'checkpoints' / 'ckpt-step-2'
save_dir.mkdir(parents=True, exist_ok=True)

# Capture status before saving (critical for partial rollout consistency check)
rl_status_before_save = ray.get(self.test_flow.get_replaybuffer_status.remote())

# 3. Save
save_dir = Path(self.temp_dir.name) / 'checkpoints' / f'ckpt-step-2'
save_dir.mkdir(parents=True, exist_ok=True)
ray.get(self.test_flow.save.remote(save_dir))
responses_old = ray.get(self.test_flow.run.remote(), timeout=300)

ray.get(self.test_flow.resume.remote(save_dir))
responses_new = ray.get(self.test_flow.run.remote(), timeout=300)
# Define run logic based on mode
def run_continuation(status_ref):
if is_partial_rollout:
remain = status_ref["rollout_paused_count"]
# Finish the remaining paused samples
return ray.get(self.test_flow.run.remote(num=remain, enable_partial_rollout=0), timeout=300)
else:
# Normal run
return ray.get(self.test_flow.run.remote(), timeout=300)

all_train_prompt_ids_old = []
for data_items in responses_old[0]:
for data_item in data_items:
all_train_prompt_ids_old.extend(data_item.data.input_ids)
# continue running after save
responses_old = run_continuation(rl_status_before_save)
rb_status_old = ray.get(self.test_flow.get_replaybuffer_status.remote())

all_train_prompt_ids_new = []
for data_items in responses_new[0]:
for data_item in data_items:
all_train_prompt_ids_new.extend(data_item.data.input_ids)

all_train_prompt_ids_old.sort()
all_train_prompt_ids_new.sort()
self.assertEqual(all_train_prompt_ids_old, all_train_prompt_ids_new)
# resume from saved checkpoint
ray.get(self.test_flow.resume.remote(save_dir))
rl_status_resume = ray.get(self.test_flow.get_replaybuffer_status.remote())
responses_new = run_continuation(rl_status_resume)
rb_status_new = ray.get(self.test_flow.get_replaybuffer_status.remote())

# 6. Cleanup
ray.get(self.test_env.shutdown.remote(), timeout=300)

@unittest.skip("skip lmdeploy async dataflow after lmdeploy support abort_request")
def test_lmdeploy_async_dataflow(self):
self.dataflow_cfg.enable_partial_rollout = 1
self.test_env = SingleTurnEnvironment.remote(
"test_env",
self.pg,
rollout_cfg=self.rollout_cfg,
)
self.test_flow = DataFlow.remote("test_env",
self.dataflow_cfg,
self.replay_buffer_cfg,
self.test_env
)
extra_params = {"stream": False, "return_token_ids": True, "return_logprobs": True}
dump_path = os.path.join(self.temp_dir.name, "unfinished_buffer.pickle")
responses = ray.get(self.test_flow.run.remote(extra_params=extra_params, dump=True, dump_path=dump_path))
finished_samples_count = sum(1 for data in responses[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length")
self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
status = ray.get(self.test_flow.get_replaybuffer_status.remote())
finished_count = status["rollout_finished_count"] # 已经去掉了data_flow返回的数量
paused_count = status["rollout_paused_count"]
sample_count = status["action_count"]
self.assertEqual(len(responses) + finished_count + paused_count, sample_count)
self.assertEqual(len(responses), self.dataflow_cfg.global_batch_size)
# 7. Assertions
# Compare Data
ids_old = self._get_sorted_input_ids(responses_old)
ids_new = self._get_sorted_input_ids(responses_new)
self.assertEqual(ids_old, ids_new)

# Compare ReplayBuffer Status (Old run vs New run)
for key in rb_status_old:
self.assertEqual(rb_status_old[key], rb_status_new[key])

# For partial rollout, verify the resumed state matches the saved state
if is_partial_rollout:
for key in rl_status_before_save:
self.assertEqual(rl_status_before_save[key], rl_status_resume[key])

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_dataflow_save_resume(self):
rollout_cfg = self.rollout_cfg
dataflow_cfg = self.dataflow_cfg
dataflow_cfg.enable_partial_rollout = 0
self._run_dataflow_save_resume_test(rollout_cfg, dataflow_cfg)

ray.get(self.test_flow.clear_replaybuffer.remote())
response_resume = ray.get(self.test_flow.run.remote(extra_params=extra_params, resume=True, resume_path=dump_path))
finished_resume_samples_count = sum(1 for data in response_resume[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length")
self.assertEqual(finished_resume_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
status = ray.get(self.test_flow.get_replaybuffer_status.remote())
finished_count = status["rollout_finished_count"]
paused_count = status["rollout_paused_count"]
sample_count = status["action_count"]
self.assertEqual(len(response_resume) + finished_count + paused_count, sample_count)
self.assertEqual(len(response_resume), self.dataflow_cfg.global_batch_size)
ray.get(self.test_env.shutdown.remote())
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_dataflow_save_resume_with_partial_rollout(self):
rollout_cfg = self.rollout_cfg
dataflow_cfg = self.dataflow_cfg
dataflow_cfg.max_concurrent = 4
dataflow_cfg.enable_partial_rollout = 1
self._run_dataflow_save_resume_test(rollout_cfg, dataflow_cfg)

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_dataflow_save_resume_with_partial_rollout_r3(self):
model_path = MOE_MODEL_PATH
rollout_cfg = RolloutConfig(
env="test_rollout",
model_path=model_path,
model_name=os.path.basename(model_path).lower(),
tokenizer_path=model_path,
rollout_cross_node_comm=False,
tensor_parallel_size=1,
expert_parallel_size=1,
gpus_per_node=8,
dtype="bfloat16",
launch_server_method="ray",
context_length=self.max_prompt_length + self.max_response_length,
worker_log_dir=self.worker_log_dir,
enable_return_routed_experts=True,
)
dataflow_cfg = DataFlowConfig(
env="test",
prompt_repeat_k=2,
global_batch_size=2,
enable_partial_rollout=1,
max_concurrent=4,
worker_log_dir=self.worker_log_dir,
)
self._run_dataflow_save_resume_test(rollout_cfg, dataflow_cfg)

@unittest.skip("skip lmdeploy turbomind generate test due to ci environment issue")
def test_lmdeploy_turbomind_generate(self):
Expand Down Expand Up @@ -289,44 +332,6 @@ def test_sglang_dataflow(self):
self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
ray.get(self.test_env.shutdown.remote(), timeout=300)
print("responses: ", responses)

@unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "lmdeploy backend is not enabled")
def test_sglang_async_dataflow(self):
self.dataflow_cfg.enable_partial_rollout = 1
self.rollout_cfg.launch_server_method="multiprocessing"
self.test_env = SingleTurnEnvironment.remote(
"test_env",
self.pg,
rollout_cfg=self.rollout_cfg,
)
self.test_flow = DataFlow.remote("test_env",
self.dataflow_cfg,
self.replay_buffer_cfg,
self.test_env
)
extra_params = {"stream": True, "return_token_ids": True, "return_logprobs": True}
dump_path = os.path.join(self.temp_dir.name, "unfinished_buffer.pickle")
responses = ray.get(self.test_flow.run.remote(extra_params=extra_params, dump=False, dump_path=dump_path))
finished_samples_count = sum(1 for data in responses[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length")
self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
status = ray.get(self.test_flow.get_replaybuffer_status.remote())
finished_count = status["rollout_finished_count"] # 已经去掉了data_flow返回的数量
paused_count = status["rollout_paused_count"]
sample_count = status["action_count"]
self.assertEqual(len(responses) + finished_count + paused_count, sample_count)
self.assertEqual(len(responses), self.dataflow_cfg.global_batch_size)

ray.get(self.test_flow.clear_replaybuffer.remote())
response_resume = ray.get(self.test_flow.run.remote(extra_params=extra_params, resume=True, resume_path=dump_path))
finished_resume_samples_count = sum(1 for data in response_resume[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length")
self.assertEqual(finished_resume_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
status = ray.get(self.test_flow.get_replaybuffer_status.remote())
finished_count = status["rollout_finished_count"]
paused_count = status["rollout_paused_count"]
sample_count = status["action_count"]
self.assertEqual(len(response_resume) + finished_count + paused_count, sample_count)
self.assertEqual(len(response_resume), self.dataflow_cfg.global_batch_size)
ray.get(self.test_env.shutdown.remote())

if __name__ == "__main__":
unittest.main()
Loading