diff --git a/lib/models/cls_cvt.py b/lib/models/cls_cvt.py index 0267176..e5952e6 100644 --- a/lib/models/cls_cvt.py +++ b/lib/models/cls_cvt.py @@ -140,7 +140,7 @@ def _build_projection(self, groups=dim_in )), ('bn', nn.BatchNorm2d(dim_in)), - ('rearrage', Rearrange('b c h w -> b (h w) c')), + ('rearrange', Rearrange('b c h w -> b (h w) c')), ])) elif method == 'avg': proj = nn.Sequential(OrderedDict([ @@ -150,7 +150,7 @@ def _build_projection(self, stride=stride, ceil_mode=True )), - ('rearrage', Rearrange('b c h w -> b (h w) c')), + ('rearrange', Rearrange('b c h w -> b (h w) c')), ])) elif method == 'linear': proj = None