diff --git a/README.md b/README.md index 09b57c0..96e0e1a 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,25 @@ quantized, indices, commit_loss = residual_vq(x) # (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8) ``` + +This paper out of Google Deepmind claims that residual vector quantization can induce hierarchical semantic ids for building a recommender system. In their scheme, they use increasing number of codes across depth for it to work. This repository supports that scheme as so + +```python +import torch +from vector_quantize_pytorch import ResidualVQ + +residual_vq = ResidualVQ( + dim = 256, + codebook_size = (5, 128, 256), # from top most hierarchy to lowest, 5 codes, 128 codes, then 256 codes +) + +x = torch.randn(2, 16, 256) + +quantized, indices, commit_loss = residual_vq(x) + +# (2, 16, 256), (2, 16, 3), (2, 3) +``` + ## Initialization The SoundStream paper proposes that the codebook should be initialized by the kmeans centroids of the first batch. You can easily turn on this feature with one flag `kmeans_init = True`, for either `VectorQuantize` or `ResidualVQ` class @@ -713,4 +732,15 @@ assert loss.item() >= 0 volume = {abs/2410.06424}, url = {https://api.semanticscholar.org/CorpusID:273229218} } -``` \ No newline at end of file +``` + +```bibtex +@article{Rajput2023RecommenderSW, + title = {Recommender Systems with Generative Retrieval}, + author = {Shashank Rajput and Nikhil Mehta and Anima Singh and Raghunandan H. Keshavan and Trung Hieu Vu and Lukasz Heldt and Lichan Hong and Yi Tay and Vinh Q. Tran and Jonah Samost and Maciej Kula and Ed H. Chi and Maheswaran Sathiamoorthy}, + journal = {ArXiv}, + year = {2023}, + volume = {abs/2305.05065}, + url = {https://api.semanticscholar.org/CorpusID:258564854} +} +``` diff --git a/pyproject.toml b/pyproject.toml index 364efe5..39a5d02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.18.8" +version = "1.19.0" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_readme.py b/tests/test_readme.py index 79c8a7f..e383996 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -201,6 +201,24 @@ def test_rq(): x = torch.randn(1, 1024, 512) indices = quantizer(x) +def test_tiger(): + from vector_quantize_pytorch import ResidualVQ + + residual_vq = ResidualVQ( + dim = 2, + codebook_size = (5, 128, 256), + ) + + x = torch.randn(2, 2, 2) + + residual_vq.train() + + quantized, indices, commit_loss = residual_vq(x, freeze_codebook = True) + + quantized_out = residual_vq.get_output_from_indices(indices) # pass your indices into here, but the indices must come during .eval(), as during training some of the indices are dropped out (-1) + + assert torch.allclose(quantized, quantized_out, atol = 1e-5) + def test_fsq(): from vector_quantize_pytorch import FSQ diff --git a/vector_quantize_pytorch/residual_vq.py b/vector_quantize_pytorch/residual_vq.py index 3e70ae2..d4f5c1c 100644 --- a/vector_quantize_pytorch/residual_vq.py +++ b/vector_quantize_pytorch/residual_vq.py @@ -1,5 +1,4 @@ from __future__ import annotations -from typing import List import random from math import ceil @@ -28,6 +27,12 @@ def first(it): def default(val, d): return val if exists(val) else d +def cast_tuple(t, length = 1): + return t if isinstance(t, tuple) else ((t,) * length) + +def unique(arr): + return list({*arr}) + def round_up_multiple(num, mult): return ceil(num / mult) * mult @@ -110,7 +115,8 @@ def __init__( self, *, dim, - num_quantizers, + num_quantizers: int | None = None, + codebook_size: int | tuple[int, ...], codebook_dim = None, shared_codebook = False, heads = 1, @@ -124,6 +130,8 @@ def __init__( ): super().__init__() assert heads == 1, 'residual vq is not compatible with multi-headed codes' + assert exists(num_quantizers) or isinstance(codebook_size, tuple) + codebook_dim = default(codebook_dim, dim) codebook_input_dim = codebook_dim * heads @@ -132,8 +140,6 @@ def __init__( self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() self.has_projections = requires_projection - self.num_quantizers = num_quantizers - self.accept_image_fmap = accept_image_fmap self.implicit_neural_codebook = implicit_neural_codebook @@ -150,7 +156,21 @@ def __init__( manual_in_place_optimizer_update = True ) - self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for _ in range(num_quantizers)]) + # take care of maybe different codebook sizes across depth, used in TIGER paper https://arxiv.org/abs/2305.05065 + + codebook_sizes = cast_tuple(codebook_size, num_quantizers) + + num_quantizers = len(codebook_sizes) + self.num_quantizers = num_quantizers + + assert len(codebook_sizes) == num_quantizers + + self.codebook_sizes = codebook_sizes + self.uniform_codebook_size = len(unique(codebook_sizes)) == 1 + + # define vq across layers + + self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_size = layer_codebook_size, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for layer_codebook_size in codebook_sizes]) assert all([not vq.has_projections for vq in self.layers]) @@ -167,6 +187,8 @@ def __init__( if implicit_neural_codebook: self.mlps = ModuleList([MLP(dim = codebook_dim, l2norm_output = first(self.layers).use_cosine_sim, **mlp_kwargs) for _ in range(num_quantizers - 1)]) + else: + self.mlps = (None,) * (num_quantizers - 1) # sharing codebook logic @@ -175,6 +197,8 @@ def __init__( if not shared_codebook: return + assert self.uniform_codebook_size + first_vq, *rest_vq = self.layers codebook = first_vq._codebook @@ -192,8 +216,13 @@ def codebook_dim(self): @property def codebooks(self): codebooks = [layer._codebook.embed for layer in self.layers] - codebooks = torch.stack(codebooks, dim = 0) - codebooks = rearrange(codebooks, 'q 1 c d -> q c d') + + codebooks = tuple(rearrange(codebook, '1 ... -> ...') for codebook in codebooks) + + if not self.uniform_codebook_size: + return codebooks + + codebooks = torch.stack(codebooks) return codebooks def get_codes_from_indices(self, indices): @@ -216,13 +245,12 @@ def get_codes_from_indices(self, indices): mask = indices == -1. indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later - if not self.implicit_neural_codebook: - # gather all the codes + if not self.implicit_neural_codebook and self.uniform_codebook_size: all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices) else: - # else if using implicit neural codebook, codes will need to be derived layer by layer + # else if using implicit neural codebook, or non uniform codebook sizes, codes will need to be derived layer by layer code_transform_mlps = (None, *self.mlps) @@ -261,7 +289,7 @@ def forward( self, x, mask = None, - indices: Tensor | List[Tensor] | None = None, + indices: Tensor | list[Tensor] | None = None, return_all_codes = False, sample_codebook_temp = None, freeze_codebook = False,