Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (examples/sdxl): Updates to SDXL entry-point #1020

Merged
merged 14 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 120 additions & 30 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def main(args):
blacklist = []
non_blacklist = dict()
for name, _ in pipe.unet.named_modules():
if 'time_emb' in name:
if any(map(lambda x: x in name, args.quant_blacklist)):
blacklist.append(name)
else:
if isinstance(_, (torch.nn.Linear, torch.nn.Conv2d)):
Expand Down Expand Up @@ -316,6 +316,29 @@ def input_zp_stats_type():

input_kwargs['zero_point_stats_impl'] = input_zp_stats_type

sdpa_kwargs = dict()
if args.sdpa_scale_stats_op == 'minmax':

@value
def sdpa_scale_stats_type():
if args.sdpa_quant_type == 'asym':
sdpa_scaling_stats_op = StatsOp.MIN_MAX
else:
sdpa_scaling_stats_op = StatsOp.MAX
return sdpa_scaling_stats_op

sdpa_kwargs['scaling_stats_op'] = sdpa_scale_stats_type

if args.sdpa_zp_stats_op == 'minmax':

@value
def sdpa_zp_stats_type():
if args.sdpa_quant_type == 'asym':
zero_point_stats_impl = NegativeMinOrZero
return zero_point_stats_impl

sdpa_kwargs['zero_point_stats_impl'] = sdpa_zp_stats_type

