Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 10, 2024
1 parent ce01540 commit 6be9a09
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 26 deletions.
1 change: 1 addition & 0 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.quant_tensor import _is_all_nested_not_none
from brevitas.quant_tensor import QuantTensor

from .utils import filter_kwargs
Expand Down
45 changes: 23 additions & 22 deletions src/brevitas/nn/quant_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import Int32Bias
from brevitas.quant import Uint8ActPerTensorFloat
from brevitas.quant_tensor import _get_dequantize_tensor
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor

QuantTupleShortEnabled = List[Tuple[Tensor, Tensor, Tensor, Tensor]]
Expand Down Expand Up @@ -421,19 +421,20 @@ def forward(self, inp, state):
if getattr(quant_bias, 'value', quant_bias) is None:
quant_bias = torch.tensor(0., device=quant_input.value.device)
else:
quant_bias = _get_dequantize_tensor(quant_bias)
quant_state = self.maybe_quantize_state(quant_input.value, state, self.cell.output_quant)
quant_bias = _unpack_quant_tensor(quant_bias)
quant_input_value = _unpack_quant_tensor(quant_input)
quant_state = self.maybe_quantize_state(quant_input_value, state, self.cell.output_quant)
if self.export_mode:
cell = self.export_handler
elif self.fast_mode:
cell = self.fast_cell
else:
cell = self.cell
quant_outputs = cell(
_get_dequantize_tensor(quant_input),
_get_dequantize_tensor(quant_state),
_get_dequantize_tensor(quant_weight_ih),
_get_dequantize_tensor(quant_weight_hh),
_unpack_quant_tensor(quant_input),
_unpack_quant_tensor(quant_state),
_unpack_quant_tensor(quant_weight_ih),
_unpack_quant_tensor(quant_weight_hh),
quant_bias)
quant_output = self.pack_quant_outputs(quant_outputs)
quant_state = self.pack_quant_state(quant_outputs[-1], self.cell.output_quant)
Expand Down Expand Up @@ -668,7 +669,7 @@ def fast_cell(self):

def forward(self, inp, hidden_state, cell_state):
quant_input = self.maybe_quantize_input(inp)
quant_input_value = _get_dequantize_tensor(quant_input)
quant_input_value = _unpack_quant_tensor(quant_input)
quant_weight_ii, quant_weight_hi, quant_bias_input = self.gate_params_fwd(
self.input_gate_params, quant_input)
quant_weight_ic, quant_weight_hc, quant_bias_cell = self.gate_params_fwd(
Expand All @@ -686,19 +687,19 @@ def forward(self, inp, hidden_state, cell_state):
if getattr(quant_bias_input, 'value', quant_bias_input) is None:
quant_bias_input = torch.tensor(0., device=quant_input_value.device)
else:
quant_bias_input = _get_dequantize_tensor(quant_bias_input)
quant_bias_input = _unpack_quant_tensor(quant_bias_input)
if getattr(quant_bias_forget, 'value', quant_bias_forget) is None:
quant_bias_forget = torch.tensor(0., device=quant_input_value.device)
else:
quant_bias_forget = _get_dequantize_tensor(quant_bias_forget)
quant_bias_forget = _unpack_quant_tensor(quant_bias_forget)
if getattr(quant_bias_cell, 'value', quant_bias_cell) is None:
quant_bias_cell = torch.tensor(0., device=quant_input_value.device)
else:
quant_bias_cell = _get_dequantize_tensor(quant_bias_cell)
quant_bias_cell = _unpack_quant_tensor(quant_bias_cell)
if getattr(quant_bias_output, 'value', quant_bias_output) is None:
quant_bias_output = torch.tensor(0., device=quant_input_value.device)
else:
quant_bias_output = _get_dequantize_tensor(quant_bias_output)
quant_bias_output = _unpack_quant_tensor(quant_bias_output)
quant_hidden_state = self.maybe_quantize_state(
quant_input_value, hidden_state, self.cell.output_quant)
quant_cell_state = self.maybe_quantize_state(
Expand All @@ -712,16 +713,16 @@ def forward(self, inp, hidden_state, cell_state):
cell = self.cell
quant_outputs, quant_hidden_state, quant_cell_state = cell(
quant_input_value,
_get_dequantize_tensor(quant_hidden_state),
_get_dequantize_tensor(quant_cell_state),
quant_weight_ii=_get_dequantize_tensor(quant_weight_ii),
quant_weight_if=_get_dequantize_tensor(quant_weight_if),
quant_weight_ic=_get_dequantize_tensor(quant_weight_ic),
quant_weight_io=_get_dequantize_tensor(quant_weight_io),
quant_weight_hi=_get_dequantize_tensor(quant_weight_hi),
quant_weight_hf=_get_dequantize_tensor(quant_weight_hf),
quant_weight_hc=_get_dequantize_tensor(quant_weight_hc),
quant_weight_ho=_get_dequantize_tensor(quant_weight_ho),
_unpack_quant_tensor(quant_hidden_state),
_unpack_quant_tensor(quant_cell_state),
quant_weight_ii=_unpack_quant_tensor(quant_weight_ii),
quant_weight_if=_unpack_quant_tensor(quant_weight_if),
quant_weight_ic=_unpack_quant_tensor(quant_weight_ic),
quant_weight_io=_unpack_quant_tensor(quant_weight_io),
quant_weight_hi=_unpack_quant_tensor(quant_weight_hi),
quant_weight_hf=_unpack_quant_tensor(quant_weight_hf),
quant_weight_hc=_unpack_quant_tensor(quant_weight_hc),
quant_weight_ho=_unpack_quant_tensor(quant_weight_ho),
quant_bias_input=quant_bias_input,
quant_bias_forget=quant_bias_forget,
quant_bias_cell=quant_bias_cell,
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class QuantTensorBase(NamedTuple):

def _unpack_quant_tensor(input_data):
if isinstance(input_data, QuantTensor):
return input_data.tensor
return input_data.value
elif isinstance(input_data, tuple):
return tuple([_unpack_quant_tensor(v) for v in input_data])
elif isinstance(input_data, list):
Expand Down Expand Up @@ -132,6 +132,8 @@ def tensor(self):
@property
def value(self):
if self.is_valid:
if self.zero_point is None or self.scale is None:
return self.qt_value
return (self.qt_value - self.zero_point) * self.scale
else:
return self.qt_value
Expand Down
15 changes: 12 additions & 3 deletions tests/brevitas/nn/test_nn_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,18 @@ def test_quant_mha(model_input, current_cases):
case_id = get_case_id(cases_generator_func)
args = case_id.split('-')[1:] # Exclude first argument
kwargs = parse_args(args)

if (kwargs['io_quant'] is None or
kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external':
is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized']
if not ((is_input_quanttensor and kwargs['weight_quant'] is not None) or
kwargs['io_quant'] is not None) and kwargs['return_quant_tensor']:
with pytest.raises(RuntimeError, match='QuantLayer is not correctly configured'):
output, _ = model(inp, inp, inp)
return
elif kwargs['io_quant'] is None and kwargs['return_quant_tensor']:
with pytest.raises(RuntimeError, match='QuantLayer is not correctly configured'):
output, _ = model(inp, inp, inp)
return
elif (kwargs['io_quant'] is None or
kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external':
with pytest.raises(RuntimeError, match='Input scale required'):
output, _ = model(inp, inp, inp)
return
Expand Down

0 comments on commit 6be9a09

Please sign in to comment.