Skip to content

Commit

Permalink
More tests fix. No export
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 12, 2024
1 parent c383e5f commit 4d62dc9
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/brevitas/core/quant/int_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def forward(
zero_point: Tensor,
bit_width: Tensor,
x: Tensor) -> Tensor:
y_int = self.to_int(pre_scale, pre_zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.to_int(pre_scale, pre_zero_point, bit_width, x)
# y = y_int - zero_point
# y = y * scale
y = self.delay_wrapper(x, y)
return y
6 changes: 5 additions & 1 deletion src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def quant_bias_zero_point(self):
if self.bias is None:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:
return self.bias_quant(self.bias).zero_point
bias_quant = self.bias_quant(self.bias)
if isinstance(bias_quant, QuantTensor):
return bias_quant.zero_point
else:
return None
else:
if self._cached_bias is None:
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/quant_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,11 @@ def forward(self, inp, state):
quant_input = self.maybe_quantize_input(inp)
quant_weight_ih, quant_weight_hh, quant_bias = self.gate_params_fwd(
self.gate_params, quant_input)
quant_input_value = _unpack_quant_tensor(quant_input)
if getattr(quant_bias, 'value', quant_bias) is None:
quant_bias = torch.tensor(0., device=quant_input.value.device)
quant_bias = torch.tensor(0., device=quant_input_value.device)
else:
quant_bias = _unpack_quant_tensor(quant_bias)
quant_input_value = _unpack_quant_tensor(quant_input)
quant_state = self.maybe_quantize_state(quant_input_value, state, self.cell.output_quant)
if self.export_mode:
cell = self.export_handler
Expand Down
18 changes: 15 additions & 3 deletions tests/brevitas/fx/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,22 @@ def test_module(module):
@pytest.mark.parametrize('module', QUANT_TENSOR_MODULES)
def test_quant_module(module):
mod = module()
x = QuantTensor(torch.randn(INPUT_SIZE))
x_trace = QuantTensor(torch.randn(INPUT_SIZE))
x = QuantTensor(
torch.randint(-128, 127, INPUT_SIZE),
scale=0.1,
zero_point=0,
bit_width=8,
signed=True,
training=True)
x_trace = QuantTensor(
torch.randint(-128, 127, INPUT_SIZE),
scale=0.1,
zero_point=0,
bit_width=8,
signed=True,
training=True)
with torch.no_grad():
out = mod(x)
graph_model = value_trace(mod, value_args={'x': x_trace})
graph_out = graph_model(x)
assert graph_out.isclose(out).all().item()
assert torch.allclose(out, graph_out)

0 comments on commit 4d62dc9

Please sign in to comment.