Skip to content

Commit

Permalink
take care of a small pain point
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 8, 2024
1 parent 35c4c84 commit e7ff7d7
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 110 deletions.
55 changes: 20 additions & 35 deletions examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,29 @@
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from vector_quantize_pytorch import VectorQuantize

from vector_quantize_pytorch import VectorQuantize, Sequential

lr = 3e-4
train_iter = 1000
num_codes = 256
seed = 1234
rotation_trick = True
device = "cuda" if torch.cuda.is_available() else "cpu"


class SimpleVQAutoEncoder(nn.Module):
def __init__(self, **vq_kwargs):
super().__init__()
self.layers = nn.ModuleList(
[
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
]
)
return

def forward(self, x):
for layer in self.layers:
if isinstance(layer, VectorQuantize):
x, indices, commit_loss = layer(x)
else:
x = layer(x)

return x.clamp(-1, 1), indices, commit_loss

def SimpleVQAutoEncoder(**vq_kwargs):
return Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
)

def train(model, train_loader, train_iterations=1000, alpha=10):
def iterate_dataset(data_loader):
Expand All @@ -62,7 +46,10 @@ def iterate_dataset(data_loader):
for _ in (pbar := trange(train_iterations)):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))

out, indices, cmt_loss = model(x)
out = out.clamp(-1., 1.)

rec_loss = (out - x).abs().mean()
(rec_loss + alpha * cmt_loss).backward()

Expand All @@ -72,8 +59,6 @@ def iterate_dataset(data_loader):
+ f"cmt loss: {cmt_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
)
return


transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
Expand All @@ -91,7 +76,7 @@ def iterate_dataset(data_loader):

model = SimpleVQAutoEncoder(
codebook_size=num_codes,
rotation_trick=True
rotation_trick=rotation_trick
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=lr)
Expand Down
51 changes: 19 additions & 32 deletions examples/autoencoder_fsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from vector_quantize_pytorch import FSQ
from vector_quantize_pytorch import FSQ, Sequential


lr = 3e-4
Expand All @@ -20,36 +20,22 @@
device = "cuda" if torch.cuda.is_available() else "cpu"


class SimpleFSQAutoEncoder(nn.Module):
def __init__(self, levels: list[int]):
super().__init__()
self.layers = nn.ModuleList(
[
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, len(levels), kernel_size=1),
FSQ(levels),
nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
]
)
return

def forward(self, x):
for layer in self.layers:
if isinstance(layer, FSQ):
x, indices = layer(x)
else:
x = layer(x)

return x.clamp(-1, 1), indices
def SimpleFSQAutoEncoder(levels: list[int]):
return Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, len(levels), kernel_size=1),
FSQ(levels),
nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
)


def train(model, train_loader, train_iterations=1000):
Expand All @@ -67,6 +53,8 @@ def iterate_dataset(data_loader):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
out, indices = model(x)
out = out.clamp(-1., 1.)

rec_loss = (out - x).abs().mean()
rec_loss.backward()

Expand All @@ -75,7 +63,6 @@ def iterate_dataset(data_loader):
f"rec loss: {rec_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
)
return


transform = transforms.Compose(
Expand Down
69 changes: 27 additions & 42 deletions examples/autoencoder_lfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from vector_quantize_pytorch import LFQ
from vector_quantize_pytorch import LFQ, Sequential

lr = 3e-4
train_iter = 1000
Expand All @@ -22,46 +22,31 @@

device = "cuda" if torch.cuda.is_available() else "cpu"

class LFQAutoEncoder(nn.Module):
def __init__(
self,
codebook_size,
**vq_kwargs
):
super().__init__()
assert log2(codebook_size).is_integer()
quantize_dim = int(log2(codebook_size))

self.encode = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
# In general norm layers are commonly used in Resnet-based encoder/decoders
# explicitly add one here with affine=False to avoid introducing new parameters
nn.GroupNorm(4, 32, affine=False),
nn.Conv2d(32, quantize_dim, kernel_size=1),
)

self.quantize = LFQ(dim=quantize_dim, **vq_kwargs)

self.decode = nn.Sequential(
nn.Conv2d(quantize_dim, 32, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
)
return

def forward(self, x):
x = self.encode(x)
x, indices, entropy_aux_loss = self.quantize(x)
x = self.decode(x)
return x.clamp(-1, 1), indices, entropy_aux_loss

def LFQAutoEncoder(
codebook_size,
**vq_kwargs
):
assert log2(codebook_size).is_integer()
quantize_dim = int(log2(codebook_size))

return Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
# In general norm layers are commonly used in Resnet-based encoder/decoders
# explicitly add one here with affine=False to avoid introducing new parameters
nn.GroupNorm(4, 32, affine=False),
nn.Conv2d(32, quantize_dim, kernel_size=1),
LFQ(dim=quantize_dim, **vq_kwargs),
nn.Conv2d(quantize_dim, 32, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
)

def train(model, train_loader, train_iterations=1000):
def iterate_dataset(data_loader):
Expand All @@ -78,6 +63,7 @@ def iterate_dataset(data_loader):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
out, indices, entropy_aux_loss = model(x)
out = out.clamp(-1., 1.)

rec_loss = F.l1_loss(out, x)
(rec_loss + entropy_aux_loss).backward()
Expand All @@ -88,7 +74,6 @@ def iterate_dataset(data_loader):
+ f"entropy aux loss: {entropy_aux_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / codebook_size * 100:.3f}"
)
return

transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.19.3"
version = "1.19.4"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
2 changes: 2 additions & 0 deletions vector_quantize_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
from vector_quantize_pytorch.latent_quantization import LatentQuantize

from vector_quantize_pytorch.utils import Sequential
57 changes: 57 additions & 0 deletions vector_quantize_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList

# quantization

from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
from vector_quantize_pytorch.lookup_free_quantization import LFQ
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
from vector_quantize_pytorch.latent_quantization import LatentQuantize

QUANTIZE_KLASSES = (
VectorQuantize,
ResidualVQ,
GroupedResidualVQ,
RandomProjectionQuantizer,
FSQ,
LFQ,
ResidualLFQ,
GroupedResidualLFQ,
ResidualFSQ,
GroupedResidualFSQ,
LatentQuantize
)

# classes

class Sequential(Module):
def __init__(
self,
*fns: Module
):
super().__init__()
assert sum([int(isinstance(fn, QUANTIZE_KLASSES)) for fn in fns]) == 1, 'this special Sequential must contain exactly one quantizer'

self.fns = ModuleList(fns)

def forward(
self,
x,
**kwargs
):
for fn in self.fns:

if not isinstance(fn, QUANTIZE_KLASSES):
x = fn(x)
continue

x, *rest = fn(x, **kwargs)

output = (x, *rest)

return output

0 comments on commit e7ff7d7

Please sign in to comment.