Skip to content

Conversation

@zhangbilang
Copy link
Collaborator

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @SmudgedWings, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily focuses on enhancing the token reduction capabilities for Llava models by introducing specific support for Llava-Next within the DivPrune and MustDrop methods. The changes involve significant refactoring of how visual tokens are processed and integrated into the model's input, ensuring efficient and effective pruning for the newer model architecture.

Highlights

  • LLaVA-Next Compatibility: I've updated the DivPrune and MustDrop token reduction methods to ensure compatibility with Llava-Next models. This includes adapting how vision token lengths are determined and how arguments are passed within the model's forward hooks.
  • Dynamic Token Pruning Integration: A key change is the introduction of a mechanism where MustDrop now explicitly returns index_masks (indicating which tokens are retained). This information is then leveraged by a new utility function, prepare_inputs_labels_for_multimodal_with_index_masks, which dynamically patches the LLaVA model's input preparation to correctly handle and insert only the pruned image features into the model's input embeddings.
  • Configuration & Code Cleanup: New YAML configuration files have been added for DivPrune and MustDrop, specifically tailored for Llava models, including suggested retained_tokens values for llava_next. Additionally, I've cleaned up tome.py by removing commented-out and unused code, improving code readability and maintainability.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the DivPrune and MustDrop token reduction methods to support llava-next. This involves adding new configuration files and modifying the implementation of these methods to handle the specifics of the new model version, particularly how vision tokens are processed.

The changes include:

  • New YAML configuration files for DivPrune and MustDrop.
  • Updates to divprune.py and mustdrop.py to support llava-next, including dynamic vision token length handling.
  • A new utility function in utils.py to prepare multimodal inputs, which is a key part of the llava-next support.
  • General code cleanup in tome.py.

My review has identified a critical issue in utils.py where an error is not being raised correctly. I've also found several high-severity issues related to leftover debugging code and a significant maintainability issue in divprune.py due to the use of *args with magic indices. Additionally, I've pointed out a potential logic error in divprune.py and a potential arithmetic error in mustdrop.py. I've provided suggestions for all these points to improve code quality, correctness, and maintainability.

if 'maxpool2x2' in mm_patch_merge_type:
raise NotImplementedError
elif 'unpad' in mm_patch_merge_type and 'anyres_max' in image_aspect_ratio:
NotImplementedError

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

NotImplementedError is a class, not an exception instance. This line of code does nothing as it is. To indicate that this code path is not implemented, you must raise the exception.

Suggested change
NotImplementedError
raise NotImplementedError


cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]

# import pdb; pdb.set_trace()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

A commented-out pdb trace was found. This appears to be a leftover from a debugging session and should be removed from the production codebase.

right_add = random.randint(left_add, self.config.pos_skipping_range)
position_ids[:, :split_position] += left_add
position_ids[:, split_position:] += right_add
# import pdb; pdb.set_trace()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

A commented-out pdb trace was found. This appears to be a leftover from a debugging session and should be removed from the production codebase.

visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len]
selected_visual_tokens, cosine_matrix = divprune(
visual_tokens, img_feature_len, None, threshold_ratio=rate
visual_tokens, img_feature_len, None, threshold_ratio=1 - rate

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The threshold ratio is being subtracted from 1. Ensure this is the intended behavior, as it changes the meaning of the threshold_ratio parameter. If the intention is to use the reduction ratio directly, this subtraction is unnecessary and could lead to confusion.

'`image_features[0].shape[0]` to the return value of the function '
'`prepare_inputs_labels_for_multimodal`, and modify the related code.'
)
assert len(outs) == 7, message

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert to enforce an API contract on an external library is brittle. If the prepare_inputs_labels_for_multimodal function in the llava library changes its return signature, this will cause the application to crash. The assertion message is helpful for developers but doesn't prevent the crash.

A more robust approach would be to handle this case gracefully, for instance, by checking the length of outs and raising a more specific, informative exception or logging a clear error message.

if pruning_paras.get('retained_tokens', None) is not None:
retained_tokens = pruning_paras['retained_tokens']
fix_r = (pruning_paras['vision_token_length'] - retained_tokens) \
fix_r = (vtoken_length - retained_tokens) \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation for fix_r can result in a negative value if retained_tokens is greater than vtoken_length. While this might not happen with current configurations, it could lead to unexpected behavior in the conditional_pooling function if inputs change. It would be safer to ensure fix_r is non-negative.

