From afa4af1839c09ce11696d16f2714cab2e3d3ff66 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 26 Dec 2024 17:15:39 +0000 Subject: [PATCH 1/4] Fix (quant_tensor): clean-up QT creation --- .../quant_tensor/base_quant_tensor.py | 20 +++++++++---------- .../quant_tensor/float_quant_tensor.py | 16 --------------- .../groupwise_float_quant_tensor.py | 19 +----------------- .../groupwise_int_quant_tensor.py | 12 ----------- src/brevitas/quant_tensor/int_quant_tensor.py | 12 ----------- 5 files changed, 11 insertions(+), 68 deletions(-) diff --git a/src/brevitas/quant_tensor/base_quant_tensor.py b/src/brevitas/quant_tensor/base_quant_tensor.py index 7b5dcd597..8c951fe62 100644 --- a/src/brevitas/quant_tensor/base_quant_tensor.py +++ b/src/brevitas/quant_tensor/base_quant_tensor.py @@ -97,8 +97,8 @@ class IntQuantTensorBase(NamedTuple): scale: Tensor zero_point: Tensor bit_width: Tensor - signed_t: Tensor - training_t: Tensor + signed: bool + training: bool class FloatQuantTensorBase(NamedTuple): @@ -108,11 +108,11 @@ class FloatQuantTensorBase(NamedTuple): exponent_bit_width: Tensor mantissa_bit_width: Tensor exponent_bias: Tensor - saturating_t: Tensor + saturating: bool inf_values: List[str] nan_values: List[str] - signed_t: Tensor - training_t: Tensor + signed: bool + training: bool class GroupwiseFloatQuantTensorBase(NamedTuple): @@ -124,11 +124,11 @@ class GroupwiseFloatQuantTensorBase(NamedTuple): exponent_bit_width: Tensor mantissa_bit_width: Tensor exponent_bias: Tensor - saturating_t: Tensor + saturating: bool inf_values: List[str] nan_values: List[str] - signed_t: Tensor - training_t: Tensor + signed: bool + training: bool class GroupwisIntQuantTensorBase(NamedTuple): @@ -138,8 +138,8 @@ class GroupwisIntQuantTensorBase(NamedTuple): group_size: Tensor group_dim: Tensor bit_width: Tensor - signed_t: Tensor - training_t: Tensor + signed: bool + training: bool def _unpack_quant_tensor(input_data): diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index 8db6fda90..c221e1ca3 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -43,10 +43,6 @@ def __new__( exponent_bias = torch.tensor(exponent_bias, dtype=torch.float) if not isinstance(saturating, torch.Tensor): saturating = torch.tensor(saturating, dtype=torch.bool) - if not isinstance(signed, torch.Tensor): - signed = torch.tensor(signed, dtype=torch.bool) - if not isinstance(training, torch.Tensor): - training = torch.tensor(training, dtype=torch.bool) quant_tensor = super().__new__( cls, value, @@ -62,18 +58,6 @@ def __new__( training) return quant_tensor - @property - def signed(self): - return self.signed_t.item() - - @property - def training(self): - return self.training_t.item() - - @property - def saturating(self): - return self.saturating_t.item() - @property def eps(self): return torch.finfo(self.scale.dtype).tiny diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index b507d3fe3..b1d3204b9 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -43,12 +43,7 @@ def __new__( mantissa_bit_width = torch.tensor(mantissa_bit_width, dtype=torch.float) if not isinstance(exponent_bias, torch.Tensor): exponent_bias = torch.tensor(exponent_bias, dtype=torch.float) - if not isinstance(saturating, torch.Tensor): - saturating = torch.tensor(saturating, dtype=torch.bool) - if not isinstance(signed, torch.Tensor): - signed = torch.tensor(signed, dtype=torch.bool) - if not isinstance(training, torch.Tensor): - training = torch.tensor(training, dtype=torch.bool) + quant_tensor = super().__new__( cls, value, @@ -66,18 +61,6 @@ def __new__( training) return quant_tensor - @property - def signed(self): - return self.signed_t.item() - - @property - def training(self): - return self.training_t.item() - - @property - def saturating(self): - return self.saturating_t.item() - def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 082ec1234..fd1ed938d 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -26,22 +26,10 @@ def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, sig zero_point = torch.tensor(zero_point, dtype=torch.float) if not isinstance(bit_width, torch.Tensor): bit_width = torch.tensor(bit_width, dtype=torch.float) - if not isinstance(signed, torch.Tensor): - signed = torch.tensor(signed, dtype=torch.bool) - if not isinstance(training, torch.Tensor): - training = torch.tensor(training, dtype=torch.bool) quant_tensor = super().__new__( cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training) return quant_tensor - @property - def signed(self): - return self.signed_t.item() - - @property - def training(self): - return self.training_t.item() - @property def saturating(self): return self.saturating_t.item() diff --git a/src/brevitas/quant_tensor/int_quant_tensor.py b/src/brevitas/quant_tensor/int_quant_tensor.py index f06e321b2..5a8c42daf 100644 --- a/src/brevitas/quant_tensor/int_quant_tensor.py +++ b/src/brevitas/quant_tensor/int_quant_tensor.py @@ -28,21 +28,9 @@ def __new__(cls, value, scale, zero_point, bit_width, signed, training): zero_point = torch.tensor(zero_point, dtype=torch.float) if not isinstance(bit_width, torch.Tensor): bit_width = torch.tensor(bit_width, dtype=torch.float) - if not isinstance(signed, torch.Tensor): - signed = torch.tensor(signed, dtype=torch.bool) - if not isinstance(training, torch.Tensor): - training = torch.tensor(training, dtype=torch.bool) quant_tensor = super().__new__(cls, value, scale, zero_point, bit_width, signed, training) return quant_tensor - @property - def signed(self): - return self.signed_t.item() - - @property - def training(self): - return self.training_t.item() - def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} From 1297617c74b5d3e616052bf0f36a54df995ee023 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 27 Dec 2024 00:51:14 +0000 Subject: [PATCH 2/4] Fix --- src/brevitas/quant_tensor/float_quant_tensor.py | 2 +- src/brevitas/quant_tensor/groupwise_float_quant_tensor.py | 2 +- src/brevitas/quant_tensor/groupwise_int_quant_tensor.py | 6 +++--- src/brevitas/quant_tensor/int_quant_tensor.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index c221e1ca3..4bd313160 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -299,7 +299,7 @@ def __mul__(self, other): return output def __str__(self): - return f"FloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, exponent_bit_width={self.exponent_bit_width}, mantissa_bit_width={self.mantissa_bit_width}, exponent_bias={self.exponent_bias}, inf_values={self.inf_values}, nan_values={self.nan_values}, signed_t={self.signed_t}, training_t={self.training_t})" + return f"FloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, exponent_bit_width={self.exponent_bit_width}, mantissa_bit_width={self.mantissa_bit_width}, exponent_bias={self.exponent_bias}, inf_values={self.inf_values}, nan_values={self.nan_values}, signed={self.signed}, training={self.training})" def __truediv__(self, other): if isinstance(other, QuantTensor): diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index b1d3204b9..65978e600 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -288,7 +288,7 @@ def __mul__(self, other): return output def __str__(self): - return f"GroupwiseFloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, group_size={self.group_size}, group_dim={self.group_dim}, exponent_bit_width={self.exponent_bit_width}, mantissa_bit_width={self.mantissa_bit_width}, exponent_bias={self.exponent_bias}, inf_values={self.inf_values}, nan_values={self.nan_values}, signed_t={self.signed_t}, training_t={self.training_t})" + return f"GroupwiseFloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, group_size={self.group_size}, group_dim={self.group_dim}, exponent_bit_width={self.exponent_bit_width}, mantissa_bit_width={self.mantissa_bit_width}, exponent_bias={self.exponent_bias}, inf_values={self.inf_values}, nan_values={self.nan_values}, signed={self.signed}, training={self.training})" def __truediv__(self, other): if isinstance(other, QuantTensor): diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index fd1ed938d..af6aa49d0 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -154,9 +154,9 @@ def int(self, float_datatype=False): else: return int_value.type(torch.float32) else: - if self.bit_width <= 8. and self.signed_t.item(): + if self.bit_width <= 8. and self.signed: return int_value.to(torch.int8) - elif self.bit_width <= 8. and not self.signed_t.item(): + elif self.bit_width <= 8. and not self.signed: return int_value.to(torch.uint8) else: return int_value.to(torch.int32) @@ -266,7 +266,7 @@ def __mul__(self, other): return output def __str__(self): - return f"GroupwiseIntQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, group_size={self.group_size}, group_dim={self.group_dim}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + return f"GroupwiseIntQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, group_size={self.group_size}, group_dim={self.group_dim}, bit_width={self.bit_width}, signed={self.signed}, training={self.training})" def __truediv__(self, other): if isinstance(other, QuantTensor): diff --git a/src/brevitas/quant_tensor/int_quant_tensor.py b/src/brevitas/quant_tensor/int_quant_tensor.py index 5a8c42daf..072eccff4 100644 --- a/src/brevitas/quant_tensor/int_quant_tensor.py +++ b/src/brevitas/quant_tensor/int_quant_tensor.py @@ -106,9 +106,9 @@ def int(self, float_datatype=False): else: return int_value.type(torch.float32) else: - if self.bit_width <= 8. and self.signed_t.item(): + if self.bit_width <= 8. and self.signed.item(): return int_value.to(torch.int8) - elif self.bit_width <= 8. and not self.signed_t.item(): + elif self.bit_width <= 8. and not self.signed.item(): return int_value.to(torch.uint8) else: return int_value.to(torch.int32) @@ -309,7 +309,7 @@ def __mul__(self, other): return output def __str__(self): - return f"IntQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + return f"IntQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed={self.signed}, training={self.training})" def __truediv__(self, other): if isinstance(other, IntQuantTensor): From 4e2c4c5e97bd4d49850e7d2d4bcfb9355d33ae3f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 27 Dec 2024 00:54:32 +0000 Subject: [PATCH 3/4] No item --- src/brevitas/quant_tensor/int_quant_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/quant_tensor/int_quant_tensor.py b/src/brevitas/quant_tensor/int_quant_tensor.py index 072eccff4..6e9f1053a 100644 --- a/src/brevitas/quant_tensor/int_quant_tensor.py +++ b/src/brevitas/quant_tensor/int_quant_tensor.py @@ -106,9 +106,9 @@ def int(self, float_datatype=False): else: return int_value.type(torch.float32) else: - if self.bit_width <= 8. and self.signed.item(): + if self.bit_width <= 8. and self.signed(): return int_value.to(torch.int8) - elif self.bit_width <= 8. and not self.signed.item(): + elif self.bit_width <= 8. and not self.signed(): return int_value.to(torch.uint8) else: return int_value.to(torch.int32) From 8d38d7681833cc07ae2e20a11a77674b6192f764 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 27 Dec 2024 00:57:12 +0000 Subject: [PATCH 4/4] No parenthesis --- src/brevitas/quant_tensor/int_quant_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/quant_tensor/int_quant_tensor.py b/src/brevitas/quant_tensor/int_quant_tensor.py index 6e9f1053a..4a9c72d9f 100644 --- a/src/brevitas/quant_tensor/int_quant_tensor.py +++ b/src/brevitas/quant_tensor/int_quant_tensor.py @@ -106,9 +106,9 @@ def int(self, float_datatype=False): else: return int_value.type(torch.float32) else: - if self.bit_width <= 8. and self.signed(): + if self.bit_width <= 8. and self.signed: return int_value.to(torch.int8) - elif self.bit_width <= 8. and not self.signed(): + elif self.bit_width <= 8. and not self.signed: return int_value.to(torch.uint8) else: return int_value.to(torch.int32)