From 3aa68673e79b4ed20b93b302c64fe4c50f637685 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 3 Dec 2024 06:15:14 -0800 Subject: [PATCH] bring back the other half of the commit loss even in the presence of rotation trick, addressing https://github.com/lucidrains/vector-quantize-pytorch/issues/177 --- examples/autoencoder_sim_vq.py | 2 +- vector_quantize_pytorch/sim_vq.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/examples/autoencoder_sim_vq.py b/examples/autoencoder_sim_vq.py index 7269eaa..535b96d 100644 --- a/examples/autoencoder_sim_vq.py +++ b/examples/autoencoder_sim_vq.py @@ -16,7 +16,7 @@ seed = 1234 rotation_trick = True # rotation trick instead ot straight-through -use_mlp = True # use a one layer mlp with relu instead of linear +use_mlp = True # use a one layer mlp with relu instead of linear device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py index 427760b..ec02a65 100644 --- a/vector_quantize_pytorch/sim_vq.py +++ b/vector_quantize_pytorch/sim_vq.py @@ -118,18 +118,15 @@ def forward( # commit loss and straight through, as was done in the paper - commit_loss = F.mse_loss(x.detach(), quantized) + commit_loss = ( + F.mse_loss(x.detach(), quantized) + + F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight + ) if self.rotation_trick: # rotation trick from @cfifty quantized = rotate_to(x, quantized) else: - - commit_loss = ( - commit_loss + - F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight - ) - quantized = (quantized - x).detach() + x quantized = inverse_pack(quantized)