From ea8cb36c439e0144a4d6c444b3ec8a3536da3e53 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 5 Sep 2024 17:52:46 +0100 Subject: [PATCH 01/14] Added conv_in/conv_out to blacklist --- src/brevitas_examples/stable_diffusion/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index a1c4fef53..032b60539 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -222,7 +222,7 @@ def main(args): blacklist = [] non_blacklist = dict() for name, _ in pipe.unet.named_modules(): - if 'time_emb' in name: + if 'time_emb' in name or 'conv_in' in name or 'conv_out' in name: blacklist.append(name) else: if isinstance(_, (torch.nn.Linear, torch.nn.Conv2d)): From fcfedfaafe64e36fb199f8471544fa63012b4adc Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 5 Sep 2024 17:53:08 +0100 Subject: [PATCH 02/14] Switch SDPA quantization to FP8 FNUZ --- src/brevitas_examples/stable_diffusion/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 032b60539..82b2b9e98 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -364,7 +364,7 @@ def input_zp_stats_type(): dtype=dtype, device=args.device, weight_bit_width=weight_bit_width, - weight_quant_format='float_ocp_e4m3', + weight_quant_format='float_fnuz_e4m3', weight_quant_type='sym', weight_param_method=args.weight_param_method, weight_scale_precision=args.weight_scale_precision, @@ -373,7 +373,7 @@ def input_zp_stats_type(): quantize_weight_zero_point=args.quantize_weight_zero_point, quantize_input_zero_point=args.quantize_input_zero_point, input_bit_width=args.linear_output_bit_width, - input_quant_format='float_ocp_e4m3', + input_quant_format='float_fnuz_e4m3', input_scale_type=args.input_scale_type, input_scale_precision=args.input_scale_precision, input_param_method=args.input_param_method, From 60a94d845b13f0d9ae11bbbbbfb7e30e36b133a8 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 5 Sep 2024 17:47:43 +0000 Subject: [PATCH 03/14] Fix (example/sdxl): Send device to fid computation backend --- src/brevitas_examples/stable_diffusion/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 82b2b9e98..9b20e0a7b 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -520,7 +520,7 @@ def input_zp_stats_type(): if args.use_mlperf_inference: print(f"Computing accuracy with MLPerf pipeline") compute_mlperf_fid( - args.model, args.path_to_coco, pipe, args.prompt, output_dir, not args.vae_fp16_fix) + args.model, args.path_to_coco, pipe, args.prompt, output_dir, args.device, not args.vae_fp16_fix) else: print(f"Computing accuracy on default prompt") testing_prompts = TESTING_PROMPTS[:args.prompt] From 223700a3db20768606d4e6a88a24c6c283680813 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 6 Sep 2024 15:30:02 +0100 Subject: [PATCH 04/14] Fix (example/sdxl): Allow export of FP8 linear/conv layers. --- .../stable_diffusion/sd_quant/export.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 89d846a79..99b77a275 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -40,18 +40,18 @@ def handle_quant_param(layer, layer_dict): layer_dict['output_scale_shape'] = output_scale.shape layer_dict['input_scale'] = input_scale.numpy().tolist() layer_dict['input_scale_shape'] = input_scale.shape - layer_dict['input_zp'] = input_zp.numpy().tolist() + layer_dict['input_zp'] = input_zp.to(torch.float32).cpu().numpy().tolist() layer_dict['input_zp_shape'] = input_zp.shape - layer_dict['input_zp_dtype'] = str(torch.int8) + layer_dict['input_zp_dtype'] = str(input_zp.dtype) layer_dict['weight_scale'] = weight_scale.cpu().numpy().tolist() nelems = layer.weight.shape[0] weight_scale_shape = [nelems] + [1] * (layer.weight.data.ndim - 1) layer_dict['weight_scale_shape'] = weight_scale_shape - if torch.sum(weight_zp) != 0.: + if torch.sum(weight_zp.to(torch.float32)) != 0.: weight_zp = weight_zp - 128. # apply offset to have signed z - layer_dict['weight_zp'] = weight_zp.cpu().numpy().tolist() + layer_dict['weight_zp'] = weight_zp.to(torch.float32).cpu().numpy().tolist() layer_dict['weight_zp_shape'] = weight_scale_shape - layer_dict['weight_zp_dtype'] = str(torch.int8) + layer_dict['weight_zp_dtype'] = str(weight_zp.dtype) return layer_dict @@ -63,6 +63,8 @@ def export_quant_params(pipe, output_dir, export_vae=False): vae_output_path = os.path.join(output_dir, 'vae.safetensors') print(f"Saving vae to {vae_output_path} ...") from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager + export_manager = StdQCDQONNXManager + export_manager.change_weight_export(export_weight_q_node=True) # We're exporting FP weights + quantization parameters quant_params = dict() state_dict = pipe.unet.state_dict() state_dict = {k: v for (k, v) in state_dict.items() if 'tensor_quant' not in k} @@ -70,7 +72,7 @@ def export_quant_params(pipe, output_dir, export_vae=False): state_dict = {k.replace('.layer.', '.'): v for (k, v) in state_dict.items()} handled_quant_layers = set() - with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, StdQCDQONNXManager): + with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): for name, module in pipe.unet.named_modules(): if isinstance(module, EqualizedModule): if id(module.layer) in handled_quant_layers: From bec3da174fdf770a832bd958e7a69f18b14b3b68 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 10 Sep 2024 13:52:41 +0100 Subject: [PATCH 05/14] Fix: formatting --- src/brevitas_examples/stable_diffusion/main.py | 8 +++++++- src/brevitas_examples/stable_diffusion/sd_quant/export.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 9b20e0a7b..1cba32cd1 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -520,7 +520,13 @@ def input_zp_stats_type(): if args.use_mlperf_inference: print(f"Computing accuracy with MLPerf pipeline") compute_mlperf_fid( - args.model, args.path_to_coco, pipe, args.prompt, output_dir, args.device, not args.vae_fp16_fix) + args.model, + args.path_to_coco, + pipe, + args.prompt, + output_dir, + args.device, + not args.vae_fp16_fix) else: print(f"Computing accuracy on default prompt") testing_prompts = TESTING_PROMPTS[:args.prompt] diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 99b77a275..a42a35204 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -64,7 +64,8 @@ def export_quant_params(pipe, output_dir, export_vae=False): print(f"Saving vae to {vae_output_path} ...") from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager export_manager = StdQCDQONNXManager - export_manager.change_weight_export(export_weight_q_node=True) # We're exporting FP weights + quantization parameters + export_manager.change_weight_export( + export_weight_q_node=True) # We're exporting FP weights + quantization parameters quant_params = dict() state_dict = pipe.unet.state_dict() state_dict = {k: v for (k, v) in state_dict.items() if 'tensor_quant' not in k} From 28371a5671bce96fd0d896e6b7619a838fb4d3a0 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 10 Sep 2024 14:29:56 +0100 Subject: [PATCH 06/14] feat: (example/sdxl): Added option to specify quantization blacklist from commandline --- src/brevitas_examples/stable_diffusion/main.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 1cba32cd1..31b3030b5 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -222,7 +222,7 @@ def main(args): blacklist = [] non_blacklist = dict() for name, _ in pipe.unet.named_modules(): - if 'time_emb' in name or 'conv_in' in name or 'conv_out' in name: + if any(map(lambda x: x in name, args.quant_blacklist)): blacklist.append(name) else: if isinstance(_, (torch.nn.Linear, torch.nn.Conv2d)): @@ -797,6 +797,13 @@ def input_zp_stats_type(): type=int, default=16, help='Group size for per_group weight quantization. Default: 16.') + parser.add_argument( + '--quant-blacklist', + type=str, + default=['time_emb'], + nargs='*', + metavar='NAME', + help='A list of module names to exclude from quantization. Default: %(default)s') add_bool_arg( parser, 'quantize-weight-zero-point', From f7684c3bfc8a9f7c0020fa592a417b7b7d7c7d3d Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 10 Sep 2024 18:22:32 +0100 Subject: [PATCH 07/14] Feat (example/sdxl): Allow customization of SDPA quant via the commandline --- .../stable_diffusion/main.py | 133 ++++++++++++++---- 1 file changed, 105 insertions(+), 28 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 31b3030b5..edb297daa 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -316,6 +316,29 @@ def input_zp_stats_type(): input_kwargs['zero_point_stats_impl'] = input_zp_stats_type + sdpa_kwargs = dict() + if args.sdpa_scale_stats_op == 'minmax': + + @value + def sdpa_scale_stats_type(): + if args.sdpa_quant_type == 'asym': + sdpa_scaling_stats_op = StatsOp.MIN_MAX + else: + sdpa_scaling_stats_op = StatsOp.MAX + return sdpa_scaling_stats_op + + sdpa_kwargs['scaling_stats_op'] = sdpa_scale_stats_type + + if args.sdpa_zp_stats_op == 'minmax': + + @value + def sdpa_zp_stats_type(): + if args.sdpa_quant_type == 'asym': + zero_point_stats_impl = NegativeMinOrZero + return zero_point_stats_impl + + sdpa_kwargs['zero_point_stats_impl'] = sdpa_zp_stats_type + print("Applying model quantization...") quantizers = generate_quantizers( dtype=dtype, @@ -360,29 +383,29 @@ def input_zp_stats_type(): if args.quantize_sdp: assert args.share_qkv_quant, "Currently SDPA quantization is supported only with shared QKV quantization" # TODO: reformat this - float_sdpa_quantizers = generate_quantizers( + sdpa_quantizers = generate_quantizers( dtype=dtype, device=args.device, - weight_bit_width=weight_bit_width, - weight_quant_format='float_fnuz_e4m3', - weight_quant_type='sym', - weight_param_method=args.weight_param_method, - weight_scale_precision=args.weight_scale_precision, - weight_quant_granularity=args.weight_quant_granularity, - weight_group_size=args.weight_group_size, - quantize_weight_zero_point=args.quantize_weight_zero_point, - quantize_input_zero_point=args.quantize_input_zero_point, - input_bit_width=args.linear_output_bit_width, - input_quant_format='float_fnuz_e4m3', - input_scale_type=args.input_scale_type, - input_scale_precision=args.input_scale_precision, - input_param_method=args.input_param_method, - input_quant_type='sym', - input_quant_granularity=args.input_quant_granularity, - input_kwargs=input_kwargs) + weight_bit_width=args.sdpa_bit_width, + weight_quant_format=args.sdpa_quant_format, + weight_quant_type=args.sdpa_quant_type, + weight_param_method=args.sdpa_param_method, + weight_scale_precision=args.sdpa_scale_precision, + weight_quant_granularity=args.sdpa_quant_granularity, + weight_group_size=32, # Not used, since args.sdpa_quant_granularity == 'per_tensor' + quantize_weight_zero_point=args.quantize_sdpa_zero_point, + quantize_input_zero_point=args.quantize_sdpa_zero_point, + input_bit_width=args.sdpa_bit_width, + input_quant_format=args.sdpa_quant_format, + input_scale_type=args.sdpa_scale_type, + input_scale_precision=args.sdpa_scale_precision, + input_param_method=args.sdpa_param_method, + input_quant_type=args.sdpa_quant_type, + input_quant_granularity=args.sdpa_quant_granularity, + input_kwargs=sdpa_kwargs) # We generate all quantizers, but we are only interested in activation quantization for # the output of softmax and the output of QKV - input_quant = float_sdpa_quantizers[0] + input_quant = sdpa_quantizers[0] rewriter = ModuleToModuleByClass( Attention, QuantAttention, @@ -400,11 +423,11 @@ def input_zp_stats_type(): if args.override_conv_quant_config: print( - f"Overriding Conv2d quantization to weights: {float_sdpa_quantizers[1]}, inputs: {float_sdpa_quantizers[2]}" + f"Overriding Conv2d quantization to weights: {sdpa_quantizers[1]}, inputs: {sdpa_quantizers[2]}" ) conv_qkwargs = layer_map[torch.nn.Conv2d][1] - conv_qkwargs['input_quant'] = float_sdpa_quantizers[2] - conv_qkwargs['weight_quant'] = float_sdpa_quantizers[1] + conv_qkwargs['input_quant'] = sdpa_quantizers[2] + conv_qkwargs['weight_quant'] = sdpa_quantizers[1] layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) pipe.unet = layerwise_quantize( @@ -435,7 +458,7 @@ def input_zp_stats_type(): pipe = pipe.to(args.device) elif not args.dry_run: if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or - args.linear_output_bit_width > 0) and args.input_scale_type == 'static': + args.sdpa_bit_width > 0) and args.input_scale_type == 'static': print("Applying activation calibration") with torch.no_grad(), calibration_mode(pipe.unet): run_val_inference( @@ -707,11 +730,6 @@ def input_zp_stats_type(): type=int, default=0, help='Input bit width. Default: 0 (not quantized).') - parser.add_argument( - '--linear-output-bit-width', - type=int, - default=0, - help='Input bit width. Default: 0 (not quantized).') parser.add_argument( '--weight-param-method', type=str, @@ -797,6 +815,60 @@ def input_zp_stats_type(): type=int, default=16, help='Group size for per_group weight quantization. Default: 16.') + parser.add_argument( + '--sdpa-bit-width', + type=int, + default=0, + help='Scaled dot product attention bit width. Default: 0 (not quantized).') + parser.add_argument( + '--sdpa-param-method', + type=str, + default='stats', + choices=['stats', 'mse'], + help='How scales/zero-point are determined for scaled dot product attention. Default: %(default)s.') + parser.add_argument( + '--sdpa-scale-stats-op', + type=str, + default='minmax', + choices=['minmax', 'percentile'], + help='Define what statistics op to use for scaled dot product attention scale. Default: %(default)s.') + parser.add_argument( + '--sdpa-zp-stats-op', + type=str, + default='minmax', + choices=['minmax', 'percentile'], + help='Define what statistics op to use for scaled dot product attention zero point. Default: %(default)s.') + parser.add_argument( + '--sdpa-scale-precision', + type=str, + default='float_scale', + choices=['float_scale', 'po2_scale'], + help='Whether the scaled dot product attention scale is a float value or a po2. Default: %(default)s.') + parser.add_argument( + '--sdpa-quant-type', + type=str, + default='asym', + choices=['sym', 'asym'], + help='Scaled dot product attention quantization type. Default: %(default)s.') + parser.add_argument( + '--sdpa-quant-format', + type=quant_format_validator, + default='int', + help= + 'Scaled dot product attention quantization format. Either int or eXmY, with X+Y==input_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: %(default)s.' + ) + parser.add_argument( + '--sdpa-quant-granularity', + type=str, + default='per_tensor', + choices=['per_tensor'], + help='Granularity for scales/zero-point of scaled dot product attention. Default: %(default)s.') + parser.add_argument( + '--sdpa-scale-type', + type=str, + default='static', + choices=['static', 'dynamic'], + help='Whether to do static or dynamic scaled dot product attention quantization. Default: %(default)s.') parser.add_argument( '--quant-blacklist', type=str, @@ -819,6 +891,11 @@ def input_zp_stats_type(): 'quantize-input-zero-point', default=False, help='Quantize input zero-point. Default: Enabled') + add_bool_arg( + parser, + 'quantize-sdpa-zero-point', + default=False, + help='Quantize scaled dot product attention zero-point. Default: %(default)s') add_bool_arg( parser, 'export-cpu-float32', default=False, help='Export FP32 on CPU. Default: Disabled') add_bool_arg( From 63948f1d044fa32e0c43b20212869dff405f7704 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 10 Sep 2024 18:38:22 +0100 Subject: [PATCH 08/14] precommit --- .../stable_diffusion/main.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index edb297daa..371bc0349 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -392,7 +392,7 @@ def sdpa_zp_stats_type(): weight_param_method=args.sdpa_param_method, weight_scale_precision=args.sdpa_scale_precision, weight_quant_granularity=args.sdpa_quant_granularity, - weight_group_size=32, # Not used, since args.sdpa_quant_granularity == 'per_tensor' + weight_group_size=32, # Not used, since args.sdpa_quant_granularity == 'per_tensor' quantize_weight_zero_point=args.quantize_sdpa_zero_point, quantize_input_zero_point=args.quantize_sdpa_zero_point, input_bit_width=args.sdpa_bit_width, @@ -825,25 +825,33 @@ def sdpa_zp_stats_type(): type=str, default='stats', choices=['stats', 'mse'], - help='How scales/zero-point are determined for scaled dot product attention. Default: %(default)s.') + help= + 'How scales/zero-point are determined for scaled dot product attention. Default: %(default)s.' + ) parser.add_argument( '--sdpa-scale-stats-op', type=str, default='minmax', choices=['minmax', 'percentile'], - help='Define what statistics op to use for scaled dot product attention scale. Default: %(default)s.') + help= + 'Define what statistics op to use for scaled dot product attention scale. Default: %(default)s.' + ) parser.add_argument( '--sdpa-zp-stats-op', type=str, default='minmax', choices=['minmax', 'percentile'], - help='Define what statistics op to use for scaled dot product attention zero point. Default: %(default)s.') + help= + 'Define what statistics op to use for scaled dot product attention zero point. Default: %(default)s.' + ) parser.add_argument( '--sdpa-scale-precision', type=str, default='float_scale', choices=['float_scale', 'po2_scale'], - help='Whether the scaled dot product attention scale is a float value or a po2. Default: %(default)s.') + help= + 'Whether the scaled dot product attention scale is a float value or a po2. Default: %(default)s.' + ) parser.add_argument( '--sdpa-quant-type', type=str, @@ -862,13 +870,16 @@ def sdpa_zp_stats_type(): type=str, default='per_tensor', choices=['per_tensor'], - help='Granularity for scales/zero-point of scaled dot product attention. Default: %(default)s.') + help= + 'Granularity for scales/zero-point of scaled dot product attention. Default: %(default)s.') parser.add_argument( '--sdpa-scale-type', type=str, default='static', choices=['static', 'dynamic'], - help='Whether to do static or dynamic scaled dot product attention quantization. Default: %(default)s.') + help= + 'Whether to do static or dynamic scaled dot product attention quantization. Default: %(default)s.' + ) parser.add_argument( '--quant-blacklist', type=str, From 00b1e89f75fa59b06cf1da2253ba45d261923b8e Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 11 Sep 2024 12:12:34 +0100 Subject: [PATCH 09/14] Fix (example/sdxl): Changes how quantization overriding works for conv. --- src/brevitas_examples/stable_diffusion/main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 371bc0349..e4553fa93 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -389,11 +389,11 @@ def sdpa_zp_stats_type(): weight_bit_width=args.sdpa_bit_width, weight_quant_format=args.sdpa_quant_format, weight_quant_type=args.sdpa_quant_type, - weight_param_method=args.sdpa_param_method, - weight_scale_precision=args.sdpa_scale_precision, - weight_quant_granularity=args.sdpa_quant_granularity, - weight_group_size=32, # Not used, since args.sdpa_quant_granularity == 'per_tensor' - quantize_weight_zero_point=args.quantize_sdpa_zero_point, + weight_param_method=args.weight_param_method, + weight_scale_precision=args.weight_scale_precision, + weight_quant_granularity=args.weight_quant_granularity, # Must be compatible with `args.sdpa_quant_format` + weight_group_size=args.weight_group_size, + quantize_weight_zero_point=args.quantize_weight_zero_point, quantize_input_zero_point=args.quantize_sdpa_zero_point, input_bit_width=args.sdpa_bit_width, input_quant_format=args.sdpa_quant_format, From f5ed4b16b8a5794961a14036fdc4b7a8f627f495 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 11 Sep 2024 12:15:11 +0100 Subject: [PATCH 10/14] precommit --- src/brevitas_examples/stable_diffusion/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index e4553fa93..8c652367f 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -382,7 +382,7 @@ def sdpa_zp_stats_type(): if args.quantize_sdp: assert args.share_qkv_quant, "Currently SDPA quantization is supported only with shared QKV quantization" - # TODO: reformat this + # `args.weight_quant_granularity` must be compatible with `args.sdpa_quant_format` sdpa_quantizers = generate_quantizers( dtype=dtype, device=args.device, @@ -391,7 +391,7 @@ def sdpa_zp_stats_type(): weight_quant_type=args.sdpa_quant_type, weight_param_method=args.weight_param_method, weight_scale_precision=args.weight_scale_precision, - weight_quant_granularity=args.weight_quant_granularity, # Must be compatible with `args.sdpa_quant_format` + weight_quant_granularity=args.weight_quant_granularity, weight_group_size=args.weight_group_size, quantize_weight_zero_point=args.quantize_weight_zero_point, quantize_input_zero_point=args.quantize_sdpa_zero_point, From f57905ba36045d3f9d3ccb9a0ddfbfdcbd85b188 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 11 Sep 2024 15:36:16 +0100 Subject: [PATCH 11/14] support 2 versions of diffusers --- .../stable_diffusion/main.py | 69 ++++++--- .../stable_diffusion/sd_quant/nn.py | 140 +++++++++++++++++- 2 files changed, 187 insertions(+), 22 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 8c652367f..0fca839c1 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -12,12 +12,14 @@ import time from dependencies import value +import diffusers from diffusers import DiffusionPipeline from diffusers import EulerDiscreteScheduler from diffusers import StableDiffusionXLPipeline from diffusers.models.attention_processor import Attention -from diffusers.models.attention_processor import AttnProcessor import numpy as np +import packaging +import packaging.version import pandas as pd import torch from torch import nn @@ -35,7 +37,6 @@ from brevitas.graph.quantize import layerwise_quantize from brevitas.inject.enum import StatsOp from brevitas.nn.equalized_layer import EqualizedModule -from brevitas.nn.quant_activation import QuantIdentity from brevitas.utils.torch_utils import KwargsForwardHook from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers @@ -47,14 +48,17 @@ from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params +from brevitas_examples.stable_diffusion.sd_quant.nn import AttnProcessor from brevitas_examples.stable_diffusion.sd_quant.nn import AttnProcessor2_0 from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention +from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttentionLast from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import unet_input_shape +diffusers_version = packaging.version.parse(diffusers.__version__) TEST_SEED = 123456 torch.manual_seed(TEST_SEED) @@ -149,7 +153,7 @@ def main(args): calibration_prompts = CALIBRATION_PROMPTS if args.calibration_prompt_path is not None: calibration_prompts = load_calib_prompts(args.calibration_prompt_path) - print(args.calibration_prompt, len(calibration_prompts)) + assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available" calibration_prompts = calibration_prompts[:args.calibration_prompt] @@ -178,18 +182,29 @@ def main(args): args.model, torch_dtype=dtype, variant=variant, use_safetensors=True) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) pipe.vae.config.force_upcast = True - if args.share_qkv_quant: - rewriter = ModuleToModuleByClass( - Attention, - QuantizableAttention, - query_dim=lambda module: module.to_q.in_features, - dim_head=lambda module: math.ceil(1 / (module.scale ** 2)), - bias=lambda module: hasattr(module.to_q, 'bias') and module.to_q.bias is not None, - processor=AttnProcessor2_0(), - dtype=dtype, - norm_num_groups=lambda module: None - if module.group_norm is None else module.group_norm.num_groups) - rewriter.apply(pipe.unet) + is_mlperf_diffusers = diffusers_version == packaging.version.parse('0.21.2') + + AttClass = Attention + if is_mlperf_diffusers: + QuantAttClass = QuantAttention + if args.share_qkv_quant: + AttClass = QuantizableAttention + rewriter = ModuleToModuleByClass( + Attention, + QuantizableAttention, + query_dim=lambda module: module.to_q.in_features, + dim_head=lambda module: math.ceil(1 / (module.scale ** 2)), + bias=lambda module: hasattr(module.to_q, 'bias') and module.to_q.bias is not None, + processor=AttnProcessor2_0(), + dtype=dtype, + norm_num_groups=lambda module: None + if module.group_norm is None else module.group_norm.num_groups) + rewriter.apply(pipe.unet) + else: + QuantAttClass = QuantAttentionLast + if args.share_qkv_quant: + pipe.fuse_qkv_projections() + print(f"Model loaded from {args.model}.") # Move model to target device @@ -232,7 +247,7 @@ def main(args): else: non_blacklist[name_to_add] += 1 print(f"Blacklisted layers: {set(blacklist)}") - print(f"Non blacklisted layers: {non_blacklist}") + print(f"Non blacklisted layers: {set(non_blacklist.keys())}") # Make sure there all LoRA layers are fused first, otherwise raise an error for m in pipe.unet.modules(): @@ -381,7 +396,6 @@ def sdpa_zp_stats_type(): layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) if args.quantize_sdp: - assert args.share_qkv_quant, "Currently SDPA quantization is supported only with shared QKV quantization" # `args.weight_quant_granularity` must be compatible with `args.sdpa_quant_format` sdpa_quantizers = generate_quantizers( dtype=dtype, @@ -406,14 +420,27 @@ def sdpa_zp_stats_type(): # We generate all quantizers, but we are only interested in activation quantization for # the output of softmax and the output of QKV input_quant = sdpa_quantizers[0] + if is_mlperf_diffusers: + extra_kwargs = {} + query_lambda = lambda module: module.to_qkv.in_features if hasattr( + module, 'to_qkv') else module.to_q.in_features + else: + extra_kwargs = { + 'fuse_qkv': + args.share_qkv_quant, + 'cross_attention_dim': + lambda module: module.cross_attention_dim + if module.is_cross_attention else None} + query_lambda = lambda module: module.query_dim rewriter = ModuleToModuleByClass( - Attention, - QuantAttention, + AttClass, + QuantAttClass, matmul_input_quant=input_quant, - query_dim=lambda module: module.to_q.in_features, + query_dim=query_lambda, dim_head=lambda module: math.ceil(1 / (module.scale ** 2)), processor=AttnProcessor(), - is_equalized=args.activation_equalization) + is_equalized=args.activation_equalization, + **extra_kwargs) import brevitas.config as config config.IGNORE_MISSING_KEYS = True pipe.unet = rewriter.apply(pipe.unet) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/nn.py b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py index 5a6c23ab9..299fb557b 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/nn.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py @@ -17,8 +17,11 @@ from typing import Any, Mapping, Optional +import diffusers from diffusers.models.attention_processor import Attention from diffusers.models.lora import LoRACompatibleLinear +import packaging +import packaging.version import torch import torch.nn.functional as F @@ -119,6 +122,142 @@ def load_state_dict( return super().load_state_dict(state_dict, strict, assign) +class QuantAttentionLast(Attention): + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + matmul_input_quant=None, + is_equalized=False, + fuse_qkv=False): + + super().__init__( + query_dim, + cross_attention_dim, + heads, + kv_heads, + dim_head, + dropout, + bias, + upcast_attention, + upcast_softmax, + cross_attention_norm, + cross_attention_norm_num_groups, + qk_norm, + added_kv_proj_dim, + added_proj_bias, + norm_num_groups, + spatial_norm_dim, + out_bias, + scale_qk, + only_cross_attention, + eps, + rescale_output_factor, + residual_connection, + _from_deprecated_attn_block, + processor, + out_dim, + context_pre_only, + pre_only, + ) + if fuse_qkv: + self.fuse_projections() + + self.output_softmax_quant = QuantIdentity(matmul_input_quant) + self.out_q = QuantIdentity(matmul_input_quant) + self.out_k = QuantIdentity(matmul_input_quant) + self.out_v = QuantIdentity(matmul_input_quant) + if is_equalized: + replacements = [] + for n, m in self.named_modules(): + if isinstance(m, torch.nn.Linear): + in_channels = m.in_features + eq_m = EqualizedModule(ScaleBias(in_channels, False, (1, 1, -1)), m) + r = ModuleInstanceToModuleInstance(m, eq_m) + replacements.append(r) + for r in replacements: + r.apply(self) + + def get_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + attention_mask: torch.Tensor = None) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + attention_probs = _unpack_quant_tensor(self.output_softmax_quant(attention_probs)) + return attention_probs + + class QuantAttention(QuantizableAttention): def __init__( @@ -172,7 +311,6 @@ def __init__( dtype, processor, ) - self.output_softmax_quant = QuantIdentity(matmul_input_quant) self.out_q = QuantIdentity(matmul_input_quant) self.out_k = QuantIdentity(matmul_input_quant) From 8431f59c0ad0f28a93ac17f26af035c6650cfeb1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 11 Sep 2024 16:28:05 +0100 Subject: [PATCH 12/14] small fixes --- src/brevitas_examples/stable_diffusion/main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 0fca839c1..57895b6d1 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -485,7 +485,8 @@ def sdpa_zp_stats_type(): pipe = pipe.to(args.device) elif not args.dry_run: if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or - args.sdpa_bit_width > 0) and args.input_scale_type == 'static': + args.sdpa_bit_width > 0 or + args.quantize_sdp) and args.input_scale_type == 'static': print("Applying activation calibration") with torch.no_grad(), calibration_mode(pipe.unet): run_val_inference( @@ -699,7 +700,7 @@ def sdpa_zp_stats_type(): help='Resolution along height and width dimension. Default: 512.') parser.add_argument('--guidance-scale', type=float, default=7.5, help='Guidance scale.') parser.add_argument( - '--calibration-steps', type=float, default=8, help='Steps used during calibration') + '--calibration-steps', type=int, default=8, help='Steps used during calibration') add_bool_arg( parser, 'output-path', @@ -882,7 +883,7 @@ def sdpa_zp_stats_type(): parser.add_argument( '--sdpa-quant-type', type=str, - default='asym', + default='sym', choices=['sym', 'asym'], help='Scaled dot product attention quantization type. Default: %(default)s.') parser.add_argument( From 071527f865150b99ec61a5d0c42ac50e84388573 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 11 Sep 2024 18:57:56 +0100 Subject: [PATCH 13/14] Fix (example/sdxl): Removed `--quantize-sdp` arg. Improved calibration logic --- src/brevitas_examples/stable_diffusion/main.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 57895b6d1..7bfc335f0 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -395,7 +395,7 @@ def sdpa_zp_stats_type(): 'weight_quant'] layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) - if args.quantize_sdp: + if args.sdpa_bit_width > 0: # `args.weight_quant_granularity` must be compatible with `args.sdpa_quant_format` sdpa_quantizers = generate_quantizers( dtype=dtype, @@ -484,9 +484,17 @@ def sdpa_zp_stats_type(): print(f"Checkpoint loaded!") pipe = pipe.to(args.device) elif not args.dry_run: - if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or - args.sdpa_bit_width > 0 or - args.quantize_sdp) and args.input_scale_type == 'static': + # Model needs calibration if any of its activation quantizers are 'static' + activation_bw = [ + args.linear_input_bit_width, + args.conv_input_bit_width, + args.sdpa_bit_width,] + activation_st = [ + args.input_scale_type, + args.input_scale_type, + args.sdpa_scale_type,] + needs_calibration = any(map(lambda b, st: (b > 0) and st == 'static', activation_bw, activation_st)) + if needs_calibration: print("Applying activation calibration") with torch.no_grad(), calibration_mode(pipe.unet): run_val_inference( @@ -952,7 +960,6 @@ def sdpa_zp_stats_type(): 'dry-run', default=False, help='Generate a quantized model without any calibration. Default: Disabled') - add_bool_arg(parser, 'quantize-sdp', default=False, help='Quantize SDP. Default: Disabled') add_bool_arg( parser, 'override-conv-quant-config', From 839694850b04186b9301e26e2a5f3e3bd355f948 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 11 Sep 2024 18:59:38 +0100 Subject: [PATCH 14/14] precommit --- src/brevitas_examples/stable_diffusion/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 7bfc335f0..3fe24a321 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -493,7 +493,8 @@ def sdpa_zp_stats_type(): args.input_scale_type, args.input_scale_type, args.sdpa_scale_type,] - needs_calibration = any(map(lambda b, st: (b > 0) and st == 'static', activation_bw, activation_st)) + needs_calibration = any( + map(lambda b, st: (b > 0) and st == 'static', activation_bw, activation_st)) if needs_calibration: print("Applying activation calibration") with torch.no_grad(), calibration_mode(pipe.unet):