Skip to content

Commit

Permalink
allow for arbitrary dimensions into SimVQ and ResidualSimVQ (video an…
Browse files Browse the repository at this point in the history
…d beyond)
  • Loading branch information
lucidrains committed Nov 13, 2024
1 parent 448c4a5 commit a766304
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 28 deletions.
2 changes: 1 addition & 1 deletion examples/autoencoder_sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
num_codes = 256
seed = 1234

rotation_trick = True # rotation trick instead ot straight-through
rotation_trick = True # rotation trick instead ot straight-through
use_mlp = True # use a one layer mlp with relu instead of linear

device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.20.8"
version = "1.20.9"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
2 changes: 1 addition & 1 deletion tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def test_residual_sim_vq():
dim = 512,
num_quantizers = 4,
codebook_size = 1024,
accept_image_fmap = True
channel_first = True
)

x = torch.randn(1, 512, 32, 32)
Expand Down
22 changes: 8 additions & 14 deletions vector_quantize_pytorch/residual_sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,20 @@ def __init__(
quantize_dropout = False,
quantize_dropout_cutoff_index = 0,
quantize_dropout_multiple_of = 1,
accept_image_fmap = False,
channel_first = False,
rotation_trick = True, # rotation trick from @cfifty, on top of sim vq
**sim_vq_kwargs
):
super().__init__()
assert heads == 1, 'residual vq is not compatible with multi-headed codes'

self.accept_image_fmap = accept_image_fmap
self.channel_first = channel_first

self.num_quantizers = num_quantizers

# define sim vq across layers

self.layers = ModuleList([SimVQ(dim = dim, codebook_size = codebook_size, rotation_trick = rotation_trick, accept_image_fmap = accept_image_fmap, **sim_vq_kwargs) for _ in range(num_quantizers)])
self.layers = ModuleList([SimVQ(dim = dim, codebook_size = codebook_size, rotation_trick = rotation_trick, channel_first = channel_first, **sim_vq_kwargs) for _ in range(num_quantizers)])

# quantize dropout

Expand Down Expand Up @@ -100,7 +100,7 @@ def get_codes_from_indices(self, indices):

batch, quantize_dim = indices.shape[0], indices.shape[-1]

# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
# may also receive indices in the shape of 'b h w q' (images)

indices, inverse = pack_one(indices, 'b * q')

Expand All @@ -122,11 +122,11 @@ def get_codes_from_indices(self, indices):

all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)

# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
# if (channel_first = True) then return shape (quantize, batch, height, width, dimension)

all_codes = inverse(all_codes, 'q b * d')

if self.accept_image_fmap:
if self.channel_first:
all_codes = rearrange(all_codes, 'q b ... d -> q b d ...')

return all_codes
Expand All @@ -139,23 +139,17 @@ def get_output_from_indices(self, indices):
def forward(
self,
x,
indices: Tensor | list[Tensor] | None = None,
return_all_codes = False,
rand_quantize_dropout_fixed_seed = None
):
num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device

assert not (self.accept_image_fmap and exists(indices))
num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device

quantized_out = 0.
residual = x

all_losses = []
all_indices = []

if isinstance(indices, list):
indices = torch.stack(indices)

should_quantize_dropout = self.training and self.quantize_dropout and not return_loss

# sample a layer index at which to dropout further residual quantization
Expand All @@ -175,7 +169,7 @@ def forward(
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1

null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.channel_first else tuple(x.shape[:2])
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)

Expand Down
23 changes: 12 additions & 11 deletions vector_quantize_pytorch/sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def __init__(
codebook_size,
codebook_transform: Module | None = None,
init_fn: Callable = identity,
accept_image_fmap = False,
channel_first = False,
rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
input_to_quantize_commit_loss_weight = 0.25,
commitment_weight = 1.,
frozen_codebook_dim = None # frozen codebook dim could have different dimensions than projection
):
super().__init__()
self.codebook_size = codebook_size
self.accept_image_fmap = accept_image_fmap
self.channel_first = channel_first

frozen_codebook_dim = default(frozen_codebook_dim, dim)
codebook = torch.randn(codebook_size, frozen_codebook_dim) * (frozen_codebook_dim ** -0.5)
Expand Down Expand Up @@ -92,7 +92,7 @@ def indices_to_codes(
frozen_codes = get_at('[c] d, b ... -> b ... d', self.frozen_codebook, indices)
quantized = self.code_transform(frozen_codes)

if self.accept_image_fmap:
if self.channel_first:
quantized = rearrange(quantized, 'b ... d -> b d ...')

return quantized
Expand All @@ -101,9 +101,10 @@ def forward(
self,
x
):
if self.accept_image_fmap:
x = rearrange(x, 'b d h w -> b h w d')
x, inverse_pack = pack_one(x, 'b * d')
if self.channel_first:
x = rearrange(x, 'b d ... -> b ... d')

x, inverse_pack = pack_one(x, 'b * d')

implicit_codebook = self.codebook

Expand Down Expand Up @@ -131,11 +132,11 @@ def forward(

quantized = (quantized - x).detach() + x

if self.accept_image_fmap:
quantized = inverse_pack(quantized)
quantized = rearrange(quantized, 'b h w d-> b d h w')
quantized = inverse_pack(quantized)
indices = inverse_pack(indices, 'b *')

indices = inverse_pack(indices, 'b *')
if self.channel_first:
quantized = rearrange(quantized, 'b ... d-> b d ...')

return quantized, indices, commit_loss * self.commitment_weight

Expand All @@ -153,7 +154,7 @@ def forward(
nn.Linear(1024, 512)
),
codebook_size = 1024,
accept_image_fmap = True
channel_first = True
)

quantized, indices, commit_loss = sim_vq(x)
Expand Down

0 comments on commit a766304

Please sign in to comment.