From 5a41cfb0ac57d1e7db01820ddc81206c32654bf3 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 13:47:15 +0000 Subject: [PATCH] Fix (export/trunc): Retrieve bit_width from cache --- src/brevitas/proxy/runtime_quant.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 2e3749236..12dcf9528 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -286,13 +286,7 @@ def zero_point(self): return self.retrieve_attribute('zero_point') def bit_width(self): - if not self.is_quant_enabled: - return None - zhs = self._zero_hw_sentinel() - # Signed might or might not be defined. We just care about retrieving the bitwidth - empty_imp = IntQuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training) - bit_width = self.__call__(empty_imp).bit_width - return bit_width + return self.retrieve_attribute('bit_width') def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: