From 80f4e848bffb4ad6ba2d2d0bcb4318a92322bb67 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Nov 2024 10:21:30 -0800 Subject: [PATCH] allow for an MLP implicitly parameterized codebook instead of just a single Linear, for SimVQ. seems to converge just fine with rotation trick --- pyproject.toml | 2 +- vector_quantize_pytorch/sim_vq.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3a4ae17..c1f9753 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.20.1" +version = "1.20.2" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py index 29ad136..7e05f5e 100644 --- a/vector_quantize_pytorch/sim_vq.py +++ b/vector_quantize_pytorch/sim_vq.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Callable import torch @@ -38,6 +39,7 @@ def __init__( self, dim, codebook_size, + codebook_transform: Module | None = None, init_fn: Callable = identity, accept_image_fmap = False, rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize @@ -51,7 +53,11 @@ def __init__( # the codebook is actually implicit from a linear layer from frozen gaussian or uniform - self.codebook_to_codes = nn.Linear(dim, dim, bias = False) + if not exists(codebook_transform): + codebook_transform = nn.Linear(dim, dim, bias = False) + + self.codebook_to_codes = codebook_transform + self.register_buffer('codebook', codebook) @@ -114,6 +120,11 @@ def forward( sim_vq = SimVQ( dim = 512, + codebook_transform = nn.Sequential( + nn.Linear(512, 1024), + nn.ReLU(), + nn.Linear(1024, 512) + ), codebook_size = 1024, accept_image_fmap = True )