print("Applying model quantization...")
quantizers = generate_quantizers(
dtype=dtype,
Expand Down Expand Up @@ -360,29 +383,29 @@ def input_zp_stats_type():
if args.quantize_sdp:
assert args.share_qkv_quant, "Currently SDPA quantization is supported only with shared QKV quantization"
# TODO: reformat this
float_sdpa_quantizers = generate_quantizers(
sdpa_quantizers = generate_quantizers(
dtype=dtype,
device=args.device,
weight_bit_width=weight_bit_width,
weight_quant_format='float_ocp_e4m3',
weight_quant_type='sym',
weight_param_method=args.weight_param_method,
weight_scale_precision=args.weight_scale_precision,
weight_quant_granularity=args.weight_quant_granularity,
weight_group_size=args.weight_group_size,
quantize_weight_zero_point=args.quantize_weight_zero_point,
quantize_input_zero_point=args.quantize_input_zero_point,
input_bit_width=args.linear_output_bit_width,
input_quant_format='float_ocp_e4m3',
input_scale_type=args.input_scale_type,
input_scale_precision=args.input_scale_precision,
input_param_method=args.input_param_method,
input_quant_type='sym',
input_quant_granularity=args.input_quant_granularity,
input_kwargs=input_kwargs)
weight_bit_width=args.sdpa_bit_width,
weight_quant_format=args.sdpa_quant_format,
weight_quant_type=args.sdpa_quant_type,
weight_param_method=args.sdpa_param_method,
weight_scale_precision=args.sdpa_scale_precision,
weight_quant_granularity=args.sdpa_quant_granularity,
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
weight_group_size=32, # Not used, since args.sdpa_quant_granularity == 'per_tensor'
quantize_weight_zero_point=args.quantize_sdpa_zero_point,
quantize_input_zero_point=args.quantize_sdpa_zero_point,
input_bit_width=args.sdpa_bit_width,
input_quant_format=args.sdpa_quant_format,
input_scale_type=args.sdpa_scale_type,
input_scale_precision=args.sdpa_scale_precision,
input_param_method=args.sdpa_param_method,
input_quant_type=args.sdpa_quant_type,
input_quant_granularity=args.sdpa_quant_granularity,
input_kwargs=sdpa_kwargs)
# We generate all quantizers, but we are only interested in activation quantization for
# the output of softmax and the output of QKV
input_quant = float_sdpa_quantizers[0]
input_quant = sdpa_quantizers[0]
rewriter = ModuleToModuleByClass(
Attention,
QuantAttention,
Expand All @@ -400,11 +423,11 @@ def input_zp_stats_type():

if args.override_conv_quant_config:
print(
f"Overriding Conv2d quantization to weights: {float_sdpa_quantizers[1]}, inputs: {float_sdpa_quantizers[2]}"
f"Overriding Conv2d quantization to weights: {sdpa_quantizers[1]}, inputs: {sdpa_quantizers[2]}"
)
conv_qkwargs = layer_map[torch.nn.Conv2d][1]
conv_qkwargs['input_quant'] = float_sdpa_quantizers[2]
conv_qkwargs['weight_quant'] = float_sdpa_quantizers[1]
conv_qkwargs['input_quant'] = sdpa_quantizers[2]
conv_qkwargs['weight_quant'] = sdpa_quantizers[1]
layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs)

pipe.unet = layerwise_quantize(
Expand Down Expand Up @@ -435,7 +458,7 @@ def input_zp_stats_type():
pipe = pipe.to(args.device)
elif not args.dry_run:
if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or
args.linear_output_bit_width > 0) and args.input_scale_type == 'static':
args.sdpa_bit_width > 0) and args.input_scale_type == 'static':
print("Applying activation calibration")
with torch.no_grad(), calibration_mode(pipe.unet):
run_val_inference(
Expand Down Expand Up @@ -520,7 +543,13 @@ def input_zp_stats_type():
if args.use_mlperf_inference:
print(f"Computing accuracy with MLPerf pipeline")
compute_mlperf_fid(
args.model, args.path_to_coco, pipe, args.prompt, output_dir, not args.vae_fp16_fix)
args.model,
args.path_to_coco,
pipe,
args.prompt,
output_dir,
args.device,
not args.vae_fp16_fix)
else:
print(f"Computing accuracy on default prompt")
testing_prompts = TESTING_PROMPTS[:args.prompt]
Expand Down Expand Up @@ -701,11 +730,6 @@ def input_zp_stats_type():
type=int,
default=0,
help='Input bit width. Default: 0 (not quantized).')
parser.add_argument(
'--linear-output-bit-width',
type=int,
default=0,
help='Input bit width. Default: 0 (not quantized).')
parser.add_argument(
'--weight-param-method',
type=str,
Expand Down Expand Up @@ -791,6 +815,67 @@ def input_zp_stats_type():
type=int,
default=16,
help='Group size for per_group weight quantization. Default: 16.')
parser.add_argument(
'--sdpa-bit-width',
type=int,
default=0,
help='Scaled dot product attention bit width. Default: 0 (not quantized).')
parser.add_argument(
'--sdpa-param-method',
type=str,
default='stats',
choices=['stats', 'mse'],
help='How scales/zero-point are determined for scaled dot product attention. Default: %(default)s.')
parser.add_argument(
'--sdpa-scale-stats-op',
type=str,
default='minmax',
choices=['minmax', 'percentile'],
help='Define what statistics op to use for scaled dot product attention scale. Default: %(default)s.')
parser.add_argument(
'--sdpa-zp-stats-op',
type=str,
default='minmax',
choices=['minmax', 'percentile'],
help='Define what statistics op to use for scaled dot product attention zero point. Default: %(default)s.')
parser.add_argument(
'--sdpa-scale-precision',
type=str,
default='float_scale',
choices=['float_scale', 'po2_scale'],
help='Whether the scaled dot product attention scale is a float value or a po2. Default: %(default)s.')
parser.add_argument(
'--sdpa-quant-type',
type=str,
default='asym',
choices=['sym', 'asym'],
help='Scaled dot product attention quantization type. Default: %(default)s.')
parser.add_argument(
'--sdpa-quant-format',
type=quant_format_validator,
default='int',
help=
'Scaled dot product attention quantization format. Either int or eXmY, with X+Y==input_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-quant-granularity',
type=str,
default='per_tensor',
choices=['per_tensor'],
help='Granularity for scales/zero-point of scaled dot product attention. Default: %(default)s.')
parser.add_argument(
'--sdpa-scale-type',
type=str,
default='static',
choices=['static', 'dynamic'],
help='Whether to do static or dynamic scaled dot product attention quantization. Default: %(default)s.')
parser.add_argument(
'--quant-blacklist',
type=str,
default=['time_emb'],
nargs='*',
metavar='NAME',
help='A list of module names to exclude from quantization. Default: %(default)s')
add_bool_arg(
parser,
'quantize-weight-zero-point',
Expand All @@ -806,6 +891,11 @@ def input_zp_stats_type():
'quantize-input-zero-point',
default=False,
help='Quantize input zero-point. Default: Enabled')
add_bool_arg(
parser,
'quantize-sdpa-zero-point',
default=False,
help='Quantize scaled dot product attention zero-point. Default: %(default)s')
add_bool_arg(
parser, 'export-cpu-float32', default=False, help='Export FP32 on CPU. Default: Disabled')
add_bool_arg(
Expand Down
15 changes: 9 additions & 6 deletions src/brevitas_examples/stable_diffusion/sd_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ def handle_quant_param(layer, layer_dict):
layer_dict['output_scale_shape'] = output_scale.shape
layer_dict['input_scale'] = input_scale.numpy().tolist()
layer_dict['input_scale_shape'] = input_scale.shape
layer_dict['input_zp'] = input_zp.numpy().tolist()
layer_dict['input_zp'] = input_zp.to(torch.float32).cpu().numpy().tolist()
layer_dict['input_zp_shape'] = input_zp.shape
layer_dict['input_zp_dtype'] = str(torch.int8)
layer_dict['input_zp_dtype'] = str(input_zp.dtype)
layer_dict['weight_scale'] = weight_scale.cpu().numpy().tolist()
nelems = layer.weight.shape[0]
weight_scale_shape = [nelems] + [1] * (layer.weight.data.ndim - 1)
layer_dict['weight_scale_shape'] = weight_scale_shape
if torch.sum(weight_zp) != 0.:
if torch.sum(weight_zp.to(torch.float32)) != 0.:
weight_zp = weight_zp - 128. # apply offset to have signed z
layer_dict['weight_zp'] = weight_zp.cpu().numpy().tolist()
layer_dict['weight_zp'] = weight_zp.to(torch.float32).cpu().numpy().tolist()
layer_dict['weight_zp_shape'] = weight_scale_shape
layer_dict['weight_zp_dtype'] = str(torch.int8)
layer_dict['weight_zp_dtype'] = str(weight_zp.dtype)
return layer_dict


Expand All @@ -63,14 +63,17 @@ def export_quant_params(pipe, output_dir, export_vae=False):
vae_output_path = os.path.join(output_dir, 'vae.safetensors')
print(f"Saving vae to {vae_output_path} ...")
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
export_manager = StdQCDQONNXManager
export_manager.change_weight_export(
export_weight_q_node=True) # We're exporting FP weights + quantization parameters
quant_params = dict()
state_dict = pipe.unet.state_dict()
state_dict = {k: v for (k, v) in state_dict.items() if 'tensor_quant' not in k}
state_dict = {k: v for (k, v) in state_dict.items() if not k.endswith('.scale.weight')}
state_dict = {k.replace('.layer.', '.'): v for (k, v) in state_dict.items()}

handled_quant_layers = set()
with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, StdQCDQONNXManager):
with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager):
for name, module in pipe.unet.named_modules():
if isinstance(module, EqualizedModule):
if id(module.layer) in handled_quant_layers:
Expand Down
Loading