From bbb24e34d457fdca3c2a6b442354f203e6294670 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 6 Oct 2023 10:40:26 -0700 Subject: [PATCH] give a learned bias to and from registers for maxvit + register token variant --- setup.py | 2 +- vit_pytorch/max_vit_with_registers.py | 25 +++++++++++++------------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index 28f87cc..929b58a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.5.2', + version = '1.5.3', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/max_vit_with_registers.py b/vit_pytorch/max_vit_with_registers.py index a55b580..a84a603 100644 --- a/vit_pytorch/max_vit_with_registers.py +++ b/vit_pytorch/max_vit_with_registers.py @@ -119,9 +119,11 @@ def __init__( dim, dim_head = 32, dropout = 0., - window_size = 7 + window_size = 7, + num_registers = 1 ): super().__init__() + assert num_registers > 0 assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head' self.heads = dim // dim_head @@ -142,7 +144,9 @@ def __init__( # relative positional bias - self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) + num_rel_pos_bias = (2 * window_size - 1) ** 2 + + self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads) pos = torch.arange(window_size) grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij')) @@ -151,10 +155,11 @@ def __init__( rel_pos += window_size - 1 rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1) + rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias) self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False) def forward(self, x): - device, h = x.device, self.heads + device, h, bias_indices = x.device, self.heads, self.rel_pos_indices x = self.norm(x) @@ -176,13 +181,8 @@ def forward(self, x): # add positional bias - bias = self.rel_pos_bias(self.rel_pos_indices) - bias = rearrange(bias, 'i j h -> h i j') - - num_registers = sim.shape[-1] - bias.shape[-1] - bias = F.pad(bias, (num_registers, 0, num_registers, 0), value = 0.) - - sim = sim + bias + bias = self.rel_pos_bias(bias_indices) + sim = sim + rearrange(bias, 'i j h -> h i j') # attention @@ -215,6 +215,7 @@ def __init__( ): super().__init__() assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage' + assert num_register_tokens > 0 # convolutional stem @@ -256,10 +257,10 @@ def __init__( shrinkage_rate = mbconv_shrinkage_rate ) - block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size) + block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens) block_ff = FeedForward(dim = layer_dim, dropout = dropout) - grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size) + grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens) grid_ff = FeedForward(dim = layer_dim, dropout = dropout) register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))