From a7d5e674aad53fd264ecc295419c554783b1db22 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 25 Sep 2024 15:52:35 +0100 Subject: [PATCH] threshold fix --- src/brevitas/core/scaling/pre_scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/scaling/pre_scaling.py b/src/brevitas/core/scaling/pre_scaling.py index d73c86461..82d12b298 100644 --- a/src/brevitas/core/scaling/pre_scaling.py +++ b/src/brevitas/core/scaling/pre_scaling.py @@ -97,7 +97,7 @@ def forward(self, weights: Tensor) -> Tensor: weights = self.stats_input_view_shape_impl(weights) d_w = self.stats(weights) # denominator for weight normalization g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g - s = self.scaling_impl(weights) # s + s = self.scaling_impl(weights, torch.tensor(1.).type_as(weights)) # s value = (s * d_w) / g return value @@ -184,7 +184,7 @@ def calc_max_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Te def inner_forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool): weights = self.stats_input_view_shape_impl(weights) d_w = self.stats(weights) # denominator for weight normalization - s = self.scaling_impl(weights) # s + s = self.scaling_impl(weights, torch.tensor(1.).type_as(weights)) # s g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g T = self.calc_max_l1_norm(input_bit_width, input_is_signed) # T / s g = torch.clamp_max(g / s, T)