From 260009d25dcc0abf25c31c582e66cc80694266e1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 5 Dec 2024 09:39:46 +0000 Subject: [PATCH] Fix --- src/brevitas/core/scaling/runtime.py | 4 +++- src/brevitas/core/scaling/standalone.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 5d013adfc..c2c7a0f2b 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -30,6 +30,7 @@ def __init__( scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], + force_parameter: bool = False, restrict_scaling_impl: Module = FloatRestrictValue(), restrict_threshold_impl: Optional[Module] = None, affine_rescaling: bool = False, @@ -48,7 +49,8 @@ def __init__( scaling_shape, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, - tracked_parameter_list) + tracked_parameter_list, + force_parameter) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, restrict_threshold_impl, diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index a75e15b45..26bd12c9a 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -207,6 +207,7 @@ def __init__( scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], + force_parameter: bool = False, restrict_scaling_impl: Module = FloatRestrictValue(), restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, @@ -218,7 +219,8 @@ def __init__( scaling_shape, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, - tracked_parameter_list) + tracked_parameter_list, + force_parameter) # Ensure retro-compatibility with shared threshold/scaling restrict if restrict_threshold_impl is None: