Skip to content

Commit

Permalink
Fix (proxy): fix for attributes retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 1, 2024
1 parent 8d95bbd commit 255408e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
22 changes: 18 additions & 4 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,22 @@ def requires_quant_input(self):
return False

def scale(self):
if not self.is_quant_enabled:
return None
scale = self.__call__(self.tracked_parameter_list[0]).scale
return scale

def zero_point(self):
if not self.is_quant_enabled:
return None
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point
return zero_point

def bit_width(self):
bit_width_ = self.__call__(self.tracked_parameter_list[0]).bit_width
return bit_width_
if not self.is_quant_enabled:
return None
bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width
return bit_width

def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
Expand All @@ -105,11 +111,15 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
class DecoupledWeightQuantProxyFromInjector(WeightQuantProxyFromInjector):

def pre_scale(self):
if not self.is_quant_enabled:
return None
output_tuple = self.tensor_quant(self.tracked_parameter_list[0])
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple
return pre_scale

def pre_zero_point(self):
if not self.is_quant_enabled:
return None
output_tuple = self.tensor_quant(self.tracked_parameter_list[0])
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple
return pre_zero_point
Expand Down Expand Up @@ -151,7 +161,7 @@ def forward(self, x: torch.Tensor, input_bit_width: torch.Tensor,
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed)
return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return QuantTensor(x, training=self.training)
return x


class BiasQuantProxyFromInjector(ParameterQuantProxyFromInjector, BiasQuantProxyProtocol):
Expand All @@ -168,18 +178,22 @@ def requires_input_scale(self) -> bool:
return False

def scale(self):
if self.requires_input_scale:
if self.requires_input_scale or not self.is_quant_enabled:
return None
zhs = self._zero_hw_sentinel()
scale = self.__call__(self.tracked_parameter_list[0], zhs).scale
return scale

def zero_point(self):
if not self.is_quant_enabled:
return None
zhs = self._zero_hw_sentinel()
zero_point = self.__call__(self.tracked_parameter_list[0], zhs).zero_point
return zero_point

def bit_width(self):
if not self.is_quant_enabled:
return None
zhs = self._zero_hw_sentinel()
bit_width = self.__call__(self.tracked_parameter_list[0], zhs).bit_width
return bit_width
Expand Down
22 changes: 15 additions & 7 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def init_tensor_quant(self):
self.fused_activation_quant_proxy = None

def scale(self, force_eval=True):
if not self.is_quant_enabled:
return None
current_status = self.training
if force_eval:
self.eval()
Expand All @@ -126,16 +128,24 @@ def scale(self, force_eval=True):
return scale

def zero_point(self, force_eval=True):
if not self.is_quant_enabled:
return None
current_status = self.training
if force_eval:
self.eval()
zero_point = self.__call__(self._zero_hw_sentinel()).zero_point
self.train(current_status)
return zero_point

def bit_width(self):
scale = self.__call__(self._zero_hw_sentinel()).bit_width
return scale
def bit_width(self, force_eval=True):
if not self.is_quant_enabled:
return None
current_status = self.training
if force_eval:
self.eval()
bit_width = self.__call__(self._zero_hw_sentinel()).bit_width
self.train(current_status)
return bit_width

def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
if self.fused_activation_quant_proxy is not None:
Expand Down Expand Up @@ -179,10 +189,6 @@ def scale(self, force_eval=True):
def zero_point(self, force_eval=True):
raise RuntimeError("Zero point for Dynamic Act Quant is input-dependant")

def bit_width(self):
bit_width = self.__call__(self._zero_hw_sentinel()).bit_width
return bit_width


class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

Expand All @@ -198,6 +204,8 @@ def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]:
class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

def bit_width(self):
if not self.is_quant_enabled:
return None
zhs = self._zero_hw_sentinel()
# Signed might or might not be defined. We just care about retrieving the bitwidth
empty_imp = QuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training)
Expand Down

0 comments on commit 255408e

Please sign in to comment.