From 10fdfe18e113cf9c977eb1f73b37f56a46571d29 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 13 Dec 2024 16:02:27 +0000 Subject: [PATCH 1/9] Feat (brevitas_examples/llm): inference_mode support --- src/brevitas/export/inference/handler.py | 68 +++++++++++++++++-- src/brevitas/export/inference/manager.py | 6 +- .../proxy/groupwise_int_runtime_quant.py | 7 +- src/brevitas_examples/llm/main.py | 7 +- 4 files changed, 75 insertions(+), 13 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 1416014ec..35201d45d 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -7,15 +7,18 @@ import torch +from brevitas import is_dynamo_compiling from brevitas.function.ops import max_float from brevitas.function.ops import max_int from brevitas.function.ops import min_int from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector +from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector from brevitas.utils.torch_utils import float_internal_scale @@ -40,6 +43,11 @@ def dequantize(self, x): class IntInferencetHandler(InferenceHandler): handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector) + def __init__(self): + super().__init__() + self.register_buffer('scale', torch.ones(1)) + self.register_buffer('zero_point', torch.ones(0)) + def attach_debug_info(self, module): pass @@ -51,12 +59,11 @@ def prepare_for_export(self, module): self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) - def quantize(self, x): - return torch.clamp( - torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp) + def quantize(self, x, scale, zero_point): + return torch.clamp(torch.round(x / scale + zero_point), self.min_clamp, self.max_clamp) - def dequantize(self, x): - return (x - self.zero_point) * self.scale + def dequantize(self, x, scale, zero_point): + return (x - zero_point) * scale def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width @@ -65,6 +72,10 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: class IntWeightInferencetHandler(IntInferencetHandler): handled_layer = WeightQuantProxyFromInjector + def __init__(self): + super().__init__() + self.register_buffer('cached_weight', torch.ones(1)) + def prepare_for_export(self, module): if module.is_quant_enabled: self.cached_weight = None @@ -76,7 +87,8 @@ def forward(self, x) -> Tuple[torch.Tensor]: if self.cached_weight is not None: x = self.cached_weight else: - x = self.dequantize(self.quantize(x)) + x = self.dequantize( + self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point) return x, self.scale, self.zero_point, self.bit_width @@ -114,7 +126,6 @@ def quantize(self, x): inf_mask = x.isinf() p_max_val_mask = x > self.max_value n_max_val_mask = -x > self.max_value - # Quantize x = x / self.scale internal_scale = float_internal_scale( @@ -151,3 +162,46 @@ def forward(self, x) -> Tuple[torch.Tensor]: else: x = self.dequantize(self.quantize(x)) return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values + + +class GroupwiseIntInferenceHandler(IntInferencetHandler): + handled_layer = GroupwiseActQuantProxyFromInjector + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.module_forward = module.fused_activation_quant_proxy + self.flattened_view = module.apply_input_view + self.input_view = module.input_view_impl + self.group_dim = module.group_dim + + def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + x, *other = self.module_forward(x) + if is_dynamo_compiling: + start_dim = self.group_dim if self.group_dim != -1 else -2 + x = x.flatten(start_dim, start_dim + 1) + return x, *other + + +class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler): + handled_layer = GroupwiseWeightQuantProxyFromInjector + + def prepare_for_export(self, module): + super().prepare_for_export(module) + self.input_view = module.input_view_impl + self.flattened_view = module.apply_input_view + + def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + x = self.input_view(x) + if self.scale.shape != (): + scale = self.input_view(self.scale) + else: + scale = self.scale + + if self.zero_point.shape != (): + zero_point = self.input_view(self.zero_point) + else: + zero_point = self.zero_point + out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) + if is_dynamo_compiling: + out = self.flattened_view(out) + return out, scale, zero_point, self.bit_width diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index 936106884..ab09e83e5 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -6,6 +6,8 @@ from brevitas.export.inference.handler import FloatInferencetHandler from brevitas.export.inference.handler import FloatWeightInferencetHandler +from brevitas.export.inference.handler import GroupwiseIntInferenceHandler +from brevitas.export.inference.handler import GroupwiseIntWeightInferenceHandler from brevitas.export.inference.handler import IntInferencetHandler from brevitas.export.inference.handler import IntWeightInferencetHandler from brevitas.export.manager import _set_proxy_export_handler @@ -93,7 +95,9 @@ class InferenceManager(BaseManager): IntInferencetHandler, FloatInferencetHandler, IntWeightInferencetHandler, - FloatWeightInferencetHandler] + FloatWeightInferencetHandler, + GroupwiseIntInferenceHandler, + GroupwiseIntWeightInferenceHandler] @classmethod def set_export_mode(cls, model: Module, enabled: bool): diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 453cb3f9b..ea91d1996 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -12,10 +12,11 @@ class GroupwiseActQuantProxyFromInjector(ActQuantProxyFromInjector): def __init__(self, quant_layer, quant_injector): super().__init__(quant_layer, quant_injector) self.cache_class = _CachedIOGroupwiseInt + self.group_dim = self.quant_injector.group_dim - @property - def group_dim(self): - return self.quant_injector.group_dim + # @property + # def group_dim(self): + # return self.quant_injector.group_dim @property def group_size(self): diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 6004ec97d..bd9669bd8 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -14,6 +14,7 @@ from transformers.utils.fx import _SUPPORTED_MODELS from brevitas.export import export_torch_qcdq +from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation @@ -421,8 +422,10 @@ def main(args): if args.eval and not args.no_quantize: print("Model eval...") - quant_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + quant_ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") remove_hooks(model) From d3bb48031fcb79ef94de19c99710b10403acc1cf Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 14 Dec 2024 15:22:55 +0000 Subject: [PATCH 2/9] fix --- src/brevitas/export/inference/handler.py | 2 +- src/brevitas/proxy/groupwise_int_runtime_quant.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 35201d45d..e09e9a1d3 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -66,7 +66,7 @@ def dequantize(self, x, scale, zero_point): return (x - zero_point) * scale def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: - return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width + return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.bit_width class IntWeightInferencetHandler(IntInferencetHandler): diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index ea91d1996..453cb3f9b 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -12,11 +12,10 @@ class GroupwiseActQuantProxyFromInjector(ActQuantProxyFromInjector): def __init__(self, quant_layer, quant_injector): super().__init__(quant_layer, quant_injector) self.cache_class = _CachedIOGroupwiseInt - self.group_dim = self.quant_injector.group_dim - # @property - # def group_dim(self): - # return self.quant_injector.group_dim + @property + def group_dim(self): + return self.quant_injector.group_dim @property def group_size(self): From dadfa1ec75c398fb73fe672755f3b7ea4e744503 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 14 Dec 2024 15:27:58 +0000 Subject: [PATCH 3/9] Precommit --- src/brevitas/export/inference/handler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index e09e9a1d3..ae52e4fc9 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -179,7 +179,8 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: if is_dynamo_compiling: start_dim = self.group_dim if self.group_dim != -1 else -2 x = x.flatten(start_dim, start_dim + 1) - return x, *other + output_args = tuple([x] + list(other)) + return output_args class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler): From ed6a20ace453c33cdc1c564b17490d4545eb0adb Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 14 Dec 2024 15:37:40 +0000 Subject: [PATCH 4/9] float Groupwsie --- src/brevitas/export/inference/handler.py | 72 +++++++++++++++++++----- 1 file changed, 58 insertions(+), 14 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index ae52e4fc9..8c58f5125 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -14,6 +14,9 @@ from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase +from brevitas.proxy.groupwise_float_parameter_quant import \ + GroupwiseWeightFloatQuantProxyFromInjector +from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector @@ -92,6 +95,48 @@ def forward(self, x) -> Tuple[torch.Tensor]: return x, self.scale, self.zero_point, self.bit_width +class GroupwiseIntInferenceHandler(IntInferencetHandler): + handled_layer = GroupwiseActQuantProxyFromInjector + + def prepare_for_export(self, module): + super().prepare_for_export(module) + if module.is_quant_enabled: + self.module_forward = module.fused_activation_quant_proxy + self.group_dim = module.group_dim + + def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + x, *other = self.module_forward(x) + if is_dynamo_compiling: + start_dim = self.group_dim if self.group_dim != -1 else -2 + x = x.flatten(start_dim, start_dim + 1) + output_args = tuple([x] + list(other)) + return output_args + + +class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler): + handled_layer = GroupwiseWeightQuantProxyFromInjector + + def prepare_for_export(self, module): + super().prepare_for_export(module) + self.input_view = module.input_view_impl + self.flattened_view = module.apply_input_view + + def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + x = self.input_view(x) + if self.scale.shape != (): + scale = self.input_view(self.scale) + else: + scale = self.scale + if self.zero_point.shape != (): + zero_point = self.input_view(self.zero_point) + else: + zero_point = self.zero_point + out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) + if is_dynamo_compiling: + out = self.flattened_view(out) + return out, scale, zero_point, self.bit_width + + class FloatInferencetHandler(InferenceHandler): handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector) @@ -121,13 +166,13 @@ def prepare_for_export(self, module): self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value - def quantize(self, x): + def quantize(self, x, scale, zero_point): # Compute masks inf_mask = x.isinf() p_max_val_mask = x > self.max_value n_max_val_mask = -x > self.max_value # Quantize - x = x / self.scale + x = x / scale internal_scale = float_internal_scale( x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps) x = internal_scale * self.float_to_int_impl(x / internal_scale) @@ -139,11 +184,11 @@ def quantize(self, x): return x - def dequantize(self, x): - return (x - self.zero_point) * self.scale + def dequantize(self, x, scale, zero_point): + return (x - zero_point) * scale def forward(self, x) -> Tuple[torch.Tensor]: - return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values + return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values class FloatWeightInferencetHandler(FloatInferencetHandler): @@ -160,19 +205,19 @@ def forward(self, x) -> Tuple[torch.Tensor]: if self.cached_weight is not None: x = self.cached_weight else: - x = self.dequantize(self.quantize(x)) + x = self.dequantize( + self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point) return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values -class GroupwiseIntInferenceHandler(IntInferencetHandler): - handled_layer = GroupwiseActQuantProxyFromInjector +class GroupwiseFloatInferenceHandler(FloatInferencetHandler): + handled_layer = GroupwiseActFloatQuantProxyFromInjector def prepare_for_export(self, module): + super().prepare_for_export(module) if module.is_quant_enabled: self.module_forward = module.fused_activation_quant_proxy - self.flattened_view = module.apply_input_view - self.input_view = module.input_view_impl - self.group_dim = module.group_dim + self.group_dim = module.group_dim def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: x, *other = self.module_forward(x) @@ -183,8 +228,8 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: return output_args -class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler): - handled_layer = GroupwiseWeightQuantProxyFromInjector +class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler): + handled_layer = GroupwiseWeightFloatQuantProxyFromInjector def prepare_for_export(self, module): super().prepare_for_export(module) @@ -197,7 +242,6 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: scale = self.input_view(self.scale) else: scale = self.scale - if self.zero_point.shape != (): zero_point = self.input_view(self.zero_point) else: From 9c07ca7f2997a932bc9f7731ccd940fe44248c9c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 14 Dec 2024 15:41:43 +0000 Subject: [PATCH 5/9] import --- src/brevitas/export/inference/manager.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index ab09e83e5..545cdc093 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -6,6 +6,8 @@ from brevitas.export.inference.handler import FloatInferencetHandler from brevitas.export.inference.handler import FloatWeightInferencetHandler +from brevitas.export.inference.handler import GroupwiseFloatInferenceHandler +from brevitas.export.inference.handler import GroupwiseFloatWeightInferenceHandler from brevitas.export.inference.handler import GroupwiseIntInferenceHandler from brevitas.export.inference.handler import GroupwiseIntWeightInferenceHandler from brevitas.export.inference.handler import IntInferencetHandler @@ -97,7 +99,9 @@ class InferenceManager(BaseManager): IntWeightInferencetHandler, FloatWeightInferencetHandler, GroupwiseIntInferenceHandler, - GroupwiseIntWeightInferenceHandler] + GroupwiseIntWeightInferenceHandler, + GroupwiseFloatInferenceHandler, + GroupwiseFloatWeightInferenceHandler] @classmethod def set_export_mode(cls, model: Module, enabled: bool): From 68879ca7fd263475aae25f4a823eed4b4b9aaa10 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 14 Dec 2024 16:27:41 +0000 Subject: [PATCH 6/9] fix --- src/brevitas/export/inference/handler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 8c58f5125..3ab6ba0a0 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -99,7 +99,6 @@ class GroupwiseIntInferenceHandler(IntInferencetHandler): handled_layer = GroupwiseActQuantProxyFromInjector def prepare_for_export(self, module): - super().prepare_for_export(module) if module.is_quant_enabled: self.module_forward = module.fused_activation_quant_proxy self.group_dim = module.group_dim @@ -214,7 +213,6 @@ class GroupwiseFloatInferenceHandler(FloatInferencetHandler): handled_layer = GroupwiseActFloatQuantProxyFromInjector def prepare_for_export(self, module): - super().prepare_for_export(module) if module.is_quant_enabled: self.module_forward = module.fused_activation_quant_proxy self.group_dim = module.group_dim From dd80cf40186d84ec5d0faf4dd0d93d5ae50251e7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 14 Dec 2024 22:12:05 +0000 Subject: [PATCH 7/9] Fix compile --- src/brevitas/export/inference/handler.py | 10 +++++----- .../quant_tensor/groupwise_float_quant_tensor.py | 1 - 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 3ab6ba0a0..a582d95a9 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -105,7 +105,7 @@ def prepare_for_export(self, module): def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: x, *other = self.module_forward(x) - if is_dynamo_compiling: + if is_dynamo_compiling(): start_dim = self.group_dim if self.group_dim != -1 else -2 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) @@ -131,7 +131,7 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: else: zero_point = self.zero_point out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) - if is_dynamo_compiling: + if is_dynamo_compiling(): out = self.flattened_view(out) return out, scale, zero_point, self.bit_width @@ -219,7 +219,7 @@ def prepare_for_export(self, module): def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: x, *other = self.module_forward(x) - if is_dynamo_compiling: + if is_dynamo_compiling(): start_dim = self.group_dim if self.group_dim != -1 else -2 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) @@ -245,6 +245,6 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: else: zero_point = self.zero_point out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) - if is_dynamo_compiling: + if is_dynamo_compiling(): out = self.flattened_view(out) - return out, scale, zero_point, self.bit_width + return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 16f75c49e..b507d3fe3 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -92,7 +92,6 @@ def expand(self): curr_shape = self.value_.shape start_dim = self.group_dim if self.group_dim != -1 else -2 new_value = self.value_.flatten(start_dim, start_dim + 1) - new_value = self.value_.flatten(start_dim, start_dim + 1) if self.scale_.shape != (): new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) else: From b08f120f1088f998ec5961c0d6d529b8d3b5a7d8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 15 Dec 2024 20:47:33 +0000 Subject: [PATCH 8/9] Expand inference_mode compatibility --- src/brevitas/export/inference/handler.py | 25 ++++++++++++++++++++---- src/brevitas/export/inference/manager.py | 4 +++- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index a582d95a9..d03bdefde 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -21,6 +21,7 @@ from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector +from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector from brevitas.utils.torch_utils import float_internal_scale @@ -95,6 +96,17 @@ def forward(self, x) -> Tuple[torch.Tensor]: return x, self.scale, self.zero_point, self.bit_width +class DynamicIntInferenceHandler(IntInferencetHandler): + handled_layer = DynamicActQuantProxyFromInjector + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.module_forward = module.fused_activation_quant_proxy + + def forward(self, x, ununsed_scale=None): + return self.module_forward(x) + + class GroupwiseIntInferenceHandler(IntInferencetHandler): handled_layer = GroupwiseActQuantProxyFromInjector @@ -119,9 +131,10 @@ def prepare_for_export(self, module): super().prepare_for_export(module) self.input_view = module.input_view_impl self.flattened_view = module.apply_input_view + if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: + self.cached_weight = module._cached_weight.quant_tensor.value_ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: - x = self.input_view(x) if self.scale.shape != (): scale = self.input_view(self.scale) else: @@ -130,9 +143,13 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: zero_point = self.input_view(self.zero_point) else: zero_point = self.zero_point - out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) - if is_dynamo_compiling(): - out = self.flattened_view(out) + if self.cached_weight is not None: + out = self.cached_weight + else: + x = self.input_view(x) + out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) + if is_dynamo_compiling(): + out = self.flattened_view(out) return out, scale, zero_point, self.bit_width diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index 545cdc093..b78a888f2 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -4,6 +4,7 @@ from torch.nn import Module import torch.nn as nn +from brevitas.export.inference.handler import DynamicIntInferenceHandler from brevitas.export.inference.handler import FloatInferencetHandler from brevitas.export.inference.handler import FloatWeightInferencetHandler from brevitas.export.inference.handler import GroupwiseFloatInferenceHandler @@ -69,6 +70,7 @@ def __exit__(self, type, value, traceback): # Disable all caching # deactivate export mode # restore return quant tensor + InferenceManager.set_export_mode(self.model, enabled=False) self.model.apply( lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False)) self.model.apply( @@ -76,7 +78,6 @@ def __exit__(self, type, value, traceback): if self.cache_quant_weight: self.model.apply( lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False)) - InferenceManager.set_export_mode(self.model, enabled=False) restore_return_quant_tensor(self.model, self.return_quant_tensor_state) def hook(self, module, inp, out): @@ -95,6 +96,7 @@ def hook(self, module, inp, out): class InferenceManager(BaseManager): handlers = [ IntInferencetHandler, + DynamicIntInferenceHandler, FloatInferencetHandler, IntWeightInferencetHandler, FloatWeightInferencetHandler, From ca976436539af1f8233668479f62017917d77df0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 16 Dec 2024 16:03:19 +0000 Subject: [PATCH 9/9] Fix tests --- src/brevitas/common.py | 5 +- src/brevitas/export/inference/handler.py | 101 ++++++++++++++--------- 2 files changed, 64 insertions(+), 42 deletions(-) diff --git a/src/brevitas/common.py b/src/brevitas/common.py index 62ddf219a..7a0ec4483 100644 --- a/src/brevitas/common.py +++ b/src/brevitas/common.py @@ -3,6 +3,7 @@ from abc import ABCMeta from abc import abstractmethod +import warnings from brevitas import config @@ -29,8 +30,8 @@ def export_mode(self): @export_mode.setter def export_mode(self, value): if value and config.JIT_ENABLED: - raise RuntimeError( - "Export mode with BREVITAS_JIT is currently not supported. Save the model' " + warnings.warn( + "Export mode with BREVITAS_JIT might fail. If so, save the model' " "state_dict to a .pth, load it back with BREVITAS_JIT=0, and call export.") if value and self.training: raise RuntimeError("Can't enter export mode during training, only during inference") diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index d03bdefde..32c1ac5ac 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -6,6 +6,8 @@ from typing import Tuple import torch +from torch import Tensor +import torch.nn as nn from brevitas import is_dynamo_compiling from brevitas.function.ops import max_float @@ -28,19 +30,19 @@ class InferenceHandler(torch.nn.Module, ABC): - def attach_debug_info(self, module): + def attach_debug_info(self, module: nn.Module): pass @abstractmethod - def prepare_for_export(self, module): + def prepare_for_export(self, module: nn.Module): pass @abstractmethod - def quantize(self, x): + def quantize(self, x: Tensor): pass @abstractmethod - def dequantize(self, x): + def dequantize(self, x: Tensor): pass @@ -52,10 +54,7 @@ def __init__(self): self.register_buffer('scale', torch.ones(1)) self.register_buffer('zero_point', torch.ones(0)) - def attach_debug_info(self, module): - pass - - def prepare_for_export(self, module): + def prepare_for_export(self, module: nn.Module): if module.is_quant_enabled: self.scale = module.scale() self.zero_point = module.zero_point().to(self.scale.device) @@ -63,13 +62,13 @@ def prepare_for_export(self, module): self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) - def quantize(self, x, scale, zero_point): + def quantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor]: return torch.clamp(torch.round(x / scale + zero_point), self.min_clamp, self.max_clamp) - def dequantize(self, x, scale, zero_point): + def dequantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor: return (x - zero_point) * scale - def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.bit_width @@ -80,14 +79,15 @@ def __init__(self): super().__init__() self.register_buffer('cached_weight', torch.ones(1)) - def prepare_for_export(self, module): + def prepare_for_export(self, module: nn.Module): + super().prepare_for_export(module) if module.is_quant_enabled: - self.cached_weight = None - super().prepare_for_export(module) if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: self.cached_weight = module._cached_weight.value + else: + self.cached_weight = None - def forward(self, x) -> Tuple[torch.Tensor]: + def forward(self, x: Tensor) -> Tuple[Tensor]: if self.cached_weight is not None: x = self.cached_weight else: @@ -99,11 +99,11 @@ def forward(self, x) -> Tuple[torch.Tensor]: class DynamicIntInferenceHandler(IntInferencetHandler): handled_layer = DynamicActQuantProxyFromInjector - def prepare_for_export(self, module): + def prepare_for_export(self, module: nn.Module): if module.is_quant_enabled: self.module_forward = module.fused_activation_quant_proxy - def forward(self, x, ununsed_scale=None): + def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: return self.module_forward(x) @@ -115,7 +115,7 @@ def prepare_for_export(self, module): self.module_forward = module.fused_activation_quant_proxy self.group_dim = module.group_dim - def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: x, *other = self.module_forward(x) if is_dynamo_compiling(): start_dim = self.group_dim if self.group_dim != -1 else -2 @@ -129,12 +129,15 @@ class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler): def prepare_for_export(self, module): super().prepare_for_export(module) - self.input_view = module.input_view_impl - self.flattened_view = module.apply_input_view - if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: - self.cached_weight = module._cached_weight.quant_tensor.value_ + if module.is_quant_enabled: + self.input_view = module.input_view_impl + self.flattened_view = module.apply_input_view + if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: + self.cached_weight = module._cached_weight.quant_tensor.value_ + else: + self.cached_weight = None - def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + def forward(self, x: Tensor) -> Tuple[Tensor]: if self.scale.shape != (): scale = self.input_view(self.scale) else: @@ -156,6 +159,11 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: class FloatInferencetHandler(InferenceHandler): handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector) + def __init__(self): + super().__init__() + self.register_buffer('scale', torch.ones(1)) + self.register_buffer('zero_point', torch.ones(0)) + def prepare_for_export(self, module): if module.is_quant_enabled: self.scale = module.scale() @@ -182,7 +190,7 @@ def prepare_for_export(self, module): self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value - def quantize(self, x, scale, zero_point): + def quantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor]: # Compute masks inf_mask = x.isinf() p_max_val_mask = x > self.max_value @@ -200,24 +208,29 @@ def quantize(self, x, scale, zero_point): return x - def dequantize(self, x, scale, zero_point): + def dequantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor: return (x - zero_point) * scale - def forward(self, x) -> Tuple[torch.Tensor]: + def forward(self, x: Tensor) -> Tuple[Tensor]: return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values class FloatWeightInferencetHandler(FloatInferencetHandler): handled_layer = WeightFloatQuantProxyFromInjector + def __init__(self): + super().__init__() + self.register_buffer('cached_weight', torch.ones(1)) + def prepare_for_export(self, module): + super().prepare_for_export(module) if module.is_quant_enabled: - self.cached_weight = None - super().prepare_for_export(module) if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: self.cached_weight = module._cached_weight.value + else: + self.cached_weight = None - def forward(self, x) -> Tuple[torch.Tensor]: + def forward(self, x: Tensor) -> Tuple[Tensor]: if self.cached_weight is not None: x = self.cached_weight else: @@ -229,12 +242,12 @@ def forward(self, x) -> Tuple[torch.Tensor]: class GroupwiseFloatInferenceHandler(FloatInferencetHandler): handled_layer = GroupwiseActFloatQuantProxyFromInjector - def prepare_for_export(self, module): + def prepare_for_export(self, module: nn.Module): if module.is_quant_enabled: self.module_forward = module.fused_activation_quant_proxy self.group_dim = module.group_dim - def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + def forward(self, x: Tensor) -> Tuple[Tensor]: x, *other = self.module_forward(x) if is_dynamo_compiling(): start_dim = self.group_dim if self.group_dim != -1 else -2 @@ -246,13 +259,17 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler): handled_layer = GroupwiseWeightFloatQuantProxyFromInjector - def prepare_for_export(self, module): + def prepare_for_export(self, module: nn.Module): super().prepare_for_export(module) - self.input_view = module.input_view_impl - self.flattened_view = module.apply_input_view + if module.is_quant_enabled: + self.input_view = module.input_view_impl + self.flattened_view = module.apply_input_view + if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: + self.cached_weight = module._cached_weight.quant_tensor.value_ + else: + self.cached_weight = None - def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: - x = self.input_view(x) + def forward(self, x: Tensor) -> Tuple[Tensor]: if self.scale.shape != (): scale = self.input_view(self.scale) else: @@ -261,7 +278,11 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: zero_point = self.input_view(self.zero_point) else: zero_point = self.zero_point - out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) - if is_dynamo_compiling(): - out = self.flattened_view(out) - return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values + if self.cached_weight is not None: + out = self.cached_weight + else: + x = self.input_view(x) + out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) + if is_dynamo_compiling(): + out = self.flattened_view(out) + return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values