Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (examples/sdxl): Updates to SDXL entry-point #1020

Merged
merged 14 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 187 additions & 51 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import time

from dependencies import value
import diffusers
from diffusers import DiffusionPipeline
from diffusers import EulerDiscreteScheduler
from diffusers import StableDiffusionXLPipeline
from diffusers.models.attention_processor import Attention
from diffusers.models.attention_processor import AttnProcessor
import numpy as np
import packaging
import packaging.version
import pandas as pd
import torch
from torch import nn
Expand All @@ -35,7 +37,6 @@
from brevitas.graph.quantize import layerwise_quantize
from brevitas.inject.enum import StatsOp
from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.nn.quant_activation import QuantIdentity
from brevitas.utils.torch_utils import KwargsForwardHook
from brevitas_examples.common.generative.quantize import generate_quant_maps
from brevitas_examples.common.generative.quantize import generate_quantizers
Expand All @@ -47,14 +48,17 @@
from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE
from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx
from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params
from brevitas_examples.stable_diffusion.sd_quant.nn import AttnProcessor
from brevitas_examples.stable_diffusion.sd_quant.nn import AttnProcessor2_0
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttentionLast
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs
from brevitas_examples.stable_diffusion.sd_quant.utils import unet_input_shape

diffusers_version = packaging.version.parse(diffusers.__version__)
TEST_SEED = 123456
torch.manual_seed(TEST_SEED)

Expand Down Expand Up @@ -149,7 +153,7 @@ def main(args):
calibration_prompts = CALIBRATION_PROMPTS
if args.calibration_prompt_path is not None:
calibration_prompts = load_calib_prompts(args.calibration_prompt_path)
print(args.calibration_prompt, len(calibration_prompts))

assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available"
calibration_prompts = calibration_prompts[:args.calibration_prompt]

Expand Down Expand Up @@ -178,18 +182,29 @@ def main(args):
args.model, torch_dtype=dtype, variant=variant, use_safetensors=True)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.vae.config.force_upcast = True
if args.share_qkv_quant:
rewriter = ModuleToModuleByClass(
Attention,
QuantizableAttention,
query_dim=lambda module: module.to_q.in_features,
dim_head=lambda module: math.ceil(1 / (module.scale ** 2)),
bias=lambda module: hasattr(module.to_q, 'bias') and module.to_q.bias is not None,
processor=AttnProcessor2_0(),
dtype=dtype,
norm_num_groups=lambda module: None
if module.group_norm is None else module.group_norm.num_groups)
rewriter.apply(pipe.unet)
is_mlperf_diffusers = diffusers_version == packaging.version.parse('0.21.2')

AttClass = Attention
if is_mlperf_diffusers:
QuantAttClass = QuantAttention
if args.share_qkv_quant:
AttClass = QuantizableAttention
rewriter = ModuleToModuleByClass(
Attention,
QuantizableAttention,
query_dim=lambda module: module.to_q.in_features,
dim_head=lambda module: math.ceil(1 / (module.scale ** 2)),
bias=lambda module: hasattr(module.to_q, 'bias') and module.to_q.bias is not None,
processor=AttnProcessor2_0(),
dtype=dtype,
norm_num_groups=lambda module: None
if module.group_norm is None else module.group_norm.num_groups)
rewriter.apply(pipe.unet)
else:
QuantAttClass = QuantAttentionLast
if args.share_qkv_quant:
pipe.fuse_qkv_projections()

print(f"Model loaded from {args.model}.")

# Move model to target device
Expand Down Expand Up @@ -222,7 +237,7 @@ def main(args):
blacklist = []
non_blacklist = dict()
for name, _ in pipe.unet.named_modules():
if 'time_emb' in name:
if any(map(lambda x: x in name, args.quant_blacklist)):
blacklist.append(name)
else:
if isinstance(_, (torch.nn.Linear, torch.nn.Conv2d)):
Expand All @@ -232,7 +247,7 @@ def main(args):
else:
non_blacklist[name_to_add] += 1
print(f"Blacklisted layers: {set(blacklist)}")
print(f"Non blacklisted layers: {non_blacklist}")
print(f"Non blacklisted layers: {set(non_blacklist.keys())}")

# Make sure there all LoRA layers are fused first, otherwise raise an error
for m in pipe.unet.modules():
Expand Down Expand Up @@ -316,6 +331,29 @@ def input_zp_stats_type():

input_kwargs['zero_point_stats_impl'] = input_zp_stats_type

sdpa_kwargs = dict()
if args.sdpa_scale_stats_op == 'minmax':

@value
def sdpa_scale_stats_type():
if args.sdpa_quant_type == 'asym':
sdpa_scaling_stats_op = StatsOp.MIN_MAX
else:
sdpa_scaling_stats_op = StatsOp.MAX
return sdpa_scaling_stats_op

sdpa_kwargs['scaling_stats_op'] = sdpa_scale_stats_type

if args.sdpa_zp_stats_op == 'minmax':

@value
def sdpa_zp_stats_type():
if args.sdpa_quant_type == 'asym':
zero_point_stats_impl = NegativeMinOrZero
return zero_point_stats_impl

sdpa_kwargs['zero_point_stats_impl'] = sdpa_zp_stats_type

print("Applying model quantization...")
quantizers = generate_quantizers(
dtype=dtype,
Expand Down Expand Up @@ -357,40 +395,52 @@ def input_zp_stats_type():
'weight_quant']
layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs)

if args.quantize_sdp:
assert args.share_qkv_quant, "Currently SDPA quantization is supported only with shared QKV quantization"
# TODO: reformat this
float_sdpa_quantizers = generate_quantizers(
if args.sdpa_bit_width > 0:
# `args.weight_quant_granularity` must be compatible with `args.sdpa_quant_format`
sdpa_quantizers = generate_quantizers(
dtype=dtype,
device=args.device,
weight_bit_width=weight_bit_width,
weight_quant_format='float_ocp_e4m3',
weight_quant_type='sym',
weight_bit_width=args.sdpa_bit_width,
weight_quant_format=args.sdpa_quant_format,
weight_quant_type=args.sdpa_quant_type,
weight_param_method=args.weight_param_method,
weight_scale_precision=args.weight_scale_precision,
weight_quant_granularity=args.weight_quant_granularity,
weight_group_size=args.weight_group_size,
quantize_weight_zero_point=args.quantize_weight_zero_point,
quantize_input_zero_point=args.quantize_input_zero_point,
input_bit_width=args.linear_output_bit_width,
input_quant_format='float_ocp_e4m3',
input_scale_type=args.input_scale_type,
input_scale_precision=args.input_scale_precision,
input_param_method=args.input_param_method,
input_quant_type='sym',
input_quant_granularity=args.input_quant_granularity,
input_kwargs=input_kwargs)
quantize_input_zero_point=args.quantize_sdpa_zero_point,
input_bit_width=args.sdpa_bit_width,
input_quant_format=args.sdpa_quant_format,
input_scale_type=args.sdpa_scale_type,
input_scale_precision=args.sdpa_scale_precision,
input_param_method=args.sdpa_param_method,
input_quant_type=args.sdpa_quant_type,
input_quant_granularity=args.sdpa_quant_granularity,
input_kwargs=sdpa_kwargs)
# We generate all quantizers, but we are only interested in activation quantization for
# the output of softmax and the output of QKV
input_quant = float_sdpa_quantizers[0]
input_quant = sdpa_quantizers[0]
if is_mlperf_diffusers:
extra_kwargs = {}
query_lambda = lambda module: module.to_qkv.in_features if hasattr(
module, 'to_qkv') else module.to_q.in_features
else:
extra_kwargs = {
'fuse_qkv':
args.share_qkv_quant,
'cross_attention_dim':
lambda module: module.cross_attention_dim
if module.is_cross_attention else None}
query_lambda = lambda module: module.query_dim
rewriter = ModuleToModuleByClass(
Attention,
QuantAttention,
AttClass,
QuantAttClass,
matmul_input_quant=input_quant,
query_dim=lambda module: module.to_q.in_features,
query_dim=query_lambda,
dim_head=lambda module: math.ceil(1 / (module.scale ** 2)),
processor=AttnProcessor(),
is_equalized=args.activation_equalization)
is_equalized=args.activation_equalization,
**extra_kwargs)
import brevitas.config as config
config.IGNORE_MISSING_KEYS = True
pipe.unet = rewriter.apply(pipe.unet)
Expand All @@ -400,11 +450,11 @@ def input_zp_stats_type():

if args.override_conv_quant_config:
print(
f"Overriding Conv2d quantization to weights: {float_sdpa_quantizers[1]}, inputs: {float_sdpa_quantizers[2]}"
f"Overriding Conv2d quantization to weights: {sdpa_quantizers[1]}, inputs: {sdpa_quantizers[2]}"
)
conv_qkwargs = layer_map[torch.nn.Conv2d][1]
conv_qkwargs['input_quant'] = float_sdpa_quantizers[2]
conv_qkwargs['weight_quant'] = float_sdpa_quantizers[1]
conv_qkwargs['input_quant'] = sdpa_quantizers[2]
conv_qkwargs['weight_quant'] = sdpa_quantizers[1]
layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs)

pipe.unet = layerwise_quantize(
Expand Down Expand Up @@ -434,8 +484,17 @@ def input_zp_stats_type():
print(f"Checkpoint loaded!")
pipe = pipe.to(args.device)
elif not args.dry_run:
if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or
args.linear_output_bit_width > 0) and args.input_scale_type == 'static':
# Model needs calibration if any of its activation quantizers are 'static'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sweet

