diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index b4c48d5fe..e1ba13a00 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -18,6 +18,7 @@ from brevitas.inject import ExtendedInjector from brevitas.inject import Injector from brevitas.nn.utils import compute_channel_view_shape +from brevitas.quant_tensor import _is_all_nested_not_none from brevitas.quant_tensor import QuantTensor from .utils import filter_kwargs diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index 8268e9eb9..7daf54ce0 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -23,7 +23,7 @@ from brevitas.quant import Int8WeightPerTensorFloat from brevitas.quant import Int32Bias from brevitas.quant import Uint8ActPerTensorFloat -from brevitas.quant_tensor import _get_dequantize_tensor +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor QuantTupleShortEnabled = List[Tuple[Tensor, Tensor, Tensor, Tensor]] @@ -421,8 +421,9 @@ def forward(self, inp, state): if getattr(quant_bias, 'value', quant_bias) is None: quant_bias = torch.tensor(0., device=quant_input.value.device) else: - quant_bias = _get_dequantize_tensor(quant_bias) - quant_state = self.maybe_quantize_state(quant_input.value, state, self.cell.output_quant) + quant_bias = _unpack_quant_tensor(quant_bias) + quant_input_value = _unpack_quant_tensor(quant_input) + quant_state = self.maybe_quantize_state(quant_input_value, state, self.cell.output_quant) if self.export_mode: cell = self.export_handler elif self.fast_mode: @@ -430,10 +431,10 @@ def forward(self, inp, state): else: cell = self.cell quant_outputs = cell( - _get_dequantize_tensor(quant_input), - _get_dequantize_tensor(quant_state), - _get_dequantize_tensor(quant_weight_ih), - _get_dequantize_tensor(quant_weight_hh), + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_state), + _unpack_quant_tensor(quant_weight_ih), + _unpack_quant_tensor(quant_weight_hh), quant_bias) quant_output = self.pack_quant_outputs(quant_outputs) quant_state = self.pack_quant_state(quant_outputs[-1], self.cell.output_quant) @@ -668,7 +669,7 @@ def fast_cell(self): def forward(self, inp, hidden_state, cell_state): quant_input = self.maybe_quantize_input(inp) - quant_input_value = _get_dequantize_tensor(quant_input) + quant_input_value = _unpack_quant_tensor(quant_input) quant_weight_ii, quant_weight_hi, quant_bias_input = self.gate_params_fwd( self.input_gate_params, quant_input) quant_weight_ic, quant_weight_hc, quant_bias_cell = self.gate_params_fwd( @@ -686,19 +687,19 @@ def forward(self, inp, hidden_state, cell_state): if getattr(quant_bias_input, 'value', quant_bias_input) is None: quant_bias_input = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_input = _get_dequantize_tensor(quant_bias_input) + quant_bias_input = _unpack_quant_tensor(quant_bias_input) if getattr(quant_bias_forget, 'value', quant_bias_forget) is None: quant_bias_forget = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_forget = _get_dequantize_tensor(quant_bias_forget) + quant_bias_forget = _unpack_quant_tensor(quant_bias_forget) if getattr(quant_bias_cell, 'value', quant_bias_cell) is None: quant_bias_cell = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_cell = _get_dequantize_tensor(quant_bias_cell) + quant_bias_cell = _unpack_quant_tensor(quant_bias_cell) if getattr(quant_bias_output, 'value', quant_bias_output) is None: quant_bias_output = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_output = _get_dequantize_tensor(quant_bias_output) + quant_bias_output = _unpack_quant_tensor(quant_bias_output) quant_hidden_state = self.maybe_quantize_state( quant_input_value, hidden_state, self.cell.output_quant) quant_cell_state = self.maybe_quantize_state( @@ -712,16 +713,16 @@ def forward(self, inp, hidden_state, cell_state): cell = self.cell quant_outputs, quant_hidden_state, quant_cell_state = cell( quant_input_value, - _get_dequantize_tensor(quant_hidden_state), - _get_dequantize_tensor(quant_cell_state), - quant_weight_ii=_get_dequantize_tensor(quant_weight_ii), - quant_weight_if=_get_dequantize_tensor(quant_weight_if), - quant_weight_ic=_get_dequantize_tensor(quant_weight_ic), - quant_weight_io=_get_dequantize_tensor(quant_weight_io), - quant_weight_hi=_get_dequantize_tensor(quant_weight_hi), - quant_weight_hf=_get_dequantize_tensor(quant_weight_hf), - quant_weight_hc=_get_dequantize_tensor(quant_weight_hc), - quant_weight_ho=_get_dequantize_tensor(quant_weight_ho), + _unpack_quant_tensor(quant_hidden_state), + _unpack_quant_tensor(quant_cell_state), + quant_weight_ii=_unpack_quant_tensor(quant_weight_ii), + quant_weight_if=_unpack_quant_tensor(quant_weight_if), + quant_weight_ic=_unpack_quant_tensor(quant_weight_ic), + quant_weight_io=_unpack_quant_tensor(quant_weight_io), + quant_weight_hi=_unpack_quant_tensor(quant_weight_hi), + quant_weight_hf=_unpack_quant_tensor(quant_weight_hf), + quant_weight_hc=_unpack_quant_tensor(quant_weight_hc), + quant_weight_ho=_unpack_quant_tensor(quant_weight_ho), quant_bias_input=quant_bias_input, quant_bias_forget=quant_bias_forget, quant_bias_cell=quant_bias_cell, diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 9eeb5ee9e..3423bc212 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -34,7 +34,7 @@ class QuantTensorBase(NamedTuple): def _unpack_quant_tensor(input_data): if isinstance(input_data, QuantTensor): - return input_data.tensor + return input_data.value elif isinstance(input_data, tuple): return tuple([_unpack_quant_tensor(v) for v in input_data]) elif isinstance(input_data, list): @@ -132,6 +132,8 @@ def tensor(self): @property def value(self): if self.is_valid: + if self.zero_point is None or self.scale is None: + return self.qt_value return (self.qt_value - self.zero_point) * self.scale else: return self.qt_value diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index 55296ff35..8788289e5 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -199,9 +199,18 @@ def test_quant_mha(model_input, current_cases): case_id = get_case_id(cases_generator_func) args = case_id.split('-')[1:] # Exclude first argument kwargs = parse_args(args) - - if (kwargs['io_quant'] is None or - kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external': + is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] + if not ((is_input_quanttensor and kwargs['weight_quant'] is not None) or + kwargs['io_quant'] is not None) and kwargs['return_quant_tensor']: + with pytest.raises(RuntimeError, match='QuantLayer is not correctly configured'): + output, _ = model(inp, inp, inp) + return + elif kwargs['io_quant'] is None and kwargs['return_quant_tensor']: + with pytest.raises(RuntimeError, match='QuantLayer is not correctly configured'): + output, _ = model(inp, inp, inp) + return + elif (kwargs['io_quant'] is None or + kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external': with pytest.raises(RuntimeError, match='Input scale required'): output, _ = model(inp, inp, inp) return