From 5e2d00a3d19101c78068a8cc53e1ada17bbcdc11 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Tue, 4 Jul 2023 18:28:00 +0100 Subject: [PATCH] Add per tensor/row/group dynamic scale support, some dtype improvements --- .../llm/llm_quant/equalize.py | 7 +- .../llm/llm_quant/ln_affine_merge.py | 4 +- .../llm/llm_quant/quant_blocks.py | 101 ++++++------ .../llm/llm_quant/quantize.py | 148 +++++++++++++----- .../llm/llm_quant/quantizers.py | 50 +++++- .../llm/llm_quant/run_utils.py | 18 ++- src/brevitas_examples/llm/main.py | 38 +++-- .../llm/test_linear_mlir_export.py | 2 +- 8 files changed, 251 insertions(+), 117 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/equalize.py b/src/brevitas_examples/llm/llm_quant/equalize.py index cd0746dc6..f3e4c3b0d 100644 --- a/src/brevitas_examples/llm/llm_quant/equalize.py +++ b/src/brevitas_examples/llm/llm_quant/equalize.py @@ -29,6 +29,7 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha): @torch.no_grad() def apply_act_equalization( model, + dtype, act_equalization_type, dataloader, nsamples, @@ -47,7 +48,7 @@ def apply_act_equalization( assert ref_kwargs is not None, "Ref kwargs required to perform tracing and lift the model into FX." # We can't do fp16 tracing on CPU as many kernels are not implemented # So we have to cast to fp32 first, trace, apply equalization, and then cast back - with cast_to_float32(model): + with cast_to_float32(model, dtype): graph_model = value_trace(model, value_args=ref_kwargs) # TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode # or an FX interpreter to run it on GPU @@ -65,9 +66,9 @@ def apply_act_equalization( @torch.no_grad() -def apply_weight_equalization(model, ref_kwargs, scale_computation_type='range'): +def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type='range'): # We can't do fp16 tracing on CPU as many kernels are not implemented # So we have to cast to fp32 first, trace, apply equalization, and then cast back - with cast_to_float32(model): + with cast_to_float32(model, dtype): graph_model = value_trace(model, value_args=ref_kwargs) EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model) diff --git a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py index 191800af1..37aa8d5d3 100644 --- a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py +++ b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py @@ -84,9 +84,9 @@ def merge_layernorm_affine_params(graph_model): @torch.no_grad() -def apply_layernorm_affine_merge(model, ref_kwargs): +def apply_layernorm_affine_merge(model, dtype, ref_kwargs): # We can't do fp16 tracing on CPU as many kernels are not implemented # So we have to cast to fp32 first, trace, apply merging, and then cast back - with cast_to_float32(model): + with cast_to_float32(model, dtype): graph_model = value_trace(model, value_args=ref_kwargs) merge_layernorm_affine_params(graph_model) diff --git a/src/brevitas_examples/llm/llm_quant/quant_blocks.py b/src/brevitas_examples/llm/llm_quant/quant_blocks.py index 2f673be75..a4334157b 100644 --- a/src/brevitas_examples/llm/llm_quant/quant_blocks.py +++ b/src/brevitas_examples/llm/llm_quant/quant_blocks.py @@ -12,7 +12,6 @@ import brevitas from brevitas.core.function_wrapper.shape import PermuteDims from brevitas.core.utils import SliceTensor -from brevitas.core.utils import StatelessBuffer class OverSubChannelBlockView(brevitas.jit.ScriptModule): @@ -33,58 +32,6 @@ def forward(self, x: torch.Tensor): return y -class AbsMaxKeepDim(brevitas.jit.ScriptModule): - __constants__ = ['stats_reduce_dim'] - - def __init__(self, stats_reduce_dim) -> None: - super(AbsMaxKeepDim, self).__init__() - self.stats_reduce_dim = stats_reduce_dim - - @brevitas.jit.script_method - def forward(self, x: Tensor): - if self.stats_reduce_dim is not None: - y = torch.max(torch.abs(x), dim=self.stats_reduce_dim, keepdim=True)[0] - else: - y = torch.max(torch.abs(x)) - return y - - -class AbsMinMaxKeepDim(brevitas.jit.ScriptModule): - __constants__ = ['stats_reduce_dim'] - - def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: - super(AbsMinMaxKeepDim, self).__init__() - self.stats_reduce_dim = stats_reduce_dim - - @brevitas.jit.script_method - def forward(self, x: Tensor): - if self.stats_reduce_dim is None: - return torch.abs(torch.max(x) - torch.min(x)) - else: - max_val = torch.max(x, dim=self.stats_reduce_dim, keepdim=True)[0] - min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=True)[0] - return torch.abs(max_val - min_val) - - -class NegativeMinOrZeroKeepDim(brevitas.jit.ScriptModule): - __constants__ = ['stats_reduce_dim'] - - def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: - super(NegativeMinOrZeroKeepDim, self).__init__() - self.stats_reduce_dim = stats_reduce_dim - self.zero = StatelessBuffer(torch.tensor(0.0)) - - @brevitas.jit.script_method - def forward(self, x: Tensor) -> Tensor: - if self.stats_reduce_dim is None: - min_val = torch.min(x, keepdim=True) - else: - min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=True)[0] - min_val = torch.where( - min_val <= self.zero().to(min_val.dtype), min_val, self.zero().to(min_val.dtype)) - return min_val - - class ExpandReshapeScalingWrapper(brevitas.jit.ScriptModule): __constants__ = ['expanded_scaling_shape', 'reshaped_scaling_shape'] @@ -138,3 +85,51 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor): zero_point = self.wrapped_zero_point_impl.scale_shift_zero_point( -zero_point_stats, scale, bit_width) return zero_point + + +class RuntimeDynamicStatsScaling(brevitas.jit.ScriptModule): + __constants__ = ['dynamic_scaling_broadcastable_shape'] + + def __init__( + self, + scaling_stats_impl: nn.Module, + dynamic_scaling_broadcastable_shape: Tuple[int, ...], + scaling_stats_input_view_shape_impl: nn.Module) -> None: + super(RuntimeDynamicStatsScaling, self).__init__() + self.scaling_stats_input_view_shape_impl = scaling_stats_input_view_shape_impl + self.stats_impl = scaling_stats_impl + self.dynamic_scaling_broadcastable_shape = dynamic_scaling_broadcastable_shape + + @brevitas.jit.script_method + def forward(self, x) -> Tensor: + x = self.scaling_stats_input_view_shape_impl(x) + x = self.stats_impl(x) + x = x.view(self.dynamic_scaling_broadcastable_shape) + return x + + +class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): + + def __init__(self, group_size: int, group_dim: int, scaling_stats_impl: nn.Module) -> None: + super(RuntimeDynamicGroupStatsScaling, self).__init__() + self.group_size = group_size + self.group_dim = group_dim + self.scaling_stats_impl = scaling_stats_impl + + @brevitas.jit.script_method + def group_scaling_reshape(self, stats_input): + tensor_shape = stats_input.shape + tensor_shape_list = list(tensor_shape) + tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) + tensor_shape_list.insert(self.group_dim + 1, self.group_size) + stats_input = stats_input.view(tensor_shape_list) + return stats_input + + @brevitas.jit.script_method + def forward(self, stats_input) -> Tensor: + stats_input_reshaped = self.group_scaling_reshape(stats_input) + out = self.scaling_stats_impl(stats_input_reshaped) + out = torch.clamp_min(out, min=torch.tensor(1e-6, device=out.device, dtype=out.dtype)) + out = out.expand(stats_input_reshaped.shape) + out = out.reshape(stats_input.shape) + return out diff --git a/src/brevitas_examples/llm/llm_quant/quantize.py b/src/brevitas_examples/llm/llm_quant/quantize.py index 86768e272..f8eb0649f 100644 --- a/src/brevitas_examples/llm/llm_quant/quantize.py +++ b/src/brevitas_examples/llm/llm_quant/quantize.py @@ -26,6 +26,9 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE +from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerGroupFloat +from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerRowFloat +from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerTensorFloat from brevitas_examples.llm.llm_quant.quantizers import Int8ActPerRowFloat from brevitas_examples.llm.llm_quant.quantizers import Int8ActPerRowFloatMSE from brevitas_examples.llm.llm_quant.quantizers import IntWeightSymmetricGroupQuant @@ -62,24 +65,34 @@ 'sym': Int8WeightPerChannelFixedPointMSE},},}} INPUT_QUANT_MAP = { - 'float': { - 'stats': { - 'per_tensor': { - 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}, - 'per_row': { - 'sym': Int8ActPerRowFloat, 'asym': ShiftedUint8ActPerRowFloat},}, - 'mse': { - 'per_tensor': { - 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}, - 'per_row': { - 'sym': Int8ActPerRowFloatMSE, 'asym': ShiftedUint8ActPerRowFloatMSE},},}, - 'po2': { - 'stats': { - 'per_tensor': { - 'sym': Int8ActPerTensorFixedPoint},}, - 'mse': { - 'per_tensor': { - 'sym': Int8ActPerTensorFixedPointMSE},},}} + 'static': { + 'float': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}, + 'per_row': { + 'sym': Int8ActPerRowFloat, 'asym': ShiftedUint8ActPerRowFloat},}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}, + 'per_row': { + 'sym': Int8ActPerRowFloatMSE, 'asym': ShiftedUint8ActPerRowFloatMSE},},}, + 'po2': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPoint},}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPointMSE},},}}, + 'dynamic': { + 'float': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActDynamicPerTensorFloat}, + 'per_row': { + 'sym': Int8ActDynamicPerRowFloat}, + 'per_group': { + 'sym': Int8ActDynamicPerGroupFloat},}}}} def quantize_model( @@ -87,37 +100,43 @@ def quantize_model( dtype, weight_bit_width, weight_param_method, - weight_scale_type, + weight_scale_precision, weight_quant_type, weight_quant_granularity, weight_group_size, quantize_weight_zero_point, input_bit_width=None, + input_scale_precision=None, input_scale_type=None, input_param_method=None, input_quant_type=None, input_quant_granularity=None, + input_group_size=None, quantize_input_zero_point=False, seqlen=None): """ Replace float layers with quant layers in the target model """ # Retrive base input and weight quantizers - weight_quant = WEIGHT_QUANT_MAP[weight_scale_type][weight_param_method][ + weight_quant = WEIGHT_QUANT_MAP[weight_scale_precision][weight_param_method][ weight_quant_granularity][weight_quant_type] if input_bit_width is not None: - input_quant = INPUT_QUANT_MAP[input_scale_type][input_param_method][ + input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][input_param_method][ input_quant_granularity][input_quant_type] # Some activations in MHA should always be symmetric - sym_input_quant = INPUT_QUANT_MAP[input_scale_type][input_param_method][ - input_quant_granularity]['sym'] - # Linear layers with 2d input should always be per tensor - per_tensor_input_quant = INPUT_QUANT_MAP[input_scale_type][input_param_method][ - 'per_tensor'][input_quant_type] + sym_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][ + input_param_method][input_quant_granularity]['sym'] + # Linear layers with 2d input should always be per tensor or per group, as there is no row dimension + if input_quant_granularity == 'per_tensor' or input_quant_granularity == 'per_row': + linear_2d_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][ + input_param_method]['per_tensor'][input_quant_type] + else: + assert input_quant_granularity == 'per_group' + linear_2d_input_quant = input_quant else: input_quant = None sym_input_quant = None - per_tensor_input_quant = None + linear_2d_input_quant = None # Modify the weight quantizer based on the arguments passed in weight_quant = weight_quant.let( @@ -129,7 +148,7 @@ def quantize_model( # weight scale is converted to a standalone parameter # This is done already by default in the per_group quantizer if weight_quant_granularity != 'per_group': - weight_quant = weight_quant.let(weight_scale_impl_type='parameter_from_stats') + weight_quant = weight_quant.let(scaling_impl_type='parameter_from_stats') # weight zero-point is converted to a standalone parameter # This is done already by default in the per_group quantizer if weight_quant_type == 'asym' and weight_quant_granularity != 'per_group': @@ -142,20 +161,34 @@ def quantize_model( 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point, 'dtype': dtype}) - if input_quant_granularity == 'per_row': + if input_scale_type == 'static' and input_quant_granularity == 'per_row': # QuantMHA internally always uses Seq, B, E input_quant = input_quant.let( **{ - 'channel_dim': 0, 'per_channel_broadcastable_shape': (seqlen, 1, 1), 'scaling_stats_permute_dims': (0, 1, 2)}) + elif input_scale_type == 'dynamic': + if input_quant_granularity == 'per_tensor': + input_quant = input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (1, -1, 1), + 'permute_dims': (1, 0, 2), + 'stats_reduce_dim': 1}) + elif input_quant_granularity == 'per_row': + input_quant = input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (seqlen, -1, 1), + 'permute_dims': (1, 0, 2), + 'stats_reduce_dim': 2}) + elif input_quant_granularity == 'per_group': + input_quant = input_quant.let(**{'group_dim': 2, 'group_size': input_group_size}) if sym_input_quant is not None: sym_input_quant = sym_input_quant.let( **{ 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point, 'dtype': dtype}) - if input_quant_granularity == 'per_row': + if input_scale_type == 'static' and input_quant_granularity == 'per_row': q_scaled_quant = sym_input_quant.let( **{ 'per_channel_broadcastable_shape': (1, seqlen, 1), @@ -166,19 +199,64 @@ def quantize_model( 'scaling_stats_permute_dims': (2, 0, 1)}) v_quant = q_scaled_quant attn_output_weights_quant = q_scaled_quant + elif input_scale_type == 'dynamic': + if input_quant_granularity == 'per_tensor': + q_scaled_quant = sym_input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (-1, 1, 1), + 'permute_dims': None, + 'stats_reduce_dim': 1}) + k_transposed_quant = sym_input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (-1, 1, 1), + 'permute_dims': None, + 'stats_reduce_dim': 1}) + elif input_quant_granularity == 'per_row': + q_scaled_quant = sym_input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (-1, seqlen, 1), + 'permute_dims': None, + 'stats_reduce_dim': 2}) + k_transposed_quant = sym_input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (-1, 1, seqlen), + 'permute_dims': None, + 'stats_reduce_dim': 1}) + elif input_quant_granularity == 'per_group': + q_scaled_quant = sym_input_quant.let( + **{ + 'group_dim': 2, 'group_size': input_group_size}) + k_transposed_quant = sym_input_quant.let( + **{ + 'group_dim': 1, 'group_size': input_group_size}) + v_quant = q_scaled_quant + attn_output_weights_quant = q_scaled_quant else: q_scaled_quant = v_quant = k_transposed_quant = attn_output_weights_quant = sym_input_quant else: q_scaled_quant = v_quant = k_transposed_quant = attn_output_weights_quant = None - if per_tensor_input_quant is not None: - per_tensor_input_quant = per_tensor_input_quant.let( + if linear_2d_input_quant is not None: + linear_2d_input_quant = linear_2d_input_quant.let( **{ 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point, 'dtype': dtype}) + if input_scale_type == 'dynamic': + # Note: this breaks if applied to 3d Linear inputs, + # in case standard MHA layers haven't been inserted + if input_quant_granularity == 'per_tensor' or input_quant_granularity == 'per_row': + linear_2d_input_quant = linear_2d_input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (-1, 1), + 'permute_dims': None, + 'stats_reduce_dim': 1}) + elif input_quant_granularity == 'per_group': + linear_2d_input_quant = linear_2d_input_quant.let( + **{ + 'group_dim': 1, 'group_size': input_group_size}) quant_linear_kwargs = { - 'input_quant': per_tensor_input_quant, 'weight_quant': weight_quant, 'dtype': dtype} + 'input_quant': linear_2d_input_quant, 'weight_quant': weight_quant, 'dtype': dtype} quant_mha_kwargs = { 'in_proj_input_quant': input_quant, @@ -190,7 +268,7 @@ def quantize_model( 'q_scaled_quant': q_scaled_quant, 'k_transposed_quant': k_transposed_quant, 'v_quant': v_quant, - 'out_proj_input_quant': per_tensor_input_quant, + 'out_proj_input_quant': linear_2d_input_quant, 'out_proj_weight_quant': weight_quant, 'out_proj_bias_quant': None, 'out_proj_output_quant': None, diff --git a/src/brevitas_examples/llm/llm_quant/quantizers.py b/src/brevitas_examples/llm/llm_quant/quantizers.py index 107afad5a..9848d994c 100644 --- a/src/brevitas_examples/llm/llm_quant/quantizers.py +++ b/src/brevitas_examples/llm/llm_quant/quantizers.py @@ -5,7 +5,14 @@ from torch import nn -from brevitas.core.scaling import StatsFromParameterScaling +from brevitas.core.function_wrapper.shape import OverBatchOverOutputChannelView +from brevitas.core.function_wrapper.shape import OverBatchOverTensorView +from brevitas.core.function_wrapper.shape import OverTensorView +from brevitas.core.scaling import ParameterFromStatsFromParameterScaling +from brevitas.core.stats import AbsMinMax +from brevitas.core.stats import NegativeMinOrZero +from brevitas.core.stats import NegativePercentileOrZero +from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.inject import this from brevitas.inject import value @@ -49,10 +56,9 @@ def reshaped_scaling_shape(module): scaling_input_shape = this.expanded_scaling_shape scaling_stats_input_view_shape_impl = OverSubChannelBlockView scaling_impl = ExpandReshapeScalingWrapper - wrapped_scaling_impl = StatsFromParameterScaling - scaling_stats_impl = AbsMaxKeepDim # scale is converted to a parameter right away - scaling_impl_type = 'parameter_from_stats' + wrapped_scaling_impl = ParameterFromStatsFromParameterScaling + keepdim = True stats_reduce_dim = 2 # Set bit_width and block size externally bit_width = None @@ -70,8 +76,9 @@ class ShiftedUintWeightAsymmetricGroupQuant(IntWeightSymmetricGroupQuant): zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl zero_point_stats_input_concat_dim = 0 zero_point_impl = ExpandReshapeZeroPointWrapper - zero_point_stats_impl = NegativeMinOrZeroKeepDim - scaling_stats_impl = AbsMinMaxKeepDim + zero_point_stats_impl = NegativeMinOrZero + scaling_stats_impl = AbsMinMax + keepdim = True # zero-point is converted to a parameter right away wrapped_zero_point_impl = ParameterFromStatsFromParameterZeroPoint quantize_zero_point = False @@ -92,3 +99,34 @@ class ShiftedUint8ActPerRowFloat(ShiftedUint8ActPerTensorFloat): class ShiftedUint8ActPerRowFloatMSE(ShiftedUint8ActPerTensorFloatMSE): scaling_per_output_channel = True + + +class Int8ActDynamicPerTensorFloat(Int8ActPerTensorFloat): + """ + Symmetric quantizer with per tensor dynamic scale. + """ + scaling_impl = RuntimeDynamicStatsScaling + scaling_stats_input_view_shape_impl = OverBatchOverTensorView + scaling_stats_op = 'max' + + +class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat): + """ + Symmetric quantizer with per row dynamic scale. + """ + scaling_impl = RuntimeDynamicStatsScaling + scaling_stats_input_view_shape_impl = OverBatchOverOutputChannelView + scaling_stats_op = 'max' + + +class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat): + """ + Symmetric quantizer with per row dynamic scale. + """ + scaling_impl = RuntimeDynamicGroupStatsScaling + keepdim = True + scaling_stats_op = 'max' + + @value + def stats_reduce_dim(group_dim): + return group_dim + 1 diff --git a/src/brevitas_examples/llm/llm_quant/run_utils.py b/src/brevitas_examples/llm/llm_quant/run_utils.py index c81a62e4f..0ba096c4a 100644 --- a/src/brevitas_examples/llm/llm_quant/run_utils.py +++ b/src/brevitas_examples/llm/llm_quant/run_utils.py @@ -145,18 +145,20 @@ def apply_layer_ptq_fn( @contextmanager -def cast_to_float32(model): +def cast_to_float32(model, target_dtype): dtype_dict = {} - for name, p in model.named_parameters(): + for name, p in model.state_dict().items(): + # This allows to pick up duplicated parameters dtype_dict[name] = p.dtype - for name, b in model.named_buffers(): - dtype_dict[name] = b.dtype if any(dtype != torch.float32 for dtype in dtype_dict.values()): model.to(dtype=torch.float32) try: yield model finally: - for name, p in model.named_parameters(): - p.data = p.data.to(dtype_dict[name]) - for name, b in model.named_buffers(): - b.data = b.data.to(dtype_dict[name]) + for name, p in {**dict(model.named_parameters()), **dict(model.named_buffers())}.items(): + if name in dtype_dict: + p.data = p.data.to(dtype_dict[name]) + else: + # target_dtype covers any new tensors that might have been + # introduced in the process (e.g. during equalization) + p.data = p.data.to(target_dtype) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 459705185..a802c743b 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -43,7 +43,7 @@ choices=['stats', 'mse'], help='How scales/zero-point are determined. Default: stats.') parser.add_argument( - '--weight-scale-type', + '--weight-scale-precision', type=str, default='float', choices=['float', 'po2'], @@ -77,13 +77,21 @@ type=str, default='stats', choices=['stats', 'mse'], - help='How scales/zero-point are determined. Default: stats.') + help= + 'How scales/zero-point are determined. Default: stats (percentile for static, absmax minmax for dynamic).' +) parser.add_argument( - '--input-scale-type', + '--input-scale-precision', type=str, default='float', choices=['float', 'po2'], help='Whether input scale is a float value or a po2. Default: float.') +parser.add_argument( + '--input-scale-type', + type=str, + default='float', + choices=['static', 'dynamic'], + help='Whether input scale is a static value or a dynamic value.') parser.add_argument( '--input-quant-type', type=str, @@ -94,8 +102,13 @@ '--input-quant-granularity', type=str, default='per_tensor', - choices=['per_tensor', 'per_row'], + choices=['per_tensor', 'per_row', 'per_group'], help='Granularity for scales/zero-point of inputs. Default: per_tensor.') +parser.add_argument( + '--input-group-size', + type=int, + default=64, + help='Group size for per_group input quantization. Default: 64.') parser.add_argument( '--quantize-input-zero-point', action='store_true', help='Quantize input zero-point.') parser.add_argument('--gptq', action='store_true', help='Apply GPTQ.') @@ -152,6 +165,8 @@ def model_export(model, ref_input, args): def validate(args): if not args.no_quantize: + if args.export_target is not None and args.input_bit_width is not None: + assert args.input_scale_type == 'static', "Only static scale supported for export currently." if args.export_target == 'sharded_torchmlir_group_weight': assert args.weight_quant_granularity == 'per_group', "Sharded torch group export requires per group weight quant." assert args.input_bit_width is None, "Sharded torch group weight export doesn't support input quant." @@ -172,8 +187,10 @@ def validate(args): assert args.quantize_weight_zero_point, "Quantized weight zero point required." if args.input_bit_width is not None and args.input_quant_type == 'asym': assert args.quantize_input_zero_point, "Quantized input zero point required." - if args.input_bit_width: - assert args.act_calibration, "Input quantization is being applied without activation calibration. Set --act-calibration." + if (args.input_bit_width and + (args.input_scale_type == 'static' or + (args.input_scale_type == 'dynamic' and args.input_quant_type == 'asym'))): + assert args.act_calibration, "Static input quantization is being applied without activation calibration. Set --act-calibration." def main(): @@ -203,7 +220,7 @@ def main(): # since currently there is support only for merging into Linear if args.ln_affine_merge: print("Apply LN affine merge...") - apply_layernorm_affine_merge(model, ref_kwargs={'input_ids': calibration_loader[0]}) + apply_layernorm_affine_merge(model, dtype, ref_kwargs={'input_ids': calibration_loader[0]}) print("LN affine merge applied.") # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing @@ -215,13 +232,14 @@ def main(): if args.weight_equalization: print("Apply weight equalization...") - apply_weight_equalization(model, ref_kwargs={'input_ids': calibration_loader[0]}) + apply_weight_equalization(model, dtype, ref_kwargs={'input_ids': calibration_loader[0]}) print("Weight equalization applied.") if args.act_equalization is not None: print("Apply act equalization (SmoothQuant)...") apply_act_equalization( model, + dtype, args.act_equalization, calibration_loader, args.nsamples, @@ -236,15 +254,17 @@ def main(): weight_quant_type=args.weight_quant_type, weight_bit_width=args.weight_bit_width, weight_param_method=args.weight_param_method, - weight_scale_type=args.weight_scale_type, + weight_scale_precision=args.weight_scale_precision, weight_quant_granularity=args.weight_quant_granularity, weight_group_size=args.weight_group_size, quantize_weight_zero_point=args.quantize_weight_zero_point, input_bit_width=args.input_bit_width, input_quant_type=args.input_quant_type, input_param_method=args.input_param_method, + input_scale_precision=args.input_scale_precision, input_scale_type=args.input_scale_type, input_quant_granularity=args.input_quant_granularity, + input_group_size=args.input_group_size, quantize_input_zero_point=args.quantize_input_zero_point, seqlen=args.seqlen) print("Model quantization applied.") diff --git a/src/brevitas_examples/llm/test_linear_mlir_export.py b/src/brevitas_examples/llm/test_linear_mlir_export.py index 870ec9406..417721a58 100644 --- a/src/brevitas_examples/llm/test_linear_mlir_export.py +++ b/src/brevitas_examples/llm/test_linear_mlir_export.py @@ -59,7 +59,7 @@ def quantize_and_export(args): weight_bit_width=args.weight_bit_width, weight_group_size=args.weight_group_size, weight_param_method='stats', - weight_scale_type='float', + weight_scale_precision='float', weight_quant_granularity='per_group', quantize_weight_zero_point=False)