From e7ff7d7f543061d64e7bac2d69858da474ef6c77 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 8 Nov 2024 08:12:16 -0800 Subject: [PATCH] take care of a small pain point --- examples/autoencoder.py | 55 +++++++++-------------- examples/autoencoder_fsq.py | 51 ++++++++------------- examples/autoencoder_lfq.py | 69 +++++++++++------------------ pyproject.toml | 2 +- vector_quantize_pytorch/__init__.py | 2 + vector_quantize_pytorch/utils.py | 57 ++++++++++++++++++++++++ 6 files changed, 126 insertions(+), 110 deletions(-) create mode 100644 vector_quantize_pytorch/utils.py diff --git a/examples/autoencoder.py b/examples/autoencoder.py index 6b529b1..3b64d58 100644 --- a/examples/autoencoder.py +++ b/examples/autoencoder.py @@ -8,45 +8,29 @@ from torchvision import datasets, transforms from torch.utils.data import DataLoader -from vector_quantize_pytorch import VectorQuantize - +from vector_quantize_pytorch import VectorQuantize, Sequential lr = 3e-4 train_iter = 1000 num_codes = 256 seed = 1234 +rotation_trick = True device = "cuda" if torch.cuda.is_available() else "cpu" - -class SimpleVQAutoEncoder(nn.Module): - def __init__(self, **vq_kwargs): - super().__init__() - self.layers = nn.ModuleList( - [ - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), - nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), - ] - ) - return - - def forward(self, x): - for layer in self.layers: - if isinstance(layer, VectorQuantize): - x, indices, commit_loss = layer(x) - else: - x = layer(x) - - return x.clamp(-1, 1), indices, commit_loss - +def SimpleVQAutoEncoder(**vq_kwargs): + return Sequential( + nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.GELU(), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + ) def train(model, train_loader, train_iterations=1000, alpha=10): def iterate_dataset(data_loader): @@ -62,7 +46,10 @@ def iterate_dataset(data_loader): for _ in (pbar := trange(train_iterations)): opt.zero_grad() x, _ = next(iterate_dataset(train_loader)) + out, indices, cmt_loss = model(x) + out = out.clamp(-1., 1.) + rec_loss = (out - x).abs().mean() (rec_loss + alpha * cmt_loss).backward() @@ -72,8 +59,6 @@ def iterate_dataset(data_loader): + f"cmt loss: {cmt_loss.item():.3f} | " + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" ) - return - transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] @@ -91,7 +76,7 @@ def iterate_dataset(data_loader): model = SimpleVQAutoEncoder( codebook_size=num_codes, - rotation_trick=True + rotation_trick=rotation_trick ).to(device) opt = torch.optim.AdamW(model.parameters(), lr=lr) diff --git a/examples/autoencoder_fsq.py b/examples/autoencoder_fsq.py index 081b2b6..a508f51 100644 --- a/examples/autoencoder_fsq.py +++ b/examples/autoencoder_fsq.py @@ -9,7 +9,7 @@ from torchvision import datasets, transforms from torch.utils.data import DataLoader -from vector_quantize_pytorch import FSQ +from vector_quantize_pytorch import FSQ, Sequential lr = 3e-4 @@ -20,36 +20,22 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -class SimpleFSQAutoEncoder(nn.Module): - def __init__(self, levels: list[int]): - super().__init__() - self.layers = nn.ModuleList( - [ - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(32, len(levels), kernel_size=1), - FSQ(levels), - nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), - nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), - ] - ) - return - - def forward(self, x): - for layer in self.layers: - if isinstance(layer, FSQ): - x, indices = layer(x) - else: - x = layer(x) - - return x.clamp(-1, 1), indices +def SimpleFSQAutoEncoder(levels: list[int]): + return Sequential( + nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(32, len(levels), kernel_size=1), + FSQ(levels), + nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.GELU(), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + ) def train(model, train_loader, train_iterations=1000): @@ -67,6 +53,8 @@ def iterate_dataset(data_loader): opt.zero_grad() x, _ = next(iterate_dataset(train_loader)) out, indices = model(x) + out = out.clamp(-1., 1.) + rec_loss = (out - x).abs().mean() rec_loss.backward() @@ -75,7 +63,6 @@ def iterate_dataset(data_loader): f"rec loss: {rec_loss.item():.3f} | " + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" ) - return transform = transforms.Compose( diff --git a/examples/autoencoder_lfq.py b/examples/autoencoder_lfq.py index e30e67c..0a48962 100644 --- a/examples/autoencoder_lfq.py +++ b/examples/autoencoder_lfq.py @@ -10,7 +10,7 @@ from torchvision import datasets, transforms from torch.utils.data import DataLoader -from vector_quantize_pytorch import LFQ +from vector_quantize_pytorch import LFQ, Sequential lr = 3e-4 train_iter = 1000 @@ -22,46 +22,31 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -class LFQAutoEncoder(nn.Module): - def __init__( - self, - codebook_size, - **vq_kwargs - ): - super().__init__() - assert log2(codebook_size).is_integer() - quantize_dim = int(log2(codebook_size)) - - self.encode = nn.Sequential( - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - # In general norm layers are commonly used in Resnet-based encoder/decoders - # explicitly add one here with affine=False to avoid introducing new parameters - nn.GroupNorm(4, 32, affine=False), - nn.Conv2d(32, quantize_dim, kernel_size=1), - ) - - self.quantize = LFQ(dim=quantize_dim, **vq_kwargs) - - self.decode = nn.Sequential( - nn.Conv2d(quantize_dim, 32, kernel_size=3, stride=1, padding=1), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), - nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), - ) - return - - def forward(self, x): - x = self.encode(x) - x, indices, entropy_aux_loss = self.quantize(x) - x = self.decode(x) - return x.clamp(-1, 1), indices, entropy_aux_loss - +def LFQAutoEncoder( + codebook_size, + **vq_kwargs +): + assert log2(codebook_size).is_integer() + quantize_dim = int(log2(codebook_size)) + + return Sequential( + nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + # In general norm layers are commonly used in Resnet-based encoder/decoders + # explicitly add one here with affine=False to avoid introducing new parameters + nn.GroupNorm(4, 32, affine=False), + nn.Conv2d(32, quantize_dim, kernel_size=1), + LFQ(dim=quantize_dim, **vq_kwargs), + nn.Conv2d(quantize_dim, 32, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.GELU(), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + ) def train(model, train_loader, train_iterations=1000): def iterate_dataset(data_loader): @@ -78,6 +63,7 @@ def iterate_dataset(data_loader): opt.zero_grad() x, _ = next(iterate_dataset(train_loader)) out, indices, entropy_aux_loss = model(x) + out = out.clamp(-1., 1.) rec_loss = F.l1_loss(out, x) (rec_loss + entropy_aux_loss).backward() @@ -88,7 +74,6 @@ def iterate_dataset(data_loader): + f"entropy aux loss: {entropy_aux_loss.item():.3f} | " + f"active %: {indices.unique().numel() / codebook_size * 100:.3f}" ) - return transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] diff --git a/pyproject.toml b/pyproject.toml index 5834438..7d38ea2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.19.3" +version = "1.19.4" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/__init__.py b/vector_quantize_pytorch/__init__.py index 451a402..06a0a24 100644 --- a/vector_quantize_pytorch/__init__.py +++ b/vector_quantize_pytorch/__init__.py @@ -6,3 +6,5 @@ from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ from vector_quantize_pytorch.latent_quantization import LatentQuantize + +from vector_quantize_pytorch.utils import Sequential diff --git a/vector_quantize_pytorch/utils.py b/vector_quantize_pytorch/utils.py new file mode 100644 index 0000000..bdb4386 --- /dev/null +++ b/vector_quantize_pytorch/utils.py @@ -0,0 +1,57 @@ +import torch +from torch import nn +from torch.nn import Module, ModuleList + +# quantization + +from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize +from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ +from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer +from vector_quantize_pytorch.finite_scalar_quantization import FSQ +from vector_quantize_pytorch.lookup_free_quantization import LFQ +from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ +from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ +from vector_quantize_pytorch.latent_quantization import LatentQuantize + +QUANTIZE_KLASSES = ( + VectorQuantize, + ResidualVQ, + GroupedResidualVQ, + RandomProjectionQuantizer, + FSQ, + LFQ, + ResidualLFQ, + GroupedResidualLFQ, + ResidualFSQ, + GroupedResidualFSQ, + LatentQuantize +) + +# classes + +class Sequential(Module): + def __init__( + self, + *fns: Module + ): + super().__init__() + assert sum([int(isinstance(fn, QUANTIZE_KLASSES)) for fn in fns]) == 1, 'this special Sequential must contain exactly one quantizer' + + self.fns = ModuleList(fns) + + def forward( + self, + x, + **kwargs + ): + for fn in self.fns: + + if not isinstance(fn, QUANTIZE_KLASSES): + x = fn(x) + continue + + x, *rest = fn(x, **kwargs) + + output = (x, *rest) + + return output