Skip to content

Commit

Permalink
Feat (scaling): no tracked_parameter_list with individual quantizer (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 6, 2024
1 parent 0ea7bac commit 7caa716
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 7 deletions.
4 changes: 3 additions & 1 deletion src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
scaling_stats_input_concat_dim: int,
tracked_parameter_list: List[torch.nn.Parameter],
scaling_shape: Tuple[int, ...],
force_parameter: bool = False,
restrict_scaling_impl: Module = FloatRestrictValue(),
restrict_threshold_impl: Optional[Module] = None,
affine_rescaling: bool = False,
Expand All @@ -48,7 +49,8 @@ def __init__(
scaling_shape,
scaling_stats_input_view_shape_impl,
scaling_stats_input_concat_dim,
tracked_parameter_list)
tracked_parameter_list,
force_parameter)
self.stats_scaling_impl = _StatsScaling(
restrict_scaling_impl,
restrict_threshold_impl,
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(
scaling_stats_input_concat_dim: int,
tracked_parameter_list: List[torch.nn.Parameter],
scaling_shape: Tuple[int, ...],
force_parameter: bool = False,
restrict_scaling_impl: Module = FloatRestrictValue(),
restrict_threshold_impl: Optional[Module] = None,
scaling_min_val: Optional[float] = None,
Expand All @@ -218,7 +219,8 @@ def __init__(
scaling_shape,
scaling_stats_input_view_shape_impl,
scaling_stats_input_concat_dim,
tracked_parameter_list)
tracked_parameter_list,
force_parameter)

# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
Expand Down
5 changes: 3 additions & 2 deletions src/brevitas/core/stats/stats_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,12 @@ def __init__(
stats_output_shape: Tuple[int, ...],
stats_input_view_shape_impl: nn.Module,
stats_input_concat_dim: int,
tracked_parameter_list: List[torch.nn.Parameter]) -> None:
tracked_parameter_list: List[torch.nn.Parameter],
force_parameter: bool = False) -> None:
super(_ParameterListStats, self).__init__()

self.stats_input_concat_dim = stats_input_concat_dim
if len(tracked_parameter_list) >= 1:
if len(tracked_parameter_list) > 1 or force_parameter:
self.first_tracked_param = _ViewParameterWrapper(
tracked_parameter_list[0], stats_input_view_shape_impl)
else:
Expand Down
7 changes: 5 additions & 2 deletions src/brevitas/core/stats/view_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,11 @@ def __init__(self, view_shape_impl: Module) -> None:
self.view_shape_impl = view_shape_impl

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
return self.view_shape_impl(x)
def forward(self, x: Optional[Tensor]) -> Tensor:
if x is not None:
return self.view_shape_impl(x)
else:
raise RuntimeError("Input cannot be None")


class _ViewCatParameterWrapper(brevitas.jit.ScriptModule):
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,9 @@ def scaling_init(scaling_init_impl, bit_width):
return scales

per_channel_pre_norm = PerChannelPreNorm

# Even if we have a single parameter per quantizer,
# we want to force the use of tracker_parameter_list for the scale computation because of the initialization
force_parameter = True
proxy_class = DecoupledWeightQuantProxyFromInjector
tensor_quant = DecoupledRescalingIntQuant
decoupled_int_quant = DecoupledIntQuant
Expand Down

0 comments on commit 7caa716

Please sign in to comment.