diff --git a/pyproject.toml b/pyproject.toml index a4351ec..4de811d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.14.20" +version = "1.14.22" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/lookup_free_quantization.py b/vector_quantize_pytorch/lookup_free_quantization.py index 67ab18b..b6797bf 100644 --- a/vector_quantize_pytorch/lookup_free_quantization.py +++ b/vector_quantize_pytorch/lookup_free_quantization.py @@ -152,7 +152,7 @@ def __init__( # whether to soft clamp the input value from -value to value self.soft_clamp_input_value = soft_clamp_input_value - assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= 1. + assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale # for no auxiliary loss, during inference diff --git a/vector_quantize_pytorch/residual_lfq.py b/vector_quantize_pytorch/residual_lfq.py index 03bd89b..21b2fde 100644 --- a/vector_quantize_pytorch/residual_lfq.py +++ b/vector_quantize_pytorch/residual_lfq.py @@ -39,6 +39,7 @@ def __init__( quantize_dropout = False, quantize_dropout_cutoff_index = 0, quantize_dropout_multiple_of = 1, + soft_clamp_input_value = None, **kwargs ): super().__init__() @@ -59,11 +60,15 @@ def __init__( lfq = LFQ( dim = codebook_dim, codebook_scale = codebook_scale, + soft_clamp_input_value = soft_clamp_input_value, **kwargs ) self.layers.append(lfq) + if exists(soft_clamp_input_value): + soft_clamp_input_value *= 0.5 + assert all([not lfq.has_projections for lfq in self.layers]) self.quantize_dropout = quantize_dropout and num_quantizers > 1