Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 17, 2024
1 parent b8c4877 commit a5e06a4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
17 changes: 14 additions & 3 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch import Tensor
import torch.nn as nn

from brevitas import is_dynamo_compiling
from brevitas.function.ops import max_float
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
Expand Down Expand Up @@ -110,14 +109,20 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
class GroupwiseIntInferenceHandler(IntInferencetHandler):
handled_layer = GroupwiseActQuantProxyFromInjector

def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy
self.group_dim = module.group_dim

def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling():

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
Expand Down Expand Up @@ -248,14 +253,20 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
class GroupwiseFloatInferenceHandler(FloatInferencetHandler):
handled_layer = GroupwiseActFloatQuantProxyFromInjector

def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False

def prepare_for_export(self, module: nn.Module):
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy
self.group_dim = module.group_dim

def forward(self, x: Tensor) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling():

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
# - quantization flow
if self.export_mode:
out = self.export_handler(x)
if is_dynamo_compiling():
if self.skip_create_quant_tensor:
out = out[0]
else:
out = self.create_quant_tensor(out)
Expand Down

0 comments on commit a5e06a4

Please sign in to comment.