diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 2927b1662..79ad7e9ec 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -82,16 +82,22 @@ def requires_quant_input(self): return False def scale(self): + if not self.is_quant_enabled: + return None scale = self.__call__(self.tracked_parameter_list[0]).scale return scale def zero_point(self): + if not self.is_quant_enabled: + return None zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point return zero_point def bit_width(self): - bit_width_ = self.__call__(self.tracked_parameter_list[0]).bit_width - return bit_width_ + if not self.is_quant_enabled: + return None + bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width + return bit_width def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: @@ -105,11 +111,15 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: class DecoupledWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): def pre_scale(self): + if not self.is_quant_enabled: + return None output_tuple = self.tensor_quant(self.tracked_parameter_list[0]) out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple return pre_scale def pre_zero_point(self): + if not self.is_quant_enabled: + return None output_tuple = self.tensor_quant(self.tracked_parameter_list[0]) out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple return pre_zero_point @@ -151,7 +161,7 @@ def forward(self, x: torch.Tensor, input_bit_width: torch.Tensor, out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled - return QuantTensor(x, training=self.training) + return x class BiasQuantProxyFromInjector(ParameterQuantProxyFromInjector, BiasQuantProxyProtocol): @@ -168,18 +178,22 @@ def requires_input_scale(self) -> bool: return False def scale(self): - if self.requires_input_scale: + if self.requires_input_scale or not self.is_quant_enabled: return None zhs = self._zero_hw_sentinel() scale = self.__call__(self.tracked_parameter_list[0], zhs).scale return scale def zero_point(self): + if not self.is_quant_enabled: + return None zhs = self._zero_hw_sentinel() zero_point = self.__call__(self.tracked_parameter_list[0], zhs).zero_point return zero_point def bit_width(self): + if not self.is_quant_enabled: + return None zhs = self._zero_hw_sentinel() bit_width = self.__call__(self.tracked_parameter_list[0], zhs).bit_width return bit_width diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 0324465c1..fe7b29daf 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -118,6 +118,8 @@ def init_tensor_quant(self): self.fused_activation_quant_proxy = None def scale(self, force_eval=True): + if not self.is_quant_enabled: + return None current_status = self.training if force_eval: self.eval() @@ -126,6 +128,8 @@ def scale(self, force_eval=True): return scale def zero_point(self, force_eval=True): + if not self.is_quant_enabled: + return None current_status = self.training if force_eval: self.eval() @@ -133,9 +137,15 @@ def zero_point(self, force_eval=True): self.train(current_status) return zero_point - def bit_width(self): - scale = self.__call__(self._zero_hw_sentinel()).bit_width - return scale + def bit_width(self, force_eval=True): + if not self.is_quant_enabled: + return None + current_status = self.training + if force_eval: + self.eval() + bit_width = self.__call__(self._zero_hw_sentinel()).bit_width + self.train(current_status) + return bit_width def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: if self.fused_activation_quant_proxy is not None: @@ -179,10 +189,6 @@ def scale(self, force_eval=True): def zero_point(self, force_eval=True): raise RuntimeError("Zero point for Dynamic Act Quant is input-dependant") - def bit_width(self): - bit_width = self.__call__(self._zero_hw_sentinel()).bit_width - return bit_width - class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): @@ -198,6 +204,8 @@ def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): def bit_width(self): + if not self.is_quant_enabled: + return None zhs = self._zero_hw_sentinel() # Signed might or might not be defined. We just care about retrieving the bitwidth empty_imp = QuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training)