diff --git a/pyproject.toml b/pyproject.toml index 8150df5..7880b1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.14.43" +version = "1.14.44" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 2c41ec9..89e5793 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -280,7 +280,6 @@ def __init__( threshold_ema_dead_code = 2, reset_cluster_size = None, use_ddp = False, - distributed_replace_codes = True, learnable_codebook = False, gumbel_sample = gumbel_sample, sample_codebook_temp = 1., @@ -315,8 +314,7 @@ def __init__( self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors - self.distributed_replace_codes = distributed_replace_codes - self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors + self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop self.all_reduce_fn = distributed.all_reduce if use_ddp else noop @@ -448,9 +446,6 @@ def replace(self, batch_samples, batch_mask): sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) sampled = rearrange(sampled, '1 ... -> ...') - if not self.distributed_replace_codes: - sampled = maybe_distributed_mean(sampled) - self.embed.data[ind][mask] = sampled self.cluster_size.data[ind][mask] = self.reset_cluster_size self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size @@ -559,7 +554,6 @@ def __init__( threshold_ema_dead_code = 2, reset_cluster_size = None, use_ddp = False, - distributed_replace_codes = True, learnable_codebook = False, gumbel_sample = gumbel_sample, sample_codebook_temp = 1., @@ -590,8 +584,7 @@ def __init__( self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors - self.distributed_replace_codes = distributed_replace_codes - self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors + self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop self.all_reduce_fn = distributed.all_reduce if use_ddp else noop @@ -638,9 +631,6 @@ def replace(self, batch_samples, batch_mask): sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) sampled = rearrange(sampled, '1 ... -> ...') - if not self.distributed_replace_codes: - sampled = maybe_distributed_mean(sampled) - self.embed.data[ind][mask] = sampled self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size self.cluster_size.data[ind][mask] = self.reset_cluster_size @@ -762,7 +752,6 @@ def __init__( stochastic_sample_codes = False, sample_codebook_temp = 1., straight_through = False, - distributed_replace_codes = True, reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all sync_codebook = None, sync_affine_param = False, @@ -845,8 +834,7 @@ def __init__( learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook, sample_codebook_temp = sample_codebook_temp, gumbel_sample = gumbel_sample_fn, - ema_update = ema_update, - distributed_replace_codes = distributed_replace_codes + ema_update = ema_update ) if affine_param: