diff --git a/src/brevitas/common.py b/src/brevitas/common.py index 62ddf219a..7a0ec4483 100644 --- a/src/brevitas/common.py +++ b/src/brevitas/common.py @@ -3,6 +3,7 @@ from abc import ABCMeta from abc import abstractmethod +import warnings from brevitas import config @@ -29,8 +30,8 @@ def export_mode(self): @export_mode.setter def export_mode(self, value): if value and config.JIT_ENABLED: - raise RuntimeError( - "Export mode with BREVITAS_JIT is currently not supported. Save the model' " + warnings.warn( + "Export mode with BREVITAS_JIT might fail. If so, save the model' " "state_dict to a .pth, load it back with BREVITAS_JIT=0, and call export.") if value and self.training: raise RuntimeError("Can't enter export mode during training, only during inference") diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 1416014ec..32c1ac5ac 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -6,44 +6,55 @@ from typing import Tuple import torch +from torch import Tensor +import torch.nn as nn +from brevitas import is_dynamo_compiling from brevitas.function.ops import max_float from brevitas.function.ops import max_int from brevitas.function.ops import min_int from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase +from brevitas.proxy.groupwise_float_parameter_quant import \ + GroupwiseWeightFloatQuantProxyFromInjector +from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector +from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector +from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector from brevitas.utils.torch_utils import float_internal_scale class InferenceHandler(torch.nn.Module, ABC): - def attach_debug_info(self, module): + def attach_debug_info(self, module: nn.Module): pass @abstractmethod - def prepare_for_export(self, module): + def prepare_for_export(self, module: nn.Module): pass @abstractmethod - def quantize(self, x): + def quantize(self, x: Tensor): pass @abstractmethod - def dequantize(self, x): + def dequantize(self, x: Tensor): pass class IntInferencetHandler(InferenceHandler): handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector) - def attach_debug_info(self, module): - pass + def __init__(self): + super().__init__() + self.register_buffer('scale', torch.ones(1)) + self.register_buffer('zero_point', torch.ones(0)) - def prepare_for_export(self, module): + def prepare_for_export(self, module: nn.Module): if module.is_quant_enabled: self.scale = module.scale() self.zero_point = module.zero_point().to(self.scale.device) @@ -51,38 +62,108 @@ def prepare_for_export(self, module): self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) - def quantize(self, x): - return torch.clamp( - torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp) + def quantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor]: + return torch.clamp(torch.round(x / scale + zero_point), self.min_clamp, self.max_clamp) - def dequantize(self, x): - return (x - self.zero_point) * self.scale + def dequantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor: + return (x - zero_point) * scale - def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: - return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width + def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: + return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.bit_width class IntWeightInferencetHandler(IntInferencetHandler): handled_layer = WeightQuantProxyFromInjector - def prepare_for_export(self, module): + def __init__(self): + super().__init__() + self.register_buffer('cached_weight', torch.ones(1)) + + def prepare_for_export(self, module: nn.Module): + super().prepare_for_export(module) if module.is_quant_enabled: - self.cached_weight = None - super().prepare_for_export(module) if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: self.cached_weight = module._cached_weight.value + else: + self.cached_weight = None - def forward(self, x) -> Tuple[torch.Tensor]: + def forward(self, x: Tensor) -> Tuple[Tensor]: if self.cached_weight is not None: x = self.cached_weight else: - x = self.dequantize(self.quantize(x)) + x = self.dequantize( + self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point) return x, self.scale, self.zero_point, self.bit_width +class DynamicIntInferenceHandler(IntInferencetHandler): + handled_layer = DynamicActQuantProxyFromInjector + + def prepare_for_export(self, module: nn.Module): + if module.is_quant_enabled: + self.module_forward = module.fused_activation_quant_proxy + + def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: + return self.module_forward(x) + + +class GroupwiseIntInferenceHandler(IntInferencetHandler): + handled_layer = GroupwiseActQuantProxyFromInjector + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.module_forward = module.fused_activation_quant_proxy + self.group_dim = module.group_dim + + def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: + x, *other = self.module_forward(x) + if is_dynamo_compiling(): + start_dim = self.group_dim if self.group_dim != -1 else -2 + x = x.flatten(start_dim, start_dim + 1) + output_args = tuple([x] + list(other)) + return output_args + + +class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler): + handled_layer = GroupwiseWeightQuantProxyFromInjector + + def prepare_for_export(self, module): + super().prepare_for_export(module) + if module.is_quant_enabled: + self.input_view = module.input_view_impl + self.flattened_view = module.apply_input_view + if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: + self.cached_weight = module._cached_weight.quant_tensor.value_ + else: + self.cached_weight = None + + def forward(self, x: Tensor) -> Tuple[Tensor]: + if self.scale.shape != (): + scale = self.input_view(self.scale) + else: + scale = self.scale + if self.zero_point.shape != (): + zero_point = self.input_view(self.zero_point) + else: + zero_point = self.zero_point + if self.cached_weight is not None: + out = self.cached_weight + else: + x = self.input_view(x) + out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) + if is_dynamo_compiling(): + out = self.flattened_view(out) + return out, scale, zero_point, self.bit_width + + class FloatInferencetHandler(InferenceHandler): handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector) + def __init__(self): + super().__init__() + self.register_buffer('scale', torch.ones(1)) + self.register_buffer('zero_point', torch.ones(0)) + def prepare_for_export(self, module): if module.is_quant_enabled: self.scale = module.scale() @@ -109,14 +190,13 @@ def prepare_for_export(self, module): self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value - def quantize(self, x): + def quantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor]: # Compute masks inf_mask = x.isinf() p_max_val_mask = x > self.max_value n_max_val_mask = -x > self.max_value - # Quantize - x = x / self.scale + x = x / scale internal_scale = float_internal_scale( x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps) x = internal_scale * self.float_to_int_impl(x / internal_scale) @@ -128,26 +208,81 @@ def quantize(self, x): return x - def dequantize(self, x): - return (x - self.zero_point) * self.scale + def dequantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor: + return (x - zero_point) * scale - def forward(self, x) -> Tuple[torch.Tensor]: - return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values + def forward(self, x: Tensor) -> Tuple[Tensor]: + return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values class FloatWeightInferencetHandler(FloatInferencetHandler): handled_layer = WeightFloatQuantProxyFromInjector + def __init__(self): + super().__init__() + self.register_buffer('cached_weight', torch.ones(1)) + def prepare_for_export(self, module): + super().prepare_for_export(module) if module.is_quant_enabled: - self.cached_weight = None - super().prepare_for_export(module) if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: self.cached_weight = module._cached_weight.value + else: + self.cached_weight = None - def forward(self, x) -> Tuple[torch.Tensor]: + def forward(self, x: Tensor) -> Tuple[Tensor]: if self.cached_weight is not None: x = self.cached_weight else: - x = self.dequantize(self.quantize(x)) + x = self.dequantize( + self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point) return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values + + +class GroupwiseFloatInferenceHandler(FloatInferencetHandler): + handled_layer = GroupwiseActFloatQuantProxyFromInjector + + def prepare_for_export(self, module: nn.Module): + if module.is_quant_enabled: + self.module_forward = module.fused_activation_quant_proxy + self.group_dim = module.group_dim + + def forward(self, x: Tensor) -> Tuple[Tensor]: + x, *other = self.module_forward(x) + if is_dynamo_compiling(): + start_dim = self.group_dim if self.group_dim != -1 else -2 + x = x.flatten(start_dim, start_dim + 1) + output_args = tuple([x] + list(other)) + return output_args + + +class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler): + handled_layer = GroupwiseWeightFloatQuantProxyFromInjector + + def prepare_for_export(self, module: nn.Module): + super().prepare_for_export(module) + if module.is_quant_enabled: + self.input_view = module.input_view_impl + self.flattened_view = module.apply_input_view + if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: + self.cached_weight = module._cached_weight.quant_tensor.value_ + else: + self.cached_weight = None + + def forward(self, x: Tensor) -> Tuple[Tensor]: + if self.scale.shape != (): + scale = self.input_view(self.scale) + else: + scale = self.scale + if self.zero_point.shape != (): + zero_point = self.input_view(self.zero_point) + else: + zero_point = self.zero_point + if self.cached_weight is not None: + out = self.cached_weight + else: + x = self.input_view(x) + out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) + if is_dynamo_compiling(): + out = self.flattened_view(out) + return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index 936106884..b78a888f2 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -4,8 +4,13 @@ from torch.nn import Module import torch.nn as nn +from brevitas.export.inference.handler import DynamicIntInferenceHandler from brevitas.export.inference.handler import FloatInferencetHandler from brevitas.export.inference.handler import FloatWeightInferencetHandler +from brevitas.export.inference.handler import GroupwiseFloatInferenceHandler +from brevitas.export.inference.handler import GroupwiseFloatWeightInferenceHandler +from brevitas.export.inference.handler import GroupwiseIntInferenceHandler +from brevitas.export.inference.handler import GroupwiseIntWeightInferenceHandler from brevitas.export.inference.handler import IntInferencetHandler from brevitas.export.inference.handler import IntWeightInferencetHandler from brevitas.export.manager import _set_proxy_export_handler @@ -65,6 +70,7 @@ def __exit__(self, type, value, traceback): # Disable all caching # deactivate export mode # restore return quant tensor + InferenceManager.set_export_mode(self.model, enabled=False) self.model.apply( lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False)) self.model.apply( @@ -72,7 +78,6 @@ def __exit__(self, type, value, traceback): if self.cache_quant_weight: self.model.apply( lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False)) - InferenceManager.set_export_mode(self.model, enabled=False) restore_return_quant_tensor(self.model, self.return_quant_tensor_state) def hook(self, module, inp, out): @@ -91,9 +96,14 @@ def hook(self, module, inp, out): class InferenceManager(BaseManager): handlers = [ IntInferencetHandler, + DynamicIntInferenceHandler, FloatInferencetHandler, IntWeightInferencetHandler, - FloatWeightInferencetHandler] + FloatWeightInferencetHandler, + GroupwiseIntInferenceHandler, + GroupwiseIntWeightInferenceHandler, + GroupwiseFloatInferenceHandler, + GroupwiseFloatWeightInferenceHandler] @classmethod def set_export_mode(cls, model: Module, enabled: bool): diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 16f75c49e..b507d3fe3 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -92,7 +92,6 @@ def expand(self): curr_shape = self.value_.shape start_dim = self.group_dim if self.group_dim != -1 else -2 new_value = self.value_.flatten(start_dim, start_dim + 1) - new_value = self.value_.flatten(start_dim, start_dim + 1) if self.scale_.shape != (): new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) else: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 6004ec97d..bd9669bd8 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -14,6 +14,7 @@ from transformers.utils.fx import _SUPPORTED_MODELS from brevitas.export import export_torch_qcdq +from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation @@ -421,8 +422,10 @@ def main(args): if args.eval and not args.no_quantize: print("Model eval...") - quant_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + quant_ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") remove_hooks(model)