Skip to content

Commit

Permalink
Feat (mx): unpadding during dequantization (#1134)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jan 11, 2025
1 parent b83ab89 commit adeeec3
Show file tree
Hide file tree
Showing 15 changed files with 113 additions and 83 deletions.
10 changes: 5 additions & 5 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,15 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Non padding weights shape torch.Size([64, 8, 3, 3])\n",
"Padded weights shape torch.Size([64, 32, 3, 3])\n"
"Non padding weights shape torch.Size([64, 1, 8, 3, 3])\n",
"Padded weights shape torch.Size([64, 1, 32, 3, 3])\n"
]
}
],
Expand Down Expand Up @@ -257,8 +257,8 @@
"o = mx_model(x)\n",
"\n",
"# The quant weight of the padded model is different from the non padding one\n",
"print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value.shape}\")\n",
"print(f\"Padded weights shape {mx_model.conv.quant_weight().value.shape}\")\n",
"print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value_.shape}\")\n",
"print(f\"Padded weights shape {mx_model.conv.quant_weight().value_.shape}\")\n",
"\n",
"# However, results are still the same \n",
"assert torch.allclose(o, o_no_padding)"
Expand Down
10 changes: 7 additions & 3 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector
from brevitas.utils.quant_utils import groupwise_dequant_expand
from brevitas.utils.torch_utils import float_internal_scale


Expand Down Expand Up @@ -146,8 +147,8 @@ def __init__(self):
def prepare_for_export(self, module):
super().prepare_for_export(module)
if module.is_quant_enabled:
self.group_dim = module.group_dim
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.quant_tensor.value_
else:
Expand All @@ -165,12 +166,13 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.cached_weight is not None:
out = self.cached_weight
else:
inp_shape = x.shape
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
out = self.flattened_view(out)
out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0]
return out, scale, zero_point, self.bit_width


Expand Down Expand Up @@ -294,6 +296,7 @@ def prepare_for_export(self, module: nn.Module):
if module.is_quant_enabled:
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
self.group_dim = module.group_dim
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.quant_tensor.value_
else:
Expand All @@ -311,11 +314,12 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.cached_weight is not None:
out = self.cached_weight
else:
inp_shape = x.shape
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
out = self.flattened_view(out)
out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0]

return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
6 changes: 0 additions & 6 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,6 @@ def __init__(
self.layer = layer
self.name = name
self.act_order = act_order
if self.layer.weight_quant.is_groupwise:
weight = self.layer.weight_quant.apply_input_view(self.layer.weight)
weight = weight.view(self.layer.weight_quant.quant_injector.reshaped_groupwise_shape)
self.layer.weight.data = weight.data
self.layer.in_channels = weight.shape[1] if is_conv_transposed(
self.layer) else weight.shape[0]

weight_shape = torch.tensor(layer.weight.shape)

Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor:
if x is None:
if isinstance(qt_args, tuple):
out = FloatQuantTensor(*qt_args, signed=self.is_signed, training=self.training)
else:
out = FloatQuantTensor(
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def apply_input_view(self, x):
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor:
shape = self.tracked_parameter_list[0].shape
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return GroupwiseFloatQuantTensor(
out,
Expand All @@ -48,4 +49,5 @@ def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor:
inf_values,
nan_values,
self.is_signed,
self.training)
self.training,
shape)
10 changes: 6 additions & 4 deletions src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def apply_input_view(self, x):
def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[GroupwiseFloatQuantTensor] = None) -> GroupwiseFloatQuantTensor:
if x is None:
x: Union[torch.Tensor, GroupwiseFloatQuantTensor]) -> GroupwiseFloatQuantTensor:
if isinstance(qt_args, tuple):
value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
out = GroupwiseFloatQuantTensor(
value,
Expand All @@ -45,7 +45,8 @@ def create_quant_tensor(
inf_values,
nan_values,
self.is_signed,
self.training)
self.training,
dequant_shape=x.shape)
else:
out = GroupwiseFloatQuantTensor(
qt_args,
Expand All @@ -60,5 +61,6 @@ def create_quant_tensor(
x.inf_values,
x.nan_values,
x.signed,
self.training)
self.training,
x.shape)
return out
4 changes: 3 additions & 1 deletion src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def apply_input_view(self, x):
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor:
shape = self.tracked_parameter_list[0].shape
out, scale, zero_point, bit_width = qt_args
return GroupwiseIntQuantTensor(
out,
Expand All @@ -43,4 +44,5 @@ def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor:
self.group_dim,
bit_width,
self.is_signed,
self.training)
self.training,
shape)
10 changes: 6 additions & 4 deletions src/brevitas/proxy/groupwise_int_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def apply_input_view(self, x):
def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor:
if x is None:
x: Union[torch.Tensor, GroupwiseIntQuantTensor]) -> GroupwiseIntQuantTensor:
if isinstance(qt_args, tuple):
value, scale, zero_point, bit_width = qt_args
out = GroupwiseIntQuantTensor(
value,
Expand All @@ -40,7 +40,8 @@ def create_quant_tensor(
self.group_dim,
bit_width,
self.is_signed,
self.training)
self.training,
x.shape)
else:
out = GroupwiseIntQuantTensor(
qt_args,
Expand All @@ -50,5 +51,6 @@ def create_quant_tensor(
self.group_dim,
x.bit_width,
x.signed,
self.training)
self.training,
x.shape)
return out
2 changes: 1 addition & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
out.detach(),
metadata_only=self.cache_inference_quant_weight_metadata_only)
else: # quantization disabled
out = self.apply_input_view(x)
out = x
return out


