From e7a21fc631b046ea4a891e0cbc0b04b6d8825b0b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 18 Dec 2024 16:56:37 +0000 Subject: [PATCH 01/10] Feat (mx): unpadding during dequantization --- .../proxy/groupwise_float_parameter_quant.py | 4 +- .../proxy/groupwise_int_parameter_quant.py | 4 +- .../quant_tensor/base_quant_tensor.py | 4 +- .../groupwise_float_quant_tensor.py | 22 ++++++++++- .../groupwise_int_quant_tensor.py | 39 ++++++++++++++++++- 5 files changed, 66 insertions(+), 7 deletions(-) diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 12aacd23b..206e983b5 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -34,6 +34,7 @@ def apply_input_view(self, x): return x.flatten(start_dim, start_dim + 1) def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor: + shape = self.tracked_parameter_list[0].shape out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args return GroupwiseFloatQuantTensor( out, @@ -48,4 +49,5 @@ def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor: inf_values, nan_values, self.is_signed, - self.training) + self.training, + shape) diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 905e50c52..51ff97c28 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -34,6 +34,7 @@ def apply_input_view(self, x): return x.flatten(start_dim, start_dim + 1) def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor: + shape = self.tracked_parameter_list[0].shape out, scale, zero_point, bit_width = qt_args return GroupwiseIntQuantTensor( out, @@ -43,4 +44,5 @@ def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor: self.group_dim, bit_width, self.is_signed, - self.training) + self.training, + shape) diff --git a/src/brevitas/quant_tensor/base_quant_tensor.py b/src/brevitas/quant_tensor/base_quant_tensor.py index 7b5dcd597..6f430025d 100644 --- a/src/brevitas/quant_tensor/base_quant_tensor.py +++ b/src/brevitas/quant_tensor/base_quant_tensor.py @@ -1,4 +1,4 @@ -from typing import List, NamedTuple, Optional +from typing import List, NamedTuple, Optional, Tuple from torch import Tensor @@ -129,6 +129,7 @@ class GroupwiseFloatQuantTensorBase(NamedTuple): nan_values: List[str] signed_t: Tensor training_t: Tensor + dequant_shape: Optional[Tuple] = None class GroupwisIntQuantTensorBase(NamedTuple): @@ -140,6 +141,7 @@ class GroupwisIntQuantTensorBase(NamedTuple): bit_width: Tensor signed_t: Tensor training_t: Tensor + dequant_shape: Optional[Tuple] = None def _unpack_quant_tensor(input_data): diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index b507d3fe3..a4099b785 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -31,7 +31,8 @@ def __new__( inf_values, nan_values, signed, - training): + training, + dequant_shape=None): if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, dtype=torch.float) @@ -63,7 +64,8 @@ def __new__( inf_values, nan_values, signed, - training) + training, + dequant_shape) return quant_tensor @property @@ -89,6 +91,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) def expand(self): + final_shape = self.dequant_shape curr_shape = self.value_.shape start_dim = self.group_dim if self.group_dim != -1 else -2 new_value = self.value_.flatten(start_dim, start_dim + 1) @@ -101,6 +104,21 @@ def expand(self): else: new_zp = self.zero_point_ + # If we padded during quantization, we unpad here: + # First, we compute how much we padded along the group_dim shape + # Then, we unbind the tensor along the group_dim shape, and drop the padded columns + # Finally, we stack the remaining tensors + unpadding_shape = final_shape[self.group_dim] + residual = curr_shape[self.group_dim] - unpadding_shape + + if residual > 0: + new_value = torch.stack( + torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim) + new_scale = torch.stack( + torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim) + new_zp = torch.stack( + torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) + return new_value, new_scale, new_zp @staticmethod diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 082ec1234..65fc9b73f 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -18,7 +18,17 @@ class GroupwiseIntQuantTensor(GroupwisIntQuantTensorBase, QuantTensor): - def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training): + def __new__( + cls, + value, + scale, + zero_point, + group_size, + group_dim, + bit_width, + signed, + training, + dequant_shape=None): if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, dtype=torch.float) @@ -31,7 +41,16 @@ def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, sig if not isinstance(training, torch.Tensor): training = torch.tensor(training, dtype=torch.bool) quant_tensor = super().__new__( - cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training) + cls, + value, + scale, + zero_point, + group_size, + group_dim, + bit_width, + signed, + training, + dequant_shape) return quant_tensor @property @@ -58,6 +77,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) def expand(self): + final_shape = self.dequant_shape curr_shape = self.value_.shape start_dim = self.group_dim if self.group_dim != -1 else -2 new_value = self.value_.flatten(start_dim, start_dim + 1) @@ -70,6 +90,21 @@ def expand(self): else: new_zp = self.zero_point_ + # If we padded during quantization, we unpad here: + # First, we compute how much we padded along the group_dim shape + # Then, we unbind the tensor along the group_dim shape, and drop the padded columns + # Finally, we stack the remaining tensors + unpadding_shape = final_shape[self.group_dim] + residual = curr_shape[self.group_dim] - unpadding_shape + + if residual > 0: + new_value = torch.stack( + torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim) + new_scale = torch.stack( + torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim) + new_zp = torch.stack( + torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) + return new_value, new_scale, new_zp @staticmethod From 584aade0554eb927b679a39cb3a91462569ed271 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 18 Dec 2024 17:15:30 +0000 Subject: [PATCH 02/10] unpadding everything --- src/brevitas/proxy/float_runtime_quant.py | 2 +- src/brevitas/proxy/groupwise_float_runtime_quant.py | 8 +++++--- src/brevitas/proxy/groupwise_int_runtime_quant.py | 8 +++++--- src/brevitas/proxy/runtime_quant.py | 7 +++---- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index d3e9edb63..2be77b8f5 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -77,7 +77,7 @@ def create_quant_tensor( self, qt_args: Union[torch.Tensor, Tuple[Any]], x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor: - if x is None: + if isinstance(qt_args, tuple): out = FloatQuantTensor(*qt_args, signed=self.is_signed, training=self.training) else: out = FloatQuantTensor( diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 835ebdd5d..e1e8b75d7 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -30,7 +30,7 @@ def create_quant_tensor( self, qt_args: Union[torch.Tensor, Tuple[Any]], x: Optional[GroupwiseFloatQuantTensor] = None) -> GroupwiseFloatQuantTensor: - if x is None: + if isinstance(qt_args, tuple): value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args out = GroupwiseFloatQuantTensor( value, @@ -45,7 +45,8 @@ def create_quant_tensor( inf_values, nan_values, self.is_signed, - self.training) + self.training, + dequant_shape=x.shape) else: out = GroupwiseFloatQuantTensor( qt_args, @@ -60,5 +61,6 @@ def create_quant_tensor( x.inf_values, x.nan_values, x.signed, - self.training) + self.training, + x.shape) return out diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 453cb3f9b..09dbc1ae3 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -30,7 +30,7 @@ def create_quant_tensor( self, qt_args: Union[torch.Tensor, Tuple[Any]], x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor: - if x is None: + if isinstance(qt_args, tuple): value, scale, zero_point, bit_width = qt_args out = GroupwiseIntQuantTensor( value, @@ -40,7 +40,8 @@ def create_quant_tensor( self.group_dim, bit_width, self.is_signed, - self.training) + self.training, + x.shape) else: out = GroupwiseIntQuantTensor( qt_args, @@ -50,5 +51,6 @@ def create_quant_tensor( self.group_dim, x.bit_width, x.signed, - self.training) + self.training, + x.shape) return out diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index cff192490..32ecaa81e 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -157,9 +157,8 @@ def init_tensor_quant(self): @abstractmethod def create_quant_tensor( - self, - qt_args: Union[torch.Tensor, Tuple[Any]], - x: Optional[QuantTensor] = None) -> QuantTensor: + self, qt_args: Union[torch.Tensor, Tuple[Any]], x: Union[Tensor, + QuantTensor]) -> QuantTensor: # Supports the following: # - qt_args as tuple of Tensors and bools = standard quant activations # - qt_args as Tensor and x as QuantTensor = passthrough activation @@ -194,7 +193,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: else: # If the second value (i.e., scale) is None, then quant is disabled if y[1] is not None: - out = self.create_quant_tensor(y) + out = self.create_quant_tensor(y, x=x) elif self.is_passthrough_act and isinstance(x, QuantTensor): # preserve scale/zp/bit/sign even without output quant y = y[0] From 84a8e2378175052e56f5d547319059835fce11a8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 18 Dec 2024 21:35:33 +0000 Subject: [PATCH 03/10] fix for tensor --- src/brevitas/proxy/runtime_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 32ecaa81e..6cd2c03ed 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -227,7 +227,7 @@ def create_quant_tensor( qt_args: Union[Tensor, Tuple[Any]], x: Optional[IntQuantTensor] = None) -> IntQuantTensor: - if x is None: + if isinstance(qt_args, tuple): out = IntQuantTensor(*qt_args, self.is_signed, self.training) else: out = IntQuantTensor( From d0a9335fc075afacd37d5ab20309341f95744571 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 18 Dec 2024 22:53:27 +0000 Subject: [PATCH 04/10] Fix zero point --- src/brevitas/quant_tensor/groupwise_float_quant_tensor.py | 5 +++-- src/brevitas/quant_tensor/groupwise_int_quant_tensor.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index a4099b785..fe40e5319 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -116,8 +116,9 @@ def expand(self): torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim) new_scale = torch.stack( torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim) - new_zp = torch.stack( - torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) + if self.zero_point_.shape != (): + new_zp = torch.stack( + torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) return new_value, new_scale, new_zp diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 65fc9b73f..67e6e769f 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -102,8 +102,9 @@ def expand(self): torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim) new_scale = torch.stack( torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim) - new_zp = torch.stack( - torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) + if self.zero_point_.shape != (): + new_zp = torch.stack( + torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) return new_value, new_scale, new_zp From ad0aab85fce313e7675639489d353d0d1fe33228 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 18 Dec 2024 23:15:19 +0000 Subject: [PATCH 05/10] Fix weight residual computation --- src/brevitas/quant_tensor/groupwise_float_quant_tensor.py | 2 +- src/brevitas/quant_tensor/groupwise_int_quant_tensor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index fe40e5319..4c1d1708c 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -109,7 +109,7 @@ def expand(self): # Then, we unbind the tensor along the group_dim shape, and drop the padded columns # Finally, we stack the remaining tensors unpadding_shape = final_shape[self.group_dim] - residual = curr_shape[self.group_dim] - unpadding_shape + residual = new_value.shape[self.group_dim] - unpadding_shape if residual > 0: new_value = torch.stack( diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 67e6e769f..4e2357579 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -95,7 +95,7 @@ def expand(self): # Then, we unbind the tensor along the group_dim shape, and drop the padded columns # Finally, we stack the remaining tensors unpadding_shape = final_shape[self.group_dim] - residual = curr_shape[self.group_dim] - unpadding_shape + residual = new_value.shape[self.group_dim] - unpadding_shape if residual > 0: new_value = torch.stack( From 7f2dd15790beabf9d9396be284ac5b3ffbd6abde Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 19 Dec 2024 08:33:27 +0000 Subject: [PATCH 06/10] fix --- notebooks/minifloat_mx_tutorial.ipynb | 10 +++++----- src/brevitas/graph/gpxq.py | 6 ------ .../quant_tensor/groupwise_float_quant_tensor.py | 6 +++--- .../quant_tensor/groupwise_int_quant_tensor.py | 6 +++--- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/notebooks/minifloat_mx_tutorial.ipynb b/notebooks/minifloat_mx_tutorial.ipynb index 2a6f9bccb..db446cc3d 100644 --- a/notebooks/minifloat_mx_tutorial.ipynb +++ b/notebooks/minifloat_mx_tutorial.ipynb @@ -206,15 +206,15 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Non padding weights shape torch.Size([64, 8, 3, 3])\n", - "Padded weights shape torch.Size([64, 32, 3, 3])\n" + "Non padding weights shape torch.Size([64, 1, 8, 3, 3])\n", + "Padded weights shape torch.Size([64, 1, 32, 3, 3])\n" ] } ], @@ -257,8 +257,8 @@ "o = mx_model(x)\n", "\n", "# The quant weight of the padded model is different from the non padding one\n", - "print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value.shape}\")\n", - "print(f\"Padded weights shape {mx_model.conv.quant_weight().value.shape}\")\n", + "print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value_.shape}\")\n", + "print(f\"Padded weights shape {mx_model.conv.quant_weight().value_.shape}\")\n", "\n", "# However, results are still the same \n", "assert torch.allclose(o, o_no_padding)" diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index df3446614..6fc8aa09b 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -187,12 +187,6 @@ def __init__( self.layer = layer self.name = name self.act_order = act_order - if self.layer.weight_quant.is_groupwise: - weight = self.layer.weight_quant.apply_input_view(self.layer.weight) - weight = weight.view(self.layer.weight_quant.quant_injector.reshaped_groupwise_shape) - self.layer.weight.data = weight.data - self.layer.in_channels = weight.shape[1] if is_conv_transposed( - self.layer) else weight.shape[0] weight_shape = torch.tensor(layer.weight.shape) diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 4c1d1708c..f5be0ec67 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -113,12 +113,12 @@ def expand(self): if residual > 0: new_value = torch.stack( - torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim) + torch.unbind(new_value, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) new_scale = torch.stack( - torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim) + torch.unbind(new_scale, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) if self.zero_point_.shape != (): new_zp = torch.stack( - torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) + torch.unbind(new_zp, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) return new_value, new_scale, new_zp diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 4e2357579..c92cc01fd 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -99,12 +99,12 @@ def expand(self): if residual > 0: new_value = torch.stack( - torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim) + torch.unbind(new_value, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) new_scale = torch.stack( - torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim) + torch.unbind(new_scale, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) if self.zero_point_.shape != (): new_zp = torch.stack( - torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) + torch.unbind(new_zp, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) return new_value, new_scale, new_zp From c909178a87324576d5c5002c5b5bfd99d90fb845 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 19 Dec 2024 10:45:05 +0000 Subject: [PATCH 07/10] fix --- src/brevitas/proxy/parameter_quant.py | 2 +- src/brevitas/proxy/runtime_quant.py | 3 +- .../groupwise_float_quant_tensor.py | 33 ++--------------- .../groupwise_int_quant_tensor.py | 33 ++--------------- src/brevitas/utils/quant_utils.py | 35 +++++++++++++++++++ tests/brevitas/graph/test_gpxq.py | 20 +++-------- 6 files changed, 48 insertions(+), 78 deletions(-) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 2ca0afe92..7545f059e 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -157,7 +157,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled - out = self.apply_input_view(x) + out = x return out diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 6cd2c03ed..95949add7 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -180,8 +180,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: elif not self.is_quant_enabled: # A tuple helps later with control flows # The second None value is used later - # If quant is not enabled, we still apply input_view in the case of groupwise + padding - y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y)) + y = self.fused_activation_quant_proxy.activation_impl(y) y = (y, None) else: y = self.fused_activation_quant_proxy(y) diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index f5be0ec67..8166f1b15 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -91,36 +91,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) def expand(self): - final_shape = self.dequant_shape - curr_shape = self.value_.shape - start_dim = self.group_dim if self.group_dim != -1 else -2 - new_value = self.value_.flatten(start_dim, start_dim + 1) - if self.scale_.shape != (): - new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_scale = self.scale_ - if self.zero_point_.shape != (): - new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_zp = self.zero_point_ - - # If we padded during quantization, we unpad here: - # First, we compute how much we padded along the group_dim shape - # Then, we unbind the tensor along the group_dim shape, and drop the padded columns - # Finally, we stack the remaining tensors - unpadding_shape = final_shape[self.group_dim] - residual = new_value.shape[self.group_dim] - unpadding_shape - - if residual > 0: - new_value = torch.stack( - torch.unbind(new_value, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - new_scale = torch.stack( - torch.unbind(new_scale, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - if self.zero_point_.shape != (): - new_zp = torch.stack( - torch.unbind(new_zp, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - - return new_value, new_scale, new_zp + from brevitas.utils.quant_utils import groupwise_dequant + return groupwise_dequant( + self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape) @staticmethod def from_expanded(value, group_size, group_dim, compress=False): diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index c92cc01fd..7d97ad4cb 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -77,36 +77,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) def expand(self): - final_shape = self.dequant_shape - curr_shape = self.value_.shape - start_dim = self.group_dim if self.group_dim != -1 else -2 - new_value = self.value_.flatten(start_dim, start_dim + 1) - if self.scale_.shape != (): - new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_scale = self.scale_ - if self.zero_point_.shape != (): - new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_zp = self.zero_point_ - - # If we padded during quantization, we unpad here: - # First, we compute how much we padded along the group_dim shape - # Then, we unbind the tensor along the group_dim shape, and drop the padded columns - # Finally, we stack the remaining tensors - unpadding_shape = final_shape[self.group_dim] - residual = new_value.shape[self.group_dim] - unpadding_shape - - if residual > 0: - new_value = torch.stack( - torch.unbind(new_value, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - new_scale = torch.stack( - torch.unbind(new_scale, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - if self.zero_point_.shape != (): - new_zp = torch.stack( - torch.unbind(new_zp, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - - return new_value, new_scale, new_zp + from brevitas.utils.quant_utils import groupwise_dequant + return groupwise_dequant( + self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape) @staticmethod def from_expanded(value, group_size, group_dim, compress=False): diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index a7c86d7bc..22c17eb9e 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import torch + from brevitas.core.bit_width import BitWidthParameter from brevitas.core.function_wrapper import * from brevitas.core.quant import RescalingIntQuant @@ -215,3 +217,36 @@ def float_to_int_impl_to_enum(module): return FloatToIntImplType.STOCHASTIC_ROUND else: return None + + +def groupwise_dequant(value_, scale_, zero_point_, group_dim, dequant_shape): + final_shape = dequant_shape + curr_shape = value_.shape + start_dim = group_dim if group_dim != -1 else -2 + new_value = value_.flatten(start_dim, start_dim + 1) + if scale_.shape != (): + new_scale = scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) + else: + new_scale = scale_ + if zero_point_.shape != (): + new_zp = zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) + else: + new_zp = zero_point_ + + # If we padded during quantization, we unpad here: + # First, we compute how much we padded along the group_dim shape + # Then, we unbind the tensor along the group_dim shape, and drop the padded columns + # Finally, we stack the remaining tensors + unpadding_shape = final_shape[group_dim] + residual = new_value.shape[group_dim] - unpadding_shape + + if residual > 0: + new_value = torch.stack( + torch.unbind(new_value, dim=group_dim)[:unpadding_shape], dim=group_dim) + new_scale = torch.stack( + torch.unbind(new_scale, dim=group_dim)[:unpadding_shape], dim=group_dim) + if zero_point_.shape != (): + new_zp = torch.stack( + torch.unbind(new_zp, dim=group_dim)[:unpadding_shape], dim=group_dim) + + return new_value, new_scale, new_zp diff --git a/tests/brevitas/graph/test_gpxq.py b/tests/brevitas/graph/test_gpxq.py index 33116332c..aa2ec9f97 100644 --- a/tests/brevitas/graph/test_gpxq.py +++ b/tests/brevitas/graph/test_gpxq.py @@ -89,18 +89,8 @@ def test_toymodels(toy_quant_model, act_order, use_quant_activations, apply_gpxq dataset = TensorDataset(inp, inp) calib_loader = DataLoader(dataset, batch_size=16, num_workers=0, pin_memory=True, shuffle=True) - if ((name == 'gptq' or name == 'gpfq2') and torch_version < version.parse('1.10')): - # Usage of linalg_cholesky() is not compatible with torch 1.9.1 and below - with pytest.raises(AssertionError): - apply_gpxq( - calib_loader=calib_loader, - model=model, - act_order=act_order, - use_quant_activations=use_quant_activations) - - else: - apply_gpxq( - calib_loader=calib_loader, - model=model, - act_order=act_order, - use_quant_activations=use_quant_activations) + apply_gpxq( + calib_loader=calib_loader, + model=model, + act_order=act_order, + use_quant_activations=use_quant_activations) From 0f61e8e61b932f22861e07119950ee74a51c062d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 19 Dec 2024 11:09:35 +0000 Subject: [PATCH 08/10] Last fixes --- src/brevitas/export/inference/handler.py | 9 ++++++--- .../quant_tensor/groupwise_float_quant_tensor.py | 4 ++-- src/brevitas/quant_tensor/groupwise_int_quant_tensor.py | 4 ++-- src/brevitas/utils/quant_utils.py | 2 +- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 01fb8a63b..be030985f 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -24,6 +24,7 @@ from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector +from brevitas.utils.quant_utils import groupwise_dequant_expand from brevitas.utils.torch_utils import float_internal_scale @@ -146,8 +147,8 @@ def __init__(self): def prepare_for_export(self, module): super().prepare_for_export(module) if module.is_quant_enabled: + self.group_dim = module.group_dim self.input_view = module.input_view_impl - self.flattened_view = module.apply_input_view if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: self.cached_weight = module._cached_weight.quant_tensor.value_ else: @@ -165,12 +166,13 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: if self.cached_weight is not None: out = self.cached_weight else: + inp_shape = x.shape x = self.input_view(x) out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) # If we skip quant tensor, we return the flattened version of the groupwise tensor if self.skip_create_quant_tensor: - out = self.flattened_view(out) + out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0] return out, scale, zero_point, self.bit_width @@ -311,11 +313,12 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: if self.cached_weight is not None: out = self.cached_weight else: + inp_shape = x.shape x = self.input_view(x) out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) # If we skip quant tensor, we return the flattened version of the groupwise tensor if self.skip_create_quant_tensor: - out = self.flattened_view(out) + out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0] return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 8166f1b15..60c5ba84f 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -91,8 +91,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) def expand(self): - from brevitas.utils.quant_utils import groupwise_dequant - return groupwise_dequant( + from brevitas.utils.quant_utils import groupwise_dequant_expand + return groupwise_dequant_expand( self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape) @staticmethod diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 7d97ad4cb..fa7e8438e 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -77,8 +77,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) def expand(self): - from brevitas.utils.quant_utils import groupwise_dequant - return groupwise_dequant( + from brevitas.utils.quant_utils import groupwise_dequant_expand + return groupwise_dequant_expand( self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape) @staticmethod diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 22c17eb9e..d0d245089 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -219,7 +219,7 @@ def float_to_int_impl_to_enum(module): return None -def groupwise_dequant(value_, scale_, zero_point_, group_dim, dequant_shape): +def groupwise_dequant_expand(value_, scale_, zero_point_, group_dim, dequant_shape): final_shape = dequant_shape curr_shape = value_.shape start_dim = group_dim if group_dim != -1 else -2 From e6b7bccdba6b2b41cc53f2a7051d7554284ff585 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 19 Dec 2024 13:48:07 +0000 Subject: [PATCH 09/10] group dim --- src/brevitas/export/inference/handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index be030985f..59944c2b0 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -296,6 +296,7 @@ def prepare_for_export(self, module: nn.Module): if module.is_quant_enabled: self.input_view = module.input_view_impl self.flattened_view = module.apply_input_view + self.group_dim = module.group_dim if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: self.cached_weight = module._cached_weight.quant_tensor.value_ else: From 1fdd11192bfc715e184154f611e3464f2b2128bf Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 10 Jan 2025 15:40:03 +0000 Subject: [PATCH 10/10] typing --- src/brevitas/proxy/groupwise_float_runtime_quant.py | 2 +- src/brevitas/proxy/groupwise_int_runtime_quant.py | 2 +- src/brevitas/proxy/runtime_quant.py | 6 ++---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index e1e8b75d7..5d76e4635 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -29,7 +29,7 @@ def apply_input_view(self, x): def create_quant_tensor( self, qt_args: Union[torch.Tensor, Tuple[Any]], - x: Optional[GroupwiseFloatQuantTensor] = None) -> GroupwiseFloatQuantTensor: + x: Union[torch.Tensor, GroupwiseFloatQuantTensor]) -> GroupwiseFloatQuantTensor: if isinstance(qt_args, tuple): value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args out = GroupwiseFloatQuantTensor( diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 09dbc1ae3..96d047808 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -29,7 +29,7 @@ def apply_input_view(self, x): def create_quant_tensor( self, qt_args: Union[torch.Tensor, Tuple[Any]], - x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor: + x: Union[torch.Tensor, GroupwiseIntQuantTensor]) -> GroupwiseIntQuantTensor: if isinstance(qt_args, tuple): value, scale, zero_point, bit_width = qt_args out = GroupwiseIntQuantTensor( diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 95949add7..3e7248602 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -222,10 +222,8 @@ def bit_width(self, force_eval=True): return self.retrieve_attribute('bit_width', force_eval) def create_quant_tensor( - self, - qt_args: Union[Tensor, Tuple[Any]], - x: Optional[IntQuantTensor] = None) -> IntQuantTensor: - + self, qt_args: Union[torch.Tensor, Tuple[Any]], + x: Union[Tensor, IntQuantTensor]) -> IntQuantTensor: if isinstance(qt_args, tuple): out = IntQuantTensor(*qt_args, self.is_signed, self.training) else: