diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index c8e612909..965bd3dce 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -212,10 +212,9 @@ def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tenso stats = stats + 0. * self.value if self.local_loss_mode: return self.stats_scaling_impl(stats, threshold) - stats = self.restrict_inplace_preprocess(stats) + stats = self.restrict_inplace_preprocess(stats / threshold) inplace_tensor_mul(self.value.detach(), stats) - value = abs_binary_sign_grad( - self.stats_scaling_impl.restrict_clamp_scaling(self.value / threshold)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) self.init_done = True return value @@ -338,14 +337,12 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens self.counter = new_counter return abs_binary_sign_grad(clamped_stats) / threshold elif self.counter == self.collect_stats_steps: - self.restrict_inplace_preprocess(self.buffer) + self.restrict_inplace_preprocess(self.buffer / threshold) inplace_tensor_mul(self.value.detach(), self.buffer) self.counter = self.counter + 1 - return abs_binary_sign_grad( - self.clamp_scaling(self.restrict_scaling(self.value / threshold))) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value))) else: - return abs_binary_sign_grad( - self.clamp_scaling(self.restrict_scaling(self.value / threshold))) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.valu))) @brevitas.jit.script_method def forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor: @@ -356,7 +353,7 @@ def forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor: out = self.buffer / threshold out = self.restrict_preprocess(out) else: - out = self.value / threshold + out = self.value out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) return out diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 52819c22f..10d8f7e7c 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -47,9 +47,9 @@ def reference_implementation_scale_factors_po2( quant = compute_quantile(x, q) quant = torch.max(min_val, quant) quant_float_to_int = torch.ceil( - torch.log2(quant)) # Float to Int Implementation for PowerOfTwo scale + torch.log2(quant / int_scale)) # Float to Int Implementation for PowerOfTwo scale - scale = torch.pow(torch.tensor(2.), quant_float_to_int) / int_scale + scale = torch.pow(torch.tensor(2.), quant_float_to_int) return scale