activation_bw = [
args.linear_input_bit_width,
args.conv_input_bit_width,
args.sdpa_bit_width,]
activation_st = [
args.input_scale_type,
args.input_scale_type,
args.sdpa_scale_type,]
needs_calibration = any(map(lambda b, st: (b > 0) and st == 'static', activation_bw, activation_st))
if needs_calibration:
print("Applying activation calibration")
with torch.no_grad(), calibration_mode(pipe.unet):
run_val_inference(
Expand Down Expand Up @@ -520,7 +579,13 @@ def input_zp_stats_type():
if args.use_mlperf_inference:
print(f"Computing accuracy with MLPerf pipeline")
compute_mlperf_fid(
args.model, args.path_to_coco, pipe, args.prompt, output_dir, not args.vae_fp16_fix)
args.model,
args.path_to_coco,
pipe,
args.prompt,
output_dir,
args.device,
not args.vae_fp16_fix)
else:
print(f"Computing accuracy on default prompt")
testing_prompts = TESTING_PROMPTS[:args.prompt]
Expand Down Expand Up @@ -643,7 +708,7 @@ def input_zp_stats_type():
help='Resolution along height and width dimension. Default: 512.')
parser.add_argument('--guidance-scale', type=float, default=7.5, help='Guidance scale.')
parser.add_argument(
'--calibration-steps', type=float, default=8, help='Steps used during calibration')
'--calibration-steps', type=int, default=8, help='Steps used during calibration')
add_bool_arg(
parser,
'output-path',
Expand Down Expand Up @@ -701,11 +766,6 @@ def input_zp_stats_type():
type=int,
default=0,
help='Input bit width. Default: 0 (not quantized).')
parser.add_argument(
'--linear-output-bit-width',
type=int,
default=0,
help='Input bit width. Default: 0 (not quantized).')
parser.add_argument(
'--weight-param-method',
type=str,
Expand Down Expand Up @@ -791,6 +851,78 @@ def input_zp_stats_type():
type=int,
default=16,
help='Group size for per_group weight quantization. Default: 16.')
parser.add_argument(
'--sdpa-bit-width',
type=int,
default=0,
help='Scaled dot product attention bit width. Default: 0 (not quantized).')
parser.add_argument(
'--sdpa-param-method',
type=str,
default='stats',
choices=['stats', 'mse'],
help=
'How scales/zero-point are determined for scaled dot product attention. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-scale-stats-op',
type=str,
default='minmax',
choices=['minmax', 'percentile'],
help=
'Define what statistics op to use for scaled dot product attention scale. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-zp-stats-op',
type=str,
default='minmax',
choices=['minmax', 'percentile'],
help=
'Define what statistics op to use for scaled dot product attention zero point. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-scale-precision',
type=str,
default='float_scale',
choices=['float_scale', 'po2_scale'],
help=
'Whether the scaled dot product attention scale is a float value or a po2. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-quant-type',
type=str,
default='sym',
choices=['sym', 'asym'],
help='Scaled dot product attention quantization type. Default: %(default)s.')
parser.add_argument(
'--sdpa-quant-format',
type=quant_format_validator,
default='int',
help=
'Scaled dot product attention quantization format. Either int or eXmY, with X+Y==input_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-quant-granularity',
type=str,
default='per_tensor',
choices=['per_tensor'],
help=
'Granularity for scales/zero-point of scaled dot product attention. Default: %(default)s.')
parser.add_argument(
'--sdpa-scale-type',
type=str,
default='static',
choices=['static', 'dynamic'],
help=
'Whether to do static or dynamic scaled dot product attention quantization. Default: %(default)s.'
)
parser.add_argument(
'--quant-blacklist',
type=str,
default=['time_emb'],
nargs='*',
metavar='NAME',
help='A list of module names to exclude from quantization. Default: %(default)s')
add_bool_arg(
parser,
'quantize-weight-zero-point',
Expand All @@ -806,6 +938,11 @@ def input_zp_stats_type():
'quantize-input-zero-point',
default=False,
help='Quantize input zero-point. Default: Enabled')
add_bool_arg(
parser,
'quantize-sdpa-zero-point',
default=False,
help='Quantize scaled dot product attention zero-point. Default: %(default)s')
add_bool_arg(
parser, 'export-cpu-float32', default=False, help='Export FP32 on CPU. Default: Disabled')
add_bool_arg(
Expand All @@ -823,7 +960,6 @@ def input_zp_stats_type():
'dry-run',
default=False,
help='Generate a quantized model without any calibration. Default: Disabled')
add_bool_arg(parser, 'quantize-sdp', default=False, help='Quantize SDP. Default: Disabled')
add_bool_arg(
parser,
'override-conv-quant-config',
Expand Down
Loading
Loading