Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/sparsification/methods/DART/dart.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ sparse:
special:
method: DART
pruning_loc: 5
reduction_ratio: 0.778
reduction_ratio: 0.7778
pivot_image_token: 4
pivot_text_token : 4
save:
Expand Down
2 changes: 1 addition & 1 deletion configs/sparsification/methods/VisPruner/vispruner.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ sparse:
method: TokenReduction
special:
method: VisPruner
prune_ratio: 0.778 # 0.667 0.778 0.889
prune_ratio: 0.7778 # 0.6667 0.7778 0.8889
important_ratio: 0.5
save:
save_trans: False
Expand Down
3 changes: 2 additions & 1 deletion llmc/compression/token_reduction/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def visualizer_hook(module, inps, layer_outs, pruning_paras):
save_path=''
)
visualize_grid_to_grid(
visual_attention_maps[0, 4, :, :],
visual_attention_maps[0, 31, :, :],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The index 31 is a "magic number," making it unclear what it represents. To improve readability and maintainability, it should be replaced with a named constant or derived from a configuration value (e.g., num_attention_heads - 1).

300,
image,
grid_size=24,
Expand All @@ -72,6 +72,7 @@ def visualizer_hook(module, inps, layer_outs, pruning_paras):
functools.partial(get_attentions_hook, pruning_paras=self.pruning_paras),
)
if idx == (len(self.blocks) - 1):
# self.model.language_model.layers[-1]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out line appears to be a development note. It should be removed to keep the codebase clean.

blk.register_forward_hook(
functools.partial(visualizer_hook, pruning_paras=self.pruning_paras),
)
207 changes: 187 additions & 20 deletions llmc/utils/visualizer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os

import cv2
import numpy as np
import torch
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from loguru import logger
from PIL import Image, ImageDraw
from torchvision.transforms import ToPILImage

