diff --git a/README.md b/README.md index 42ba239..035572f 100644 --- a/README.md +++ b/README.md @@ -276,6 +276,47 @@ indices = quantizer(x) This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False` +### Sim VQ + + + +A new ICLR 2025 paper proposes a scheme where the codebook is frozen, and the codes are implicitly generated through a linear projection. The authors claim this setup leads to less codebook collapse as well as easier convergence. I have found this to perform even better when paired with rotation trick from Fifty et al., and expanding the linear projection to a small one layer MLP. You can experiment with it as so + +```python +import torch +from vector_quantize_pytorch import SimVQ + +sim_vq = SimVQ( + dim = 512, + codebook_size = 1024 +) + +x = torch.randn(1, 1024, 512) +quantized, indices, commit_loss = sim_vq(x) + +assert x.shape == quantized.shape +assert torch.allclose(quantized, sim_vq.indices_to_codes(indices), atol = 1e-6) +``` + +For the residual flavor, just import `ResidualSimVQ` instead + +```python +import torch +from vector_quantize_pytorch import ResidualSimVQ + +residual_sim_vq = ResidualSimVQ( + dim = 512, + num_quantizers = 4, + codebook_size = 1024 +) + +x = torch.randn(1, 1024, 512) +quantized, indices, commit_loss = residual_sim_vq(x) + +assert x.shape == quantized.shape +assert torch.allclose(quantized, residual_sim_vq.get_output_from_indices(indices), atol = 1e-6) +``` + ### Finite Scalar Quantization diff --git a/images/simvq.png b/images/simvq.png new file mode 100644 index 0000000..7b5d606 Binary files /dev/null and b/images/simvq.png differ diff --git a/pyproject.toml b/pyproject.toml index 671334e..ad856ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.20.7" +version = "1.20.8" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py index f264ffe..a873173 100644 --- a/vector_quantize_pytorch/sim_vq.py +++ b/vector_quantize_pytorch/sim_vq.py @@ -44,6 +44,7 @@ def __init__( accept_image_fmap = False, rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize input_to_quantize_commit_loss_weight = 0.25, + commitment_weight = 1., frozen_codebook_dim = None # frozen codebook dim could have different dimensions than projection ): super().__init__() @@ -74,6 +75,10 @@ def __init__( self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight + # total commitment loss weight + + self.commitment_weight = commitment_weight + @property def codebook(self): return self.code_transform(self.frozen_codebook) @@ -132,7 +137,7 @@ def forward( indices = inverse_pack(indices, 'b *') - return quantized, indices, commit_loss + return quantized, indices, commit_loss * self.commitment_weight # main