diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index b81504c40..c79f28fdb 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -4,6 +4,7 @@ from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase from brevitas.quant_tensor import GroupwiseFloatQuantTensor +from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat @@ -17,13 +18,11 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def forward( - self, x: Union[Tensor, - GroupwiseFloatQuantTensor]) -> Union[Tensor, GroupwiseFloatQuantTensor]: + def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloatQuantTensor]: out = x if self.fused_activation_quant_proxy is not None: y = x - if isinstance(y, GroupwiseFloatQuantTensor): + if isinstance(y, QuantTensor): y = y.value if self.export_mode: