diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 86ef58b77..1da67ff3a 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -4,6 +4,7 @@ import math from hypothesis import given +import pytest_cases from pytest_cases import fixture import torch import torch.nn as nn @@ -13,6 +14,7 @@ from brevitas.graph.calibrate import load_quant_model_mode import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations from brevitas.utils.torch_utils import kthvalue @@ -21,6 +23,10 @@ IN_CH = 8 OUT_CH = 16 BATCH = 1 +REFERENCE_SCALES = { + 'int_quant': (0.00935234408825635910, 0.00859776325523853302), + 'fp_quant': (0.00249395845457911491, 0.00190271728206425905)} +REFERNECE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) def compute_quantile(x, q): @@ -65,6 +71,42 @@ def forward(self, x): assert torch.allclose(expected_scale, scale) +QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat} + + +@pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys()) +def test_scale_factors_ptq_calibration_reference(act_quant): + + reference, act_quant = act_quant + + class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.act = qnn.QuantReLU(act_quant=act_quant) + self.linear = qnn.QuantLinear(3, 8) + self.act_1 = qnn.QuantIdentity(act_quant=act_quant) + + def forward(self, x): + o = self.act(x) + o = self.linear(o) + return self.act_1(o) + + # Reference input + inp = REFERNECE_INP + model = TestModel() + model.eval() + with torch.no_grad(): + with calibration_mode(model): + model(inp) + + computed_scale = model.act.act_quant.scale(), model.act_1.act_quant.scale() + reference_values = REFERENCE_SCALES[reference] + assert all([ + torch.allclose(comp, torch.tensor(ref)) for comp, + ref in zip(computed_scale, reference_values)]) + + def test_calibration_training_state(): class TestModel(nn.Module):