Skip to content

Commit

Permalink
Merge pull request #156 from lweitkamp/master
Browse files Browse the repository at this point in the history
Remove double l2 normalization
  • Loading branch information
lucidrains authored Aug 27, 2024
2 parents 1447998 + ab5a61a commit 15629a5
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def forward(
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
embed_normalized = l2norm(embed_normalized)

self.embed.data.copy_(l2norm(embed_normalized))
self.embed.data.copy_(embed_normalized)
self.expire_codes_(x)

if needs_codebook_dim:
Expand Down

0 comments on commit 15629a5

Please sign in to comment.