diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index ff1898db2..6dda6bb32 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -406,9 +406,10 @@ def __sub__(self, other): def __truediv__(self, other): if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: - output_tensor = self.value / other.tensor - output_scale = self.scale / other.scale - output_bit_width = self.bit_width - other.bit_width + output_tensor = self.value / other.tensor # Note, output tensor not guaranteed to pass self.is_valid() + max_int_denominator = 2 ** (other.bit_width - int(other.signed)) + output_scale = self.scale / (other.scale * max_int_denominator) + output_bit_width = self.bit_width + other.bit_width output_signed = self.signed or other.signed output_training = self.training or other.training if self.is_zero_zero_point(self) and self.is_zero_zero_point(other):