Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from trinity.common.config import InferenceModelConfig
from trinity.common.constants import RunningStatus
from trinity.common.experience import Experience
from trinity.common.models.utils import get_action_mask_method
from trinity.utils.log import get_logger


Expand Down Expand Up @@ -84,6 +85,123 @@ def get_model_path(self) -> Optional[str]:
return self.config.model_path


class BaseInferenceModel(InferenceModel):
"""Base class for inference models containing common logic."""

def __init__(self, config: InferenceModelConfig) -> None:
super().__init__(config)
self.tokenizer = None
self.chat_template = None
if self.config.chat_template:
self.chat_template = self.config.chat_template
self.action_mask_method = get_action_mask_method(self.chat_template)
self.enable_thinking = config.enable_thinking

def apply_chat_template(
self,
tokenizer_or_processor,
messages: List[dict],
) -> str:
assert tokenizer_or_processor is not None, "tokenizer_or_processor must be provided."
if self.chat_template is None:
assert self.tokenizer is not None, "self.tokenizer must be initialized."
self.chat_template = self.tokenizer.get_chat_template()

if messages[-1]["role"] == "assistant":
prompt = tokenizer_or_processor.apply_chat_template(
messages,
tokenize=False,
continue_final_message=True,
chat_template=self.chat_template,
)
else:
prompt = tokenizer_or_processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
chat_template=self.chat_template,
enable_thinking=self.enable_thinking,
)
return prompt

def _handle_prompt_truncation(self, prompt: str, **kwargs) -> Tuple[Sequence, bool]:
"""Handle prompt truncation if needed."""
# Tokenize once without truncation to check if truncation is needed
token_ids = self.tokenizer( # type: ignore
prompt,
truncation=False,
return_tensors="pt",
)[
"input_ids"
][0].tolist()

# Check if truncation is needed and apply it
if (
self.config.enable_prompt_truncation
and self.config.max_prompt_tokens is not None
and len(token_ids) > self.config.max_prompt_tokens
):
self.logger.warning(f"Prompt was truncated to {self.config.max_prompt_tokens} tokens")
token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response
return [
Experience(
tokens=token_ids,
logprobs=torch.zeros(1, dtype=torch.float32),
prompt_length=len(token_ids) - 1,
prompt_text=self.tokenizer.decode(token_ids[:-1]),
response_text=self.tokenizer.decode(token_ids[-1]),
truncate_status="prompt_truncated",
reward=0.0,
)
for _ in range(kwargs.get("n", 1))
], False
Comment on lines +145 to +157
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling prompt truncation and creating a dummy Experience object is a bit dense. Using an explicit variable for the prompt token count would make the code more self-documenting and easier to understand at a glance.

Suggested change
token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response
return [
Experience(
tokens=token_ids,
logprobs=torch.zeros(1, dtype=torch.float32),
prompt_length=len(token_ids) - 1,
prompt_text=self.tokenizer.decode(token_ids[:-1]),
response_text=self.tokenizer.decode(token_ids[-1]),
truncate_status="prompt_truncated",
reward=0.0,
)
for _ in range(kwargs.get("n", 1))
], False
prompt_token_count = self.config.max_prompt_tokens
token_ids = token_ids[:prompt_token_count + 1]
return [
Experience(
tokens=token_ids,
logprobs=torch.zeros(1, dtype=torch.float32),
prompt_length=prompt_token_count,
prompt_text=self.tokenizer.decode(token_ids[:-1]),
response_text=self.tokenizer.decode(token_ids[-1]),
truncate_status="prompt_truncated",
reward=0.0,
)
for _ in range(kwargs.get("n", 1))
], False

return token_ids, True

async def convert_messages_to_experience(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience in async."""
if self.tokenizer is None:
await self._initialize_tokenizer()
if self.chat_template is None:
self.chat_template = self.tokenizer.get_chat_template()
token_ids, action_mask, prompt_length = self.action_mask_method(
tokenizer=self.tokenizer,
messages=messages,
tools=tools,
chat_template=self.chat_template,
enable_thinking=self.enable_thinking,
) # (seq_length, ), (seq_length, )

# Truncate tokens if they exceed the length limit
assert token_ids is not None
truncate_status = None
if self.config.max_model_len is not None and self.config.max_model_len > 0:
if len(token_ids) > self.config.max_model_len - 1:
truncate_status = "response_truncated"
self.logger.warning(
f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}"
)
token_ids = token_ids[: self.config.max_model_len - 1]
action_mask = action_mask[: self.config.max_model_len - 1]
Comment on lines +182 to +189
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of max_model_len - 1 is a bit of a magic number. Introducing a local variable for the maximum length and adding a comment to explain the -1 would improve code clarity and maintainability, making it easier for future developers to understand the rationale behind this truncation logic.

Suggested change
if self.config.max_model_len is not None and self.config.max_model_len > 0:
if len(token_ids) > self.config.max_model_len - 1:
truncate_status = "response_truncated"
self.logger.warning(
f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}"
)
token_ids = token_ids[: self.config.max_model_len - 1]
action_mask = action_mask[: self.config.max_model_len - 1]
if self.config.max_model_len is not None and self.config.max_model_len > 0:
# The -1 is to leave space for at least one token to be generated for logprobs calculation.
max_len = self.config.max_model_len - 1
if len(token_ids) > max_len:
truncate_status = "response_truncated"
self.logger.warning(
f"Warning: {len(token_ids)=} exceeds the length limit {max_len=}"
)
token_ids = token_ids[:max_len]
action_mask = action_mask[:max_len]


temperature = temperature if temperature is not None else self.config.temperature
logprobs = await self.logprobs(
token_ids=token_ids.tolist(), temperature=temperature
) # (seq_length - 1,)
return Experience(
tokens=token_ids,
logprobs=logprobs[prompt_length - 1 :],
prompt_length=prompt_length,
action_mask=action_mask[prompt_length:], # Exclude the prompt tokens
messages=messages,
truncate_status=truncate_status,
)


def _history_recorder(func):
"""Decorator to record history of the model calls."""

Expand Down
104 changes: 6 additions & 98 deletions trinity/common/models/tinker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@

from trinity.common.config import InferenceModelConfig
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel
from trinity.common.models.utils import get_action_mask_method
from trinity.common.models.model import BaseInferenceModel
from trinity.manager.synchronizer import Synchronizer


class TinkerModel(InferenceModel):
class TinkerModel(BaseInferenceModel):
def __init__(
self,
config: InferenceModelConfig,
Expand All @@ -25,12 +24,6 @@ def __init__(
self.synchronizer = Synchronizer.get_actor(namespace=ray.get_runtime_context().namespace)
self.model = None
self.model_path = config.model_path
self.tokenizer = None
self.chat_template = None
if self.config.chat_template:
self.chat_template = self.config.chat_template
self.action_mask_method = get_action_mask_method(self.chat_template)
self.enable_thinking = config.enable_thinking

async def _initialize_tokenizer(self) -> None:
"""Initialize the tokenizer."""
Expand Down Expand Up @@ -62,34 +55,9 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]:
if self.tokenizer is None:
await self._initialize_tokenizer()

# Tokenize once without truncation to check if truncation is needed
token_ids = self.tokenizer( # type: ignore
prompt,
truncation=False,
return_tensors="pt",
)[
"input_ids"
][0].tolist()

# Check if truncation is needed and apply it
if self.config.enable_prompt_truncation and self.config.max_prompt_tokens is not None:
if len(token_ids) > self.config.max_prompt_tokens:
self.logger.warning(
f"Prompt was truncated to {self.config.max_prompt_tokens} tokens"
)
token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response
return [
Experience(
tokens=token_ids,
logprobs=torch.zeros(1, dtype=torch.float32),
prompt_length=len(token_ids) - 1,
prompt_text=self.tokenizer.decode(token_ids[:-1]),
response_text=self.tokenizer.decode(token_ids[-1]),
truncate_status="prompt_truncated",
reward=0.0,
)
for _ in range(kwargs.get("n", 1))
]
token_ids, is_valid = self._handle_prompt_truncation(prompt, **kwargs)
if not is_valid:
return token_ids

with_chat_completion = kwargs.get("with_chat_completion", False)
if with_chat_completion:
Expand Down Expand Up @@ -157,8 +125,6 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]:
"""Generate experiences from a list of history chat messages in async."""
if self.tokenizer is None:
await self._initialize_tokenizer()
if self.chat_template is None:
self.chat_template = self.tokenizer.get_chat_template()

# TODO: this is a hack to support openai chat messages, which only supports text
for msg in messages:
Expand All @@ -169,72 +135,14 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]:
content_str = msg["content"]
msg["content"] = content_str

if messages[-1]["role"] == "assistant":
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
continue_final_message=True,
chat_template=self.chat_template,
)
else:
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
chat_template=self.chat_template,
enable_thinking=self.enable_thinking,
)
prompt = self.apply_chat_template(self.tokenizer, messages)
return await self.generate(prompt=prompt, **kwargs)

async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor:
"""Generate logprobs for a list of tokens in async."""
logprobs = await self.model.compute_logprobs_async(types.ModelInput.from_ints(token_ids))
return torch.tensor(logprobs[1:], dtype=torch.float32)

async def convert_messages_to_experience(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience in async."""
if self.tokenizer is None:
await self._initialize_tokenizer()
if self.chat_template is None:
self.chat_template = self.tokenizer.get_chat_template()
token_ids, action_mask, prompt_length = self.action_mask_method(
tokenizer=self.tokenizer,
messages=messages,
tools=tools,
chat_template=self.chat_template,
enable_thinking=self.enable_thinking,
) # (seq_length, ), (seq_length, )

# Truncate tokens if they exceed the length limit
assert token_ids is not None
truncate_status = None
if self.config.max_model_len is not None and self.config.max_model_len > 0:
if len(token_ids) > self.config.max_model_len - 1:
truncate_status = "response_truncated"
self.logger.warning(
f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}"
)
token_ids = token_ids[: self.config.max_model_len - 1]
action_mask = action_mask[: self.config.max_model_len - 1]

temperature = temperature if temperature is not None else self.config.temperature
logprobs = await self.logprobs(
token_ids=token_ids.tolist(), temperature=temperature
) # (seq_length - 1,)
return Experience(
tokens=token_ids,
logprobs=logprobs[prompt_length - 1 :],
prompt_length=prompt_length,
action_mask=action_mask[prompt_length:], # Exclude the prompt tokens
messages=messages,
truncate_status=truncate_status,
)

async def prepare(self) -> None:
"""Prepare the model before inference."""
self.service_client = tinker.ServiceClient()
Expand Down
Loading