-
Notifications
You must be signed in to change notification settings - Fork 66
update divprune,mustdrop for llava-next #428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @SmudgedWings, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request primarily focuses on enhancing the token reduction capabilities for Llava models by introducing specific support for Llava-Next within the DivPrune and MustDrop methods. The changes involve significant refactoring of how visual tokens are processed and integrated into the model's input, ensuring efficient and effective pruning for the newer model architecture.
Highlights
- LLaVA-Next Compatibility: I've updated the
DivPruneandMustDroptoken reduction methods to ensure compatibility withLlava-Nextmodels. This includes adapting how vision token lengths are determined and how arguments are passed within the model's forward hooks. - Dynamic Token Pruning Integration: A key change is the introduction of a mechanism where
MustDropnow explicitly returnsindex_masks(indicating which tokens are retained). This information is then leveraged by a new utility function,prepare_inputs_labels_for_multimodal_with_index_masks, which dynamically patches the LLaVA model's input preparation to correctly handle and insert only the pruned image features into the model's input embeddings. - Configuration & Code Cleanup: New YAML configuration files have been added for
DivPruneandMustDrop, specifically tailored for Llava models, including suggestedretained_tokensvalues forllava_next. Additionally, I've cleaned uptome.pyby removing commented-out and unused code, improving code readability and maintainability.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request updates the DivPrune and MustDrop token reduction methods to support llava-next. This involves adding new configuration files and modifying the implementation of these methods to handle the specifics of the new model version, particularly how vision tokens are processed.
The changes include:
- New YAML configuration files for
DivPruneandMustDrop. - Updates to
divprune.pyandmustdrop.pyto supportllava-next, including dynamic vision token length handling. - A new utility function in
utils.pyto prepare multimodal inputs, which is a key part of thellava-nextsupport. - General code cleanup in
tome.py.
My review has identified a critical issue in utils.py where an error is not being raised correctly. I've also found several high-severity issues related to leftover debugging code and a significant maintainability issue in divprune.py due to the use of *args with magic indices. Additionally, I've pointed out a potential logic error in divprune.py and a potential arithmetic error in mustdrop.py. I've provided suggestions for all these points to improve code quality, correctness, and maintainability.
| if 'maxpool2x2' in mm_patch_merge_type: | ||
| raise NotImplementedError | ||
| elif 'unpad' in mm_patch_merge_type and 'anyres_max' in image_aspect_ratio: | ||
| NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] | ||
|
|
||
| # import pdb; pdb.set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| right_add = random.randint(left_add, self.config.pos_skipping_range) | ||
| position_ids[:, :split_position] += left_add | ||
| position_ids[:, split_position:] += right_add | ||
| # import pdb; pdb.set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len] | ||
| selected_visual_tokens, cosine_matrix = divprune( | ||
| visual_tokens, img_feature_len, None, threshold_ratio=rate | ||
| visual_tokens, img_feature_len, None, threshold_ratio=1 - rate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| '`image_features[0].shape[0]` to the return value of the function ' | ||
| '`prepare_inputs_labels_for_multimodal`, and modify the related code.' | ||
| ) | ||
| assert len(outs) == 7, message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using assert to enforce an API contract on an external library is brittle. If the prepare_inputs_labels_for_multimodal function in the llava library changes its return signature, this will cause the application to crash. The assertion message is helpful for developers but doesn't prevent the crash.
A more robust approach would be to handle this case gracefully, for instance, by checking the length of outs and raising a more specific, informative exception or logging a clear error message.
| if pruning_paras.get('retained_tokens', None) is not None: | ||
| retained_tokens = pruning_paras['retained_tokens'] | ||
| fix_r = (pruning_paras['vision_token_length'] - retained_tokens) \ | ||
| fix_r = (vtoken_length - retained_tokens) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation for fix_r can result in a negative value if retained_tokens is greater than vtoken_length. While this might not happen with current configurations, it could lead to unexpected behavior in the conditional_pooling function if inputs change. It would be safer to ensure fix_r is non-negative.
| fix_r = (vtoken_length - retained_tokens) \ | |
| fix_r = max(0, (vtoken_length - retained_tokens) // \ | |
| (window_size[0] * window_size[1] - 1)) |
| def prepare_inputs_labels_for_multimodal_with_index_masks( | ||
| self, input_ids, position_ids, attention_mask, past_key_values, labels, | ||
| images, modalities=['image'], image_sizes=None | ||
| ): | ||
| vision_tower = self.get_vision_tower() | ||
| if vision_tower is None or images is None or input_ids.shape[1] == 1: | ||
| return input_ids, position_ids, attention_mask, past_key_values, None, labels, None | ||
|
|
||
| if isinstance(modalities, str): | ||
| modalities = [modalities] | ||
|
|
||
| if type(images) is list or images.ndim == 5: | ||
| if type(images) is list: | ||
| images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] | ||
|
|
||
| video_idx_in_batch = [] | ||
| for _ in range(len(modalities)): | ||
| if modalities[_] == 'video': | ||
| video_idx_in_batch.append(_) | ||
|
|
||
| images_list = [] | ||
| for image in images: | ||
| if image.ndim == 4: | ||
| images_list.append(image) | ||
| else: | ||
| images_list.append(image.unsqueeze(0)) | ||
|
|
||
| concat_images = torch.cat([image for image in images_list], dim=0) | ||
| split_sizes = [image.shape[0] for image in images_list] | ||
| encoded_image_features = self.encode_images(concat_images) | ||
| index_masks = vision_tower.index_masks | ||
| encoded_image_features = torch.split(encoded_image_features, split_sizes) | ||
| index_masks = torch.split(index_masks, split_sizes) | ||
| image_features = [] | ||
| for idx, image_feat in enumerate(encoded_image_features): | ||
| if idx in video_idx_in_batch: | ||
| image_features.append(self.get_2dPool(image_feat)) | ||
| else: | ||
| image_features.append(image_feat) | ||
| mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') | ||
| # mm_patch_merge_type = mm_patch_merge_type.replace('_unpad', '') | ||
| image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') | ||
|
|
||
| if mm_patch_merge_type == 'flat': | ||
| image_features = [x.flatten(0, 1) for x in image_features] | ||
| index_masks = [x.flatten(0, 1) for x in index_masks] | ||
| image_features = [x[m] for x, m in zip(image_features, index_masks)] | ||
| elif mm_patch_merge_type.startswith('spatial'): | ||
| new_image_features = [] | ||
| for image_idx, (image_feature, index_mask) in enumerate( | ||
| zip(image_features, index_masks) | ||
| ): | ||
| if image_idx in video_idx_in_batch: # video operations | ||
| raise NotImplementedError | ||
| elif image_feature.shape[0] > 1: | ||
|
|
||
| base_image_feature, base_index_mask = image_feature[0], index_mask[0] | ||
| image_feature, index_mask = image_feature[1:], index_mask[1:] | ||
| height = width = self.get_vision_tower().num_patches_per_side | ||
| assert height * width == base_image_feature.shape[0] | ||
|
|
||
| if image_aspect_ratio == 'anyres': | ||
| if hasattr(self.get_vision_tower(), 'image_size'): | ||
| vision_tower_image_size = self.get_vision_tower().image_size | ||
| else: | ||
| raise ValueError('vision_tower_image_size is not found.') | ||
| try: | ||
| num_patch_width, num_patch_height = get_anyres_image_grid_shape( | ||
| image_sizes[image_idx], | ||
| self.config.image_grid_pinpoints, | ||
| vision_tower_image_size | ||
| ) | ||
| except Exception: | ||
| num_patch_width, num_patch_height = 2, 2 | ||
| image_feature = image_feature.view( | ||
| num_patch_height, num_patch_width, height, width, -1 | ||
| ) | ||
| index_mask = index_mask.view( | ||
| num_patch_height, num_patch_width, height, width | ||
| ) | ||
| else: | ||
| raise NotImplementedError | ||
|
|
||
| if 'maxpool2x2' in mm_patch_merge_type: | ||
| raise NotImplementedError | ||
| elif 'unpad' in mm_patch_merge_type and 'anyres_max' in image_aspect_ratio: | ||
| NotImplementedError | ||
| elif 'unpad' in mm_patch_merge_type: | ||
| image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() | ||
| image_feature = image_feature.flatten(1, 2).flatten(2, 3) | ||
| image_feature = unpad_image(image_feature, image_sizes[image_idx]) | ||
| image_feature = torch.cat( | ||
| ( | ||
| image_feature, | ||
| self.model.image_newline[ | ||
| :, None, None | ||
| ].expand(*image_feature.shape[:-1], 1).to(image_feature.device) | ||
| ), dim=-1 | ||
| ) | ||
| image_feature = image_feature.flatten(1, 2).transpose(0, 1) | ||
| index_mask = index_mask.permute(0, 2, 1, 3).contiguous().unsqueeze(0) | ||
| index_mask = index_mask.flatten(1, 2).flatten(2, 3) | ||
| index_mask = unpad_image(index_mask, image_sizes[image_idx]) | ||
| index_mask = torch.cat(( | ||
| index_mask, | ||
| torch.ones( | ||
| *index_mask.shape[:-1], 1, dtype=torch.bool | ||
| ).to(index_mask.device) | ||
| ), dim=-1) | ||
| index_mask = index_mask.flatten(1, 2).squeeze(0) | ||
| image_feature = image_feature[index_mask] | ||
| else: | ||
| image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() | ||
| image_feature = image_feature.flatten(0, 3) | ||
| index_mask = index_mask.permute(0, 2, 1, 3).contiguous() | ||
| index_mask = index_mask.flatten(0, 3) | ||
| image_feature = image_feature[index_mask] | ||
| if 'nobase' in mm_patch_merge_type: | ||
| pass | ||
| else: | ||
| base_image_feature = base_image_feature[base_index_mask] | ||
| image_feature = torch.cat((base_image_feature, image_feature), dim=0) | ||
| new_image_features.append(image_feature) | ||
| else: # single image operations | ||
| image_feature = image_feature[0] | ||
| index_mask = index_mask[0] | ||
| if 'unpad' in mm_patch_merge_type: | ||
| image_feature = torch.cat(( | ||
| image_feature, | ||
| self.model.image_newline[None].to(image_feature.device) | ||
| ), dim=0) | ||
| index_mask = torch.cat(( | ||
| index_mask, | ||
| torch.ones(1, dtype=torch.bool).to(index_mask.device) | ||
| ), dim=0) | ||
| image_feature = image_feature[index_mask] | ||
| new_image_features.append(image_feature) | ||
| image_features = new_image_features | ||
| else: | ||
| raise ValueError(f'Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}') | ||
| else: | ||
| image_features = self.encode_images(images) | ||
| image_features = image_features[index_masks].unsqueeze(0) | ||
| vision_tower.index_masks = [] | ||
| vtoken_length = image_features[0].shape[0] | ||
| # TODO: image start / end is not implemented here to support pretraining. | ||
| if ( | ||
| getattr(self.config, 'tune_mm_mlp_adapter', False) and | ||
| getattr(self.config, 'mm_use_im_start_end', False) | ||
| ): | ||
| raise NotImplementedError | ||
| # rank_print(f"Total images : {len(image_features)}") | ||
|
|
||
| # Let's just add dummy tensors if they do not exist, | ||
| # it is a headache to deal with None all the time. | ||
| # But it is not ideal, and if you have a better idea, | ||
| # please open an issue / submit a PR, thanks. | ||
| _labels = labels | ||
| _position_ids = position_ids | ||
| _attention_mask = attention_mask | ||
| if attention_mask is None: | ||
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) | ||
| else: | ||
| attention_mask = attention_mask.bool() | ||
| if position_ids is None: | ||
| position_ids = torch.arange( | ||
| 0, input_ids.shape[1], | ||
| dtype=torch.long, device=input_ids.device | ||
| ) | ||
| if labels is None: | ||
| labels = torch.full_like(input_ids, IGNORE_INDEX) | ||
|
|
||
| # remove the padding using attention_mask -- FIXME | ||
| input_ids = [ | ||
| cur_input_ids[cur_attention_mask] | ||
| for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) | ||
| ] | ||
| labels = [ | ||
| cur_labels[cur_attention_mask] | ||
| for cur_labels, cur_attention_mask in zip(labels, attention_mask) | ||
| ] | ||
|
|
||
| new_input_embeds = [] | ||
| new_labels = [] | ||
| cur_image_idx = 0 | ||
| # rank_print("Inserting Images embedding") | ||
| for batch_idx, cur_input_ids in enumerate(input_ids): | ||
| num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() | ||
| # rank0_print(num_images) | ||
| if num_images == 0: | ||
| cur_image_features = image_features[cur_image_idx] | ||
| cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) | ||
| cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) | ||
| new_input_embeds.append(cur_input_embeds) | ||
| new_labels.append(labels[batch_idx]) | ||
| cur_image_idx += 1 | ||
| continue | ||
|
|
||
| image_token_indices = [-1] + \ | ||
| torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] | ||
| cur_input_ids_noim = [] | ||
| cur_labels = labels[batch_idx] | ||
| cur_labels_noim = [] | ||
| for i in range(len(image_token_indices) - 1): | ||
| cur_input_ids_noim.append( | ||
| cur_input_ids[image_token_indices[i] + 1: image_token_indices[i + 1]] | ||
| ) | ||
| cur_labels_noim.append( | ||
| cur_labels[image_token_indices[i] + 1: image_token_indices[i + 1]] | ||
| ) | ||
| split_sizes = [x.shape[0] for x in cur_labels_noim] | ||
| cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) | ||
| cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) | ||
| cur_new_input_embeds = [] | ||
| cur_new_labels = [] | ||
|
|
||
| for i in range(num_images + 1): | ||
| cur_new_input_embeds.append(cur_input_embeds_no_im[i]) | ||
| cur_new_labels.append(cur_labels_noim[i]) | ||
| if i < num_images: | ||
| try: | ||
| cur_image_features = image_features[cur_image_idx] | ||
| except IndexError: | ||
| cur_image_features = image_features[cur_image_idx - 1] | ||
| cur_image_idx += 1 | ||
| cur_new_input_embeds.append(cur_image_features) | ||
| cur_new_labels.append( | ||
| torch.full( | ||
| (cur_image_features.shape[0],), | ||
| IGNORE_INDEX, | ||
| device=cur_labels.device, dtype=cur_labels.dtype | ||
| ) | ||
| ) | ||
|
|
||
| cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] | ||
|
|
||
| # import pdb; pdb.set_trace() | ||
| cur_new_input_embeds = torch.cat(cur_new_input_embeds) | ||
| cur_new_labels = torch.cat(cur_new_labels) | ||
|
|
||
| new_input_embeds.append(cur_new_input_embeds) | ||
| new_labels.append(cur_new_labels) | ||
|
|
||
| # Truncate sequences to max length as image embeddings can make the sequence longer | ||
| tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) | ||
| # rank_print("Finishing Inserting") | ||
|
|
||
| new_input_embeds = [ | ||
| x[:tokenizer_model_max_length] | ||
| for x, modality in zip(new_input_embeds, modalities) | ||
| ] | ||
| new_labels = [ | ||
| x[:tokenizer_model_max_length] | ||
| for x, modality in zip(new_labels, modalities) | ||
| ] | ||
|
|
||
| # Combine them | ||
| max_len = max(x.shape[0] for x in new_input_embeds) | ||
| batch_size = len(new_input_embeds) | ||
|
|
||
| new_input_embeds_padded = [] | ||
| new_labels_padded = torch.full( | ||
| (batch_size, max_len), | ||
| IGNORE_INDEX, | ||
| dtype=new_labels[0].dtype, | ||
| device=new_labels[0].device | ||
| ) | ||
| attention_mask = torch.zeros( | ||
| (batch_size, max_len), | ||
| dtype=attention_mask.dtype, | ||
| device=attention_mask.device | ||
| ) | ||
| position_ids = torch.zeros( | ||
| (batch_size, max_len), | ||
| dtype=position_ids.dtype, device=position_ids.device | ||
| ) | ||
| # rank0_print("Prepare pos id") | ||
|
|
||
| for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): | ||
| cur_len = cur_new_embed.shape[0] | ||
| if getattr(self.config, 'tokenizer_padding_side', 'right') == 'left': | ||
| new_input_embeds_padded.append( | ||
| torch.cat( | ||
| ( | ||
| torch.zeros( | ||
| (max_len - cur_len, cur_new_embed.shape[1]), | ||
| dtype=cur_new_embed.dtype, device=cur_new_embed.device | ||
| ), | ||
| cur_new_embed | ||
| ), dim=0 | ||
| ) | ||
| ) | ||
| if cur_len > 0: | ||
| new_labels_padded[i, -cur_len:] = cur_new_labels | ||
| attention_mask[i, -cur_len:] = True | ||
| position_ids[i, -cur_len:] = torch.arange( | ||
| 0, cur_len, | ||
| dtype=position_ids.dtype, device=position_ids.device | ||
| ) | ||
| else: | ||
| new_input_embeds_padded.append( | ||
| torch.cat( | ||
| ( | ||
| cur_new_embed, | ||
| torch.zeros( | ||
| (max_len - cur_len, cur_new_embed.shape[1]), | ||
| dtype=cur_new_embed.dtype, device=cur_new_embed.device | ||
| ) | ||
| ), dim=0 | ||
| ) | ||
| ) | ||
| if cur_len > 0: | ||
| new_labels_padded[i, :cur_len] = cur_new_labels | ||
| attention_mask[i, :cur_len] = True | ||
| position_ids[i, :cur_len] = torch.arange( | ||
| 0, cur_len, | ||
| dtype=position_ids.dtype, device=position_ids.device | ||
| ) | ||
|
|
||
| new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) | ||
| # rank0_print("tokenizer padding") | ||
|
|
||
| if _labels is None: | ||
| new_labels = None | ||
| else: | ||
| new_labels = new_labels_padded | ||
|
|
||
| if _attention_mask is None: | ||
| attention_mask = None | ||
| else: | ||
| attention_mask = attention_mask.to(dtype=_attention_mask.dtype) | ||
|
|
||
| if _position_ids is None: | ||
| position_ids = None | ||
| if getattr(self.config, 'use_pos_skipping', False) and self.training: | ||
| position_ids = torch.arange( | ||
| new_input_embeds.size(1), | ||
| device=new_input_embeds.device | ||
| ).unsqueeze(0).to(new_input_embeds.device) | ||
| split_position = random.randint(0, new_input_embeds.size(1)) | ||
| left_add = random.randint(0, self.config.pos_skipping_range) | ||
| right_add = random.randint(left_add, self.config.pos_skipping_range) | ||
| position_ids[:, :split_position] += left_add | ||
| position_ids[:, split_position:] += right_add | ||
| # import pdb; pdb.set_trace() | ||
| # rank0_print("Finish preparing") | ||
| # print(vtoken_length) | ||
| return None, position_ids, attention_mask, past_key_values, \ | ||
| new_input_embeds, new_labels, vtoken_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function prepare_inputs_labels_for_multimodal_with_index_masks is extremely long (over 300 lines) and complex. This significantly impacts readability and maintainability.
To improve the code quality, I recommend refactoring this function into smaller, more focused helper functions. Each helper function should have a single, well-defined responsibility. This will make the code easier to understand, test, and modify in the future.
| else: | ||
| image_features.append(image_feat) | ||
| mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') | ||
| # mm_patch_merge_type = mm_patch_merge_type.replace('_unpad', '') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| position_ids[:, split_position:] += right_add | ||
| # import pdb; pdb.set_trace() | ||
| # rank0_print("Finish preparing") | ||
| # print(vtoken_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.