From 2c26cd14c83ffdaff28e3db713803b71788995ae Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 20 Aug 2024 11:41:02 +0100 Subject: [PATCH] Review, notebook missing --- src/brevitas/graph/quantize.py | 6 ---- src/brevitas/nn/mixin/base.py | 5 ++- src/brevitas/proxy/float_runtime_quant.py | 14 ++++----- .../proxy/groupwise_float_parameter_quant.py | 2 +- .../proxy/groupwise_float_runtime_quant.py | 10 +++--- .../proxy/groupwise_int_parameter_quant.py | 4 +-- .../proxy/groupwise_int_runtime_quant.py | 2 +- .../quant_tensor/float_quant_tensor.py | 21 +++++++------ .../quant_tensor/float_torch_handler.py | 6 ++-- .../groupwise_float_quant_tensor.py | 29 ++++++++++------- .../groupwise_int_quant_tensor.py | 31 +++++++++---------- 11 files changed, 65 insertions(+), 65 deletions(-) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 3a665d5ad..ee035b9bd 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -8,13 +8,11 @@ from brevitas.core.scaling.standalone import ParameterScaling from brevitas.fx.brevitas_tracer import symbolic_trace from brevitas.graph.base import ModuleToModuleByClass -from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.channel_splitting import GraphChannelSplitting from brevitas.graph.equalize import EqualizeGraph from brevitas.graph.fixed_point import CollapseConsecutiveConcats from brevitas.graph.fixed_point import MergeBatchNorm from brevitas.graph.fixed_point import MoveSplitBatchNormBeforeCat -from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool from brevitas.graph.quantize_impl import act_handler from brevitas.graph.quantize_impl import add_output_quant_handler from brevitas.graph.quantize_impl import inp_placeholder_handler @@ -26,11 +24,7 @@ from brevitas.graph.standardize import MeanMethodToAdaptiveAvgPool2d from brevitas.graph.standardize import RemoveStochasticModules from brevitas.graph.standardize import TorchFunctionalToModule -from brevitas.nn import quant_layer import brevitas.nn as qnn -from brevitas.proxy.groupwise_float_parameter_quant import \ - GroupwiseWeightFloatQuantProxyFromInjector -from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector from brevitas.quant import Int8ActPerTensorFloat from brevitas.quant import Int8ActPerTensorFloatMinMaxInit from brevitas.quant import Int8WeightPerTensorFloat diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index bbdd77ac7..167852508 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -21,6 +21,8 @@ from brevitas.quant_tensor import FloatQuantTensor from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor +from brevitas.quant_tensor.groupwise_float_quant_tensor import GroupwiseFloatQuantTensor +from brevitas.quant_tensor.groupwise_int_quant_tensor import GroupwiseIntQuantTensor from .utils import filter_kwargs @@ -71,7 +73,8 @@ def _set_global_is_quant_layer(self, value): config._IS_INSIDE_QUANT_LAYER = value def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]): - quant_tensor_classes = [IntQuantTensor, FloatQuantTensor] + quant_tensor_classes = [ + IntQuantTensor, FloatQuantTensor, GroupwiseIntQuantTensor, GroupwiseFloatQuantTensor] for qt_class in quant_tensor_classes: if len(inp) == len(qt_class._fields): return qt_class diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index 4b5bf2b6d..021aefd12 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -87,14 +87,14 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, FloatQuantTens y, x.scale, x.zero_point, - x.mantissa_bit_width, x.exponent_bit_width, + x.mantissa_bit_width, x.exponent_bias, - x.signed, - self.training, x.saturating, x.inf_values, - x.nan_values) + x.nan_values, + x.signed, + self.training) else: out = y else: @@ -143,11 +143,11 @@ def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuan x.mantissa_bit_width, x.exponent_bit_width, x.exponent_bias, - x.signed, - self.training, x.saturating, x.inf_values, - x.nan_values) + x.nan_values, + x.signed, + self.training) else: out = y else: diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 04e3e9895..cd38d9906 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union import torch from torch import Tensor diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index c79f28fdb..4ab182d20 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -59,16 +59,16 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloat y, x.scale, x.zero_point, - self.group_dim, self.group_size, - x.mantissa_bit_width, + self.group_dim, x.exponent_bit_width, + x.mantissa_bit_width, x.exponent_bias, - x.signed, - self.training, x.saturating, x.inf_values, - x.nan_values) + x.nan_values, + x.signed, + self.training) else: out = y else: diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index e1952ddef..035ee9729 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -31,9 +31,9 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseIntQuantTensor]: out, scale, zero_point, - bit_width, - self.group_dim, self.group_size, + self.group_dim, + bit_width, self.is_signed, self.training) else: # quantization disabled diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 3aceb3c02..e9788e89b 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -54,8 +54,8 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseIntQu y, x.scale, x.zero_point, - self.group_dim, self.group_size, + self.group_dim, x.bit_width, x.signed, self.training) diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index 74f42dc94..cf4ba1420 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -256,11 +256,11 @@ def cat(tensors, dim, out=None): exponent_bit_width=output_exponent_bit_width, mantissa_bit_width=output_mantissa_bit_width, exponent_bias=output_exponent_bias, - signed=output_signed, - training=output_training, saturating=output_saturating, inf_values=output_inf_values, - nan_values=output_nan_values) + nan_values=output_nan_values, + signed=output_signed, + training=output_training) else: tensors = [_unpack_quant_tensor(qt) for qt in tensors] output_value = torch.cat(tensors, dim=dim) @@ -280,11 +280,11 @@ def __neg__(self): exponent_bit_width=self.exponent_bit_width, mantissa_bit_width=self.mantissa_bit_width, exponent_bias=self.exponent_bias, - signed=self.signed, - training=self.training, saturating=self.saturating, inf_values=self.inf_values, - nan_values=self.nan_values) + nan_values=self.nan_values, + signed=self.signed, + training=self.training) else: # TODO: implement raise NotImplementedError @@ -304,7 +304,7 @@ def __mul__(self, other): return output def __str__(self): - return f"FloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + return f"FloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, exponent_bit_width={self.exponent_bit_width}, mantissa_bit_width={self.mantissa_bit_width}, exponent_bias={self.exponent_bias}, inf_values={self.inf_values}, nan_values={self.nan_values}, signed_t={self.signed_t}, training_t={self.training_t})" def __truediv__(self, other): if isinstance(other, QuantTensor): @@ -325,10 +325,11 @@ def __abs__(self): zero_point=self.zero_point, exponent_bit_width=self.exponent_bit_width, mantissa_bit_width=self.mantissa_bit_width, - signed=False, - training=self.training, + exponent_bias=self.exponent_bias, saturating=self.saturating, inf_values=self.inf_values, - nan_values=self.nan_values) + nan_values=self.nan_values, + signed=False, + training=self.training) else: return self diff --git a/src/brevitas/quant_tensor/float_torch_handler.py b/src/brevitas/quant_tensor/float_torch_handler.py index 7fb4507c1..60401fc4a 100644 --- a/src/brevitas/quant_tensor/float_torch_handler.py +++ b/src/brevitas/quant_tensor/float_torch_handler.py @@ -92,11 +92,11 @@ def embedding_handler(input, quant_weight, *args, **kwargs): exponent_bit_width, mantissa_bit_width, exponent_bias, - signed, - training, saturating, inf_values, - nan_values) + nan_values, + signed, + training) return out diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index d636d92f6..7d73bf7de 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -103,7 +103,7 @@ def expand(self): return new_value, new_scale, new_zp @staticmethod - def from_expanded(value, group_dim, group_size, compress=False): + def from_expanded(value, group_size, group_dim, compress=False): size = list(value.shape) assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' if compress: @@ -253,22 +253,24 @@ def __neg__(self): # In case the dtype of self.minifloat is different from the one of the scale neg_value = neg_value.type(scale.dtype) neg_value = GroupwiseFloatQuantTensor.from_expanded( - neg_value, self.group_dim, self.group_size, compress=False) + neg_value, self.group_size, self.group_dim, compress=False) scale = GroupwiseFloatQuantTensor.from_expanded( - scale, self.group_dim, self.group_size, compress=True) + scale, self.group_size, self.group_dim, compress=True) if self.signed: return GroupwiseFloatQuantTensor( value=neg_value, scale=scale, zero_point=self.zero_point, + group_size=self.group_size, + group_dim=self.group_dim, exponent_bit_width=self.exponent_bit_width, mantissa_bit_width=self.mantissa_bit_width, exponent_bias=self.exponent_bias, - signed=self.signed, - training=self.training, saturating=self.saturating, inf_values=self.inf_values, - nan_values=self.nan_values) + nan_values=self.nan_values, + signed=self.signed, + training=self.training) else: # TODO: implement raise NotImplementedError @@ -288,7 +290,7 @@ def __mul__(self, other): return output def __str__(self): - return f"GroupwiseFloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + return f"GroupwiseFloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, group_size={self.group_size}, group_dim={self.group_dim}, exponent_bit_width={self.exponent_bit_width}, mantissa_bit_width={self.mantissa_bit_width}, exponent_bias={self.exponent_bias}, inf_values={self.inf_values}, nan_values={self.nan_values}, signed_t={self.signed_t}, training_t={self.training_t})" def __truediv__(self, other): if isinstance(other, QuantTensor): @@ -307,19 +309,22 @@ def __abs__(self): # In case the dtype of self.minifloat is different from the one of the scale abs_value = abs_value.type(scale.dtype) abs_value = GroupwiseFloatQuantTensor.from_expanded( - abs_value, self.group_dim, self.group_size, compress=False) + abs_value, self.group_size, self.group_dim, compress=False) scale = GroupwiseFloatQuantTensor.from_expanded( - scale, self.group_dim, self.group_size, compress=True) + scale, self.group_size, self.group_dim, compress=True) return GroupwiseFloatQuantTensor( value=abs_value, scale=self.scale, zero_point=self.zero_point, + group_size=self.group_size, + group_dim=self.group_dim, exponent_bit_width=self.exponent_bit_width, mantissa_bit_width=self.mantissa_bit_width, - signed=False, - training=self.training, + exponent_bias=self.exponent_bias, saturating=self.saturating, inf_values=self.inf_values, - nan_values=self.nan_values) + nan_values=self.nan_values, + signed=False, + training=self.training) else: return self diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 371b4ef6a..976e86130 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -72,7 +72,7 @@ def expand(self): return new_value, new_scale, new_zp @staticmethod - def from_expanded(value, group_dim, group_size, compress=False): + def from_expanded(value, group_size, group_dim, compress=False): size = list(value.shape) assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' if compress: @@ -243,22 +243,20 @@ def __neg__(self): # In case the dtype of self.minifloat is different from the one of the scale neg_value = neg_value.type(scale.dtype) neg_value = GroupwiseIntQuantTensor.from_expanded( - neg_value, self.group_dim, self.group_size, compress=False) + neg_value, self.group_size, self.group_dim, compress=False) scale = GroupwiseIntQuantTensor.from_expanded( - scale, self.group_dim, self.group_size, compress=True) + scale, self.group_size, self.group_dim, compress=True) if self.signed: return GroupwiseIntQuantTensor( value=neg_value, scale=scale, zero_point=self.zero_point, - exponent_bit_width=self.exponent_bit_width, - mantissa_bit_width=self.mantissa_bit_width, - exponent_bias=self.exponent_bias, + group_size=self.group_size, + group_dim=self.group_dim, + bit_width=self.bit_width, signed=self.signed, training=self.training, - saturating=self.saturating, - inf_values=self.inf_values, - nan_values=self.nan_values) + saturating=self.saturating) else: # TODO: implement raise NotImplementedError @@ -278,7 +276,7 @@ def __mul__(self, other): return output def __str__(self): - return f"GroupwiseIntQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + return f"GroupwiseIntQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, group_size={self.group_size}, group_dim={self.group_dim}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" def __truediv__(self, other): if isinstance(other, QuantTensor): @@ -297,19 +295,18 @@ def __abs__(self): # In case the dtype of self.minifloat is different from the one of the scale abs_value = abs_value.type(scale.dtype) abs_value = GroupwiseIntQuantTensor.from_expanded( - abs_value, self.group_dim, self.group_size, compress=False) + abs_value, self.group_size, self.group_dim, compress=False) scale = GroupwiseIntQuantTensor.from_expanded( - scale, self.group_dim, self.group_size, compress=True) + scale, self.group_size, self.group_dim, compress=True) return GroupwiseIntQuantTensor( value=abs_value, scale=self.scale, zero_point=self.zero_point, - exponent_bit_width=self.exponent_bit_width, - mantissa_bit_width=self.mantissa_bit_width, + group_size=self.group_size, + group_dim=self.group_dim, + bit_width=self.bit_width, signed=False, training=self.training, - saturating=self.saturating, - inf_values=self.inf_values, - nan_values=self.nan_values) + saturating=self.saturating) else: return self