Skip to content

Commit

Permalink
Restrict scaling for dynamic quant
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 27, 2024
1 parent a7dad76 commit e040ffd
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import Tensor
import torch.nn as nn

from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.zero_point import _ScaleShiftZeroPoint
from brevitas.function.ops_ste import abs_binary_sign_grad

Expand All @@ -19,16 +20,31 @@ def __init__(
self,
scaling_stats_impl: nn.Module,
dynamic_scaling_broadcastable_fn: Callable,
scaling_stats_input_view_shape_impl: nn.Module) -> None:
scaling_stats_input_view_shape_impl: nn.Module,
restrict_scaling_impl: nn.Module,
restrict_threshold_impl: nn.Module = None,
scaling_min_val=None) -> None:
super(RuntimeDynamicStatsScaling, self).__init__()
# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl
self.scaling_stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
self.stats_impl = scaling_stats_impl
self.dynamic_scaling_broadcastable_fn = dynamic_scaling_broadcastable_fn
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
self.restrict_clamp_scaling = _RestrictClampValue(
scaling_min_val=scaling_min_val, restrict_value_impl=restrict_scaling_impl)
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()
self.restrict_clamp_threshold = _RestrictClampValue(
restrict_value_impl=restrict_threshold_impl)

def forward(self, x, threshold) -> Tensor:
shape = x.shape
threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold))
x = self.scaling_stats_input_view_shape_impl(x)
x = self.stats_impl(x) / threshold
x = self.stats_impl(x)
x = self.restrict_clamp_scaling(self.restrict_scaling_pre(x))
x = x / threshold

x = self.dynamic_scaling_broadcastable_fn(x, shape)
return x
Expand Down

0 comments on commit e040ffd

Please sign in to comment.