Skip to content

Commit

Permalink
simplify-unroll
Browse files Browse the repository at this point in the history
Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob committed Nov 11, 2024
1 parent a89bc41 commit 7c3f370
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func.func @matmul_lowering_WMMA_F32_16x16x16_F16() {
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x1x16x16xf16>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x1x16x16xf16>
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x16x16xf16>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x16x16xf16>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x8x2x16xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ struct TileSwizzle {

// The size of the dimension.
int16_t size = 0;

// Support constructing from any size type.
template <typename T>
Dim(Kind kind, T size) : kind(kind), size(size) {}
};

using ExpandShapeDimVectorType = llvm::SmallVector<Dim, 4>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,85 +3,77 @@
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"

namespace mlir::iree_compiler {

// Given an `expandShape` vector-of-vectors describing the mapping from source
// dimensions to expanded dimensions, returns the index of the first expanded
// dimension corresponding to the given source dimension index.
static int64_t
getExpandedDimFirstIdx(const TileSwizzle::ExpandShapeType &expandShape,
int64_t srcIndex) {
using Kind = TileSwizzle::Dim::Kind;

// Returns the index of the first destination dimension corresponding to the
// given source dimension `srcIdx`.
static int64_t expandedDimIdx(const TileSwizzle::ExpandShapeType &expandShape,
int srcIdx) {
int dstIndexFirst = 0;
for (int i = 0; i < srcIndex; ++i) {
for (int i = 0; i < srcIdx; ++i) {
dstIndexFirst += expandShape[i].size();
}
return dstIndexFirst;
}

void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
TileSwizzle::Dim::Kind kind) {
assert(unrollFactor > 1);
int dstIndexFirst = getExpandedDimFirstIdx(swizzle.expandShape, srcIndex);
TileSwizzle::Dim unrollDim;
unrollDim.size = unrollFactor;
unrollDim.kind = kind;
// Pushes `dim` to the front of `swizzle.expandShape[srcIdx]`, and updates
// `swizzle.permutation` to make the new dimension outer-most among the dims in
// `swizzle.expandShape[srcIdx]`.
//
// This can be used to unroll a kernel with kind = CrossIntrinsic,
// or to expand a kernel to multiple subgroups with kind = CrossThread.
//
// Example:
// Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] }
// Input srcIdx = 1
// Input unrollFactor = 4
// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
//
static void expand(TileSwizzle &swizzle, int srcIdx, TileSwizzle::Dim dim) {
int dstIndex = expandedDimIdx(swizzle.expandShape, srcIdx);
// The new unrolling dimension is inserted at the start of the expandShape
// dimensions group corresponding to srcIndex.
swizzle.expandShape[srcIndex].insert(swizzle.expandShape[srcIndex].begin(),
unrollDim);
// dimensions group corresponding to srcIdx.
swizzle.expandShape[srcIdx].insert(swizzle.expandShape[srcIdx].begin(), dim);
// Since we are not interleaving here, generating side-by-side copies of the
// original layout, the new unrolling dimension is the new outermost
// dimension. Existing entries get shifted to make room for it.
for (auto &p : swizzle.permutation) {
p += (p >= dstIndexFirst);
p += (p >= dstIndex);
}
swizzle.permutation.insert(swizzle.permutation.begin(), dstIndexFirst);
swizzle.permutation.insert(swizzle.permutation.begin(), dstIndex);
}

void interleave(TileSwizzle &swizzle, int srcIndex,
int expandedDimIndexToInterleaveAt) {
// Compute which inner dimension to permute the current outer dimension into.
int dstIndexFirst = getExpandedDimFirstIdx(swizzle.expandShape, srcIndex);
int dstIndexToInterleaveAt = dstIndexFirst + expandedDimIndexToInterleaveAt;

// Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to
// move permutation[0], the outer-most dimension (which the unroll() function
// created to be the unrolling dimension), to the inner dimension given by
// `expandedIdx`.
//
// Example:
// Input swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
// Input srcIndex = 1
// Input expandedDimIndexToInterleaveAt = 1
// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [2, 0, 1] }
//
static void interleave(TileSwizzle &swizzle, int srcIdx, int expandedIdx) {
int dstIdx = expandedDimIdx(swizzle.expandShape, srcIdx) + expandedIdx;
SmallVector<int64_t> outPermutation(swizzle.permutation.size());
// The leading dimension, permutation[0], gets moved inwards to the
// position that we just computed, dstIndexToInterleaveAt.
outPermutation[dstIndexToInterleaveAt] = swizzle.permutation[0];
// position that we just computed, dstIdx.
outPermutation[dstIdx] = swizzle.permutation[0];
// Outer dimensions get shifted outwards to fill the gap.
for (int i = 0; i < dstIndexToInterleaveAt; ++i) {
for (int i = 0; i < dstIdx; ++i) {
outPermutation[i] = swizzle.permutation[i + 1];
}
// Inner dimensions don't change. That is to say that we only interleave
// at `targetInterleavedElements` granularity, we don't swizzle further
// internally to that.
for (int i = dstIndexToInterleaveAt + 1; i < outPermutation.size(); ++i) {
// Inner dimensions don't change.
for (int i = dstIdx + 1; i < outPermutation.size(); ++i) {
outPermutation[i] = swizzle.permutation[i];
}
swizzle.permutation = outPermutation;
}

// Returns the permutation of indices that sorts `v` with the given comparator.
template <template <typename U> class Comparator, typename T>
static SmallVector<int64_t> getSortingPermutation(ArrayRef<T> v) {
using P = std::pair<int64_t, T>;
SmallVector<P> pairs;
pairs.reserve(v.size());
for (auto [i, x] : llvm::enumerate(v)) {
pairs.push_back({i, x});
}
std::sort(pairs.begin(), pairs.end(),
[](P p1, P p2) { return Comparator<T>{}(p1.second, p2.second); });
SmallVector<int64_t> indices;
for (auto p : pairs) {
indices.push_back(p.first);
}
return indices;
}

TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
IREE::GPU::MMAFragment fragment) {
auto layout = IREE::GPU::getSingleSubgroupLayout(intrinsic, fragment);
Expand All @@ -95,57 +87,49 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
std::swap(layout.element[0], layout.element[1]);
}

// Initially populate swizzle.expandShape with just the thread sizes, no
// shape expansion for now.
TileSwizzle swizzle;
for (auto t : layout.thread) {
TileSwizzle::Dim dim;
dim.size = t;
dim.kind = TileSwizzle::Dim::Kind::CrossThread; // Because `layout.thread`.
swizzle.expandShape.push_back({dim});
}
// The layout strides decide the initial swizzle.permutation.
// Some WMMA intrinsics have tstrides=0 value. That always indicates an outer
// dimension, so overwrite 0 with a large value to get the right order.
SmallVector<int64_t, 2> order = layout.tstrides;
for (auto &val : order) {
val = (val == 0) ? INT64_MAX : val;
}
swizzle.permutation = getSortingPermutation<std::greater, int64_t>(order);
// Deal with any element size greater than 1 by inserting it innermost.
// Notice that this is similar to the unroll() function, just creating an
// inner dimension instead of an outer dimension.
// There are two source dimensions, corresponding to the arrays in `layout`
// all having size 2. Let's just guard that assumption with one assert here.
assert(layout.thread.size() == 2);
swizzle.expandShape.resize(2);
// Expand the shape from inner-most to outer-most dimension, so that we can
// simply use the `expand` helper function, which creates new outer dims.
// `layout.element` dims are inner-most, so we add them first.
for (auto [i, e] : llvm::enumerate(layout.element)) {
if (e != 1) {
TileSwizzle::Dim dim;
dim.size = e;
dim.kind = TileSwizzle::Dim::Kind::Internal; // Because `layout.element`.
swizzle.expandShape[i].push_back(dim);
int newIndex = getExpandedDimFirstIdx(swizzle.expandShape, i + 1) - 1;
for (auto &p : swizzle.permutation) {
p += (p >= newIndex);
}
swizzle.permutation.push_back(newIndex);
expand(swizzle, i, {Kind::Internal, e});
}
}
// Deal with any outer size greater than 1 as just a call to unroll.
// Iterate over dims in reverse order because we are creating a new outermost
// dimension each time.
// Next come `layout.thread` dims.
for (auto [i, t] : llvm::enumerate(layout.thread)) {
if (t != 1) {
expand(swizzle, i, {Kind::CrossThread, t});
}
}
// `layout.thread` dims are special in that they come with `layout.tstrides`
// which may call for a swap in `swizzle.permutation`. We only need to worry
// about that when both `layout.thread` sizes are greater than 1, so we didn't
// skip them above. Note that this condition also implies that we don't need
// to worry about `layout.tstrides == 0` which only happens with
// `layout.thread == 1`.
if (layout.thread[0] != 1 && layout.thread[1] != 1) {
if (layout.tstrides[0] > layout.tstrides[1]) {
std::swap(swizzle.permutation[0], swizzle.permutation[1]);
}
}
// Finally come `layout.outer` dims, added last so they are outer-most.
for (auto [i, o] : llvm::enumerate(layout.outer)) {
if (o != 1) {
// `layout.outer` means additional Internal dimensions, just like
// `layout.element`, just swizzled outermost.
unroll(swizzle, i, o, TileSwizzle::Dim::Kind::Internal);
expand(swizzle, i, {Kind::Internal, o});
}
}

return swizzle;
}

static int getInnermostNonInternalDimIdx(
const TileSwizzle::ExpandShapeDimVectorType &shape) {
for (int idx = shape.size() - 1; idx >= 0; --idx) {
if (shape[idx].kind != TileSwizzle::Dim::Kind::Internal) {
if (shape[idx].kind != Kind::Internal) {
return idx;
}
}
Expand All @@ -156,55 +140,54 @@ static int getInnermostNonInternalDimIdx(
TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
IREE::GPU::MMAFragment fragment) {
auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment);
using Kind = TileSwizzle::Dim::Kind;
switch (fragment) {
case IREE::GPU::MMAFragment::Lhs:
// A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
// Unroll on K with interleaving, then on M.
if (mma.getUnrollK() > 1) {
unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollK()});
int interleavingIdx =
getInnermostNonInternalDimIdx(swizzle.expandShape[1]);
interleave(swizzle, 1, interleavingIdx);
}
if (mma.getUnrollM() > 1) {
unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
}
if (mma.getUnrollMToSubgroups() > 1) {
unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollMToSubgroups()});
}
break;
case IREE::GPU::MMAFragment::Rhs:
// B-matrix (RHS). Since the pack ops already took care of transposing B,
// source dimensions are N (index 0) and K (index 1).
// Unroll on K with interleaving, then on N.
if (mma.getUnrollK() > 1) {
unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollK()});
int interleavingIdx =
getInnermostNonInternalDimIdx(swizzle.expandShape[1]);
interleave(swizzle, 1, interleavingIdx);
}
if (mma.getUnrollN() > 1) {
unroll(swizzle, 0, mma.getUnrollN(), Kind::CrossIntrinsic);
expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollN()});
}
if (mma.getUnrollNToSubgroups() > 1) {
unroll(swizzle, 0, mma.getUnrollNToSubgroups(), Kind::CrossThread);
expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollNToSubgroups()});
}
break;
case IREE::GPU::MMAFragment::Acc:
// C-matrix (accumulator). Source dimensions are M (index 0) and N (index
// 1). Unroll on N, then on M.
if (mma.getUnrollN() > 1) {
unroll(swizzle, 1, mma.getUnrollN(), Kind::CrossIntrinsic);
expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollN()});
}
if (mma.getUnrollNToSubgroups() > 1) {
unroll(swizzle, 1, mma.getUnrollNToSubgroups(), Kind::CrossThread);
expand(swizzle, 1, {Kind::CrossThread, mma.getUnrollNToSubgroups()});
}
if (mma.getUnrollM() > 1) {
unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
}
if (mma.getUnrollMToSubgroups() > 1) {
unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollMToSubgroups()});
}
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,40 +19,10 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
IREE::GPU::MMAFragment fragment);

