diff --git a/src/brevitas/quant_tensor/base_quant_tensor.py b/src/brevitas/quant_tensor/base_quant_tensor.py index 7b5dcd597..8c951fe62 100644 --- a/src/brevitas/quant_tensor/base_quant_tensor.py +++ b/src/brevitas/quant_tensor/base_quant_tensor.py @@ -97,8 +97,8 @@ class IntQuantTensorBase(NamedTuple): scale: Tensor zero_point: Tensor bit_width: Tensor - signed_t: Tensor - training_t: Tensor + signed: bool + training: bool class FloatQuantTensorBase(NamedTuple): @@ -108,11 +108,11 @@ class FloatQuantTensorBase(NamedTuple): exponent_bit_width: Tensor mantissa_bit_width: Tensor exponent_bias: Tensor - saturating_t: Tensor + saturating: bool inf_values: List[str] nan_values: List[str] - signed_t: Tensor - training_t: Tensor + signed: bool + training: bool class GroupwiseFloatQuantTensorBase(NamedTuple): @@ -124,11 +124,11 @@ class GroupwiseFloatQuantTensorBase(NamedTuple): exponent_bit_width: Tensor mantissa_bit_width: Tensor exponent_bias: Tensor - saturating_t: Tensor + saturating: bool inf_values: List[str] nan_values: List[str] - signed_t: Tensor - training_t: Tensor + signed: bool + training: bool class GroupwisIntQuantTensorBase(NamedTuple): @@ -138,8 +138,8 @@ class GroupwisIntQuantTensorBase(NamedTuple): group_size: Tensor group_dim: Tensor bit_width: Tensor - signed_t: Tensor - training_t: Tensor + signed: bool + training: bool def _unpack_quant_tensor(input_data): diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index 8db6fda90..c221e1ca3 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -43,10 +43,6 @@ def __new__( exponent_bias = torch.tensor(exponent_bias, dtype=torch.float) if not isinstance(saturating, torch.Tensor): saturating = torch.tensor(saturating, dtype=torch.bool) - if not isinstance(signed, torch.Tensor): - signed = torch.tensor(signed, dtype=torch.bool) - if not isinstance(training, torch.Tensor): - training = torch.tensor(training, dtype=torch.bool) quant_tensor = super().__new__( cls, value, @@ -62,18 +58,6 @@ def __new__( training) return quant_tensor - @property - def signed(self): - return self.signed_t.item() - - @property - def training(self): - return self.training_t.item() - - @property - def saturating(self): - return self.saturating_t.item() - @property def eps(self): return torch.finfo(self.scale.dtype).tiny diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index b507d3fe3..b1d3204b9 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -43,12 +43,7 @@ def __new__( mantissa_bit_width = torch.tensor(mantissa_bit_width, dtype=torch.float) if not isinstance(exponent_bias, torch.Tensor): exponent_bias = torch.tensor(exponent_bias, dtype=torch.float) - if not isinstance(saturating, torch.Tensor): - saturating = torch.tensor(saturating, dtype=torch.bool) - if not isinstance(signed, torch.Tensor): - signed = torch.tensor(signed, dtype=torch.bool) - if not isinstance(training, torch.Tensor): - training = torch.tensor(training, dtype=torch.bool) + quant_tensor = super().__new__( cls, value, @@ -66,18 +61,6 @@ def __new__( training) return quant_tensor - @property - def signed(self): - return self.signed_t.item() - - @property - def training(self): - return self.training_t.item() - - @property - def saturating(self): - return self.saturating_t.item() - def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 082ec1234..fd1ed938d 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -26,22 +26,10 @@ def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, sig zero_point = torch.tensor(zero_point, dtype=torch.float) if not isinstance(bit_width, torch.Tensor): bit_width = torch.tensor(bit_width, dtype=torch.float) - if not isinstance(signed, torch.Tensor): - signed = torch.tensor(signed, dtype=torch.bool) - if not isinstance(training, torch.Tensor): - training = torch.tensor(training, dtype=torch.bool) quant_tensor = super().__new__( cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training) return quant_tensor - @property - def signed(self): - return self.signed_t.item() - - @property - def training(self): - return self.training_t.item() - @property def saturating(self): return self.saturating_t.item() diff --git a/src/brevitas/quant_tensor/int_quant_tensor.py b/src/brevitas/quant_tensor/int_quant_tensor.py index f06e321b2..5a8c42daf 100644 --- a/src/brevitas/quant_tensor/int_quant_tensor.py +++ b/src/brevitas/quant_tensor/int_quant_tensor.py @@ -28,21 +28,9 @@ def __new__(cls, value, scale, zero_point, bit_width, signed, training): zero_point = torch.tensor(zero_point, dtype=torch.float) if not isinstance(bit_width, torch.Tensor): bit_width = torch.tensor(bit_width, dtype=torch.float) - if not isinstance(signed, torch.Tensor): - signed = torch.tensor(signed, dtype=torch.bool) - if not isinstance(training, torch.Tensor): - training = torch.tensor(training, dtype=torch.bool) quant_tensor = super().__new__(cls, value, scale, zero_point, bit_width, signed, training) return quant_tensor - @property - def signed(self): - return self.signed_t.item() - - @property - def training(self): - return self.training_t.item() - def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {}