diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 5d013adfc..c2c7a0f2b 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -30,6 +30,7 @@ def __init__( scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], + force_parameter: bool = False, restrict_scaling_impl: Module = FloatRestrictValue(), restrict_threshold_impl: Optional[Module] = None, affine_rescaling: bool = False, @@ -48,7 +49,8 @@ def __init__( scaling_shape, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, - tracked_parameter_list) + tracked_parameter_list, + force_parameter) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, restrict_threshold_impl, diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index a75e15b45..26bd12c9a 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -207,6 +207,7 @@ def __init__( scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], + force_parameter: bool = False, restrict_scaling_impl: Module = FloatRestrictValue(), restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, @@ -218,7 +219,8 @@ def __init__( scaling_shape, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, - tracked_parameter_list) + tracked_parameter_list, + force_parameter) # Ensure retro-compatibility with shared threshold/scaling restrict if restrict_threshold_impl is None: diff --git a/src/brevitas/core/stats/stats_wrapper.py b/src/brevitas/core/stats/stats_wrapper.py index 49bf62a82..18e9c286a 100644 --- a/src/brevitas/core/stats/stats_wrapper.py +++ b/src/brevitas/core/stats/stats_wrapper.py @@ -93,11 +93,12 @@ def __init__( stats_output_shape: Tuple[int, ...], stats_input_view_shape_impl: nn.Module, stats_input_concat_dim: int, - tracked_parameter_list: List[torch.nn.Parameter]) -> None: + tracked_parameter_list: List[torch.nn.Parameter], + force_parameter: bool = False) -> None: super(_ParameterListStats, self).__init__() self.stats_input_concat_dim = stats_input_concat_dim - if len(tracked_parameter_list) >= 1: + if len(tracked_parameter_list) > 1 or force_parameter: self.first_tracked_param = _ViewParameterWrapper( tracked_parameter_list[0], stats_input_view_shape_impl) else: diff --git a/src/brevitas/core/stats/view_wrapper.py b/src/brevitas/core/stats/view_wrapper.py index 98c6ab538..2535ee246 100644 --- a/src/brevitas/core/stats/view_wrapper.py +++ b/src/brevitas/core/stats/view_wrapper.py @@ -52,8 +52,11 @@ def __init__(self, view_shape_impl: Module) -> None: self.view_shape_impl = view_shape_impl @brevitas.jit.script_method - def forward(self, x: Tensor) -> Tensor: - return self.view_shape_impl(x) + def forward(self, x: Optional[Tensor]) -> Tensor: + if x is not None: + return self.view_shape_impl(x) + else: + raise RuntimeError("Input cannot be None") class _ViewCatParameterWrapper(brevitas.jit.ScriptModule): diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 0fd5e683a..7cd050e21 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -418,7 +418,9 @@ def scaling_init(scaling_init_impl, bit_width): return scales per_channel_pre_norm = PerChannelPreNorm - + # Even if we have a single parameter per quantizer, + # we want to force the use of tracker_parameter_list for the scale computation because of the initialization + force_parameter = True proxy_class = DecoupledWeightQuantProxyFromInjector tensor_quant = DecoupledRescalingIntQuant decoupled_int_quant = DecoupledIntQuant