diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 88e78beaa..53765c13c 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -24,49 +24,73 @@ def bit_width(self): def scale(self): if not self.is_quant_enabled: return None - scale = self.__call__(self.tracked_parameter_list[0]).scale + elif self._cached_weight: + scale = self._cached_weight.scale + else: + scale = self.__call__(self.tracked_parameter_list[0]).scale return scale def zero_point(self): if not self.is_quant_enabled: return None - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point + elif self._cached_weight: + zero_point = self._cached_weight.zero_point + else: + zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point return zero_point def exponent_bit_width(self): if not self.is_quant_enabled: return None - exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width + elif self._cached_weight: + exponent_bit_width = self._cached_weight.exponent_bit_width + else: + exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width return exponent_bit_width def mantissa_bit_width(self): if not self.is_quant_enabled: return None - mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width + elif self._cached_weight: + mantissa_bit_width = self._cached_weight.mantissa_bit_width + else: + mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width return mantissa_bit_width def exponent_bias(self): if not self.is_quant_enabled: return None - exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias + elif self._cached_weight: + exponent_bias = self._cached_weight.exponent_bias + else: + exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias return exponent_bias def is_saturating(self): if not self.is_quant_enabled: return None - saturating = self.__call__(self.tracked_parameter_list[0]).saturating + elif self._cached_weight: + saturating = self._cached_weight.saturating + else: + saturating = self.__call__(self.tracked_parameter_list[0]).saturating return saturating def inf_values(self): if not self.is_quant_enabled: return None - inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values + elif self._cached_weight: + inf_values = self._cached_weight.inf_values + else: + inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values return inf_values def nan_values(self): if not self.is_quant_enabled: return None - nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values + elif self._cached_weight: + nan_values = self._cached_weight.nan_values + else: + nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values return nan_values @property