Skip to content

Commit

Permalink
Merge pull request #125 from MisterBourbaki/tests
Browse files Browse the repository at this point in the history
Add a simple test suite
  • Loading branch information
lucidrains authored May 6, 2024
2 parents 4a643eb + 709f978 commit 190ac99
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 45 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Tests the examples in README
on: push

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v4
- name: Install the latest version of rye
uses: eifinger/setup-rye@v2
- name: Use UV instead of pip
run: rye config --set-bool behavior.use-uv=true
- name: Install dependencies
run: |
rye sync
- name: Run pytest
run: rye run pytest --cov=. tests/test_examples_readme.py
119 changes: 74 additions & 45 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
```

## Residual VQ
Expand All @@ -46,16 +48,14 @@ residual_vq = ResidualVQ(
x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (1, 1024, 8), (1, 8)
# (batch, seq, dim), (batch, seq, quantizer), (batch, quantizer)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8])

# if you need all the codes across the quantization layers, just pass return_all_codes = True

quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)

# *_, (8, 1, 1024, 256)
# all_codes - (quantizer, batch, seq, dim)
print(all_codes.shape)
#> torch.Size([8, 1, 1024, 256])
```

Furthermore, <a href="https://arxiv.org/abs/2203.01941">this paper</a> uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.
Expand All @@ -77,9 +77,8 @@ residual_vq = ResidualVQ(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (8, 1, 1024), (8, 1)
# (batch, seq, dim), (quantizer, batch, seq), (quantizer, batch)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8])
```

<a href="https://arxiv.org/abs/2305.02765">A recent paper</a> further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing `GroupedResidualVQ`
Expand All @@ -98,9 +97,8 @@ residual_vq = GroupedResidualVQ(
x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)
# (batch, seq, dim), (groups, batch, seq, quantizer), (groups, batch, quantizer)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([2, 1, 1024, 8]) torch.Size([2, 1, 8])

```

Expand All @@ -122,6 +120,8 @@ residual_vq = ResidualVQ(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 4]) torch.Size([1, 4])
```

## Increasing codebook usage
Expand All @@ -144,6 +144,8 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
```

### Cosine similarity
Expand All @@ -162,6 +164,8 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
```

### Expiring stale codes
Expand All @@ -180,6 +184,8 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
```

### Orthogonal regularization loss
Expand All @@ -204,6 +210,8 @@ vq = VectorQuantize(
img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)
# loss now contains the orthogonal regularization loss with the weight as assigned
print(quantized.shape, indices.shape, loss.shape)
#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32]) torch.Size([1])
```

### Multi-headed VQ
Expand All @@ -226,10 +234,12 @@ vq = VectorQuantize(
)

img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32, 8), (1,)
quantized, indices, loss = vq(img_fmap)
print(quantized.shape, indices.shape, loss.shape)
#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32, 8]) torch.Size([1])

# indices shape - (batch, height, width, heads)
```

### Random Projection Quantizer

<a href="https://arxiv.org/abs/2202.01855">This paper</a> first proposed to use a random projection quantizer for masked speech modeling, where signals are projected with a randomly initialized matrix and then matched with a random initialized codebook. One therefore does not need to learn the quantizer. This technique was used by Google's <a href="https://ai.googleblog.com/2023/03/universal-speech-model-usm-state-of-art.html">Universal Speech Model</a> to achieve SOTA for speech-to-text modeling.
Expand All @@ -248,7 +258,9 @@ quantizer = RandomProjectionQuantizer(
)

x = torch.randn(1, 1024, 512)
indices = quantizer(x) # (1, 1024, 16) - (batch, seq, num_codebooks)
indices = quantizer(x)
print(indices.shape)
#> torch.Size([1, 1024, 16])
```

This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False`
Expand Down Expand Up @@ -279,10 +291,11 @@ quantizer = FSQ(levels)
x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
xhat, indices = quantizer(x)

print(xhat.shape) # (1, 1024, 4) - (batch, seq, dim)
print(indices.shape) # (1, 1024) - (batch, seq)
print(xhat.shape)
#> torch.Size([1, 1024, 4])
print(indices.shape)
#> torch.Size([1, 1024])

