Skip to content

Commit

Permalink
start leveraging einx get_at for extra clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 20, 2024
1 parent 5f7851b commit 60de331
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 17 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '1.12.12',
version = '1.12.14',
license='MIT',
description = 'Vector Quantization - Pytorch',
long_description_content_type = 'text/markdown',
Expand All @@ -18,6 +18,7 @@
],
install_requires=[
'einops>=0.7.0',
'einx',
'torch'
],
classifiers=[
Expand Down
13 changes: 5 additions & 8 deletions vector_quantize_pytorch/residual_lfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from einops import rearrange, repeat, reduce, pack, unpack

from einx import get_at

# helper functions

def exists(val):
Expand Down Expand Up @@ -92,17 +94,12 @@ def get_codes_from_indices(self, indices):
assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)

# get ready for gathering

codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])

# take care of quantizer dropout

mask = gather_indices == -1.
gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
mask = indices == -1.
indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later

all_codes = codebooks.gather(2, gather_indices) # gather all codes
all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)

# mask out any codes that were dropout-ed

Expand Down
13 changes: 5 additions & 8 deletions vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from einops import rearrange, repeat, reduce, pack, unpack

from einx import get_at

# helper functions

def exists(val):
Expand Down Expand Up @@ -94,17 +96,12 @@ def get_codes_from_indices(self, indices):
assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)

# get ready for gathering

codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])

# take care of quantizer dropout

mask = gather_indices == -1.
gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
mask = indices == -1.
indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later

all_codes = codebooks.gather(2, gather_indices) # gather all codes
all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)

# mask out any codes that were dropout-ed

Expand Down

0 comments on commit 60de331

Please sign in to comment.