From 8e5f218a1d73c39ccdebb4a1d19225d5329cc6f0 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 7 Nov 2024 16:40:15 -0500 Subject: [PATCH] GPU data tiling on RDNA3 (#18980) A few things were needed: * Populate required fields in KnownTargets.cpp. * Support the case where the intrinsic vector operand size is greater than the load instruction size (here it is 16xf16 = 256 bit). * `buildMmaOperation` creates `vector.insert_strided_slice` to insert the new accumulator vectors into the accumulator tile. In doing so, it was relying on `vector.insert_strided_slice` implicit expand-shape semantics, in ways that worked for the shapes we had seen in CDNA3 but not here. Solved by explicitly expanding shapes with `vector.shape_cast` ops. * In thread-distribution code (populateOperandXxx), we needed to account for the nuance between two distinct thread grids: "layout" vs "distribution". In the case of RDNA3, there is a distribution-only dimension that isn't reflected in the layout-centric TileSwizzle's. * On RDNA3, some float arithmetic is strongly non-IEEE754-compliant: even exactly-representable small integral values, on which float arithmetic should be exact, have epsilon numerical errors! Addressed by tolerance. * Fix a bug: the doubly-nullable type `std::optional` tricked us, change to `IREE::GPU::DataTiledMMAAttr`. --------- Signed-off-by: Benoit Jacob --- .../Common/GPU/GPUMaterializeEncoding.cpp | 15 ++-- .../Codegen/Common/GPU/test/BUILD.bazel | 3 +- .../Codegen/Common/GPU/test/CMakeLists.txt | 3 +- .../gpu_materialize_encoding_gfx1100.mlir | 60 +++++++++++++ ...r => gpu_materialize_encoding_gfx942.mlir} | 0 .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 84 +++++++++++++++---- .../Dialect/GPU/TargetUtils/KnownTargets.cpp | 5 +- tests/e2e/matmul/CMakeLists.txt | 61 ++++++++++++++ 8 files changed, 201 insertions(+), 30 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir rename compiler/src/iree/compiler/Codegen/Common/GPU/test/{gpu_materialize_encoding.mlir => gpu_materialize_encoding_gfx942.mlir} (100%) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp index 778cd082736a..4d494d952711 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp @@ -103,12 +103,8 @@ chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target, // unrollK=4 to turn 4 separate 32-bit loads into one 128-bit load. int intrinsicLoadBits = std::min(sizeInBits(intrinsicA), sizeInBits(intrinsicB)); - if (*wgp.getMaxLoadInstructionBits() % intrinsicLoadBits != 0) { - // Never seen that case: the ISA does not have a suitable load instruction - // to feed that intrinsic?! - return {}; - } - const int unrollK = *wgp.getMaxLoadInstructionBits() / intrinsicLoadBits; + const int unrollK = + std::max(1, *wgp.getMaxLoadInstructionBits() / intrinsicLoadBits); // The total amount of unrolling along the M and N dimensions is normally // limited only by the number of available registers, since larger M and N @@ -493,13 +489,13 @@ class GPUConvertToMultiMma final } else { gpuTargetAttr = getCLGPUTarget(op.getContext()); } - std::optional mma = chooseDataTiledMMAAttr( + IREE::GPU::DataTiledMMAAttr mma = chooseDataTiledMMAAttr( resultEncoding.getElementTypesArray(), gpuTargetAttr, resultEncoding); if (!mma) { LLVM_DEBUG(llvm::dbgs() << "can't find supported Mma intrinsic\n"); return failure(); } - LLVM_DEBUG(llvm::dbgs() << "Target MMA: " << mma.value() << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Target MMA: " << mma << "\n"); FailureOr contractionDims = linalg::inferContractionDims(linalgOp); @@ -535,8 +531,7 @@ class GPUConvertToMultiMma final Location loc = op.getLoc(); auto mmaOp = rewriter.create( loc, operands[0], operands[1], operands[2], - ArrayRef{lhsMap, rhsMap, accMap}, iteratorTypes, - mma.value()); + ArrayRef{lhsMap, rhsMap, accMap}, iteratorTypes, mma); rewriter.replaceOp(op, mmaOp); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index 967780cfad59..126ebe675715 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -30,7 +30,8 @@ iree_lit_test_suite( "gpu_infer_memory_space.mlir", "gpu_lower_to_ukernels.mlir", "gpu_combine_value_barriers.mlir", - "gpu_materialize_encoding.mlir", + "gpu_materialize_encoding_gfx942.mlir", + "gpu_materialize_encoding_gfx1100.mlir", "gpu_nested_layout_contract_amdgpu.mlir", "gpu_nested_layout_vector_distribution.mlir", "gpu_pipeline.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index d1b0b3ae59e0..cddf4ebf17f1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -26,7 +26,8 @@ iree_lit_test_suite( "gpu_greedily_distribute_to_threads.mlir" "gpu_infer_memory_space.mlir" "gpu_lower_to_ukernels.mlir" - "gpu_materialize_encoding.mlir" + "gpu_materialize_encoding_gfx1100.mlir" + "gpu_materialize_encoding_gfx942.mlir" "gpu_nested_layout_contract_amdgpu.mlir" "gpu_nested_layout_vector_distribution.mlir" "gpu_pipeline.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir new file mode 100644 index 000000000000..f6e944544ec6 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir @@ -0,0 +1,60 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" \ +// RUN: --iree-gpu-test-target=gfx1100 \ +// RUN: --split-input-file %s | FileCheck %s + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +func.func @matmul_lowering_WMMA_F32_16x16x16_F16() { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK: func.func @matmul_lowering_WMMA_F32_16x16x16_F16 +// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0) +// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1) +// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2) +// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor +// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]], +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout +// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]] diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx942.mlir similarity index 100% rename from compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir rename to compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx942.mlir diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index a09ae277819f..3cc43be79cfc 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -955,9 +955,6 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides( Value threadId, ArrayRef permutation, SmallVector &offsets, SmallVector &sizes, SmallVector &strides) const { - // TODO(bjacob): Support WMMA intrinsics. - - // Get the swizzle describing the internal layout of this fragment. TileSwizzle swizzle = getSwizzle(*this, fragment); LLVM_DEBUG({ @@ -966,36 +963,79 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides( DBGS() << " swizzle: " << swizzle << "\n"; }); - // Populate tile sizes. MLIRContext *ctx = builder.getContext(); SmallVector tileSizes = getAsIndexOpFoldResult( ctx, sliceSwizzledShape(swizzle, [](TileSwizzle::Dim d) { return d.kind != TileSwizzle::Dim::Kind::CrossThread; })); - // Populate tile offsets by delinearizing threadId over the CrossThread dims. - // Since the AffineDelinearizeIndexOp does not bound the input index, we - // must bound the threadId by the product of the offset ranges. - SmallVector tileOffsetsBasis = + // Most of the rest of this function is the computation of the offsets. + // The basic idea is to delinearize the threadId over the basis of + // cross-thread dimensions. These cross-thread dimensions may be either + // the intrinsic's own, or they may come from expansion to multiple subgroups. + // Normally, that distinction is irrelevant here: we just delinearize the + // thread-id over all cross-thread dimensions. + // + // There is one case that makes things more complicated, encountered so far + // only on RDNA3. That is when some intrinsic has multiple (so far, 2) threads + // reading the same data. This redundancy is not encoded in the TileSwizzle + // structures that we are using here. Instead, in that case, the thread grid + // (as encoded in the TileSwizzle) is smaller than the subgroup size. In that + // case, there is an implied thread-distribution-only dimension along which + // multiple threads read exactly the same data. + // So we need to distinguish layoutThreadSizes vs. distributionThreadSizes. + SmallVector layoutThreadSizes = sliceSwizzledShape(swizzle, [](TileSwizzle::Dim d) { return d.kind == TileSwizzle::Dim::Kind::CrossThread; }); - - // Bound for threadId is the product of tileOffsetsBasis. + // In layoutThreadSizes, intrinsic level dimensions are mixed with expansion + // to multiple subgroups, so in order to tell if there are additional + // distribution-only thread dimensions, we need to get back to the intrinsic. + TileSwizzle intrinsicSwizzle = + getIntrinsicSwizzle(getIntrinsic().getValue(), fragment); + SmallVector intrinsicLayoutThreadSizes = + sliceSwizzledShape(intrinsicSwizzle, [](TileSwizzle::Dim d) { + return d.kind == TileSwizzle::Dim::Kind::CrossThread; + }); + int64_t intrinsicLayoutThreadBound = + ShapedType::getNumElements(intrinsicLayoutThreadSizes); + SmallVector distributionThreadSizes = layoutThreadSizes; + int distributionOnlyDimIdx = + distributionThreadSizes.size() - intrinsicLayoutThreadSizes.size(); + // Now we are able to tell if there is an extra distribution-only dimension. + bool hasDistributionOnlyDim = intrinsicLayoutThreadBound < getSubgroupSize(); + if (hasDistributionOnlyDim) { + // Insert the extra distribution-only dimension. This will need to be paired + // below with erasing the corresponding dim out of the delinearized indices. + distributionThreadSizes.insert( + distributionThreadSizes.begin() + distributionOnlyDimIdx, + getSubgroupSize() / intrinsicLayoutThreadBound); + } + + // AffineDelinearizeIndexOp requires an in-bounds input index, so we bound it. OpFoldResult threadIdBound = - builder.getIndexAttr(ShapedType::getNumElements(tileOffsetsBasis)); + builder.getIndexAttr(ShapedType::getNumElements(distributionThreadSizes)); AffineExpr d0 = builder.getAffineDimExpr(0), d1 = builder.getAffineDimExpr(1); OpFoldResult boundedThreadId = affine::makeComposedFoldedAffineApply( builder, loc, {d0 % d1}, {threadId, threadIdBound}); + // Obtain the offsets from delinearization along the distributionThreadSizes. SmallVector tileOffsets = builder .create( loc, getValueOrCreateConstantIndexOp(builder, loc, boundedThreadId), - getAsIndexOpFoldResult(ctx, tileOffsetsBasis)) + getAsIndexOpFoldResult(ctx, distributionThreadSizes)) ->getResults(); + if (hasDistributionOnlyDim) { + // Erase the delinearized index that corresponds to the extra distribution + // dimension that we had inserted above. This is what causes multiple + // threads (which only differed in the index being discarded here) to read + // exactly the same data. + tileOffsets.erase(tileOffsets.begin() + distributionOnlyDimIdx); + } + // Strides are trivial: each slice is contiguous along the *expanded* dims // even if it may not be contiguous in the flattened layout. SmallVector tileStrides(tileSizes.size(), @@ -1071,8 +1111,6 @@ FailureOr DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder, Type resultType, Value lhs, Value rhs, Value acc) const { - // TODO(bjacob): Support WMMA intrinsics. - // Validation. Similar to MMAAttr::buildMmaOperation. auto [aType, bType, cType] = getABCVectorTypes(); if (aType != lhs.getType() || bType != rhs.getType() || @@ -1137,15 +1175,27 @@ FailureOr DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder, sliceSwizzledShape(accSwizzle, [](TileSwizzle::Dim dim) { return dim.kind == TileSwizzle::Dim::Kind::CrossIntrinsic; }); + SmallVector accInternalShape = + sliceSwizzledShape(accSwizzle, [](TileSwizzle::Dim dim) { + return dim.kind == TileSwizzle::Dim::Kind::Internal; + }); + LLVM_DEBUG({ DBGS() << "accCrossIntrinsicShape: "; llvm::interleaveComma(accCrossIntrinsicShape, llvm::dbgs()); llvm::dbgs() << "\n"; + DBGS() << "accInternalShape: "; + llvm::interleaveComma(accInternalShape, llvm::dbgs()); + llvm::dbgs() << "\n"; }); - SmallVector strides(intrinsicCType.getRank(), 1); - SmallVector indices(accCrossIntrinsicShape.size(), 0); + int dstRank = accCrossIntrinsicShape.size(); + SmallVector strides(dstRank, 1); + SmallVector indices(dstRank, 0); for (Value intrAcc : intrinsicsAcc) { - acc = builder.create(loc, intrAcc, acc, + auto expandedAcc = builder.create( + loc, VectorType::get(accInternalShape, cType.getElementType()), + intrAcc); + acc = builder.create(loc, expandedAcc, acc, indices, strides); incrementIndices(indices, accCrossIntrinsicShape); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index 5e8f031ff8ac..82fd46d9be2c 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -216,7 +216,10 @@ const WgpDetails *getRDNA3WgpDetails() { {1024, 1024, 1024}, 1024, 64 * 1024, - {0x7fffffff, 0x7fffffff, 0x7fffffff}}; + {0x7fffffff, 0x7fffffff, 0x7fffffff}, + /*maxLoadInstructionBits=*/128, + /*simdsPerWgp=*/4, + /*vgprSpaceBits=*/256 * 32}; return &rdna3Wgp; } diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt index a70cb98b3003..22a701efded4 100644 --- a/tests/e2e/matmul/CMakeLists.txt +++ b/tests/e2e/matmul/CMakeLists.txt @@ -1788,4 +1788,65 @@ iree_generated_e2e_runner_test( "requires-gpu-rdna3" ) +iree_generated_e2e_runner_test( + NAME + e2e_matmul_rdna3_dt_f16 + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=f16" + "--acc_type=f32" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-opt-data-tiling" + "--iree-global-opt-experimental-rocm-data-tiling" + "--iree-global-opt-enable-early-materialization=true" + RUNNER_ARGS + "--require_exact_results=false" + "--acceptable_fp_delta=1e-04" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-rdna3" +) + +iree_generated_e2e_runner_test( + NAME + e2e_matmul_rdna3_dt_i8 + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=i8" + "--acc_type=i32" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-opt-data-tiling" + "--iree-global-opt-experimental-rocm-data-tiling" + "--iree-global-opt-enable-early-materialization=true" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-rdna3" +) + endif()