Skip to content

Commit

Permalink
Fix (proxy/runtime_quant): correct handling of mixed type quantization (
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jul 30, 2024
1 parent 05c80f5 commit b889baa
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from brevitas.inject import BaseInjector as Injector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.quant_tensor.base_quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIOFloat


Expand Down Expand Up @@ -59,11 +60,11 @@ def is_fnuz(self):
) is None and self.exponent_bias() == 16
return is_fnuz_e4m3 or is_fnuz_e5m2

def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]:
def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, FloatQuantTensor]:
out = x
if self.fused_activation_quant_proxy is not None:
y = x
if isinstance(y, FloatQuantTensor):
if isinstance(y, QuantTensor):
y = y.value

if self.export_mode:
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ def zero_point(self, force_eval=True):
def bit_width(self, force_eval=True):
return self.retrieve_attribute('bit_width', force_eval)

def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTensor]:
def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, IntQuantTensor]:
out = x
if self.fused_activation_quant_proxy is not None:
y = x
if isinstance(y, IntQuantTensor):
if isinstance(y, QuantTensor):
y = y.value

if self.export_mode:
Expand Down

0 comments on commit b889baa

Please sign in to comment.