From 20aba2e9a3264d64c1255d2b8b726c7486ee0868 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 13 Feb 2024 11:42:13 +0000 Subject: [PATCH 1/7] added test for issue --- tests/brevitas/graph/test_calibration.py | 28 ++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 58561cbd1..efcd1812b 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -187,3 +187,31 @@ def simple_hook(mod, inp, out): ) # In bias_correction mode, the input to each layer is equal to the FP output of the previous layer assert (inputs[1] == fp_outs[1, 0, :]).all( ) # In bias_correction mode, the input to each layer is equal to the FP output of the previous layer + + +def test_import_bias_correction(): + + class SimpleQuantLinearNet(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.net = nn.Sequential(qnn.QuantLinear(IN_CH, OUT_CH, bias=False)) + + def forward(self, inp): + return self.net(inp) + + model = SimpleQuantLinearNet() + + with bias_correction_mode(model): + model(torch.randn((1, IN_CH))) + + for m in model.modules(): + if isinstance(m, qnn.QuantLinear): + assert m.bias is not None + + new_model = SimpleQuantLinearNet() + new_model.load_state_dict(model.state_dict()) + + for m in new_model.modules(): + if isinstance(m, qnn.QuantLinear): + assert m.bias is not None From addf1417aaa9826dcadeeae0413e4008246c14ad Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 13 Feb 2024 16:58:13 +0000 Subject: [PATCH 2/7] first pass version of handling unexpected bias keys --- src/brevitas/graph/calibrate.py | 22 ++++++++++++++++++++++ tests/brevitas/graph/test_calibration.py | 4 +++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 9eed8b38e..bcda54741 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -85,6 +85,28 @@ def __exit__(self, type, value, traceback): self.model, is_training=self.previous_training_state, quantization_enabled=True) +class allow_unexpected_bias_keys: + + def __init__(self, model): + self.model = model + + def __enter__(self): + self.tracked_modules = [] + for module in self.model.modules(): + if issubclass(type(module), QuantWBIOL): + if module.bias is None: + module.register_parameter( + 'bias', + nn.Parameter(torch.empty(module.weight.shape[0])).to(module.weight.device)) + self.tracked_modules.append(module) + + def __exit__(self, type, value, traceback): + for module in self.tracked_modules: + # empty tensor has a numel result of 0 + if torch.numel(module.bias) == 0: + module.bias = None + + class bias_correction_mode: def __init__(self, model, enabled=True): diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index efcd1812b..23d4a63e7 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +from brevitas.graph.calibrate import allow_unexpected_bias_keys from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode import brevitas.nn as qnn @@ -210,7 +211,8 @@ def forward(self, inp): assert m.bias is not None new_model = SimpleQuantLinearNet() - new_model.load_state_dict(model.state_dict()) + with allow_unexpected_bias_keys(new_model): + new_model.load_state_dict(model.state_dict()) for m in new_model.modules(): if isinstance(m, qnn.QuantLinear): From 11b508569f57b4bb3baccc51faefa7544159e5f8 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 13 Feb 2024 17:13:36 +0000 Subject: [PATCH 3/7] added class name to __all__ in calibrate.py --- src/brevitas/graph/calibrate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index bcda54741..9e0fd60e9 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -24,7 +24,11 @@ from .base import Transform __all__ = [ - 'ClipFloatWeights', 'DisableEnableQuantization', 'bias_correction_mode', 'calibration_mode'] + 'ClipFloatWeights', + 'DisableEnableQuantization', + 'bias_correction_mode', + 'calibration_mode', + 'allow_unexpected_bias_keys'] _PARAM_PROXIES = (WeightQuantProxyFromInjector, BiasQuantProxyFromInjector) From 99bc408d22505197721893c6b5b4e082be85de16 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 20 Feb 2024 09:05:16 +0000 Subject: [PATCH 4/7] added flag to bias correction mode context manager and the _BiasCorrection class to ignore any layers without a bias --- src/brevitas/graph/calibrate.py | 9 +++++---- tests/brevitas/graph/test_calibration.py | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 9e0fd60e9..ee62b85ef 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -113,9 +113,9 @@ def __exit__(self, type, value, traceback): class bias_correction_mode: - def __init__(self, model, enabled=True): + def __init__(self, model, enabled=True, only_layers_with_bias=False): self.model = model - self.bias_correction = _BiasCorrection() + self.bias_correction = _BiasCorrection(only_layers_with_bias=only_layers_with_bias) self.enabled = enabled self.hooks = [] @@ -235,7 +235,7 @@ class _BiasCorrection(DisableEnableQuantization): LAYERS = (QuantWBIOL,) - def __init__(self, layers=LAYERS): + def __init__(self, layers=LAYERS, only_layers_with_bias=False): super(_BiasCorrection, self).__init__() self.layers = layers self.iterations = {} @@ -243,6 +243,7 @@ def __init__(self, layers=LAYERS): self.float_mean_map = {} self.collect_float_mean_hooks = [] self.correct_bias_hooks = [] + self.only_layers_with_bias = only_layers_with_bias def compute_mean(self, inp, transpose_dim): inp = inp.transpose(0, transpose_dim) @@ -274,7 +275,7 @@ def apply_correction(self, model): correction = self.correction_map[name] / self.iterations[name] if module.bias is not None: module.bias.data += correction - else: + elif self.only_layers_with_bias is False: module.register_parameter( 'bias', nn.Parameter(correction).to(module.weight.device)) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 23d4a63e7..86546dd67 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -217,3 +217,24 @@ def forward(self, inp): for m in new_model.modules(): if isinstance(m, qnn.QuantLinear): assert m.bias is not None + + +def test_bias_correction_flag(): + + class SimpleQuantLinearNet(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.net = nn.Sequential(qnn.QuantLinear(IN_CH, OUT_CH, bias=False)) + + def forward(self, inp): + return self.net(inp) + + model = SimpleQuantLinearNet() + + with bias_correction_mode(model, only_layers_with_bias=True): + model(torch.randn((1, IN_CH))) + + for m in model.modules(): + if isinstance(m, qnn.QuantLinear): + assert m.bias is None From 49c9917409d800d14198592cd5650a9a3fe3dd96 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 20 Feb 2024 14:00:19 +0000 Subject: [PATCH 5/7] changed name of parameter --- src/brevitas/graph/calibrate.py | 10 +++++----- tests/brevitas/graph/test_calibration.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index ee62b85ef..d2d983691 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -113,9 +113,9 @@ def __exit__(self, type, value, traceback): class bias_correction_mode: - def __init__(self, model, enabled=True, only_layers_with_bias=False): + def __init__(self, model, enabled=True, skip_if_no_bias=False): self.model = model - self.bias_correction = _BiasCorrection(only_layers_with_bias=only_layers_with_bias) + self.bias_correction = _BiasCorrection(skip_if_no_bias=skip_if_no_bias) self.enabled = enabled self.hooks = [] @@ -235,7 +235,7 @@ class _BiasCorrection(DisableEnableQuantization): LAYERS = (QuantWBIOL,) - def __init__(self, layers=LAYERS, only_layers_with_bias=False): + def __init__(self, layers=LAYERS, skip_if_no_bias=False): super(_BiasCorrection, self).__init__() self.layers = layers self.iterations = {} @@ -243,7 +243,7 @@ def __init__(self, layers=LAYERS, only_layers_with_bias=False): self.float_mean_map = {} self.collect_float_mean_hooks = [] self.correct_bias_hooks = [] - self.only_layers_with_bias = only_layers_with_bias + self.skip_if_no_bias = skip_if_no_bias def compute_mean(self, inp, transpose_dim): inp = inp.transpose(0, transpose_dim) @@ -275,7 +275,7 @@ def apply_correction(self, model): correction = self.correction_map[name] / self.iterations[name] if module.bias is not None: module.bias.data += correction - elif self.only_layers_with_bias is False: + elif self.skip_if_no_bias is False: module.register_parameter( 'bias', nn.Parameter(correction).to(module.weight.device)) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 86546dd67..fed0a5367 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -232,7 +232,7 @@ def forward(self, inp): model = SimpleQuantLinearNet() - with bias_correction_mode(model, only_layers_with_bias=True): + with bias_correction_mode(model, skip_if_no_bias=True): model(torch.randn((1, IN_CH))) for m in model.modules(): From f9b3445fc5d5bfb212a7d6faaedee67d352191e4 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 21 Feb 2024 09:03:26 +0000 Subject: [PATCH 6/7] renamed context manager to load_quant_model --- src/brevitas/graph/calibrate.py | 4 ++-- tests/brevitas/graph/test_calibration.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index d2d983691..0e1920ef9 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -28,7 +28,7 @@ 'DisableEnableQuantization', 'bias_correction_mode', 'calibration_mode', - 'allow_unexpected_bias_keys'] + 'load_quant_model'] _PARAM_PROXIES = (WeightQuantProxyFromInjector, BiasQuantProxyFromInjector) @@ -89,7 +89,7 @@ def __exit__(self, type, value, traceback): self.model, is_training=self.previous_training_state, quantization_enabled=True) -class allow_unexpected_bias_keys: +class load_quant_model: def __init__(self, model): self.model = model diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index fed0a5367..a68485920 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn -from brevitas.graph.calibrate import allow_unexpected_bias_keys from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.calibrate import load_quant_model import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint from tests.brevitas.hyp_helper import float_tensor_random_size_st @@ -211,7 +211,7 @@ def forward(self, inp): assert m.bias is not None new_model = SimpleQuantLinearNet() - with allow_unexpected_bias_keys(new_model): + with load_quant_model(new_model): new_model.load_state_dict(model.state_dict()) for m in new_model.modules(): From e13c29e7f866d0b1251024731fe2387536221bfd Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:24:09 +0000 Subject: [PATCH 7/7] removed tracked_modules creation into __init__ --- src/brevitas/graph/calibrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 0e1920ef9..bb435b7ef 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -93,9 +93,9 @@ class load_quant_model: def __init__(self, model): self.model = model + self.tracked_modules = [] def __enter__(self): - self.tracked_modules = [] for module in self.model.modules(): if issubclass(type(module), QuantWBIOL): if module.bias is None: