From 9596911ccd07f794025020e198628b7cea5f59ee Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 13 Dec 2024 15:40:21 +0000 Subject: [PATCH 1/2] Fix (quantizer): fix shapes for mse/hqo optimization --- src/brevitas/core/stats/stats_op.py | 1 + src/brevitas/export/inference/handler.py | 1 + src/brevitas/quant/base.py | 3 --- src/brevitas_examples/llm/main.py | 1 - 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index ac520a707..84811b6e4 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -501,6 +501,7 @@ def evaluate_loss(self, x, candidate): self.set_observer_mode(False) quant_value = self.proxy_forward(x) quant_value = _unpack_quant_tensor(quant_value) + quant_value = quant_value.view(x.shape) loss = self.mse_loss_fn(x, quant_value) self.set_local_loss_mode(False) self.restore_observer_mode() diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 1416014ec..8f05b036c 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -13,6 +13,7 @@ from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase +from brevitas.proxy.groupwise_int_runtime_quant import GroupwiseActQuantProxyFromInjector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 7cd050e21..c4fb36bf7 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -538,7 +538,6 @@ class MSEAsymmetricScale(ExtendedInjector): mse_scale = MSEAsymmetricScaleSubInjector scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS - scaling_stats_input_view_shape_impl = nn.Identity() @value def inner_stats_input_view_shape_impl(scaling_per_output): @@ -561,7 +560,6 @@ class MSESymmetricScale(ExtendedInjector): mse_scale = MSESymmetricScaleSubInjector scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS - scaling_stats_input_view_shape_impl = nn.Identity() @value def inner_stats_input_view_shape_impl(scaling_per_output): @@ -583,7 +581,6 @@ class MSEZeroPoint(ExtendedInjector): """ mse_zero_point = MSEZeroPointSubInjector - zero_point_stats_input_view_shape_impl = nn.Identity() @value def inner_stats_input_view_shape_impl(scaling_per_output): diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 65c5334d7..5ecf7b376 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -374,7 +374,6 @@ def main(args): if args.bias_corr: model = add_zero_bias_to_linear(model) - model = offload_model(model) with torch.no_grad(): From 88e81a849d4cd3d8ca70fc6aac956e14653d25b5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 13 Dec 2024 16:43:00 +0100 Subject: [PATCH 2/2] Update main.py --- src/brevitas_examples/llm/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 5ecf7b376..65c5334d7 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -374,6 +374,7 @@ def main(args): if args.bias_corr: model = add_zero_bias_to_linear(model) + model = offload_model(model) with torch.no_grad():