From a5e06a4d770bb4584caed482b1a6351a1b861a7e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 17 Dec 2024 14:47:01 +0000 Subject: [PATCH] fix --- src/brevitas/export/inference/handler.py | 17 ++++++++++++++--- src/brevitas/proxy/parameter_quant.py | 2 +- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 911e2bd2d..c7fc21790 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -9,7 +9,6 @@ from torch import Tensor import torch.nn as nn -from brevitas import is_dynamo_compiling from brevitas.function.ops import max_float from brevitas.function.ops import max_int from brevitas.function.ops import min_int @@ -110,6 +109,10 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: class GroupwiseIntInferenceHandler(IntInferencetHandler): handled_layer = GroupwiseActQuantProxyFromInjector + def __init__(self): + super().__init__() + self.skip_create_quant_tensor = False + def prepare_for_export(self, module): if module.is_quant_enabled: self.module_forward = module.fused_activation_quant_proxy @@ -117,7 +120,9 @@ def prepare_for_export(self, module): def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: x, *other = self.module_forward(x) - if is_dynamo_compiling(): + + # If we skip quant tensor, we return the flattened version of the groupwise tensor + if self.skip_create_quant_tensor: start_dim = self.group_dim if self.group_dim != -1 else -2 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) @@ -248,6 +253,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: class GroupwiseFloatInferenceHandler(FloatInferencetHandler): handled_layer = GroupwiseActFloatQuantProxyFromInjector + def __init__(self): + super().__init__() + self.skip_create_quant_tensor = False + def prepare_for_export(self, module: nn.Module): if module.is_quant_enabled: self.module_forward = module.fused_activation_quant_proxy @@ -255,7 +264,9 @@ def prepare_for_export(self, module: nn.Module): def forward(self, x: Tensor) -> Tuple[Tensor]: x, *other = self.module_forward(x) - if is_dynamo_compiling(): + + # If we skip quant tensor, we return the flattened version of the groupwise tensor + if self.skip_create_quant_tensor: start_dim = self.group_dim if self.group_dim != -1 else -2 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index b57ae4716..5c4e447d4 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: # - quantization flow if self.export_mode: out = self.export_handler(x) - if is_dynamo_compiling(): + if self.skip_create_quant_tensor: out = out[0] else: out = self.create_quant_tensor(out)