From 2a055e2561d31c932560eb1157f28de0e8558029 Mon Sep 17 00:00:00 2001 From: zhangbilang <2045955563@qq.com> Date: Tue, 12 Aug 2025 09:13:06 +0800 Subject: [PATCH] update visualizer --- configs/sparsification/methods/DART/dart.yml | 2 +- .../methods/VisPruner/vispruner.yml | 2 +- .../compression/token_reduction/visualizer.py | 3 +- llmc/utils/visualizer.py | 207 ++++++++++++++++-- 4 files changed, 191 insertions(+), 23 deletions(-) diff --git a/configs/sparsification/methods/DART/dart.yml b/configs/sparsification/methods/DART/dart.yml index 426f256a..104f560c 100644 --- a/configs/sparsification/methods/DART/dart.yml +++ b/configs/sparsification/methods/DART/dart.yml @@ -17,7 +17,7 @@ sparse: special: method: DART pruning_loc: 5 - reduction_ratio: 0.778 + reduction_ratio: 0.7778 pivot_image_token: 4 pivot_text_token : 4 save: diff --git a/configs/sparsification/methods/VisPruner/vispruner.yml b/configs/sparsification/methods/VisPruner/vispruner.yml index eb70c554..eef5ffc8 100644 --- a/configs/sparsification/methods/VisPruner/vispruner.yml +++ b/configs/sparsification/methods/VisPruner/vispruner.yml @@ -18,7 +18,7 @@ sparse: method: TokenReduction special: method: VisPruner - prune_ratio: 0.778 # 0.667 0.778 0.889 + prune_ratio: 0.7778 # 0.6667 0.7778 0.8889 important_ratio: 0.5 save: save_trans: False diff --git a/llmc/compression/token_reduction/visualizer.py b/llmc/compression/token_reduction/visualizer.py index 256c282b..732b901a 100644 --- a/llmc/compression/token_reduction/visualizer.py +++ b/llmc/compression/token_reduction/visualizer.py @@ -48,7 +48,7 @@ def visualizer_hook(module, inps, layer_outs, pruning_paras): save_path='' ) visualize_grid_to_grid( - visual_attention_maps[0, 4, :, :], + visual_attention_maps[0, 31, :, :], 300, image, grid_size=24, @@ -72,6 +72,7 @@ def visualizer_hook(module, inps, layer_outs, pruning_paras): functools.partial(get_attentions_hook, pruning_paras=self.pruning_paras), ) if idx == (len(self.blocks) - 1): + # self.model.language_model.layers[-1] blk.register_forward_hook( functools.partial(visualizer_hook, pruning_paras=self.pruning_paras), ) diff --git a/llmc/utils/visualizer.py b/llmc/utils/visualizer.py index b9fc3ebd..5dec72f7 100644 --- a/llmc/utils/visualizer.py +++ b/llmc/utils/visualizer.py @@ -1,15 +1,14 @@ import os -import cv2 import numpy as np import torch -import torchvision.transforms.functional as TF +import torch.nn.functional as F from loguru import logger from PIL import Image, ImageDraw -from torchvision.transforms import ToPILImage try: import matplotlib.pyplot as plt + import seaborn as sns except Exception: logger.warning( 'Can not import matplotlib. ' @@ -30,22 +29,32 @@ def to_pil_image( def save_image(image_tensor, mean, std, save_path): img = to_pil_image(image_tensor) - Image.fromarray(img).save(save_path) + + if not save_path.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')): + os.makedirs(save_path, exist_ok=True) + base_path = os.path.join(save_path, '{:04d}_visprunerP.png') + idx = 0 + while os.path.exists(base_path.format(idx)): + idx += 1 + save_path = base_path.format(idx) + + else: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + img.save(save_path) def visualize_kept_patches( image, - keep_idx, + keep_idx=None, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], patch_size=14, - darken_ratio=0.3, + darken_ratio=0.8, save_path=None, ): - assert image.ndim == 3 and image.shape[0] == 3, \ - f'Expected image of shape [3, H, W], got {image.shape}' - # save_image(image,mean,std,save_path) - + # save_image(image, mean, std, save_path) + # return _, H, W = image.shape # 3 336 336 device = image.device num_patches_h = H // patch_size # 24 @@ -59,8 +68,11 @@ def visualize_kept_patches( mask = patch_mask.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1) mask = mask.unsqueeze(0) # shape [1, H, W] - # Darken image - masked_image = image * (mask + (~mask) * darken_ratio) + # white + prune_mask = ~mask + white_tensor = torch.ones_like(image) + masked_image = image * (1 - darken_ratio * prune_mask.float()) + \ + white_tensor * (darken_ratio * prune_mask.float()) save_image(masked_image, mean, std, save_path) @@ -85,14 +97,6 @@ def grid_show(to_shows, cols, save_path=None, dpi=100): plt.savefig(save_path, bbox_inches='tight', dpi=dpi) plt.close() -# def visualize_head(att_map): -# ax = plt.gca() -# # Plot the heatmap -# im = ax.imshow(att_map) -# # Create colorbar -# cbar = ax.figure.colorbar(im, ax=ax) -# plt.show() - def visualize_heads(att_map, cols, save_path): to_shows = [] @@ -215,6 +219,169 @@ def visualize_grid_to_grid(att_map, grid_index, image, grid_size=14, alpha=0.6, plt.close() +def visualize_attention(attention, grid_size=24, save_path=None): + + if hasattr(attention, 'detach'): + attention = attention.detach().cpu().numpy() + + H, W = attention.shape + new_H = H // grid_size * grid_size + new_W = W // grid_size * grid_size + attention = attention[:new_H, :new_W] + + blocks = attention.reshape(new_H // grid_size, grid_size, new_W // grid_size, grid_size) + block_means = blocks.mean(axis=(1, 3)) + + mask = np.triu(np.ones_like(block_means, dtype=bool), k=1) + + plt.figure(figsize=(10, 10)) + sns.heatmap(block_means, mask=mask, cmap='viridis', square=True, cbar=True) + + ticks = np.arange(0, block_means.shape[0], 1) + labels = ['' for i in ticks] + plt.xticks(ticks=ticks, labels=labels, rotation=90) + plt.yticks(ticks=ticks, labels=labels) + + plt.title('Attention Map') + plt.tight_layout() + plt.savefig(save_path, bbox_inches='tight') + plt.close() + + +def visualize_attention_v2(attention, grid_size=24, save_path=None): + + if hasattr(attention, 'detach'): + attention = attention.detach().cpu().numpy() + + # 分区 + block_ranges = [] + + # SYS: 2 blocks + + sys_splits = [0, 17, 35] + for i in range(len(sys_splits) - 1): + block_ranges.append((sys_splits[i], sys_splits[i + 1])) + # IMG: 24 blocks of size 24 + for i in range(24): + start = 35 + i * 24 + end = start + 24 + block_ranges.append((start, end)) + + # INS: 6 blocks + ins_splits = [611 + i * 91 for i in range(7)] # 611 + 6 * 91 = 1157 → crop to 1155 + ins_splits[-1] = 1155 + for i in range(len(ins_splits) - 1): + block_ranges.append((ins_splits[i], ins_splits[i + 1])) + + # 对每个 block pair 求平均 + num_blocks = len(block_ranges) + block_attention = np.zeros((num_blocks, num_blocks)) + for i in range(num_blocks): + i_start, i_end = block_ranges[i] + for j in range(num_blocks): + j_start, j_end = block_ranges[j] + block = attention[i_start:i_end, j_start:j_end] + block_attention[31 - i, j] = block.mean() + + mask = np.triu(np.ones_like(block_attention, dtype=bool), k=1) + plt.figure(figsize=(10, 10)) + block_attention = block_attention / block_attention.max(axis=1, keepdims=True) + sns.heatmap(block_attention, mask=mask, cmap='viridis', square=True, cbar=True) + # sns.heatmap(block_attention, cmap='viridis', square=True, cbar=True) + + section_labels = ['SYS', 'IMG', 'INS'] + section_boundaries = [2, 26, 32] # block_ranges 分别为2个SYS,24个IMG,6个INS + ticks = np.arange(0, num_blocks) + plt.xticks(ticks=ticks, labels=[''] * num_blocks) + plt.yticks(ticks=ticks, labels=[''] * num_blocks) + plt.xticks(ticks=section_boundaries, labels=section_labels, fontsize=12) + plt.yticks(ticks=section_boundaries, labels=section_labels, fontsize=12) + plt.title('Attention Map') + plt.tight_layout() + plt.savefig(save_path, bbox_inches='tight') + plt.close() + + +def visualize_cosin_token(token_embedding, save_path=None): + + plt.rcParams['font.size'] = 15 + + x = token_embedding[0, 14: 14 + 196 * 4, :] + x_norm = F.normalize(x, p=2, dim=1) + similarity_matrix = x_norm @ x_norm.T + + sim_np = similarity_matrix.cpu().numpy() + sim_np = np.triu(sim_np, k=1) + valid_sim = sim_np[sim_np > 0] + vmin = np.percentile(valid_sim, 90) # 10% min + + plt.subplots(figsize=(10, 10)) + sns.heatmap(similarity_matrix.cpu().numpy(), cmap='Reds', vmin=vmin, vmax=1) + + start = 0 + step = 196 + ticks = np.arange(start, 196 * 5, step) + plt.xticks(ticks, ticks) + plt.yticks(ticks, ticks) + + plt.title('') + plt.xlabel('') + plt.ylabel('') + plt.tight_layout() + plt.savefig(save_path, format='pdf') + plt.rcdefaults() + plt.close() + + +def visualize_cosin_token_32p(token_embedding, save_path=None): + + plt.rcParams['font.size'] = 20 + + all_tokens = token_embedding[0, 14:14 + 196 * 32, :] + x_norm = F.normalize(all_tokens, p=2, dim=1) + similarity_matrix = x_norm @ x_norm.T + sim_np = similarity_matrix.cpu().numpy() + sim_np = np.triu(sim_np, k=1) + valid_sim = sim_np[sim_np > 0] + vmin = np.percentile(valid_sim, 90) # 10% min + + group_size = 4 + num_groups = 8 + tokens_per_group = 196 * group_size + step = 196 + + fig, axs = plt.subplots(2, 4, figsize=(22, 10)) # 2x4排布 + axs = axs.flatten() + + for i in range(num_groups): + x = all_tokens[i * tokens_per_group: (i + 1) * tokens_per_group, :] + x_norm = F.normalize(x, p=2, dim=1) + similarity_matrix = x_norm @ x_norm.T + + ax = axs[i] + sns.heatmap( + similarity_matrix.cpu().numpy(), cmap='Reds', + vmin=vmin, vmax=1, ax=ax, cbar=False + ) + + ticks = np.arange(0, tokens_per_group, step) + labels = np.arange(i * tokens_per_group, (i + 1) * tokens_per_group, step) + ax.set_xticks(ticks) + ax.set_yticks(ticks) + ax.set_xticklabels(labels, rotation=0) + ax.set_yticklabels(labels) + start_frame = i * group_size + end_frame = (i + 1) * group_size - 1 + ax.set_xlabel(f'Frame {start_frame}-{end_frame}', fontsize=17, labelpad=10) + + plt.tight_layout() + # plt.savefig(save_path, format='pdf') + # plt.savefig(save_path.replace('.pdf', '.svg'), format='svg', bbox_inches='tight') + plt.savefig(save_path, dpi=300) + plt.rcdefaults() + plt.close() + + def highlight_grid(image, grid_indexes, grid_size=14): if not isinstance(grid_size, tuple): grid_size = (grid_size, grid_size)