From bfd9d56602e7e3b06ecce0370a53cb4606020ef0 Mon Sep 17 00:00:00 2001 From: saienduri Date: Fri, 8 Nov 2024 15:48:56 -0600 Subject: [PATCH 01/13] initial testing of punet Signed-off-by: saienduri Signed-off-by: saienduri --- .../attention_and_matmul_spec_punet.mlir | 309 ++++++++++++++++++ .../benchmarks/sdxl/benchmark_sdxl_rocm.py | 84 ++++- experimental/benchmarks/sdxl/conftest.py | 29 ++ .../shark-test-suite-models/sdxl/test_unet.py | 3 +- 4 files changed, 423 insertions(+), 2 deletions(-) create mode 100644 build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir diff --git a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir new file mode 100644 index 000000000000..c92ddd0f9cde --- /dev/null +++ b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir @@ -0,0 +1,309 @@ +module attributes { transform.with_named_sequence } { +//===----------------------------------------------------------------------===// +// Tuning infra +//===----------------------------------------------------------------------===// + + transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly}, + %config: !transform.any_param {transform.readonly}) { + transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param + // transform.print %op {name = "Applied"} : !transform.any_op + transform.yield + } + + transform.named_sequence @apply_attn_op_config(%attention: !transform.any_op {transform.readonly}, + %config: !transform.any_param {transform.readonly}, + %decomposition_config: !transform.any_param {transform.readonly}) { + transform.annotate %attention "compilation_info" = %config : !transform.any_op, !transform.any_param + transform.annotate %attention "decomposition_config" = %decomposition_config : !transform.any_op, !transform.any_param + // transform.print %attention {name = "Applied attention config"} : !transform.any_op + transform.yield + } + + transform.named_sequence @match_broadcast_rhs_mmt_i8_i8_i32( + %root: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.match.operation_name %root ["linalg.generic"] : !transform.any_op + // transform.print %root {name = "Generic"} : !transform.any_op + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root { + ^bb0(%lhs: tensor, %rhs: tensor, %out: tensor): + %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor) outs(%out : tensor) { + ^bb0(%in: i8, %in_0: i8, %acc: i32): + %22 = arith.extsi %in : i8 to i32 + %23 = arith.extsi %in_0 : i8 to i32 + %24 = arith.muli %22, %23 : i32 + %25 = arith.addi %acc, %24 : i32 + linalg.yield %25 : i32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + transform.yield %root : !transform.any_op + } + +//===----------------------------------------------------------------------===// +// Attention tuning +//===----------------------------------------------------------------------===// + +transform.named_sequence @match_attention_f16(%attention: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param, !transform.any_param) { + transform.match.operation_name %attention ["iree_linalg_ext.attention"] : !transform.any_op + %in0 = transform.get_operand %attention[0] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %in0 = tensor : !transform.any_value + + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 128, 0, 0, 0], reduction=[0, 0, 0, 0, 0, 32], promote_operands = [1, 2]}>, + translation_info = #iree_codegen.translation_info> + -> !transform.any_param + + %decomposition_config = transform.param.constant { + qk_attrs = {attention_qk_matmul, + lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.virtual_mma_layout, + subgroup_m_count = 4, subgroup_n_count = 1, promote_operands = [1] }>}, + pv_attrs = {attention_pv_matmul, + lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 1, promote_operands = [1] }>} + } -> !transform.any_param + + transform.yield %attention, %config, %decomposition_config : !transform.any_op, !transform.any_param, !transform.any_param + } + +transform.named_sequence @match_attention_f8(%attention: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param, !transform.any_param) { + transform.match.operation_name %attention ["iree_linalg_ext.attention"] : !transform.any_op + %in0 = transform.get_operand %attention[0] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %in0 = tensor : !transform.any_value + + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 0], reduction=[0, 0, 0, 0, 32, 0], promote_operands = [1, 2]}>, + translation_info = #iree_codegen.translation_info> + -> !transform.any_param + + %decomposition_config = transform.param.constant { + qk_attrs = {attention_qk_matmul, + lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 1, promote_operands = [1] }>}, + pv_attrs = {attention_pv_matmul, + lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.virtual_mma_layout, + subgroup_m_count = 4, subgroup_n_count = 1, promote_operands = [1] }>} + } -> !transform.any_param + + transform.yield %attention, %config, %decomposition_config : !transform.any_op, !transform.any_param, !transform.any_param + } + +// TUNING_SPEC_BEGIN DO NOT REMOVE + +//===----------------------------------------------------------------------===// +// Matmul tuning +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Convolution tuning +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Batch matmul tuning +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Broadcast rhs mmt tuning +//===----------------------------------------------------------------------===// + + transform.named_sequence @match_broadcast_rhs_mmt_Bx1024x10240x1280(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<10240x1280xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 2, + reduction = [0, 0, 0, 128], + workgroup = [1, 128, 320, 0]}>, + translation_info = #iree_codegen.translation_info}> + > -> !transform.any_param + transform.yield %generic, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_broadcast_rhs_mmt_Bx1024x1280x1280(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x1280xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 2, + reduction = [0, 0, 0, 128], + workgroup = [1, 64, 160, 0]}>, + translation_info = #iree_codegen.translation_info>}> + > -> !transform.any_param + transform.yield %generic, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x5120x640(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<5120x640xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 4, + reduction = [0, 0, 0, 64], + workgroup = [1, 256, 128, 0]}>, + translation_info = #iree_codegen.translation_info}> + > -> !transform.any_param + transform.yield %generic, %config : !transform.any_op, !transform.any_param + } + +//===----------------------------------------------------------------------===// +// Contraction tuning +//===----------------------------------------------------------------------===// + + transform.named_sequence @match_matmul_like_Bx20x1024x64x1280_i8xi8xi32(%cont: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont { + ^bb0(%lhs: tensor, %rhs: tensor<20x64x1280xi8>, %out: tensor): + %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor<20x64x1280xi8>) + outs(%out : tensor) { + ^bb0(%in: i8, %in_0: i8, %acc: i32): + %18 = arith.extsi %in : i8 to i32 + %19 = arith.extsi %in_0 : i8 to i32 + %20 = arith.muli %18, %19 : i32 + %21 = arith.addi %acc, %20 : i32 + linalg.yield %21 : i32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 2, + reduction = [0, 0, 0, 0, 128], + workgroup = [1, 1, 64, 160, 0]}>, + translation_info = #iree_codegen.translation_info> + }> + > -> !transform.any_param + transform.yield %cont, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_matmul_like_Bx20x64x64x2048_i8xi8xi32(%cont: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont { + ^bb0(%lhs: tensor, %rhs: tensor<20x64x2048xi8>, %out: tensor): + %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor<20x64x2048xi8>) + outs(%out : tensor) { + ^bb0(%in: i8, %in_0: i8, %acc: i32): + %18 = arith.extsi %in : i8 to i32 + %19 = arith.extsi %in_0 : i8 to i32 + %20 = arith.muli %18, %19 : i32 + %21 = arith.addi %acc, %20 : i32 + linalg.yield %21 : i32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 1, + reduction = [0, 0, 0, 0, 128], + workgroup = [1, 1, 32, 320, 0]}>, + translation_info = #iree_codegen.translation_info}> + > -> !transform.any_param + transform.yield %cont, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_matmul_like_Bx10x4096x64x640_i8xi8xi32(%cont: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont { + ^bb0(%lhs: tensor, %rhs: tensor<10x64x640xi8>, %out: tensor): + %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor<10x64x640xi8>) + outs(%out : tensor) { + ^bb0(%in: i8, %in_0: i8, %acc: i32): + %18 = arith.extsi %in : i8 to i32 + %19 = arith.extsi %in_0 : i8 to i32 + %20 = arith.muli %18, %19 : i32 + %21 = arith.addi %acc, %20 : i32 + linalg.yield %21 : i32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 8, subgroup_n_count = 1, + reduction = [0, 0, 0, 0, 64], + workgroup = [1, 1, 256, 64, 0]}>, + translation_info = #iree_codegen.translation_info}> + > -> !transform.any_param + transform.yield %cont, %config : !transform.any_op, !transform.any_param + } + +// TUNING_SPEC_END DO NOT REMOVE + +//===----------------------------------------------------------------------===// +// Entry point +//===----------------------------------------------------------------------===// + + transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {transform.consumed}) { + transform.foreach_match in %variant_op + // Attention. + @match_attention_f16 -> @apply_attn_op_config + , @match_attention_f8 -> @apply_attn_op_config + + // TUNING_MATCH_BEGIN DO NOT REMOVE + + // Matmul. + + // Convolution. + + // Batch matmul. + + // Broadcast rhs mmt. + , @match_broadcast_rhs_mmt_Bx4096x5120x640 -> @apply_op_config + + // Carried over from SPX. + , @match_broadcast_rhs_mmt_Bx1024x10240x1280 -> @apply_op_config + , @match_broadcast_rhs_mmt_Bx1024x1280x1280 -> @apply_op_config + + // Contration. + , @match_matmul_like_Bx20x1024x64x1280_i8xi8xi32 -> @apply_op_config + , @match_matmul_like_Bx10x4096x64x640_i8xi8xi32 -> @apply_op_config + , @match_matmul_like_Bx20x64x64x2048_i8xi8xi32 -> @apply_op_config + + // TUNING_MATCH_END DO NOT REMOVE + : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} //// module diff --git a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py index 879a0d40dff3..d5fb0da0cc33 100644 --- a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py +++ b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py @@ -16,13 +16,15 @@ vmfb_dir = os.getenv("TEST_OUTPUT_ARTIFACTS", default=Path.cwd()) benchmark_dir = os.path.dirname(os.path.realpath(__file__)) -artifacts_dir = os.getenv("IREE_TEST_FILES", default=Path.cwd()) + "/artifacts" +artifacts_dir = os.getenv("IREE_TEST_FILES", default=Path.cwd()) / "artifacts" artifacts_dir = Path(os.path.expanduser(artifacts_dir)).resolve() prompt_encoder_dir = f"{artifacts_dir}/sdxl_clip" scheduled_unet_dir = f"{artifacts_dir}/sdxl_unet_fp16" +punet_int8_fp16_dir = f"{artifacts_dir}/sdxl_punet_int8_fp16" vae_decode_dir = f"{artifacts_dir}/sdxl_vae" prompt_encoder_dir_compile = f"{vmfb_dir}/sdxl_clip_vmfbs" scheduled_unet_dir_compile = f"{vmfb_dir}/sdxl_unet_fp16_vmfbs" +punet_int8_fp16_dir_compile = f"{vmfb_dir}/sdxl_punet_int8_fp16_vmfbs" vae_decode_dir_compile = f"{vmfb_dir}/sdxl_vae_vmfbs" @@ -114,6 +116,25 @@ def run_sdxl_unet_rocm_benchmark(rocm_chip): # iree benchmark command for full sdxl pipeline return run_iree_command(exec_args) +def run_sdxl_punet_int8_fp16_rocm_benchmark(rocm_chip): + exec_args = [ + "iree-benchmark-module", + f"--device=hip", + "--device_allocator=caching", + f"--module={punet_int8_fp16_dir_compile}/punet.rocm_{rocm_chip}.vmfb", + f"--parameters=model={punet_int8_fp16_dir}/punet_weights.irpa", + "--function=main", + f"--input=1x4x128x128xf16", + f"--input=1xf16", + f"--input=2x64x2048xf16", + f"--input=2x1280xf16", + f"--input=2x6xf16", + f"--input=1xf16", + "--benchmark_repetitions=10", + "--benchmark_min_warmup_time=3.0", + ] + # iree benchmark command for full sdxl pipeline + return run_iree_command(exec_args) def run_sdxl_prompt_encoder_rocm_benchmark(rocm_chip): exec_args = [ @@ -196,13 +217,16 @@ def job_summary_process(ret_value, output): def test_sdxl_rocm_benchmark( goldentime_rocm_e2e, goldentime_rocm_unet, + goldentime_rocm_punet_int8_fp16, goldentime_rocm_clip, goldentime_rocm_vae, rocm_chip, goldendispatch_rocm_unet, + goldendispatch_rocm_punet_int8_fp16, goldendispatch_rocm_clip, goldendispatch_rocm_vae, goldensize_rocm_unet, + goldensize_rocm_punet_int8_fp16, goldensize_rocm_clip, goldensize_rocm_vae, ): @@ -244,6 +268,36 @@ def test_sdxl_rocm_benchmark( ) logging.getLogger().info(compilation_line) + if rocm_chip == "gfx942": + # punet int8 f16 attention benchmark + ret_value, output = run_sdxl_punet_int8_fp16_rocm_benchmark(rocm_chip) + benchmark_punet_int8_fp16_mean_time = job_summary_process(ret_value, output) + mean_line = ( + f"Punet F16 Benchmark Time: {str(benchmark_punet_int8_fp16_mean_time)} ms" + f" (golden time {goldentime_rocm_punet_int8_fp16} ms)" + ) + logging.getLogger().info(mean_line) + + # punet int8 f16 compilation stats check + with open(f"{punet_int8_fp16_dir_compile}/compilation_info.json", "r") as file: + comp_stats = json.load(file) + punet_int8_fp16_dispatch_count = int( + comp_stats["stream-aggregate"]["execution"]["dispatch-count"] + ) + compilation_line = ( + f"Punet F16 Dispatch Count: {punet_int8_fp16_dispatch_count}" + f" (golden dispatch count {goldendispatch_rocm_punet_int8_fp16})" + ) + logging.getLogger().info(compilation_line) + + module_path = f"{punet_int8_fp16_dir_compile}/punet.rocm_{rocm_chip}.vmfb" + punet_int8_fp16_binary_size = Path(module_path).stat().st_size + compilation_line = ( + f"Punet F16 Binary Size: {punet_int8_fp16_binary_size} bytes" + f" (golden binary size {goldensize_rocm_punet_int8_fp16} bytes)" + ) + logging.getLogger().info(compilation_line) + # prompt encoder benchmark ret_value, output = run_sdxl_prompt_encoder_rocm_benchmark(rocm_chip) benchmark_clip_mean_time = job_summary_process(ret_value, output) @@ -310,6 +364,10 @@ def test_sdxl_rocm_benchmark( ["Prompt Encoder", f"{benchmark_clip_mean_time}", f"{goldentime_rocm_clip}"], ["VAE Decode", f"{benchmark_vae_mean_time}", f"{goldentime_rocm_vae}"], ] + if rocm_chip == "gfx942": + mean_time_rows.append( + ["Punet F16", f"{benchmark_punet_int8_fp16_mean_time}", f"{goldentime_rocm_punet_int8_fp16}"] + ) # Create dispatch count table's header and rows dispatch_count_header = [ @@ -322,6 +380,10 @@ def test_sdxl_rocm_benchmark( ["Prompt Encoder", f"{clip_dispatch_count}", f"{goldendispatch_rocm_clip}"], ["VAE Decode", f"{vae_dispatch_count}", f"{goldendispatch_rocm_vae}"], ] + if rocm_chip == "gfx942": + dispatch_count_rows.append( + ["Punet F16", f"{punet_int8_fp16_dispatch_count}", f"{goldendispatch_rocm_punet_int8_fp16}"] + ) # Create binary size table's header and rows binary_size_header = [ @@ -334,6 +396,10 @@ def test_sdxl_rocm_benchmark( ["Prompt Encoder", f"{clip_binary_size}", f"{goldensize_rocm_clip}"], ["VAE Decode", f"{vae_binary_size}", f"{goldensize_rocm_vae}"], ] + if rocm_chip == "gfx942": + binary_size_rows.append( + ["Punet F16", f"{punet_int8_fp16_binary_size}", f"{goldensize_rocm_punet_int8_fp16}"] + ) # Create mean time table using tabulate mean_time_full = [mean_time_header] + mean_time_rows @@ -384,6 +450,22 @@ def test_sdxl_rocm_benchmark( goldensize_rocm_unet, "SDXL scheduled unet binary size should not get bigger", ) + if rocm_chip == "gfx942": + check.less_equal( + benchmark_punet_int8_fp16_mean_time, + goldentime_rocm_punet_int8_fp16, + "SDXL punet f16 benchmark time should not regress", + ) + check.equal( + punet_int8_fp16_dispatch_count, + goldendispatch_rocm_punet_int8_fp16, + "SDXL punet f16 dispatch count should not regress", + ) + check.less_equal( + punet_int8_fp16_binary_size, + goldensize_rocm_punet_int8_fp16, + "SDXL punet f16 binary size should not get bigger", + ) check.less_equal( benchmark_clip_mean_time, goldentime_rocm_clip, diff --git a/experimental/benchmarks/sdxl/conftest.py b/experimental/benchmarks/sdxl/conftest.py index df241132c644..369b85103d45 100644 --- a/experimental/benchmarks/sdxl/conftest.py +++ b/experimental/benchmarks/sdxl/conftest.py @@ -14,6 +14,12 @@ def pytest_addoption(parser): type=float, help="Golden time to test benchmark", ) + parser.addoption( + "--goldentime-rocm-punet-int8-fp16-ms", + action="store", + type=float, + help="Golden time to test benchmark", + ) parser.addoption( "--goldentime-rocm-clip-ms", action="store", @@ -33,6 +39,13 @@ def pytest_addoption(parser): type=int, help="Golden dispatch count to test benchmark", ) + parser.addoption( + "--goldendispatch-rocm-punet-int8-fp16", + action="store", + default=1276, + type=int, + help="Golden dispatch count to test benchmark", + ) parser.addoption( "--goldendispatch-rocm-clip", action="store", @@ -54,6 +67,13 @@ def pytest_addoption(parser): type=int, help="Golden vmfb size to test benchmark", ) + parser.addoption( + "--goldensize-rocm-punet-int8-fp16-bytes", + action="store", + default=2065046, + type=int, + help="Golden vmfb size to test benchmark", + ) parser.addoption( "--goldensize-rocm-clip-bytes", action="store", @@ -86,6 +106,9 @@ def goldentime_rocm_e2e(request): def goldentime_rocm_unet(request): return request.config.getoption("--goldentime-rocm-unet-ms") +@pytest.fixture +def goldentime_rocm_punet_int8_fp16(request): + return request.config.getoption("--goldentime-rocm-punet-int8-fp16-ms") @pytest.fixture def goldentime_rocm_clip(request): @@ -101,6 +124,9 @@ def goldentime_rocm_vae(request): def goldendispatch_rocm_unet(request): return request.config.getoption("--goldendispatch-rocm-unet") +@pytest.fixture +def goldendispatch_rocm_punet_int8_fp16(request): + return request.config.getoption("--goldendispatch-rocm-punet-int8-fp16") @pytest.fixture def goldendispatch_rocm_clip(request): @@ -116,6 +142,9 @@ def goldendispatch_rocm_vae(request): def goldensize_rocm_unet(request): return request.config.getoption("--goldensize-rocm-unet-bytes") +@pytest.fixture +def goldensize_rocm_punet_int8_fp16(request): + return request.config.getoption("--goldensize-rocm-punet-int8-fp16-bytes") @pytest.fixture def goldensize_rocm_clip(request): diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index 76e3e6c053c6..d3e64535572f 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -13,6 +13,7 @@ vmfb_dir = os.getenv("TEST_OUTPUT_ARTIFACTS", default=Path.cwd()) rocm_chip = os.getenv("ROCM_CHIP", default="gfx90a") +iree_test_path_extension = os.getenv("IREE_TEST_PATH_EXTENSION", default=Path.cwd()) ############################################################################### # Fixtures @@ -194,7 +195,6 @@ def SDXL_PUNET_INT8_FP8_OUT( f"--iree-hip-target={rocm_chip}", "--iree-opt-const-eval=false", "--iree-global-opt-propagate-transposes=true", - "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true", "--iree-dispatch-creation-enable-aggressive-fusion=true", "--iree-opt-aggressively-propagate-transposes=true", "--iree-opt-outer-dim-concat=true", @@ -214,6 +214,7 @@ def SDXL_PUNET_INT8_FP8_OUT( ] INT8_PUNET_FLAGS = [ + f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_punet.mlir", "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))", ] From e6a597de32570eae1d6dabde39af9a89a2690421 Mon Sep 17 00:00:00 2001 From: saienduri Date: Fri, 8 Nov 2024 18:25:45 -0600 Subject: [PATCH 02/13] change outputs and mlirs Signed-off-by: saienduri Signed-off-by: saienduri --- .github/workflows/pkgci_regression_test.yml | 3 +++ experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py | 2 +- .../shark-test-suite-models/sdxl/test_unet.py | 8 ++++---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 4c36b09384ed..0ab7eec1fc8e 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -247,6 +247,9 @@ jobs: --goldensize-rocm-unet-bytes 2270000 \ --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ + --goldentime-rocm-punet-int8-fp16-ms 55 \ + --goldendispatch-rocm-punet-int8-fp16 1276 \ + --goldensize-rocm-punet-int8-fp16-bytes 2270000 \ --rocm-chip gfx942 \ --log-cli-level=info \ --retries 7 diff --git a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py index d5fb0da0cc33..3403cfd9cb06 100644 --- a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py +++ b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py @@ -16,7 +16,7 @@ vmfb_dir = os.getenv("TEST_OUTPUT_ARTIFACTS", default=Path.cwd()) benchmark_dir = os.path.dirname(os.path.realpath(__file__)) -artifacts_dir = os.getenv("IREE_TEST_FILES", default=Path.cwd()) / "artifacts" +artifacts_dir = os.getenv("IREE_TEST_FILES", default=Path.cwd()) + "/artifacts" artifacts_dir = Path(os.path.expanduser(artifacts_dir)).resolve() prompt_encoder_dir = f"{artifacts_dir}/sdxl_clip" scheduled_unet_dir = f"{artifacts_dir}/sdxl_unet_fp16" diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index d3e64535572f..74dab81191d1 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -94,7 +94,7 @@ ) sdxl_punet_int8_fp16_inference_output_0 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/new_punet_out.0.bin", + "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/11-8-2024/punet_fp16_out.0.bin", group="sdxl_punet_int8_fp16", ) @@ -104,14 +104,14 @@ ) sdxl_punet_int8_fp16_mlir = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/punet.mlir", + "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/11-8-2024/punet_fp16.mlir", group="sdxl_punet_int8_fp16", ) # INT8 Punet + FP8 Attention sdxl_punet_int8_fp8_inference_output_0 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/new_punet_fp8_out.0.bin", + "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/11-8-2024/punet_fp8_out.0.bin", group="sdxl_punet_int8_fp8", ) @@ -121,7 +121,7 @@ ) sdxl_punet_int8_fp8_mlir = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/rob/sdxl-punet/punet_fp8.mlir", + "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/11-8-2024/punet_fp8.mlir", group="sdxl_punet_int8_fp8", ) From a4d128f7d53ca3573f39c6a3be02b510b66787ef Mon Sep 17 00:00:00 2001 From: saienduri Date: Fri, 8 Nov 2024 18:33:05 -0600 Subject: [PATCH 03/13] lint Signed-off-by: saienduri Signed-off-by: saienduri --- .../benchmarks/sdxl/benchmark_sdxl_rocm.py | 20 ++++++++++++++++--- experimental/benchmarks/sdxl/conftest.py | 6 ++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py index 3403cfd9cb06..bd45b452565d 100644 --- a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py +++ b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py @@ -116,6 +116,7 @@ def run_sdxl_unet_rocm_benchmark(rocm_chip): # iree benchmark command for full sdxl pipeline return run_iree_command(exec_args) + def run_sdxl_punet_int8_fp16_rocm_benchmark(rocm_chip): exec_args = [ "iree-benchmark-module", @@ -136,6 +137,7 @@ def run_sdxl_punet_int8_fp16_rocm_benchmark(rocm_chip): # iree benchmark command for full sdxl pipeline return run_iree_command(exec_args) + def run_sdxl_prompt_encoder_rocm_benchmark(rocm_chip): exec_args = [ "iree-benchmark-module", @@ -366,7 +368,11 @@ def test_sdxl_rocm_benchmark( ] if rocm_chip == "gfx942": mean_time_rows.append( - ["Punet F16", f"{benchmark_punet_int8_fp16_mean_time}", f"{goldentime_rocm_punet_int8_fp16}"] + [ + "Punet F16", + f"{benchmark_punet_int8_fp16_mean_time}", + f"{goldentime_rocm_punet_int8_fp16}", + ] ) # Create dispatch count table's header and rows @@ -382,7 +388,11 @@ def test_sdxl_rocm_benchmark( ] if rocm_chip == "gfx942": dispatch_count_rows.append( - ["Punet F16", f"{punet_int8_fp16_dispatch_count}", f"{goldendispatch_rocm_punet_int8_fp16}"] + [ + "Punet F16", + f"{punet_int8_fp16_dispatch_count}", + f"{goldendispatch_rocm_punet_int8_fp16}", + ] ) # Create binary size table's header and rows @@ -398,7 +408,11 @@ def test_sdxl_rocm_benchmark( ] if rocm_chip == "gfx942": binary_size_rows.append( - ["Punet F16", f"{punet_int8_fp16_binary_size}", f"{goldensize_rocm_punet_int8_fp16}"] + [ + "Punet F16", + f"{punet_int8_fp16_binary_size}", + f"{goldensize_rocm_punet_int8_fp16}", + ] ) # Create mean time table using tabulate diff --git a/experimental/benchmarks/sdxl/conftest.py b/experimental/benchmarks/sdxl/conftest.py index 369b85103d45..43d6be07b6a2 100644 --- a/experimental/benchmarks/sdxl/conftest.py +++ b/experimental/benchmarks/sdxl/conftest.py @@ -106,10 +106,12 @@ def goldentime_rocm_e2e(request): def goldentime_rocm_unet(request): return request.config.getoption("--goldentime-rocm-unet-ms") + @pytest.fixture def goldentime_rocm_punet_int8_fp16(request): return request.config.getoption("--goldentime-rocm-punet-int8-fp16-ms") + @pytest.fixture def goldentime_rocm_clip(request): return request.config.getoption("--goldentime-rocm-clip-ms") @@ -124,10 +126,12 @@ def goldentime_rocm_vae(request): def goldendispatch_rocm_unet(request): return request.config.getoption("--goldendispatch-rocm-unet") + @pytest.fixture def goldendispatch_rocm_punet_int8_fp16(request): return request.config.getoption("--goldendispatch-rocm-punet-int8-fp16") + @pytest.fixture def goldendispatch_rocm_clip(request): return request.config.getoption("--goldendispatch-rocm-clip") @@ -142,10 +146,12 @@ def goldendispatch_rocm_vae(request): def goldensize_rocm_unet(request): return request.config.getoption("--goldensize-rocm-unet-bytes") + @pytest.fixture def goldensize_rocm_punet_int8_fp16(request): return request.config.getoption("--goldensize-rocm-punet-int8-fp16-bytes") + @pytest.fixture def goldensize_rocm_clip(request): return request.config.getoption("--goldensize-rocm-clip-bytes") From 2fb0a17d5575b65003c3e8ed01f9eddbb2dfd375 Mon Sep 17 00:00:00 2001 From: saienduri Date: Fri, 8 Nov 2024 19:56:28 -0600 Subject: [PATCH 04/13] add xfail for mi250 compile and change vmfb names Signed-off-by: saienduri --- experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py | 4 ++-- .../shark-test-suite-models/sdxl/test_unet.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py index bd45b452565d..d4e76dbb825b 100644 --- a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py +++ b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py @@ -122,7 +122,7 @@ def run_sdxl_punet_int8_fp16_rocm_benchmark(rocm_chip): "iree-benchmark-module", f"--device=hip", "--device_allocator=caching", - f"--module={punet_int8_fp16_dir_compile}/punet.rocm_{rocm_chip}.vmfb", + f"--module={punet_int8_fp16_dir_compile}/punet_fp16.rocm_{rocm_chip}.vmfb", f"--parameters=model={punet_int8_fp16_dir}/punet_weights.irpa", "--function=main", f"--input=1x4x128x128xf16", @@ -292,7 +292,7 @@ def test_sdxl_rocm_benchmark( ) logging.getLogger().info(compilation_line) - module_path = f"{punet_int8_fp16_dir_compile}/punet.rocm_{rocm_chip}.vmfb" + module_path = f"{punet_int8_fp16_dir_compile}/punet_fp16.rocm_{rocm_chip}.vmfb" punet_int8_fp16_binary_size = Path(module_path).stat().st_size compilation_line = ( f"Punet F16 Binary Size: {punet_int8_fp16_binary_size} bytes" diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index 74dab81191d1..dd867b30be73 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -316,6 +316,13 @@ def test_run_unet_fp16_rocm( def test_compile_punet_int8_fp16_rocm(sdxl_punet_int8_fp16_mlir): + if rocm_chip == "gfx90a": + request.node.add_marker( + pytest.mark.xfail( + reason="Expected punet_int8_fp8 compilation on mi250 to fail", + strict=True, + ) + ) VmfbManager.sdxl_punet_int8_fp16_rocm_vmfb = iree_compile( sdxl_punet_int8_fp16_mlir, ROCM_COMPILE_FLAGS + INT8_PUNET_FLAGS, From 6911ec3a656ae6fad93f81e1489d31dad7277a4c Mon Sep 17 00:00:00 2001 From: saienduri Date: Fri, 8 Nov 2024 21:01:43 -0600 Subject: [PATCH 05/13] add back hz fusion Signed-off-by: saienduri --- .../shark-test-suite-models/sdxl/test_unet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index dd867b30be73..93e9a5131494 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -12,7 +12,7 @@ from pathlib import Path vmfb_dir = os.getenv("TEST_OUTPUT_ARTIFACTS", default=Path.cwd()) -rocm_chip = os.getenv("ROCM_CHIP", default="gfx90a") +rocm_chip = os.getenv("ROCM_CHIP", default="gfx942") iree_test_path_extension = os.getenv("IREE_TEST_PATH_EXTENSION", default=Path.cwd()) ############################################################################### @@ -195,6 +195,7 @@ def SDXL_PUNET_INT8_FP8_OUT( f"--iree-hip-target={rocm_chip}", "--iree-opt-const-eval=false", "--iree-global-opt-propagate-transposes=true", + "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true", "--iree-dispatch-creation-enable-aggressive-fusion=true", "--iree-opt-aggressively-propagate-transposes=true", "--iree-opt-outer-dim-concat=true", @@ -315,7 +316,7 @@ def test_run_unet_fp16_rocm( ) -def test_compile_punet_int8_fp16_rocm(sdxl_punet_int8_fp16_mlir): +def test_compile_punet_int8_fp16_rocm(request, sdxl_punet_int8_fp16_mlir): if rocm_chip == "gfx90a": request.node.add_marker( pytest.mark.xfail( From b0da27ebde60cab4fc4a44c3822cc71a9bb48529 Mon Sep 17 00:00:00 2001 From: saienduri Date: Mon, 11 Nov 2024 04:43:06 -0600 Subject: [PATCH 06/13] add fp8 benchmarks Signed-off-by: saienduri --- .github/workflows/pkgci_regression_test.yml | 5 +- .../benchmarks/sdxl/benchmark_sdxl_rocm.py | 91 ++++++++++++++++++- experimental/benchmarks/sdxl/conftest.py | 35 +++++++ 3 files changed, 129 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 0ab7eec1fc8e..ef38f9abd270 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -248,8 +248,11 @@ jobs: --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ --goldentime-rocm-punet-int8-fp16-ms 55 \ - --goldendispatch-rocm-punet-int8-fp16 1276 \ + --goldendispatch-rocm-punet-int8-fp16 1284 \ --goldensize-rocm-punet-int8-fp16-bytes 2270000 \ + --goldentime-rocm-punet-int8-fp8-ms 59 \ + --goldendispatch-rocm-punet-int8-fp8 1556 \ + --goldensize-rocm-punet-int8-fp8-bytes 2244086 \ --rocm-chip gfx942 \ --log-cli-level=info \ --retries 7 diff --git a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py index d4e76dbb825b..f5f4be3a3ca5 100644 --- a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py +++ b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py @@ -16,15 +16,17 @@ vmfb_dir = os.getenv("TEST_OUTPUT_ARTIFACTS", default=Path.cwd()) benchmark_dir = os.path.dirname(os.path.realpath(__file__)) -artifacts_dir = os.getenv("IREE_TEST_FILES", default=Path.cwd()) + "/artifacts" +artifacts_dir = f"{os.getenv('IREE_TEST_FILES', default=Path.cwd())}/artifacts" artifacts_dir = Path(os.path.expanduser(artifacts_dir)).resolve() prompt_encoder_dir = f"{artifacts_dir}/sdxl_clip" scheduled_unet_dir = f"{artifacts_dir}/sdxl_unet_fp16" punet_int8_fp16_dir = f"{artifacts_dir}/sdxl_punet_int8_fp16" +punet_int8_fp8_dir = f"{artifacts_dir}/sdxl_punet_int8_fp8" vae_decode_dir = f"{artifacts_dir}/sdxl_vae" prompt_encoder_dir_compile = f"{vmfb_dir}/sdxl_clip_vmfbs" scheduled_unet_dir_compile = f"{vmfb_dir}/sdxl_unet_fp16_vmfbs" punet_int8_fp16_dir_compile = f"{vmfb_dir}/sdxl_punet_int8_fp16_vmfbs" +punet_int8_fp8_dir_compile = f"{vmfb_dir}/sdxl_punet_int8_fp8_vmfbs" vae_decode_dir_compile = f"{vmfb_dir}/sdxl_vae_vmfbs" @@ -137,6 +139,25 @@ def run_sdxl_punet_int8_fp16_rocm_benchmark(rocm_chip): # iree benchmark command for full sdxl pipeline return run_iree_command(exec_args) +def run_sdxl_punet_int8_fp8_rocm_benchmark(rocm_chip): + exec_args = [ + "iree-benchmark-module", + f"--device=hip", + "--device_allocator=caching", + f"--module={punet_int8_fp8_dir_compile}/punet_fp8.rocm_{rocm_chip}.vmfb", + f"--parameters=model={punet_int8_fp8_dir}/punet_fp8_weights.irpa", + "--function=main", + f"--input=1x4x128x128xf16", + f"--input=1xf16", + f"--input=2x64x2048xf16", + f"--input=2x1280xf16", + f"--input=2x6xf16", + f"--input=1xf16", + "--benchmark_repetitions=10", + "--benchmark_min_warmup_time=3.0", + ] + # iree benchmark command for full sdxl pipeline + return run_iree_command(exec_args) def run_sdxl_prompt_encoder_rocm_benchmark(rocm_chip): exec_args = [ @@ -220,15 +241,18 @@ def test_sdxl_rocm_benchmark( goldentime_rocm_e2e, goldentime_rocm_unet, goldentime_rocm_punet_int8_fp16, + goldentime_rocm_punet_int8_fp8, goldentime_rocm_clip, goldentime_rocm_vae, rocm_chip, goldendispatch_rocm_unet, goldendispatch_rocm_punet_int8_fp16, + goldendispatch_rocm_punet_int8_fp8, goldendispatch_rocm_clip, goldendispatch_rocm_vae, goldensize_rocm_unet, goldensize_rocm_punet_int8_fp16, + goldensize_rocm_punet_int8_fp8, goldensize_rocm_clip, goldensize_rocm_vae, ): @@ -300,6 +324,35 @@ def test_sdxl_rocm_benchmark( ) logging.getLogger().info(compilation_line) + # punet int8 f8 attention benchmark + ret_value, output = run_sdxl_punet_int8_fp8_rocm_benchmark(rocm_chip) + benchmark_punet_int8_fp8_mean_time = job_summary_process(ret_value, output) + mean_line = ( + f"Punet F8 Benchmark Time: {str(benchmark_punet_int8_fp8_mean_time)} ms" + f" (golden time {goldentime_rocm_punet_int8_fp8} ms)" + ) + logging.getLogger().info(mean_line) + + # punet int8 f8 compilation stats check + with open(f"{punet_int8_fp8_dir_compile}/compilation_info.json", "r") as file: + comp_stats = json.load(file) + punet_int8_fp8_dispatch_count = int( + comp_stats["stream-aggregate"]["execution"]["dispatch-count"] + ) + compilation_line = ( + f"Punet F8 Dispatch Count: {punet_int8_fp8_dispatch_count}" + f" (golden dispatch count {goldendispatch_rocm_punet_int8_fp8})" + ) + logging.getLogger().info(compilation_line) + + module_path = f"{punet_int8_fp8_dir_compile}/punet_fp8.rocm_{rocm_chip}.vmfb" + punet_int8_fp8_binary_size = Path(module_path).stat().st_size + compilation_line = ( + f"Punet F8 Binary Size: {punet_int8_fp8_binary_size} bytes" + f" (golden binary size {goldensize_rocm_punet_int8_fp8} bytes)" + ) + logging.getLogger().info(compilation_line) + # prompt encoder benchmark ret_value, output = run_sdxl_prompt_encoder_rocm_benchmark(rocm_chip) benchmark_clip_mean_time = job_summary_process(ret_value, output) @@ -374,6 +427,13 @@ def test_sdxl_rocm_benchmark( f"{goldentime_rocm_punet_int8_fp16}", ] ) + mean_time_rows.append( + [ + "Punet F8", + f"{benchmark_punet_int8_fp8_mean_time}", + f"{goldentime_rocm_punet_int8_fp8}", + ] + ) # Create dispatch count table's header and rows dispatch_count_header = [ @@ -394,6 +454,13 @@ def test_sdxl_rocm_benchmark( f"{goldendispatch_rocm_punet_int8_fp16}", ] ) + dispatch_count_rows.append( + [ + "Punet F8", + f"{punet_int8_fp8_dispatch_count}", + f"{goldendispatch_rocm_punet_int8_fp8}", + ] + ) # Create binary size table's header and rows binary_size_header = [ @@ -414,6 +481,13 @@ def test_sdxl_rocm_benchmark( f"{goldensize_rocm_punet_int8_fp16}", ] ) + binary_size_rows.append( + [ + "Punet F8", + f"{punet_int8_fp8_binary_size}", + f"{goldensize_rocm_punet_int8_fp8}", + ] + ) # Create mean time table using tabulate mean_time_full = [mean_time_header] + mean_time_rows @@ -480,6 +554,21 @@ def test_sdxl_rocm_benchmark( goldensize_rocm_punet_int8_fp16, "SDXL punet f16 binary size should not get bigger", ) + check.less_equal( + benchmark_punet_int8_fp8_mean_time, + goldentime_rocm_punet_int8_fp8, + "SDXL punet f8 benchmark time should not regress", + ) + check.equal( + punet_int8_fp8_dispatch_count, + goldendispatch_rocm_punet_int8_fp8, + "SDXL punet f8 dispatch count should not regress", + ) + check.less_equal( + punet_int8_fp8_binary_size, + goldensize_rocm_punet_int8_fp8, + "SDXL punet f8 binary size should not get bigger", + ) check.less_equal( benchmark_clip_mean_time, goldentime_rocm_clip, diff --git a/experimental/benchmarks/sdxl/conftest.py b/experimental/benchmarks/sdxl/conftest.py index 43d6be07b6a2..1a7102386a71 100644 --- a/experimental/benchmarks/sdxl/conftest.py +++ b/experimental/benchmarks/sdxl/conftest.py @@ -20,6 +20,12 @@ def pytest_addoption(parser): type=float, help="Golden time to test benchmark", ) + parser.addoption( + "--goldentime-rocm-punet-int8-fp8-ms", + action="store", + type=float, + help="Golden time to test benchmark", + ) parser.addoption( "--goldentime-rocm-clip-ms", action="store", @@ -42,6 +48,13 @@ def pytest_addoption(parser): parser.addoption( "--goldendispatch-rocm-punet-int8-fp16", action="store", + default=1284, + type=int, + help="Golden dispatch count to test benchmark", + ) + parser.addoption( + "--goldendispatch-rocm-punet-int8-fp8", + action="store", default=1276, type=int, help="Golden dispatch count to test benchmark", @@ -74,6 +87,13 @@ def pytest_addoption(parser): type=int, help="Golden vmfb size to test benchmark", ) + parser.addoption( + "--goldensize-rocm-punet-int8-fp8-bytes", + action="store", + default=2065046, + type=int, + help="Golden vmfb size to test benchmark", + ) parser.addoption( "--goldensize-rocm-clip-bytes", action="store", @@ -112,6 +132,11 @@ def goldentime_rocm_punet_int8_fp16(request): return request.config.getoption("--goldentime-rocm-punet-int8-fp16-ms") +@pytest.fixture +def goldentime_rocm_punet_int8_fp8(request): + return request.config.getoption("--goldentime-rocm-punet-int8-fp8-ms") + + @pytest.fixture def goldentime_rocm_clip(request): return request.config.getoption("--goldentime-rocm-clip-ms") @@ -132,6 +157,11 @@ def goldendispatch_rocm_punet_int8_fp16(request): return request.config.getoption("--goldendispatch-rocm-punet-int8-fp16") +@pytest.fixture +def goldendispatch_rocm_punet_int8_fp8(request): + return request.config.getoption("--goldendispatch-rocm-punet-int8-fp8") + + @pytest.fixture def goldendispatch_rocm_clip(request): return request.config.getoption("--goldendispatch-rocm-clip") @@ -152,6 +182,11 @@ def goldensize_rocm_punet_int8_fp16(request): return request.config.getoption("--goldensize-rocm-punet-int8-fp16-bytes") +@pytest.fixture +def goldensize_rocm_punet_int8_fp8(request): + return request.config.getoption("--goldensize-rocm-punet-int8-fp8-bytes") + + @pytest.fixture def goldensize_rocm_clip(request): return request.config.getoption("--goldensize-rocm-clip-bytes") From c3a38e7ae4dab9e899cd61c7574443466ae0468f Mon Sep 17 00:00:00 2001 From: saienduri Date: Mon, 11 Nov 2024 04:45:20 -0600 Subject: [PATCH 07/13] lint Signed-off-by: saienduri --- experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py index f5f4be3a3ca5..db1610928477 100644 --- a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py +++ b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py @@ -139,6 +139,7 @@ def run_sdxl_punet_int8_fp16_rocm_benchmark(rocm_chip): # iree benchmark command for full sdxl pipeline return run_iree_command(exec_args) + def run_sdxl_punet_int8_fp8_rocm_benchmark(rocm_chip): exec_args = [ "iree-benchmark-module", @@ -159,6 +160,7 @@ def run_sdxl_punet_int8_fp8_rocm_benchmark(rocm_chip): # iree benchmark command for full sdxl pipeline return run_iree_command(exec_args) + def run_sdxl_prompt_encoder_rocm_benchmark(rocm_chip): exec_args = [ "iree-benchmark-module", From acfcd6e69df7b4d6c3230c87b779eeb9ad680684 Mon Sep 17 00:00:00 2001 From: saienduri Date: Mon, 11 Nov 2024 05:11:15 -0600 Subject: [PATCH 08/13] add timeouts and update golden values Signed-off-by: saienduri --- .github/workflows/pkgci_regression_test.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index ef38f9abd270..f3facd282541 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -227,6 +227,7 @@ jobs: --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ --rocm-chip gfx90a \ + --timeout=1200 \ --log-cli-level=info \ --retries 7 echo "$(> $GITHUB_STEP_SUMMARY @@ -249,11 +250,12 @@ jobs: --goldensize-rocm-vae-bytes 840000 \ --goldentime-rocm-punet-int8-fp16-ms 55 \ --goldendispatch-rocm-punet-int8-fp16 1284 \ - --goldensize-rocm-punet-int8-fp16-bytes 2270000 \ + --goldensize-rocm-punet-int8-fp16-bytes 2560000 \ --goldentime-rocm-punet-int8-fp8-ms 59 \ - --goldendispatch-rocm-punet-int8-fp8 1556 \ - --goldensize-rocm-punet-int8-fp8-bytes 2244086 \ + --goldendispatch-rocm-punet-int8-fp8 1564 \ + --goldensize-rocm-punet-int8-fp8-bytes 2800000 \ --rocm-chip gfx942 \ --log-cli-level=info \ + --timeout=1200 \ --retries 7 echo "$(> $GITHUB_STEP_SUMMARY From ec222cff13f679ebfd7ef9bcd8e118524628447c Mon Sep 17 00:00:00 2001 From: saienduri Date: Mon, 11 Nov 2024 12:53:17 -0600 Subject: [PATCH 09/13] update per test timeout Signed-off-by: saienduri --- .github/workflows/pkgci_regression_test.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index f3facd282541..950bf48a2b35 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -112,7 +112,7 @@ jobs: --no-skip-tests-missing-files \ --capture=no \ --log-cli-level=info \ - --timeout=1200 \ + --timeout=240 \ --durations=0 \ --config-files=${MODELS_CONFIG_FILE_PATH} @@ -189,7 +189,7 @@ jobs: -rpfE \ --capture=no \ --log-cli-level=info \ - --timeout=1200 \ + --timeout=240 \ --durations=0 env: ROCM_CHIP: ${{ matrix.rocm-chip }} @@ -203,7 +203,7 @@ jobs: -rpfE \ --capture=no \ --log-cli-level=info \ - --timeout=1200 \ + --timeout=240 \ --durations=0 env: ROCM_CHIP: ${{ matrix.rocm-chip }} @@ -227,7 +227,7 @@ jobs: --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ --rocm-chip gfx90a \ - --timeout=1200 \ + --timeout=240 \ --log-cli-level=info \ --retries 7 echo "$(> $GITHUB_STEP_SUMMARY @@ -256,6 +256,6 @@ jobs: --goldensize-rocm-punet-int8-fp8-bytes 2800000 \ --rocm-chip gfx942 \ --log-cli-level=info \ - --timeout=1200 \ + --timeout=240 \ --retries 7 echo "$(> $GITHUB_STEP_SUMMARY From f2550953b63ad930d587c2f7ffb0284cbda39aa2 Mon Sep 17 00:00:00 2001 From: saienduri Date: Mon, 11 Nov 2024 13:32:16 -0600 Subject: [PATCH 10/13] update fp8 dispatch count Signed-off-by: saienduri --- .github/workflows/pkgci_regression_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 950bf48a2b35..e95bf545acf7 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -252,7 +252,7 @@ jobs: --goldendispatch-rocm-punet-int8-fp16 1284 \ --goldensize-rocm-punet-int8-fp16-bytes 2560000 \ --goldentime-rocm-punet-int8-fp8-ms 59 \ - --goldendispatch-rocm-punet-int8-fp8 1564 \ + --goldendispatch-rocm-punet-int8-fp8 1563 \ --goldensize-rocm-punet-int8-fp8-bytes 2800000 \ --rocm-chip gfx942 \ --log-cli-level=info \ From 8a61b66d7b8254bfffc577b98fd45c25be975329 Mon Sep 17 00:00:00 2001 From: saienduri Date: Mon, 11 Nov 2024 14:03:14 -0600 Subject: [PATCH 11/13] back to 1564 Signed-off-by: saienduri --- .github/workflows/pkgci_regression_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index e95bf545acf7..950bf48a2b35 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -252,7 +252,7 @@ jobs: --goldendispatch-rocm-punet-int8-fp16 1284 \ --goldensize-rocm-punet-int8-fp16-bytes 2560000 \ --goldentime-rocm-punet-int8-fp8-ms 59 \ - --goldendispatch-rocm-punet-int8-fp8 1563 \ + --goldendispatch-rocm-punet-int8-fp8 1564 \ --goldensize-rocm-punet-int8-fp8-bytes 2800000 \ --rocm-chip gfx942 \ --log-cli-level=info \ From 81e5bc2840e576a666b5260b2f892734b6a3d12a Mon Sep 17 00:00:00 2001 From: saienduri Date: Tue, 12 Nov 2024 13:43:06 -0600 Subject: [PATCH 12/13] try the latest spec Signed-off-by: saienduri --- .../attention_and_matmul_spec_punet.mlir | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir index c92ddd0f9cde..a566203907e4 100644 --- a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir +++ b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir @@ -51,8 +51,8 @@ transform.named_sequence @match_attention_f16(%attention: !transform.any_op {tra transform.iree.match.cast_compatible_type %in0 = tensor : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 128, 0, 0, 0], reduction=[0, 0, 0, 0, 0, 32], promote_operands = [1, 2]}>, - translation_info = #iree_codegen.translation_info, + translation_info = #iree_codegen.translation_info> @@ -76,8 +76,8 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran transform.iree.match.cast_compatible_type %in0 = tensor : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 0], reduction=[0, 0, 0, 0, 32, 0], promote_operands = [1, 2]}>, - translation_info = #iree_codegen.translation_info, + translation_info = #iree_codegen.translation_info> @@ -125,7 +125,7 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran subgroup_m_count = 4, subgroup_n_count = 2, reduction = [0, 0, 0, 128], workgroup = [1, 128, 320, 0]}>, - translation_info = #iree_codegen.translation_info}> > -> !transform.any_param @@ -144,7 +144,7 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran subgroup_m_count = 2, subgroup_n_count = 2, reduction = [0, 0, 0, 128], workgroup = [1, 64, 160, 0]}>, - translation_info = #iree_codegen.translation_info>}> @@ -164,7 +164,7 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran subgroup_m_count = 2, subgroup_n_count = 4, reduction = [0, 0, 0, 64], workgroup = [1, 256, 128, 0]}>, - translation_info = #iree_codegen.translation_info}> > -> !transform.any_param @@ -199,7 +199,7 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran subgroup_m_count = 2, subgroup_n_count = 2, reduction = [0, 0, 0, 0, 128], workgroup = [1, 1, 64, 160, 0]}>, - translation_info = #iree_codegen.translation_info> @@ -232,7 +232,7 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran subgroup_m_count = 2, subgroup_n_count = 1, reduction = [0, 0, 0, 0, 128], workgroup = [1, 1, 32, 320, 0]}>, - translation_info = #iree_codegen.translation_info}> > -> !transform.any_param @@ -263,7 +263,7 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran subgroup_m_count = 8, subgroup_n_count = 1, reduction = [0, 0, 0, 0, 64], workgroup = [1, 1, 256, 64, 0]}>, - translation_info = #iree_codegen.translation_info}> > -> !transform.any_param From 9aaf51791a79d79f3bfe72027a11f070a7c25151 Mon Sep 17 00:00:00 2001 From: saienduri Date: Wed, 13 Nov 2024 01:09:53 -0800 Subject: [PATCH 13/13] update outs for fp16 and fp8 Signed-off-by: saienduri --- .../shark-test-suite-models/sdxl/test_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index 93e9a5131494..990b35fc1688 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -94,7 +94,7 @@ ) sdxl_punet_int8_fp16_inference_output_0 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/11-8-2024/punet_fp16_out.0.bin", + "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/11-13-2024/punet_fp16_out.0.bin", group="sdxl_punet_int8_fp16", ) @@ -111,7 +111,7 @@ # INT8 Punet + FP8 Attention sdxl_punet_int8_fp8_inference_output_0 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/11-8-2024/punet_fp8_out.0.bin", + "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/11-13-2024/punet_fp8_out.0.bin", group="sdxl_punet_int8_fp8", )