Skip to content

Commit

Permalink
Po2 Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 25, 2024
1 parent 25b95fc commit 6398a6b
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.calibrate import load_quant_model_mode
from brevitas.inject.enum import RestrictValueType
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
Expand All @@ -27,7 +28,9 @@
BATCH = 1
REFERENCE_SCALES = {
'int_quant': (0.00935234408825635910, 0.01362917013466358185),
'fp_quant': (0.00249395845457911491, 0.00363444536924362183)}
'fp_quant': (0.00249395845457911491, 0.00363444536924362183),
'int_po2_quant': (0.015625, 0.015625),
'fp_po2_quant': (0.001953125, 0.00390625),}
REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]])
REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010],
[1.4573, -0.9074, -0.2708]])
Expand Down Expand Up @@ -75,7 +78,15 @@ def forward(self, x):
assert torch.allclose(expected_scale, scale)


QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat}
class Fp8e4m3ActPerTensorFixedPoint(Fp8e4m3ActPerTensorFloat):
restrict_scaling_type = RestrictValueType.POWER_OF_TWO


QUANTS = {
'int_quant': Int8ActPerTensorFloat,
'fp_quant': Fp8e4m3ActPerTensorFloat,
'int_po2_quant': Int8ActPerTensorFixedPoint,
'fp_po2_quant': Fp8e4m3ActPerTensorFixedPoint}


@pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys())
Expand Down

0 comments on commit 6398a6b

Please sign in to comment.