Skip to content

Commit

Permalink
Merge pull request #4 from wesbz/soundstream
Browse files Browse the repository at this point in the history
Updating the RVQ to better match Soundstream's algorithm
  • Loading branch information
lucidrains authored Oct 20, 2021
2 parents b1f5d8e + d4f0665 commit 369240f
Showing 1 changed file with 53 additions and 37 deletions.
90 changes: 53 additions & 37 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
if use_cosine_sim:
dists = samples @ means.t()
else:
diffs = rearrange(samples, 'n d -> n () d') - rearrange(means, 'c d -> () c d')
diffs = rearrange(samples, 'n d -> n () d') \
- rearrange(means, 'c d -> () c d')
dists = -(diffs ** 2).sum(dim = -1)

buckets = dists.max(dim = -1).indices
Expand Down Expand Up @@ -66,7 +67,8 @@ def __init__(
kmeans_init = False,
kmeans_iters = 10,
decay = 0.8,
eps = 1e-5
eps = 1e-5,
threshold_ema_dead_code = 2
):
super().__init__()
self.decay = decay
Expand All @@ -76,6 +78,7 @@ def __init__(
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code

self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.register_buffer('cluster_size', torch.zeros(codebook_size))
Expand All @@ -89,9 +92,22 @@ def init_embed_(self, data):
self.initted.data.copy_(torch.Tensor([True]))

def replace(self, samples, mask):
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
modified_codebook = torch.where(
mask[..., None],
sample_vectors(samples, self.codebook_size),
self.embed
)
self.embed.data.copy_(modified_codebook)

def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return

expired_codes = self.cluster_size < self.threshold_ema_dead_code
if torch.any(expired_codes):
batch_samples = rearrange(batch_samples, '... d -> (...) d')
self.replace(batch_samples, mask = expired_codes)

def forward(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, '... d -> (...) d')
Expand All @@ -107,7 +123,7 @@ def forward(self, x):
)

embed_ind = dist.max(dim = -1).indices
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(x.dtype)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = embed_ind.view(*shape[:-1])
quantize = F.embedding(embed_ind, self.embed)

Expand All @@ -118,6 +134,7 @@ def forward(self, x):
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
self.expire_codes_(x)

return quantize, embed_ind

Expand All @@ -129,7 +146,8 @@ def __init__(
kmeans_init = False,
kmeans_iters = 10,
decay = 0.8,
eps = 1e-5
eps = 1e-5,
threshold_ema_dead_code = 2
):
super().__init__()
self.decay = decay
Expand All @@ -142,20 +160,35 @@ def __init__(
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code

self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.register_buffer('embed', embed)

def init_embed_(self, data):
embed = kmeans(data, self.codebook_size, self.kmeans_iters, use_cosine_sim = True)
embed = kmeans(data, self.codebook_size, self.kmeans_iters,
use_cosine_sim = True)
self.embed.data.copy_(embed)
self.initted.data.copy_(torch.Tensor([True]))

def replace(self, samples, mask):
samples = l2norm(samples)
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
modified_codebook = torch.where(
mask[..., None],
sample_vectors(samples, self.codebook_size),
self.embed
)
self.embed.data.copy_(modified_codebook)

def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return

expired_codes = self.cluster_size < self.threshold_ema_dead_code
if torch.any(expired_codes):
batch_samples = rearrange(batch_samples, '... d -> (...) d')
self.replace(batch_samples, mask = expired_codes)

def forward(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, '... d -> (...) d')
Expand All @@ -180,8 +213,10 @@ def forward(self, x):
embed_sum = flatten.t() @ embed_onehot
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
embed_normalized = l2norm(embed_normalized)
embed_normalized = torch.where(zero_mask[..., None], embed, embed_normalized)
embed_normalized = torch.where(zero_mask[..., None], embed,
embed_normalized)
ema_inplace(self.embed, embed_normalized, self.decay)
self.expire_codes_(x)

return quantize, embed_ind

Expand All @@ -200,59 +235,41 @@ def __init__(
kmeans_init = False,
kmeans_iters = 10,
use_cosine_sim = False,
max_codebook_misses_before_expiry = 0
threshold_ema_dead_code = 0
):
super().__init__()
n_embed = default(n_embed, codebook_size)

codebook_dim = default(codebook_dim, dim)
requires_projection = codebook_dim != dim
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection \
else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection \
else nn.Identity()

self.eps = eps
self.commitment = commitment

klass = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
codebook_class = EuclideanCodebook if not use_cosine_sim \
else CosineSimCodebook

self._codebook = klass(
self._codebook = codebook_class(
dim = codebook_dim,
codebook_size = n_embed,
kmeans_init = kmeans_init,
kmeans_iters = kmeans_iters,
decay = decay,
eps = eps
eps = eps,
threshold_ema_dead_code = threshold_ema_dead_code
)

self.codebook_size = codebook_size
self.max_codebook_misses_before_expiry = max_codebook_misses_before_expiry

if max_codebook_misses_before_expiry > 0:
codebook_misses = torch.zeros(codebook_size)
self.register_buffer('codebook_misses', codebook_misses)

@property
def codebook(self):
return self._codebook.codebook

def expire_codes_(self, embed_ind, batch_samples):
if self.max_codebook_misses_before_expiry == 0:
return

embed_ind = rearrange(embed_ind, '... -> (...)')
misses = torch.bincount(embed_ind, minlength = self.codebook_size) == 0
self.codebook_misses += misses

expired_codes = self.codebook_misses >= self.max_codebook_misses_before_expiry
if not torch.any(expired_codes):
return

self.codebook_misses.masked_fill_(expired_codes, 0)
batch_samples = rearrange(batch_samples, '... d -> (...) d')
self._codebook.replace(batch_samples, mask = expired_codes)

def forward(self, x):
dtype = x.dtype
x = self.project_in(x)

quantize, embed_ind = self._codebook(x)
Expand All @@ -262,7 +279,6 @@ def forward(self, x):
if self.training:
commit_loss = F.mse_loss(quantize.detach(), x) * self.commitment
quantize = x + (quantize - x).detach()
self.expire_codes_(embed_ind, x)

quantize = self.project_out(quantize)
return quantize, embed_ind, commit_loss

0 comments on commit 369240f

Please sign in to comment.