Skip to content

Commit

Permalink
GPU data tiling on RDNA3 (#18980)
Browse files Browse the repository at this point in the history
A few things were needed:
* Populate required fields in KnownTargets.cpp.
* Support the case where the intrinsic vector operand size is greater
than the load instruction size (here it is 16xf16 = 256 bit).
* `buildMmaOperation` creates `vector.insert_strided_slice` to insert
the new accumulator vectors into the accumulator tile. In doing so, it
was relying on `vector.insert_strided_slice` implicit expand-shape
semantics, in ways that worked for the shapes we had seen in CDNA3 but
not here. Solved by explicitly expanding shapes with `vector.shape_cast`
ops.
* In thread-distribution code (populateOperandXxx), we needed to account
for the nuance between two distinct thread grids: "layout" vs
"distribution". In the case of RDNA3, there is a distribution-only
dimension that isn't reflected in the layout-centric TileSwizzle's.
* On RDNA3, some float arithmetic is strongly non-IEEE754-compliant:
even exactly-representable small integral values, on which float
arithmetic should be exact, have epsilon numerical errors! Addressed by
tolerance.
* Fix a bug: the doubly-nullable type
`std::optional<IREE::GPU::DataTiledMMAAttr>` tricked us, change to
`IREE::GPU::DataTiledMMAAttr`.

---------

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Nov 7, 2024
1 parent c651ba9 commit 8e5f218
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,8 @@ chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target,
// unrollK=4 to turn 4 separate 32-bit loads into one 128-bit load.
int intrinsicLoadBits =
std::min(sizeInBits(intrinsicA), sizeInBits(intrinsicB));
if (*wgp.getMaxLoadInstructionBits() % intrinsicLoadBits != 0) {
// Never seen that case: the ISA does not have a suitable load instruction
// to feed that intrinsic?!
return {};
}
const int unrollK = *wgp.getMaxLoadInstructionBits() / intrinsicLoadBits;
const int unrollK =
std::max(1, *wgp.getMaxLoadInstructionBits() / intrinsicLoadBits);

// The total amount of unrolling along the M and N dimensions is normally
// limited only by the number of available registers, since larger M and N
Expand Down Expand Up @@ -493,13 +489,13 @@ class GPUConvertToMultiMma final
} else {
gpuTargetAttr = getCLGPUTarget(op.getContext());
}
std::optional<IREE::GPU::DataTiledMMAAttr> mma = chooseDataTiledMMAAttr(
IREE::GPU::DataTiledMMAAttr mma = chooseDataTiledMMAAttr(
resultEncoding.getElementTypesArray(), gpuTargetAttr, resultEncoding);
if (!mma) {
LLVM_DEBUG(llvm::dbgs() << "can't find supported Mma intrinsic\n");
return failure();
}
LLVM_DEBUG(llvm::dbgs() << "Target MMA: " << mma.value() << "\n");
LLVM_DEBUG(llvm::dbgs() << "Target MMA: " << mma << "\n");

FailureOr<linalg::ContractionDimensions> contractionDims =
linalg::inferContractionDims(linalgOp);
Expand Down Expand Up @@ -535,8 +531,7 @@ class GPUConvertToMultiMma final
Location loc = op.getLoc();
auto mmaOp = rewriter.create<IREE::GPU::MultiMmaOp>(
loc, operands[0], operands[1], operands[2],
ArrayRef<AffineMap>{lhsMap, rhsMap, accMap}, iteratorTypes,
mma.value());
ArrayRef<AffineMap>{lhsMap, rhsMap, accMap}, iteratorTypes, mma);
rewriter.replaceOp(op, mmaOp);
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ iree_lit_test_suite(
"gpu_infer_memory_space.mlir",
"gpu_lower_to_ukernels.mlir",
"gpu_combine_value_barriers.mlir",
"gpu_materialize_encoding.mlir",
"gpu_materialize_encoding_gfx942.mlir",
"gpu_materialize_encoding_gfx1100.mlir",
"gpu_nested_layout_contract_amdgpu.mlir",
"gpu_nested_layout_vector_distribution.mlir",
"gpu_pipeline.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ iree_lit_test_suite(
"gpu_greedily_distribute_to_threads.mlir"
"gpu_infer_memory_space.mlir"
"gpu_lower_to_ukernels.mlir"
"gpu_materialize_encoding.mlir"
"gpu_materialize_encoding_gfx1100.mlir"
"gpu_materialize_encoding_gfx942.mlir"
"gpu_nested_layout_contract_amdgpu.mlir"
"gpu_nested_layout_vector_distribution.mlir"
"gpu_pipeline.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" \
// RUN: --iree-gpu-test-target=gfx1100 \
// RUN: --split-input-file %s | FileCheck %s

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
#encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
#pipeline_layout_3 = #hal.pipeline.layout<constants = 3, bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
func.func @matmul_lowering_WMMA_F32_16x16x16_F16() {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
%N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
%K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0)
: !flow.dispatch.tensor<readonly:tensor<?x?xf16, #encoding_lhs>>{%M, %K}
%1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0)
: !flow.dispatch.tensor<readonly:tensor<?x?xf16, #encoding_rhs>>{%K, %N}
%2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0)
: !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<?x?xf16, #encoding_lhs>>{%M, %K}
-> tensor<?x?xf16, #encoding_lhs>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<?x?xf16, #encoding_rhs>>{%K, %N}
-> tensor<?x?xf16, #encoding_rhs>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
: !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
-> tensor<?x?xf32, #encoding_result>
%6 = linalg.matmul
ins(%3, %4 : tensor<?x?xf16, #encoding_lhs>,
tensor<?x?xf16, #encoding_rhs>)
outs(%5 : tensor<?x?xf32, #encoding_result>)
-> tensor<?x?xf32, #encoding_result>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
: tensor<?x?xf32, #encoding_result>
-> !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK: func.func @matmul_lowering_WMMA_F32_16x16x16_F16
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x1x16x16xf16>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x1x16x16xf16>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x8x2x16xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = WMMA_F32_16x16x16_F16, unroll_m = 4, unroll_n_to_subgroups = 4>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
84 changes: 67 additions & 17 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -955,9 +955,6 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides(
Value threadId, ArrayRef<int64_t> permutation,
SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) const {
// TODO(bjacob): Support WMMA intrinsics.

// Get the swizzle describing the internal layout of this fragment.
TileSwizzle swizzle = getSwizzle(*this, fragment);

LLVM_DEBUG({
Expand All @@ -966,36 +963,79 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides(
DBGS() << " swizzle: " << swizzle << "\n";
});

// Populate tile sizes.
MLIRContext *ctx = builder.getContext();
SmallVector<OpFoldResult> tileSizes = getAsIndexOpFoldResult(
ctx, sliceSwizzledShape(swizzle, [](TileSwizzle::Dim d) {
return d.kind != TileSwizzle::Dim::Kind::CrossThread;
}));

// Populate tile offsets by delinearizing threadId over the CrossThread dims.
// Since the AffineDelinearizeIndexOp does not bound the input index, we
// must bound the threadId by the product of the offset ranges.
SmallVector<int64_t> tileOffsetsBasis =
// Most of the rest of this function is the computation of the offsets.
// The basic idea is to delinearize the threadId over the basis of
// cross-thread dimensions. These cross-thread dimensions may be either
// the intrinsic's own, or they may come from expansion to multiple subgroups.
// Normally, that distinction is irrelevant here: we just delinearize the
// thread-id over all cross-thread dimensions.
//
// There is one case that makes things more complicated, encountered so far
// only on RDNA3. That is when some intrinsic has multiple (so far, 2) threads
// reading the same data. This redundancy is not encoded in the TileSwizzle
// structures that we are using here. Instead, in that case, the thread grid
// (as encoded in the TileSwizzle) is smaller than the subgroup size. In that
// case, there is an implied thread-distribution-only dimension along which
// multiple threads read exactly the same data.
// So we need to distinguish layoutThreadSizes vs. distributionThreadSizes.
SmallVector<int64_t> layoutThreadSizes =
sliceSwizzledShape(swizzle, [](TileSwizzle::Dim d) {
return d.kind == TileSwizzle::Dim::Kind::CrossThread;
});

// Bound for threadId is the product of tileOffsetsBasis.
// In layoutThreadSizes, intrinsic level dimensions are mixed with expansion
// to multiple subgroups, so in order to tell if there are additional
// distribution-only thread dimensions, we need to get back to the intrinsic.
TileSwizzle intrinsicSwizzle =
getIntrinsicSwizzle(getIntrinsic().getValue(), fragment);
SmallVector<int64_t> intrinsicLayoutThreadSizes =
sliceSwizzledShape(intrinsicSwizzle, [](TileSwizzle::Dim d) {
return d.kind == TileSwizzle::Dim::Kind::CrossThread;
});
int64_t intrinsicLayoutThreadBound =
ShapedType::getNumElements(intrinsicLayoutThreadSizes);
SmallVector<int64_t> distributionThreadSizes = layoutThreadSizes;
int distributionOnlyDimIdx =
distributionThreadSizes.size() - intrinsicLayoutThreadSizes.size();
// Now we are able to tell if there is an extra distribution-only dimension.
bool hasDistributionOnlyDim = intrinsicLayoutThreadBound < getSubgroupSize();
if (hasDistributionOnlyDim) {
// Insert the extra distribution-only dimension. This will need to be paired
// below with erasing the corresponding dim out of the delinearized indices.
distributionThreadSizes.insert(
distributionThreadSizes.begin() + distributionOnlyDimIdx,
getSubgroupSize() / intrinsicLayoutThreadBound);
}

// AffineDelinearizeIndexOp requires an in-bounds input index, so we bound it.
OpFoldResult threadIdBound =
builder.getIndexAttr(ShapedType::getNumElements(tileOffsetsBasis));
builder.getIndexAttr(ShapedType::getNumElements(distributionThreadSizes));
AffineExpr d0 = builder.getAffineDimExpr(0), d1 = builder.getAffineDimExpr(1);
OpFoldResult boundedThreadId = affine::makeComposedFoldedAffineApply(
builder, loc, {d0 % d1}, {threadId, threadIdBound});

// Obtain the offsets from delinearization along the distributionThreadSizes.
SmallVector<OpFoldResult> tileOffsets =
builder
.create<affine::AffineDelinearizeIndexOp>(
loc,
getValueOrCreateConstantIndexOp(builder, loc, boundedThreadId),
getAsIndexOpFoldResult(ctx, tileOffsetsBasis))
getAsIndexOpFoldResult(ctx, distributionThreadSizes))
->getResults();

if (hasDistributionOnlyDim) {
// Erase the delinearized index that corresponds to the extra distribution
// dimension that we had inserted above. This is what causes multiple
// threads (which only differed in the index being discarded here) to read
// exactly the same data.
tileOffsets.erase(tileOffsets.begin() + distributionOnlyDimIdx);
}

// Strides are trivial: each slice is contiguous along the *expanded* dims
// even if it may not be contiguous in the flattened layout.
SmallVector<OpFoldResult> tileStrides(tileSizes.size(),
Expand Down Expand Up @@ -1071,8 +1111,6 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
Type resultType, Value lhs,
Value rhs,
Value acc) const {
// TODO(bjacob): Support WMMA intrinsics.

// Validation. Similar to MMAAttr::buildMmaOperation.
auto [aType, bType, cType] = getABCVectorTypes();
if (aType != lhs.getType() || bType != rhs.getType() ||
Expand Down Expand Up @@ -1137,15 +1175,27 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
sliceSwizzledShape(accSwizzle, [](TileSwizzle::Dim dim) {
return dim.kind == TileSwizzle::Dim::Kind::CrossIntrinsic;
});
SmallVector<int64_t> accInternalShape =
sliceSwizzledShape(accSwizzle, [](TileSwizzle::Dim dim) {
return dim.kind == TileSwizzle::Dim::Kind::Internal;
});

LLVM_DEBUG({
DBGS() << "accCrossIntrinsicShape: ";
llvm::interleaveComma(accCrossIntrinsicShape, llvm::dbgs());
llvm::dbgs() << "\n";
DBGS() << "accInternalShape: ";
llvm::interleaveComma(accInternalShape, llvm::dbgs());
llvm::dbgs() << "\n";
});
SmallVector<int64_t> strides(intrinsicCType.getRank(), 1);
SmallVector<int64_t> indices(accCrossIntrinsicShape.size(), 0);
int dstRank = accCrossIntrinsicShape.size();
SmallVector<int64_t> strides(dstRank, 1);
SmallVector<int64_t> indices(dstRank, 0);
for (Value intrAcc : intrinsicsAcc) {
acc = builder.create<vector::InsertStridedSliceOp>(loc, intrAcc, acc,
auto expandedAcc = builder.create<vector::ShapeCastOp>(
loc, VectorType::get(accInternalShape, cType.getElementType()),
intrAcc);
acc = builder.create<vector::InsertStridedSliceOp>(loc, expandedAcc, acc,
indices, strides);
incrementIndices(indices, accCrossIntrinsicShape);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ const WgpDetails *getRDNA3WgpDetails() {
{1024, 1024, 1024},
1024,
64 * 1024,
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
{0x7fffffff, 0x7fffffff, 0x7fffffff},
/*maxLoadInstructionBits=*/128,
/*simdsPerWgp=*/4,
/*vgprSpaceBits=*/256 * 32};
return &rdna3Wgp;
}

Expand Down
61 changes: 61 additions & 0 deletions tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1788,4 +1788,65 @@ iree_generated_e2e_runner_test(
"requires-gpu-rdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rdna3_dt_f16
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
"--acc_type=f32"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"rocm"
DRIVERS
"hip"
COMPILER_FLAGS
${IREE_HIP_TEST_COMPILER_FLAGS}
"--iree-opt-data-tiling"
"--iree-global-opt-experimental-rocm-data-tiling"
"--iree-global-opt-enable-early-materialization=true"
RUNNER_ARGS
"--require_exact_results=false"
"--acceptable_fp_delta=1e-04"
LABELS
"noasan"
"nomsan"
"notsan"
"noubsan"
"requires-gpu-rdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rdna3_dt_i8
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=i8"
"--acc_type=i32"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"rocm"
DRIVERS
"hip"
COMPILER_FLAGS
${IREE_HIP_TEST_COMPILER_FLAGS}
"--iree-opt-data-tiling"
"--iree-global-opt-experimental-rocm-data-tiling"
"--iree-global-opt-enable-early-materialization=true"
LABELS
"noasan"
"nomsan"
"notsan"
"noubsan"
"requires-gpu-rdna3"
)

endif()

0 comments on commit 8e5f218

Please sign in to comment.