Skip to content

Commit

Permalink
add ability to use cosine similarity for measuring distance to codes
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 19, 2021
1 parent 345ae69 commit a1b4a71
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 51 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
```

### Cosine Similarity

The <a href="https://openreview.net/forum?id=pfNyExj7z2">Improved VQGAN paper</a> also proposes to l2 normalize the codes and the encoded vectors, which boils down to using cosine similarity for the distance. They claim enforcing the vectors on a sphere leads to improvements in code usage and downstream reconstruction. You can turn this on by setting `use_cosine_sim = True`

```python
import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
dim = 256,
codebook_size = 256,
use_cosine_sim = True # set this to True
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
```

## Citations

```bibtex
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '0.3.1',
version = '0.3.2',
license='MIT',
description = 'Vector Quantization - Pytorch',
author = 'Phil Wang',
Expand Down
204 changes: 154 additions & 50 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def ema_inplace(moving_avg, new, decay):
def laplace_smoothing(x, n_categories, eps = 1e-5):
return (x + eps) / (x.sum() + n_categories * eps)

def kmeans(x, num_clusters, num_iters = 10):
def kmeans(x, num_clusters, num_iters = 10, use_cosine_sim = False):
samples = rearrange(x, '... d -> (...) d')
num_samples, dim, dtype, device = *samples.shape, x.dtype, x.device

Expand All @@ -27,9 +27,13 @@ def kmeans(x, num_clusters, num_iters = 10):
means = samples[indices]

for _ in range(num_iters):
diffs = rearrange(samples, 'n d -> n () d') - rearrange(means, 'c d -> () c d')
dists = (diffs ** 2).sum(dim = -1)
buckets = dists.argmin(dim = -1)
if use_cosine_sim:
dists = samples @ means.t()
buckets = dists.max(dim = -1).indices
else:
diffs = rearrange(samples, 'n d -> n () d') - rearrange(means, 'c d -> () c d')
dists = (diffs ** 2).sum(dim = -1)
buckets = dists.argmin(dim = -1)

bins = torch.bincount(buckets, minlength = num_clusters)
zero_mask = bins == 0
Expand All @@ -38,86 +42,186 @@ def kmeans(x, num_clusters, num_iters = 10):
new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype)
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples)
new_means = new_means / bins[..., None]

if use_cosine_sim:
new_means = F.normalize(new_means, dim = -1)

means = torch.where(zero_mask[..., None], means, new_means)

return rearrange(means, 'n d -> d n')
return means

class VectorQuantize(nn.Module):
# distance types

class EuclideanCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
decay = 0.8,
commitment = 1.,
eps = 1e-5,
n_embed = None,
kmeans_init = False,
kmeans_iters = 10,
codebook_dim = None
decay = 0.8,
eps = 1e-5
):
super().__init__()
n_embed = default(n_embed, codebook_size)
self.n_embed = n_embed

codebook_dim = default(codebook_dim, dim)
requires_projection = codebook_dim != dim
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()

self.decay = decay
self.eps = eps
self.commitment = commitment

init_fn = torch.randn if not kmeans_init else torch.zeros
embed = init_fn(codebook_dim, n_embed)
embed = init_fn(codebook_size, dim)

self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.eps = eps

self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.register_buffer('cluster_size', torch.zeros(n_embed))
self.register_buffer('cluster_size', torch.zeros(codebook_size))
self.register_buffer('embed', embed)
self.register_buffer('embed_avg', embed.clone())

@property
def codebook(self):
return self.embed.transpose(0, 1)

def init_embed_(self, data):
embed = kmeans(data, self.n_embed, self.kmeans_iters)
embed = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.initted.data.copy_(torch.Tensor([True]))

def forward(self, input):
input = self.project_in(input)

def forward(self, x):
if not self.initted:
self.init_embed_(input)
self.init_embed_(x)

shape, dtype = x.shape, x.dtype
flatten = rearrange(x, '... d -> (...) d')
embed = self.embed.t()

dtype = input.dtype
flatten = rearrange(input, '... d -> (...) d')
dist = (
dist = -(
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
- 2 * flatten @ embed
+ embed.pow(2).sum(0, keepdim=True)
)

_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
embed_ind = embed_ind.view(*input.shape[:-1])

commit_loss = 0.
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
embed_ind = dist.max(dim = -1).indices
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(x.dtype)
embed_ind = embed_ind.view(*shape[:-1])
quantize = F.embedding(embed_ind, self.embed)

if self.training:
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
ema_inplace(self.embed_avg, embed_sum, self.decay)
cluster_size = laplace_smoothing(self.cluster_size, self.n_embed, self.eps) * self.cluster_size.sum()
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
embed_sum = flatten.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)

commit_loss = F.mse_loss(quantize.detach(), input) * self.commitment
quantize = input + (quantize - input).detach()
return quantize, embed_ind

class CosineSimCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
kmeans_init = False,
kmeans_iters = 10,
decay = 0.8,
eps = 1e-5
):
super().__init__()
self.decay = decay

if not kmeans_init:
embed = F.normalize(torch.randn(codebook_size, dim), dim = -1)
else:
embed = torch.zeros(codebook_size, dim)

self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.eps = eps

self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.register_buffer('embed', embed)

def init_embed_(self, data):
embed = kmeans(data, self.codebook_size, self.kmeans_iters, use_cosine_sim = True)
self.embed.data.copy_(embed)
self.initted.data.copy_(torch.Tensor([True]))

def forward(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, '... d -> (...) d')
flatten = F.normalize(flatten, dim = -1)
embed = F.normalize(self.embed, dim = - 1)

if not self.initted:
self.init_embed_(flatten)

dist = flatten @ embed.t()
embed_ind = dist.max(dim = -1).indices
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = embed_ind.view(*shape[:-1])

quantize = F.embedding(embed_ind, self.embed)

if self.training:
bins = embed_onehot.sum(0)
zero_mask = (bins == 0)
bins = bins.masked_fill(zero_mask, 1.)

embed_sum = flatten.t() @ embed_onehot
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
embed_normalized = F.normalize(embed_normalized, dim = -1)
embed_normalized = torch.where(zero_mask[..., None], embed, embed_normalized)
ema_inplace(self.embed, embed_normalized, self.decay)

return quantize, embed_ind

# main class

class VectorQuantize(nn.Module):
def __init__(
self,
dim,
codebook_size,
n_embed = None,
codebook_dim = None,
decay = 0.8,
commitment = 1.,
eps = 1e-5,
kmeans_init = False,
kmeans_iters = 10,
use_cosine_sim = False
):
super().__init__()
n_embed = default(n_embed, codebook_size)

codebook_dim = default(codebook_dim, dim)
requires_projection = codebook_dim != dim
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()

self.eps = eps
self.commitment = commitment

klass = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook

self._codebook = klass(
dim = codebook_dim,
codebook_size = n_embed,
kmeans_init = kmeans_init,
kmeans_iters = kmeans_iters,
decay = decay,
eps = eps
)

@property
def codebook(self):
return self._codebook.codebook

def forward(self, x):
dtype = x.dtype
x = self.project_in(x)

quantize, embed_ind = self._codebook(x)

commit_loss = 0.
if self.training:
commit_loss = F.mse_loss(quantize.detach(), x) * self.commitment
quantize = x + (quantize - x).detach()

quantize = self.project_out(quantize)
return quantize, embed_ind, commit_loss

0 comments on commit a1b4a71

Please sign in to comment.