Skip to content

Commit

Permalink
remove the hack, as it does not work
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 30, 2024
1 parent c14fa4d commit 6c3f9a7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.14.43"
version = "1.14.44"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
18 changes: 3 additions & 15 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ def __init__(
threshold_ema_dead_code = 2,
reset_cluster_size = None,
use_ddp = False,
distributed_replace_codes = True,
learnable_codebook = False,
gumbel_sample = gumbel_sample,
sample_codebook_temp = 1.,
Expand Down Expand Up @@ -315,8 +314,7 @@ def __init__(

self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors

self.distributed_replace_codes = distributed_replace_codes
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors

self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
Expand Down Expand Up @@ -448,9 +446,6 @@ def replace(self, batch_samples, batch_mask):
sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
sampled = rearrange(sampled, '1 ... -> ...')

if not self.distributed_replace_codes:
sampled = maybe_distributed_mean(sampled)

self.embed.data[ind][mask] = sampled
self.cluster_size.data[ind][mask] = self.reset_cluster_size
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
Expand Down Expand Up @@ -559,7 +554,6 @@ def __init__(
threshold_ema_dead_code = 2,
reset_cluster_size = None,
use_ddp = False,
distributed_replace_codes = True,
learnable_codebook = False,
gumbel_sample = gumbel_sample,
sample_codebook_temp = 1.,
Expand Down Expand Up @@ -590,8 +584,7 @@ def __init__(

self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors

self.distributed_replace_codes = distributed_replace_codes
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors

self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
Expand Down Expand Up @@ -638,9 +631,6 @@ def replace(self, batch_samples, batch_mask):
sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
sampled = rearrange(sampled, '1 ... -> ...')

if not self.distributed_replace_codes:
sampled = maybe_distributed_mean(sampled)

self.embed.data[ind][mask] = sampled
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
self.cluster_size.data[ind][mask] = self.reset_cluster_size
Expand Down Expand Up @@ -762,7 +752,6 @@ def __init__(
stochastic_sample_codes = False,
sample_codebook_temp = 1.,
straight_through = False,
distributed_replace_codes = True,
reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all
sync_codebook = None,
sync_affine_param = False,
Expand Down Expand Up @@ -845,8 +834,7 @@ def __init__(
learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook,
sample_codebook_temp = sample_codebook_temp,
gumbel_sample = gumbel_sample_fn,
ema_update = ema_update,
distributed_replace_codes = distributed_replace_codes
ema_update = ema_update
)

if affine_param:
Expand Down

0 comments on commit 6c3f9a7

Please sign in to comment.