Skip to content

Commit

Permalink
bring back the other half of the commit loss even in the presence of …
Browse files Browse the repository at this point in the history
…rotation trick, addressing #177
  • Loading branch information
lucidrains committed Dec 3, 2024
1 parent 7be8916 commit 3aa6867
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/autoencoder_sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
seed = 1234

rotation_trick = True # rotation trick instead ot straight-through
use_mlp = True # use a one layer mlp with relu instead of linear
use_mlp = True # use a one layer mlp with relu instead of linear

device = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down
11 changes: 4 additions & 7 deletions vector_quantize_pytorch/sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,15 @@ def forward(

# commit loss and straight through, as was done in the paper

commit_loss = F.mse_loss(x.detach(), quantized)
commit_loss = (
F.mse_loss(x.detach(), quantized) +
F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
)

if self.rotation_trick:
# rotation trick from @cfifty
quantized = rotate_to(x, quantized)
else:

commit_loss = (
commit_loss +
F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
)

quantized = (quantized - x).detach() + x

quantized = inverse_pack(quantized)
Expand Down

0 comments on commit 3aa6867

Please sign in to comment.