assert xhat.shape == x.shape
assert torch.all(xhat == quantizer.indices_to_codes(indices))
```

Expand All @@ -305,14 +318,12 @@ x = torch.randn(1, 1024, 256)
residual_fsq.eval()

quantized, indices = residual_fsq(x)

# (1, 1024, 256), (1, 1024, 8), (8)
# (batch, seq, dim), (batch, seq, quantizers), (quantizers)
print(quantized.shape, indices.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8])

quantized_out = residual_fsq.get_output_from_indices(indices)

# (8, 1, 1024, 8)
# (residual layers, batch, seq, quantizers)
print(quantized_out.shape)
#> torch.Size([1, 1024, 256])

assert torch.all(quantized == quantized_out)
```
Expand Down Expand Up @@ -346,26 +357,34 @@ quantizer = LFQ(
image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) # you may want to experiment with temperature
print(quantized.shape, indices.shape, entropy_aux_loss.shape)
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([])

# (1, 16, 32, 32), (1, 32, 32), (1,)

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
```

You can also pass in video features as `(batch, feat, time, height, width)` or sequences as `(batch, seq, feat)`

```python
import torch
from vector_quantize_pytorch import LFQ

quantizer = LFQ(
codebook_size = 65536,
dim = 16,
entropy_loss_weight = 0.1,
diversity_gamma = 1.
)

seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)

assert seq.shape == quantized.shape
# assert seq.shape == quantized.shape

video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)
# video_feats = torch.randn(1, 16, 10, 32, 32)
# quantized, *_ = quantizer(video_feats)

assert video_feats.shape == quantized.shape
# assert video_feats.shape == quantized.shape

```

Expand All @@ -384,8 +403,8 @@ quantizer = LFQ(
image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, entropy_aux_loss = quantizer(image_feats)

# (1, 16, 32, 32), (1, 32, 32, 4), (1,)
print(quantized.shape, indices.shape, entropy_aux_loss.shape)
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32, 4]) torch.Size([])

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
Expand All @@ -408,14 +427,12 @@ x = torch.randn(1, 1024, 256)
residual_lfq.eval()

quantized, indices, commit_loss = residual_lfq(x)

# (1, 1024, 256), (1, 1024, 8), (8)
# (batch, seq, dim), (batch, seq, quantizers), (quantizers)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([8])

quantized_out = residual_lfq.get_output_from_indices(indices)

# (8, 1, 1024, 8)
# (residual layers, batch, seq, quantizers)
print(quantized_out.shape)
#> torch.Size([1, 1024, 256])

assert torch.all(quantized == quantized_out)
```
Expand Down Expand Up @@ -443,8 +460,8 @@ quantizer = LatentQuantize(
image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, loss = quantizer(image_feats)

# (1, 16, 32, 32), (1, 32, 32), (1,)
print(quantized.shape, indices.shape, loss.shape)
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([])

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
Expand All @@ -454,15 +471,25 @@ You can also pass in video features as `(batch, feat, time, height, width)` or s

```python

import torch
from vector_quantize_pytorch import LatentQuantize

quantizer = LatentQuantize(
levels = [5, 5, 8],
dim = 16,
commitment_loss_weight=0.1,
quantization_loss_weight=0.1,
)

seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)

assert seq.shape == quantized.shape
print(quantized.shape)
#> torch.Size([1, 32, 16])

video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)

assert video_feats.shape == quantized.shape
print(quantized.shape)
#> torch.Size([1, 16, 10, 32, 32])

```

Expand All @@ -480,6 +507,8 @@ model = LatentQuantize(levels, dim, num_codebooks=num_codebooks)

input_tensor = torch.randn(2, 3, dim)
output_tensor, indices, loss = model(input_tensor)
print(output_tensor.shape, indices.shape, loss.shape)
#> torch.Size([2, 3, 9]) torch.Size([2, 3, 3]) torch.Size([])

assert output_tensor.shape == input_tensor.shape
assert indices.shape == (2, 3, num_codebooks)
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ build-backend = "hatchling.build"
managed = true
dev-dependencies = [
"ruff>=0.4.2",
"pytest>=8.2.0",
"pytest-examples>=0.0.10",
"pytest-cov>=5.0.0",
]

[tool.hatch.metadata]
Expand Down
27 changes: 27 additions & 0 deletions tests/test_examples_readme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from pytest_examples import find_examples, CodeExample, EvalExample


@pytest.mark.parametrize('example', find_examples('README.md'), ids=str)
def test_docstrings(example: CodeExample, eval_example: EvalExample):
"""Test all examples (automatically) found in README.
Usage, in an activated virtual env:
```py
(.venv) pytest tests/test_examples_readme.py
```
for a simple check on running the examples, and
```py
(.venv) pytest tests/test_examples_readme.py --update-examples
```
to lint and format the code in the README.
"""
if eval_example.update_examples:
eval_example.format(example)
eval_example.lint(example)
eval_example.run_print_check(example)
else:
eval_example.run_print_check(example)

0 comments on commit 190ac99

Please sign in to comment.