From 68a09dc9b4e79959d57b1a8aa05312e9382488bc Mon Sep 17 00:00:00 2001 From: SmudgedWings <2045955563@qq.com> Date: Mon, 28 Jul 2025 23:46:45 +0800 Subject: [PATCH] update divprune,mustdrop for llava-next --- .../methods/DivPrune/divprune.yml | 23 ++ .../methods/MustDrop/mustdrop.yml | 25 ++ llmc/compression/token_reduction/divprune.py | 86 ++--- llmc/compression/token_reduction/mustdrop.py | 69 +++- llmc/compression/token_reduction/tome.py | 150 -------- llmc/compression/token_reduction/utils.py | 361 +++++++++++++++++- 6 files changed, 487 insertions(+), 227 deletions(-) create mode 100644 configs/sparsification/methods/DivPrune/divprune.yml create mode 100644 configs/sparsification/methods/MustDrop/mustdrop.yml diff --git a/configs/sparsification/methods/DivPrune/divprune.yml b/configs/sparsification/methods/DivPrune/divprune.yml new file mode 100644 index 00000000..0234c0bd --- /dev/null +++ b/configs/sparsification/methods/DivPrune/divprune.yml @@ -0,0 +1,23 @@ +base: + seed: &seed 42 +model: + type: Llava + path: model path + torch_dtype: auto +eval: + eval_pos: [pretrain, transformed] + type: vqa + name: [mme] + download: False + path: MME dataset path + bs: 1 + inference_per_block: False +sparse: + method: TokenReduction + special: + method: DivPrune + reduction_ratio: 0.9444 # 0.7778 0.8889 0.9444 +save: + save_trans: False + save_fake: False + save_path: /path/to/save/ diff --git a/configs/sparsification/methods/MustDrop/mustdrop.yml b/configs/sparsification/methods/MustDrop/mustdrop.yml new file mode 100644 index 00000000..87731fae --- /dev/null +++ b/configs/sparsification/methods/MustDrop/mustdrop.yml @@ -0,0 +1,25 @@ +base: + seed: &seed 42 +model: + type: Llava + path: model path + torch_dtype: auto +eval: + eval_pos: [pretrain, transformed] + type: vqa + name: [mme] + download: False + path: MME dataset path + bs: 1 +sparse: + vision: + method: TokenReduction + special: + method: MustDrop + spatial_threshold: 0.6 + window_size: [3, 3] + retained_tokens: 128 # llava_next: 128, 64, 32 llava: 192, 128, 64 +save: + save_trans: False + save_fake: False + save_path: /path/to/save/ diff --git a/llmc/compression/token_reduction/divprune.py b/llmc/compression/token_reduction/divprune.py index 50b4cf4e..9ca45e86 100644 --- a/llmc/compression/token_reduction/divprune.py +++ b/llmc/compression/token_reduction/divprune.py @@ -1,4 +1,3 @@ -import functools from functools import wraps from types import MethodType @@ -7,7 +6,6 @@ from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY from .token_reduction_module import TokenReductionModule -from .utils import prefill_wrapper def pairwise_cosine_similarity(matrix): @@ -22,7 +20,7 @@ def divprune( cosine_matrix=None, threshold_ratio=0.1, ): - threshold_terms = int(round(threshold_ratio * image_feature_length)) + threshold_terms = round(threshold_ratio * image_feature_length) if cosine_matrix is None: cosine_matrix = 1.0 - (pairwise_cosine_similarity(visual_feature_vectors)) @@ -53,22 +51,16 @@ def divprune( return s, cosine_matrix -def divprune_post_hook( - input_ids, - position_ids, - attention_mask, - past_key_values, - inputs_embeds, - labels, - pruning_paras=None, -): - rate = pruning_paras['rate'] - SYS_TOKEN_LEN = pruning_paras['image_token_start_index'] - img_feature_len = pruning_paras['image_token_length'] +def divprune_post_hook(*args, pruning_paras=None): + args = list(args) + position_ids, attention_mask, inputs_embeds = args[1], args[2], args[4] + rate = pruning_paras['reduction_ratio'] + SYS_TOKEN_LEN = pruning_paras['vision_token_start_index'] + img_feature_len = pruning_paras['vision_token_length'] device = inputs_embeds.device visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len] selected_visual_tokens, cosine_matrix = divprune( - visual_tokens, img_feature_len, None, threshold_ratio=rate + visual_tokens, img_feature_len, None, threshold_ratio=1 - rate ) selected_visual_tokens += SYS_TOKEN_LEN @@ -83,20 +75,13 @@ def divprune_post_hook( ) keep_indexs = keep_indexs.sort().values - inputs_embeds = inputs_embeds[:, keep_indexs] if position_ids is not None: - position_ids = position_ids[:, keep_indexs, :] + args[1] = position_ids[:, keep_indexs, :] if attention_mask is not None: - attention_mask = attention_mask[:, keep_indexs] - - return ( - input_ids, - position_ids, - attention_mask, - past_key_values, - inputs_embeds, - labels, - ) + args[2] = attention_mask[:, keep_indexs] + args[4] = inputs_embeds[:, keep_indexs] + + return tuple(args) @TOKEN_REDUCTION_REGISTRY.register('DivPrune') @@ -107,43 +92,34 @@ def __init__(self, config, model, blocks): self.register_reduction_modules() def add_sparse_config(self): - self.special_config['image_token_length'] = self.model.pruning_config[ - 'image_token_length' - ] - self.pruning_paras = self.special_config def register_reduction_modules(self): - def input_hook_llava(fn, pruning_paras): + def input_hook_llava(fn, pruning_paras, llava_next): @wraps(fn) def wrapper(self, *args, **kwargs): - if len(args) == 0: - return fn(*args, **kwargs) - input_args = args[0] - if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: + if args[0].shape[1] == 1: return fn(*args, **kwargs) - - input_ids = args[0] - attention_mask = args[2] - token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX - pruning_paras['image_token_start_index'] = torch.where(token_indices)[ - 0 - ][0].item() - - outputs = fn(*args, **kwargs) - - return divprune_post_hook(*outputs, pruning_paras=pruning_paras) - + outs = fn(*args, **kwargs) + + if llava_next: + message = ( + 'To obtain the vision_token_length for LLaVA-1.6, you should append ' + '`image_features[0].shape[0]` to the return value of the function ' + '`prepare_inputs_labels_for_multimodal`, and modify the related code.' + ) + assert len(outs) == 7, message + pruning_paras['vision_token_length'] = outs[-1] + return divprune_post_hook(*outs, pruning_paras=pruning_paras) return wrapper if self.model.__class__.__name__ == 'Llava': - from llava.constants import IMAGE_TOKEN_INDEX - hook_fn = input_hook_llava( - self.model.vlm_model.prepare_inputs_labels_for_multimodal, - self.pruning_paras, - ) self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( - hook_fn, self.model.vlm_model + input_hook_llava( + self.model.vlm_model.prepare_inputs_labels_for_multimodal, + self.pruning_paras, + llava_next=self.special_config['vision_token_length'] is None + ), self.model.vlm_model ) diff --git a/llmc/compression/token_reduction/mustdrop.py b/llmc/compression/token_reduction/mustdrop.py index ab00fe53..97c63942 100644 --- a/llmc/compression/token_reduction/mustdrop.py +++ b/llmc/compression/token_reduction/mustdrop.py @@ -1,10 +1,16 @@ import functools +import math +from types import MethodType +from typing import Callable, Tuple import torch +import torch.nn.functional as F +from einops import rearrange from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY from .token_reduction_module import TokenReductionModule +from .utils import prepare_inputs_labels_for_multimodal_with_index_masks @TOKEN_REDUCTION_REGISTRY.register('MustDrop') @@ -15,18 +21,11 @@ def __init__(self, config, model, blocks): self.register_reduction_modules() def add_sparse_config(self): - self.pruning_loc = self.special_config['pruning_loc'] + self.pruning_loc = self.model.pruning_config.get('select_layer', -1) self.pruning_paras = self.special_config def register_reduction_modules(self): - import math - from typing import Callable, Tuple - - import numpy as np - import torch.nn.functional as F - from einops import rearrange - def conditional_pooling( feat: torch.Tensor, threshold: float, @@ -170,7 +169,14 @@ def merge(x: torch.Tensor, mode='mean') -> torch.Tensor: ) x = torch.cat([dst, unm], dim=1) x = torch.cat((x_cls, x), dim=1) - return x + + index_masks = torch.zeros((n, t1), dtype=torch.bool, device=x_feat.device) + dst_flat = dst_idx.view(n, -1) + unm_flat = unm_idx.view(n, -1) + index_masks.scatter_(1, dst_flat, True) + index_masks.scatter_(1, unm_flat, True) + + return x, index_masks return merge @@ -181,26 +187,49 @@ def merge_wavg( if size is None: size = torch.ones_like(x[..., 0, None]) - x = merge(x * size, mode='sum') - size = merge(size, mode='sum') + x, index_masks = merge(x * size, mode='sum') + size, _ = merge(size, mode='sum') x = x / size - return x, size + return x, size, index_masks - def spatial_merge_hook(module, args, kwargs, layer_outs, pruning_paras): + def spatial_merge_hook(module, inps, outs, pruning_paras, llava_next): spatial_threshold = pruning_paras['spatial_threshold'] window_size = pruning_paras['window_size'] - hidden_states = layer_outs[0] + hidden_states = outs[0] + vtoken_length = hidden_states.shape[1] fix_r = 0 if pruning_paras.get('retained_tokens', None) is not None: retained_tokens = pruning_paras['retained_tokens'] - fix_r = (pruning_paras['vision_token_length'] - retained_tokens) \ + fix_r = (vtoken_length - retained_tokens) \ // (window_size[0] * window_size[1] - 1) merge = conditional_pooling(hidden_states, spatial_threshold, window_size, fix_r) - hidden_states, size = merge_wavg(merge, hidden_states, None) - return (hidden_states,) + hidden_states, size, index_masks = merge_wavg(merge, hidden_states, None) + + if not llava_next: + return (hidden_states,) - self.blocks[self.pruning_loc - 1].register_forward_hook( - functools.partial(spatial_merge_hook, pruning_paras=self.pruning_paras), - with_kwargs=True, + pruning_paras['index_masks'] = index_masks + return outs + + def update_index_masks_hook(module, inps, outs, pruning_paras): + module.index_masks = pruning_paras['index_masks'] + + self.blocks[self.pruning_loc].register_forward_hook( + functools.partial( + spatial_merge_hook, + pruning_paras=self.pruning_paras, + llava_next=self.special_config['vision_token_length'] is None + ), ) + + if self.special_config['vision_token_length'] is None: + + self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( + prepare_inputs_labels_for_multimodal_with_index_masks, + self.model.vlm_model + ) + + self.model.vision_model.register_forward_hook( + functools.partial(update_index_masks_hook, pruning_paras=self.pruning_paras), + ) diff --git a/llmc/compression/token_reduction/tome.py b/llmc/compression/token_reduction/tome.py index 759c795d..55c9e051 100644 --- a/llmc/compression/token_reduction/tome.py +++ b/llmc/compression/token_reduction/tome.py @@ -7,13 +7,6 @@ from loguru import logger from transformers.models.clip.modeling_clip import CLIPEncoderLayer -try: - from transformers.models.qwen2_vl.modeling_qwen2_vl import \ - Qwen2VLVisionBlock -except ModuleNotFoundError: - logger.info('Qwen2VLVisionBlock not found, if need, please upgrade transformers first.') - Qwen2VLVisionBlock = None - from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY from .token_reduction_module import TokenReductionModule @@ -60,32 +53,6 @@ def patch_layer(self): tome_CLIPEncoderLayer_forward, block ) - # elif isinstance(block, Qwen2VLVisionBlock): # qwenvl - # block.self_attn.original_forward = block.self_attn.forward - # block.self_attn.forward = types.MethodType( - # tome_VisionSdpaAttention_forward, - # block.self_attn - # ) - # block.original_forward = block.forward - # block.forward = types.MethodType( - # tome_Qwen2VLVisionBlock_forward, - # block - # ) - # else: # intervl2 token 剪枝数量有要求 - # block.attn.original_naive_attn_forward = block.attn._naive_attn - # block.attn._naive_attn = types.MethodType(tome_naive_attn, block.attn) - # block.attn.original_flash_attn_forward = block.attn._flash_attn - # block.attn._flash_attn = types.MethodType(tome_flash_attn, block.attn) - # block.attn.original_forward = block.attn.forward - # block.attn.forward = types.MethodType( - # tome_InternAttention_forward, - # block.attn - # ) - # block.original_forward = block.forward - # block.forward = types.MethodType( - # tome_InternVisionEncoderLayer_forward, - # block - # ) def do_nothing(x, mode=None): @@ -294,120 +261,3 @@ def tome_CLIPEncoderLayer_forward( outputs += (attn_weights,) return outputs - - -# def tome_VisionSdpaAttention_forward( -# self, hidden_states: torch.Tensor, -# cu_seqlens: torch.Tensor, -# rotary_pos_emb: torch.Tensor = None -# ) -> torch.Tensor: -# from transformers.models.qwen2_vl.modeling_qwen2_vl import \ -# apply_rotary_pos_emb_vision -# seq_length = hidden_states.shape[0] -# q, k, v = self.qkv(hidden_states) \ -# .reshape(seq_length, 3, self.num_heads, -1) \ -# .permute(1, 0, 2, 3) \ -# .unbind(0) -# q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) -# k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - -# attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) -# for i in range(1, len(cu_seqlens)): -# start, end = cu_seqlens[i - 1], cu_seqlens[i] -# attention_mask[..., start:end, start:end] = True -# q = q.transpose(0, 1) -# k = k.transpose(0, 1) -# v = v.transpose(0, 1) -# attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) -# attn_output = attn_output.transpose(0, 1) -# attn_output = attn_output.reshape(seq_length, -1) -# attn_output = self.proj(attn_output) -# return attn_output, k.mean(1) - - -# def tome_Qwen2VLVisionBlock_forward( -# self, hidden_states, cu_seqlens, rotary_pos_emb -# ) -> torch.Tensor: -# residual = hidden_states -# hidden_states, key_mean = self.attn( -# self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb -# ) -# hidden_states = residual + hidden_states - -# # ToMe -# merge, _ = bipartite_soft_matching( -# key_mean, -# self.r, -# True -# ) -# hidden_states, _ = merge_wavg(merge, hidden_states) -# hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) -# return hidden_states - -# def tome_naive_attn(self, x): -# B, N, C = x.shape -# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) -# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - -# if self.qk_normalization: -# B_, H_, N_, D_ = q.shape -# q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) -# k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) - -# attn = ((q * self.scale) @ k.transpose(-2, -1)) -# attn = attn.softmax(dim=-1) -# attn = self.attn_drop(attn) - -# x = (attn @ v).transpose(1, 2).reshape(B, N, C) -# x = self.proj(x) -# x = self.proj_drop(x) -# return x, k.mean(1) - -# def tome_flash_attn(self, x, key_padding_mask=None, need_weights=False): -# from einops import rearrange -# qkv = self.qkv(x) -# qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) - -# if self.qk_normalization: -# q, k, v = qkv.unbind(2) -# q = self.q_norm(q.flatten(-2, -1)).view(q.shape) -# k = self.k_norm(k.flatten(-2, -1)).view(k.shape) -# qkv = torch.stack([q, k, v], dim=2) - -# context, _ = self.inner_attn( -# qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False -# ) -# outs = self.proj(rearrange(context, 'b s h d -> b s (h d)')) -# outs = self.proj_drop(outs) -# return outs, k.mean(1) - - -# def tome_InternAttention_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: -# if self.use_flash_attn: -# x, key_mean = self._flash_attn(hidden_states) -# else: -# x, key_mean = self._naive_attn(hidden_states) -# return x, key_mean - - -# def tome_InternVisionEncoderLayer_forward( -# self, -# hidden_states: torch.Tensor, -# ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: - -# residual = hidden_states -# x_attn, key_mean = self.attn(self.norm1(hidden_states)) -# hidden_states = residual + self.drop_path1(x_attn * self.ls1) - -# merge, _ = bipartite_soft_matching( -# key_mean, -# self.r, -# True -# ) -# hidden_states, _ = merge_wavg(merge, hidden_states) - -# residual = hidden_states -# hidden_states = self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2) -# hidden_states = residual + hidden_states - -# return hidden_states diff --git a/llmc/compression/token_reduction/utils.py b/llmc/compression/token_reduction/utils.py index cf657e22..d3451ffb 100755 --- a/llmc/compression/token_reduction/utils.py +++ b/llmc/compression/token_reduction/utils.py @@ -1,12 +1,18 @@ import ast import re from functools import wraps -from typing import Any, List, Optional, Tuple, Union +from typing import List, Tuple, Union import torch -import torch.nn as nn +from loguru import logger from transformers.models.clip.modeling_clip import CLIPEncoderLayer +try: + from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX +except Exception as e: + logger.debug('LLaVA is not installed. Please install LLaVA to use this model.\nError: %s' % e) +import random + def prefill_wrapper(func): @wraps(func) @@ -202,3 +208,354 @@ def unpad_image(tensor, original_size): unpadded_tensor = tensor[:, :, padding: current_width - padding] return unpadded_tensor + + +def prepare_inputs_labels_for_multimodal_with_index_masks( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + images, modalities=['image'], image_sizes=None +): + vision_tower = self.get_vision_tower() + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels, None + + if isinstance(modalities, str): + modalities = [modalities] + + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + + video_idx_in_batch = [] + for _ in range(len(modalities)): + if modalities[_] == 'video': + video_idx_in_batch.append(_) + + images_list = [] + for image in images: + if image.ndim == 4: + images_list.append(image) + else: + images_list.append(image.unsqueeze(0)) + + concat_images = torch.cat([image for image in images_list], dim=0) + split_sizes = [image.shape[0] for image in images_list] + encoded_image_features = self.encode_images(concat_images) + index_masks = vision_tower.index_masks + encoded_image_features = torch.split(encoded_image_features, split_sizes) + index_masks = torch.split(index_masks, split_sizes) + image_features = [] + for idx, image_feat in enumerate(encoded_image_features): + if idx in video_idx_in_batch: + image_features.append(self.get_2dPool(image_feat)) + else: + image_features.append(image_feat) + mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') + # mm_patch_merge_type = mm_patch_merge_type.replace('_unpad', '') + image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') + + if mm_patch_merge_type == 'flat': + image_features = [x.flatten(0, 1) for x in image_features] + index_masks = [x.flatten(0, 1) for x in index_masks] + image_features = [x[m] for x, m in zip(image_features, index_masks)] + elif mm_patch_merge_type.startswith('spatial'): + new_image_features = [] + for image_idx, (image_feature, index_mask) in enumerate( + zip(image_features, index_masks) + ): + if image_idx in video_idx_in_batch: # video operations + raise NotImplementedError + elif image_feature.shape[0] > 1: + + base_image_feature, base_index_mask = image_feature[0], index_mask[0] + image_feature, index_mask = image_feature[1:], index_mask[1:] + height = width = self.get_vision_tower().num_patches_per_side + assert height * width == base_image_feature.shape[0] + + if image_aspect_ratio == 'anyres': + if hasattr(self.get_vision_tower(), 'image_size'): + vision_tower_image_size = self.get_vision_tower().image_size + else: + raise ValueError('vision_tower_image_size is not found.') + try: + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + vision_tower_image_size + ) + except Exception: + num_patch_width, num_patch_height = 2, 2 + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + index_mask = index_mask.view( + num_patch_height, num_patch_width, height, width + ) + else: + raise NotImplementedError + + if 'maxpool2x2' in mm_patch_merge_type: + raise NotImplementedError + elif 'unpad' in mm_patch_merge_type and 'anyres_max' in image_aspect_ratio: + NotImplementedError + elif 'unpad' in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.model.image_newline[ + :, None, None + ].expand(*image_feature.shape[:-1], 1).to(image_feature.device) + ), dim=-1 + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + index_mask = index_mask.permute(0, 2, 1, 3).contiguous().unsqueeze(0) + index_mask = index_mask.flatten(1, 2).flatten(2, 3) + index_mask = unpad_image(index_mask, image_sizes[image_idx]) + index_mask = torch.cat(( + index_mask, + torch.ones( + *index_mask.shape[:-1], 1, dtype=torch.bool + ).to(index_mask.device) + ), dim=-1) + index_mask = index_mask.flatten(1, 2).squeeze(0) + image_feature = image_feature[index_mask] + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + index_mask = index_mask.permute(0, 2, 1, 3).contiguous() + index_mask = index_mask.flatten(0, 3) + image_feature = image_feature[index_mask] + if 'nobase' in mm_patch_merge_type: + pass + else: + base_image_feature = base_image_feature[base_index_mask] + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + new_image_features.append(image_feature) + else: # single image operations + image_feature = image_feature[0] + index_mask = index_mask[0] + if 'unpad' in mm_patch_merge_type: + image_feature = torch.cat(( + image_feature, + self.model.image_newline[None].to(image_feature.device) + ), dim=0) + index_mask = torch.cat(( + index_mask, + torch.ones(1, dtype=torch.bool).to(index_mask.device) + ), dim=0) + image_feature = image_feature[index_mask] + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError(f'Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}') + else: + image_features = self.encode_images(images) + image_features = image_features[index_masks].unsqueeze(0) + vision_tower.index_masks = [] + vtoken_length = image_features[0].shape[0] + # TODO: image start / end is not implemented here to support pretraining. + if ( + getattr(self.config, 'tune_mm_mlp_adapter', False) and + getattr(self.config, 'mm_use_im_start_end', False) + ): + raise NotImplementedError + # rank_print(f"Total images : {len(image_features)}") + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange( + 0, input_ids.shape[1], + dtype=torch.long, device=input_ids.device + ) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + input_ids = [ + cur_input_ids[cur_attention_mask] + for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) + ] + labels = [ + cur_labels[cur_attention_mask] + for cur_labels, cur_attention_mask in zip(labels, attention_mask) + ] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + # rank_print("Inserting Images embedding") + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + # rank0_print(num_images) + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = [-1] + \ + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append( + cur_input_ids[image_token_indices[i] + 1: image_token_indices[i + 1]] + ) + cur_labels_noim.append( + cur_labels[image_token_indices[i] + 1: image_token_indices[i + 1]] + ) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + try: + cur_image_features = image_features[cur_image_idx] + except IndexError: + cur_image_features = image_features[cur_image_idx - 1] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append( + torch.full( + (cur_image_features.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, dtype=cur_labels.dtype + ) + ) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + # import pdb; pdb.set_trace() + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + # rank_print("Finishing Inserting") + + new_input_embeds = [ + x[:tokenizer_model_max_length] + for x, modality in zip(new_input_embeds, modalities) + ] + new_labels = [ + x[:tokenizer_model_max_length] + for x, modality in zip(new_labels, modalities) + ] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full( + (batch_size, max_len), + IGNORE_INDEX, + dtype=new_labels[0].dtype, + device=new_labels[0].device + ) + attention_mask = torch.zeros( + (batch_size, max_len), + dtype=attention_mask.dtype, + device=attention_mask.device + ) + position_ids = torch.zeros( + (batch_size, max_len), + dtype=position_ids.dtype, device=position_ids.device + ) + # rank0_print("Prepare pos id") + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == 'left': + new_input_embeds_padded.append( + torch.cat( + ( + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, device=cur_new_embed.device + ), + cur_new_embed + ), dim=0 + ) + ) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange( + 0, cur_len, + dtype=position_ids.dtype, device=position_ids.device + ) + else: + new_input_embeds_padded.append( + torch.cat( + ( + cur_new_embed, + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, device=cur_new_embed.device + ) + ), dim=0 + ) + ) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange( + 0, cur_len, + dtype=position_ids.dtype, device=position_ids.device + ) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + # rank0_print("tokenizer padding") + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + if getattr(self.config, 'use_pos_skipping', False) and self.training: + position_ids = torch.arange( + new_input_embeds.size(1), + device=new_input_embeds.device + ).unsqueeze(0).to(new_input_embeds.device) + split_position = random.randint(0, new_input_embeds.size(1)) + left_add = random.randint(0, self.config.pos_skipping_range) + right_add = random.randint(left_add, self.config.pos_skipping_range) + position_ids[:, :split_position] += left_add + position_ids[:, split_position:] += right_add + # import pdb; pdb.set_trace() + # rank0_print("Finish preparing") + # print(vtoken_length) + return None, position_ids, attention_mask, past_key_values, \ + new_input_embeds, new_labels, vtoken_length