From 90d4e31a64cd76da03a74e54795456606a628e44 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 25 Sep 2024 15:30:41 +0100 Subject: [PATCH] fix some tests --- src/brevitas/core/scaling/standalone.py | 15 ++++++--------- tests/brevitas/graph/test_calibration.py | 4 ++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index c8e612909..965bd3dce 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -212,10 +212,9 @@ def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tenso stats = stats + 0. * self.value if self.local_loss_mode: return self.stats_scaling_impl(stats, threshold) - stats = self.restrict_inplace_preprocess(stats) + stats = self.restrict_inplace_preprocess(stats / threshold) inplace_tensor_mul(self.value.detach(), stats) - value = abs_binary_sign_grad( - self.stats_scaling_impl.restrict_clamp_scaling(self.value / threshold)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) self.init_done = True return value @@ -338,14 +337,12 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens self.counter = new_counter return abs_binary_sign_grad(clamped_stats) / threshold elif self.counter == self.collect_stats_steps: - self.restrict_inplace_preprocess(self.buffer) + self.restrict_inplace_preprocess(self.buffer / threshold) inplace_tensor_mul(self.value.detach(), self.buffer) self.counter = self.counter + 1 - return abs_binary_sign_grad( - self.clamp_scaling(self.restrict_scaling(self.value / threshold))) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value))) else: - return abs_binary_sign_grad( - self.clamp_scaling(self.restrict_scaling(self.value / threshold))) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.valu))) @brevitas.jit.script_method def forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor: @@ -356,7 +353,7 @@ def forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor: out = self.buffer / threshold out = self.restrict_preprocess(out) else: - out = self.value / threshold + out = self.value out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) return out diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 52819c22f..10d8f7e7c 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -47,9 +47,9 @@ def reference_implementation_scale_factors_po2( quant = compute_quantile(x, q) quant = torch.max(min_val, quant) quant_float_to_int = torch.ceil( - torch.log2(quant)) # Float to Int Implementation for PowerOfTwo scale + torch.log2(quant / int_scale)) # Float to Int Implementation for PowerOfTwo scale - scale = torch.pow(torch.tensor(2.), quant_float_to_int) / int_scale + scale = torch.pow(torch.tensor(2.), quant_float_to_int) return scale