diff --git a/pyproject.toml b/pyproject.toml index 420295a..766c211 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.17.6" +version = "1.17.7" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/residual_vq.py b/vector_quantize_pytorch/residual_vq.py index ef1b135..bae1168 100644 --- a/vector_quantize_pytorch/residual_vq.py +++ b/vector_quantize_pytorch/residual_vq.py @@ -315,6 +315,10 @@ def forward( if self.implicit_neural_codebook: maybe_code_transforms = (None, *self.mlps) + # save all inputs across layers, for use during expiration at end under shared codebook setting + + all_residuals = [] + # go through the layers for quantizer_index, (vq, maybe_mlp) in enumerate(zip(self.layers, maybe_code_transforms)): @@ -333,6 +337,10 @@ def forward( if exists(maybe_mlp): maybe_mlp = partial(maybe_mlp, condition = quantized_out) + # save for expiration + + all_residuals.append(residual) + # vector quantize forward quantized, *rest = vq( @@ -360,8 +368,10 @@ def forward( # if shared codebook, update ema only at end if self.shared_codebook: - first(self.layers)._codebook.update_ema() - first(self.layers).update_in_place_optimizer() + shared_layer = first(self.layers) + shared_layer._codebook.update_ema() + shared_layer.update_in_place_optimizer() + shared_layer.expire_codes_(torch.cat(all_residuals, dim = -2)) # project out, if needed diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 0d4dc0d..b6160f0 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -975,6 +975,18 @@ def update_in_place_optimizer(self): self.in_place_codebook_optimizer.step() self.in_place_codebook_optimizer.zero_grad() + def maybe_split_heads_from_input(self, x): + if self.heads == 1: + return x + + ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d' + return rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = self.heads) + + def expire_codes_(self, x): + x = self._codebook.transform_input(x) + x = self.maybe_split_heads_from_input(x) + self._codebook.expire_codes_(x) + def forward( self, x, @@ -1024,9 +1036,7 @@ def forward( # handle multi-headed separate codebooks - if is_multiheaded: - ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d' - x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads) + x = self.maybe_split_heads_from_input(x) # l2norm for cosine sim, otherwise identity