From 36ddc7a6baa4d7f54c4ba7cd5558deb78808afd1 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 10 Oct 2024 10:42:37 -0700 Subject: [PATCH] go all the way with the normalized vit, fix some scales --- setup.py | 2 +- vit_pytorch/normalized_vit.py | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index e7e77aa..01a0e3c 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.8.2', + version = '1.8.4', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/normalized_vit.py b/vit_pytorch/normalized_vit.py index c3a5925..5a21126 100644 --- a/vit_pytorch/normalized_vit.py +++ b/vit_pytorch/normalized_vit.py @@ -179,18 +179,18 @@ def __init__( self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size), - nn.LayerNorm(patch_dim), - nn.Linear(patch_dim, dim), - nn.LayerNorm(dim), + NormLinear(patch_dim, dim, norm_dim_in = False), ) - self.abs_pos_emb = nn.Embedding(num_patches, dim) + self.abs_pos_emb = NormLinear(dim, num_patches) residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth) # layers self.dim = dim + self.scale = dim ** 0.5 + self.layers = ModuleList([]) self.residual_lerp_scales = nn.ParameterList([]) @@ -201,8 +201,8 @@ def __init__( ])) self.residual_lerp_scales.append(nn.ParameterList([ - nn.Parameter(torch.ones(dim) * residual_lerp_scale_init), - nn.Parameter(torch.ones(dim) * residual_lerp_scale_init), + nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale), + nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale), ])) self.logit_scale = nn.Parameter(torch.ones(num_classes)) @@ -225,22 +225,23 @@ def forward(self, images): tokens = self.to_patch_embedding(images) - pos_emb = self.abs_pos_emb(torch.arange(tokens.shape[-2], device = device)) + seq_len = tokens.shape[-2] + pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)] tokens = l2norm(tokens + pos_emb) for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales): attn_out = l2norm(attn(tokens)) - tokens = l2norm(tokens.lerp(attn_out, attn_alpha)) + tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale)) ff_out = l2norm(ff(tokens)) - tokens = l2norm(tokens.lerp(ff_out, ff_alpha)) + tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale)) pooled = reduce(tokens, 'b n d -> b d', 'mean') logits = self.to_pred(pooled) - logits = logits * self.logit_scale * (self.dim ** 0.5) + logits = logits * self.logit_scale * self.scale return logits