Skip to content

Commit

Permalink
Fix (quant_tensor): clean-up QT creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 27, 2024
1 parent 39ce837 commit afa4af1
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 68 deletions.
20 changes: 10 additions & 10 deletions src/brevitas/quant_tensor/base_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ class IntQuantTensorBase(NamedTuple):
scale: Tensor
zero_point: Tensor
bit_width: Tensor
signed_t: Tensor
training_t: Tensor
signed: bool
training: bool


class FloatQuantTensorBase(NamedTuple):
Expand All @@ -108,11 +108,11 @@ class FloatQuantTensorBase(NamedTuple):
exponent_bit_width: Tensor
mantissa_bit_width: Tensor
exponent_bias: Tensor
saturating_t: Tensor
saturating: bool
inf_values: List[str]
nan_values: List[str]
signed_t: Tensor
training_t: Tensor
signed: bool
training: bool


class GroupwiseFloatQuantTensorBase(NamedTuple):
Expand All @@ -124,11 +124,11 @@ class GroupwiseFloatQuantTensorBase(NamedTuple):
exponent_bit_width: Tensor
mantissa_bit_width: Tensor
exponent_bias: Tensor
saturating_t: Tensor
saturating: bool
inf_values: List[str]
nan_values: List[str]
signed_t: Tensor
training_t: Tensor
signed: bool
training: bool


class GroupwisIntQuantTensorBase(NamedTuple):
Expand All @@ -138,8 +138,8 @@ class GroupwisIntQuantTensorBase(NamedTuple):
group_size: Tensor
group_dim: Tensor
bit_width: Tensor
signed_t: Tensor
training_t: Tensor
signed: bool
training: bool


def _unpack_quant_tensor(input_data):
Expand Down
16 changes: 0 additions & 16 deletions src/brevitas/quant_tensor/float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ def __new__(
exponent_bias = torch.tensor(exponent_bias, dtype=torch.float)
if not isinstance(saturating, torch.Tensor):
saturating = torch.tensor(saturating, dtype=torch.bool)
if not isinstance(signed, torch.Tensor):
signed = torch.tensor(signed, dtype=torch.bool)
if not isinstance(training, torch.Tensor):
training = torch.tensor(training, dtype=torch.bool)
quant_tensor = super().__new__(
cls,
value,
Expand All @@ -62,18 +58,6 @@ def __new__(
training)
return quant_tensor

@property
def signed(self):
return self.signed_t.item()

@property
def training(self):
return self.training_t.item()

@property
def saturating(self):
return self.saturating_t.item()

@property
def eps(self):
return torch.finfo(self.scale.dtype).tiny
Expand Down
19 changes: 1 addition & 18 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ def __new__(
mantissa_bit_width = torch.tensor(mantissa_bit_width, dtype=torch.float)
if not isinstance(exponent_bias, torch.Tensor):
exponent_bias = torch.tensor(exponent_bias, dtype=torch.float)
if not isinstance(saturating, torch.Tensor):
saturating = torch.tensor(saturating, dtype=torch.bool)
if not isinstance(signed, torch.Tensor):
signed = torch.tensor(signed, dtype=torch.bool)
if not isinstance(training, torch.Tensor):
training = torch.tensor(training, dtype=torch.bool)

quant_tensor = super().__new__(
cls,
value,
Expand All @@ -66,18 +61,6 @@ def __new__(
training)
return quant_tensor

@property
def signed(self):
return self.signed_t.item()

@property
def training(self):
return self.training_t.item()

@property
def saturating(self):
return self.saturating_t.item()

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
Expand Down
12 changes: 0 additions & 12 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,10 @@ def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, sig
zero_point = torch.tensor(zero_point, dtype=torch.float)
if not isinstance(bit_width, torch.Tensor):
bit_width = torch.tensor(bit_width, dtype=torch.float)
if not isinstance(signed, torch.Tensor):
signed = torch.tensor(signed, dtype=torch.bool)
if not isinstance(training, torch.Tensor):
training = torch.tensor(training, dtype=torch.bool)
quant_tensor = super().__new__(
cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training)
return quant_tensor

@property
def signed(self):
return self.signed_t.item()

@property
def training(self):
return self.training_t.item()

@property
def saturating(self):
return self.saturating_t.item()
Expand Down
12 changes: 0 additions & 12 deletions src/brevitas/quant_tensor/int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,9 @@ def __new__(cls, value, scale, zero_point, bit_width, signed, training):
zero_point = torch.tensor(zero_point, dtype=torch.float)
if not isinstance(bit_width, torch.Tensor):
bit_width = torch.tensor(bit_width, dtype=torch.float)
if not isinstance(signed, torch.Tensor):
signed = torch.tensor(signed, dtype=torch.bool)
if not isinstance(training, torch.Tensor):
training = torch.tensor(training, dtype=torch.bool)
quant_tensor = super().__new__(cls, value, scale, zero_point, bit_width, signed, training)
return quant_tensor

@property
def signed(self):
return self.signed_t.item()

@property
def training(self):
return self.training_t.item()

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
Expand Down

0 comments on commit afa4af1

Please sign in to comment.