Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ def __getitem__(self, item):
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)

def __setitem__(self, item, value: 'DataProtoItem'):
if not isinstance(value, DataProtoItem):
raise TypeError(f"Expected value to be a DataProtoItem, got {type(value)}")

# TODO: update non_tensor_batch
self.batch[item] = value.batch

def __getstate__(self):
import io
buffer = io.BytesIO()
Expand Down
5 changes: 5 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ data:
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
truncation: error
image_key: images
template_key: null

actor_rollout_ref:
hybrid_engine: True
Expand Down Expand Up @@ -189,3 +190,7 @@ trainer:
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
partial_rollout:
enable: False
max_response_length: 2048
train_num_threshold: 0.6
10 changes: 9 additions & 1 deletion verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.trainer.ppo.ray_partial_rollout_trainer import RayPPOPartialRolloutTrainer

import ray
import hydra
Expand Down Expand Up @@ -142,6 +143,9 @@ def main_task(config):
elif reward_manager_name == 'prime':
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
elif reward_manager_name == 'deepscaler':
from verl.workers.reward_manager import DeepScalerRewardManager
reward_manager_cls = DeepScalerRewardManager
else:
raise NotImplementedError

Expand All @@ -153,7 +157,11 @@ def main_task(config):

resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

trainer = RayPPOTrainer(config=config,
if config.trainer.partial_rollout.enable:
ppo_trainer_cls = RayPPOPartialRolloutTrainer
else:
ppo_trainer_cls = RayPPOTrainer
trainer = ppo_trainer_cls(config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
Expand Down
577 changes: 577 additions & 0 deletions verl/trainer/ppo/ray_partial_rollout_trainer.py

Large diffs are not rendered by default.

12 changes: 10 additions & 2 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,9 @@ def _create_dataloader(self):
max_prompt_length=self.config.data.max_prompt_length,
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation=self.config.data.get('truncation', 'error'))
truncation=self.config.data.get('truncation', 'error'),
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', True),
template_key=self.config.data.get('template_key', None))
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
Expand All @@ -574,7 +576,9 @@ def _create_dataloader(self):
max_prompt_length=self.config.data.max_prompt_length,
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation=self.config.data.get('truncation', 'error'))
truncation=self.config.data.get('truncation', 'error'),
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', True),
template_key=self.config.data.get('template_key', None))
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
# Validation datasets are sent to inference engines as a whole batch,
Expand Down Expand Up @@ -641,6 +645,9 @@ def _validate(self):

for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
test_batch = test_batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
)

# we only do validation on rule-based rm
if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model':
Expand All @@ -663,6 +670,7 @@ def _validate(self):
)

test_gen_batch.meta_info = {
'n': 1,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
'recompute_log_prob': False,
Expand Down
35 changes: 28 additions & 7 deletions verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from verl.utils.model import compute_position_id_with_mask
import verl.utils.torch_functional as verl_F

from verl.utils.dataset.utils import INPUT_TEMPLATE

def collate_fn(data_list: list[dict]) -> dict:
tensors = defaultdict(list)
Expand Down Expand Up @@ -89,7 +89,9 @@ def __init__(self,
chat_template_func=None,
return_raw_chat=False,
truncation='error',
filter_overlong_prompts=False):
filter_overlong_prompts=False,
template_key=None,
padding_size=None):
if not isinstance(parquet_files, (List, ListConfig)):
parquet_files = [parquet_files]

Expand All @@ -108,6 +110,12 @@ def __init__(self,
self.chat_template_func = chat_template_func
self.truncation = truncation
self.filter_overlong_prompts = filter_overlong_prompts
if template_key:
assert template_key in INPUT_TEMPLATE
self.input_template = INPUT_TEMPLATE[template_key]
else:
self.input_template = None
self.padding_size = padding_size if padding_size is not None else self.max_prompt_length

# whether to store the dataset in state_dict()
# default not store
Expand Down Expand Up @@ -135,9 +143,14 @@ def _read_files_and_tokenize(self):
if self.filter_overlong_prompts:
tokenizer = self.tokenizer
prompt_key = self.prompt_key
self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
axis=1)]
if self.input_template:
self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
tokenizer.encode(self.input_template.format(doc[prompt_key][0]['content']))) <= self.max_prompt_length,
axis=1)]
else:
self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
axis=1)]

print(f'filter dataset len: {len(self.dataframe)}')

Expand All @@ -161,7 +174,13 @@ def __getitem__(self, item):

chat = row_dict.pop(self.prompt_key)

prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
if self.input_template:
question = chat[0]["content"]
prompt_with_chat_template = self.input_template.format(question)
else:
prompt_with_chat_template = self.tokenizer.apply_chat_template(
chat, add_generation_prompt=True, tokenize=False
)

is_multi_modal = self.image_key in row_dict
if is_multi_modal: # expand image token
Expand Down Expand Up @@ -190,7 +209,8 @@ def __getitem__(self, item):

input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template,
tokenizer=self.tokenizer,
max_length=self.max_prompt_length,
# max_length=self.max_prompt_length,
max_length=self.padding_size,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation)
Expand All @@ -211,6 +231,7 @@ def __getitem__(self, item):
row_dict['attention_mask'] = attention_mask[0]
row_dict['position_ids'] = position_ids[0]
row_dict['raw_prompt_ids'] = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
row_dict["prompt_length"] = row_dict['position_ids'].max().item() + 1