Expand Down
18 changes: 7 additions & 11 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,8 @@ def init_tensor_quant(self):

@abstractmethod
def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[QuantTensor] = None) -> QuantTensor:
self, qt_args: Union[torch.Tensor, Tuple[Any]], x: Union[Tensor,
QuantTensor]) -> QuantTensor:
# Supports the following:
# - qt_args as tuple of Tensors and bools = standard quant activations
# - qt_args as Tensor and x as QuantTensor = passthrough activation
Expand All @@ -181,8 +180,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
elif not self.is_quant_enabled:
# A tuple helps later with control flows
# The second None value is used later
# If quant is not enabled, we still apply input_view in the case of groupwise + padding
y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y))
y = self.fused_activation_quant_proxy.activation_impl(y)
y = (y, None)
else:
y = self.fused_activation_quant_proxy(y)
Expand All @@ -194,7 +192,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
else:
# If the second value (i.e., scale) is None, then quant is disabled
if y[1] is not None:
out = self.create_quant_tensor(y)
out = self.create_quant_tensor(y, x=x)
elif self.is_passthrough_act and isinstance(x, QuantTensor):
# preserve scale/zp/bit/sign even without output quant
y = y[0]
Expand Down Expand Up @@ -224,11 +222,9 @@ def bit_width(self, force_eval=True):
return self.retrieve_attribute('bit_width', force_eval)

def create_quant_tensor(
self,
qt_args: Union[Tensor, Tuple[Any]],
x: Optional[IntQuantTensor] = None) -> IntQuantTensor:

if x is None:
self, qt_args: Union[torch.Tensor, Tuple[Any]],
x: Union[Tensor, IntQuantTensor]) -> IntQuantTensor:
if isinstance(qt_args, tuple):
out = IntQuantTensor(*qt_args, self.is_signed, self.training)
else:
out = IntQuantTensor(
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/quant_tensor/base_quant_tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, NamedTuple, Optional
from typing import List, NamedTuple, Optional, Tuple

from torch import Tensor

Expand Down Expand Up @@ -129,6 +129,7 @@ class GroupwiseFloatQuantTensorBase(NamedTuple):
nan_values: List[str]
signed_t: Tensor
training_t: Tensor
dequant_shape: Optional[Tuple] = None


class GroupwisIntQuantTensorBase(NamedTuple):
Expand All @@ -140,6 +141,7 @@ class GroupwisIntQuantTensorBase(NamedTuple):
bit_width: Tensor
signed_t: Tensor
training_t: Tensor
dequant_shape: Optional[Tuple] = None


def _unpack_quant_tensor(input_data):
Expand Down
22 changes: 7 additions & 15 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __new__(
inf_values,
nan_values,
signed,
training):
training,
dequant_shape=None):

if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, dtype=torch.float)
Expand Down Expand Up @@ -63,7 +64,8 @@ def __new__(
inf_values,
nan_values,
signed,
training)
training,
dequant_shape)
return quant_tensor

@property
Expand All @@ -89,19 +91,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)

def expand(self):
curr_shape = self.value_.shape
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_scale = self.scale_
if self.zero_point_.shape != ():
new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_zp = self.zero_point_

return new_value, new_scale, new_zp
from brevitas.utils.quant_utils import groupwise_dequant_expand
return groupwise_dequant_expand(
self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape)

@staticmethod
def from_expanded(value, group_size, group_dim, compress=False):
Expand Down
39 changes: 24 additions & 15 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@

class GroupwiseIntQuantTensor(GroupwisIntQuantTensorBase, QuantTensor):

def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training):
def __new__(
cls,
value,
scale,
zero_point,
group_size,
group_dim,
bit_width,
signed,
training,
dequant_shape=None):

if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, dtype=torch.float)
Expand All @@ -31,7 +41,16 @@ def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, sig
if not isinstance(training, torch.Tensor):
training = torch.tensor(training, dtype=torch.bool)
quant_tensor = super().__new__(
cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training)
cls,
value,
scale,
zero_point,
group_size,
group_dim,
bit_width,
signed,
training,
dequant_shape)
return quant_tensor

@property
Expand All @@ -58,19 +77,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)

def expand(self):
curr_shape = self.value_.shape
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_scale = self.scale_
if self.zero_point_.shape != ():
new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_zp = self.zero_point_

return new_value, new_scale, new_zp
from brevitas.utils.quant_utils import groupwise_dequant_expand
return groupwise_dequant_expand(
self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape)

@staticmethod
def from_expanded(value, group_size, group_dim, compress=False):
Expand Down
Loading

0 comments on commit adeeec3

Please sign in to comment.