From 76455504bfcad1f5e8011995cd0cb64890c84305 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 24 Oct 2024 10:21:15 +0100 Subject: [PATCH 01/14] Fix (groupwise): correct log and groupdim --- src/brevitas/core/scaling/runtime.py | 4 ++++ src/brevitas/quant/solver/common.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index f11eb1f2a..0e6037903 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -187,6 +187,8 @@ def __init__( self.scaling_min_val = scaling_min_val self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_module = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module( + ) @brevitas.jit.script_method def forward( @@ -197,6 +199,8 @@ def forward( threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) out = self.scaling_stats_impl(stats_input_reshaped) / threshold + # Apply log scaling + out = self.restrict_module(out) # Scaling min val out = self.restrict_clamp_scaling(out) return out diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 4d46cc704..a4930e43d 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -178,7 +178,8 @@ def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None): elif scaling_per_output == ScalingPerOutputType.TENSOR: return None elif scaling_per_output == ScalingPerOutputType.GROUP: - return group_dim + 1 + reduce_dim = group_dim + 1 if group_dim != -1 else -1 + return reduce_dim @value def keepdim(scaling_per_output): From eb72870c53bc2a9575a1855bfa109ece4fcc7c22 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 24 Oct 2024 12:56:14 +0100 Subject: [PATCH 02/14] More fix --- src/brevitas/core/function_wrapper/shape.py | 5 +++-- src/brevitas/core/scaling/runtime.py | 5 +++-- src/brevitas/quant/experimental/mx_quant_ocp.py | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index e175e4445..e8b42312a 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -195,8 +195,9 @@ def forward(self, x): tensor_shape = x.shape tensor_shape_list = list(tensor_shape) - tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) - block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 + tensor_shape_list[self.group_dim] = ( + tensor_shape_list[self.group_dim] + self.group_size - 1) // self.group_size + block_dim = self.group_dim + 1 if self.group_dim != -1 else len(tensor_shape_list) tensor_shape_list.insert(block_dim, self.group_size) x = x.view(tensor_shape_list) return x diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 0e6037903..2dc4cea1c 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -198,9 +198,10 @@ def forward( if threshold is None: threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) - out = self.scaling_stats_impl(stats_input_reshaped) / threshold + threshold = self.restrict_clamp_scaling(self.restrict_module(threshold)) + out = self.scaling_stats_impl(stats_input_reshaped) # Apply log scaling out = self.restrict_module(out) # Scaling min val - out = self.restrict_clamp_scaling(out) + out = self.restrict_clamp_scaling(out) / threshold return out diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py index 2299c1783..551f4f3d7 100644 --- a/src/brevitas/quant/experimental/mx_quant_ocp.py +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -4,6 +4,7 @@ from dependencies import value from brevitas.core.function_wrapper.ops_ste import CeilSte +from brevitas.core.function_wrapper.ops_ste import FloorSte from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.inject import ExtendedInjector from brevitas.inject.enum import RestrictValueType @@ -46,14 +47,14 @@ class GroupwiseActProxyMixin(ExtendedInjector): class MXWeightMixin(ExtendedInjector): group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO - restrict_value_float_to_int_impl = CeilSte + restrict_value_float_to_int_impl = FloorSte scaling_per_output_type = ScalingPerOutputType.GROUP class MXActMixin(ExtendedInjector): group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO - restrict_value_float_to_int_impl = CeilSte + restrict_value_float_to_int_impl = FloorSte scaling_impl = RuntimeDynamicGroupStatsScaling scaling_per_output_type = ScalingPerOutputType.GROUP From 7a8afeb0e0b26356b07840fd8a3163bf10dc6fe7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 24 Oct 2024 21:57:23 +0100 Subject: [PATCH 03/14] More fixes --- src/brevitas/core/restrict_val.py | 12 -------- src/brevitas/core/scaling/runtime.py | 3 +- src/brevitas/core/scaling/standalone.py | 38 ++++++++++++++---------- tests/brevitas/graph/test_calibration.py | 3 +- 4 files changed, 25 insertions(+), 31 deletions(-) diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 59b3fe8ec..7eb9845f9 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -90,9 +90,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x / threshold - @brevitas.jit.script_method def forward(self, x: Tensor) -> Tensor: return x @@ -116,9 +113,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x - threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.power_of_two(x) @@ -143,9 +137,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x / threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.float_to_int_impl(x) @@ -171,9 +162,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x - threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.float_to_int_impl(x) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 2dc4cea1c..fee4175bc 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -90,10 +90,11 @@ def forward( if threshold is None: threshold = torch.ones(1).type_as(stats) threshold = self.restrict_scaling_pre(threshold) + threshold = self.restrict_clamp_scaling(threshold) stats = self.restrict_scaling_pre(stats) - stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) + stats = stats / threshold return stats diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 4917b859a..e43fd577a 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -220,9 +220,10 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor # This is because we don't want to store a parameter dependant on a runtime value (threshold) # And because restrict needs to happen after we divide by threshold if self.init_done: - threshold = self.restrict_inplace_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) + threshold = self.stats_scaling_impl.restrict_clamp_scaling( + self.restrict_preprocess(threshold)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = value / threshold return value else: stats = self.parameter_list_stats() @@ -231,10 +232,11 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor if self.local_loss_mode: return self.stats_scaling_impl(stats, threshold) stats = self.restrict_inplace_preprocess(stats) - threshold = self.restrict_inplace_preprocess(threshold) + threshold = self.stats_scaling_impl.restrict_clamp_scaling( + self.restrict_preprocess(threshold)) inplace_tensor_mul(self.value.detach(), stats) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = value / threshold self.init_done = True return value @@ -360,14 +362,16 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) - threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) + threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + value = self.clamp_scaling(self.restrict_scaling(self.value)) + value = value / threshold self.counter = self.counter + 1 - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) + return abs_binary_sign_grad(value) else: - threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) + threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + value = self.clamp_scaling(self.restrict_scaling(self.value)) + value = value / threshold + return abs_binary_sign_grad(value) @brevitas.jit.script_method def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -378,12 +382,14 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te return self.training_forward(stats_input, threshold) else: if self.counter <= self.collect_stats_steps: - out = self.buffer / threshold + out = self.buffer out = self.restrict_preprocess(out) else: - threshold = self.restrict_preprocess(threshold) - out = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) + out = self.value + threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + out = self.clamp_scaling(self.restrict_scaling(out)) + out = out / threshold + out = abs_binary_sign_grad(self.clamp_scaling(out)) return out def state_dict(self, destination=None, prefix='', keep_vars=False): diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index fbfc76842..16f944e97 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -60,7 +60,7 @@ def reference_implementation_scale_factors_po2( return scale -@given(inp=float_tensor_random_size_st()) +@given(inp=float_tensor_random_size_st(max_val=1e10, min_val=-1e10)) def test_scale_factors_ptq_calibration_po2(inp): class TestModel(nn.Module): @@ -80,7 +80,6 @@ def forward(self, x): expected_scale = reference_implementation_scale_factors_po2(inp) scale = model.act.act_quant.scale() - assert torch.allclose(expected_scale, scale) From 66fbbfde22dae6e7478f3768e30d8963414ba265 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 12:24:56 +0000 Subject: [PATCH 04/14] Decouple threshold restrict impl from scaling --- src/brevitas/core/scaling/runtime.py | 47 +++++++++++----- src/brevitas/core/scaling/standalone.py | 75 ++++++++++++++++++------- 2 files changed, 89 insertions(+), 33 deletions(-) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index fee4175bc..9792ebdae 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -30,12 +30,18 @@ def __init__( tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, affine_rescaling: bool = False, affine_shift_scale: bool = False, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(StatsFromParameterScaling, self).__init__() + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.parameter_list_stats = _ParameterListStats( scaling_stats_impl, scaling_shape, @@ -44,6 +50,7 @@ def __init__( tracked_parameter_list) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, + restrict_threshold_impl, scaling_shape, scaling_min_val, affine_rescaling, @@ -65,6 +72,7 @@ class _StatsScaling(brevitas.jit.ScriptModule): def __init__( self, restrict_scaling_impl: Module, + restrict_threshold_impl: Module, scaling_shape: Tuple[int, ...], scaling_min_val: Optional[float], affine_rescaling: bool, @@ -81,16 +89,18 @@ def __init__( else: self.affine_rescaling = Identity() self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() - self.restrict_scaling_impl = restrict_scaling_impl + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward( self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: if threshold is None: threshold = torch.ones(1).type_as(stats) - threshold = self.restrict_scaling_pre(threshold) - threshold = self.restrict_clamp_scaling(threshold) + threshold = self.restrict_threshold_pre(threshold) + threshold = self.restrict_clamp_threshold(threshold) stats = self.restrict_scaling_pre(stats) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) @@ -108,12 +118,17 @@ def __init__( affine_rescaling: bool = False, affine_shift_scale: bool = False, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_stats_momentum: float = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(RuntimeStatsScaling, self).__init__() + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.runtime_stats = _RuntimeStats( scaling_stats_impl, scaling_shape, @@ -123,6 +138,7 @@ def __init__( device) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, + restrict_threshold_impl, scaling_shape, scaling_min_val, affine_rescaling, @@ -174,13 +190,14 @@ def _load_from_state_dict( class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): def __init__( - self, - group_size: int, - group_dim: int, - input_view_impl: Module, - scaling_stats_impl: Module, - scaling_min_val: Optional[float], - restrict_scaling_impl: Module = FloatRestrictValue()) -> None: + self, + group_size: int, + group_dim: int, + input_view_impl: Module, + scaling_stats_impl: Module, + scaling_min_val: Optional[float], + restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None) -> None: super(RuntimeDynamicGroupStatsScaling, self).__init__() self.group_size = group_size self.group_dim = group_dim @@ -188,7 +205,11 @@ def __init__( self.scaling_min_val = scaling_min_val self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) - self.restrict_module = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module( + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) + self.restrict_scaling_pre = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module( + ) + self.restrict_threshold_pre = self.restrict_clamp_threshold.restrict_value_impl.restrict_init_module( ) @brevitas.jit.script_method @@ -199,10 +220,10 @@ def forward( if threshold is None: threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) - threshold = self.restrict_clamp_scaling(self.restrict_module(threshold)) + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) out = self.scaling_stats_impl(stats_input_reshaped) # Apply log scaling - out = self.restrict_module(out) + out = self.restrict_scaling_pre(out) # Scaling min val out = self.restrict_clamp_scaling(out) / threshold return out diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index e43fd577a..d9347898f 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -62,20 +62,27 @@ def __init__( self, scaling_init: Union[float, Tensor], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ConstScaling, self).__init__() + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) if isinstance(scaling_init, Tensor): scaling_init = scaling_init.to(device=device, dtype=dtype) scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(scaling_init.detach()) else: scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -83,7 +90,7 @@ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Te threshold = torch.ones(1).type_as(placeholder) # We first apply any restriction to scaling # For IntQuant, this is no-op, retrocompatible. - threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) restricted_value = self.restrict_clamp_scaling(self.value()) restricted_value = restricted_value / threshold return restricted_value @@ -133,11 +140,16 @@ def __init__( scaling_init: Union[float, Tensor], scaling_shape: Optional[Tuple[int, ...]] = None, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ParameterScaling, self).__init__() + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + if (isinstance(scaling_init, Tensor) and scaling_shape is not None and scaling_init.shape != SCALAR_SHAPE and scaling_init.shape != scaling_shape): raise RuntimeError("scaling_init.shape is non-scalar and != from scaling_shape.") @@ -149,12 +161,14 @@ def __init__( scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device) scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None: scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device) self.value = Parameter(scaling_init) self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -162,7 +176,7 @@ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Te threshold = torch.ones(1).type_as(placeholder) # We first apply any restriction to scaling # For IntQuant, this is no-op, retrocompatible. - threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) return value / threshold @@ -193,6 +207,7 @@ def __init__( tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -203,13 +218,26 @@ def __init__( scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list) - self.restrict_scaling_impl = restrict_scaling_impl + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.stats_scaling_impl = _StatsScaling( - restrict_scaling_impl, scaling_shape, scaling_min_val, False, False, dtype, device) + restrict_scaling_impl, + restrict_threshold_impl, + scaling_shape, + scaling_min_val, + False, + False, + dtype, + device) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() + self.restrict_inplace_scaling_pre = restrict_scaling_impl.restrict_init_inplace_module() + self.init_done: bool = brevitas.jit.Attribute(False, bool) self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) - self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() + self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method @@ -220,8 +248,8 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor # This is because we don't want to store a parameter dependant on a runtime value (threshold) # And because restrict needs to happen after we divide by threshold if self.init_done: - threshold = self.stats_scaling_impl.restrict_clamp_scaling( - self.restrict_preprocess(threshold)) + threshold = self.stats_scaling_impl.restrict_clamp_threshold( + self.restrict_threshold_pre(threshold)) value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) value = value / threshold return value @@ -231,9 +259,9 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor stats = stats + 0. * self.value if self.local_loss_mode: return self.stats_scaling_impl(stats, threshold) - stats = self.restrict_inplace_preprocess(stats) - threshold = self.stats_scaling_impl.restrict_clamp_scaling( - self.restrict_preprocess(threshold)) + stats = self.restrict_inplace_scaling_pre(stats) + threshold = self.stats_scaling_impl.restrict_clamp_threshold( + self.restrict_threshold_pre(threshold)) inplace_tensor_mul(self.value.detach(), stats) value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) value = value / threshold @@ -314,12 +342,18 @@ def __init__( scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(), scaling_shape: Tuple[int, ...] = SCALAR_SHAPE, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ParameterFromRuntimeStatsScaling, self).__init__() assert collect_stats_steps > 0, 'Steps should be more than 0' + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.collect_stats_steps: int = brevitas.jit.Attribute(collect_stats_steps, int) self.counter: int = brevitas.jit.Attribute(0, int) self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl @@ -328,13 +362,14 @@ def __init__( scaling_stats_momentum, Optional[float]) self.register_buffer('buffer', torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) - self.restrict_scaling_impl = restrict_scaling_impl self.restrict_scaling = _RestrictValue(restrict_scaling_impl) + self.restrict_threshold = _RestrictValue(restrict_threshold_impl) self.clamp_scaling = _ClampValue(scaling_min_val) self.local_loss_mode: bool = brevitas.jit.Attribute( False, bool) # required to support MSE eval or variants self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() + self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: @@ -362,13 +397,13 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) - threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) value = self.clamp_scaling(self.restrict_scaling(self.value)) value = value / threshold self.counter = self.counter + 1 return abs_binary_sign_grad(value) else: - threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) value = self.clamp_scaling(self.restrict_scaling(self.value)) value = value / threshold return abs_binary_sign_grad(value) @@ -383,10 +418,10 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te else: if self.counter <= self.collect_stats_steps: out = self.buffer - out = self.restrict_preprocess(out) + out = self.restrict_scaling_pre(out) else: out = self.value - threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) out = self.clamp_scaling(self.restrict_scaling(out)) out = out / threshold out = abs_binary_sign_grad(self.clamp_scaling(out)) From 6601a2f9fb063fe9d25513a0f24f7816e43dbd1a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 13:08:43 +0000 Subject: [PATCH 05/14] Clean-up --- src/brevitas/core/restrict_val.py | 5 ++++- .../quant/experimental/mx_quant_ocp.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 7eb9845f9..7d6d83231 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -24,7 +24,10 @@ class _RestrictClampValue(brevitas.jit.ScriptModule): - def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Optional[Module]): + def __init__( + self, + scaling_min_val: Optional[float] = None, + restrict_value_impl: Optional[Module] = None): super(_RestrictClampValue, self).__init__() if scaling_min_val is not None and scaling_min_val != 0: self.clamp_min_ste = ScalarClampMinSte(scaling_min_val) diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py index 551f4f3d7..5900fe663 100644 --- a/src/brevitas/quant/experimental/mx_quant_ocp.py +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -1,10 +1,13 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from dependencies import this from dependencies import value from brevitas.core.function_wrapper.ops_ste import CeilSte from brevitas.core.function_wrapper.ops_ste import FloorSte +from brevitas.core.restrict_val import PowerOfTwo +from brevitas.core.restrict_val import PowerOfTwoRestrictValue from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.inject import ExtendedInjector from brevitas.inject.enum import RestrictValueType @@ -44,14 +47,25 @@ class GroupwiseActProxyMixin(ExtendedInjector): proxy_class = GroupwiseActQuantProxyFromInjector +class RestrictThresholdMixin(ExtendedInjector): + restrict_value_float_to_int_impl = FloorSte + restrict_scaling_impl = PowerOfTwoRestrictValue + + class MXWeightMixin(ExtendedInjector): + threshold_mixin = RestrictThresholdMixin group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO restrict_value_float_to_int_impl = FloorSte scaling_per_output_type = ScalingPerOutputType.GROUP + @value + def restrict_threshold_impl(): + return this.threshold_mixin.restrict_scaling_impl + class MXActMixin(ExtendedInjector): + threshold_mixin = RestrictThresholdMixin group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO restrict_value_float_to_int_impl = FloorSte @@ -66,6 +80,10 @@ def stats_reduce_dim(group_dim): else: return group_dim + 1 + @value + def restrict_threshold_impl(): + return this.threshold_mixin.restrict_scaling_impl + class MXFloat8e4m3Weight(MXWeightMixin, GroupwiseWeightFloatProxyMixin, From 41add72f3d48d03bf49d8bb80b22487a16a2a151 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 13:22:39 +0000 Subject: [PATCH 06/14] fix --- src/brevitas/core/scaling/standalone.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index d9347898f..13ead5afc 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -244,9 +244,6 @@ def __init__( def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(ignored) - # Threshold division must happen after we update self.value, but before we apply restrict_preproces - # This is because we don't want to store a parameter dependant on a runtime value (threshold) - # And because restrict needs to happen after we divide by threshold if self.init_done: threshold = self.stats_scaling_impl.restrict_clamp_threshold( self.restrict_threshold_pre(threshold)) @@ -373,9 +370,6 @@ def __init__( @brevitas.jit.script_method def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: - # Threshold division must happen after we update self.value, but before we apply restrict_preproces - # This is because we don't want to store a parameter dependent on a runtime value (threshold) - # And because restrict needs to happen after we divide by threshold if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) @@ -437,7 +431,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): del output_dict[prefix + 'value'] # Save buffer into value for any non-zero number of collection steps elif self.counter <= self.collect_stats_steps: - output_dict[prefix + 'value'] = self.restrict_preprocess(self.buffer) + output_dict[prefix + 'value'] = self.restrict_scaling_pre(self.buffer) return output_dict def _load_from_state_dict( From b2cc565c64c0d07082563d5e2af9f91d335e2582 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 23 Oct 2024 13:34:08 +0100 Subject: [PATCH 07/14] Fix (minifloat): correct minifloat computation and tests --- src/brevitas/quant_tensor/float_quant_tensor.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index 459f0eec7..9252b8d72 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -150,11 +150,6 @@ def minifloat(self, float_datatype=True): int_scale = float_internal_scale( minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps) float_value = torch.round(self._pre_round_float_value) * int_scale - return float_value.type(self.scale.dtype) - else: - raise RuntimeError(f"FloatQuantTensor not valid.") - - @staticmethod def check_input_type(tensor): if not isinstance(tensor, FloatQuantTensor): raise RuntimeError("Tensor is not a FloatQuantTensor") From cd649ab891d3e6f51d757a0342449d3a19661493 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 23 Oct 2024 10:04:51 +0100 Subject: [PATCH 08/14] Test --- tests/brevitas/core/test_quant_mx.py | 209 +++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 tests/brevitas/core/test_quant_mx.py diff --git a/tests/brevitas/core/test_quant_mx.py b/tests/brevitas/core/test_quant_mx.py new file mode 100644 index 000000000..8ac55b849 --- /dev/null +++ b/tests/brevitas/core/test_quant_mx.py @@ -0,0 +1,209 @@ +""" +Brief MXFP quantizer +""" +# pylint: disable=missing-function-docstring, redefined-outer-name + +import struct + +try: + from mx.mx_ops import _quantize_mx as mx +except: + mx = None +import pytest_cases +import torch + +from brevitas.nn.quant_activation import QuantIdentity +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act +from brevitas.utils.torch_utils import float_internal_scale + +torch.manual_seed(0) + + +# debug utility +def to_string(val: torch.Tensor | float, spaced: bool = True, code: str = "f") -> str | list[str]: + """ Debug util for visualizing float values """ + + def scalar_to_string(val: float, spaced: bool) -> str: + s = ''.join(bin(c).replace('0b', '').rjust(8, '0') for c in struct.pack('!' + code, val)) + spaced = spaced and len(s) == 32 + return f"{s[0]} {s[1:9]} {s[9:]}" if spaced else s + + if isinstance(val, float): + return scalar_to_string(val, spaced) + val = val.view(-1) + return [scalar_to_string(val[i].item(), spaced) for i in range(val.numel())] + + +# debug utility +def check_bits(val: torch.Tensor | float, mbits: int) -> (bool, int): + """ return (too many precision bits, lowest mantissa bit) """ + strings = to_string(val, spaced=False) + if isinstance(strings, str): + strings = [strings] + error, lowest = False, 0 + for s in strings: + mant = s[9:] + error = error or "1" in mant[mbits:] + lowest = max(lowest, mant.find("1")) + return error, lowest + + +# Avoid returning exp 0 if we is 0 +def safe_frexp(x: torch.Tensor) -> torch.Tensor: + """torch.frexp returns unbiased exponent 0 for 0.0, which is not what we want.""" + if x.is_cuda and x.dtype not in (torch.float32, torch.float16): + x = x.float() # no gpu support for frexp on bfloat16 or any float8 + return torch.where(x == 0.0, -126, x.frexp().exponent - 1) + + +class MXFP: + """ + MXFP - Quantize OCP MXFP floating point types. + A type is defined as ebits, mbits, bias, and inf/nan handling. + """ + CONFIG = dict( + e5m2=(5, 2, 15, "ieee"), + e4m3=(4, 3, 7, "fn"), + e3m2=(3, 2, 3, "fnuz"), + e2m3=(2, 3, 1, "fnuz"), + e2m1=(2, 1, 1, "fnuz")) + + def __init__(self, name, tile_size: int | None = 32): + self.name = name.lower() + assert self.name in self.CONFIG + self.ebits, self.mbits, self.bias, self.infnan = self.CONFIG[self.name] + self.tile_size = tile_size + + @property # maximum unbiased exponent for this type + def emax(self) -> int: + return 2 ** self.ebits - 1 - self.bias - int(self.infnan == "ieee") + + @property # minimum unbiased exponent for this type + def emin(self) -> int: + return 1 - self.bias + + @property # maximum representable value; the "fn" reserves values for all non-sign bits == 1 + def maxval(self) -> float: + return 2 ** self.emax * (2.0 - (1 + int(self.infnan == "fn")) * 2 ** (-self.mbits)) + + @property # for alternative scale selection + def midmax(self) -> float: + return (2 ** (self.emax + 1) - self.maxval) / 2. + self.maxval + + @property # minimum representable positive value + def minval(self) -> float: + return 2 ** self.emin * 2 ** (-self.mbits) + + def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False): + """ + Fake quantize along the indicated dimension. This method assumes the tile dimension is the size of the tile, + so some reshaping and possibly padding is likely required. From there, we have 5 needed lines of code. + """ + exp = safe_frexp(tensor) # safe_frexp pretends the mantissa is < 1.0 + shared = exp.amax(axis, keepdim=True) # shared exponent per the OCP MX spec + + # This is an alternative to the OCP MX scale selection, which chooses the maximum exponent (maxexp). + # Instead, choose maxexp + 1 if absmax is closer to 2^(maxexp+1) than maxval. This reduces error on + # the highest magnitude value at the potential cost increased error or underflow of the smallest. + # Ad hoc MSE test shows that e4m3, due to reserving the most significant value for Nan, benefits the + # most from this technique. In hardware or a kernel, this is as simple as comparing bits [30:21] + # instead of [30:23] when getting max exponent, then add 1 to the max eeeeeeeemm and shift right two. + # e2m1 e3m2 e2m3 e4m3 e5m2 + # max 0.01325 0.00291 0.00080 0.00085 0.00291 + # best 0.01254 0.00280 0.00079 0.00071 0.00280 + + if select: + midmax = self.midmax * (shared - self.emax).exp2() + shared[tensor.abs().amax(axis, keepdim=True) > midmax] += 1 + + # The way this works is to appropriately shift values so that rounding can work, then shift them back. + # All values that are representable as normal given the scale are shifted up by the difference + # between the individual exponent and zero, plus the mantissa width. Subnormals get the same, + # but with decreasing mantissa bits. The maxval for saturation is adjusted on a per block basis. + scale = (self.mbits - (shared - exp - (self.emax - self.emin)).clamp_min(0) - exp).exp2() + # about that last line of code: + # The "offset" is the number of mbits lost to subnormal/underflow. This is based on the difference between + # the shared exponent and the individual exponent, adjusted to the dynamic range of normals for this type. + # It can't be negative, because we subtract it from mbits, and don't want to exceed the available mbits. + # offset = (shared - exp - (self.emax - self.emin)).clamp_min(0) + # The shift left will be mbits - offset - exp, which for negative exponents gets them into the right range. + maxval = self.maxval * (shared - self.emax).exp2() # scale maxval per tile + return ((tensor * scale).round() / scale).clamp(-maxval, maxval), scale + + +INP = torch.tensor([[ + -0.569248080254, + 0.919971406460, + 1.110816121101, + 1.289874076843, + -1.478173971176, + 2.567232847214, + -0.473119795322, + 0.335550755262, + -1.629325985909, + -0.549743652344, + -0.479834258556, + -0.499681532383, + -1.066980361938, + 1.114939570427, + -0.140671432018, + 0.805753588676, + -0.093348234892, + 0.687050223351, + -0.838315367699, + 0.000891821750, + 0.841894090176, + -0.400034159422, + 1.039461970329, + 0.358153104782, + -0.246000945568, + 2.302516460419, + -1.881689190865, + -0.049727022648, + -1.044978618622, + -0.956500828266, + 0.033531859517, + 0.710086584091]]) +# Falsifying value is [0, 19] + +MAP = { + "e4m3": (4, 3),} +# "e5m2": (5,2), +# "e2m3": (2,3), +# "e3m2": (3,2), +# "e2m1": (2,1)} + + +@pytest_cases.parametrize('bit_widths', list(MAP.keys())) +@pytest_cases.parametrize('select', [False]) +def test_mx(bit_widths, select): + # print("-------------------------------------------") + torch.set_printoptions(precision=12, sci_mode=False) + exp, mant = MAP[bit_widths] + act_quant = QuantIdentity( + MXFloat8e4m3Act, + exponent_bit_width=exp, + mantissa_bit_width=mant, + bit_width=mant + exp + 1, + group_dim=-1, + return_quant_tensor=True) + act_quant.eval() + x = INP + + dtype = MXFP(bit_widths) + q, scale = dtype.quantize(x, select=select) + qx = act_quant(x) + error, lowest = check_bits(q, dtype.mbits) + + exp_bias = torch.tensor(2 ** (exp - 1) - 1) + + int_scale = float_internal_scale( + x / qx.scale, torch.tensor(mant), 1. - exp_bias - torch.tensor(mant), torch.tensor(1e-8)) + brev_scale = 1 / (int_scale * qx.scale) + if mx is None: + print("Install microscaling library, --no-deps flag recommended") + else: + y = mx( + x, 8, elem_format="fp8_e4m3", block_size=32, axes=-1, round='even', custom_cuda=False) + assert torch.allclose(qx.value, q, atol=1e-4) + assert torch.allclose(brev_scale, scale, atol=1e-4) From 2bbaa3641963e92ad470ed0523492e1a469692e3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 24 Oct 2024 12:56:53 +0100 Subject: [PATCH 09/14] Test update --- tests/brevitas/core/test_quant_mx.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/brevitas/core/test_quant_mx.py b/tests/brevitas/core/test_quant_mx.py index 8ac55b849..8127c3d58 100644 --- a/tests/brevitas/core/test_quant_mx.py +++ b/tests/brevitas/core/test_quant_mx.py @@ -166,17 +166,13 @@ def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False): 0.710086584091]]) # Falsifying value is [0, 19] -MAP = { - "e4m3": (4, 3),} -# "e5m2": (5,2), -# "e2m3": (2,3), -# "e3m2": (3,2), -# "e2m1": (2,1)} +MAP = {"e4m3": (4, 3), "e5m2": (5, 2), "e2m3": (2, 3), "e3m2": (3, 2), "e2m1": (2, 1)} @pytest_cases.parametrize('bit_widths', list(MAP.keys())) @pytest_cases.parametrize('select', [False]) -def test_mx(bit_widths, select): +@pytest_cases.parametrize('iter', [0]) +def test_mx(bit_widths, select, iter): # print("-------------------------------------------") torch.set_printoptions(precision=12, sci_mode=False) exp, mant = MAP[bit_widths] @@ -185,7 +181,7 @@ def test_mx(bit_widths, select): exponent_bit_width=exp, mantissa_bit_width=mant, bit_width=mant + exp + 1, - group_dim=-1, + group_dim=1, return_quant_tensor=True) act_quant.eval() x = INP From 07bbff0b45de983d0ec7f1aed25ed8bd27a2ba9e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 24 Oct 2024 15:12:02 +0100 Subject: [PATCH 10/14] Update tests, now passing (local) --- tests/brevitas/core/test_quant_mx.py | 64 ++++++++-------------------- 1 file changed, 18 insertions(+), 46 deletions(-) diff --git a/tests/brevitas/core/test_quant_mx.py b/tests/brevitas/core/test_quant_mx.py index 8127c3d58..d300768ae 100644 --- a/tests/brevitas/core/test_quant_mx.py +++ b/tests/brevitas/core/test_quant_mx.py @@ -5,10 +5,13 @@ import struct +from tests.brevitas.hyp_helper import float_tensor_nz_st + try: from mx.mx_ops import _quantize_mx as mx except: mx = None +from hypothesis import given import pytest_cases import torch @@ -131,48 +134,17 @@ def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False): return ((tensor * scale).round() / scale).clamp(-maxval, maxval), scale -INP = torch.tensor([[ - -0.569248080254, - 0.919971406460, - 1.110816121101, - 1.289874076843, - -1.478173971176, - 2.567232847214, - -0.473119795322, - 0.335550755262, - -1.629325985909, - -0.549743652344, - -0.479834258556, - -0.499681532383, - -1.066980361938, - 1.114939570427, - -0.140671432018, - 0.805753588676, - -0.093348234892, - 0.687050223351, - -0.838315367699, - 0.000891821750, - 0.841894090176, - -0.400034159422, - 1.039461970329, - 0.358153104782, - -0.246000945568, - 2.302516460419, - -1.881689190865, - -0.049727022648, - -1.044978618622, - -0.956500828266, - 0.033531859517, - 0.710086584091]]) -# Falsifying value is [0, 19] - -MAP = {"e4m3": (4, 3), "e5m2": (5, 2), "e2m3": (2, 3), "e3m2": (3, 2), "e2m1": (2, 1)} +MAP = { + "fp8_e4m3": (4, 3), + "fp8_e5m2": (5, 2), + "fp6_e2m3": (2, 3), + "fp6_e3m2": (3, 2), + "fp4_e2m1": (2, 1)} +@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10)) @pytest_cases.parametrize('bit_widths', list(MAP.keys())) -@pytest_cases.parametrize('select', [False]) -@pytest_cases.parametrize('iter', [0]) -def test_mx(bit_widths, select, iter): +def test_mx(inp, bit_widths): # print("-------------------------------------------") torch.set_printoptions(precision=12, sci_mode=False) exp, mant = MAP[bit_widths] @@ -184,12 +156,12 @@ def test_mx(bit_widths, select, iter): group_dim=1, return_quant_tensor=True) act_quant.eval() - x = INP + x = inp - dtype = MXFP(bit_widths) - q, scale = dtype.quantize(x, select=select) + # dtype = MXFP(bit_widths) + # q, scale = dtype.quantize(x, select=False) qx = act_quant(x) - error, lowest = check_bits(q, dtype.mbits) + # error, lowest = check_bits(q, dtype.mbits) exp_bias = torch.tensor(2 ** (exp - 1) - 1) @@ -200,6 +172,6 @@ def test_mx(bit_widths, select, iter): print("Install microscaling library, --no-deps flag recommended") else: y = mx( - x, 8, elem_format="fp8_e4m3", block_size=32, axes=-1, round='even', custom_cuda=False) - assert torch.allclose(qx.value, q, atol=1e-4) - assert torch.allclose(brev_scale, scale, atol=1e-4) + x, 8, elem_format=bit_widths, block_size=32, axes=-1, round='even', custom_cuda=False) + assert torch.allclose(qx.value, y, atol=1e-4) + # assert torch.allclose(brev_scale, scale, atol=1e-4) From 3340972192d5aa56318b03861d7ad0dd794c443f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 24 Oct 2024 21:57:23 +0100 Subject: [PATCH 11/14] More fixes --- tests/brevitas/core/test_quant_mx.py | 48 +++++++++++++++++++++------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/tests/brevitas/core/test_quant_mx.py b/tests/brevitas/core/test_quant_mx.py index d300768ae..44f375962 100644 --- a/tests/brevitas/core/test_quant_mx.py +++ b/tests/brevitas/core/test_quant_mx.py @@ -5,6 +5,7 @@ import struct +from brevitas.nn.quant_linear import QuantLinear from tests.brevitas.hyp_helper import float_tensor_nz_st try: @@ -17,7 +18,7 @@ from brevitas.nn.quant_activation import QuantIdentity from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act -from brevitas.utils.torch_utils import float_internal_scale +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight torch.manual_seed(0) @@ -144,10 +145,11 @@ def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False): @given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10)) @pytest_cases.parametrize('bit_widths', list(MAP.keys())) -def test_mx(inp, bit_widths): +def test_act_mx(inp, bit_widths): # print("-------------------------------------------") torch.set_printoptions(precision=12, sci_mode=False) exp, mant = MAP[bit_widths] + act_quant = QuantIdentity( MXFloat8e4m3Act, exponent_bit_width=exp, @@ -158,20 +160,44 @@ def test_mx(inp, bit_widths): act_quant.eval() x = inp - # dtype = MXFP(bit_widths) - # q, scale = dtype.quantize(x, select=False) qx = act_quant(x) - # error, lowest = check_bits(q, dtype.mbits) - exp_bias = torch.tensor(2 ** (exp - 1) - 1) + if mx is None: + print("Install microscaling library, --no-deps flag recommended") + else: + y = mx( + x, 8, elem_format=bit_widths, block_size=32, axes=-1, round='even', custom_cuda=False) + assert torch.allclose(qx.value, y, atol=1e-8) + + +@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10)) +@pytest_cases.parametrize('bit_widths', list(MAP.keys())) +@pytest_cases.parametrize('weight_quant_type', ['stats', 'parameter_from_stats']) +def test_weight_mx(inp, bit_widths, weight_quant_type): + # print("-------------------------------------------") + torch.set_printoptions(precision=12, sci_mode=False) + exp, mant = MAP[bit_widths] + weight_quant = QuantLinear( + 32, + 1, + bias=False, + weight_quant=MXFloat8e4m3Weight, + weight_scaling_impl_type=weight_quant_type, + weight_exponent_bit_width=exp, + weight_mantissa_bit_width=mant, + weight_bit_width=mant + exp + 1) + + x = inp + weight_quant.weight.data = x + weight_quant.weight_quant.init_tensor_quant() + + qx_weight = weight_quant.quant_weight() + qx_weight_two = weight_quant.quant_weight() - int_scale = float_internal_scale( - x / qx.scale, torch.tensor(mant), 1. - exp_bias - torch.tensor(mant), torch.tensor(1e-8)) - brev_scale = 1 / (int_scale * qx.scale) if mx is None: print("Install microscaling library, --no-deps flag recommended") else: y = mx( x, 8, elem_format=bit_widths, block_size=32, axes=-1, round='even', custom_cuda=False) - assert torch.allclose(qx.value, y, atol=1e-4) - # assert torch.allclose(brev_scale, scale, atol=1e-4) + assert torch.allclose(qx_weight.value, y, atol=1e-8) + assert torch.allclose(qx_weight_two.value, y, atol=1e-8) From 7494e2ed5ad6e31fa89f46f2a28949ad02b855d5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 13:56:34 +0000 Subject: [PATCH 12/14] Update tests --- tests/brevitas/core/test_quant_mx.py | 38 ++++++++-------------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/tests/brevitas/core/test_quant_mx.py b/tests/brevitas/core/test_quant_mx.py index 44f375962..b2ab279d4 100644 --- a/tests/brevitas/core/test_quant_mx.py +++ b/tests/brevitas/core/test_quant_mx.py @@ -4,21 +4,17 @@ # pylint: disable=missing-function-docstring, redefined-outer-name import struct +from typing import Tuple -from brevitas.nn.quant_linear import QuantLinear -from tests.brevitas.hyp_helper import float_tensor_nz_st - -try: - from mx.mx_ops import _quantize_mx as mx -except: - mx = None from hypothesis import given import pytest_cases import torch from brevitas.nn.quant_activation import QuantIdentity +from brevitas.nn.quant_linear import QuantLinear from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight +from tests.brevitas.hyp_helper import float_tensor_nz_st torch.manual_seed(0) @@ -39,7 +35,7 @@ def scalar_to_string(val: float, spaced: bool) -> str: # debug utility -def check_bits(val: torch.Tensor | float, mbits: int) -> (bool, int): +def check_bits(val: torch.Tensor | float, mbits: int) -> Tuple[bool, int]: """ return (too many precision bits, lowest mantissa bit) """ strings = to_string(val, spaced=False) if isinstance(strings, str): @@ -132,21 +128,15 @@ def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False): # offset = (shared - exp - (self.emax - self.emin)).clamp_min(0) # The shift left will be mbits - offset - exp, which for negative exponents gets them into the right range. maxval = self.maxval * (shared - self.emax).exp2() # scale maxval per tile - return ((tensor * scale).round() / scale).clamp(-maxval, maxval), scale + return ((tensor * scale).round() / scale).clamp(-maxval, maxval) -MAP = { - "fp8_e4m3": (4, 3), - "fp8_e5m2": (5, 2), - "fp6_e2m3": (2, 3), - "fp6_e3m2": (3, 2), - "fp4_e2m1": (2, 1)} +MAP = {"e4m3": (4, 3), "e5m2": (5, 2), "e2m3": (2, 3), "e3m2": (3, 2), "e2m1": (2, 1)} @given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10)) @pytest_cases.parametrize('bit_widths', list(MAP.keys())) def test_act_mx(inp, bit_widths): - # print("-------------------------------------------") torch.set_printoptions(precision=12, sci_mode=False) exp, mant = MAP[bit_widths] @@ -160,13 +150,11 @@ def test_act_mx(inp, bit_widths): act_quant.eval() x = inp + quantizer = MXFP(bit_widths) + qx = act_quant(x) - if mx is None: - print("Install microscaling library, --no-deps flag recommended") - else: - y = mx( - x, 8, elem_format=bit_widths, block_size=32, axes=-1, round='even', custom_cuda=False) + y = quantizer.quantize(x) assert torch.allclose(qx.value, y, atol=1e-8) @@ -174,7 +162,6 @@ def test_act_mx(inp, bit_widths): @pytest_cases.parametrize('bit_widths', list(MAP.keys())) @pytest_cases.parametrize('weight_quant_type', ['stats', 'parameter_from_stats']) def test_weight_mx(inp, bit_widths, weight_quant_type): - # print("-------------------------------------------") torch.set_printoptions(precision=12, sci_mode=False) exp, mant = MAP[bit_widths] weight_quant = QuantLinear( @@ -190,14 +177,11 @@ def test_weight_mx(inp, bit_widths, weight_quant_type): x = inp weight_quant.weight.data = x weight_quant.weight_quant.init_tensor_quant() + quantizer = MXFP(bit_widths) qx_weight = weight_quant.quant_weight() qx_weight_two = weight_quant.quant_weight() - if mx is None: - print("Install microscaling library, --no-deps flag recommended") - else: - y = mx( - x, 8, elem_format=bit_widths, block_size=32, axes=-1, round='even', custom_cuda=False) + y = quantizer.quantize(x) assert torch.allclose(qx_weight.value, y, atol=1e-8) assert torch.allclose(qx_weight_two.value, y, atol=1e-8) From 87358dfdd618b21877454d53769eeb5119787cd5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 15:33:11 +0000 Subject: [PATCH 13/14] Feat (groupwise): builder class for quantizer --- .../quant/experimental/mx_quant_ocp.py | 125 ++++++++++++++++++ .../common/generative/quantize.py | 2 +- .../common/generative/quantizers.py | 9 -- 3 files changed, 126 insertions(+), 10 deletions(-) diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py index 5900fe663..d682da2dc 100644 --- a/src/brevitas/quant/experimental/mx_quant_ocp.py +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional + from dependencies import this from dependencies import value @@ -8,6 +10,7 @@ from brevitas.core.function_wrapper.ops_ste import FloorSte from brevitas.core.restrict_val import PowerOfTwo from brevitas.core.restrict_val import PowerOfTwoRestrictValue +from brevitas.core.restrict_val import RoundSte from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.inject import ExtendedInjector from brevitas.inject.enum import RestrictValueType @@ -22,10 +25,14 @@ from brevitas.quant.base import MinMaxStatsScaling from brevitas.quant.base import MSEAsymmetricScale from brevitas.quant.base import MSESymmetricScale +from brevitas.quant.base import MSESymmetricScaleSubInjector from brevitas.quant.base import ShiftedMinUintQuant +from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float_base import ScaledFloatActBase from brevitas.quant.experimental.float_base import ScaledFloatWeightBase +from brevitas.quant.experimental.float_quant_fnuz import FpFNUZMixin from brevitas.quant.experimental.float_quant_ocp import FpOCPAct +from brevitas.quant.experimental.float_quant_ocp import FpOCPMixin from brevitas.quant.experimental.float_quant_ocp import FpOCPWeight from brevitas.quant.solver.act import ActQuantSolver from brevitas.quant.solver.weight import WeightQuantSolver @@ -154,3 +161,121 @@ class ShiftedMXUInt8WeightMSE(MSEAsymmetricScale, ShiftedMXUInt8Weight): MX Int signed weight quantizer with per-channel MSE-based scaling. """ pass + + +class Fp8e4m3WeightSymmetricGroupQuant(Fp8e4m3WeightPerChannelFloat): + """ + Block / group / vector signed symmetric e4m3 weight quantizer with float scales. + We inherit from a per-channel quantizer to re-use some underlying machinery. + """ + proxy_class = GroupwiseWeightFloatQuantProxyFromInjector + scaling_per_output_type = ScalingPerOutputType.GROUP + + +def build_options( + weight_quant, + bit_width, + scale_stats_op, + is_po2_scale, + scale_computation_type, + scale_rounding_func_type: Optional[str], + group_size: int = 32, + group_dim: Optional[int] = None, + scaling_min_val: float = 1e-8): + + options = dict() + scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte} + + options['group_size'] = group_size + options['bit_width'] = bit_width + options['scaling_min_val'] = scaling_min_val + + if scale_stats_op == 'mse': + weight_quant = type('MSEWeightQuant', (MSESymmetricScale, weight_quant), {}) + else: + options['scale_stats_op'] = scale_stats_op + + if group_dim is not None: + options['group_dim'] = group_dim + + if scale_computation_type == 'param_from_stats': + options['scaling_impl_type'] = 'parameter_from_stats' + elif scale_computation_type == 'stats': + options['scaling_impl_type'] = 'stats' + else: + raise RuntimeError("Not supported") + + if is_po2_scale: + scale_rounding_func = scale_rounding_func_dict[scale_rounding_func_type] + options['restrict_scaling_type'] = RestrictValueType.POWER_OF_TWO + options['restrict_value_float_to_int_impl'] = scale_rounding_func + else: + # If not po2, threshold does need any restriction and will match float restriction of the scale + options['restrict_scaling_type'] = RestrictValueType.FP + options['restrict_threshold_impl'] = None + assert scale_rounding_func_type is None, "Rounding for scale not needed when float" + return options, weight_quant + + +class GroupwiseIntWeightQuantizerBuilder: + + def __new__( + self, + bit_width, + scale_stats_op, + is_po2_scale, + scale_computation_type, + scale_rounding_func_type: Optional[str], + group_size: int = 32, + group_dim: Optional[int] = None, + scaling_min_val: float = 1e-8, + ): + + weight_quant = MXInt8Weight + options, weight_quant = build_options(weight_quant, bit_width, + scale_stats_op, + is_po2_scale, + scale_computation_type, + scale_rounding_func_type, + group_size, + group_dim, + scaling_min_val) + weight_quant = weight_quant.let(**options) + return weight_quant + + +class GroupwiseFloatWeightQuantizerBuilder(GroupwiseIntWeightQuantizerBuilder): + + def __new__( + self, + exponent_bit_width, + mantissa_bit_width, + bit_width, + scale_stats_op, + is_po2_scale, + scale_computation_type, + scale_rounding_func_type: Optional[str], + group_size: int = 32, + group_dim: Optional[int] = None, + scaling_min_val: float = 1e-8, + format: Optional[str] = None): + weight_quant = Fp8e4m3WeightSymmetricGroupQuant + + if format == 'ocp': + weight_quant = type('OCPWeightQuant', (FpOCPMixin, weight_quant), {}) + if format == 'fnuz': + weight_quant = type('OCPWeightQuant', (FpFNUZMixin, weight_quant), {}) + + options, weight_quant = build_options(weight_quant, bit_width, + scale_stats_op, + is_po2_scale, + scale_computation_type, + scale_rounding_func_type, + group_size, + group_dim, + scaling_min_val) + options['exponent_bit_width'] = exponent_bit_width + options['mantissa_bit_width'] = mantissa_bit_width + + weight_quant = weight_quant.let(**options) + return weight_quant diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 9460fadf1..98e467708 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -20,6 +20,7 @@ from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat +from brevitas.quant.experimental.mx_quant_ocp import Fp8e4m3WeightSymmetricGroupQuant from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE @@ -55,7 +56,6 @@ from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat -from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index c3c99a96f..4f7040d08 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -49,15 +49,6 @@ class IntWeightSymmetricGroupQuant(Int8WeightPerChannelFloat): scaling_per_output_type = ScalingPerOutputType.GROUP -class Fp8e4m3WeightSymmetricGroupQuant(Fp8e4m3WeightPerChannelFloat): - """ - Block / group / vector signed symmetric e4m3 weight quantizer with float scales. - We inherit from a per-channel quantizer to re-use some underlying machinery. - """ - proxy_class = GroupwiseWeightFloatQuantProxyFromInjector - scaling_per_output_type = ScalingPerOutputType.GROUP - - class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ Symmetric quantizer with per tensor dynamic scale. From dca94faa61afd0cc5b58622ff69258293116c5f6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 16:59:49 +0000 Subject: [PATCH 14/14] Integration with llm entrypoing --- .../quant/experimental/mx_quant_ocp.py | 1 + .../common/generative/quantize.py | 33 +++++++++++++++++-- src/brevitas_examples/llm/main.py | 10 +++++- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py index d682da2dc..b2d719bc6 100644 --- a/src/brevitas/quant/experimental/mx_quant_ocp.py +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -206,6 +206,7 @@ def build_options( raise RuntimeError("Not supported") if is_po2_scale: + assert scale_rounding_func_type is not None scale_rounding_func = scale_rounding_func_dict[scale_rounding_func_type] options['restrict_scaling_type'] = RestrictValueType.POWER_OF_TWO options['restrict_value_float_to_int_impl'] = scale_rounding_func diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 98e467708..457877459 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -21,6 +21,8 @@ from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat from brevitas.quant.experimental.mx_quant_ocp import Fp8e4m3WeightSymmetricGroupQuant +from brevitas.quant.experimental.mx_quant_ocp import GroupwiseFloatWeightQuantizerBuilder +from brevitas.quant.experimental.mx_quant_ocp import GroupwiseIntWeightQuantizerBuilder from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE @@ -222,7 +224,8 @@ def generate_quantizers( quantize_input_zero_point=False, device=None, weight_kwargs=None, - input_kwargs=None): + input_kwargs=None, + weight_scale_rounding_func_type=None): """ Replace float layers with quant layers in the target model """ @@ -243,8 +246,32 @@ def generate_quantizers( else: input_float_format = {} - weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][ - weight_param_method][weight_quant_granularity][weight_quant_type] + if weight_quant_granularity == 'per_group': + if weight_quant_format == 'int': + weight_quant = GroupwiseIntWeightQuantizerBuilder( + bit_width=weight_bit_width, + scale_stats_op='max' if weight_param_method != 'mse' else weight_param_method, + is_po2_scale=weight_scale_precision == 'po2_scale', + scale_computation_type='parameter_from_stats', + scale_rounding_func_type=weight_scale_rounding_func_type, + group_dim=weight_group_dim, + group_size=weight_group_size, + scaling_min_val=1e-4 if dtype == torch.float16 else 1e-8) + else: + weight_quant = GroupwiseFloatWeightQuantizerBuilder( + exponent_bit_width=weight_float_format['exponent_bit_width'], + mantissa_bit_width=weight_float_format['mantissa_bit_width'], + bit_width=weight_bit_width, + scale_stats_op='max' if weight_param_method != 'mse' else weight_param_method, + is_po2_scale=weight_scale_precision == 'po2_scale', + scale_computation_type='parameter_from_stats', + scale_rounding_func_type=weight_scale_rounding_func_type, + group_dim=weight_group_dim, + group_size=weight_group_size, + scaling_min_val=1e-4 if dtype == torch.float16 else 1e-8) + else: + weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][ + weight_param_method][weight_quant_granularity][weight_quant_type] if input_bit_width is not None and input_scale_type == 'no_scale': input_quant = sym_input_quant = linear_input_quant = INPUT_QUANT_MAP[input_quant_format][ diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 4a87f5a1a..5ef39fffa 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -253,7 +253,9 @@ def main(args): input_quant_granularity=args.input_quant_granularity, input_group_size=args.input_group_size, quantize_input_zero_point=args.quantize_input_zero_point, - device=device) + device=device, + weight_scale_rounding_func_type=args.weight_scale_rounding_func_type + ) layer_map = generate_quant_maps( linear_input_quant=linear_input_quant, weight_quant=weight_quant, @@ -400,6 +402,12 @@ def parse_args(args): default='per_group', choices=['per_channel', 'per_tensor', 'per_group'], help='Granularity for scales/zero-point of weights. Default: per_group.') + parser.add_argument( + '--weight-scale-rounding-func-type', + type=str, + default=None, + choices=['round', 'ceil', 'floor'], + help='Rounding function to use with Po2 scale. Default: None.') parser.add_argument( '--weight-group-dim', type=int,