diff --git a/README.md b/README.md
index 42ba239..035572f 100644
--- a/README.md
+++ b/README.md
@@ -276,6 +276,47 @@ indices = quantizer(x)
This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False`
+### Sim VQ
+
+
+
+A new ICLR 2025 paper proposes a scheme where the codebook is frozen, and the codes are implicitly generated through a linear projection. The authors claim this setup leads to less codebook collapse as well as easier convergence. I have found this to perform even better when paired with rotation trick from Fifty et al., and expanding the linear projection to a small one layer MLP. You can experiment with it as so
+
+```python
+import torch
+from vector_quantize_pytorch import SimVQ
+
+sim_vq = SimVQ(
+ dim = 512,
+ codebook_size = 1024
+)
+
+x = torch.randn(1, 1024, 512)
+quantized, indices, commit_loss = sim_vq(x)
+
+assert x.shape == quantized.shape
+assert torch.allclose(quantized, sim_vq.indices_to_codes(indices), atol = 1e-6)
+```
+
+For the residual flavor, just import `ResidualSimVQ` instead
+
+```python
+import torch
+from vector_quantize_pytorch import ResidualSimVQ
+
+residual_sim_vq = ResidualSimVQ(
+ dim = 512,
+ num_quantizers = 4,
+ codebook_size = 1024
+)
+
+x = torch.randn(1, 1024, 512)
+quantized, indices, commit_loss = residual_sim_vq(x)
+
+assert x.shape == quantized.shape
+assert torch.allclose(quantized, residual_sim_vq.get_output_from_indices(indices), atol = 1e-6)
+```
+
### Finite Scalar Quantization
diff --git a/images/simvq.png b/images/simvq.png
new file mode 100644
index 0000000..7b5d606
Binary files /dev/null and b/images/simvq.png differ
diff --git a/pyproject.toml b/pyproject.toml
index 671334e..ad856ee 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
-version = "1.20.7"
+version = "1.20.8"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py
index f264ffe..a873173 100644
--- a/vector_quantize_pytorch/sim_vq.py
+++ b/vector_quantize_pytorch/sim_vq.py
@@ -44,6 +44,7 @@ def __init__(
accept_image_fmap = False,
rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
input_to_quantize_commit_loss_weight = 0.25,
+ commitment_weight = 1.,
frozen_codebook_dim = None # frozen codebook dim could have different dimensions than projection
):
super().__init__()
@@ -74,6 +75,10 @@ def __init__(
self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
+ # total commitment loss weight
+
+ self.commitment_weight = commitment_weight
+
@property
def codebook(self):
return self.code_transform(self.frozen_codebook)
@@ -132,7 +137,7 @@ def forward(
indices = inverse_pack(indices, 'b *')
- return quantized, indices, commit_loss
+ return quantized, indices, commit_loss * self.commitment_weight
# main