diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index 28f1e1b5e..546fa8f8a 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -8,6 +8,7 @@ from brevitas.inject import BaseInjector as Injector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase from brevitas.quant_tensor import FloatQuantTensor +from brevitas.quant_tensor.base_quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOFloat @@ -59,11 +60,11 @@ def is_fnuz(self): ) is None and self.exponent_bias() == 16 return is_fnuz_e4m3 or is_fnuz_e5m2 - def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]: + def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, FloatQuantTensor]: out = x if self.fused_activation_quant_proxy is not None: y = x - if isinstance(y, FloatQuantTensor): + if isinstance(y, QuantTensor): y = y.value if self.export_mode: diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 4ef93cad6..a89bc9abb 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -157,11 +157,11 @@ def zero_point(self, force_eval=True): def bit_width(self, force_eval=True): return self.retrieve_attribute('bit_width', force_eval) - def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTensor]: + def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, IntQuantTensor]: out = x if self.fused_activation_quant_proxy is not None: y = x - if isinstance(y, IntQuantTensor): + if isinstance(y, QuantTensor): y = y.value if self.export_mode: