Skip to content

Commit

Permalink
complete residual sim vq
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 12, 2024
1 parent 8f5b428 commit 81d0f3b
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.20.4"
version = "1.20.5"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
17 changes: 17 additions & 0 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,20 @@ def test_sim_vq():

assert x.shape == quantized.shape
assert torch.allclose(quantized, sim_vq.indices_to_codes(indices), atol = 1e-6)

def test_residual_sim_vq():

from vector_quantize_pytorch import ResidualSimVQ

residual_sim_vq = ResidualSimVQ(
dim = 512,
num_quantizers = 4,
codebook_size = 1024,
accept_image_fmap = True
)

x = torch.randn(1, 512, 32, 32)
quantized, indices, commit_loss = residual_sim_vq(x)

assert x.shape == quantized.shape
assert torch.allclose(quantized, residual_sim_vq.get_output_from_indices(indices), atol = 1e-5)
2 changes: 2 additions & 0 deletions vector_quantize_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
from vector_quantize_pytorch.latent_quantization import LatentQuantize

from vector_quantize_pytorch.sim_vq import SimVQ
from vector_quantize_pytorch.residual_sim_vq import ResidualSimVQ

from vector_quantize_pytorch.utils import Sequential
226 changes: 226 additions & 0 deletions vector_quantize_pytorch/residual_sim_vq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from __future__ import annotations

import random
from math import ceil
from functools import partial, cache
from itertools import zip_longest

import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
import torch.distributed as dist

from vector_quantize_pytorch.sim_vq import SimVQ, pack_one

from einx import get_at
from einops import rearrange, repeat, reduce, pack, unpack

# helper functions

def exists(val):
return val is not None

def first(it):
return it[0]

def default(val, d):
return val if exists(val) else d

def round_up_multiple(num, mult):
return ceil(num / mult) * mult

# distributed helpers

def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1

def get_maybe_sync_seed(device, max_size = 10_000):
rand_int = torch.randint(0, max_size, (), device = device)

if is_distributed():
dist.all_reduce(rand_int)

return rand_int.item()

# main class

class ResidualSimVQ(Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """

def __init__(
self,
*,
dim,
num_quantizers,
codebook_size,
heads = 1,
quantize_dropout = False,
quantize_dropout_cutoff_index = 0,
quantize_dropout_multiple_of = 1,
accept_image_fmap = False,
rotation_trick = True, # rotation trick from @cfifty, on top of sim vq
**sim_vq_kwargs
):
super().__init__()
assert heads == 1, 'residual vq is not compatible with multi-headed codes'

self.accept_image_fmap = accept_image_fmap

self.num_quantizers = num_quantizers

# define sim vq across layers

self.layers = ModuleList([SimVQ(dim = dim, codebook_size = codebook_size, rotation_trick = rotation_trick, accept_image_fmap = accept_image_fmap, **sim_vq_kwargs) for _ in range(num_quantizers)])

# quantize dropout

self.quantize_dropout = quantize_dropout and num_quantizers > 1

assert quantize_dropout_cutoff_index >= 0

self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4

@property
def codebook_size(self):
return first(self.layers).codebook_size

@property
def codebook_dim(self):
return first(self.layers).codebook_dim

@property
def codebooks(self):
codebooks = [layer.codebook for layer in self.layers]
codebooks = torch.stack(codebooks)
return codebooks

def get_codes_from_indices(self, indices):

batch, quantize_dim = indices.shape[0], indices.shape[-1]

# may also receive indices in the shape of 'b h w q' (accept_image_fmap)

indices, inverse = pack_one(indices, 'b * q')

# because of quantize dropout, one can pass in indices that are coarse
# and the network should be able to reconstruct

if quantize_dim < self.num_quantizers:
assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)

# take care of quantizer dropout

mask = indices == -1.
indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later

all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)

# mask out any codes that were dropout-ed

all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)

# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)

all_codes = inverse(all_codes, 'q b * d')

if self.accept_image_fmap:
all_codes = rearrange(all_codes, 'q b ... d -> q b d ...')

return all_codes

def get_output_from_indices(self, indices):
all_codes = self.get_codes_from_indices(indices)
summed_residual_codes = reduce(all_codes, 'q ... -> ...', 'sum')
return summed_residual_codes

def forward(
self,
x,
indices: Tensor | list[Tensor] | None = None,
return_all_codes = False,
rand_quantize_dropout_fixed_seed = None
):
num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device

assert not (self.accept_image_fmap and exists(indices))

quantized_out = 0.
residual = x

all_losses = []
all_indices = []

if isinstance(indices, list):
indices = torch.stack(indices)

should_quantize_dropout = self.training and self.quantize_dropout and not return_loss

# sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss

if should_quantize_dropout:

# check if seed is manually passed in

if not exists(rand_quantize_dropout_fixed_seed):
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)

rand = random.Random(rand_quantize_dropout_fixed_seed)

rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)

if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1

null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)

# save all inputs across layers, for use during expiration at end under shared codebook setting

all_residuals = []

# go through the layers

for quantizer_index, sim_vq in enumerate(self.layers):

if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
continue

# save for expiration

all_residuals.append(residual)

# sim vq forward

quantized, *rest = sim_vq(residual)

residual = residual - quantized.detach()
quantized_out = quantized_out + quantized

embed_indices, loss = rest

all_indices.append(embed_indices)
all_losses.append(loss)

# stack all losses and indices

all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices))

ret = (quantized_out, all_indices, all_losses)

if not return_all_codes:
return ret

# whether to return all codes from all codebooks across layers

all_codes = self.get_codes_from_indices(all_indices)

# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)

return (*ret, all_codes)
2 changes: 1 addition & 1 deletion vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
manual_in_place_optimizer_update = True
)

# take care of maybe different codebook sizes across depth, used in TIGER paper https://arxiv.org/abs/2305.05065
# take care of maybe different codebook sizes across depth

codebook_sizes = cast_tuple(codebook_size, num_quantizers)

Expand Down
8 changes: 6 additions & 2 deletions vector_quantize_pytorch/sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,21 @@ def __init__(
accept_image_fmap = False,
rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
input_to_quantize_commit_loss_weight = 0.25,
frozen_codebook_dim = None # frozen codebook dim could have different dimensions than projection
):
super().__init__()
self.codebook_size = codebook_size
self.accept_image_fmap = accept_image_fmap

codebook = torch.randn(codebook_size, dim) * (dim ** -0.5)
frozen_codebook_dim = default(frozen_codebook_dim, dim)
codebook = torch.randn(codebook_size, frozen_codebook_dim) * (frozen_codebook_dim ** -0.5)
codebook = init_fn(codebook)

# the codebook is actually implicit from a linear layer from frozen gaussian or uniform


if not exists(codebook_transform):
codebook_transform = nn.Linear(dim, dim, bias = False)
codebook_transform = nn.Linear(frozen_codebook_dim, dim, bias = False)

self.code_transform = codebook_transform

Expand Down

0 comments on commit 81d0f3b

Please sign in to comment.