diff --git a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py index 3403cfd9cb06f..bd45b452565d9 100644 --- a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py +++ b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py @@ -116,6 +116,7 @@ def run_sdxl_unet_rocm_benchmark(rocm_chip): # iree benchmark command for full sdxl pipeline return run_iree_command(exec_args) + def run_sdxl_punet_int8_fp16_rocm_benchmark(rocm_chip): exec_args = [ "iree-benchmark-module", @@ -136,6 +137,7 @@ def run_sdxl_punet_int8_fp16_rocm_benchmark(rocm_chip): # iree benchmark command for full sdxl pipeline return run_iree_command(exec_args) + def run_sdxl_prompt_encoder_rocm_benchmark(rocm_chip): exec_args = [ "iree-benchmark-module", @@ -366,7 +368,11 @@ def test_sdxl_rocm_benchmark( ] if rocm_chip == "gfx942": mean_time_rows.append( - ["Punet F16", f"{benchmark_punet_int8_fp16_mean_time}", f"{goldentime_rocm_punet_int8_fp16}"] + [ + "Punet F16", + f"{benchmark_punet_int8_fp16_mean_time}", + f"{goldentime_rocm_punet_int8_fp16}", + ] ) # Create dispatch count table's header and rows @@ -382,7 +388,11 @@ def test_sdxl_rocm_benchmark( ] if rocm_chip == "gfx942": dispatch_count_rows.append( - ["Punet F16", f"{punet_int8_fp16_dispatch_count}", f"{goldendispatch_rocm_punet_int8_fp16}"] + [ + "Punet F16", + f"{punet_int8_fp16_dispatch_count}", + f"{goldendispatch_rocm_punet_int8_fp16}", + ] ) # Create binary size table's header and rows @@ -398,7 +408,11 @@ def test_sdxl_rocm_benchmark( ] if rocm_chip == "gfx942": binary_size_rows.append( - ["Punet F16", f"{punet_int8_fp16_binary_size}", f"{goldensize_rocm_punet_int8_fp16}"] + [ + "Punet F16", + f"{punet_int8_fp16_binary_size}", + f"{goldensize_rocm_punet_int8_fp16}", + ] ) # Create mean time table using tabulate diff --git a/experimental/benchmarks/sdxl/conftest.py b/experimental/benchmarks/sdxl/conftest.py index 369b85103d452..43d6be07b6a23 100644 --- a/experimental/benchmarks/sdxl/conftest.py +++ b/experimental/benchmarks/sdxl/conftest.py @@ -106,10 +106,12 @@ def goldentime_rocm_e2e(request): def goldentime_rocm_unet(request): return request.config.getoption("--goldentime-rocm-unet-ms") + @pytest.fixture def goldentime_rocm_punet_int8_fp16(request): return request.config.getoption("--goldentime-rocm-punet-int8-fp16-ms") + @pytest.fixture def goldentime_rocm_clip(request): return request.config.getoption("--goldentime-rocm-clip-ms") @@ -124,10 +126,12 @@ def goldentime_rocm_vae(request): def goldendispatch_rocm_unet(request): return request.config.getoption("--goldendispatch-rocm-unet") + @pytest.fixture def goldendispatch_rocm_punet_int8_fp16(request): return request.config.getoption("--goldendispatch-rocm-punet-int8-fp16") + @pytest.fixture def goldendispatch_rocm_clip(request): return request.config.getoption("--goldendispatch-rocm-clip") @@ -142,10 +146,12 @@ def goldendispatch_rocm_vae(request): def goldensize_rocm_unet(request): return request.config.getoption("--goldensize-rocm-unet-bytes") + @pytest.fixture def goldensize_rocm_punet_int8_fp16(request): return request.config.getoption("--goldensize-rocm-punet-int8-fp16-bytes") + @pytest.fixture def goldensize_rocm_clip(request): return request.config.getoption("--goldensize-rocm-clip-bytes")