Suggested change
fix_r = (vtoken_length - retained_tokens) \
fix_r = max(0, (vtoken_length - retained_tokens) // \
(window_size[0] * window_size[1] - 1))

Comment on lines +213 to +561
def prepare_inputs_labels_for_multimodal_with_index_masks(
self, input_ids, position_ids, attention_mask, past_key_values, labels,
images, modalities=['image'], image_sizes=None
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels, None

if isinstance(modalities, str):
modalities = [modalities]

if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]

video_idx_in_batch = []
for _ in range(len(modalities)):
if modalities[_] == 'video':
video_idx_in_batch.append(_)

images_list = []
for image in images:
if image.ndim == 4:
images_list.append(image)
else:
images_list.append(image.unsqueeze(0))

concat_images = torch.cat([image for image in images_list], dim=0)
split_sizes = [image.shape[0] for image in images_list]
encoded_image_features = self.encode_images(concat_images)
index_masks = vision_tower.index_masks
encoded_image_features = torch.split(encoded_image_features, split_sizes)
index_masks = torch.split(index_masks, split_sizes)
image_features = []
for idx, image_feat in enumerate(encoded_image_features):
if idx in video_idx_in_batch:
image_features.append(self.get_2dPool(image_feat))
else:
image_features.append(image_feat)
mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
# mm_patch_merge_type = mm_patch_merge_type.replace('_unpad', '')
image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')

if mm_patch_merge_type == 'flat':
image_features = [x.flatten(0, 1) for x in image_features]
index_masks = [x.flatten(0, 1) for x in index_masks]
image_features = [x[m] for x, m in zip(image_features, index_masks)]
elif mm_patch_merge_type.startswith('spatial'):
new_image_features = []
for image_idx, (image_feature, index_mask) in enumerate(
zip(image_features, index_masks)
):
if image_idx in video_idx_in_batch: # video operations
raise NotImplementedError
elif image_feature.shape[0] > 1:

base_image_feature, base_index_mask = image_feature[0], index_mask[0]
image_feature, index_mask = image_feature[1:], index_mask[1:]
height = width = self.get_vision_tower().num_patches_per_side
assert height * width == base_image_feature.shape[0]

if image_aspect_ratio == 'anyres':
if hasattr(self.get_vision_tower(), 'image_size'):
vision_tower_image_size = self.get_vision_tower().image_size
else:
raise ValueError('vision_tower_image_size is not found.')
try:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
vision_tower_image_size
)
except Exception:
num_patch_width, num_patch_height = 2, 2
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
index_mask = index_mask.view(
num_patch_height, num_patch_width, height, width
)
else:
raise NotImplementedError

if 'maxpool2x2' in mm_patch_merge_type:
raise NotImplementedError
elif 'unpad' in mm_patch_merge_type and 'anyres_max' in image_aspect_ratio:
NotImplementedError
elif 'unpad' in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[
:, None, None
].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
), dim=-1
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
index_mask = index_mask.permute(0, 2, 1, 3).contiguous().unsqueeze(0)
index_mask = index_mask.flatten(1, 2).flatten(2, 3)
index_mask = unpad_image(index_mask, image_sizes[image_idx])
index_mask = torch.cat((
index_mask,
torch.ones(
*index_mask.shape[:-1], 1, dtype=torch.bool
).to(index_mask.device)
), dim=-1)
index_mask = index_mask.flatten(1, 2).squeeze(0)
image_feature = image_feature[index_mask]
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
index_mask = index_mask.permute(0, 2, 1, 3).contiguous()
index_mask = index_mask.flatten(0, 3)
image_feature = image_feature[index_mask]
if 'nobase' in mm_patch_merge_type:
pass
else:
base_image_feature = base_image_feature[base_index_mask]
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
new_image_features.append(image_feature)
else: # single image operations
image_feature = image_feature[0]
index_mask = index_mask[0]
if 'unpad' in mm_patch_merge_type:
image_feature = torch.cat((
image_feature,
self.model.image_newline[None].to(image_feature.device)
), dim=0)
index_mask = torch.cat((
index_mask,
torch.ones(1, dtype=torch.bool).to(index_mask.device)
), dim=0)
image_feature = image_feature[index_mask]
new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError(f'Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}')
else:
image_features = self.encode_images(images)
image_features = image_features[index_masks].unsqueeze(0)
vision_tower.index_masks = []
vtoken_length = image_features[0].shape[0]
# TODO: image start / end is not implemented here to support pretraining.
if (
getattr(self.config, 'tune_mm_mlp_adapter', False) and
getattr(self.config, 'mm_use_im_start_end', False)
):
raise NotImplementedError
# rank_print(f"Total images : {len(image_features)}")

# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(
0, input_ids.shape[1],
dtype=torch.long, device=input_ids.device
)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)

# remove the padding using attention_mask -- FIXME
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
]
labels = [
cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
]

new_input_embeds = []
new_labels = []
cur_image_idx = 0
# rank_print("Inserting Images embedding")
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
# rank0_print(num_images)
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue

image_token_indices = [-1] + \
torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(
cur_input_ids[image_token_indices[i] + 1: image_token_indices[i + 1]]
)
cur_labels_noim.append(
cur_labels[image_token_indices[i] + 1: image_token_indices[i + 1]]
)
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []

for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
try:
cur_image_features = image_features[cur_image_idx]
except IndexError:
cur_image_features = image_features[cur_image_idx - 1]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device, dtype=cur_labels.dtype
)
)

cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]

# import pdb; pdb.set_trace()
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)

new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)

# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
# rank_print("Finishing Inserting")

new_input_embeds = [
x[:tokenizer_model_max_length]
for x, modality in zip(new_input_embeds, modalities)
]
new_labels = [
x[:tokenizer_model_max_length]
for x, modality in zip(new_labels, modalities)
]

# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)

new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len),
IGNORE_INDEX,
dtype=new_labels[0].dtype,
device=new_labels[0].device
)
attention_mask = torch.zeros(
(batch_size, max_len),
dtype=attention_mask.dtype,
device=attention_mask.device
)
position_ids = torch.zeros(
(batch_size, max_len),
dtype=position_ids.dtype, device=position_ids.device
)
# rank0_print("Prepare pos id")

for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, 'tokenizer_padding_side', 'right') == 'left':
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype, device=cur_new_embed.device
),
cur_new_embed
), dim=0
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0, cur_len,
dtype=position_ids.dtype, device=position_ids.device
)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype, device=cur_new_embed.device
)
), dim=0
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0, cur_len,
dtype=position_ids.dtype, device=position_ids.device
)

new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
# rank0_print("tokenizer padding")

if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded

if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

if _position_ids is None:
position_ids = None
if getattr(self.config, 'use_pos_skipping', False) and self.training:
position_ids = torch.arange(
new_input_embeds.size(1),
device=new_input_embeds.device
).unsqueeze(0).to(new_input_embeds.device)
split_position = random.randint(0, new_input_embeds.size(1))
left_add = random.randint(0, self.config.pos_skipping_range)
right_add = random.randint(left_add, self.config.pos_skipping_range)
position_ids[:, :split_position] += left_add
position_ids[:, split_position:] += right_add
# import pdb; pdb.set_trace()
# rank0_print("Finish preparing")
# print(vtoken_length)
return None, position_ids, attention_mask, past_key_values, \
new_input_embeds, new_labels, vtoken_length

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function prepare_inputs_labels_for_multimodal_with_index_masks is extremely long (over 300 lines) and complex. This significantly impacts readability and maintainability.

To improve the code quality, I recommend refactoring this function into smaller, more focused helper functions. Each helper function should have a single, well-defined responsibility. This will make the code easier to understand, test, and modify in the future.

else:
image_features.append(image_feat)
mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
# mm_patch_merge_type = mm_patch_merge_type.replace('_unpad', '')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line appears to be commented-out code. If it's no longer needed, it should be removed to improve code clarity.

position_ids[:, split_position:] += right_add
# import pdb; pdb.set_trace()
# rank0_print("Finish preparing")
# print(vtoken_length)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

A commented-out print statement was found. This appears to be a leftover from a debugging session and should be removed.

@zhangbilang zhangbilang merged commit 695fbc3 into ModelTC:main Jul 28, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants