From 3bb00f57d8343925c679af2413afaf0d024601de Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Nov 2024 11:55:12 -0800 Subject: [PATCH] seems to work even better with a one layer mlp --- examples/autoencoder_sim_vq.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/autoencoder_sim_vq.py b/examples/autoencoder_sim_vq.py index bd1c13b..3681ba9 100644 --- a/examples/autoencoder_sim_vq.py +++ b/examples/autoencoder_sim_vq.py @@ -14,7 +14,10 @@ train_iter = 10000 num_codes = 256 seed = 1234 -rotation_trick = True + +rotation_trick = True # rotation trick instead ot straight-through +use_mlp = True # use a one layer mlp with relu instead of linear + device = "cuda" if torch.cuda.is_available() else "cpu" def SimVQAutoEncoder(**vq_kwargs): @@ -77,7 +80,12 @@ def iterate_dataset(data_loader): model = SimVQAutoEncoder( codebook_size = num_codes, - rotation_trick = rotation_trick + rotation_trick = rotation_trick, + codebook_transform = nn.Sequential( + nn.Linear(32, 128), + nn.ReLU(), + nn.Linear(128, 32), + ) if use_mlp else None ).to(device) opt = torch.optim.AdamW(model.parameters(), lr=lr)