From 4d62dc9e8ff9a989784aee9e7f740c5f6f8a6b17 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 12 Jan 2024 17:59:13 +0000 Subject: [PATCH] More tests fix. No export --- src/brevitas/core/quant/int_base.py | 6 +++--- src/brevitas/nn/mixin/parameter.py | 6 +++++- src/brevitas/nn/quant_rnn.py | 4 ++-- tests/brevitas/fx/test_tracer.py | 18 +++++++++++++++--- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/brevitas/core/quant/int_base.py b/src/brevitas/core/quant/int_base.py index 962c20b3b..6753ab5db 100644 --- a/src/brevitas/core/quant/int_base.py +++ b/src/brevitas/core/quant/int_base.py @@ -164,8 +164,8 @@ def forward( zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: - y_int = self.to_int(pre_scale, pre_zero_point, bit_width, x) - y = y_int - zero_point - y = y * scale + y = self.to_int(pre_scale, pre_zero_point, bit_width, x) + # y = y_int - zero_point + # y = y * scale y = self.delay_wrapper(x, y) return y diff --git a/src/brevitas/nn/mixin/parameter.py b/src/brevitas/nn/mixin/parameter.py index 095c981f1..3acfe7c95 100644 --- a/src/brevitas/nn/mixin/parameter.py +++ b/src/brevitas/nn/mixin/parameter.py @@ -198,7 +198,11 @@ def quant_bias_zero_point(self): if self.bias is None: return None if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width: - return self.bias_quant(self.bias).zero_point + bias_quant = self.bias_quant(self.bias) + if isinstance(bias_quant, QuantTensor): + return bias_quant.zero_point + else: + return None else: if self._cached_bias is None: raise RuntimeError( diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index 7daf54ce0..4cd7fb87c 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -418,11 +418,11 @@ def forward(self, inp, state): quant_input = self.maybe_quantize_input(inp) quant_weight_ih, quant_weight_hh, quant_bias = self.gate_params_fwd( self.gate_params, quant_input) + quant_input_value = _unpack_quant_tensor(quant_input) if getattr(quant_bias, 'value', quant_bias) is None: - quant_bias = torch.tensor(0., device=quant_input.value.device) + quant_bias = torch.tensor(0., device=quant_input_value.device) else: quant_bias = _unpack_quant_tensor(quant_bias) - quant_input_value = _unpack_quant_tensor(quant_input) quant_state = self.maybe_quantize_state(quant_input_value, state, self.cell.output_quant) if self.export_mode: cell = self.export_handler diff --git a/tests/brevitas/fx/test_tracer.py b/tests/brevitas/fx/test_tracer.py index be5698d2c..cf910b51f 100644 --- a/tests/brevitas/fx/test_tracer.py +++ b/tests/brevitas/fx/test_tracer.py @@ -232,10 +232,22 @@ def test_module(module): @pytest.mark.parametrize('module', QUANT_TENSOR_MODULES) def test_quant_module(module): mod = module() - x = QuantTensor(torch.randn(INPUT_SIZE)) - x_trace = QuantTensor(torch.randn(INPUT_SIZE)) + x = QuantTensor( + torch.randint(-128, 127, INPUT_SIZE), + scale=0.1, + zero_point=0, + bit_width=8, + signed=True, + training=True) + x_trace = QuantTensor( + torch.randint(-128, 127, INPUT_SIZE), + scale=0.1, + zero_point=0, + bit_width=8, + signed=True, + training=True) with torch.no_grad(): out = mod(x) graph_model = value_trace(mod, value_args={'x': x_trace}) graph_out = graph_model(x) - assert graph_out.isclose(out).all().item() + assert torch.allclose(out, graph_out)