Skip to content

Commit

Permalink
test train and eval pathways for residual vq
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 15, 2024
1 parent 3ad24a3 commit 1447998
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def test_vq_mask():

@pytest.mark.parametrize('implicit_neural_codebook', (True, False))
@pytest.mark.parametrize('use_cosine_sim', (True, False))
@pytest.mark.parametrize('train', (True, False))
def test_residual_vq(
implicit_neural_codebook,
use_cosine_sim
use_cosine_sim,
train
):
from vector_quantize_pytorch import ResidualVQ

Expand All @@ -80,14 +82,9 @@ def test_residual_vq(

x = torch.randn(1, 256, 32)

quantized, indices, commit_loss = residual_vq(x)
quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)

# test eval mode and `get_output_from_indices`

residual_vq.eval()
quantized, indices, commit_loss = residual_vq(x)
residual_vq.train(train)

quantized, indices, commit_loss = residual_vq(x, freeze_codebook = train and not implicit_neural_codebook)
quantized_out = residual_vq.get_output_from_indices(indices)
assert torch.allclose(quantized, quantized_out, atol = 1e-6)

Expand Down

0 comments on commit 1447998

Please sign in to comment.