From adeeec36b4cfa44d6b439b92b7ead84f730a89bc Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 11 Jan 2025 02:26:27 +0100 Subject: [PATCH] Feat (mx): unpadding during dequantization (#1134) --- notebooks/minifloat_mx_tutorial.ipynb | 10 ++--- src/brevitas/export/inference/handler.py | 10 +++-- src/brevitas/graph/gpxq.py | 6 --- src/brevitas/proxy/float_runtime_quant.py | 2 +- .../proxy/groupwise_float_parameter_quant.py | 4 +- .../proxy/groupwise_float_runtime_quant.py | 10 +++-- .../proxy/groupwise_int_parameter_quant.py | 4 +- .../proxy/groupwise_int_runtime_quant.py | 10 +++-- src/brevitas/proxy/parameter_quant.py | 2 +- src/brevitas/proxy/runtime_quant.py | 18 ++++----- .../quant_tensor/base_quant_tensor.py | 4 +- .../groupwise_float_quant_tensor.py | 22 ++++------- .../groupwise_int_quant_tensor.py | 39 ++++++++++++------- src/brevitas/utils/quant_utils.py | 35 +++++++++++++++++ tests/brevitas/graph/test_gpxq.py | 20 +++------- 15 files changed, 113 insertions(+), 83 deletions(-) diff --git a/notebooks/minifloat_mx_tutorial.ipynb b/notebooks/minifloat_mx_tutorial.ipynb index 2a6f9bccb..db446cc3d 100644 --- a/notebooks/minifloat_mx_tutorial.ipynb +++ b/notebooks/minifloat_mx_tutorial.ipynb @@ -206,15 +206,15 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Non padding weights shape torch.Size([64, 8, 3, 3])\n", - "Padded weights shape torch.Size([64, 32, 3, 3])\n" + "Non padding weights shape torch.Size([64, 1, 8, 3, 3])\n", + "Padded weights shape torch.Size([64, 1, 32, 3, 3])\n" ] } ], @@ -257,8 +257,8 @@ "o = mx_model(x)\n", "\n", "# The quant weight of the padded model is different from the non padding one\n", - "print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value.shape}\")\n", - "print(f\"Padded weights shape {mx_model.conv.quant_weight().value.shape}\")\n", + "print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value_.shape}\")\n", + "print(f\"Padded weights shape {mx_model.conv.quant_weight().value_.shape}\")\n", "\n", "# However, results are still the same \n", "assert torch.allclose(o, o_no_padding)" diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 01fb8a63b..59944c2b0 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -24,6 +24,7 @@ from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector +from brevitas.utils.quant_utils import groupwise_dequant_expand from brevitas.utils.torch_utils import float_internal_scale @@ -146,8 +147,8 @@ def __init__(self): def prepare_for_export(self, module): super().prepare_for_export(module) if module.is_quant_enabled: + self.group_dim = module.group_dim self.input_view = module.input_view_impl - self.flattened_view = module.apply_input_view if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: self.cached_weight = module._cached_weight.quant_tensor.value_ else: @@ -165,12 +166,13 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: if self.cached_weight is not None: out = self.cached_weight else: + inp_shape = x.shape x = self.input_view(x) out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) # If we skip quant tensor, we return the flattened version of the groupwise tensor if self.skip_create_quant_tensor: - out = self.flattened_view(out) + out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0] return out, scale, zero_point, self.bit_width @@ -294,6 +296,7 @@ def prepare_for_export(self, module: nn.Module): if module.is_quant_enabled: self.input_view = module.input_view_impl self.flattened_view = module.apply_input_view + self.group_dim = module.group_dim if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: self.cached_weight = module._cached_weight.quant_tensor.value_ else: @@ -311,11 +314,12 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: if self.cached_weight is not None: out = self.cached_weight else: + inp_shape = x.shape x = self.input_view(x) out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) # If we skip quant tensor, we return the flattened version of the groupwise tensor if self.skip_create_quant_tensor: - out = self.flattened_view(out) + out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0] return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index df3446614..6fc8aa09b 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -187,12 +187,6 @@ def __init__( self.layer = layer self.name = name self.act_order = act_order - if self.layer.weight_quant.is_groupwise: - weight = self.layer.weight_quant.apply_input_view(self.layer.weight) - weight = weight.view(self.layer.weight_quant.quant_injector.reshaped_groupwise_shape) - self.layer.weight.data = weight.data - self.layer.in_channels = weight.shape[1] if is_conv_transposed( - self.layer) else weight.shape[0] weight_shape = torch.tensor(layer.weight.shape) diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index d3e9edb63..2be77b8f5 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -77,7 +77,7 @@ def create_quant_tensor( self, qt_args: Union[torch.Tensor, Tuple[Any]], x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor: - if x is None: + if isinstance(qt_args, tuple): out = FloatQuantTensor(*qt_args, signed=self.is_signed, training=self.training) else: out = FloatQuantTensor( diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 12aacd23b..206e983b5 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -34,6 +34,7 @@ def apply_input_view(self, x): return x.flatten(start_dim, start_dim + 1) def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor: + shape = self.tracked_parameter_list[0].shape out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args return GroupwiseFloatQuantTensor( out, @@ -48,4 +49,5 @@ def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor: inf_values, nan_values, self.is_signed, - self.training) + self.training, + shape) diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 835ebdd5d..5d76e4635 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -29,8 +29,8 @@ def apply_input_view(self, x): def create_quant_tensor( self, qt_args: Union[torch.Tensor, Tuple[Any]], - x: Optional[GroupwiseFloatQuantTensor] = None) -> GroupwiseFloatQuantTensor: - if x is None: + x: Union[torch.Tensor, GroupwiseFloatQuantTensor]) -> GroupwiseFloatQuantTensor: + if isinstance(qt_args, tuple): value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args out = GroupwiseFloatQuantTensor( value, @@ -45,7 +45,8 @@ def create_quant_tensor( inf_values, nan_values, self.is_signed, - self.training) + self.training, + dequant_shape=x.shape) else: out = GroupwiseFloatQuantTensor( qt_args, @@ -60,5 +61,6 @@ def create_quant_tensor( x.inf_values, x.nan_values, x.signed, - self.training) + self.training, + x.shape) return out diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 905e50c52..51ff97c28 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -34,6 +34,7 @@ def apply_input_view(self, x): return x.flatten(start_dim, start_dim + 1) def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor: + shape = self.tracked_parameter_list[0].shape out, scale, zero_point, bit_width = qt_args return GroupwiseIntQuantTensor( out, @@ -43,4 +44,5 @@ def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor: self.group_dim, bit_width, self.is_signed, - self.training) + self.training, + shape) diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 453cb3f9b..96d047808 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -29,8 +29,8 @@ def apply_input_view(self, x): def create_quant_tensor( self, qt_args: Union[torch.Tensor, Tuple[Any]], - x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor: - if x is None: + x: Union[torch.Tensor, GroupwiseIntQuantTensor]) -> GroupwiseIntQuantTensor: + if isinstance(qt_args, tuple): value, scale, zero_point, bit_width = qt_args out = GroupwiseIntQuantTensor( value, @@ -40,7 +40,8 @@ def create_quant_tensor( self.group_dim, bit_width, self.is_signed, - self.training) + self.training, + x.shape) else: out = GroupwiseIntQuantTensor( qt_args, @@ -50,5 +51,6 @@ def create_quant_tensor( self.group_dim, x.bit_width, x.signed, - self.training) + self.training, + x.shape) return out diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 2ca0afe92..7545f059e 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -157,7 +157,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled - out = self.apply_input_view(x) + out = x return out diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index cff192490..3e7248602 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -157,9 +157,8 @@ def init_tensor_quant(self): @abstractmethod def create_quant_tensor( - self, - qt_args: Union[torch.Tensor, Tuple[Any]], - x: Optional[QuantTensor] = None) -> QuantTensor: + self, qt_args: Union[torch.Tensor, Tuple[Any]], x: Union[Tensor, + QuantTensor]) -> QuantTensor: # Supports the following: # - qt_args as tuple of Tensors and bools = standard quant activations # - qt_args as Tensor and x as QuantTensor = passthrough activation @@ -181,8 +180,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: elif not self.is_quant_enabled: # A tuple helps later with control flows # The second None value is used later - # If quant is not enabled, we still apply input_view in the case of groupwise + padding - y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y)) + y = self.fused_activation_quant_proxy.activation_impl(y) y = (y, None) else: y = self.fused_activation_quant_proxy(y) @@ -194,7 +192,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: else: # If the second value (i.e., scale) is None, then quant is disabled if y[1] is not None: - out = self.create_quant_tensor(y) + out = self.create_quant_tensor(y, x=x) elif self.is_passthrough_act and isinstance(x, QuantTensor): # preserve scale/zp/bit/sign even without output quant y = y[0] @@ -224,11 +222,9 @@ def bit_width(self, force_eval=True): return self.retrieve_attribute('bit_width', force_eval) def create_quant_tensor( - self, - qt_args: Union[Tensor, Tuple[Any]], - x: Optional[IntQuantTensor] = None) -> IntQuantTensor: - - if x is None: + self, qt_args: Union[torch.Tensor, Tuple[Any]], + x: Union[Tensor, IntQuantTensor]) -> IntQuantTensor: + if isinstance(qt_args, tuple): out = IntQuantTensor(*qt_args, self.is_signed, self.training) else: out = IntQuantTensor( diff --git a/src/brevitas/quant_tensor/base_quant_tensor.py b/src/brevitas/quant_tensor/base_quant_tensor.py index 7b5dcd597..6f430025d 100644 --- a/src/brevitas/quant_tensor/base_quant_tensor.py +++ b/src/brevitas/quant_tensor/base_quant_tensor.py @@ -1,4 +1,4 @@ -from typing import List, NamedTuple, Optional +from typing import List, NamedTuple, Optional, Tuple from torch import Tensor @@ -129,6 +129,7 @@ class GroupwiseFloatQuantTensorBase(NamedTuple): nan_values: List[str] signed_t: Tensor training_t: Tensor + dequant_shape: Optional[Tuple] = None class GroupwisIntQuantTensorBase(NamedTuple): @@ -140,6 +141,7 @@ class GroupwisIntQuantTensorBase(NamedTuple): bit_width: Tensor signed_t: Tensor training_t: Tensor + dequant_shape: Optional[Tuple] = None def _unpack_quant_tensor(input_data): diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index b507d3fe3..60c5ba84f 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -31,7 +31,8 @@ def __new__( inf_values, nan_values, signed, - training): + training, + dequant_shape=None): if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, dtype=torch.float) @@ -63,7 +64,8 @@ def __new__( inf_values, nan_values, signed, - training) + training, + dequant_shape) return quant_tensor @property @@ -89,19 +91,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) def expand(self): - curr_shape = self.value_.shape - start_dim = self.group_dim if self.group_dim != -1 else -2 - new_value = self.value_.flatten(start_dim, start_dim + 1) - if self.scale_.shape != (): - new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_scale = self.scale_ - if self.zero_point_.shape != (): - new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_zp = self.zero_point_ - - return new_value, new_scale, new_zp + from brevitas.utils.quant_utils import groupwise_dequant_expand + return groupwise_dequant_expand( + self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape) @staticmethod def from_expanded(value, group_size, group_dim, compress=False): diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 082ec1234..fa7e8438e 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -18,7 +18,17 @@ class GroupwiseIntQuantTensor(GroupwisIntQuantTensorBase, QuantTensor): - def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training): + def __new__( + cls, + value, + scale, + zero_point, + group_size, + group_dim, + bit_width, + signed, + training, + dequant_shape=None): if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, dtype=torch.float) @@ -31,7 +41,16 @@ def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, sig if not isinstance(training, torch.Tensor): training = torch.tensor(training, dtype=torch.bool) quant_tensor = super().__new__( - cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training) + cls, + value, + scale, + zero_point, + group_size, + group_dim, + bit_width, + signed, + training, + dequant_shape) return quant_tensor @property @@ -58,19 +77,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) def expand(self): - curr_shape = self.value_.shape - start_dim = self.group_dim if self.group_dim != -1 else -2 - new_value = self.value_.flatten(start_dim, start_dim + 1) - if self.scale_.shape != (): - new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_scale = self.scale_ - if self.zero_point_.shape != (): - new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_zp = self.zero_point_ - - return new_value, new_scale, new_zp + from brevitas.utils.quant_utils import groupwise_dequant_expand + return groupwise_dequant_expand( + self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape) @staticmethod def from_expanded(value, group_size, group_dim, compress=False): diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index a7c86d7bc..d0d245089 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import torch + from brevitas.core.bit_width import BitWidthParameter from brevitas.core.function_wrapper import * from brevitas.core.quant import RescalingIntQuant @@ -215,3 +217,36 @@ def float_to_int_impl_to_enum(module): return FloatToIntImplType.STOCHASTIC_ROUND else: return None + + +def groupwise_dequant_expand(value_, scale_, zero_point_, group_dim, dequant_shape): + final_shape = dequant_shape + curr_shape = value_.shape + start_dim = group_dim if group_dim != -1 else -2 + new_value = value_.flatten(start_dim, start_dim + 1) + if scale_.shape != (): + new_scale = scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) + else: + new_scale = scale_ + if zero_point_.shape != (): + new_zp = zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) + else: + new_zp = zero_point_ + + # If we padded during quantization, we unpad here: + # First, we compute how much we padded along the group_dim shape + # Then, we unbind the tensor along the group_dim shape, and drop the padded columns + # Finally, we stack the remaining tensors + unpadding_shape = final_shape[group_dim] + residual = new_value.shape[group_dim] - unpadding_shape + + if residual > 0: + new_value = torch.stack( + torch.unbind(new_value, dim=group_dim)[:unpadding_shape], dim=group_dim) + new_scale = torch.stack( + torch.unbind(new_scale, dim=group_dim)[:unpadding_shape], dim=group_dim) + if zero_point_.shape != (): + new_zp = torch.stack( + torch.unbind(new_zp, dim=group_dim)[:unpadding_shape], dim=group_dim) + + return new_value, new_scale, new_zp diff --git a/tests/brevitas/graph/test_gpxq.py b/tests/brevitas/graph/test_gpxq.py index 33116332c..aa2ec9f97 100644 --- a/tests/brevitas/graph/test_gpxq.py +++ b/tests/brevitas/graph/test_gpxq.py @@ -89,18 +89,8 @@ def test_toymodels(toy_quant_model, act_order, use_quant_activations, apply_gpxq dataset = TensorDataset(inp, inp) calib_loader = DataLoader(dataset, batch_size=16, num_workers=0, pin_memory=True, shuffle=True) - if ((name == 'gptq' or name == 'gpfq2') and torch_version < version.parse('1.10')): - # Usage of linalg_cholesky() is not compatible with torch 1.9.1 and below - with pytest.raises(AssertionError): - apply_gpxq( - calib_loader=calib_loader, - model=model, - act_order=act_order, - use_quant_activations=use_quant_activations) - - else: - apply_gpxq( - calib_loader=calib_loader, - model=model, - act_order=act_order, - use_quant_activations=use_quant_activations) + apply_gpxq( + calib_loader=calib_loader, + model=model, + act_order=act_order, + use_quant_activations=use_quant_activations)