Skip to content

Commit

Permalink
calibration with reference values
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 24, 2024
1 parent 5d5dfce commit 2b8d1f2
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math

from hypothesis import given
import pytest_cases
from pytest_cases import fixture
import torch
import torch.nn as nn
Expand All @@ -13,6 +14,7 @@
from brevitas.graph.calibrate import load_quant_model_mode
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue
Expand All @@ -21,6 +23,10 @@
IN_CH = 8
OUT_CH = 16
BATCH = 1
REFERENCE_SCALES = {
'int_quant': (0.00935234408825635910, 0.00859776325523853302),
'fp_quant': (0.00249395845457911491, 0.00190271728206425905)}
REFERNECE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]])


def compute_quantile(x, q):
Expand Down Expand Up @@ -65,6 +71,42 @@ def forward(self, x):
assert torch.allclose(expected_scale, scale)


QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat}


@pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys())
def test_scale_factors_ptq_calibration_reference(act_quant):

reference, act_quant = act_quant

class TestModel(nn.Module):

def __init__(self):
super(TestModel, self).__init__()
self.act = qnn.QuantReLU(act_quant=act_quant)
self.linear = qnn.QuantLinear(3, 8)
self.act_1 = qnn.QuantIdentity(act_quant=act_quant)

def forward(self, x):
o = self.act(x)
o = self.linear(o)
return self.act_1(o)

# Reference input
inp = REFERNECE_INP
model = TestModel()
model.eval()
with torch.no_grad():
with calibration_mode(model):
model(inp)

computed_scale = model.act.act_quant.scale(), model.act_1.act_quant.scale()
reference_values = REFERENCE_SCALES[reference]
assert all([
torch.allclose(comp, torch.tensor(ref)) for comp,
ref in zip(computed_scale, reference_values)])


def test_calibration_training_state():

class TestModel(nn.Module):
Expand Down

0 comments on commit 2b8d1f2

Please sign in to comment.