Skip to content

Commit

Permalink
handle expiration of codes for residual vq with shared codebooks, han…
Browse files Browse the repository at this point in the history
…dling #162
  • Loading branch information
lucidrains committed Sep 26, 2024
1 parent 6105a32 commit 22a0375
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.17.6"
version = "1.17.7"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
14 changes: 12 additions & 2 deletions vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ def forward(
if self.implicit_neural_codebook:
maybe_code_transforms = (None, *self.mlps)

# save all inputs across layers, for use during expiration at end under shared codebook setting

all_residuals = []

# go through the layers

for quantizer_index, (vq, maybe_mlp) in enumerate(zip(self.layers, maybe_code_transforms)):
Expand All @@ -333,6 +337,10 @@ def forward(
if exists(maybe_mlp):
maybe_mlp = partial(maybe_mlp, condition = quantized_out)

# save for expiration

all_residuals.append(residual)

# vector quantize forward

quantized, *rest = vq(
Expand Down Expand Up @@ -360,8 +368,10 @@ def forward(
# if shared codebook, update ema only at end

if self.shared_codebook:
first(self.layers)._codebook.update_ema()
first(self.layers).update_in_place_optimizer()
shared_layer = first(self.layers)
shared_layer._codebook.update_ema()
shared_layer.update_in_place_optimizer()
shared_layer.expire_codes_(torch.cat(all_residuals, dim = -2))

# project out, if needed

Expand Down
16 changes: 13 additions & 3 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,18 @@ def update_in_place_optimizer(self):
self.in_place_codebook_optimizer.step()
self.in_place_codebook_optimizer.zero_grad()

def maybe_split_heads_from_input(self, x):
if self.heads == 1:
return x

ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d'
return rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = self.heads)

def expire_codes_(self, x):
x = self._codebook.transform_input(x)
x = self.maybe_split_heads_from_input(x)
self._codebook.expire_codes_(x)

def forward(
self,
x,
Expand Down Expand Up @@ -1024,9 +1036,7 @@ def forward(

# handle multi-headed separate codebooks

if is_multiheaded:
ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d'
x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads)
x = self.maybe_split_heads_from_input(x)

# l2norm for cosine sim, otherwise identity

Expand Down

0 comments on commit 22a0375

Please sign in to comment.