Skip to content

Commit

Permalink
fix distributed_replace_codes for #142
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 28, 2024
1 parent bb9e878 commit d23b27c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.14.41"
version = "1.14.42"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
33 changes: 27 additions & 6 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import partial
from functools import partial, cache
from collections import namedtuple

import torch
Expand Down Expand Up @@ -46,8 +46,6 @@ def cdist(x, y):
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))

# entropy

def entropy(prob, eps = 1e-5):
return (-prob * log(prob, eps = eps)).sum(dim = -1)

Expand Down Expand Up @@ -243,6 +241,20 @@ def batched_embedding(indices, embeds):
embeds = repeat(embeds, 'h c d -> h b c d', b = batch)
return embeds.gather(2, indices)

# distributed helpers

@cache
def is_distributed():
return distributed.is_initialized() and distributed.get_world_size() > 1

def maybe_distributed_mean(t):
if not is_distributed():
return t

distributed.all_reduce(t)
t = t / distributed.get_world_size()
return t

# regularization losses

def orthogonal_loss_fn(t):
Expand Down Expand Up @@ -302,6 +314,8 @@ def __init__(
assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'

self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors

self.distributed_replace_codes = distributed_replace_codes
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors

self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
Expand Down Expand Up @@ -433,9 +447,11 @@ def replace(self, batch_samples, batch_mask):
for ind, (samples, mask) in enumerate(zip(batch_samples, batch_mask)):
sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
sampled = rearrange(sampled, '1 ... -> ...')

self.embed.data[ind][mask] = sampled

if not self.distributed_replace_codes:
sampled = maybe_distributed_mean(sampled)

self.embed.data[ind][mask] = sampled
self.cluster_size.data[ind][mask] = self.reset_cluster_size
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size

Expand Down Expand Up @@ -573,6 +589,8 @@ def __init__(
self.sample_codebook_temp = sample_codebook_temp

self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors

self.distributed_replace_codes = distributed_replace_codes
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors

self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
Expand Down Expand Up @@ -620,6 +638,9 @@ def replace(self, batch_samples, batch_mask):
sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
sampled = rearrange(sampled, '1 ... -> ...')

if not self.distributed_replace_codes:
sampled = maybe_distributed_mean(sampled)

self.embed.data[ind][mask] = sampled
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
self.cluster_size.data[ind][mask] = self.reset_cluster_size
Expand Down Expand Up @@ -808,7 +829,7 @@ def __init__(
)

if not exists(sync_codebook):
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
sync_codebook = is_distributed()

codebook_kwargs = dict(
dim = codebook_dim,
Expand Down

0 comments on commit d23b27c

Please sign in to comment.