Skip to content

Commit

Permalink
Fix (proxy): fix groupwise scale/zp caching
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 20, 2024
1 parent 39ce837 commit f7aae4d
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 17 deletions.
12 changes: 8 additions & 4 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ def __init__(self):

def prepare_for_export(self, module: nn.Module):
if module.is_quant_enabled:
self.scale = module.scale()
self.zero_point = module.zero_point().to(self.scale.device)
self.scale = module.scale_() if hasattr(module, 'scale_') else module.scale()
self.zero_point = module.zero_point_() if hasattr(
module, 'zero_point_') else module.zero_point()
self.zero_point = self.zero_point.to(self.scale.device)
self.bit_width = module.bit_width()
self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width)
self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width)
Expand Down Expand Up @@ -177,8 +179,10 @@ def __init__(self):

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.scale = module.scale()
self.zero_point = module.zero_point().to(self.scale.device)
self.scale = module.scale_() if hasattr(module, 'scale_') else module.scale()
self.zero_point = module.zero_point_() if hasattr(
module, 'zero_point_') else module.zero_point()
self.zero_point = self.zero_point.to(self.scale.device)
self.exponent_bit_width = module.exponent_bit_width()
self.mantissa_bit_width = module.mantissa_bit_width()
self.exponent_bias = module.exponent_bias()
Expand Down
18 changes: 18 additions & 0 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
super().__init__(quant_layer, quant_injector)
self.cache_class = _CachedIOGroupwiseFloat

def scale_(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
scale = self._cached_weight.scale_
else:
scale = self.__call__(self.tracked_parameter_list[0]).scale_
return scale

def zero_point_(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
zero_point = self._cached_weight.zero_point_
else:
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point_
return zero_point

@property
def group_dim(self):
return self.quant_injector.group_dim
Expand Down
18 changes: 18 additions & 0 deletions src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
super().__init__(quant_layer, quant_injector)
self.cache_class = _CachedIOGroupwiseInt

def scale_(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
scale = self._cached_weight.scale_
else:
scale = self.__call__(self.tracked_parameter_list[0]).scale_
return scale

def zero_point_(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
zero_point = self._cached_weight.zero_point_
else:
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point_
return zero_point

@property
def group_dim(self):
return self.quant_injector.group_dim
Expand Down
15 changes: 12 additions & 3 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,28 @@ def requires_quant_input(self):
def scale(self):
if not self.is_quant_enabled:
return None
scale = self.__call__(self.tracked_parameter_list[0]).scale
elif self._cached_weight:
scale = self._cached_weight.scale
else:
scale = self.__call__(self.tracked_parameter_list[0]).scale
return scale

def zero_point(self):
if not self.is_quant_enabled:
return None
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point
elif self._cached_weight:
zero_point = self._cached_weight.zero_point
else:
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point
return zero_point

def bit_width(self):
if not self.is_quant_enabled:
return None
bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width
elif self._cached_weight:
bit_width = self._cached_weight.bit_width
else:
bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width
return bit_width

def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor:
Expand Down
14 changes: 4 additions & 10 deletions src/brevitas/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,8 @@ def __init__(self, quant_tensor: GroupwiseFloatQuantTensor, metadata_only: bool)
# torch.compile compatibility
self.value = quant_tensor.value
# torch.compile compatibility
self.scale = quant_tensor.scale

@property
def zero_point(self):
return self.quant_tensor.zero_point
self.scale_ = quant_tensor.scale_
self.zero_point_ = quant_tensor.zero_point_

@property
def exponent_bit_width(self):
Expand Down Expand Up @@ -152,11 +149,8 @@ def __init__(self, quant_tensor: GroupwiseIntQuantTensor, metadata_only: bool):
# torch.compile compatibility
self.value = quant_tensor.value
# torch.compile compatibility
self.scale = quant_tensor.scale

@property
def zero_point(self):
return self.quant_tensor.zero_point
self.scale_ = quant_tensor.scale_
self.zero_point_ = quant_tensor.zero_point_

@property
def bit_width(self):
Expand Down

0 comments on commit f7aae4d

Please sign in to comment.