-
Notifications
You must be signed in to change notification settings - Fork 49
Refactor InferenceModel
#485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |||||||||||||||||||||||||||||||||||||
| from trinity.common.config import InferenceModelConfig | ||||||||||||||||||||||||||||||||||||||
| from trinity.common.constants import RunningStatus | ||||||||||||||||||||||||||||||||||||||
| from trinity.common.experience import Experience | ||||||||||||||||||||||||||||||||||||||
| from trinity.common.models.utils import get_action_mask_method | ||||||||||||||||||||||||||||||||||||||
| from trinity.utils.log import get_logger | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
@@ -84,6 +85,123 @@ def get_model_path(self) -> Optional[str]: | |||||||||||||||||||||||||||||||||||||
| return self.config.model_path | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| class BaseInferenceModel(InferenceModel): | ||||||||||||||||||||||||||||||||||||||
| """Base class for inference models containing common logic.""" | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def __init__(self, config: InferenceModelConfig) -> None: | ||||||||||||||||||||||||||||||||||||||
| super().__init__(config) | ||||||||||||||||||||||||||||||||||||||
| self.tokenizer = None | ||||||||||||||||||||||||||||||||||||||
| self.chat_template = None | ||||||||||||||||||||||||||||||||||||||
| if self.config.chat_template: | ||||||||||||||||||||||||||||||||||||||
| self.chat_template = self.config.chat_template | ||||||||||||||||||||||||||||||||||||||
| self.action_mask_method = get_action_mask_method(self.chat_template) | ||||||||||||||||||||||||||||||||||||||
| self.enable_thinking = config.enable_thinking | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def apply_chat_template( | ||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||
| tokenizer_or_processor, | ||||||||||||||||||||||||||||||||||||||
| messages: List[dict], | ||||||||||||||||||||||||||||||||||||||
| ) -> str: | ||||||||||||||||||||||||||||||||||||||
| assert tokenizer_or_processor is not None, "tokenizer_or_processor must be provided." | ||||||||||||||||||||||||||||||||||||||
| if self.chat_template is None: | ||||||||||||||||||||||||||||||||||||||
| assert self.tokenizer is not None, "self.tokenizer must be initialized." | ||||||||||||||||||||||||||||||||||||||
| self.chat_template = self.tokenizer.get_chat_template() | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if messages[-1]["role"] == "assistant": | ||||||||||||||||||||||||||||||||||||||
| prompt = tokenizer_or_processor.apply_chat_template( | ||||||||||||||||||||||||||||||||||||||
| messages, | ||||||||||||||||||||||||||||||||||||||
| tokenize=False, | ||||||||||||||||||||||||||||||||||||||
| continue_final_message=True, | ||||||||||||||||||||||||||||||||||||||
| chat_template=self.chat_template, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| prompt = tokenizer_or_processor.apply_chat_template( | ||||||||||||||||||||||||||||||||||||||
| messages, | ||||||||||||||||||||||||||||||||||||||
| tokenize=False, | ||||||||||||||||||||||||||||||||||||||
| add_generation_prompt=True, | ||||||||||||||||||||||||||||||||||||||
| chat_template=self.chat_template, | ||||||||||||||||||||||||||||||||||||||
| enable_thinking=self.enable_thinking, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| return prompt | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _handle_prompt_truncation(self, prompt: str, **kwargs) -> Tuple[Sequence, bool]: | ||||||||||||||||||||||||||||||||||||||
| """Handle prompt truncation if needed.""" | ||||||||||||||||||||||||||||||||||||||
| # Tokenize once without truncation to check if truncation is needed | ||||||||||||||||||||||||||||||||||||||
| token_ids = self.tokenizer( # type: ignore | ||||||||||||||||||||||||||||||||||||||
| prompt, | ||||||||||||||||||||||||||||||||||||||
| truncation=False, | ||||||||||||||||||||||||||||||||||||||
| return_tensors="pt", | ||||||||||||||||||||||||||||||||||||||
| )[ | ||||||||||||||||||||||||||||||||||||||
| "input_ids" | ||||||||||||||||||||||||||||||||||||||
| ][0].tolist() | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Check if truncation is needed and apply it | ||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||
| self.config.enable_prompt_truncation | ||||||||||||||||||||||||||||||||||||||
| and self.config.max_prompt_tokens is not None | ||||||||||||||||||||||||||||||||||||||
| and len(token_ids) > self.config.max_prompt_tokens | ||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
| self.logger.warning(f"Prompt was truncated to {self.config.max_prompt_tokens} tokens") | ||||||||||||||||||||||||||||||||||||||
| token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response | ||||||||||||||||||||||||||||||||||||||
| return [ | ||||||||||||||||||||||||||||||||||||||
| Experience( | ||||||||||||||||||||||||||||||||||||||
| tokens=token_ids, | ||||||||||||||||||||||||||||||||||||||
| logprobs=torch.zeros(1, dtype=torch.float32), | ||||||||||||||||||||||||||||||||||||||
| prompt_length=len(token_ids) - 1, | ||||||||||||||||||||||||||||||||||||||
| prompt_text=self.tokenizer.decode(token_ids[:-1]), | ||||||||||||||||||||||||||||||||||||||
| response_text=self.tokenizer.decode(token_ids[-1]), | ||||||||||||||||||||||||||||||||||||||
| truncate_status="prompt_truncated", | ||||||||||||||||||||||||||||||||||||||
| reward=0.0, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| for _ in range(kwargs.get("n", 1)) | ||||||||||||||||||||||||||||||||||||||
| ], False | ||||||||||||||||||||||||||||||||||||||
| return token_ids, True | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| async def convert_messages_to_experience( | ||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||
| messages: List[dict], | ||||||||||||||||||||||||||||||||||||||
| tools: Optional[List[dict]] = None, | ||||||||||||||||||||||||||||||||||||||
| temperature: Optional[float] = None, | ||||||||||||||||||||||||||||||||||||||
| ) -> Experience: | ||||||||||||||||||||||||||||||||||||||
| """Convert a list of messages into an experience in async.""" | ||||||||||||||||||||||||||||||||||||||
| if self.tokenizer is None: | ||||||||||||||||||||||||||||||||||||||
| await self._initialize_tokenizer() | ||||||||||||||||||||||||||||||||||||||
| if self.chat_template is None: | ||||||||||||||||||||||||||||||||||||||
| self.chat_template = self.tokenizer.get_chat_template() | ||||||||||||||||||||||||||||||||||||||
| token_ids, action_mask, prompt_length = self.action_mask_method( | ||||||||||||||||||||||||||||||||||||||
| tokenizer=self.tokenizer, | ||||||||||||||||||||||||||||||||||||||
| messages=messages, | ||||||||||||||||||||||||||||||||||||||
| tools=tools, | ||||||||||||||||||||||||||||||||||||||
| chat_template=self.chat_template, | ||||||||||||||||||||||||||||||||||||||
| enable_thinking=self.enable_thinking, | ||||||||||||||||||||||||||||||||||||||
| ) # (seq_length, ), (seq_length, ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Truncate tokens if they exceed the length limit | ||||||||||||||||||||||||||||||||||||||
| assert token_ids is not None | ||||||||||||||||||||||||||||||||||||||
| truncate_status = None | ||||||||||||||||||||||||||||||||||||||
| if self.config.max_model_len is not None and self.config.max_model_len > 0: | ||||||||||||||||||||||||||||||||||||||
| if len(token_ids) > self.config.max_model_len - 1: | ||||||||||||||||||||||||||||||||||||||
| truncate_status = "response_truncated" | ||||||||||||||||||||||||||||||||||||||
| self.logger.warning( | ||||||||||||||||||||||||||||||||||||||
| f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}" | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| token_ids = token_ids[: self.config.max_model_len - 1] | ||||||||||||||||||||||||||||||||||||||
| action_mask = action_mask[: self.config.max_model_len - 1] | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+182
to
+189
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| temperature = temperature if temperature is not None else self.config.temperature | ||||||||||||||||||||||||||||||||||||||
| logprobs = await self.logprobs( | ||||||||||||||||||||||||||||||||||||||
| token_ids=token_ids.tolist(), temperature=temperature | ||||||||||||||||||||||||||||||||||||||
| ) # (seq_length - 1,) | ||||||||||||||||||||||||||||||||||||||
| return Experience( | ||||||||||||||||||||||||||||||||||||||
| tokens=token_ids, | ||||||||||||||||||||||||||||||||||||||
| logprobs=logprobs[prompt_length - 1 :], | ||||||||||||||||||||||||||||||||||||||
| prompt_length=prompt_length, | ||||||||||||||||||||||||||||||||||||||
| action_mask=action_mask[prompt_length:], # Exclude the prompt tokens | ||||||||||||||||||||||||||||||||||||||
| messages=messages, | ||||||||||||||||||||||||||||||||||||||
| truncate_status=truncate_status, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _history_recorder(func): | ||||||||||||||||||||||||||||||||||||||
| """Decorator to record history of the model calls.""" | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for handling prompt truncation and creating a dummy
Experienceobject is a bit dense. Using an explicit variable for the prompt token count would make the code more self-documenting and easier to understand at a glance.