-
Notifications
You must be signed in to change notification settings - Fork 235
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
35c4c84
commit e7ff7d7
Showing
6 changed files
with
126 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[project] | ||
name = "vector-quantize-pytorch" | ||
version = "1.19.3" | ||
version = "1.19.4" | ||
description = "Vector Quantization - Pytorch" | ||
authors = [ | ||
{ name = "Phil Wang", email = "[email protected]" } | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import torch | ||
from torch import nn | ||
from torch.nn import Module, ModuleList | ||
|
||
# quantization | ||
|
||
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize | ||
from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ | ||
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer | ||
from vector_quantize_pytorch.finite_scalar_quantization import FSQ | ||
from vector_quantize_pytorch.lookup_free_quantization import LFQ | ||
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ | ||
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ | ||
from vector_quantize_pytorch.latent_quantization import LatentQuantize | ||
|
||
QUANTIZE_KLASSES = ( | ||
VectorQuantize, | ||
ResidualVQ, | ||
GroupedResidualVQ, | ||
RandomProjectionQuantizer, | ||
FSQ, | ||
LFQ, | ||
ResidualLFQ, | ||
GroupedResidualLFQ, | ||
ResidualFSQ, | ||
GroupedResidualFSQ, | ||
LatentQuantize | ||
) | ||
|
||
# classes | ||
|
||
class Sequential(Module): | ||
def __init__( | ||
self, | ||
*fns: Module | ||
): | ||
super().__init__() | ||
assert sum([int(isinstance(fn, QUANTIZE_KLASSES)) for fn in fns]) == 1, 'this special Sequential must contain exactly one quantizer' | ||
|
||
self.fns = ModuleList(fns) | ||
|
||
def forward( | ||
self, | ||
x, | ||
**kwargs | ||
): | ||
for fn in self.fns: | ||
|
||
if not isinstance(fn, QUANTIZE_KLASSES): | ||
x = fn(x) | ||
continue | ||
|
||
x, *rest = fn(x, **kwargs) | ||
|
||
output = (x, *rest) | ||
|
||
return output |