try:
import matplotlib.pyplot as plt
import seaborn as sns
except Exception:
logger.warning(
'Can not import matplotlib. '
Expand All @@ -30,22 +29,32 @@ def to_pil_image(

def save_image(image_tensor, mean, std, save_path):
img = to_pil_image(image_tensor)
Image.fromarray(img).save(save_path)

if not save_path.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')):
os.makedirs(save_path, exist_ok=True)
base_path = os.path.join(save_path, '{:04d}_visprunerP.png')
idx = 0
while os.path.exists(base_path.format(idx)):
idx += 1
save_path = base_path.format(idx)

else:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
Comment on lines +33 to +42

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to determine if save_path is a directory or a file is based on a hardcoded list of extensions. This is brittle; for example, it would fail for a file without an extension or a file with an unlisted extension (e.g., .svg). A more robust approach is to check for the presence of a file extension using os.path.splitext.

Suggested change
if not save_path.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')):
os.makedirs(save_path, exist_ok=True)
base_path = os.path.join(save_path, '{:04d}_visprunerP.png')
idx = 0
while os.path.exists(base_path.format(idx)):
idx += 1
save_path = base_path.format(idx)
else:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
if not os.path.splitext(save_path)[1]:
os.makedirs(save_path, exist_ok=True)
base_path = os.path.join(save_path, '{:04d}_visprunerP.png')
idx = 0
while os.path.exists(base_path.format(idx)):
idx += 1
save_path = base_path.format(idx)
else:
os.makedirs(os.path.dirname(save_path), exist_ok=True)


img.save(save_path)


def visualize_kept_patches(
image,
keep_idx,
keep_idx=None,
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
patch_size=14,
darken_ratio=0.3,
darken_ratio=0.8,
save_path=None,
):
Comment on lines 47 to 55

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert image.ndim == 3 and image.shape[0] == 3 was removed. This check is valuable for ensuring the input image has the expected shape and provides a clear error message if it doesn't. Without it, the code might fail later with a less specific ValueError. Please consider adding it back at the beginning of the function body.

assert image.ndim == 3 and image.shape[0] == 3, \
f'Expected image of shape [3, H, W], got {image.shape}'
# save_image(image,mean,std,save_path)

# save_image(image, mean, std, save_path)
# return
_, H, W = image.shape # 3 336 336
device = image.device
num_patches_h = H // patch_size # 24
Expand All @@ -59,8 +68,11 @@ def visualize_kept_patches(
mask = patch_mask.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1)
mask = mask.unsqueeze(0) # shape [1, H, W]

# Darken image
masked_image = image * (mask + (~mask) * darken_ratio)
# white
prune_mask = ~mask
white_tensor = torch.ones_like(image)
masked_image = image * (1 - darken_ratio * prune_mask.float()) + \
white_tensor * (darken_ratio * prune_mask.float())

save_image(masked_image, mean, std, save_path)

Expand All @@ -85,14 +97,6 @@ def grid_show(to_shows, cols, save_path=None, dpi=100):
plt.savefig(save_path, bbox_inches='tight', dpi=dpi)
plt.close()

# def visualize_head(att_map):
# ax = plt.gca()
# # Plot the heatmap
# im = ax.imshow(att_map)
# # Create colorbar
# cbar = ax.figure.colorbar(im, ax=ax)
# plt.show()


def visualize_heads(att_map, cols, save_path):
to_shows = []
Expand Down Expand Up @@ -215,6 +219,169 @@ def visualize_grid_to_grid(att_map, grid_index, image, grid_size=14, alpha=0.6,
plt.close()


def visualize_attention(attention, grid_size=24, save_path=None):

if hasattr(attention, 'detach'):
attention = attention.detach().cpu().numpy()

H, W = attention.shape
new_H = H // grid_size * grid_size
new_W = W // grid_size * grid_size
attention = attention[:new_H, :new_W]

blocks = attention.reshape(new_H // grid_size, grid_size, new_W // grid_size, grid_size)
block_means = blocks.mean(axis=(1, 3))

mask = np.triu(np.ones_like(block_means, dtype=bool), k=1)

plt.figure(figsize=(10, 10))
sns.heatmap(block_means, mask=mask, cmap='viridis', square=True, cbar=True)

ticks = np.arange(0, block_means.shape[0], 1)
labels = ['' for i in ticks]
plt.xticks(ticks=ticks, labels=labels, rotation=90)
plt.yticks(ticks=ticks, labels=labels)

plt.title('Attention Map')
plt.tight_layout()
plt.savefig(save_path, bbox_inches='tight')
plt.close()


def visualize_attention_v2(attention, grid_size=24, save_path=None):

if hasattr(attention, 'detach'):
attention = attention.detach().cpu().numpy()

# 分区
block_ranges = []

# SYS: 2 blocks

sys_splits = [0, 17, 35]
for i in range(len(sys_splits) - 1):
block_ranges.append((sys_splits[i], sys_splits[i + 1]))
# IMG: 24 blocks of size 24
for i in range(24):
start = 35 + i * 24
end = start + 24
block_ranges.append((start, end))

# INS: 6 blocks
ins_splits = [611 + i * 91 for i in range(7)] # 611 + 6 * 91 = 1157 → crop to 1155
ins_splits[-1] = 1155
for i in range(len(ins_splits) - 1):
block_ranges.append((ins_splits[i], ins_splits[i + 1]))

# 对每个 block pair 求平均
num_blocks = len(block_ranges)
block_attention = np.zeros((num_blocks, num_blocks))
for i in range(num_blocks):
i_start, i_end = block_ranges[i]
for j in range(num_blocks):
j_start, j_end = block_ranges[j]
block = attention[i_start:i_end, j_start:j_end]
block_attention[31 - i, j] = block.mean()

mask = np.triu(np.ones_like(block_attention, dtype=bool), k=1)
plt.figure(figsize=(10, 10))
block_attention = block_attention / block_attention.max(axis=1, keepdims=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The normalization block_attention / block_attention.max(...) is unsafe. If block_attention.max(...) is zero for any row, it will result in a division by zero, producing NaN or inf values, which will likely cause sns.heatmap to fail. You should handle this case, for example by adding a small epsilon to the denominator or using np.divide with a where clause.

Suggested change
block_attention = block_attention / block_attention.max(axis=1, keepdims=True)
max_vals = block_attention.max(axis=1, keepdims=True)
block_attention = np.divide(block_attention, max_vals, out=np.zeros_like(block_attention), where=max_vals!=0)

sns.heatmap(block_attention, mask=mask, cmap='viridis', square=True, cbar=True)
# sns.heatmap(block_attention, cmap='viridis', square=True, cbar=True)

section_labels = ['SYS', 'IMG', 'INS']
section_boundaries = [2, 26, 32] # block_ranges 分别为2个SYS,24个IMG,6个INS
ticks = np.arange(0, num_blocks)
plt.xticks(ticks=ticks, labels=[''] * num_blocks)
plt.yticks(ticks=ticks, labels=[''] * num_blocks)
plt.xticks(ticks=section_boundaries, labels=section_labels, fontsize=12)
plt.yticks(ticks=section_boundaries, labels=section_labels, fontsize=12)
plt.title('Attention Map')
plt.tight_layout()
plt.savefig(save_path, bbox_inches='tight')
plt.close()
Comment on lines +251 to +302

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function contains many hardcoded "magic numbers" (e.g., [0, 17, 35], 35 + i * 24, 31 - i, [2, 26, 32]) that seem specific to a particular model architecture. This makes the function rigid and hard to maintain or reuse for other models. These values should be refactored to be passed in as parameters or loaded from a configuration.



def visualize_cosin_token(token_embedding, save_path=None):

plt.rcParams['font.size'] = 15

x = token_embedding[0, 14: 14 + 196 * 4, :]
x_norm = F.normalize(x, p=2, dim=1)
similarity_matrix = x_norm @ x_norm.T

sim_np = similarity_matrix.cpu().numpy()
sim_np = np.triu(sim_np, k=1)
valid_sim = sim_np[sim_np > 0]
vmin = np.percentile(valid_sim, 90) # 10% min

plt.subplots(figsize=(10, 10))
sns.heatmap(similarity_matrix.cpu().numpy(), cmap='Reds', vmin=vmin, vmax=1)

start = 0
step = 196
ticks = np.arange(start, 196 * 5, step)
plt.xticks(ticks, ticks)
plt.yticks(ticks, ticks)

plt.title('')
plt.xlabel('')
plt.ylabel('')
plt.tight_layout()
plt.savefig(save_path, format='pdf')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file format is hardcoded to 'pdf'. matplotlib.pyplot.savefig can infer the format from the file extension in save_path. Hardcoding the format prevents saving in other formats (e.g., PNG, SVG) and may be unexpected if save_path has a different extension. It's better to remove the format argument and let savefig infer it.

Suggested change
plt.savefig(save_path, format='pdf')
plt.savefig(save_path)

plt.rcdefaults()
Comment on lines +307 to +332

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Modifying matplotlib's global rcParams can have unintended side effects on other plotting functions in the application. It's better practice to use a context manager with plt.rc_context(...) to apply these settings only for this function. This also makes the call to plt.rcdefaults() at the end unnecessary.

plt.close()


def visualize_cosin_token_32p(token_embedding, save_path=None):

plt.rcParams['font.size'] = 20

all_tokens = token_embedding[0, 14:14 + 196 * 32, :]
x_norm = F.normalize(all_tokens, p=2, dim=1)
similarity_matrix = x_norm @ x_norm.T
sim_np = similarity_matrix.cpu().numpy()
sim_np = np.triu(sim_np, k=1)
valid_sim = sim_np[sim_np > 0]
vmin = np.percentile(valid_sim, 90) # 10% min

group_size = 4
num_groups = 8
tokens_per_group = 196 * group_size
step = 196

fig, axs = plt.subplots(2, 4, figsize=(22, 10)) # 2x4排布
axs = axs.flatten()

for i in range(num_groups):
x = all_tokens[i * tokens_per_group: (i + 1) * tokens_per_group, :]
x_norm = F.normalize(x, p=2, dim=1)
similarity_matrix = x_norm @ x_norm.T

ax = axs[i]
sns.heatmap(
similarity_matrix.cpu().numpy(), cmap='Reds',
vmin=vmin, vmax=1, ax=ax, cbar=False
)

ticks = np.arange(0, tokens_per_group, step)
labels = np.arange(i * tokens_per_group, (i + 1) * tokens_per_group, step)
ax.set_xticks(ticks)
ax.set_yticks(ticks)
ax.set_xticklabels(labels, rotation=0)
ax.set_yticklabels(labels)
start_frame = i * group_size
end_frame = (i + 1) * group_size - 1
ax.set_xlabel(f'Frame {start_frame}-{end_frame}', fontsize=17, labelpad=10)

plt.tight_layout()
# plt.savefig(save_path, format='pdf')
# plt.savefig(save_path.replace('.pdf', '.svg'), format='svg', bbox_inches='tight')
Comment on lines +378 to +379

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These commented-out lines should be removed to keep the code clean.

plt.savefig(save_path, dpi=300)
plt.rcdefaults()
Comment on lines +338 to +381

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to visualize_cosin_token, modifying global rcParams is risky. Please use with plt.rc_context(...) to scope the font size change to this function, which will also make plt.rcdefaults() unnecessary.

plt.close()


def highlight_grid(image, grid_indexes, grid_size=14):
if not isinstance(grid_size, tuple):
grid_size = (grid_size, grid_size)
Expand Down