diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 82b2b9e98..9b20e0a7b 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -520,7 +520,7 @@ def input_zp_stats_type(): if args.use_mlperf_inference: print(f"Computing accuracy with MLPerf pipeline") compute_mlperf_fid( - args.model, args.path_to_coco, pipe, args.prompt, output_dir, not args.vae_fp16_fix) + args.model, args.path_to_coco, pipe, args.prompt, output_dir, args.device, not args.vae_fp16_fix) else: print(f"Computing accuracy on default prompt") testing_prompts = TESTING_PROMPTS[:args.prompt]