From 707965d5a0cbb0cecc13aaca1d634ebe09d3ee28 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 20 Jan 2026 20:17:08 +0800 Subject: [PATCH] Refactor `InferenceModel` --- trinity/common/models/model.py | 118 ++++++++++++++++++++++++ trinity/common/models/tinker_model.py | 104 ++------------------- trinity/common/models/vllm_model.py | 124 ++------------------------ 3 files changed, 132 insertions(+), 214 deletions(-) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index e0534518e0..5622511ca5 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -16,6 +16,7 @@ from trinity.common.config import InferenceModelConfig from trinity.common.constants import RunningStatus from trinity.common.experience import Experience +from trinity.common.models.utils import get_action_mask_method from trinity.utils.log import get_logger @@ -84,6 +85,123 @@ def get_model_path(self) -> Optional[str]: return self.config.model_path +class BaseInferenceModel(InferenceModel): + """Base class for inference models containing common logic.""" + + def __init__(self, config: InferenceModelConfig) -> None: + super().__init__(config) + self.tokenizer = None + self.chat_template = None + if self.config.chat_template: + self.chat_template = self.config.chat_template + self.action_mask_method = get_action_mask_method(self.chat_template) + self.enable_thinking = config.enable_thinking + + def apply_chat_template( + self, + tokenizer_or_processor, + messages: List[dict], + ) -> str: + assert tokenizer_or_processor is not None, "tokenizer_or_processor must be provided." + if self.chat_template is None: + assert self.tokenizer is not None, "self.tokenizer must be initialized." + self.chat_template = self.tokenizer.get_chat_template() + + if messages[-1]["role"] == "assistant": + prompt = tokenizer_or_processor.apply_chat_template( + messages, + tokenize=False, + continue_final_message=True, + chat_template=self.chat_template, + ) + else: + prompt = tokenizer_or_processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + chat_template=self.chat_template, + enable_thinking=self.enable_thinking, + ) + return prompt + + def _handle_prompt_truncation(self, prompt: str, **kwargs) -> Tuple[Sequence, bool]: + """Handle prompt truncation if needed.""" + # Tokenize once without truncation to check if truncation is needed + token_ids = self.tokenizer( # type: ignore + prompt, + truncation=False, + return_tensors="pt", + )[ + "input_ids" + ][0].tolist() + + # Check if truncation is needed and apply it + if ( + self.config.enable_prompt_truncation + and self.config.max_prompt_tokens is not None + and len(token_ids) > self.config.max_prompt_tokens + ): + self.logger.warning(f"Prompt was truncated to {self.config.max_prompt_tokens} tokens") + token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response + return [ + Experience( + tokens=token_ids, + logprobs=torch.zeros(1, dtype=torch.float32), + prompt_length=len(token_ids) - 1, + prompt_text=self.tokenizer.decode(token_ids[:-1]), + response_text=self.tokenizer.decode(token_ids[-1]), + truncate_status="prompt_truncated", + reward=0.0, + ) + for _ in range(kwargs.get("n", 1)) + ], False + return token_ids, True + + async def convert_messages_to_experience( + self, + messages: List[dict], + tools: Optional[List[dict]] = None, + temperature: Optional[float] = None, + ) -> Experience: + """Convert a list of messages into an experience in async.""" + if self.tokenizer is None: + await self._initialize_tokenizer() + if self.chat_template is None: + self.chat_template = self.tokenizer.get_chat_template() + token_ids, action_mask, prompt_length = self.action_mask_method( + tokenizer=self.tokenizer, + messages=messages, + tools=tools, + chat_template=self.chat_template, + enable_thinking=self.enable_thinking, + ) # (seq_length, ), (seq_length, ) + + # Truncate tokens if they exceed the length limit + assert token_ids is not None + truncate_status = None + if self.config.max_model_len is not None and self.config.max_model_len > 0: + if len(token_ids) > self.config.max_model_len - 1: + truncate_status = "response_truncated" + self.logger.warning( + f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}" + ) + token_ids = token_ids[: self.config.max_model_len - 1] + action_mask = action_mask[: self.config.max_model_len - 1] + + temperature = temperature if temperature is not None else self.config.temperature + logprobs = await self.logprobs( + token_ids=token_ids.tolist(), temperature=temperature + ) # (seq_length - 1,) + return Experience( + tokens=token_ids, + logprobs=logprobs[prompt_length - 1 :], + prompt_length=prompt_length, + action_mask=action_mask[prompt_length:], # Exclude the prompt tokens + messages=messages, + truncate_status=truncate_status, + ) + + def _history_recorder(func): """Decorator to record history of the model calls.""" diff --git a/trinity/common/models/tinker_model.py b/trinity/common/models/tinker_model.py index 631795bf74..0c93c77daf 100644 --- a/trinity/common/models/tinker_model.py +++ b/trinity/common/models/tinker_model.py @@ -10,12 +10,11 @@ from trinity.common.config import InferenceModelConfig from trinity.common.experience import Experience -from trinity.common.models.model import InferenceModel -from trinity.common.models.utils import get_action_mask_method +from trinity.common.models.model import BaseInferenceModel from trinity.manager.synchronizer import Synchronizer -class TinkerModel(InferenceModel): +class TinkerModel(BaseInferenceModel): def __init__( self, config: InferenceModelConfig, @@ -25,12 +24,6 @@ def __init__( self.synchronizer = Synchronizer.get_actor(namespace=ray.get_runtime_context().namespace) self.model = None self.model_path = config.model_path - self.tokenizer = None - self.chat_template = None - if self.config.chat_template: - self.chat_template = self.config.chat_template - self.action_mask_method = get_action_mask_method(self.chat_template) - self.enable_thinking = config.enable_thinking async def _initialize_tokenizer(self) -> None: """Initialize the tokenizer.""" @@ -62,34 +55,9 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: if self.tokenizer is None: await self._initialize_tokenizer() - # Tokenize once without truncation to check if truncation is needed - token_ids = self.tokenizer( # type: ignore - prompt, - truncation=False, - return_tensors="pt", - )[ - "input_ids" - ][0].tolist() - - # Check if truncation is needed and apply it - if self.config.enable_prompt_truncation and self.config.max_prompt_tokens is not None: - if len(token_ids) > self.config.max_prompt_tokens: - self.logger.warning( - f"Prompt was truncated to {self.config.max_prompt_tokens} tokens" - ) - token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response - return [ - Experience( - tokens=token_ids, - logprobs=torch.zeros(1, dtype=torch.float32), - prompt_length=len(token_ids) - 1, - prompt_text=self.tokenizer.decode(token_ids[:-1]), - response_text=self.tokenizer.decode(token_ids[-1]), - truncate_status="prompt_truncated", - reward=0.0, - ) - for _ in range(kwargs.get("n", 1)) - ] + token_ids, is_valid = self._handle_prompt_truncation(prompt, **kwargs) + if not is_valid: + return token_ids with_chat_completion = kwargs.get("with_chat_completion", False) if with_chat_completion: @@ -157,8 +125,6 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]: """Generate experiences from a list of history chat messages in async.""" if self.tokenizer is None: await self._initialize_tokenizer() - if self.chat_template is None: - self.chat_template = self.tokenizer.get_chat_template() # TODO: this is a hack to support openai chat messages, which only supports text for msg in messages: @@ -169,21 +135,7 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]: content_str = msg["content"] msg["content"] = content_str - if messages[-1]["role"] == "assistant": - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - continue_final_message=True, - chat_template=self.chat_template, - ) - else: - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - chat_template=self.chat_template, - enable_thinking=self.enable_thinking, - ) + prompt = self.apply_chat_template(self.tokenizer, messages) return await self.generate(prompt=prompt, **kwargs) async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor: @@ -191,50 +143,6 @@ async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor: logprobs = await self.model.compute_logprobs_async(types.ModelInput.from_ints(token_ids)) return torch.tensor(logprobs[1:], dtype=torch.float32) - async def convert_messages_to_experience( - self, - messages: List[dict], - tools: Optional[List[dict]] = None, - temperature: Optional[float] = None, - ) -> Experience: - """Convert a list of messages into an experience in async.""" - if self.tokenizer is None: - await self._initialize_tokenizer() - if self.chat_template is None: - self.chat_template = self.tokenizer.get_chat_template() - token_ids, action_mask, prompt_length = self.action_mask_method( - tokenizer=self.tokenizer, - messages=messages, - tools=tools, - chat_template=self.chat_template, - enable_thinking=self.enable_thinking, - ) # (seq_length, ), (seq_length, ) - - # Truncate tokens if they exceed the length limit - assert token_ids is not None - truncate_status = None - if self.config.max_model_len is not None and self.config.max_model_len > 0: - if len(token_ids) > self.config.max_model_len - 1: - truncate_status = "response_truncated" - self.logger.warning( - f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}" - ) - token_ids = token_ids[: self.config.max_model_len - 1] - action_mask = action_mask[: self.config.max_model_len - 1] - - temperature = temperature if temperature is not None else self.config.temperature - logprobs = await self.logprobs( - token_ids=token_ids.tolist(), temperature=temperature - ) # (seq_length - 1,) - return Experience( - tokens=token_ids, - logprobs=logprobs[prompt_length - 1 :], - prompt_length=prompt_length, - action_mask=action_mask[prompt_length:], # Exclude the prompt tokens - messages=messages, - truncate_status=truncate_status, - ) - async def prepare(self) -> None: """Prepare the model before inference.""" self.service_client = tinker.ServiceClient() diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 6c01cddc71..59369f5e18 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -17,13 +17,12 @@ build_multi_modal_inputs, convert_messages_to_mm_format, ) -from trinity.common.models.model import InferenceModel -from trinity.common.models.utils import get_action_mask_method +from trinity.common.models.model import BaseInferenceModel from trinity.common.models.vllm_patch import get_vllm_version # V0 engine is deprecated since vLLM v0.10.2, related code will be removed in the future. -class vLLMRolloutModel(InferenceModel): +class vLLMRolloutModel(BaseInferenceModel): """Wrapper around the vLLM engine to handle async requests. Args: @@ -79,7 +78,6 @@ def __init__( top_k=config.top_k, ignore_eos=config.ignore_eos, ) - self.enable_thinking = config.enable_thinking self.ray_namespace = config.ray_namespace self.request_id = 0 max_model_len = config.max_model_len @@ -137,11 +135,6 @@ def __init__( self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args) self.processor = None - self.tokenizer = None - self.chat_template = None - if self.config.chat_template: - self.chat_template = self.config.chat_template - self.action_mask_method = get_action_mask_method(self.chat_template) self.state_dict_meta = None self.model_version = 0 # TODO: resume the value from the checkpoint self.api_server_host = None @@ -184,23 +177,8 @@ async def chat(self, messages: List[Dict], lora_request=None, **kwargs) -> Seque """ if self.tokenizer is None: await self._initialize_tokenizer() - if self.chat_template is None: - self.chat_template = self.tokenizer.get_chat_template() - if messages[-1]["role"] == "assistant": - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - continue_final_message=True, - chat_template=self.chat_template, - ) - else: - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - chat_template=self.chat_template, - enable_thinking=self.enable_thinking, - ) + + prompt = self.apply_chat_template(self.tokenizer, messages) return await self.generate(prompt=prompt, lora_request=lora_request, **kwargs) async def generate(self, prompt: str, lora_request=None, **kwargs) -> Sequence[Experience]: @@ -216,34 +194,9 @@ async def generate(self, prompt: str, lora_request=None, **kwargs) -> Sequence[E if self.tokenizer is None: await self._initialize_tokenizer() - # Tokenize once without truncation to check if truncation is needed - token_ids = self.tokenizer( # type: ignore - prompt, - truncation=False, - return_tensors="pt", - )[ - "input_ids" - ][0].tolist() - - # Check if truncation is needed and apply it - if self.config.enable_prompt_truncation and self.config.max_prompt_tokens is not None: - if len(token_ids) > self.config.max_prompt_tokens: - self.logger.warning( - f"Prompt was truncated to {self.config.max_prompt_tokens} tokens" - ) - token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response - return [ - Experience( - tokens=token_ids, - logprobs=torch.zeros(1, dtype=torch.float32), - prompt_length=len(token_ids) - 1, - prompt_text=self.tokenizer.decode(token_ids[:-1]), - response_text=self.tokenizer.decode(token_ids[-1]), - truncate_status="prompt_truncated", - reward=0.0, - ) - for i in range(kwargs.get("n", 1)) - ] + token_ids, is_valid = self._handle_prompt_truncation(prompt, **kwargs) + if not is_valid: + return token_ids output = await self._generate_internal( prompt={"prompt_token_ids": token_ids}, lora_request=lora_request, **kwargs @@ -290,25 +243,8 @@ async def chat_mm( """ if self.processor is None: self._initialize_processor() - if self.chat_template is None: - self.chat_template = self.tokenizer.get_chat_template() messages = convert_messages_to_mm_format(messages) - if messages[-1]["role"] == "assistant": - prompt = self.processor.apply_chat_template( - messages, - tokenize=False, - continue_final_message=True, - chat_template=self.chat_template, - ) - else: - prompt = self.processor.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - chat_template=self.chat_template, - enable_thinking=self.enable_thinking, - ) - + prompt = self.apply_chat_template(self.processor, messages) return await self.generate_mm(prompt=prompt, images=images, videos=videos, **kwargs) async def generate_mm( @@ -537,50 +473,6 @@ async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> raise RuntimeError("[vLLM] The request is not finished. This should not happen.") - async def convert_messages_to_experience( - self, - messages: List[dict], - tools: Optional[List[dict]] = None, - temperature: Optional[float] = None, - ) -> Experience: - """Convert a list of messages into an experience.""" - if self.tokenizer is None: - await self._initialize_tokenizer() - if self.chat_template is None: - self.chat_template = self.tokenizer.get_chat_template() - token_ids, action_mask, prompt_length = self.action_mask_method( - tokenizer=self.tokenizer, - messages=messages, - tools=tools, - chat_template=self.chat_template, - enable_thinking=self.enable_thinking, - ) # (seq_length, ), (seq_length, ) - - # Truncate tokens if they exceed the length limit - assert token_ids is not None - truncate_status = None - if self.config.max_model_len is not None and self.config.max_model_len > 0: - if len(token_ids) > self.config.max_model_len - 1: - truncate_status = "response_truncated" - self.logger.warning( - f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}" - ) - token_ids = token_ids[: self.config.max_model_len - 1] - action_mask = action_mask[: self.config.max_model_len - 1] - - temperature = temperature if temperature is not None else self.config.temperature - logprobs = await self.logprobs( - token_ids=token_ids.tolist(), temperature=temperature - ) # (seq_length - 1,) - return Experience( - tokens=token_ids, - logprobs=logprobs[prompt_length - 1 :], - prompt_length=prompt_length, - action_mask=action_mask[prompt_length:], # Exclude the prompt tokens - messages=messages, - truncate_status=truncate_status, - ) - async def shutdown(self): """Shutdown the vLLM v1 engine. This kills child processes forked by the vLLM engine. If not called, the child processes will be