diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..4b792b9 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,19 @@ +name: Tests the examples in README +on: push + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Python + uses: actions/setup-python@v4 + - name: Install the latest version of rye + uses: eifinger/setup-rye@v2 + - name: Use UV instead of pip + run: rye config --set-bool behavior.use-uv=true + - name: Install dependencies + run: | + rye sync + - name: Run pytest + run: rye run pytest --cov=. tests/test_examples_readme.py diff --git a/README.md b/README.md index 07f73e5..c3580fd 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ vq = VectorQuantize( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1) +print(quantized.shape, indices.shape, commit_loss.shape) +#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1]) ``` ## Residual VQ @@ -46,16 +48,14 @@ residual_vq = ResidualVQ( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = residual_vq(x) - -# (1, 1024, 256), (1, 1024, 8), (1, 8) -# (batch, seq, dim), (batch, seq, quantizer), (batch, quantizer) +print(quantized.shape, indices.shape, commit_loss.shape) +#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8]) # if you need all the codes across the quantization layers, just pass return_all_codes = True quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True) - -# *_, (8, 1, 1024, 256) -# all_codes - (quantizer, batch, seq, dim) +print(all_codes.shape) +#> torch.Size([8, 1, 1024, 256]) ``` Furthermore, this paper uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes. @@ -77,9 +77,8 @@ residual_vq = ResidualVQ( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = residual_vq(x) - -# (1, 1024, 256), (8, 1, 1024), (8, 1) -# (batch, seq, dim), (quantizer, batch, seq), (quantizer, batch) +print(quantized.shape, indices.shape, commit_loss.shape) +#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8]) ``` A recent paper further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing `GroupedResidualVQ` @@ -98,9 +97,8 @@ residual_vq = GroupedResidualVQ( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = residual_vq(x) - -# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8) -# (batch, seq, dim), (groups, batch, seq, quantizer), (groups, batch, quantizer) +print(quantized.shape, indices.shape, commit_loss.shape) +#> torch.Size([1, 1024, 256]) torch.Size([2, 1, 1024, 8]) torch.Size([2, 1, 8]) ``` @@ -122,6 +120,8 @@ residual_vq = ResidualVQ( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = residual_vq(x) +print(quantized.shape, indices.shape, commit_loss.shape) +#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 4]) torch.Size([1, 4]) ``` ## Increasing codebook usage @@ -144,6 +144,8 @@ vq = VectorQuantize( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = vq(x) +print(quantized.shape, indices.shape, commit_loss.shape) +#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1]) ``` ### Cosine similarity @@ -162,6 +164,8 @@ vq = VectorQuantize( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = vq(x) +print(quantized.shape, indices.shape, commit_loss.shape) +#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1]) ``` ### Expiring stale codes @@ -180,6 +184,8 @@ vq = VectorQuantize( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = vq(x) +print(quantized.shape, indices.shape, commit_loss.shape) +#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1]) ``` ### Orthogonal regularization loss @@ -204,6 +210,8 @@ vq = VectorQuantize( img_fmap = torch.randn(1, 256, 32, 32) quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,) # loss now contains the orthogonal regularization loss with the weight as assigned +print(quantized.shape, indices.shape, loss.shape) +#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32]) torch.Size([1]) ``` ### Multi-headed VQ @@ -226,10 +234,12 @@ vq = VectorQuantize( ) img_fmap = torch.randn(1, 256, 32, 32) -quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32, 8), (1,) +quantized, indices, loss = vq(img_fmap) +print(quantized.shape, indices.shape, loss.shape) +#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32, 8]) torch.Size([1]) -# indices shape - (batch, height, width, heads) ``` + ### Random Projection Quantizer This paper first proposed to use a random projection quantizer for masked speech modeling, where signals are projected with a randomly initialized matrix and then matched with a random initialized codebook. One therefore does not need to learn the quantizer. This technique was used by Google's Universal Speech Model to achieve SOTA for speech-to-text modeling. @@ -248,7 +258,9 @@ quantizer = RandomProjectionQuantizer( ) x = torch.randn(1, 1024, 512) -indices = quantizer(x) # (1, 1024, 16) - (batch, seq, num_codebooks) +indices = quantizer(x) +print(indices.shape) +#> torch.Size([1, 1024, 16]) ``` This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False` @@ -279,10 +291,11 @@ quantizer = FSQ(levels) x = torch.randn(1, 1024, 4) # 4 since there are 4 levels xhat, indices = quantizer(x) -print(xhat.shape) # (1, 1024, 4) - (batch, seq, dim) -print(indices.shape) # (1, 1024) - (batch, seq) +print(xhat.shape) +#> torch.Size([1, 1024, 4]) +print(indices.shape) +#> torch.Size([1, 1024]) -assert xhat.shape == x.shape assert torch.all(xhat == quantizer.indices_to_codes(indices)) ``` @@ -305,14 +318,12 @@ x = torch.randn(1, 1024, 256) residual_fsq.eval() quantized, indices = residual_fsq(x) - -# (1, 1024, 256), (1, 1024, 8), (8) -# (batch, seq, dim), (batch, seq, quantizers), (quantizers) +print(quantized.shape, indices.shape) +#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) quantized_out = residual_fsq.get_output_from_indices(indices) - -# (8, 1, 1024, 8) -# (residual layers, batch, seq, quantizers) +print(quantized_out.shape) +#> torch.Size([1, 1024, 256]) assert torch.all(quantized == quantized_out) ``` @@ -346,26 +357,34 @@ quantizer = LFQ( image_feats = torch.randn(1, 16, 32, 32) quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) # you may want to experiment with temperature +print(quantized.shape, indices.shape, entropy_aux_loss.shape) +#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([]) -# (1, 16, 32, 32), (1, 32, 32), (1,) - -assert image_feats.shape == quantized.shape assert (quantized == quantizer.indices_to_codes(indices)).all() ``` You can also pass in video features as `(batch, feat, time, height, width)` or sequences as `(batch, seq, feat)` ```python +import torch +from vector_quantize_pytorch import LFQ + +quantizer = LFQ( + codebook_size = 65536, + dim = 16, + entropy_loss_weight = 0.1, + diversity_gamma = 1. +) seq = torch.randn(1, 32, 16) quantized, *_ = quantizer(seq) -assert seq.shape == quantized.shape +# assert seq.shape == quantized.shape -video_feats = torch.randn(1, 16, 10, 32, 32) -quantized, *_ = quantizer(video_feats) +# video_feats = torch.randn(1, 16, 10, 32, 32) +# quantized, *_ = quantizer(video_feats) -assert video_feats.shape == quantized.shape +# assert video_feats.shape == quantized.shape ``` @@ -384,8 +403,8 @@ quantizer = LFQ( image_feats = torch.randn(1, 16, 32, 32) quantized, indices, entropy_aux_loss = quantizer(image_feats) - -# (1, 16, 32, 32), (1, 32, 32, 4), (1,) +print(quantized.shape, indices.shape, entropy_aux_loss.shape) +#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32, 4]) torch.Size([]) assert image_feats.shape == quantized.shape assert (quantized == quantizer.indices_to_codes(indices)).all() @@ -408,14 +427,12 @@ x = torch.randn(1, 1024, 256) residual_lfq.eval() quantized, indices, commit_loss = residual_lfq(x) - -# (1, 1024, 256), (1, 1024, 8), (8) -# (batch, seq, dim), (batch, seq, quantizers), (quantizers) +print(quantized.shape, indices.shape, commit_loss.shape) +#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([8]) quantized_out = residual_lfq.get_output_from_indices(indices) - -# (8, 1, 1024, 8) -# (residual layers, batch, seq, quantizers) +print(quantized_out.shape) +#> torch.Size([1, 1024, 256]) assert torch.all(quantized == quantized_out) ``` @@ -443,8 +460,8 @@ quantizer = LatentQuantize( image_feats = torch.randn(1, 16, 32, 32) quantized, indices, loss = quantizer(image_feats) - -# (1, 16, 32, 32), (1, 32, 32), (1,) +print(quantized.shape, indices.shape, loss.shape) +#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([]) assert image_feats.shape == quantized.shape assert (quantized == quantizer.indices_to_codes(indices)).all() @@ -454,15 +471,25 @@ You can also pass in video features as `(batch, feat, time, height, width)` or s ```python +import torch +from vector_quantize_pytorch import LatentQuantize + +quantizer = LatentQuantize( + levels = [5, 5, 8], + dim = 16, + commitment_loss_weight=0.1, + quantization_loss_weight=0.1, +) + seq = torch.randn(1, 32, 16) quantized, *_ = quantizer(seq) - -assert seq.shape == quantized.shape +print(quantized.shape) +#> torch.Size([1, 32, 16]) video_feats = torch.randn(1, 16, 10, 32, 32) quantized, *_ = quantizer(video_feats) - -assert video_feats.shape == quantized.shape +print(quantized.shape) +#> torch.Size([1, 16, 10, 32, 32]) ``` @@ -480,6 +507,8 @@ model = LatentQuantize(levels, dim, num_codebooks=num_codebooks) input_tensor = torch.randn(2, 3, dim) output_tensor, indices, loss = model(input_tensor) +print(output_tensor.shape, indices.shape, loss.shape) +#> torch.Size([2, 3, 9]) torch.Size([2, 3, 3]) torch.Size([]) assert output_tensor.shape == input_tensor.shape assert indices.shape == (2, 3, num_codebooks) diff --git a/pyproject.toml b/pyproject.toml index cd0d3de..030f4ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,9 @@ build-backend = "hatchling.build" managed = true dev-dependencies = [ "ruff>=0.4.2", + "pytest>=8.2.0", + "pytest-examples>=0.0.10", + "pytest-cov>=5.0.0", ] [tool.hatch.metadata] diff --git a/tests/test_examples_readme.py b/tests/test_examples_readme.py new file mode 100644 index 0000000..6c034bc --- /dev/null +++ b/tests/test_examples_readme.py @@ -0,0 +1,27 @@ +import pytest +from pytest_examples import find_examples, CodeExample, EvalExample + + +@pytest.mark.parametrize('example', find_examples('README.md'), ids=str) +def test_docstrings(example: CodeExample, eval_example: EvalExample): + """Test all examples (automatically) found in README. + + Usage, in an activated virtual env: + ```py + (.venv) pytest tests/test_examples_readme.py + ``` + + for a simple check on running the examples, and + ```py + (.venv) pytest tests/test_examples_readme.py --update-examples + ``` + + to lint and format the code in the README. + + """ + if eval_example.update_examples: + eval_example.format(example) + eval_example.lint(example) + eval_example.run_print_check(example) + else: + eval_example.run_print_check(example)