From fa2211d63b7c682b64dd2718799b674a13af4b0d Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 7 Jan 2025 08:20:13 -0800 Subject: [PATCH] do not do straight through nor rotation trick if input does not require grad, to make Genie2 cleaner --- pyproject.toml | 2 +- tests/test_readme.py | 8 +++++++- .../vector_quantize_pytorch.py | 15 +++++++++------ 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e3e105a..62a2f9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.21.0" +version = "1.21.1" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_readme.py b/tests/test_readme.py index 6ada148..17e8907 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -6,9 +6,11 @@ def exists(v): @pytest.mark.parametrize('use_cosine_sim', (True, False)) @pytest.mark.parametrize('rotation_trick', (True, False)) +@pytest.mark.parametrize('input_requires_grad', (True, False)) def test_vq( use_cosine_sim, - rotation_trick + rotation_trick, + input_requires_grad ): from vector_quantize_pytorch import VectorQuantize @@ -22,6 +24,10 @@ def test_vq( ) x = torch.randn(1, 1024, 256) + + if input_requires_grad: + x.requires_grad_() + quantized, indices, commit_loss = vq(x) def test_vq_eval(): diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 394c215..01d901c 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -1023,7 +1023,7 @@ def forward( return_loss_breakdown = False, codebook_transform_fn: Callable | None = None ): - orig_input = x + orig_input, input_requires_grad = x, x.requires_grad # handle masking, either passed in as `mask` or `lens` @@ -1117,11 +1117,14 @@ def forward( commit_quantize = maybe_detach(quantize) - if self.rotation_trick: - quantize = rotate_to(x, quantize) - else: - # standard STE to get gradients through VQ layer. - quantize = x + (quantize - x).detach() + # spare rotation trick calculation if inputs do not need gradients + + if input_requires_grad: + if self.rotation_trick: + quantize = rotate_to(x, quantize) + else: + # standard STE to get gradients through VQ layer. + quantize = x + (quantize - x).detach() if self.sync_update_v > 0.: # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf