-
Notifications
You must be signed in to change notification settings - Fork 66
update visualizer #434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update visualizer #434
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,7 +48,7 @@ def visualizer_hook(module, inps, layer_outs, pruning_paras): | |
| save_path='' | ||
| ) | ||
| visualize_grid_to_grid( | ||
| visual_attention_maps[0, 4, :, :], | ||
| visual_attention_maps[0, 31, :, :], | ||
| 300, | ||
| image, | ||
| grid_size=24, | ||
|
|
@@ -72,6 +72,7 @@ def visualizer_hook(module, inps, layer_outs, pruning_paras): | |
| functools.partial(get_attentions_hook, pruning_paras=self.pruning_paras), | ||
| ) | ||
| if idx == (len(self.blocks) - 1): | ||
| # self.model.language_model.layers[-1] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| blk.register_forward_hook( | ||
| functools.partial(visualizer_hook, pruning_paras=self.pruning_paras), | ||
| ) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,15 +1,14 @@ | ||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| import cv2 | ||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||
| import torchvision.transforms.functional as TF | ||||||||||||||||||||||||||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||||||||||||||||||||||||||
| from loguru import logger | ||||||||||||||||||||||||||||||||||||||||
| from PIL import Image, ImageDraw | ||||||||||||||||||||||||||||||||||||||||
| from torchvision.transforms import ToPILImage | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||
| import matplotlib.pyplot as plt | ||||||||||||||||||||||||||||||||||||||||
| import seaborn as sns | ||||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||||||||||||
| 'Can not import matplotlib. ' | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -30,22 +29,32 @@ def to_pil_image( | |||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def save_image(image_tensor, mean, std, save_path): | ||||||||||||||||||||||||||||||||||||||||
| img = to_pil_image(image_tensor) | ||||||||||||||||||||||||||||||||||||||||
| Image.fromarray(img).save(save_path) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if not save_path.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')): | ||||||||||||||||||||||||||||||||||||||||
| os.makedirs(save_path, exist_ok=True) | ||||||||||||||||||||||||||||||||||||||||
| base_path = os.path.join(save_path, '{:04d}_visprunerP.png') | ||||||||||||||||||||||||||||||||||||||||
| idx = 0 | ||||||||||||||||||||||||||||||||||||||||
| while os.path.exists(base_path.format(idx)): | ||||||||||||||||||||||||||||||||||||||||
| idx += 1 | ||||||||||||||||||||||||||||||||||||||||
| save_path = base_path.format(idx) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+33
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic to determine if
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| img.save(save_path) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def visualize_kept_patches( | ||||||||||||||||||||||||||||||||||||||||
| image, | ||||||||||||||||||||||||||||||||||||||||
| keep_idx, | ||||||||||||||||||||||||||||||||||||||||
| keep_idx=None, | ||||||||||||||||||||||||||||||||||||||||
| mean=[0.48145466, 0.4578275, 0.40821073], | ||||||||||||||||||||||||||||||||||||||||
| std=[0.26862954, 0.26130258, 0.27577711], | ||||||||||||||||||||||||||||||||||||||||
| patch_size=14, | ||||||||||||||||||||||||||||||||||||||||
| darken_ratio=0.3, | ||||||||||||||||||||||||||||||||||||||||
| darken_ratio=0.8, | ||||||||||||||||||||||||||||||||||||||||
| save_path=None, | ||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
47
to
55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertion |
||||||||||||||||||||||||||||||||||||||||
| assert image.ndim == 3 and image.shape[0] == 3, \ | ||||||||||||||||||||||||||||||||||||||||
| f'Expected image of shape [3, H, W], got {image.shape}' | ||||||||||||||||||||||||||||||||||||||||
| # save_image(image,mean,std,save_path) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # save_image(image, mean, std, save_path) | ||||||||||||||||||||||||||||||||||||||||
| # return | ||||||||||||||||||||||||||||||||||||||||
| _, H, W = image.shape # 3 336 336 | ||||||||||||||||||||||||||||||||||||||||
| device = image.device | ||||||||||||||||||||||||||||||||||||||||
| num_patches_h = H // patch_size # 24 | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -59,8 +68,11 @@ def visualize_kept_patches( | |||||||||||||||||||||||||||||||||||||||
| mask = patch_mask.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1) | ||||||||||||||||||||||||||||||||||||||||
| mask = mask.unsqueeze(0) # shape [1, H, W] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Darken image | ||||||||||||||||||||||||||||||||||||||||
| masked_image = image * (mask + (~mask) * darken_ratio) | ||||||||||||||||||||||||||||||||||||||||
| # white | ||||||||||||||||||||||||||||||||||||||||
| prune_mask = ~mask | ||||||||||||||||||||||||||||||||||||||||
| white_tensor = torch.ones_like(image) | ||||||||||||||||||||||||||||||||||||||||
| masked_image = image * (1 - darken_ratio * prune_mask.float()) + \ | ||||||||||||||||||||||||||||||||||||||||
| white_tensor * (darken_ratio * prune_mask.float()) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| save_image(masked_image, mean, std, save_path) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
@@ -85,14 +97,6 @@ def grid_show(to_shows, cols, save_path=None, dpi=100): | |||||||||||||||||||||||||||||||||||||||
| plt.savefig(save_path, bbox_inches='tight', dpi=dpi) | ||||||||||||||||||||||||||||||||||||||||
| plt.close() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # def visualize_head(att_map): | ||||||||||||||||||||||||||||||||||||||||
| # ax = plt.gca() | ||||||||||||||||||||||||||||||||||||||||
| # # Plot the heatmap | ||||||||||||||||||||||||||||||||||||||||
| # im = ax.imshow(att_map) | ||||||||||||||||||||||||||||||||||||||||
| # # Create colorbar | ||||||||||||||||||||||||||||||||||||||||
| # cbar = ax.figure.colorbar(im, ax=ax) | ||||||||||||||||||||||||||||||||||||||||
| # plt.show() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def visualize_heads(att_map, cols, save_path): | ||||||||||||||||||||||||||||||||||||||||
| to_shows = [] | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -215,6 +219,169 @@ def visualize_grid_to_grid(att_map, grid_index, image, grid_size=14, alpha=0.6, | |||||||||||||||||||||||||||||||||||||||
| plt.close() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def visualize_attention(attention, grid_size=24, save_path=None): | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if hasattr(attention, 'detach'): | ||||||||||||||||||||||||||||||||||||||||
| attention = attention.detach().cpu().numpy() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| H, W = attention.shape | ||||||||||||||||||||||||||||||||||||||||
| new_H = H // grid_size * grid_size | ||||||||||||||||||||||||||||||||||||||||
| new_W = W // grid_size * grid_size | ||||||||||||||||||||||||||||||||||||||||
| attention = attention[:new_H, :new_W] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| blocks = attention.reshape(new_H // grid_size, grid_size, new_W // grid_size, grid_size) | ||||||||||||||||||||||||||||||||||||||||
| block_means = blocks.mean(axis=(1, 3)) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| mask = np.triu(np.ones_like(block_means, dtype=bool), k=1) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| plt.figure(figsize=(10, 10)) | ||||||||||||||||||||||||||||||||||||||||
| sns.heatmap(block_means, mask=mask, cmap='viridis', square=True, cbar=True) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ticks = np.arange(0, block_means.shape[0], 1) | ||||||||||||||||||||||||||||||||||||||||
| labels = ['' for i in ticks] | ||||||||||||||||||||||||||||||||||||||||
| plt.xticks(ticks=ticks, labels=labels, rotation=90) | ||||||||||||||||||||||||||||||||||||||||
| plt.yticks(ticks=ticks, labels=labels) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| plt.title('Attention Map') | ||||||||||||||||||||||||||||||||||||||||
| plt.tight_layout() | ||||||||||||||||||||||||||||||||||||||||
| plt.savefig(save_path, bbox_inches='tight') | ||||||||||||||||||||||||||||||||||||||||
| plt.close() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def visualize_attention_v2(attention, grid_size=24, save_path=None): | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if hasattr(attention, 'detach'): | ||||||||||||||||||||||||||||||||||||||||
| attention = attention.detach().cpu().numpy() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 分区 | ||||||||||||||||||||||||||||||||||||||||
| block_ranges = [] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # SYS: 2 blocks | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| sys_splits = [0, 17, 35] | ||||||||||||||||||||||||||||||||||||||||
| for i in range(len(sys_splits) - 1): | ||||||||||||||||||||||||||||||||||||||||
| block_ranges.append((sys_splits[i], sys_splits[i + 1])) | ||||||||||||||||||||||||||||||||||||||||
| # IMG: 24 blocks of size 24 | ||||||||||||||||||||||||||||||||||||||||
| for i in range(24): | ||||||||||||||||||||||||||||||||||||||||
| start = 35 + i * 24 | ||||||||||||||||||||||||||||||||||||||||
| end = start + 24 | ||||||||||||||||||||||||||||||||||||||||
| block_ranges.append((start, end)) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # INS: 6 blocks | ||||||||||||||||||||||||||||||||||||||||
| ins_splits = [611 + i * 91 for i in range(7)] # 611 + 6 * 91 = 1157 → crop to 1155 | ||||||||||||||||||||||||||||||||||||||||
| ins_splits[-1] = 1155 | ||||||||||||||||||||||||||||||||||||||||
| for i in range(len(ins_splits) - 1): | ||||||||||||||||||||||||||||||||||||||||
| block_ranges.append((ins_splits[i], ins_splits[i + 1])) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 对每个 block pair 求平均 | ||||||||||||||||||||||||||||||||||||||||
| num_blocks = len(block_ranges) | ||||||||||||||||||||||||||||||||||||||||
| block_attention = np.zeros((num_blocks, num_blocks)) | ||||||||||||||||||||||||||||||||||||||||
| for i in range(num_blocks): | ||||||||||||||||||||||||||||||||||||||||
| i_start, i_end = block_ranges[i] | ||||||||||||||||||||||||||||||||||||||||
| for j in range(num_blocks): | ||||||||||||||||||||||||||||||||||||||||
| j_start, j_end = block_ranges[j] | ||||||||||||||||||||||||||||||||||||||||
| block = attention[i_start:i_end, j_start:j_end] | ||||||||||||||||||||||||||||||||||||||||
| block_attention[31 - i, j] = block.mean() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| mask = np.triu(np.ones_like(block_attention, dtype=bool), k=1) | ||||||||||||||||||||||||||||||||||||||||
| plt.figure(figsize=(10, 10)) | ||||||||||||||||||||||||||||||||||||||||
| block_attention = block_attention / block_attention.max(axis=1, keepdims=True) | ||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The normalization
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| sns.heatmap(block_attention, mask=mask, cmap='viridis', square=True, cbar=True) | ||||||||||||||||||||||||||||||||||||||||
| # sns.heatmap(block_attention, cmap='viridis', square=True, cbar=True) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| section_labels = ['SYS', 'IMG', 'INS'] | ||||||||||||||||||||||||||||||||||||||||
| section_boundaries = [2, 26, 32] # block_ranges 分别为2个SYS,24个IMG,6个INS | ||||||||||||||||||||||||||||||||||||||||
| ticks = np.arange(0, num_blocks) | ||||||||||||||||||||||||||||||||||||||||
| plt.xticks(ticks=ticks, labels=[''] * num_blocks) | ||||||||||||||||||||||||||||||||||||||||
| plt.yticks(ticks=ticks, labels=[''] * num_blocks) | ||||||||||||||||||||||||||||||||||||||||
| plt.xticks(ticks=section_boundaries, labels=section_labels, fontsize=12) | ||||||||||||||||||||||||||||||||||||||||
| plt.yticks(ticks=section_boundaries, labels=section_labels, fontsize=12) | ||||||||||||||||||||||||||||||||||||||||
| plt.title('Attention Map') | ||||||||||||||||||||||||||||||||||||||||
| plt.tight_layout() | ||||||||||||||||||||||||||||||||||||||||
| plt.savefig(save_path, bbox_inches='tight') | ||||||||||||||||||||||||||||||||||||||||
| plt.close() | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+251
to
+302
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function contains many hardcoded "magic numbers" (e.g., |
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def visualize_cosin_token(token_embedding, save_path=None): | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| plt.rcParams['font.size'] = 15 | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| x = token_embedding[0, 14: 14 + 196 * 4, :] | ||||||||||||||||||||||||||||||||||||||||
| x_norm = F.normalize(x, p=2, dim=1) | ||||||||||||||||||||||||||||||||||||||||
| similarity_matrix = x_norm @ x_norm.T | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| sim_np = similarity_matrix.cpu().numpy() | ||||||||||||||||||||||||||||||||||||||||
| sim_np = np.triu(sim_np, k=1) | ||||||||||||||||||||||||||||||||||||||||
| valid_sim = sim_np[sim_np > 0] | ||||||||||||||||||||||||||||||||||||||||
| vmin = np.percentile(valid_sim, 90) # 10% min | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| plt.subplots(figsize=(10, 10)) | ||||||||||||||||||||||||||||||||||||||||
| sns.heatmap(similarity_matrix.cpu().numpy(), cmap='Reds', vmin=vmin, vmax=1) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| start = 0 | ||||||||||||||||||||||||||||||||||||||||
| step = 196 | ||||||||||||||||||||||||||||||||||||||||
| ticks = np.arange(start, 196 * 5, step) | ||||||||||||||||||||||||||||||||||||||||
| plt.xticks(ticks, ticks) | ||||||||||||||||||||||||||||||||||||||||
| plt.yticks(ticks, ticks) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| plt.title('') | ||||||||||||||||||||||||||||||||||||||||
| plt.xlabel('') | ||||||||||||||||||||||||||||||||||||||||
| plt.ylabel('') | ||||||||||||||||||||||||||||||||||||||||
| plt.tight_layout() | ||||||||||||||||||||||||||||||||||||||||
| plt.savefig(save_path, format='pdf') | ||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The file format is hardcoded to
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| plt.rcdefaults() | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+307
to
+332
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modifying |
||||||||||||||||||||||||||||||||||||||||
| plt.close() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def visualize_cosin_token_32p(token_embedding, save_path=None): | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| plt.rcParams['font.size'] = 20 | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| all_tokens = token_embedding[0, 14:14 + 196 * 32, :] | ||||||||||||||||||||||||||||||||||||||||
| x_norm = F.normalize(all_tokens, p=2, dim=1) | ||||||||||||||||||||||||||||||||||||||||
| similarity_matrix = x_norm @ x_norm.T | ||||||||||||||||||||||||||||||||||||||||
| sim_np = similarity_matrix.cpu().numpy() | ||||||||||||||||||||||||||||||||||||||||
| sim_np = np.triu(sim_np, k=1) | ||||||||||||||||||||||||||||||||||||||||
| valid_sim = sim_np[sim_np > 0] | ||||||||||||||||||||||||||||||||||||||||
| vmin = np.percentile(valid_sim, 90) # 10% min | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| group_size = 4 | ||||||||||||||||||||||||||||||||||||||||
| num_groups = 8 | ||||||||||||||||||||||||||||||||||||||||
| tokens_per_group = 196 * group_size | ||||||||||||||||||||||||||||||||||||||||
| step = 196 | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| fig, axs = plt.subplots(2, 4, figsize=(22, 10)) # 2x4排布 | ||||||||||||||||||||||||||||||||||||||||
| axs = axs.flatten() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| for i in range(num_groups): | ||||||||||||||||||||||||||||||||||||||||
| x = all_tokens[i * tokens_per_group: (i + 1) * tokens_per_group, :] | ||||||||||||||||||||||||||||||||||||||||
| x_norm = F.normalize(x, p=2, dim=1) | ||||||||||||||||||||||||||||||||||||||||
| similarity_matrix = x_norm @ x_norm.T | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ax = axs[i] | ||||||||||||||||||||||||||||||||||||||||
| sns.heatmap( | ||||||||||||||||||||||||||||||||||||||||
| similarity_matrix.cpu().numpy(), cmap='Reds', | ||||||||||||||||||||||||||||||||||||||||
| vmin=vmin, vmax=1, ax=ax, cbar=False | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ticks = np.arange(0, tokens_per_group, step) | ||||||||||||||||||||||||||||||||||||||||
| labels = np.arange(i * tokens_per_group, (i + 1) * tokens_per_group, step) | ||||||||||||||||||||||||||||||||||||||||
| ax.set_xticks(ticks) | ||||||||||||||||||||||||||||||||||||||||
| ax.set_yticks(ticks) | ||||||||||||||||||||||||||||||||||||||||
| ax.set_xticklabels(labels, rotation=0) | ||||||||||||||||||||||||||||||||||||||||
| ax.set_yticklabels(labels) | ||||||||||||||||||||||||||||||||||||||||
| start_frame = i * group_size | ||||||||||||||||||||||||||||||||||||||||
| end_frame = (i + 1) * group_size - 1 | ||||||||||||||||||||||||||||||||||||||||
| ax.set_xlabel(f'Frame {start_frame}-{end_frame}', fontsize=17, labelpad=10) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| plt.tight_layout() | ||||||||||||||||||||||||||||||||||||||||
| # plt.savefig(save_path, format='pdf') | ||||||||||||||||||||||||||||||||||||||||
| # plt.savefig(save_path.replace('.pdf', '.svg'), format='svg', bbox_inches='tight') | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+378
to
+379
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||
| plt.savefig(save_path, dpi=300) | ||||||||||||||||||||||||||||||||||||||||
| plt.rcdefaults() | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+338
to
+381
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||
| plt.close() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def highlight_grid(image, grid_indexes, grid_size=14): | ||||||||||||||||||||||||||||||||||||||||
| if not isinstance(grid_size, tuple): | ||||||||||||||||||||||||||||||||||||||||
| grid_size = (grid_size, grid_size) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The index
31is a "magic number," making it unclear what it represents. To improve readability and maintainability, it should be replaced with a named constant or derived from a configuration value (e.g.,num_attention_heads - 1).