From b889bb25162c0f387c89f0c202991d23a4cfd0f7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 5 Sep 2024 16:13:17 +0200 Subject: [PATCH] Feat (mx): PTQ MX + Float support (#1010) --------- Co-authored-by: Nick Fraser --- .../common/generative/quant_blocks.py | 5 +- .../imagenet_classification/ptq/README.md | 26 +++--- .../benchmark/ptq_benchmark_torchvision.py | 4 + .../imagenet_classification/ptq/ptq_common.py | 91 ++++++++++++++----- .../ptq/ptq_evaluate.py | 16 ++-- 5 files changed, 94 insertions(+), 48 deletions(-) diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index 18149578d..696340a2c 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -3,15 +3,12 @@ # SPDX-License-Identifier: BSD-3-Clause """ -from typing import Callable, List, Optional, Tuple +from typing import Callable import torch from torch import Tensor import torch.nn as nn -import brevitas -from brevitas.core.function_wrapper.shape import PermuteDims -from brevitas.core.utils import SliceTensor from brevitas.core.zero_point import _ScaleShiftZeroPoint from brevitas.function.ops_ste import abs_binary_sign_grad diff --git a/src/brevitas_examples/imagenet_classification/ptq/README.md b/src/brevitas_examples/imagenet_classification/ptq/README.md index 5387014e9..74653f96b 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/README.md +++ b/src/brevitas_examples/imagenet_classification/ptq/README.md @@ -80,7 +80,8 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir [--bias-bit-width {32,16,None}] [--act-quant-type {sym,asym}] [--weight-quant-type {sym,asym}] - [--weight-quant-granularity {per_tensor,per_channel}] + [--weight-quant-granularity {per_tensor,per_channel,per_group}] + [--act-quant-granularity {per_tensor,per_group}] [--weight-quant-calibration-type {stats,mse}] [--act-equalization {fx,layerwise,None}] [--act-quant-calibration-type {stats,mse}] @@ -90,11 +91,11 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir [--learned-round-lr LEARNED_ROUND_LR] [--act-quant-percentile ACT_QUANT_PERCENTILE] [--export-onnx-qcdq] [--export-torch-qcdq] - [--scaling-per-output-channel | --no-scaling-per-output-channel] [--bias-corr | --no-bias-corr] [--graph-eq-merge-bias | --no-graph-eq-merge-bias] [--weight-narrow-range | --no-weight-narrow-range] - [--gpfq-p GPFQ_P] [--quant-format {int,float}] + [--gpfq-p GPFQ_P] + [--quant-format {int,float,float_ocp}] [--layerwise-first-last-mantissa-bit-width LAYERWISE_FIRST_LAST_MANTISSA_BIT_WIDTH] [--layerwise-first-last-exponent-bit-width LAYERWISE_FIRST_LAST_EXPONENT_BIT_WIDTH] [--weight-mantissa-bit-width WEIGHT_MANTISSA_BIT_WIDTH] @@ -104,6 +105,7 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir [--accumulator-bit-width ACCUMULATOR_BIT_WIDTH] [--onnx-opset-version ONNX_OPSET_VERSION] [--channel-splitting-ratio CHANNEL_SPLITTING_RATIO] + [--compression-rate COMPRESSION_RATE] [--gptq | --no-gptq] [--gpfq | --no-gpfq] [--gpfa2q | --no-gpfa2q] [--gpxq-act-order | --no-gpxq-act-order] @@ -115,7 +117,7 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir PyTorch ImageNet PTQ Validation -options: +optional arguments: -h, --help show this help message and exit --calibration-dir CALIBRATION_DIR Path to folder containing Imagenet calibration folder @@ -176,7 +178,9 @@ options: Activation quantization type (default: sym) --weight-quant-type {sym,asym} Weight quantization type (default: sym) - --weight-quant-granularity {per_tensor,per_channel} + --weight-quant-granularity {per_tensor,per_channel,per_group} + Weight quantization type (default: per_tensor) + --act-quant-granularity {per_tensor,per_group} Activation quantization type (default: per_tensor) --weight-quant-calibration-type {stats,mse} Weight quantization calibration type (default: stats) @@ -201,12 +205,6 @@ options: (default: 99.999) --export-onnx-qcdq If true, export the model in onnx qcdq format --export-torch-qcdq If true, export the model in torch qcdq format - --scaling-per-output-channel - Enable Weight scaling per output channel (default: - enabled) - --no-scaling-per-output-channel - Disable Weight scaling per output channel (default: - enabled) --bias-corr Enable Bias correction after calibration (default: enabled) --no-bias-corr Disable Bias correction after calibration (default: @@ -224,7 +222,7 @@ options: Disable Narrow range for weight quantization (default: disabled) --gpfq-p GPFQ_P P parameter for GPFQ (default: 1.0) - --quant-format {int,float} + --quant-format {int,float,float_ocp} Quantization format to use for weights and activations (default: int) --layerwise-first-last-mantissa-bit-width LAYERWISE_FIRST_LAST_MANTISSA_BIT_WIDTH @@ -252,6 +250,9 @@ options: --channel-splitting-ratio CHANNEL_SPLITTING_RATIO Split Ratio for Channel Splitting. When set to 0.0, Channel Splitting will not be applied. (default: 0.0) + --compression-rate COMPRESSION_RATE + Specify compression rate < 1.0 for random projection. + Default is 0.0 and does not use RP. --gptq Enable GPTQ (default: disabled) --no-gptq Disable GPTQ (default: disabled) --gpfq Enable GPFQ (default: disabled) @@ -280,7 +281,6 @@ options: --no-uint_sym_act_for_unsigned_values Disable Use unsigned act quant when possible (default: enabled) - ``` The script requires to specify the calibration folder (`--calibration-dir`), from which the calibration samples will be taken (configurable with the `--calibration-samples` argument), and a validation folder (`--validation-dir`). diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 668eee22c..69a5f626a 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -89,7 +89,9 @@ def unique(sequence): 'act_bit_width': [8], # Act bit width 'bias_bit_width': [32], # Bias Bit-Width for Po2 scale 'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel + 'act_quant_granularity': ['per_tensor'], # Scaling Per Output Channel 'act_quant_type': ['sym'], # Act Quant Type + 'act_scale_computation_type': ['static'], # Act Quant Type 'act_param_method': ['stats'], # Act Param Method 'weight_param_method': ['mse'], # Weight Quant Type 'bias_corr': [True], # Bias Correction @@ -240,7 +242,9 @@ def ptq_torchvision_models(args): weight_param_method=config_namespace.weight_param_method, act_param_method=config_namespace.act_param_method, bias_bit_width=config_namespace.bias_bit_width, + act_scale_computation_type=config_namespace.act_scale_computation_type, weight_quant_granularity=config_namespace.weight_quant_granularity, + act_quant_granularity=config_namespace.act_quant_granularity, act_quant_percentile=config_namespace.act_quant_percentile, act_quant_type=config_namespace.act_quant_type, scale_factor_type=config_namespace.scale_factor_type, diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 3c6b82243..bac596be5 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -29,6 +29,20 @@ from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloatMSE +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloatMSE +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloatMSE +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloatMSE +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight +from brevitas.quant.experimental.mx_quant_ocp import MXInt8WeightMSE +from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8Weight +from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8WeightMSE from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint @@ -96,12 +110,16 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'per_tensor': { 'sym': Int8WeightPerTensorFixedPoint}, 'per_channel': { - 'sym': Int8WeightPerChannelFixedPoint},}, + 'sym': Int8WeightPerChannelFixedPoint}, + 'per_group': { + 'sym': MXInt8Weight, 'asym': ShiftedMXUInt8Weight}}, 'mse': { 'per_tensor': { 'sym': Int8WeightPerTensorFixedPointMSE}, 'per_channel': { - 'sym': Int8WeightPerChannelFixedPointMSE}},}}, + 'sym': Int8WeightPerChannelFixedPointMSE}, + 'per_group': { + 'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}},}}, 'float': { 'float_scale': { 'stats': { @@ -113,7 +131,26 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'per_tensor': { 'sym': Fp8e4m3WeightPerTensorFloatMSE}, 'per_channel': { - 'sym': Fp8e4m3WeightPerChannelFloatMSE}}}}} + 'sym': Fp8e4m3WeightPerChannelFloatMSE}}}}, + 'float_ocp': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3OCPWeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e4m3OCPWeightPerChannelFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Fp8e4m3OCPWeightPerTensorFloatMSE}, + 'per_channel': { + 'sym': Fp8e4m3OCPWeightPerChannelFloatMSE}}}, + 'po2_scale': { + 'stats': { + 'per_group': { + 'sym': MXFloat8e4m3Weight}}, + 'mse': { + 'per_group': { + 'sym': MXFloat8e4m3WeightMSE}}}}} INPUT_QUANT_MAP = { 'int': { @@ -139,7 +176,10 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'stats': { 'per_tensor': { 'sym': CNNInt8DynamicActPerTensorFloat, - 'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}}}, + 'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}, + 'po2_scale': { + 'stats': { + 'per_group': MXInt8Act}}}}, 'float': { 'static': { 'float_scale': { @@ -148,7 +188,21 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'sym': Fp8e4m3ActPerTensorFloat}}, 'mse': { 'per_tensor': { - 'sym': Fp8e4m3ActPerTensorFloatMSE}}}}}} + 'sym': Fp8e4m3ActPerTensorFloatMSE}}}}}, + 'float_ocp': { + 'static': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3OCPActPerTensorFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Fp8e4m3OCPActPerTensorFloatMSE}}}}, + 'dynamic': { + 'po2_scale': { + 'stats': { + 'per_group': { + 'sym': MXFloat8e4m3Act}}}}}} def quantize_model( @@ -252,14 +306,14 @@ def layerwise_bit_width_fn_weight(module): weight_bit_width_dict['weight_bit_width'] = weight_bit_width act_bit_width_dict['act_bit_width'] = act_bit_width - if quant_format == 'float' and backend == 'layerwise': + if 'float' in quant_format and backend == 'layerwise': weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act weight_bit_width_dict['weight_mantissa_bit_width'] = layerwise_bit_width_fn_weight_mantissa weight_bit_width_dict['weight_exponent_bit_width'] = layerwise_bit_width_fn_weight_exponent act_bit_width_dict['act_mantissa_bit_width'] = layerwise_bit_width_fn_act_mantissa act_bit_width_dict['act_exponent_bit_width'] = layerwise_bit_width_fn_act_exponent - elif quant_format == 'float' and backend != 'layerwise': + elif 'float' in quant_format and backend != 'layerwise': weight_bit_width_dict['weight_bit_width'] = weight_bit_width act_bit_width_dict['act_bit_width'] = act_bit_width weight_bit_width_dict['weight_mantissa_bit_width'] = weight_mantissa_bit_width @@ -334,12 +388,12 @@ def kwargs_prefix(prefix, weight_kwargs): return {prefix + k: v for k, v in weight_kwargs.items()} weight_bit_width_dict = {'bit_width': weight_bit_width} - if weight_quant_format == 'float': + if 'float' in weight_quant_format: weight_bit_width_dict['exponent_bit_width'] = weight_exponent_bit_width weight_bit_width_dict['mantissa_bit_width'] = weight_mantissa_bit_width act_bit_width_dict = {'bit_width': act_bit_width} - if act_quant_format == 'float': + if 'float' in act_quant_format: act_bit_width_dict['exponent_bit_width'] = act_exponent_bit_width act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width @@ -355,16 +409,12 @@ def kwargs_prefix(prefix, weight_kwargs): # Some activations in MHA should always be symmetric sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][ act_scale_type][act_param_method][act_quant_granularity]['sym'] - # Linear layers with 2d input should always be per tensor - per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][ - act_scale_type][act_param_method]['per_tensor'][act_quant_type] + act_quant = act_quant.let(**act_bit_width_dict) sym_act_quant = sym_act_quant.let(**act_bit_width_dict) - per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict) else: act_quant = None sym_act_quant = None - per_tensor_act_quant = None # Modify the weight quantizer based on the arguments passed in weight_quant = weight_quant.let( @@ -383,13 +433,6 @@ def kwargs_prefix(prefix, weight_kwargs): sym_act_quant = sym_act_quant.let( **{ 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) - if per_tensor_act_quant is not None: - per_tensor_act_quant = per_tensor_act_quant.let( - **{ - 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) - if act_quant_type == 'asym' and act_quant_percentile is not None: - per_tensor_act_quant = per_tensor_act_quant.let( - **{'low_percentile_q': 100 - act_quant_percentile}) weight_quant_dict = {'weight_quant': weight_quant} @@ -431,9 +474,9 @@ def kwargs_prefix(prefix, weight_kwargs): unsigned_quant_act_kwargs['signed'] = False # Layerwise is basic quant kwargs + input_quant - layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant} + layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': act_quant} - layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': per_tensor_act_quant} + layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': act_quant} quant_layer_map = { torch.nn.Linear: (qnn.QuantLinear, quant_wbiol_kwargs), @@ -526,7 +569,7 @@ def apply_gptq(calib_loader, model, act_order=False): dtype = next(model.parameters()).dtype device = next(model.parameters()).device with torch.no_grad(): - with gptq_mode(model, act_order=act_order, use_quant_activations=False) as gptq: + with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq: gptq_model = gptq.model for i in tqdm(range(gptq.num_layers)): for i, (images, target) in enumerate(calib_loader): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 3a9bb29fa..c960a89e6 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -120,7 +120,12 @@ def parse_type(v, default_type): parser.add_argument( '--weight-quant-granularity', default='per_tensor', - choices=['per_tensor', 'per_channel'], + choices=['per_tensor', 'per_channel', 'per_group'], + help='Weight quantization type (default: per_tensor)') +parser.add_argument( + '--act-quant-granularity', + default='per_tensor', + choices=['per_tensor', 'per_group'], help='Activation quantization type (default: per_tensor)') parser.add_argument( '--weight-quant-calibration-type', @@ -168,11 +173,7 @@ def parse_type(v, default_type): '--export-torch-qcdq', action='store_true', help='If true, export the model in torch qcdq format') -add_bool_arg( - parser, - 'scaling-per-output-channel', - default=True, - help='Weight scaling per output channel (default: enabled)') + add_bool_arg( parser, 'bias-corr', default=True, help='Bias correction after calibration (default: enabled)') add_bool_arg( @@ -189,7 +190,7 @@ def parse_type(v, default_type): parser.add_argument( '--quant-format', default='int', - choices=['int', 'float'], + choices=['int', 'float', 'float_ocp'], help='Quantization format to use for weights and activations (default: int)') parser.add_argument( '--layerwise-first-last-mantissa-bit-width', @@ -409,6 +410,7 @@ def main(): weight_narrow_range=args.weight_narrow_range, weight_param_method=args.weight_quant_calibration_type, weight_quant_granularity=args.weight_quant_granularity, + act_quant_granularity=args.act_quant_granularity, weight_quant_type=args.weight_quant_type, layerwise_first_last_bit_width=args.layerwise_first_last_bit_width, act_bit_width=args.act_bit_width,