From cfa76d9767abdc173a4b6ab2d9dde2eb6af75b61 Mon Sep 17 00:00:00 2001
From: llkn-2 <765732949@qq.com>
Date: Sun, 16 Mar 2025 17:09:27 +0800
Subject: [PATCH 1/6] [bugfix] fix validate
---
verl/trainer/ppo/ray_trainer.py | 4 ++++
verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py | 2 +-
2 files changed, 5 insertions(+), 1 deletion(-)
diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py
index 29bd1825e02..890b47bc5be 100644
--- a/verl/trainer/ppo/ray_trainer.py
+++ b/verl/trainer/ppo/ray_trainer.py
@@ -641,6 +641,9 @@ def _validate(self):
for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
+ test_batch = test_batch.repeat(
+ repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
+ )
# we only do validation on rule-based rm
if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model':
@@ -663,6 +666,7 @@ def _validate(self):
)
test_gen_batch.meta_info = {
+ 'n': 1,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
'recompute_log_prob': False,
diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
index 8e1de199fb6..e74864864da 100644
--- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
+++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
@@ -204,7 +204,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
'top_k': self.config.val_kwargs.top_k,
'top_p': self.config.val_kwargs.top_p,
'temperature': self.config.val_kwargs.temperature,
- 'n': self.config.val_kwargs.n,
+ # 'n': self.config.val_kwargs.n,
}
# users can customize different sampling_params at different run
From 63b2f86dfda969d9f7fb5acabdf32ddcfa4da090 Mon Sep 17 00:00:00 2001
From: llkn-2 <765732949@qq.com>
Date: Sun, 16 Mar 2025 17:12:34 +0800
Subject: [PATCH 2/6] [feat] support config template in dataset
---
verl/trainer/config/ppo_trainer.yaml | 1 +
verl/trainer/ppo/ray_trainer.py | 8 +++++--
verl/utils/dataset/rl_dataset.py | 35 ++++++++++++++++++++++------
verl/utils/dataset/utils.py | 5 ++++
4 files changed, 40 insertions(+), 9 deletions(-)
create mode 100644 verl/utils/dataset/utils.py
diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml
index bf486279bcb..763054a677c 100644
--- a/verl/trainer/config/ppo_trainer.yaml
+++ b/verl/trainer/config/ppo_trainer.yaml
@@ -13,6 +13,7 @@ data:
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
truncation: error
image_key: images
+ template_key: null
actor_rollout_ref:
hybrid_engine: True
diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py
index 890b47bc5be..2f66b870a0f 100644
--- a/verl/trainer/ppo/ray_trainer.py
+++ b/verl/trainer/ppo/ray_trainer.py
@@ -550,7 +550,9 @@ def _create_dataloader(self):
max_prompt_length=self.config.data.max_prompt_length,
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
- truncation=self.config.data.get('truncation', 'error'))
+ truncation=self.config.data.get('truncation', 'error'),
+ filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', True),
+ template_key=self.config.data.get('template_key', None))
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
@@ -574,7 +576,9 @@ def _create_dataloader(self):
max_prompt_length=self.config.data.max_prompt_length,
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
- truncation=self.config.data.get('truncation', 'error'))
+ truncation=self.config.data.get('truncation', 'error'),
+ filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', True),
+ template_key=self.config.data.get('template_key', None))
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
# Validation datasets are sent to inference engines as a whole batch,
diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py
index b9872072224..c84410531ef 100644
--- a/verl/utils/dataset/rl_dataset.py
+++ b/verl/utils/dataset/rl_dataset.py
@@ -26,7 +26,7 @@
from verl.utils.model import compute_position_id_with_mask
import verl.utils.torch_functional as verl_F
-
+from verl.utils.dataset.utils import INPUT_TEMPLATE
def collate_fn(data_list: list[dict]) -> dict:
tensors = defaultdict(list)
@@ -89,7 +89,9 @@ def __init__(self,
chat_template_func=None,
return_raw_chat=False,
truncation='error',
- filter_overlong_prompts=False):
+ filter_overlong_prompts=False,
+ template_key=None,
+ padding_size=None):
if not isinstance(parquet_files, (List, ListConfig)):
parquet_files = [parquet_files]
@@ -108,6 +110,12 @@ def __init__(self,
self.chat_template_func = chat_template_func
self.truncation = truncation
self.filter_overlong_prompts = filter_overlong_prompts
+ if template_key:
+ assert template_key in INPUT_TEMPLATE
+ self.input_template = INPUT_TEMPLATE[template_key]
+ else:
+ self.input_template = None
+ self.padding_size = padding_size if padding_size is not None else self.max_prompt_length
# whether to store the dataset in state_dict()
# default not store
@@ -135,9 +143,14 @@ def _read_files_and_tokenize(self):
if self.filter_overlong_prompts:
tokenizer = self.tokenizer
prompt_key = self.prompt_key
- self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
- tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
- axis=1)]
+ if self.input_template:
+ self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
+ tokenizer.encode(self.input_template.format(doc[prompt_key][0]['content']))) <= self.max_prompt_length,
+ axis=1)]
+ else:
+ self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
+ tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
+ axis=1)]
print(f'filter dataset len: {len(self.dataframe)}')
@@ -161,7 +174,13 @@ def __getitem__(self, item):
chat = row_dict.pop(self.prompt_key)
- prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
+ if self.input_template:
+ question = chat[0]["content"]
+ prompt_with_chat_template = self.input_template.format(question)
+ else:
+ prompt_with_chat_template = self.tokenizer.apply_chat_template(
+ chat, add_generation_prompt=True, tokenize=False
+ )
is_multi_modal = self.image_key in row_dict
if is_multi_modal: # expand image token
@@ -190,7 +209,8 @@ def __getitem__(self, item):
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template,
tokenizer=self.tokenizer,
- max_length=self.max_prompt_length,
+ # max_length=self.max_prompt_length,
+ max_length=self.padding_size,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation)
@@ -211,6 +231,7 @@ def __getitem__(self, item):
row_dict['attention_mask'] = attention_mask[0]
row_dict['position_ids'] = position_ids[0]
row_dict['raw_prompt_ids'] = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
+ row_dict["prompt_length"] = row_dict['position_ids'].max().item() + 1
# encode prompts without chat template
if self.return_raw_chat:
diff --git a/verl/utils/dataset/utils.py b/verl/utils/dataset/utils.py
new file mode 100644
index 00000000000..f3e87080845
--- /dev/null
+++ b/verl/utils/dataset/utils.py
@@ -0,0 +1,5 @@
+
+INPUT_TEMPLATE = {
+ "base": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . User: {} Please reason step by step, and put your final answer within \\boxed{{}}. Assistant: \n",
+ "distill": "<|begin▁of▁sentence|><|User|>{} Let's think step by step and output the final answer within \\boxed{{}}.<|Assistant|>\n"
+}
\ No newline at end of file
From 2a64d1de893c0370471c6147484b8cdbd3258d31 Mon Sep 17 00:00:00 2001
From: llkn-2 <765732949@qq.com>
Date: Sun, 16 Mar 2025 17:14:39 +0800
Subject: [PATCH 3/6] [feat] support deepscaler reward manager
---
verl/trainer/main_ppo.py | 3 +
verl/workers/reward_manager/__init__.py | 3 +-
verl/workers/reward_manager/deepscaler.py | 173 ++++++++++++++++++++++
3 files changed, 178 insertions(+), 1 deletion(-)
create mode 100644 verl/workers/reward_manager/deepscaler.py
diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py
index 3d5837a6c16..265dec90d35 100644
--- a/verl/trainer/main_ppo.py
+++ b/verl/trainer/main_ppo.py
@@ -142,6 +142,9 @@ def main_task(config):
elif reward_manager_name == 'prime':
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
+ elif reward_manager_name == 'deepscaler':
+ from verl.workers.reward_manager import DeepScalerRewardManager
+ reward_manager_cls = DeepScalerRewardManager
else:
raise NotImplementedError
diff --git a/verl/workers/reward_manager/__init__.py b/verl/workers/reward_manager/__init__.py
index 3de46a7717e..331b3a75ab4 100644
--- a/verl/workers/reward_manager/__init__.py
+++ b/verl/workers/reward_manager/__init__.py
@@ -13,4 +13,5 @@
# limitations under the License.
from .naive import NaiveRewardManager
-from .prime import PrimeRewardManager
\ No newline at end of file
+from .prime import PrimeRewardManager
+from .deepscaler import DeepScalerRewardManager
\ No newline at end of file
diff --git a/verl/workers/reward_manager/deepscaler.py b/verl/workers/reward_manager/deepscaler.py
new file mode 100644
index 00000000000..e5d8fde1baa
--- /dev/null
+++ b/verl/workers/reward_manager/deepscaler.py
@@ -0,0 +1,173 @@
+from verl import DataProto
+# from verl.utils.reward_score import _default_compute_score
+import torch
+import json
+
+class DeepScalerRewardManager:
+ """The reward manager.
+ """
+
+ def __init__(self, tokenizer, num_examine, compute_score=None, is_val=False, log_file=None) -> None:
+ self.tokenizer = tokenizer
+ self.num_examine = num_examine # the number of batches of decoded responses to print to the console
+ self.is_val = is_val
+ self.log_file = log_file
+ self.compute_score_fn = compute_score
+
+ def __call__(self, data: DataProto):
+ """We will expand this function gradually based on the available datasets"""
+
+ # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
+ if 'rm_scores' in data.batch.keys():
+ return data.batch['rm_scores']
+
+ reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
+ reward_info = {}
+
+ already_print_data_sources = {}
+
+ from concurrent.futures import ThreadPoolExecutor, as_completed
+ from typing import Dict, Any
+ #import threading
+ # Thread-safe dict for tracking printed data sources
+ # print_lock = threading.Lock()
+
+ def process_timeout_item(args, log_file):
+ # i, data_item, already_print_data_sources = args
+ i, data_item, already_print_data_sources, is_val = args
+ prompt_ids = data_item.batch['prompts']
+ prompt_length = prompt_ids.shape[-1]
+
+ valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
+ valid_prompt_ids = prompt_ids[-valid_prompt_length:]
+
+ response_ids = data_item.batch['responses']
+ valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
+ valid_response_ids = response_ids[:valid_response_length]
+
+ # decode
+ sequences = torch.cat((valid_prompt_ids, valid_response_ids))
+ sequences_str = self.tokenizer.decode(sequences)
+
+ ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
+
+ print("A task timed out!")
+ print(sequences_str)
+ print(f'ground_truth:{ground_truth}')
+
+ # Write query-score pairs to JSONL if log_file is provided
+ if log_file:
+ with open(log_file, "a", encoding="utf-8") as f:
+ record = {
+ "sequence": sequences_str,
+ "ground_truth": ground_truth,
+ "timeout": True,
+ }
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
+
+ return i, 0., valid_response_length, {'score': 0.}
+
+
+ def _print(data_item, reward_info, log_file=None):
+ prompt_ids = data_item.batch['prompts']
+ prompt_length = prompt_ids.shape[-1]
+
+ valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
+ valid_prompt_ids = prompt_ids[-valid_prompt_length:]
+
+ response_ids = data_item.batch['responses']
+ valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
+ valid_response_ids = response_ids[:valid_response_length]
+
+ # decode
+ sequences = torch.cat((valid_prompt_ids, valid_response_ids))
+ sequences_str = self.tokenizer.decode(sequences)
+
+ ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
+ print(sequences_str)
+ print(f'ground_truth:{ground_truth}')
+ print(reward_info)
+
+ # Write query-score pairs to JSONL if log_file is provided
+ if log_file:
+ with open(log_file, "a", encoding="utf-8") as f:
+ record = {
+ "sequence": sequences_str,
+ "ground_truth": ground_truth,
+ "reward": reward_info,
+ }
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
+
+ def process_item(args):
+ # i, data_item, already_print_data_sources = args
+ i, data_item, already_print_data_sources, is_val = args
+ prompt_ids = data_item.batch['prompts']
+ prompt_length = prompt_ids.shape[-1]
+
+ valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
+ valid_prompt_ids = prompt_ids[-valid_prompt_length:]
+
+ response_ids = data_item.batch['responses']
+ valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
+ valid_response_ids = response_ids[:valid_response_length]
+
+ # decode
+ sequences = torch.cat((valid_prompt_ids, valid_response_ids))
+ sequences_str = self.tokenizer.decode(sequences)
+
+ ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
+
+ # select rm_score
+ data_source = data_item.non_tensor_batch['data_source']
+ # compute_score_fn = _select_rm_score_fn(data_source)
+ # score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth)
+ score, info = self.compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth, is_val=is_val)
+
+ return i, score, valid_response_length, info
+
+ # Process items in parallel using ThreadPoolExecutor
+ # with ThreadPoolExecutor(max_workers=8) as executor:
+ # args = [(i, data[i], already_print_data_sources, self.is_val) for i in range(len(data))]
+ # results = list(executor.map(process_item, args))
+
+ import func_timeout
+ results = []
+ for i in range(len(data)):
+ args = (i, data[i], already_print_data_sources, self.is_val)
+ try:
+ result = process_item(args)
+ except func_timeout.FunctionTimedOut:
+ result = process_timeout_item(args, self.log_file)
+ results.append(result)
+
+ # with ThreadPoolExecutor(max_workers=8) as executor:
+ # args = [(i, data[i], already_print_data_sources, self.is_val) for i in range(len(data))]
+ # futures = [executor.submit(process_item, arg) for arg in args]
+
+ # results = []
+ # for i, future in enumerate(futures):
+ # try:
+ # result = future.result(timeout=60)
+ # results.append(result)
+ # except TimeoutError:
+ # print("A task timed out!")
+ # result = process_timeout_item(args[i], self.log_file)
+ # results.append(result)
+
+ _print(data[0], results[0][-1], log_file=self.log_file)
+ # for i in range(len(data)):
+ # _print(data[i], results[i][-1], log_file=self.log_file)
+
+ # Fill reward tensor with results
+ for i, score, valid_response_length, info in results:
+ reward_tensor[i, valid_response_length - 1] = score
+ for k, v in info.items():
+ if k not in reward_info:
+ reward_info[k] = torch.zeros(len(data))
+ reward_info[k][i] = v
+
+ # if self.is_val:
+ # return reward_tensor
+ # else:
+ # return reward_tensor, reward_info
+ return reward_tensor
\ No newline at end of file
From 1c715547dc5ecfd02bc3bf4748241060663cf752 Mon Sep 17 00:00:00 2001
From: llkn-2 <765732949@qq.com>
Date: Sun, 16 Mar 2025 17:15:23 +0800
Subject: [PATCH 4/6] [feat] support partial rollout
---
verl/protocol.py | 7 +
verl/trainer/config/ppo_trainer.yaml | 4 +
verl/trainer/main_ppo.py | 7 +-
.../ppo/ray_partial_rollout_trainer.py | 577 ++++++++++++++++++
.../rollout/vllm_rollout/vllm_rollout_spmd.py | 23 +-
5 files changed, 609 insertions(+), 9 deletions(-)
create mode 100644 verl/trainer/ppo/ray_partial_rollout_trainer.py
diff --git a/verl/protocol.py b/verl/protocol.py
index 94d7410f687..d6b0262b0d5 100644
--- a/verl/protocol.py
+++ b/verl/protocol.py
@@ -199,6 +199,13 @@ def __getitem__(self, item):
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
+ def __setitem__(self, item, value: 'DataProtoItem'):
+ if not isinstance(value, DataProtoItem):
+ raise TypeError(f"Expected value to be a DataProtoItem, got {type(value)}")
+
+ # TODO: update non_tensor_batch
+ self.batch[item] = value.batch
+
def __getstate__(self):
import io
buffer = io.BytesIO()
diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml
index 763054a677c..991fea5df16 100644
--- a/verl/trainer/config/ppo_trainer.yaml
+++ b/verl/trainer/config/ppo_trainer.yaml
@@ -190,3 +190,7 @@ trainer:
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
+ partial_rollout:
+ enable: False
+ max_response_length: 2048
+ train_num_threshold: 0.6
diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py
index 265dec90d35..b3cd6919bf4 100644
--- a/verl/trainer/main_ppo.py
+++ b/verl/trainer/main_ppo.py
@@ -15,6 +15,7 @@
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
+from verl.trainer.ppo.ray_partial_rollout_trainer import RayPPOPartialRolloutTrainer
import ray
import hydra
@@ -156,7 +157,11 @@ def main_task(config):
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
- trainer = RayPPOTrainer(config=config,
+ if config.trainer.partial_rollout.enable:
+ ppo_trainer_cls = RayPPOPartialRolloutTrainer
+ else:
+ ppo_trainer_cls = RayPPOTrainer
+ trainer = ppo_trainer_cls(config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
diff --git a/verl/trainer/ppo/ray_partial_rollout_trainer.py b/verl/trainer/ppo/ray_partial_rollout_trainer.py
new file mode 100644
index 00000000000..5892a6701ad
--- /dev/null
+++ b/verl/trainer/ppo/ray_partial_rollout_trainer.py
@@ -0,0 +1,577 @@
+import os
+import uuid
+from contextlib import contextmanager
+from dataclasses import dataclass, field
+from enum import Enum
+from pprint import pprint
+from typing import Type, Dict
+from copy import deepcopy
+
+import ray
+import numpy as np
+from codetiming import Timer
+from omegaconf import OmegaConf, open_dict
+from verl import DataProto
+from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto, DataProtoItem
+from verl.single_controller.base import Worker
+from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
+from verl.single_controller.ray.base import create_colocated_worker_cls
+from verl.trainer.ppo import core_algos
+from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
+from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
+from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
+from verl.utils.tracking import ValidationGenerationsLogger
+from torch.utils.data import RandomSampler, SequentialSampler
+from torchdata.stateful_dataloader import StatefulDataLoader
+
+from .ray_trainer import (
+ RayPPOTrainer,
+ Role,
+ WorkerType,
+ ResourcePoolManager,
+ _timer,
+ _compute_response_info,
+ compute_advantage,
+ reduce_metrics,
+ compute_data_metrics,
+ compute_timing_metrics
+)
+
+import torch
+
+def dataprotoitem_to_dataproto(item: DataProtoItem) -> DataProto:
+ """Convert a DataProtoItem to a DataProto object"""
+ return DataProto.from_dict(
+ tensors=item.batch, # TensorDict is already in correct format
+ non_tensors=item.non_tensor_batch, # Dict is already in correct format
+ meta_info=item.meta_info
+ )
+
+
+def expand_idx_to_group(seq_idxs, group_size):
+ group_seq_idxs = set()
+ for idx in seq_idxs:
+ group_idx = idx // group_size
+ for i in range(group_idx*group_size, (group_idx+1)*group_size):
+ group_seq_idxs.add(i)
+ return sorted(list(group_seq_idxs))
+
+
+def compute_generate_data_metrics(gen_batch):
+ response_info = _compute_response_info(gen_batch)
+ prompt_length = response_info['prompt_length']
+ response_length = response_info['response_length']
+ metrics = {
+ # prompt length
+ 'gen_batch/prompt_length/sum':
+ torch.sum(prompt_length).detach().item(),
+ 'gen_batch/prompt_length/max':
+ torch.max(prompt_length).detach().item(),
+ 'gen_batch/prompt_length/mean':
+ torch.mean(prompt_length).detach().item(),
+
+ # response length
+ 'gen_batch/response_length/sum':
+ torch.sum(response_length).detach().item(),
+ 'gen_batch/response_length/max':
+ torch.max(response_length).detach().item(),
+ 'gen_batch/response_length/mean':
+ torch.mean(response_length).detach().item(),
+ }
+ return metrics
+
+class RayPPOPartialRolloutTrainer(RayPPOTrainer):
+ def __init__(self,
+ config,
+ tokenizer,
+ role_worker_mapping: dict[Role, WorkerType],
+ resource_pool_manager: ResourcePoolManager,
+ ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
+ processor=None,
+ reward_fn=None,
+ val_reward_fn=None):
+ super().__init__(
+ config,
+ tokenizer,
+ role_worker_mapping,
+ resource_pool_manager,
+ ray_worker_group_cls,
+ processor,
+ reward_fn,
+ val_reward_fn
+ )
+
+ def _create_dataloader(self):
+ if self.config.trainer.partial_rollout.enable:
+ self.max_prompt_length_in_gen = (
+ self.config.data.max_prompt_length +
+ self.config.data.max_response_length -
+ self.config.trainer.partial_rollout.max_response_length
+ )
+ self.max_response_length_in_gen =self.config.trainer.partial_rollout.max_response_length
+ else:
+ self.max_prompt_length_in_gen = self.config.data.max_prompt_length
+ self.max_response_length_in_gen = self.config.data.max_response_length
+
+ # TODO: we have to make sure the batch size is divisible by the dp size
+ self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
+ tokenizer=self.tokenizer,
+ processor=self.processor,
+ prompt_key=self.config.data.prompt_key,
+ image_key=self.config.data.get('image_key', 'images'),
+ max_prompt_length=self.config.data.max_prompt_length,
+ return_raw_chat=self.config.data.get('return_raw_chat', False),
+ truncation=self.config.data.get('truncation', 'error'),
+ filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', True),
+ template_key=self.config.data.get('template_key', None),
+ padding_size=self.max_prompt_length_in_gen)
+ # use sampler for better ckpt resume
+ if self.config.data.shuffle:
+ train_dataloader_generator = torch.Generator()
+ train_dataloader_generator.manual_seed(self.config.data.get('seed', 1))
+ sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
+ else:
+ sampler = SequentialSampler(data_source=self.train_dataset)
+
+ self.train_dataloader = StatefulDataLoader(dataset=self.train_dataset,
+ # iter manually
+ batch_size=1,
+ num_workers=8,
+ drop_last=True,
+ collate_fn=collate_fn,
+ sampler=sampler)
+
+ self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
+ tokenizer=self.tokenizer,
+ processor=self.processor,
+ prompt_key=self.config.data.prompt_key,
+ image_key=self.config.data.get('image_key', 'images'),
+ max_prompt_length=self.config.data.max_prompt_length,
+ return_raw_chat=self.config.data.get('return_raw_chat', False),
+ truncation=self.config.data.get('truncation', 'error'),
+ filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', True),
+ template_key=self.config.data.get('template_key', None))
+ self.val_dataloader = StatefulDataLoader(
+ dataset=self.val_dataset,
+ # Validation datasets are sent to inference engines as a whole batch,
+ # which will schedule the memory themselves.
+ batch_size=len(self.val_dataset),
+ num_workers=8,
+ shuffle=False,
+ drop_last=False,
+ collate_fn=collate_fn)
+
+ assert len(self.train_dataloader) >= 1
+ assert len(
+ self.val_dataloader
+ ) == 1, "Validation dataloader must have a single batch, which inference engines will schedule the memory themselves."
+
+ print(f'Size of train dataloader: {len(self.train_dataloader)}')
+
+ # inject total_training_steps to actor/critic optim_config. This is hacky.
+ total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
+
+ if self.config.trainer.total_training_steps is not None:
+ total_training_steps = self.config.trainer.total_training_steps
+
+ self.total_training_steps = total_training_steps
+ print(f'Total training steps: {self.total_training_steps}')
+
+ OmegaConf.set_struct(self.config, True)
+ with open_dict(self.config):
+ self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
+ self.config.critic.optim.total_training_steps = total_training_steps
+
+ def _balance_gen_batch(self, batch: DataProto, metrics, logging_prefix='gen_seqlen'):
+ """Reorder the data on single controller such that each dp rank gets similar total tokens"""
+ attention_mask = batch.batch['attention_mask']
+ batch_size = attention_mask.shape[0]
+ global_seqlen_lst = batch.batch['attention_mask'].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
+ world_size = self.actor_rollout_wg.world_size
+ global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst,
+ k_partitions=world_size,
+ equal_size=True)
+ # reorder based on index. The data will be automatically equally partitioned by dispatch function
+ global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
+ batch.reorder(global_idx)
+ global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst,
+ partitions=global_partition_lst,
+ prefix=logging_prefix)
+ metrics.update(global_balance_stats)
+
+ idx_map = {}
+ for i, idx in enumerate(global_idx.tolist()):
+ idx_map[idx] = i
+
+ reorder_idx = []
+ for i in range(len(batch)):
+ reorder_idx.append(idx_map[i])
+ return torch.tensor(reorder_idx)
+
+ def _get_seq_idx_for_partial_rollout(self, batch):
+ ## unfinish
+ unfinish_mask = (
+ (batch.batch['responses'][:, -1] != self.tokenizer.eos_token_id) &
+ (batch.batch['responses'][:, -1] != self.tokenizer.pad_token_id)
+ )
+ ## unexceed
+ response_lengths = batch.batch['attention_mask'].sum(-1) - torch.tensor(batch.non_tensor_batch['prompt_length'].astype(int))
+ unexceed_mask = response_lengths < self.config.data.max_response_length
+ #TODO: add repeat detection
+ # pass
+ mask = unfinish_mask & unexceed_mask
+ return torch.nonzero(mask, as_tuple=True)[0].tolist()
+
+ def _recompute_batch(self, batch, old_log_probs):
+ from torch.nn.utils.rnn import pad_sequence
+ from verl.utils.model import compute_position_id_with_mask
+ from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
+
+ prompt_length = torch.tensor(batch.non_tensor_batch['prompt_length'].astype(int))
+ prompt_start_idx = (batch.batch['input_ids'] != self.tokenizer.pad_token_id).int().argmax(dim=1)
+ prompt_end_idx = prompt_start_idx + prompt_length
+ prompts = [batch.batch['input_ids'][i, prompt_start_idx[i] : prompt_end_idx[i]] for i in range(len(batch))]
+ prompts = torch.stack(
+ [pad_sequence_to_length(prompt, self.config.data.max_prompt_length, self.tokenizer.pad_token_id, left_pad=True) for prompt in prompts]
+ )
+
+ resp_length = batch.batch['attention_mask'].sum(-1) - prompt_length
+ resp_start_idx = prompt_end_idx
+ resp_end_idx = resp_start_idx + resp_length
+ # responses = [batch.batch['input_ids'][i, resp_start_idx[i]:] for i in range(len(batch))]
+ # responses = pad_sequence(responses, batch_first=True, padding_value=self.tokenizer.pad_token_id)
+ responses = [batch.batch['input_ids'][i, resp_start_idx[i] : resp_end_idx[i]] for i in range(len(batch))]
+ responses = pad_sequence(responses, batch_first=True, padding_value=self.tokenizer.pad_token_id)
+
+ old_log_probs = pad_sequence(old_log_probs, batch_first=True, padding_value=0.)
+ assert responses.shape == old_log_probs.shape, f"get responses.shape:{responses.shape}, old_log_probs.shape:{old_log_probs.shape}"
+
+ prompt_attention_mask = (prompts != self.tokenizer.pad_token_id).long()
+ response_attention_mask = get_eos_mask(
+ response_id=responses,
+ eos_token=self.tokenizer.eos_token_id,
+ dtype=torch.int64
+ )
+ attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
+
+ batch.batch['prompts'] = prompts
+ batch.batch['responses'] = responses
+ batch.batch['input_ids'] = torch.cat([prompts, responses], dim=-1)
+ batch.batch['attention_mask'] = attention_mask
+ batch.batch['position_ids'] = compute_position_id_with_mask(batch.batch['attention_mask'])
+ batch.batch['old_log_probs'] = old_log_probs
+ return batch
+
+ def fit(self):
+ """
+ The training loop of PPO.
+ The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
+ The light-weight advantage computation is done on the driver process.
+ """
+ from verl.utils.tracking import Tracking
+ from omegaconf import OmegaConf
+
+ logger = Tracking(project_name=self.config.trainer.project_name,
+ experiment_name=self.config.trainer.experiment_name,
+ default_backend=self.config.trainer.logger,
+ config=OmegaConf.to_container(self.config, resolve=True))
+
+ self.global_steps = 0
+
+ # load checkpoint before doing anything
+ self._load_checkpoint()
+
+ # perform validation before training
+ if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
+ val_metrics = self._validate()
+ pprint(f'Initial validation metrics: {val_metrics}')
+ logger.log(data=val_metrics, step=self.global_steps)
+ if self.config.trainer.get('val_only', False):
+ return
+
+ # we start from step 1
+ self.global_steps += 1
+
+ n_samples = self.config.actor_rollout_ref.rollout.n
+ self.partial_batch = DataProto()
+ self.partial_old_log_probs = []
+ for _ in range(self.config.trainer.total_epochs):
+ data_iter = iter(self.train_dataloader)
+ data_exhausted = False # Flag to indicate if the iterator is exhausted
+
+ while not data_exhausted:
+ metrics = {}
+ timing_raw = {}
+
+ new_batch = []
+ for _ in range(self.config.data.train_batch_size - len(self.partial_batch)//n_samples):
+ try:
+ batch_dict = next(data_iter)
+ del batch_dict['raw_prompt_ids']
+ new_batch.append(DataProto.from_single_dict(batch_dict))
+ except StopIteration:
+ data_exhausted = True
+
+ # If the iterator is exhausted, break the outer while loop as well
+ if data_exhausted:
+ print("Data iterator exhausted, breaking the loop.")
+ break
+
+ if len(new_batch) > 0:
+ new_batch = DataProto.concat(new_batch)
+ new_batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object)
+ new_batch.non_tensor_batch['continue_generate'] = np.array([False for _ in range(len(new_batch.batch))], dtype=object)
+ new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
+
+ gen_batch = []
+ # add data from new batch
+ if len(new_batch) > 0:
+ gen_batch.append(new_batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']))
+
+ # add data from partial batch
+ idx_in_partial_batch = []
+ if len(self.partial_batch) > 0:
+ idx_in_partial_batch = (np.where(self.partial_batch.non_tensor_batch['continue_generate']==True)[0]).tolist()
+ partial_gen_batch = dataprotoitem_to_dataproto(self.partial_batch[idx_in_partial_batch])
+ partial_gen_batch = partial_gen_batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
+ for key in partial_gen_batch.batch.keys():
+ partial_gen_batch.batch[key] = partial_gen_batch.batch[key][:, self.max_response_length_in_gen:]
+ gen_batch.append(partial_gen_batch)
+ gen_batch = DataProto.concat(gen_batch)
+ # pad to be divisible by dp_size
+ gen_batch, padding_size = pad_dataproto_to_divisor(gen_batch, self.actor_rollout_wg.world_size)
+
+ metrics['batch/partial_rollout_num'] = len(self.partial_batch)
+ metrics['batch/continue_generate_num'] = len(idx_in_partial_batch)
+ print(f"step: {self.global_steps}, len(new_batch): {len(new_batch)}, len(partial_batch):{len(self.partial_batch)}, ",
+ f"len(continue_gen):{len(idx_in_partial_batch)}, len(gen_batch): {len(gen_batch)}, padding_size:{padding_size}")
+
+ with _timer('step', timing_raw):
+ # generate a batch
+ with _timer('gen', timing_raw):
+ gen_batch.meta_info['n'] = 1
+ gen_batch.meta_info['max_tokens'] = self.max_response_length_in_gen
+ reorder_idx = self._balance_gen_batch(gen_batch, metrics)
+ gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
+ gen_batch_output.reorder(reorder_idx)
+ metrics.update(compute_generate_data_metrics(gen_batch_output))
+ batch = []
+ if len(new_batch) > 0:
+ new_batch_output = gen_batch_output[:len(new_batch)]
+ new_batch = new_batch.union(new_batch_output)
+ batch.append(new_batch)
+
+ if len(self.partial_batch) > 0 :
+ partial_batch_output = gen_batch_output[len(new_batch): len(new_batch)+len(idx_in_partial_batch)]
+ self.partial_batch[idx_in_partial_batch] = partial_batch_output
+ self.partial_batch.non_tensor_batch['continue_generate'][:] = False
+ batch.append(self.partial_batch)
+ batch = DataProto.concat(batch)
+
+ # recompute old_log_probs
+ with _timer('old_log_prob', timing_raw):
+ old_log_prob_proto = self.actor_rollout_wg.compute_log_prob(batch)
+
+ if self.config.trainer.partial_rollout.enable:
+ from verl.protocol import union_numpy_dict
+ from verl.utils.py_functional import union_two_dict
+ batch.non_tensor_batch = union_numpy_dict(batch.non_tensor_batch, old_log_prob_proto.non_tensor_batch)
+ batch.meta_info = union_two_dict(batch.meta_info, old_log_prob_proto.meta_info)
+
+ response_lengths = _compute_response_info(batch)['response_length'].int()
+ old_log_probs = old_log_prob_proto.batch['old_log_probs']
+ old_log_probs = [old_log_probs[i, :response_lengths[i]] for i in range(old_log_probs.shape[0])]
+
+ if len(self.partial_batch) > 0:
+ for i in range(len(self.partial_batch)):
+ idx_b = i + len(new_batch)
+ if i in idx_in_partial_batch:
+ old_log_probs[idx_b] = torch.cat(
+ (self.partial_old_log_probs[i], old_log_probs[idx_b])
+ )
+ else:
+ old_log_probs[idx_b] = self.partial_old_log_probs[i]
+ else:
+ batch = batch.union(old_log_prob_proto)
+
+ # get partial rollout
+ if self.config.trainer.partial_rollout.enable:
+ partial_idxs = self._get_seq_idx_for_partial_rollout(batch)
+ if len(partial_idxs) > 0:
+ batch.non_tensor_batch['continue_generate'][partial_idxs] = True
+ if self.config.algorithm.adv_estimator == "grpo":
+ partial_idxs = expand_idx_to_group(partial_idxs, n_samples)
+
+ remain_idxs = [i for i in range(len(batch)) if i not in partial_idxs]
+
+ print(f"step:{self.global_steps}, len(remain_idxs):{len(remain_idxs)}, len(partial_idxs):{len(partial_idxs)}")
+ if len(remain_idxs) < len(batch) * self.config.trainer.partial_rollout.train_num_threshold:
+ partial_idxs = list(range(len(batch)))
+ self.partial_batch = dataprotoitem_to_dataproto(batch[partial_idxs])
+ self.partial_old_log_probs = [old_log_probs[idx] for idx in partial_idxs]
+ continue
+ else:
+ if len(partial_idxs) > 0:
+ self.partial_batch = dataprotoitem_to_dataproto(batch[partial_idxs])
+ self.partial_old_log_probs = [old_log_probs[idx] for idx in partial_idxs]
+ else:
+ self.partial_batch = DataProto()
+ self.partial_old_log_probs = []
+
+ metrics['batch/train_seq_num'] = len(remain_idxs)
+ metrics['batch/train_new_num'] = len([i for i in remain_idxs if i < len(new_batch)])
+
+ batch = dataprotoitem_to_dataproto(batch[remain_idxs])
+ old_log_probs = [old_log_probs[idx] for idx in remain_idxs]
+ # reset prompt and response for training
+ batch = self._recompute_batch(batch, old_log_probs)
+ batch, _ = pad_dataproto_to_divisor(batch, self.actor_rollout_wg.world_size)
+
+ # compute values
+ if self.use_critic:
+ with _timer('values', timing_raw):
+ values = self.critic_wg.compute_values(batch)
+ batch = batch.union(values)
+
+ with _timer('adv', timing_raw):
+ # compute scores using reward model and/or reward function
+ if self.use_rm:
+ reward_tensor = self.rm_wg.compute_rm_score(batch)
+ batch = batch.union(reward_tensor)
+
+ with _timer('reward_fn', timing_raw):
+ # reward_tensor, reward_info = self.reward_fn(batch)
+ reward_tensor = self.reward_fn(batch)
+ batch.batch['token_level_scores'] = reward_tensor
+
+ # Rejection sampling based on rewards
+ # Group rewards by uid
+ uids = batch.non_tensor_batch['uid']
+ unique_uids = np.unique(uids)
+ valid_mask = torch.ones(len(uids), dtype=torch.bool)
+ solve_none = 0
+ solve_all = 0
+ for uid in unique_uids:
+ uid_mask = uids == uid
+ uid_rewards = reward_tensor[uid_mask].sum(-1) # Sum rewards for each sequence
+
+ # Check if all rewards are 0 or all are 1 for this uid
+ if (uid_rewards == 0).all():
+ valid_mask[uid_mask] = False
+ solve_none += 1
+ elif (uid_rewards == 1).all():
+ valid_mask[uid_mask] = False
+ solve_all += 1
+
+ # Log to metrics
+ metrics['batch/solve_none'] = solve_none
+ metrics['batch/solve_all'] = solve_all
+
+ # for key in reward_info:
+ # metrics[f'critic/{key}_reward/mean'] = reward_info[key].mean().item()
+
+ # if self.config.trainer.rejection_sample:
+ # # If no valid samples remain, skip this batch and get a new one
+ # if not valid_mask.any():
+ # continue
+
+ # # Filter batch to keep only valid samples
+ # batch = batch[valid_mask]
+ # batch = dataprotoitem_to_dataproto(batch)
+ # # Round down to the nearest multiple of world size
+ # num_trainer_replicas = self.actor_rollout_wg.world_size
+ # max_batch_size = (batch.batch['input_ids'].shape[0] // num_trainer_replicas) * num_trainer_replicas
+ # if not max_batch_size:
+ # # give up, you got everything either all wrong or right.
+ # continue
+
+ # size_mask = torch.zeros(batch.batch['input_ids'].shape[0], dtype=torch.bool)
+ # size_mask[:max_batch_size] = True
+ # batch = batch[size_mask]
+ # batch = dataprotoitem_to_dataproto(batch)
+
+ # # recompute old_log_probs
+ # with _timer('old_log_prob', timing_raw):
+ # old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
+ # batch = batch.union(old_log_prob)
+
+ if self.use_reference_policy:
+ # compute reference log_prob
+ with _timer('ref', timing_raw):
+ ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
+ batch = batch.union(ref_log_prob)
+
+ # compute rewards with KL penalty if needed
+ # Note: This kl penalty applied directly over the rewards is disabled for GRPO. The kl penalty is applied at dp_actor.py
+ # where it is subtracted directly from the policy loss
+ # if not self.config.actor_rollout_ref.actor.use_kl_loss:
+ # batch, kl_metrics = apply_kl_penalty(batch,
+ # kl_ctrl=self.kl_ctrl,
+ # kl_penalty=self.config.algorithm.kl_penalty)
+ # metrics.update(kl_metrics)
+ # else:
+ # batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
+
+ batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
+
+ # compute advantages, executed on the driver process
+ batch = compute_advantage(batch,
+ adv_estimator=self.config.algorithm.adv_estimator,
+ gamma=self.config.algorithm.gamma,
+ lam=self.config.algorithm.lam,
+ num_repeat=self.config.actor_rollout_ref.rollout.n)
+
+ # balance the number of valid tokens on each dp rank.
+ # Note that this breaks the order of data inside the batch.
+ # Please take care when you implement group based adv computation such as GRPO and rloo
+ if self.config.trainer.balance_batch:
+ self._balance_batch(batch, metrics=metrics)
+
+ # compute global_valid tokens
+ batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
+
+ # update critic
+ if self.use_critic:
+ with _timer('update_critic', timing_raw):
+ critic_output = self.critic_wg.update_critic(batch)
+ critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
+ metrics.update(critic_output_metrics)
+
+ # implement critic warmup
+ if self.config.trainer.critic_warmup <= self.global_steps:
+ # update actor
+ with _timer('update_actor', timing_raw):
+ actor_output = self.actor_rollout_wg.update_actor(batch)
+ actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
+ metrics.update(actor_output_metrics)
+
+ # validate
+ if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
+ self.global_steps % self.config.trainer.test_freq == 0:
+ with _timer('testing', timing_raw):
+ val_metrics: dict = self._validate()
+ pprint(f'validation metrics: {val_metrics}')
+ metrics.update(val_metrics)
+
+ if self.config.trainer.save_freq > 0 and \
+ self.global_steps % self.config.trainer.save_freq == 0:
+ with _timer('save_checkpoint', timing_raw):
+ self._save_checkpoint()
+
+ # collect metrics
+ metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
+ metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
+
+ # TODO: make a canonical logger that supports various backend
+ logger.log(data=metrics, step=self.global_steps)
+
+ self.global_steps += 1
+
+ if self.global_steps >= self.total_training_steps:
+
+ # perform validation after training
+ if self.val_reward_fn is not None:
+ val_metrics = self._validate()
+ pprint(f'Final validation metrics: {val_metrics}')
+ logger.log(data=val_metrics, step=self.global_steps)
+ return
diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
index e74864864da..b4777f206df 100644
--- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
+++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
@@ -207,6 +207,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
# 'n': self.config.val_kwargs.n,
}
+ # supporting adding any sampling params from meta_info
+ for k in prompts.meta_info.keys():
+ if hasattr(SamplingParams(), str(k)):
+ kwargs[k] = prompts.meta_info[k]
+
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
outputs = self.inference_engine.generate(
@@ -222,17 +227,19 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
for sample_id in range(len(output.outputs)):
response.append(output.outputs[sample_id].token_ids)
+ pad_response_length = prompts.meta_info.get('max_tokens', self.config.response_length)
response = pad_2d_list_to_length(response, self.pad_token_id,
- max_length=self.config.response_length).to(idx.device)
-
- if self.sampling_params.n > 1 and do_sample:
- idx = _repeat_interleave(idx, self.sampling_params.n)
- attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
- position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
- batch_size = batch_size * self.sampling_params.n
+ max_length=pad_response_length).to(idx.device)
+
+ n = prompts.meta_info.get('n', self.sampling_params.n)
+ if n > 1 and do_sample:
+ idx = _repeat_interleave(idx, n)
+ attention_mask = _repeat_interleave(attention_mask, n)
+ position_ids = _repeat_interleave(position_ids, n)
+ batch_size = batch_size * n
if 'multi_modal_inputs' in non_tensor_batch.keys():
non_tensor_batch['multi_modal_inputs'] = _repeat_interleave(non_tensor_batch['multi_modal_inputs'],
- self.sampling_params.n)
+ n)
seq = torch.cat([idx, response], dim=-1)
From 607704ec7635eb138db5db90172e892b18c9cb46 Mon Sep 17 00:00:00 2001
From: yqyao
Date: Sun, 16 Mar 2025 19:46:07 +0800
Subject: [PATCH 5/6] feat: reduce vllm GPU memory usage peak
---
verl/workers/sharding_manager/fsdp_vllm.py | 14 ++++++++++++--
1 file changed, 12 insertions(+), 2 deletions(-)
diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py
index f875d4d949b..66a9c4c2f33 100644
--- a/verl/workers/sharding_manager/fsdp_vllm.py
+++ b/verl/workers/sharding_manager/fsdp_vllm.py
@@ -84,6 +84,10 @@ def __enter__(self):
log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
params = self.module.state_dict()
+ # fsdp offload
+ # state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
+ # with FSDP.state_dict_type(self.module, StateDictType.SHARDED_STATE_DICT, state_dict_cfg):
+ # params = self.module.state_dict()
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
# Copy, not share memory
load_format = 'hf' if self.full_params else 'dtensor'
@@ -91,16 +95,22 @@ def __enter__(self):
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
self.inference_engine.sync_model_weights(params, load_format=load_format)
else:
- self.inference_engine.wake_up()
+ # self.inference_engine.wake_up()
+ # to reduce GPU memory usage peak
+ # level=1, mean load weights and kv cache, level=2 mean load weights, level=3 mean load kv cache
+ self.inference_engine.wake_up(level=2)
world_size = torch.distributed.get_world_size()
model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
loaded_params = model.load_weights(
((name, param.full_tensor() if world_size != 1 else param) for name, param in params.items()))
logger.info(f"vLLM load wegiths, loaded_params: {len(loaded_params)}")
+ del params
+ torch.cuda.empty_cache()
+ self.inference_engine.wake_up(level=3)
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)
- del params
+ # del params
log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger)
# TODO: offload FSDP model weights
From 1a0f6eb15789b15d7c022c8bb1b734e1b3fe8e61 Mon Sep 17 00:00:00 2001
From: yqyao
Date: Mon, 17 Mar 2025 10:57:07 +0800
Subject: [PATCH 6/6] bugfix: fix token ids exceeding the vocabulary size
---
verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
index b4777f206df..1653738d34b 100644
--- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
+++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
@@ -136,6 +136,7 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf
self.sampling_params = SamplingParams(**kwargs)
self.pad_token_id = tokenizer.pad_token_id
+ self.vocab_size = len(tokenizer)
@contextmanager
def update_sampling_params(self, **kwargs):
@@ -211,6 +212,12 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
for k in prompts.meta_info.keys():
if hasattr(SamplingParams(), str(k)):
kwargs[k] = prompts.meta_info[k]
+
+ # Tokens with IDs exceeding the vocabulary size should be ignored.
+ def process_token(token_ids, logits):
+ logits[self.vocab_size:] = float("-inf")
+ return logits
+ kwargs['logits_processors'] = [process_token]
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):