Skip to content

Commit

Permalink
able to turn off bias for FSQ as well, as in scalar quantization, cod…
Browse files Browse the repository at this point in the history
…ebook is implicit in the previous projection
  • Loading branch information
lucidrains committed May 6, 2024
1 parent be995a7 commit 5ae60bd
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions vector_quantize_pytorch/finite_scalar_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(
keep_num_codebooks_dim: Optional[bool] = None,
scale: Optional[float] = None,
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
channel_first: bool = False
channel_first: bool = False,
projection_has_bias: bool = True
):
super().__init__()
_levels = torch.tensor(levels, dtype=int32)
Expand All @@ -75,8 +76,8 @@ def __init__(
self.channel_first = channel_first

has_projections = self.dim != effective_codebook_dim
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = projection_has_bias) if has_projections else nn.Identity()
self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = projection_has_bias) if has_projections else nn.Identity()

self.has_projections = has_projections

Expand Down

0 comments on commit 5ae60bd

Please sign in to comment.