diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 633918d1d..193707dc0 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -311,6 +311,7 @@ def __init__( self.stats_output_shape = pre_zero_point_shape self.stats_input_view_shape_impl = pre_zero_point_stats_input_view_shape_impl + @brevitas.jit.script_method def get_zero_center(self, x: Tensor) -> Tensor: x = self.stats_input_view_shape_impl(x) u = torch.mean(x, axis=self.stats_reduce_dim, keepdim=True)