# encode prompts without chat template
if self.return_raw_chat:
Expand Down
5 changes: 5 additions & 0 deletions verl/utils/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

INPUT_TEMPLATE = {
"base": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: {} Please reason step by step, and put your final answer within \\boxed{{}}. Assistant: <think>\n",
"distill": "<|begin▁of▁sentence|><|User|>{} Let's think step by step and output the final answer within \\boxed{{}}.<|Assistant|><think>\n"
}
3 changes: 2 additions & 1 deletion verl/workers/reward_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.

from .naive import NaiveRewardManager
from .prime import PrimeRewardManager
from .prime import PrimeRewardManager
from .deepscaler import DeepScalerRewardManager
173 changes: 173 additions & 0 deletions verl/workers/reward_manager/deepscaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from verl import DataProto
# from verl.utils.reward_score import _default_compute_score
import torch
import json

class DeepScalerRewardManager:
"""The reward manager.
"""

def __init__(self, tokenizer, num_examine, compute_score=None, is_val=False, log_file=None) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.is_val = is_val
self.log_file = log_file
self.compute_score_fn = compute_score

def __call__(self, data: DataProto):
"""We will expand this function gradually based on the available datasets"""

# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if 'rm_scores' in data.batch.keys():
return data.batch['rm_scores']

reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
reward_info = {}

already_print_data_sources = {}

from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Any
#import threading
# Thread-safe dict for tracking printed data sources
# print_lock = threading.Lock()

def process_timeout_item(args, log_file):
# i, data_item, already_print_data_sources = args
i, data_item, already_print_data_sources, is_val = args
prompt_ids = data_item.batch['prompts']
prompt_length = prompt_ids.shape[-1]

valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]

response_ids = data_item.batch['responses']
valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]

# decode
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
sequences_str = self.tokenizer.decode(sequences)

ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']

print("A task timed out!")
print(sequences_str)
print(f'ground_truth:{ground_truth}')

# Write query-score pairs to JSONL if log_file is provided
if log_file:
with open(log_file, "a", encoding="utf-8") as f:
record = {
"sequence": sequences_str,
"ground_truth": ground_truth,
"timeout": True,
}
f.write(json.dumps(record, ensure_ascii=False) + "\n")

return i, 0., valid_response_length, {'score': 0.}


def _print(data_item, reward_info, log_file=None):
prompt_ids = data_item.batch['prompts']
prompt_length = prompt_ids.shape[-1]

valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]

response_ids = data_item.batch['responses']
valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]

# decode
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
sequences_str = self.tokenizer.decode(sequences)

ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
print(sequences_str)
print(f'ground_truth:{ground_truth}')
print(reward_info)

# Write query-score pairs to JSONL if log_file is provided
if log_file:
with open(log_file, "a", encoding="utf-8") as f:
record = {
"sequence": sequences_str,
"ground_truth": ground_truth,
"reward": reward_info,
}
f.write(json.dumps(record, ensure_ascii=False) + "\n")

def process_item(args):
# i, data_item, already_print_data_sources = args
i, data_item, already_print_data_sources, is_val = args
prompt_ids = data_item.batch['prompts']
prompt_length = prompt_ids.shape[-1]

valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]

response_ids = data_item.batch['responses']
valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]

# decode
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
sequences_str = self.tokenizer.decode(sequences)

ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']

# select rm_score
data_source = data_item.non_tensor_batch['data_source']
# compute_score_fn = _select_rm_score_fn(data_source)
# score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth)
score, info = self.compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth, is_val=is_val)

return i, score, valid_response_length, info

# Process items in parallel using ThreadPoolExecutor
# with ThreadPoolExecutor(max_workers=8) as executor:
# args = [(i, data[i], already_print_data_sources, self.is_val) for i in range(len(data))]
# results = list(executor.map(process_item, args))

import func_timeout
results = []
for i in range(len(data)):
args = (i, data[i], already_print_data_sources, self.is_val)
try:
result = process_item(args)
except func_timeout.FunctionTimedOut:
result = process_timeout_item(args, self.log_file)
results.append(result)

# with ThreadPoolExecutor(max_workers=8) as executor:
# args = [(i, data[i], already_print_data_sources, self.is_val) for i in range(len(data))]
# futures = [executor.submit(process_item, arg) for arg in args]

# results = []
# for i, future in enumerate(futures):
# try:
# result = future.result(timeout=60)
# results.append(result)
# except TimeoutError:
# print("A task timed out!")
# result = process_timeout_item(args[i], self.log_file)
# results.append(result)

_print(data[0], results[0][-1], log_file=self.log_file)
# for i in range(len(data)):
# _print(data[i], results[i][-1], log_file=self.log_file)

# Fill reward tensor with results
for i, score, valid_response_length, info in results:
reward_tensor[i, valid_response_length - 1] = score
for k, v in info.items():
if k not in reward_info:
reward_info[k] = torch.zeros(len(data))
reward_info[k][i] = v

# if self.is_val:
# return reward_tensor
# else:
# return reward_tensor, reward_info
return reward_tensor
Loading