Skip to content

Commit

Permalink
address #144
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 6, 2024
1 parent 3505761 commit be92f79
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 71 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.14.46"
version = "1.15.0"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
155 changes: 85 additions & 70 deletions vector_quantize_pytorch/lookup_free_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from math import log2, ceil
from functools import partial, cache
from collections import namedtuple
from contextlib import nullcontext
import torch.distributed as dist

import torch
Expand Down Expand Up @@ -112,7 +113,8 @@ def __init__(
channel_first = None,
experimental_softplus_entropy_loss = False,
entropy_loss_offset = 5., # how much to shift the loss before softplus
spherical = False # from https://arxiv.org/abs/2406.07548
spherical = False, # from https://arxiv.org/abs/2406.07548
force_quantization_f32 = True # will force the quantization step to be full precision
):
super().__init__()

Expand Down Expand Up @@ -192,6 +194,10 @@ def __init__(
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
self.register_buffer('zero', torch.tensor(0.), persistent = False)

# whether to force quantization step to be f32

self.force_quantization_f32 = force_quantization_f32

# codes

all_codes = torch.arange(codebook_size)
Expand Down Expand Up @@ -241,7 +247,6 @@ def indices_to_codes(

return codes

@autocast(enabled = False)
def forward(
self,
x,
Expand All @@ -257,9 +262,6 @@ def forward(
c - number of codebook dim
"""

orig_dtype = x.dtype
x = x.float()

is_img_or_video = x.ndim >= 4
should_transpose = default(self.channel_first, is_img_or_video)

Expand All @@ -271,8 +273,7 @@ def forward(

assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'

with autocast():
x = self.project_in(x)
x = self.project_in(x)

# maybe soft clamp

Expand All @@ -288,104 +289,122 @@ def forward(

x = self.maybe_l2norm(x)

# quantize by eq 3.
# whether to force quantization step to be full precision or not

original_input = x
force_f32 = self.force_quantization_f32

codebook_value = torch.ones_like(x) * self.codebook_scale
quantized = torch.where(x > 0, codebook_value, -codebook_value)
quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext

# calculate indices
with quantization_context():

indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
if force_f32:
orig_dtype = x.dtype
x = x.float()

# maybe l2norm
# quantize by eq 3.

quantized = self.maybe_l2norm(quantized)
original_input = x

# use straight-through gradients (optionally with custom activation fn) if training
codebook_value = torch.ones_like(x) * self.codebook_scale
quantized = torch.where(x > 0, codebook_value, -codebook_value)

if self.training:
x = self.activation(x)
x = x + (quantized - x).detach()
else:
x = quantized
# calculate indices

# entropy aux loss
indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')

if self.training:
codebook = self.codebook.float()
# maybe l2norm

codebook = self.maybe_l2norm(codebook)
quantized = self.maybe_l2norm(quantized)

# the same as euclidean distance up to a constant
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
# use straight-through gradients (optionally with custom activation fn) if training

prob = (-distance * inv_temperature).softmax(dim = -1)
if self.training:
x = self.activation(x)
x = x + (quantized - x).detach()
else:
x = quantized

# account for mask
# entropy aux loss

if exists(mask):
prob = prob[mask]
else:
prob = rearrange(prob, 'b n ... -> (b n) ...')
if self.training:

# whether to only use a fraction of probs, for reducing memory
if force_f32:
codebook = self.codebook.float()

if self.frac_per_sample_entropy < 1.:
num_tokens = prob.shape[0]
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens
per_sample_probs = prob[rand_mask]
else:
per_sample_probs = prob
codebook = self.maybe_l2norm(codebook)

# calculate per sample entropy
# the same as euclidean distance up to a constant
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)

per_sample_entropy = entropy(per_sample_probs).mean()
prob = (-distance * inv_temperature).softmax(dim = -1)

# distribution over all available tokens in the batch
# account for mask

avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
if exists(mask):
prob = prob[mask]
else:
prob = rearrange(prob, 'b n ... -> (b n) ...')

avg_prob = maybe_distributed_mean(avg_prob)
# whether to only use a fraction of probs, for reducing memory

codebook_entropy = entropy(avg_prob).mean()
if self.frac_per_sample_entropy < 1.:
num_tokens = prob.shape[0]
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens
per_sample_probs = prob[rand_mask]
else:
per_sample_probs = prob

# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
# calculate per sample entropy

entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
else:
# if not training, just return dummy 0
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
per_sample_entropy = entropy(per_sample_probs).mean()

# whether to make the entropy loss positive or not through a (shifted) softplus
# distribution over all available tokens in the batch

if self.training and self.experimental_softplus_entropy_loss:
entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')

# commit loss
avg_prob = maybe_distributed_mean(avg_prob)

if self.training and self.commitment_loss_weight > 0.:
codebook_entropy = entropy(avg_prob).mean()

commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch

if exists(mask):
commit_loss = commit_loss[mask]
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
else:
# if not training, just return dummy 0
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero

commit_loss = commit_loss.mean()
else:
commit_loss = self.zero
# whether to make the entropy loss positive or not through a (shifted) softplus

if self.training and self.experimental_softplus_entropy_loss:
entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)

# commit loss

if self.training and self.commitment_loss_weight > 0.:

commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')

if exists(mask):
commit_loss = commit_loss[mask]

commit_loss = commit_loss.mean()
else:
commit_loss = self.zero

# input back to original dtype if needed

if force_f32:
x = x.type(orig_dtype)

# merge back codebook dim

x = rearrange(x, 'b n c d -> b n (c d)')

# project out to feature dimension if needed

with autocast():
x = self.project_out(x)
x = self.project_out(x)

# reconstitute image or video dimensions

Expand All @@ -404,10 +423,6 @@ def forward(

aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight

# restore original dtype

x = x.type(orig_dtype)

# returns

ret = Return(x, indices, aux_loss)
Expand Down

0 comments on commit be92f79

Please sign in to comment.