From 931b676e99cb2ad6ba668b0f241fd79c0b15794b Mon Sep 17 00:00:00 2001 From: tripleMu <865626@163.com> Date: Wed, 22 Jun 2022 11:07:27 +0800 Subject: [PATCH 1/3] Add convit --- flowvision/models/__init__.py | 1 + flowvision/models/convit.py | 422 ++++++++++++++++++++++++++++++++++ 2 files changed, 423 insertions(+) create mode 100644 flowvision/models/convit.py diff --git a/flowvision/models/__init__.py b/flowvision/models/__init__.py index 764a537b..264a5ce8 100644 --- a/flowvision/models/__init__.py +++ b/flowvision/models/__init__.py @@ -32,6 +32,7 @@ from .van import * from .levit import * from .mobilevit import * +from .convit import * from . import style_transfer from . import detection diff --git a/flowvision/models/convit.py b/flowvision/models/convit.py new file mode 100644 index 00000000..c73ef3a2 --- /dev/null +++ b/flowvision/models/convit.py @@ -0,0 +1,422 @@ +import oneflow as flow +import oneflow.nn as nn +from functools import partial +import oneflow.nn.functional as F + + +from flowvision.models.helpers import to_2tuple +from flowvision.layers import trunc_normal_,DropPath,Mlp +from .registry import ModelCreator +from .utils import load_state_dict_from_url + +__all__ = [ + "VisionTransformer", + "convit_tiny", + "convit_small", + "convit_base", +] + +model_urls = { + "convit_tiny": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ConVit/convit_tiny.zip", + "convit_small": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ConVit/convit_small.zip", + "convit_base": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ConVit/convit_base.zip", +} + + +class GPSA(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., + locality_strength=1., use_local_init=True): + super().__init__() + self.num_heads = num_heads + self.dim = dim + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.pos_proj = nn.Linear(3, num_heads) + self.proj_drop = nn.Dropout(proj_drop) + self.locality_strength = locality_strength + self.gating_param = nn.Parameter(flow.ones(self.num_heads)) + self.apply(self._init_weights) + if use_local_init: + self.local_init(locality_strength=locality_strength) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + B, N, C = x.shape + if not hasattr(self, 'rel_indices') or self.rel_indices.size(1) != N: + self.get_rel_indices(N) + + attn = self.get_attention(x) + v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def get_attention(self, x): + B, N, C = x.shape + qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k = qk[0], qk[1] + pos_score = self.rel_indices.expand(B, -1, -1, -1) + pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2) + patch_score = (q @ k.transpose(-2, -1)) * self.scale + patch_score = patch_score.softmax(dim=-1) + pos_score = pos_score.softmax(dim=-1) + + gating = self.gating_param.view(1, -1, 1, 1) + attn = (1. - flow.sigmoid(gating)) * patch_score + flow.sigmoid(gating) * pos_score + attn /= attn.sum(dim=-1).unsqueeze(-1) + attn = self.attn_drop(attn) + return attn + + def get_attention_map(self, x, return_map=False): + + attn_map = self.get_attention(x).mean(0) # average over batch + distances = self.rel_indices.squeeze()[:, :, -1] ** .5 + dist = flow.einsum('nm,hnm->h', (distances, attn_map)) + dist /= distances.size(0) + if return_map: + return dist, attn_map + else: + return dist + + def local_init(self, locality_strength=1.): + + self.v.weight.data.copy_(flow.eye(self.dim)) + locality_distance = 1 # max(1,1/locality_strength**.5) + + kernel_size = int(self.num_heads ** .5) + center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2 + for h1 in range(kernel_size): + for h2 in range(kernel_size): + position = h1 + kernel_size * h2 + self.pos_proj.weight.data[position, 2] = -1 + self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance + self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance + self.pos_proj.weight.data *= locality_strength + + def get_rel_indices(self, num_patches): + img_size = int(num_patches ** .5) + rel_indices = flow.zeros(1, num_patches, num_patches, 3) + ind = flow.arange(img_size,dtype=flow.float32).view(1, -1) - flow.arange(img_size,dtype=flow.float32).view(-1, 1) + indx = ind.repeat(img_size, img_size) + indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) + indd = indx ** 2 + indy ** 2 + rel_indices[:, :, :, 2] = indd.unsqueeze(0) + rel_indices[:, :, :, 1] = indy.unsqueeze(0) + rel_indices[:, :, :, 0] = indx.unsqueeze(0) + device = self.qk.weight.device + self.rel_indices = rel_indices.to(device) + + +class MHSA(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_attention_map(self, x, return_map=False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + attn_map = (q @ k.transpose(-2, -1)) * self.scale + attn_map = attn_map.softmax(dim=-1).mean(0) + + img_size = int(N ** .5) + ind = flow.arange(img_size).view(1, -1) - flow.arange(img_size).view(-1, 1) + indx = ind.repeat(img_size, img_size) + indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) + indd = indx ** 2 + indy ** 2 + distances = indd ** .5 + distances = distances.to('cuda') + + dist = flow.einsum('nm,hnm->h', (distances, attn_map)) + dist /= N + + if return_map: + return dist, attn_map + else: + return dist + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): + super().__init__() + self.norm1 = norm_layer(dim) + self.use_gpsa = use_gpsa + if self.use_gpsa: + self.attn = GPSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + proj_drop=drop, **kwargs) + else: + self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + proj_drop=drop, **kwargs) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding, from timm + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.apply(self._init_weights) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding, from timm + """ + + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with flow.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(flow.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + self.apply(self._init_weights) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=48, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, + local_up_to_layer=10, locality_strength=1., use_pos_embed=True): + super().__init__() + self.num_classes = num_classes + self.local_up_to_layer = local_up_to_layer + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.locality_strength = locality_strength + self.use_pos_embed = use_pos_embed + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + self.num_patches = num_patches + + self.cls_token = nn.Parameter(flow.zeros(1, 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + if self.use_pos_embed: + self.pos_embed = nn.Parameter(flow.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.pos_embed, std=.02) + + dpr = [x.item() for x in flow.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + use_gpsa=True, + locality_strength=locality_strength) + if i < local_up_to_layer else + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + use_gpsa=False) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.cls_token, std=.02) + self.head.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + + if self.use_pos_embed: + x = x + self.pos_embed + x = self.pos_drop(x) + + for u, blk in enumerate(self.blocks): + if u == self.local_up_to_layer: + x = flow.cat((cls_tokens, x), dim=1) + x = blk(x) + + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +@ModelCreator.register_model +def convit_tiny(pretrained: bool = False, progress: bool = True, **kwargs): + num_heads = 4 + kwargs.setdefault('embed_dim',48) + kwargs['embed_dim'] *= num_heads + model = VisionTransformer( + num_heads=num_heads, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + if pretrained: + state_dict = load_state_dict_from_url( + model_urls["convit_tiny"], + model_dir="./checkpoints", + progress=progress, + ) + model.load_state_dict(state_dict) + return model + +@ModelCreator.register_model +def convit_small(pretrained: bool = False, progress: bool = True, **kwargs): + num_heads = 9 + kwargs.setdefault('embed_dim',48) + kwargs['embed_dim'] *= num_heads + model = VisionTransformer( + num_heads=num_heads, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + if pretrained: + state_dict = load_state_dict_from_url( + model_urls["convit_small"], + model_dir="./checkpoints", + progress=progress, + ) + model.load_state_dict(state_dict) + return model + +@ModelCreator.register_model +def convit_base(pretrained: bool = False, progress: bool = True, **kwargs): + num_heads = 16 + kwargs.setdefault('embed_dim',48) + kwargs['embed_dim'] *= num_heads + model = VisionTransformer( + num_heads=num_heads, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + if pretrained: + state_dict = load_state_dict_from_url( + model_urls["convit_base"], + model_dir="./checkpoints", + progress=progress, + ) + model.load_state_dict(state_dict) + return model + + From 161d44d91168da424c6b87035f4ca5331807f72c Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Wed, 22 Jun 2022 03:08:32 +0000 Subject: [PATCH 2/3] auto format by CI --- flowvision/models/convit.py | 304 ++++++++++++++------ flowvision/models/mobilevit.py | 467 ++++++++++++++++++++----------- flowvision/models/senet.py | 2 +- projects/classification/utils.py | 4 +- 4 files changed, 528 insertions(+), 249 deletions(-) diff --git a/flowvision/models/convit.py b/flowvision/models/convit.py index c73ef3a2..91c3c903 100644 --- a/flowvision/models/convit.py +++ b/flowvision/models/convit.py @@ -5,7 +5,7 @@ from flowvision.models.helpers import to_2tuple -from flowvision.layers import trunc_normal_,DropPath,Mlp +from flowvision.layers import trunc_normal_, DropPath, Mlp from .registry import ModelCreator from .utils import load_state_dict_from_url @@ -24,8 +24,17 @@ class GPSA(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., - locality_strength=1., use_local_init=True): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + locality_strength=1.0, + use_local_init=True, + ): super().__init__() self.num_heads = num_heads self.dim = dim @@ -47,7 +56,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0. def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -56,11 +65,15 @@ def _init_weights(self, m): def forward(self, x): B, N, C = x.shape - if not hasattr(self, 'rel_indices') or self.rel_indices.size(1) != N: + if not hasattr(self, "rel_indices") or self.rel_indices.size(1) != N: self.get_rel_indices(N) attn = self.get_attention(x) - v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = ( + self.v(x) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) @@ -68,7 +81,11 @@ def forward(self, x): def get_attention(self, x): B, N, C = x.shape - qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qk = ( + self.qk(x) + .reshape(B, N, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) q, k = qk[0], qk[1] pos_score = self.rel_indices.expand(B, -1, -1, -1) pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2) @@ -77,7 +94,9 @@ def get_attention(self, x): pos_score = pos_score.softmax(dim=-1) gating = self.gating_param.view(1, -1, 1, 1) - attn = (1. - flow.sigmoid(gating)) * patch_score + flow.sigmoid(gating) * pos_score + attn = (1.0 - flow.sigmoid(gating)) * patch_score + flow.sigmoid( + gating + ) * pos_score attn /= attn.sum(dim=-1).unsqueeze(-1) attn = self.attn_drop(attn) return attn @@ -85,33 +104,39 @@ def get_attention(self, x): def get_attention_map(self, x, return_map=False): attn_map = self.get_attention(x).mean(0) # average over batch - distances = self.rel_indices.squeeze()[:, :, -1] ** .5 - dist = flow.einsum('nm,hnm->h', (distances, attn_map)) + distances = self.rel_indices.squeeze()[:, :, -1] ** 0.5 + dist = flow.einsum("nm,hnm->h", (distances, attn_map)) dist /= distances.size(0) if return_map: return dist, attn_map else: return dist - def local_init(self, locality_strength=1.): + def local_init(self, locality_strength=1.0): self.v.weight.data.copy_(flow.eye(self.dim)) locality_distance = 1 # max(1,1/locality_strength**.5) - kernel_size = int(self.num_heads ** .5) + kernel_size = int(self.num_heads ** 0.5) center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2 for h1 in range(kernel_size): for h2 in range(kernel_size): position = h1 + kernel_size * h2 self.pos_proj.weight.data[position, 2] = -1 - self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance - self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance + self.pos_proj.weight.data[position, 1] = ( + 2 * (h1 - center) * locality_distance + ) + self.pos_proj.weight.data[position, 0] = ( + 2 * (h2 - center) * locality_distance + ) self.pos_proj.weight.data *= locality_strength def get_rel_indices(self, num_patches): - img_size = int(num_patches ** .5) + img_size = int(num_patches ** 0.5) rel_indices = flow.zeros(1, num_patches, num_patches, 3) - ind = flow.arange(img_size,dtype=flow.float32).view(1, -1) - flow.arange(img_size,dtype=flow.float32).view(-1, 1) + ind = flow.arange(img_size, dtype=flow.float32).view(1, -1) - flow.arange( + img_size, dtype=flow.float32 + ).view(-1, 1) indx = ind.repeat(img_size, img_size) indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) indd = indx ** 2 + indy ** 2 @@ -123,7 +148,15 @@ def get_rel_indices(self, num_patches): class MHSA(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads @@ -137,7 +170,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0. def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -146,20 +179,24 @@ def _init_weights(self, m): def get_attention_map(self, x, return_map=False): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) q, k, v = qkv[0], qkv[1], qkv[2] attn_map = (q @ k.transpose(-2, -1)) * self.scale attn_map = attn_map.softmax(dim=-1).mean(0) - img_size = int(N ** .5) + img_size = int(N ** 0.5) ind = flow.arange(img_size).view(1, -1) - flow.arange(img_size).view(-1, 1) indx = ind.repeat(img_size, img_size) indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) indd = indx ** 2 + indy ** 2 - distances = indd ** .5 - distances = distances.to('cuda') + distances = indd ** 0.5 + distances = distances.to("cuda") - dist = flow.einsum('nm,hnm->h', (distances, attn_map)) + dist = flow.einsum("nm,hnm->h", (distances, attn_map)) dist /= N if return_map: @@ -169,7 +206,11 @@ def get_attention_map(self, x, return_map=False): def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale @@ -183,22 +224,53 @@ def forward(self, x): class Block(nn.Module): - - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_gpsa=True, + **kwargs, + ): super().__init__() self.norm1 = norm_layer(dim) self.use_gpsa = use_gpsa if self.use_gpsa: - self.attn = GPSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, - proj_drop=drop, **kwargs) + self.attn = GPSA( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + **kwargs, + ) else: - self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, - proj_drop=drop, **kwargs) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.attn = MHSA( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + **kwargs, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) @@ -219,19 +291,22 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): self.patch_size = patch_size self.num_patches = num_patches - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) self.apply(self._init_weights) def forward(self, x): B, C, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) return x def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -243,7 +318,9 @@ class HybridEmbed(nn.Module): """ CNN Feature Map Embedding, from timm """ - def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + def __init__( + self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768 + ): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) @@ -276,23 +353,51 @@ class VisionTransformer(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=48, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, - local_up_to_layer=10, locality_strength=1., use_pos_embed=True): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=48, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + hybrid_backbone=None, + norm_layer=nn.LayerNorm, + global_pool=None, + local_up_to_layer=10, + locality_strength=1.0, + use_pos_embed=True, + ): super().__init__() self.num_classes = num_classes self.local_up_to_layer = local_up_to_layer - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models self.locality_strength = locality_strength self.use_pos_embed = use_pos_embed if hybrid_backbone is not None: self.patch_embed = HybridEmbed( - hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + hybrid_backbone, + img_size=img_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) else: self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) num_patches = self.patch_embed.num_patches self.num_patches = num_patches @@ -301,33 +406,56 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em if self.use_pos_embed: self.pos_embed = nn.Parameter(flow.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.pos_embed, std=.02) - - dpr = [x.item() for x in flow.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - use_gpsa=True, - locality_strength=locality_strength) - if i < local_up_to_layer else - Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - use_gpsa=False) - for i in range(depth)]) + trunc_normal_(self.pos_embed, std=0.02) + + dpr = [ + x.item() for x in flow.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + use_gpsa=True, + locality_strength=locality_strength, + ) + if i < local_up_to_layer + else Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + use_gpsa=False, + ) + for i in range(depth) + ] + ) self.norm = norm_layer(embed_dim) # Classifier head - self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] - self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.feature_info = [dict(num_chs=embed_dim, reduction=0, module="head")] + self.head = ( + nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) - trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.cls_token, std=0.02) self.head.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -335,14 +463,16 @@ def _init_weights(self, m): nn.init.constant_(m.weight, 1.0) def no_weight_decay(self): - return {'pos_embed', 'cls_token'} + return {"pos_embed", "cls_token"} def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=""): self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) def forward_features(self, x): B = x.shape[0] @@ -371,52 +501,46 @@ def forward(self, x): @ModelCreator.register_model def convit_tiny(pretrained: bool = False, progress: bool = True, **kwargs): num_heads = 4 - kwargs.setdefault('embed_dim',48) - kwargs['embed_dim'] *= num_heads + kwargs.setdefault("embed_dim", 48) + kwargs["embed_dim"] *= num_heads model = VisionTransformer( - num_heads=num_heads, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + num_heads=num_heads, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + ) if pretrained: state_dict = load_state_dict_from_url( - model_urls["convit_tiny"], - model_dir="./checkpoints", - progress=progress, + model_urls["convit_tiny"], model_dir="./checkpoints", progress=progress, ) model.load_state_dict(state_dict) return model + @ModelCreator.register_model def convit_small(pretrained: bool = False, progress: bool = True, **kwargs): num_heads = 9 - kwargs.setdefault('embed_dim',48) - kwargs['embed_dim'] *= num_heads + kwargs.setdefault("embed_dim", 48) + kwargs["embed_dim"] *= num_heads model = VisionTransformer( - num_heads=num_heads, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + num_heads=num_heads, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + ) if pretrained: state_dict = load_state_dict_from_url( - model_urls["convit_small"], - model_dir="./checkpoints", - progress=progress, + model_urls["convit_small"], model_dir="./checkpoints", progress=progress, ) model.load_state_dict(state_dict) return model + @ModelCreator.register_model def convit_base(pretrained: bool = False, progress: bool = True, **kwargs): num_heads = 16 - kwargs.setdefault('embed_dim',48) - kwargs['embed_dim'] *= num_heads + kwargs.setdefault("embed_dim", 48) + kwargs["embed_dim"] *= num_heads model = VisionTransformer( - num_heads=num_heads, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + num_heads=num_heads, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + ) if pretrained: state_dict = load_state_dict_from_url( - model_urls["convit_base"], - model_dir="./checkpoints", - progress=progress, + model_urls["convit_base"], model_dir="./checkpoints", progress=progress, ) model.load_state_dict(state_dict) return model - - diff --git a/flowvision/models/mobilevit.py b/flowvision/models/mobilevit.py index 8774538c..1155c62d 100644 --- a/flowvision/models/mobilevit.py +++ b/flowvision/models/mobilevit.py @@ -16,7 +16,7 @@ model_urls = { "mobilevit_small": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/MobileViT/mobilevit_s.zip", "mobilevit_x_small": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/MobileViT/mobilevit_xs.zip", - "mobilevit_xx_small": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/MobileViT/mobilevit_xxs.zip" + "mobilevit_xx_small": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/MobileViT/mobilevit_xxs.zip", } @@ -26,8 +26,13 @@ class MultiHeadAttention(nn.Module): https://arxiv.org/abs/1706.03762 """ - def __init__(self, embed_dim: int, num_heads: int, attn_dropout: Optional[float] = 0.0, - bias: Optional[bool] = True): + def __init__( + self, + embed_dim: int, + num_heads: int, + attn_dropout: Optional[float] = 0.0, + bias: Optional[bool] = True, + ): """ :param embed_dim: Embedding dimension :param num_heads: Number of attention heads @@ -35,12 +40,18 @@ def __init__(self, embed_dim: int, num_heads: int, attn_dropout: Optional[float] :param bias: Bias """ super(MultiHeadAttention, self).__init__() - assert embed_dim % num_heads == 0, "Got: embed_dim={} and num_heads={}".format(embed_dim, num_heads) + assert embed_dim % num_heads == 0, "Got: embed_dim={} and num_heads={}".format( + embed_dim, num_heads + ) - self.qkv_proj = LinearLayer(in_features=embed_dim, out_features=3 * embed_dim, bias=bias) + self.qkv_proj = LinearLayer( + in_features=embed_dim, out_features=3 * embed_dim, bias=bias + ) self.attn_dropout = nn.Dropout(attn_dropout) - self.out_proj = LinearLayer(in_features=embed_dim, out_features=embed_dim, bias=bias) + self.out_proj = LinearLayer( + in_features=embed_dim, out_features=embed_dim, bias=bias + ) self.head_dim = embed_dim // num_heads self.scaling = self.head_dim ** -0.5 @@ -53,10 +64,7 @@ def forward(self, x: Tensor) -> Tensor: b_sz, n_patches, in_channels = x.shape # [B x N x C] --> [B x N x 3 x h x C] - qkv = ( - self.qkv_proj(x) - .reshape(b_sz, n_patches, 3, self.num_heads, -1) - ) + qkv = self.qkv_proj(x).reshape(b_sz, n_patches, 3, self.num_heads, -1) # [B x N x 3 x h x C] --> [B x h x 3 x N x C] qkv = qkv.transpose(1, 3) @@ -91,15 +99,23 @@ class TransformerEncoder(nn.Module): https://arxiv.org/abs/1706.03762 """ - def __init__(self, embed_dim: int, ffn_latent_dim: int, num_heads: Optional[int] = 8, - attn_dropout: Optional[float] = 0.0, - dropout: Optional[float] = 0.1, ffn_dropout: Optional[float] = 0.0): + def __init__( + self, + embed_dim: int, + ffn_latent_dim: int, + num_heads: Optional[int] = 8, + attn_dropout: Optional[float] = 0.0, + dropout: Optional[float] = 0.1, + ffn_dropout: Optional[float] = 0.0, + ): super(TransformerEncoder, self).__init__() self.pre_norm_mha = nn.Sequential( nn.LayerNorm(embed_dim), - MultiHeadAttention(embed_dim, num_heads, attn_dropout=attn_dropout, bias=True), - nn.Dropout(dropout) + MultiHeadAttention( + embed_dim, num_heads, attn_dropout=attn_dropout, bias=True + ), + nn.Dropout(dropout), ) self.pre_norm_ffn = nn.Sequential( @@ -108,7 +124,7 @@ def __init__(self, embed_dim: int, ffn_latent_dim: int, num_heads: Optional[int] nn.SiLU(), nn.Dropout(ffn_dropout), LinearLayer(in_features=ffn_latent_dim, out_features=embed_dim, bias=True), - nn.Dropout(dropout) + nn.Dropout(dropout), ) self.embed_dim = embed_dim self.ffn_dim = ffn_latent_dim @@ -128,32 +144,58 @@ class MobileViTBlock(nn.Module): MobileViT block: https://arxiv.org/abs/2110.02178?context=cs.LG """ - def __init__(self, in_channels: int, transformer_dim: int, ffn_dim: int, - n_transformer_blocks: Optional[int] = 2, - head_dim: Optional[int] = 32, attn_dropout: Optional[float] = 0.1, - dropout: Optional[float] = 0.1, ffn_dropout: Optional[float] = 0.1, patch_h: Optional[int] = 8, - patch_w: Optional[int] = 8, - conv_ksize: Optional[int] = 3, - dilation: Optional[int] = 1, var_ffn: Optional[bool] = False, - no_fusion: Optional[bool] = False): + def __init__( + self, + in_channels: int, + transformer_dim: int, + ffn_dim: int, + n_transformer_blocks: Optional[int] = 2, + head_dim: Optional[int] = 32, + attn_dropout: Optional[float] = 0.1, + dropout: Optional[float] = 0.1, + ffn_dropout: Optional[float] = 0.1, + patch_h: Optional[int] = 8, + patch_w: Optional[int] = 8, + conv_ksize: Optional[int] = 3, + dilation: Optional[int] = 1, + var_ffn: Optional[bool] = False, + no_fusion: Optional[bool] = False, + ): conv_3x3_in = ConvLayer( - in_channels=in_channels, out_channels=in_channels, - kernel_size=conv_ksize, stride=1, use_norm=True, use_act=True, dilation=dilation + in_channels=in_channels, + out_channels=in_channels, + kernel_size=conv_ksize, + stride=1, + use_norm=True, + use_act=True, + dilation=dilation, ) conv_1x1_in = ConvLayer( - in_channels=in_channels, out_channels=transformer_dim, - kernel_size=1, stride=1, use_norm=False, use_act=False + in_channels=in_channels, + out_channels=transformer_dim, + kernel_size=1, + stride=1, + use_norm=False, + use_act=False, ) conv_1x1_out = ConvLayer( - in_channels=transformer_dim, out_channels=in_channels, - kernel_size=1, stride=1, use_norm=True, use_act=True + in_channels=transformer_dim, + out_channels=in_channels, + kernel_size=1, + stride=1, + use_norm=True, + use_act=True, ) conv_3x3_out = None if not no_fusion: conv_3x3_out = ConvLayer( - in_channels=2 * in_channels, out_channels=in_channels, - kernel_size=conv_ksize, stride=1, use_norm=True, use_act=True + in_channels=2 * in_channels, + out_channels=in_channels, + kernel_size=conv_ksize, + stride=1, + use_norm=True, + use_act=True, ) super(MobileViTBlock, self).__init__() self.local_rep = nn.Sequential() @@ -166,8 +208,14 @@ def __init__(self, in_channels: int, transformer_dim: int, ffn_dim: int, ffn_dims = [ffn_dim] * n_transformer_blocks global_rep = [ - TransformerEncoder(embed_dim=transformer_dim, ffn_latent_dim=ffn_dims[block_idx], num_heads=num_heads, - attn_dropout=attn_dropout, dropout=dropout, ffn_dropout=ffn_dropout) + TransformerEncoder( + embed_dim=transformer_dim, + ffn_latent_dim=ffn_dims[block_idx], + num_heads=num_heads, + attn_dropout=attn_dropout, + dropout=dropout, + ffn_dropout=ffn_dropout, + ) for block_idx in range(n_transformer_blocks) ] global_rep.append(nn.LayerNorm(transformer_dim)) @@ -206,7 +254,9 @@ def unfolding(self, feature_map: Tensor) -> Tuple[Tensor, Dict]: interpolate = False if new_w != orig_w or new_h != orig_h: # Note: Padding can be done, but then it needs to be handled in attention function. - feature_map = F.interpolate(feature_map, size=(new_h, new_w), mode="bilinear", align_corners=False) + feature_map = F.interpolate( + feature_map, size=(new_h, new_w), mode="bilinear", align_corners=False + ) interpolate = True # number of patches along width and height @@ -215,11 +265,15 @@ def unfolding(self, feature_map: Tensor) -> Tuple[Tensor, Dict]: num_patches = num_patch_h * num_patch_w # N # [B, C, H, W] --> [B * C * n_h, p_h, n_w, p_w] - reshaped_fm = feature_map.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w) + reshaped_fm = feature_map.reshape( + batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w + ) # [B * C * n_h, p_h, n_w, p_w] --> [B * C * n_h, n_w, p_h, p_w] transposed_fm = reshaped_fm.transpose(1, 2) # [B * C * n_h, n_w, p_h, p_w] --> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w - reshaped_fm = transposed_fm.reshape(batch_size, in_channels, num_patches, patch_area) + reshaped_fm = transposed_fm.reshape( + batch_size, in_channels, num_patches, patch_area + ) # [B, C, N, P] --> [B, P, N, C] transposed_fm = reshaped_fm.transpose(1, 3) # [B, P, N, C] --> [BP, N, C] @@ -231,16 +285,20 @@ def unfolding(self, feature_map: Tensor) -> Tuple[Tensor, Dict]: "interpolate": interpolate, "total_patches": num_patches, "num_patches_w": num_patch_w, - "num_patches_h": num_patch_h + "num_patches_h": num_patch_h, } return patches, info_dict def folding(self, patches: Tensor, info_dict: Dict) -> Tensor: n_dim = patches.dim() - assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(patches.shape) + assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format( + patches.shape + ) # [BP, N, C] --> [B, P, N, C] - patches = patches.contiguous().view(info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1) + patches = patches.contiguous().view( + info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1 + ) batch_size, pixels, num_patches, channels = patches.size() num_patch_h = info_dict["num_patches_h"] @@ -250,13 +308,22 @@ def folding(self, patches: Tensor, info_dict: Dict) -> Tensor: patches = patches.transpose(1, 3) # [B, C, N, P] --> [B*C*n_h, n_w, p_h, p_w] - feature_map = patches.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w) + feature_map = patches.reshape( + batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w + ) # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] feature_map = feature_map.transpose(1, 2) # [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] - feature_map = feature_map.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w) + feature_map = feature_map.reshape( + batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w + ) if info_dict["interpolate"]: - feature_map = F.interpolate(feature_map, size=info_dict["orig_size"], mode="bilinear", align_corners=False) + feature_map = F.interpolate( + feature_map, + size=info_dict["orig_size"], + mode="bilinear", + align_corners=False, + ) return feature_map def forward(self, x: Tensor) -> Tensor: @@ -276,15 +343,15 @@ def forward(self, x: Tensor) -> Tensor: fm = self.conv_proj(fm) if self.fusion is not None: - fm = self.fusion( - flow.cat((res, fm), dim=1) - ) + fm = self.fusion(flow.cat((res, fm), dim=1)) return fm -def make_divisible(v: Union[float, int], - divisor: Optional[int] = 8, - min_value: Optional[Union[float, int]] = None) -> Union[float, int]: +def make_divisible( + v: Union[float, int], + divisor: Optional[int] = 8, + min_value: Optional[Union[float, int]] = None, +) -> Union[float, int]: """ This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 @@ -309,13 +376,14 @@ class InvertedResidual(nn.Module): Inverted residual block (MobileNetv2): https://arxiv.org/abs/1801.04381 """ - def __init__(self, - in_channels: int, - out_channels: int, - stride: int, - expand_ratio: Union[int, float], - dilation: int = 1 - ) -> None: + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + expand_ratio: Union[int, float], + dilation: int = 1, + ) -> None: assert stride in [1, 2] super(InvertedResidual, self).__init__() self.stride = stride @@ -325,19 +393,41 @@ def __init__(self, block = nn.Sequential() if expand_ratio != 1: - block.add_module(name="exp_1x1", - module=ConvLayer(in_channels=in_channels, out_channels=hidden_dim, kernel_size=1, - use_act=True, use_norm=True)) + block.add_module( + name="exp_1x1", + module=ConvLayer( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + use_act=True, + use_norm=True, + ), + ) block.add_module( name="conv_3x3", - module=ConvLayer(in_channels=hidden_dim, out_channels=hidden_dim, stride=stride, kernel_size=3, - groups=hidden_dim, use_act=True, use_norm=True, dilation=dilation) + module=ConvLayer( + in_channels=hidden_dim, + out_channels=hidden_dim, + stride=stride, + kernel_size=3, + groups=hidden_dim, + use_act=True, + use_norm=True, + dilation=dilation, + ), ) - block.add_module(name="red_1x1", - module=ConvLayer(in_channels=hidden_dim, out_channels=out_channels, kernel_size=1, - use_act=False, use_norm=True)) + block.add_module( + name="red_1x1", + module=ConvLayer( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + use_act=False, + use_norm=True, + ), + ) self.block = block self.in_channels = in_channels @@ -353,25 +443,27 @@ def forward(self, x: Tensor, *args, **kwargs) -> Tensor: class GlobalPool(nn.Module): - def __init__(self, pool_type='mean', keep_dim=False): + def __init__(self, pool_type="mean", keep_dim=False): """ Global pooling :param pool_type: Global pool operation type (mean, rms, abs) :param keep_dim: Keep dimensions the same as the input or not """ super(GlobalPool, self).__init__() - pool_types = ['mean', 'rms', 'abs'] - assert pool_type in pool_types, 'Supported pool types are: {}. Got {}'.format(pool_types, pool_type) + pool_types = ["mean", "rms", "abs"] + assert pool_type in pool_types, "Supported pool types are: {}. Got {}".format( + pool_types, pool_type + ) self.pool_type = pool_type self.keep_dim = keep_dim def _global_pool(self, x): assert x.dim() == 4, "Got: {}".format(x.shape) - if self.pool_type == 'rms': + if self.pool_type == "rms": x = x ** 2 x = flow.mean(x, dim=[-2, -1], keepdim=self.keep_dim) x = x ** -0.5 - elif self.pool_type == 'abs': + elif self.pool_type == "abs": x = flow.mean(flow.abs(x), dim=[-2, -1], keepdim=self.keep_dim) else: # default is mean @@ -384,11 +476,9 @@ def forward(self, x: Tensor) -> Tensor: class LinearLayer(nn.Module): - def __init__(self, - in_features: int, - out_features: int, - bias: Optional[bool] = True - ) -> None: + def __init__( + self, in_features: int, out_features: int, bias: Optional[bool] = True + ) -> None: """ Applies a linear transformation to the input data :param in_features: size of each input sample @@ -421,21 +511,45 @@ def forward(self, x: Tensor) -> Tensor: class Conv2d(nn.Conv2d): - def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple or int, stride: tuple or int, - padding: tuple or int, dilation: int or tuple, groups: int, bias: bool, padding_mode: str - ): - super(Conv2d, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, - stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, - padding_mode=padding_mode) + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: tuple or int, + stride: tuple or int, + padding: tuple or int, + dilation: int or tuple, + groups: int, + bias: bool, + padding_mode: str, + ): + super(Conv2d, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + ) class ConvLayer(nn.Module): - def __init__(self, in_channels: int, out_channels: int, kernel_size: int or tuple, - stride: Optional[int or tuple] = 1, - dilation: Optional[int or tuple] = 1, groups: Optional[int] = 1, - bias: Optional[bool] = False, padding_mode: Optional[str] = 'zeros', - use_norm: Optional[bool] = True, use_act: Optional[bool] = True - ) -> None: + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int or tuple, + stride: Optional[int or tuple] = 1, + dilation: Optional[int or tuple] = 1, + groups: Optional[int] = 1, + bias: Optional[bool] = False, + padding_mode: Optional[str] = "zeros", + use_norm: Optional[bool] = True, + use_act: Optional[bool] = True, + ) -> None: """ Applies a 2D convolution over an input signal composed of several input planes. :param opts: arguments @@ -454,7 +568,7 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: int or tupl super(ConvLayer, self).__init__() if use_norm: - assert not bias, 'Do not use bias when using normalization layers.' + assert not bias, "Do not use bias when using normalization layers." if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) @@ -469,18 +583,35 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: int or tupl assert isinstance(stride, (tuple, list)) assert isinstance(dilation, (tuple, list)) - padding = (int((kernel_size[0] - 1) / 2) * dilation[0], int((kernel_size[1] - 1) / 2) * dilation[1]) + padding = ( + int((kernel_size[0] - 1) / 2) * dilation[0], + int((kernel_size[1] - 1) / 2) * dilation[1], + ) - assert in_channels % groups == 0, \ - 'Input channels are not divisible by groups. {}%{} != 0 '.format(in_channels, groups) - assert out_channels % groups == 0, \ - 'Output channels are not divisible by groups. {}%{} != 0 '.format(out_channels, groups) + assert ( + in_channels % groups == 0 + ), "Input channels are not divisible by groups. {}%{} != 0 ".format( + in_channels, groups + ) + assert ( + out_channels % groups == 0 + ), "Output channels are not divisible by groups. {}%{} != 0 ".format( + out_channels, groups + ) block = nn.Sequential() - conv_layer = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, - stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, - padding_mode=padding_mode) + conv_layer = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + ) block.add_module(name="conv", module=conv_layer) @@ -517,12 +648,7 @@ class MobileViT(nn.Module): """ def __init__( - self, - arch, - num_classes=1000, - classifier_dropout=0.1, - pool_type='mean', - **kwargs + self, arch, num_classes=1000, classifier_dropout=0.1, pool_type="mean", **kwargs ) -> None: image_channels = 3 out_channels = 16 @@ -546,58 +672,75 @@ def __init__( # store model configuration in a dictionary self.model_conf_dict = dict() self.conv_1 = ConvLayer( - in_channels=image_channels, out_channels=out_channels, - kernel_size=3, stride=2, use_norm=True, use_act=True + in_channels=image_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + use_norm=True, + use_act=True, ) - self.model_conf_dict['conv1'] = {'in': image_channels, 'out': out_channels} + self.model_conf_dict["conv1"] = {"in": image_channels, "out": out_channels} in_channels = out_channels self.layer_1, out_channels = self._make_layer( input_channel=in_channels, cfg=mobilevit_config["layer1"] ) - self.model_conf_dict['layer1'] = {'in': in_channels, 'out': out_channels} + self.model_conf_dict["layer1"] = {"in": in_channels, "out": out_channels} in_channels = out_channels self.layer_2, out_channels = self._make_layer( input_channel=in_channels, cfg=mobilevit_config["layer2"] ) - self.model_conf_dict['layer2'] = {'in': in_channels, 'out': out_channels} + self.model_conf_dict["layer2"] = {"in": in_channels, "out": out_channels} in_channels = out_channels self.layer_3, out_channels = self._make_layer( input_channel=in_channels, cfg=mobilevit_config["layer3"] ) - self.model_conf_dict['layer3'] = {'in': in_channels, 'out': out_channels} + self.model_conf_dict["layer3"] = {"in": in_channels, "out": out_channels} in_channels = out_channels self.layer_4, out_channels = self._make_layer( input_channel=in_channels, cfg=mobilevit_config["layer4"], dilate=dilate_l4 ) - self.model_conf_dict['layer4'] = {'in': in_channels, 'out': out_channels} + self.model_conf_dict["layer4"] = {"in": in_channels, "out": out_channels} in_channels = out_channels self.layer_5, out_channels = self._make_layer( input_channel=in_channels, cfg=mobilevit_config["layer5"], dilate=dilate_l5 ) - self.model_conf_dict['layer5'] = {'in': in_channels, 'out': out_channels} + self.model_conf_dict["layer5"] = {"in": in_channels, "out": out_channels} in_channels = out_channels exp_channels = min(mobilevit_config["last_layer_exp_factor"] * in_channels, 960) self.conv_1x1_exp = ConvLayer( - in_channels=in_channels, out_channels=exp_channels, - kernel_size=1, stride=1, use_act=True, use_norm=True + in_channels=in_channels, + out_channels=exp_channels, + kernel_size=1, + stride=1, + use_act=True, + use_norm=True, ) - self.model_conf_dict['exp_before_cls'] = {'in': in_channels, 'out': exp_channels} + self.model_conf_dict["exp_before_cls"] = { + "in": in_channels, + "out": exp_channels, + } self.classifier = nn.Sequential() - self.classifier.add_module(name="global_pool", module=GlobalPool(pool_type=pool_type, keep_dim=False)) + self.classifier.add_module( + name="global_pool", module=GlobalPool(pool_type=pool_type, keep_dim=False) + ) if 0.0 < classifier_dropout < 1.0: - self.classifier.add_module(name="dropout", module=nn.Dropout(p=classifier_dropout, inplace=True)) + self.classifier.add_module( + name="dropout", module=nn.Dropout(p=classifier_dropout, inplace=True) + ) self.classifier.add_module( name="fc", - module=LinearLayer(in_features=exp_channels, out_features=num_classes, bias=True) + module=LinearLayer( + in_features=exp_channels, out_features=num_classes, bias=True + ), ) # weight initialization @@ -610,7 +753,7 @@ def initialize_weights(self): for m in modules: if isinstance(m, nn.Conv2d): if m.weight is not None: - nn.init.kaiming_normal_(m.weight, mode='fan_out') + nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.Linear, LinearLayer)): @@ -627,22 +770,21 @@ def initialize_weights(self): if m.bias is not None: nn.init.zeros_(m.bias) - def _make_layer(self, input_channel, cfg: Dict, dilate: Optional[bool] = False) -> Tuple[nn.Sequential, int]: + def _make_layer( + self, input_channel, cfg: Dict, dilate: Optional[bool] = False + ) -> Tuple[nn.Sequential, int]: block_type = cfg.get("block_type", "mobilevit") if block_type.lower() == "mobilevit": return self._make_mit_layer( - input_channel=input_channel, - cfg=cfg, - dilate=dilate + input_channel=input_channel, cfg=cfg, dilate=dilate ) else: - return self._make_mobilenet_layer( - input_channel=input_channel, - cfg=cfg - ) + return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg) @staticmethod - def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]: + def _make_mobilenet_layer( + input_channel: int, cfg: Dict + ) -> Tuple[nn.Sequential, int]: output_channels = cfg.get("out_channels") num_blocks = cfg.get("num_blocks", 2) expand_ratio = cfg.get("expand_ratio", 4) @@ -655,13 +797,15 @@ def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, in_channels=input_channel, out_channels=output_channels, stride=stride, - expand_ratio=expand_ratio + expand_ratio=expand_ratio, ) block.append(layer) input_channel = output_channels return nn.Sequential(*block), input_channel - def _make_mit_layer(self, input_channel, cfg: Dict, dilate: Optional[bool] = False) -> Tuple[nn.Sequential, int]: + def _make_mit_layer( + self, input_channel, cfg: Dict, dilate: Optional[bool] = False + ) -> Tuple[nn.Sequential, int]: prev_dilation = self.dilation block = [] stride = cfg.get("stride", 1) @@ -676,7 +820,7 @@ def _make_mit_layer(self, input_channel, cfg: Dict, dilate: Optional[bool] = Fal out_channels=cfg.get("out_channels"), stride=stride, expand_ratio=cfg.get("mv_expand_ratio", 4), - dilation=prev_dilation + dilation=prev_dilation, ) block.append(layer) @@ -691,9 +835,10 @@ def _make_mit_layer(self, input_channel, cfg: Dict, dilate: Optional[bool] = Fal num_heads = 4 head_dim = transformer_dim // num_heads - assert transformer_dim % head_dim == 0, \ - "Transformer input dimension should be divisible by head dimension. " \ + assert transformer_dim % head_dim == 0, ( + "Transformer input dimension should be divisible by head dimension. " "Got {} and {}.".format(transformer_dim, head_dim) + ) block.append( MobileViTBlock( @@ -708,7 +853,7 @@ def _make_mit_layer(self, input_channel, cfg: Dict, dilate: Optional[bool] = Fal attn_dropout=0.0, head_dim=head_dim, no_fusion=False, - conv_ksize=3 + conv_ksize=3, ) ) @@ -732,20 +877,20 @@ def forward(self, x: Tensor) -> Tensor: CONFIG = { - 'mobilevit_xx_small': { + "mobilevit_xx_small": { "layer1": { "out_channels": 16, "expand_ratio": 2, "num_blocks": 1, "stride": 1, - "block_type": "mv2" + "block_type": "mv2", }, "layer2": { "out_channels": 24, "expand_ratio": 2, "num_blocks": 3, "stride": 2, - "block_type": "mv2" + "block_type": "mv2", }, "layer3": { # 28x28 "out_channels": 48, @@ -758,7 +903,7 @@ def forward(self, x: Tensor) -> Tensor: "mv_expand_ratio": 2, "head_dim": None, "num_heads": 4, - "block_type": "mobilevit" + "block_type": "mobilevit", }, "layer4": { # 14x14 "out_channels": 64, @@ -771,7 +916,7 @@ def forward(self, x: Tensor) -> Tensor: "mv_expand_ratio": 2, "head_dim": None, "num_heads": 4, - "block_type": "mobilevit" + "block_type": "mobilevit", }, "layer5": { # 7x7 "out_channels": 80, @@ -784,24 +929,24 @@ def forward(self, x: Tensor) -> Tensor: "mv_expand_ratio": 2, "head_dim": None, "num_heads": 4, - "block_type": "mobilevit" + "block_type": "mobilevit", }, - "last_layer_exp_factor": 4 + "last_layer_exp_factor": 4, }, - 'mobilevit_x_small': { + "mobilevit_x_small": { "layer1": { "out_channels": 32, "expand_ratio": 4, "num_blocks": 1, "stride": 1, - "block_type": "mv2" + "block_type": "mv2", }, "layer2": { "out_channels": 48, "expand_ratio": 4, "num_blocks": 3, "stride": 2, - "block_type": "mv2" + "block_type": "mv2", }, "layer3": { # 28x28 "out_channels": 64, @@ -814,7 +959,7 @@ def forward(self, x: Tensor) -> Tensor: "mv_expand_ratio": 4, "head_dim": None, "num_heads": 4, - "block_type": "mobilevit" + "block_type": "mobilevit", }, "layer4": { # 14x14 "out_channels": 80, @@ -827,7 +972,7 @@ def forward(self, x: Tensor) -> Tensor: "mv_expand_ratio": 4, "head_dim": None, "num_heads": 4, - "block_type": "mobilevit" + "block_type": "mobilevit", }, "layer5": { # 7x7 "out_channels": 96, @@ -840,24 +985,24 @@ def forward(self, x: Tensor) -> Tensor: "mv_expand_ratio": 4, "head_dim": None, "num_heads": 4, - "block_type": "mobilevit" + "block_type": "mobilevit", }, - "last_layer_exp_factor": 4 + "last_layer_exp_factor": 4, }, - 'mobilevit_small': { + "mobilevit_small": { "layer1": { "out_channels": 32, "expand_ratio": 4, "num_blocks": 1, "stride": 1, - "block_type": "mv2" + "block_type": "mv2", }, "layer2": { "out_channels": 64, "expand_ratio": 4, "num_blocks": 3, "stride": 2, - "block_type": "mv2" + "block_type": "mv2", }, "layer3": { # 28x28 "out_channels": 96, @@ -870,7 +1015,7 @@ def forward(self, x: Tensor) -> Tensor: "mv_expand_ratio": 4, "head_dim": None, "num_heads": 4, - "block_type": "mobilevit" + "block_type": "mobilevit", }, "layer4": { # 14x14 "out_channels": 128, @@ -883,7 +1028,7 @@ def forward(self, x: Tensor) -> Tensor: "mv_expand_ratio": 4, "head_dim": None, "num_heads": 4, - "block_type": "mobilevit" + "block_type": "mobilevit", }, "layer5": { # 7x7 "out_channels": 160, @@ -896,14 +1041,16 @@ def forward(self, x: Tensor) -> Tensor: "mv_expand_ratio": 4, "head_dim": None, "num_heads": 4, - "block_type": "mobilevit" + "block_type": "mobilevit", }, - "last_layer_exp_factor": 4 - } + "last_layer_exp_factor": 4, + }, } -def _create_mobilevit(arch: str, pretrained: bool = False, progress: bool = True, **model_kwargs): +def _create_mobilevit( + arch: str, pretrained: bool = False, progress: bool = True, **model_kwargs +): model = MobileViT(arch=arch, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) @@ -931,7 +1078,9 @@ def mobilevit_small(pretrained: bool = False, progress: bool = True, **kwargs): >>> mobilevit_s = flowvision.models.mobilevit_small(pretrained=False, progress=True) """ - return _create_mobilevit(arch='mobilevit_small', pretrained=pretrained, progress=progress, **kwargs) + return _create_mobilevit( + arch="mobilevit_small", pretrained=pretrained, progress=progress, **kwargs + ) @ModelCreator.register_model @@ -954,7 +1103,9 @@ def mobilevit_x_small(pretrained: bool = False, progress: bool = True, **kwargs) >>> mobilevit_xs = flowvision.models.mobilevit_x_small(pretrained=False, progress=True) """ - return _create_mobilevit(arch='mobilevit_x_small', pretrained=pretrained, progress=progress, **kwargs) + return _create_mobilevit( + arch="mobilevit_x_small", pretrained=pretrained, progress=progress, **kwargs + ) @ModelCreator.register_model @@ -977,4 +1128,6 @@ def mobilevit_xx_small(pretrained: bool = False, progress: bool = True, **kwargs >>> mobilevit_xxs = flowvision.models.mobilevit_xx_small(pretrained=False, progress=True) """ - return _create_mobilevit(arch='mobilevit_xx_small', pretrained=pretrained, progress=progress, **kwargs) + return _create_mobilevit( + arch="mobilevit_xx_small", pretrained=pretrained, progress=progress, **kwargs + ) diff --git a/flowvision/models/senet.py b/flowvision/models/senet.py index 89a4acac..adeff188 100644 --- a/flowvision/models/senet.py +++ b/flowvision/models/senet.py @@ -20,7 +20,7 @@ "se_resnet101", "se_resnet152", "se_resnext50_32x4d", - "se_resnext101_32x4d" + "se_resnext101_32x4d", ] model_urls = { diff --git a/projects/classification/utils.py b/projects/classification/utils.py index 00ef8dc5..b3a84ec3 100644 --- a/projects/classification/utils.py +++ b/projects/classification/utils.py @@ -70,7 +70,9 @@ def get_grad_norm(parameters, norm_type=2): def auto_resume_helper(output_dir): checkpoints = os.listdir(output_dir) - checkpoints = [ckpt for ckpt in checkpoints if os.path.isdir(os.path.join(output_dir,ckpt))] + checkpoints = [ + ckpt for ckpt in checkpoints if os.path.isdir(os.path.join(output_dir, ckpt)) + ] print(f"All checkpoints founded in {output_dir}: {checkpoints}") if len(checkpoints) > 0: latest_checkpoint = max( From 0696e2e71f53176eb2f4a31c6fccdce5e0c7358a Mon Sep 17 00:00:00 2001 From: tripleMu <865626@163.com> Date: Thu, 30 Jun 2022 16:56:21 +0800 Subject: [PATCH 3/3] Add xcit and pretrain model (42 xcit models) --- flowvision/models/__init__.py | 1 + flowvision/models/xcit.py | 866 ++++++++++++++++++++++++++++++++++ 2 files changed, 867 insertions(+) create mode 100644 flowvision/models/xcit.py diff --git a/flowvision/models/__init__.py b/flowvision/models/__init__.py index 264a5ce8..ffedbf6a 100644 --- a/flowvision/models/__init__.py +++ b/flowvision/models/__init__.py @@ -33,6 +33,7 @@ from .levit import * from .mobilevit import * from .convit import * +from .xcit import * from . import style_transfer from . import detection diff --git a/flowvision/models/xcit.py b/flowvision/models/xcit.py new file mode 100644 index 00000000..0b5d760e --- /dev/null +++ b/flowvision/models/xcit.py @@ -0,0 +1,866 @@ +import math +from functools import partial + +import oneflow as flow +import oneflow.nn as nn + +from flowvision.models.helpers import to_2tuple +from flowvision.layers import trunc_normal_, DropPath, Mlp +from .registry import ModelCreator +from .utils import load_state_dict_from_url + +__all__ = [ + "XCiT", + "xcit_nano_12_p16_224", + "xcit_nano_12_p16_224_dist", + "xcit_nano_12_p16_384_dist", + "xcit_tiny_12_p16_224", + "xcit_tiny_12_p16_224_dist", + "xcit_tiny_12_p16_384_dist", + "xcit_tiny_24_p16_224", + "xcit_tiny_24_p16_224_dist", + "xcit_tiny_24_p16_384_dist", + "xcit_small_12_p16_224", + "xcit_small_12_p16_224_dist", + "xcit_small_12_p16_384_dist", + "xcit_small_24_p16_224", + "xcit_small_24_p16_224_dist", + "xcit_small_24_p16_384_dist", + "xcit_medium_24_p16_224", + "xcit_medium_24_p16_224_dist", + "xcit_medium_24_p16_384_dist", + "xcit_large_24_p16_224", + "xcit_large_24_p16_224_dist", + "xcit_large_24_p16_384_dist", + "xcit_nano_12_p8_224", + "xcit_nano_12_p8_224_dist", + "xcit_nano_12_p8_384_dist", + "xcit_tiny_12_p8_224", + "xcit_tiny_12_p8_224_dist", + "xcit_tiny_12_p8_384_dist", + "xcit_tiny_24_p8_224", + "xcit_tiny_24_p8_224_dist", + "xcit_tiny_24_p8_384_dist", + "xcit_small_12_p8_224", + "xcit_small_12_p8_224_dist", + "xcit_small_12_p8_384_dist", + "xcit_small_24_p8_224", + "xcit_small_24_p8_224_dist", + "xcit_small_24_p8_384_dist", + "xcit_medium_24_p8_224", + "xcit_medium_24_p8_224_dist", + "xcit_medium_24_p8_384_dist", + "xcit_large_24_p8_224", + "xcit_large_24_p8_224_dist", + "xcit_large_24_p8_384_dist", +] + +model_urls = { + 'xcit_nano_12_p16_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_nano_12_p16_224.zip', + 'xcit_nano_12_p16_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_nano_12_p16_224_dist.zip', + 'xcit_nano_12_p16_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_nano_12_p16_384_dist.zip', + 'xcit_tiny_12_p16_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_12_p16_224.zip', + 'xcit_tiny_12_p16_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_12_p16_224_dist.zip', + 'xcit_tiny_12_p16_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_12_p16_384_dist.zip', + 'xcit_tiny_24_p16_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_24_p16_224.zip', + 'xcit_tiny_24_p16_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_24_p16_224_dist.zip', + 'xcit_tiny_24_p16_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_24_p16_384_dist.zip', + 'xcit_small_12_p16_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_12_p16_224.zip', + 'xcit_small_12_p16_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_12_p16_224_dist.zip', + 'xcit_small_12_p16_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_12_p16_384_dist.zip', + 'xcit_small_24_p16_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_24_p16_224.zip', + 'xcit_small_24_p16_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_24_p16_224_dist.zip', + 'xcit_small_24_p16_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_24_p16_384_dist.zip', + 'xcit_medium_24_p16_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_medium_24_p16_224.zip', + 'xcit_medium_24_p16_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_medium_24_p16_224_dist.zip', + 'xcit_medium_24_p16_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_medium_24_p16_384_dist.zip', + 'xcit_large_24_p16_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_large_24_p16_224.zip', + 'xcit_large_24_p16_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_large_24_p16_224_dist.zip', + 'xcit_large_24_p16_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_large_24_p16_384_dist.zip', + 'xcit_nano_12_p8_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_nano_12_p8_224.zip', + 'xcit_nano_12_p8_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_nano_12_p8_224_dist.zip', + 'xcit_nano_12_p8_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_nano_12_p8_384_dist.zip', + 'xcit_tiny_12_p8_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_12_p8_224.zip', + 'xcit_tiny_12_p8_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_12_p8_224_dist.zip', + 'xcit_tiny_12_p8_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_12_p8_384_dist.zip', + 'xcit_tiny_24_p8_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_24_p8_224.zip', + 'xcit_tiny_24_p8_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_24_p8_224_dist.zip', + 'xcit_tiny_24_p8_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_tiny_24_p8_384_dist.zip', + 'xcit_small_12_p8_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_12_p8_224.zip', + 'xcit_small_12_p8_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_12_p8_224_dist.zip', + 'xcit_small_12_p8_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_12_p8_384_dist.zip', + 'xcit_small_24_p8_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_24_p8_224.zip', + 'xcit_small_24_p8_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_24_p8_224_dist.zip', + 'xcit_small_24_p8_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_small_24_p8_384_dist.zip', + 'xcit_medium_24_p8_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_medium_24_p8_224.zip', + 'xcit_medium_24_p8_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_medium_24_p8_224_dist.zip', + 'xcit_medium_24_p8_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_medium_24_p8_384_dist.zip', + 'xcit_large_24_p8_224': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_large_24_p8_224.zip', + 'xcit_large_24_p8_224_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_large_24_p8_224_dist.zip', + 'xcit_large_24_p8_384_dist': 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Xcit/xcit_large_24_p8_384_dist.zip'} + + +class ClassAttn(nn.Module): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to do CA + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + q = q * self.scale + v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) + x_cls = self.proj(x_cls) + x_cls = self.proj_drop(x_cls) + + return x_cls + + +class PositionalEncodingFourier(nn.Module): + """ + Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. + Based on the official XCiT code + - https://github.com/facebookresearch/xcit/blob/master/xcit.py + """ + + def __init__(self, hidden_dim=32, dim=768, temperature=10000): + super().__init__() + self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + self.eps = 1e-6 + + def forward(self, B: int, H: int, W: int): + device = self.token_projection.weight.device + y_embed = flow.arange(1, H + 1, dtype=flow.float32, device=device).unsqueeze(1).repeat(1, 1, W) + x_embed = flow.arange(1, W + 1, dtype=flow.float32, device=device).repeat(1, H, 1) + y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale + dim_t = flow.arange(self.hidden_dim, dtype=flow.float32, device=device) + dim_t = self.temperature ** (2 * flow.div(dim_t, 2).floor() / self.hidden_dim) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = flow.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3) + pos_y = flow.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3) + pos = flow.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos) + return pos.repeat(B, 1, 1, 1) # (B, C, H, W) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution + batch norm""" + return flow.nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), + nn.BatchNorm2d(out_planes) + ) + + +class ConvPatchEmbed(nn.Module): + """Image to Patch Embedding using multiple convolutional layers""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, act_layer=nn.GELU): + super().__init__() + img_size = to_2tuple(img_size) + num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + if patch_size == 16: + self.proj = flow.nn.Sequential( + conv3x3(in_chans, embed_dim // 8, 2), + act_layer(), + conv3x3(embed_dim // 8, embed_dim // 4, 2), + act_layer(), + conv3x3(embed_dim // 4, embed_dim // 2, 2), + act_layer(), + conv3x3(embed_dim // 2, embed_dim, 2), + ) + elif patch_size == 8: + self.proj = flow.nn.Sequential( + conv3x3(in_chans, embed_dim // 4, 2), + act_layer(), + conv3x3(embed_dim // 4, embed_dim // 2, 2), + act_layer(), + conv3x3(embed_dim // 2, embed_dim, 2), + ) + else: + raise ('For convolutional projection, patch size has to be in [8, 16]') + + def forward(self, x): + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + x = x.flatten(2).transpose(1, 2) # (B, N, C) + return x, (Hp, Wp) + + +class LPI(nn.Module): + """ + Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows to augment the + implicit communication performed by the block diagonal scatter attention. Implemented using 2 layers of separable + 3x3 convolutions with GeLU and BatchNorm2d + """ + + def __init__(self, in_features, out_features=None, act_layer=nn.GELU, kernel_size=3): + super().__init__() + out_features = out_features or in_features + + padding = kernel_size // 2 + + self.conv1 = flow.nn.Conv2d( + in_features, in_features, kernel_size=kernel_size, padding=padding, groups=in_features) + self.act = act_layer() + self.bn = nn.BatchNorm2d(in_features) + self.conv2 = flow.nn.Conv2d( + in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features) + + def forward(self, x, H: int, W: int): + B, N, C = x.shape + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.conv1(x) + x = self.act(x) + x = self.bn(x) + x = self.conv2(x) + x = x.reshape(B, C, N).permute(0, 2, 1) + return x + + +class ClassAttentionBlock(nn.Module): + """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239""" + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False): + super().__init__() + self.norm1 = norm_layer(dim) + + self.attn = ClassAttn( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + + if eta is not None: # LayerScale Initialization (no layerscale when None) + self.gamma1 = nn.Parameter(eta * flow.ones(dim), requires_grad=True) + self.gamma2 = nn.Parameter(eta * flow.ones(dim), requires_grad=True) + else: + self.gamma1, self.gamma2 = 1.0, 1.0 + + # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 + self.tokens_norm = tokens_norm + + def forward(self, x): + x_norm1 = self.norm1(x) + x_attn = flow.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1) + x = x + self.drop_path(self.gamma1 * x_attn) + if self.tokens_norm: + x = self.norm2(x) + else: + x = flow.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1) + x_res = x + cls_token = x[:, 0:1] + cls_token = self.gamma2 * self.mlp(cls_token) + x = flow.cat([cls_token, x[:, 1:]], dim=1) + x = x_res + self.drop_path(x) + return x + + +class XCA(nn.Module): + """ Cross-Covariance Attention (XCA) + Operation where the channels are updated using a weighted sum. The weights are obtained from the (softmax + normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h) + """ + + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(flow.ones(num_heads, 1, 1)) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + # Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # Paper section 3.2 l2-Normalization and temperature scaling + q = flow.nn.functional.normalize(q, dim=-1) + k = flow.nn.functional.normalize(k, dim=-1) + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # (B, H, C', N), permute -> (B, N, H, C') + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def no_weight_decay(self): + return {'temperature'} + + +class XCABlock(nn.Module): + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm3 = norm_layer(dim) + self.local_mp = LPI(in_features=dim, act_layer=act_layer) + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + + self.gamma1 = nn.Parameter(eta * flow.ones(dim), requires_grad=True) + self.gamma3 = nn.Parameter(eta * flow.ones(dim), requires_grad=True) + self.gamma2 = nn.Parameter(eta * flow.ones(dim), requires_grad=True) + + def forward(self, x, H: int, W: int): + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) + # NOTE official code has 3 then 2, so keeping it the same to be consistent with loaded weights + # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 + x = x + self.drop_path(self.gamma3 * self.local_mp(self.norm3(x), H, W)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + return x + + +class XCiT(nn.Module): + """ + Based on timm and DeiT code bases + https://github.com/rwightman/pytorch-image-models/tree/master/timm + https://github.com/facebookresearch/deit/ + """ + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, + depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False): + """ + Args: + img_size (int, tuple): input image size + patch_size (int): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate (constant across all layers) + norm_layer: (nn.Module): normalization layer + cls_attn_layers: (int) Depth of Class attention layers + use_pos_embed: (bool) whether to use positional encoding + eta: (float) layerscale initialization value + tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA + + Notes: + - Although `layer_norm` is user specifiable, there are hard-coded `BatchNorm2d`s in the local patch + interaction (class LPI) and the patch embedding (class ConvPatchEmbed) + """ + super().__init__() + assert global_pool in ('', 'avg', 'token') + img_size = to_2tuple(img_size) + assert (img_size[0] % patch_size == 0) and (img_size[0] % patch_size == 0), \ + '`patch_size` should divide image dimensions evenly' + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim + self.global_pool = global_pool + + self.patch_embed = ConvPatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer) + + self.cls_token = nn.Parameter(flow.zeros(1, 1, embed_dim)) + self.use_pos_embed = use_pos_embed + if use_pos_embed: + self.pos_embed = PositionalEncodingFourier(dim=embed_dim) + self.pos_drop = nn.Dropout(p=drop_rate) + + self.blocks = nn.ModuleList([ + XCABlock( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=drop_path_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta) + for _ in range(depth)]) + + self.cls_attn_blocks = nn.ModuleList([ + ClassAttentionBlock( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm) + for _ in range(cls_attn_layers)]) + + # Classifier head + self.norm = norm_layer(embed_dim) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # Init weights + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=r'^blocks\.(\d+)', + cls_attn_blocks=[(r'^cls_attn_blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token') + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + # x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches) + x, (Hp, Wp) = self.patch_embed(x) + + if self.use_pos_embed: + # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C) + pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1) + x = x + pos_encoding + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x, Hp, Wp) + + x = flow.cat((self.cls_token.expand(B, -1, -1), x), dim=1) + + for blk in self.cls_attn_blocks: + x = blk(x) + + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict: + state_dict = state_dict['model'] + # For consistency with timm's transformer models while being compatible with official weights source we rename + # pos_embeder to pos_embed. Also account for use_pos_embed == False + use_pos_embed = getattr(model, 'pos_embed', None) is not None + pos_embed_keys = [k for k in state_dict if k.startswith('pos_embed')] + for k in pos_embed_keys: + if use_pos_embed: + state_dict[k.replace('pos_embeder.', 'pos_embed.')] = state_dict.pop(k) + else: + del state_dict[k] + # timm's implementation of class attention in CaiT is slightly more efficient as it does not compute query vectors + # for all tokens, just the class token. To use official weights source we must split qkv into q, k, v + if 'cls_attn_blocks.0.attn.qkv.weight' in state_dict and 'cls_attn_blocks.0.attn.q.weight' in model.state_dict(): + num_ca_blocks = len(model.cls_attn_blocks) + for i in range(num_ca_blocks): + qkv_weight = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.weight') + qkv_weight = qkv_weight.reshape(3, -1, qkv_weight.shape[-1]) + for j, subscript in enumerate('qkv'): + state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.weight'] = qkv_weight[j] + qkv_bias = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.bias', None) + if qkv_bias is not None: + qkv_bias = qkv_bias.reshape(3, -1) + for j, subscript in enumerate('qkv'): + state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.bias'] = qkv_bias[j] + return state_dict + + +def _create_xcit(arch: str, pretrained: bool = False, progress: bool = True, **kwargs): + model = XCiT(**kwargs) + if pretrained: + state_dict = load_state_dict_from_url( + model_urls[arch], + model_dir="./checkpoints", + progress=progress, + ) + model.load_state_dict(state_dict) + return model + + +@ModelCreator.register_model +def xcit_nano_12_p16_224(pretrained:bool =False, progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_nano_12_p16_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p16_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_nano_12_p16_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384, **kwargs) + model = _create_xcit('xcit_nano_12_p16_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_12_p16_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_12_p16_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p16_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_12_p16_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p16_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_12_p16_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_12_p16_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p16_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_12_p16_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p16_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_24_p16_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_24_p16_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p16_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_24_p16_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p16_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_24_p16_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_24_p16_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p16_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_24_p16_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p16_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_medium_24_p16_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_medium_24_p16_224_dist(pretrained:bool =False,progress: bool = True,**kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p16_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_medium_24_p16_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p16_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_large_24_p16_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_large_24_p16_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p16_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_large_24_p16_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p16_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +# Patch size 8x8 models +@ModelCreator.register_model +def xcit_nano_12_p8_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_nano_12_p8_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p8_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_nano_12_p8_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p8_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_12_p8_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_12_p8_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p8_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_12_p8_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p8_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_12_p8_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_12_p8_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p8_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_12_p8_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p8_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_24_p8_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_24_p8_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p8_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_tiny_24_p8_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p8_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_24_p8_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_24_p8_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p8_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_small_24_p8_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p8_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_medium_24_p8_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_medium_24_p8_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p8_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_medium_24_p8_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p8_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_large_24_p8_224(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_large_24_p8_224_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p8_224_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model + + +@ModelCreator.register_model +def xcit_large_24_p8_384_dist(pretrained:bool =False,progress: bool = True, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p8_384_dist', pretrained=pretrained, progress=progress ,**model_kwargs) + return model