Skip to content

Commit

Permalink
Feat (brevitas_examples/sdxl): inference_mode + compile (#1133)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 17, 2024
1 parent 09f1371 commit 48efcf6
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 15 deletions.
59 changes: 53 additions & 6 deletions src/brevitas_examples/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--gptq | --no-gptq] [--bias-correction | --no-bias-correction]
[--dtype {float32,float16,bfloat16}]
[--attention-slicing | --no-attention-slicing]
[--compile | --no-compile]
[--export-target {,onnx,params_only}]
[--export-weight-q-node | --no-export-weight-q-node]
[--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH]
[--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH]
[--conv-input-bit-width CONV_INPUT_BIT_WIDTH]
[--act-eq-alpha ACT_EQ_ALPHA]
[--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH]
[--linear-output-bit-width LINEAR_OUTPUT_BIT_WIDTH]
[--weight-param-method {stats,mse}]
[--input-param-method {stats,mse}]
[--input-scale-stats-op {minmax,percentile}]
Expand All @@ -92,13 +92,24 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--input-quant-granularity {per_tensor}]
[--input-scale-type {static,dynamic}]
[--weight-group-size WEIGHT_GROUP_SIZE]
[--sdpa-bit-width SDPA_BIT_WIDTH]
[--sdpa-param-method {stats,mse}]
[--sdpa-scale-stats-op {minmax,percentile}]
[--sdpa-zp-stats-op {minmax,percentile}]
[--sdpa-scale-precision {float_scale,po2_scale}]
[--sdpa-quant-type {sym,asym}]
[--sdpa-quant-format SDPA_QUANT_FORMAT]
[--sdpa-quant-granularity {per_tensor}]
[--sdpa-scale-type {static,dynamic}]
[--quant-blacklist [NAME ...]]
[--quantize-weight-zero-point | --no-quantize-weight-zero-point]
[--exclude-blacklist-act-eq | --no-exclude-blacklist-act-eq]
[--quantize-input-zero-point | --no-quantize-input-zero-point]
[--quantize-sdpa-zero-point | --no-quantize-sdpa-zero-point]
[--export-cpu-float32 | --no-export-cpu-float32]
[--use-mlperf-inference | --no-use-mlperf-inference]
[--use-negative-prompts | --no-use-negative-prompts]
[--dry-run | --no-dry-run] [--quantize-sdp | --no-quantize-sdp]
[--dry-run | --no-dry-run]
[--override-conv-quant-config | --no-override-conv-quant-config]
[--vae-fp16-fix | --no-vae-fp16-fix]
[--share-qkv-quant | --no-share-qkv-quant]
Expand Down Expand Up @@ -159,6 +170,8 @@ options:
--attention-slicing Enable Enable attention slicing. Default: Disabled
--no-attention-slicing
Disable Enable attention slicing. Default: Disabled
--compile Enable Compile during inference. Default: Disabled
--no-compile Disable Compile during inference. Default: Disabled
--export-target {,onnx,params_only}
Target export flow.
--export-weight-q-node
Expand All @@ -177,8 +190,6 @@ options:
Alpha for activation equalization. Default: 0.9
--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH
Input bit width. Default: 0 (not quantized).
--linear-output-bit-width LINEAR_OUTPUT_BIT_WIDTH
Input bit width. Default: 0 (not quantized).
--weight-param-method {stats,mse}
How scales/zero-point are determined. Default: stats.
--input-param-method {stats,mse}
Expand Down Expand Up @@ -221,6 +232,38 @@ options:
--weight-group-size WEIGHT_GROUP_SIZE
Group size for per_group weight quantization. Default:
16.
--sdpa-bit-width SDPA_BIT_WIDTH
Scaled dot product attention bit width. Default: 0
(not quantized).
--sdpa-param-method {stats,mse}
How scales/zero-point are determined for scaled dot
product attention. Default: stats.
--sdpa-scale-stats-op {minmax,percentile}
Define what statistics op to use for scaled dot
product attention scale. Default: minmax.
--sdpa-zp-stats-op {minmax,percentile}
Define what statistics op to use for scaled dot
product attention zero point. Default: minmax.
--sdpa-scale-precision {float_scale,po2_scale}
Whether the scaled dot product attention scale is a
float value or a po2. Default: float_scale.
--sdpa-quant-type {sym,asym}
Scaled dot product attention quantization type.
Default: sym.
--sdpa-quant-format SDPA_QUANT_FORMAT
Scaled dot product attention quantization format.
Either int or eXmY, with X+Y==input_bit_width-1. It's
possible to add float_ocp_ or float_fnuz_ before the
exponent/mantissa bitwidth. Default: int.
--sdpa-quant-granularity {per_tensor}
Granularity for scales/zero-point of scaled dot
product attention. Default: per_tensor.
--sdpa-scale-type {static,dynamic}
Whether to do static or dynamic scaled dot product
attention quantization. Default: static.
--quant-blacklist [NAME ...]
A list of module names to exclude from quantization.
Default: ['time_emb']
--quantize-weight-zero-point
Enable Quantize weight zero-point. Default: Enabled
--no-quantize-weight-zero-point
Expand All @@ -235,6 +278,12 @@ options:
Enable Quantize input zero-point. Default: Enabled
--no-quantize-input-zero-point
Disable Quantize input zero-point. Default: Enabled
--quantize-sdpa-zero-point
Enable Quantize scaled dot product attention zero-
point. Default: False
--no-quantize-sdpa-zero-point
Disable Quantize scaled dot product attention zero-
point. Default: False
--export-cpu-float32 Enable Export FP32 on CPU. Default: Disabled
--no-export-cpu-float32
Disable Export FP32 on CPU. Default: Disabled
Expand All @@ -254,8 +303,6 @@ options:
calibration. Default: Disabled
--no-dry-run Disable Generate a quantized model without any
calibration. Default: Disabled
--quantize-sdp Enable Quantize SDP. Default: Disabled
--no-quantize-sdp Disable Quantize SDP. Default: Disabled
--override-conv-quant-config
Enable Quantize Convolutions in the same way as SDP
(i.e., FP8). Default: Disabled
Expand Down
35 changes: 26 additions & 9 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from brevitas_examples.common.parse_utils import add_bool_arg
from brevitas_examples.common.parse_utils import quant_format_validator
from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager
from brevitas_examples.llm.main import quant_inference_mode
from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid
from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE
from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE
Expand Down Expand Up @@ -247,7 +248,6 @@ def main(args):
else:
non_blacklist[name_to_add] += 1
print(f"Blacklisted layers: {set(blacklist)}")
print(f"Non blacklisted layers: {set(non_blacklist.keys())}")

# Make sure there all LoRA layers are fused first, otherwise raise an error
for m in pipe.unet.modules():
Expand Down Expand Up @@ -610,14 +610,29 @@ def sdpa_zp_stats_type():
# with brevitas_proxy_inference_mode(pipe.unet):
if args.use_mlperf_inference:
print(f"Computing accuracy with MLPerf pipeline")
compute_mlperf_fid(
args.model,
args.path_to_coco,
pipe,
args.prompt,
output_dir,
args.device,
not args.vae_fp16_fix)
with torch.no_grad(), quant_inference_mode(pipe.unet):
# Perform a single forward pass before evenutally compiling
run_val_inference(
pipe,
args.resolution,
[calibration_prompts[0]], # We need a list
test_seeds,
args.device,
dtype,
total_steps=1,
use_negative_prompts=args.use_negative_prompts,
test_latents=latents,
guidance_scale=args.guidance_scale)
if args.compile:
pipe.unet = torch.compile(pipe.unet)
compute_mlperf_fid(
args.model,
args.path_to_coco,
pipe,
args.prompt,
output_dir,
args.device,
not args.vae_fp16_fix)
else:
print(f"Computing accuracy on default prompt")
testing_prompts = TESTING_PROMPTS[:args.prompt]
Expand Down Expand Up @@ -734,6 +749,8 @@ def sdpa_zp_stats_type():
'attention-slicing',
default=False,
help='Enable attention slicing. Default: Disabled')
add_bool_arg(
parser, 'compile', default=False, help='Compile during inference. Default: Disabled')
parser.add_argument(
'--export-target',
type=str,
Expand Down

0 comments on commit 48efcf6

Please sign in to comment.