// Returns the swizzle for the full data-tiled-mma tile, including all the
// relevant unrolling factors.
// relevant unrolling and expansion factors.
TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
IREE::GPU::MMAFragment fragment);

// Unrolls the dimension given by `srcIndex` by the given `unrollFactor`.
// This is not interleaving layouts. The layout will consist of multiple copies
// of the input tile, side by side.
//
// The enum parameter `kind` initializes the corresponding member on the newly
// created TileSwizzle::Dim.
//
// Example:
// Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] }
// Input srcIndex = 1
// Input unrollFactor = 4
// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
//
void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
TileSwizzle::Dim::Kind kind);

// Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to
// move permutation[0], the outer-most dimension (which the unroll() function
// created to be the unrolling dimension), to the inner dimension given by
// `expandedDimIndexToInterleaveAt`.
//
// Example:
// Input swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
// Input srcIndex = 1
// Input expandedDimIndexToInterleaveAt = 1
// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [2, 0, 1] }
//
void interleave(TileSwizzle &swizzle, int srcIndex,
int expandedDimIndexToInterleaveAt);

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPUTILESWIZZLEUTILS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ builtin.module attributes { transform.with_named_sequence } {
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.0 : f16
%lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 1, 16]>>}}
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, VECTORX], [1, 16]>>}}
%rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 1, 16]>>}}
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, VECTORX], [1, 16]>>}}
%output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %init : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, VECTORX, LANEY], [1, 8, 2]>, <[ BATCHY, LANEX], [1, 16]>>}}
return %output : vector<16x16xf32>
Expand Down

0 comments on commit 7c3f370

Please sign in to comment.