Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 16, 2024
1 parent c6fb3d6 commit 378bd8c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
25 changes: 14 additions & 11 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,11 @@ def __init__(
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()
self.init_done: bool = brevitas.jit.Attribute(False, bool)

def init_scale(self):
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
self.counter = self.counter + 1

@brevitas.jit.script_method
def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
if self.counter < self.collect_stats_steps:
Expand All @@ -397,12 +402,10 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
self.counter = new_counter
return abs_binary_sign_grad(clamped_stats / threshold)
elif self.counter == self.collect_stats_steps:
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
self.init_scale()
value = self.clamp_scaling(self.restrict_scaling(self.value))
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
value = value / threshold
self.counter = self.counter + 1
return abs_binary_sign_grad(value)
else:
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
Expand All @@ -414,22 +417,22 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
if self.training and not self.init_done:
if self.training:
# Threshold division handled inside the training_forward
return self.training_forward(stats_input, threshold)
else:
if not self.init_done:
self.init_done = True
if self.counter <= self.collect_stats_steps:
out = self.buffer
# No clamping is necessary since statistics are already clamped in training_forward
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
out = self.value
out = self.restrict_scaling_pre(out)
else:
out = self.value
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
out = self.restrict_scaling(out)
out = out / threshold
# We can clamp after restrict val since the learned parameter is already in log-domain
out = abs_binary_sign_grad(self.clamp_scaling(out))
return out
return out

def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(ParameterFromRuntimeStatsScaling, self).state_dict(
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def restore_return_quant_tensor(model, previous_state):
def extend_collect_stats_steps(module):
if hasattr(module, 'collect_stats_steps'):
# We extend the collect steps in PTQ to match potentially long calibrations
module.collect_stats_steps = sys.maxsize
module.collect_stats_steps = sys.maxsize - 1


def set_collect_stats_to_average(module):
Expand All @@ -80,6 +80,8 @@ def finalize_collect_stats(module):
# otherwise the restrict_preprocess might be applied twice: during calibration
# (that happens in training mode) and then when the model is evaluated
module.counter = max(module.collect_stats_steps, module.counter)
if hasattr(module, 'init_scale'):
module.init_scale()


class calibration_mode:
Expand Down

0 comments on commit 378bd8c

Please sign in to comment.