Skip to content

Commit

Permalink
Add subgroup size to distribute patterns, keep hacking on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysz00 committed Nov 12, 2024
1 parent dbbf653 commit c67c36d
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,54 +32,55 @@ namespace {
/// parameterized by the thread grid.
static SmallVector<Value> computeSIMDIndex(const LayoutIterator::State &state,
LayoutAttr layout, Value laneId,
int64_t subgroupSize,
RewriterBase &rewriter) {
MLIRContext *ctx = layout.getContext();
AffineExpr threadX, threadY, threadZ;
bindSymbols(ctx, threadX, threadY, threadZ);
Location loc = laneId.getLoc();

auto [laneDimX, laneDimY, laneDimZ] = layout.getLaneGrid();
int64_t gridsPerSubgroup =
llvm::divideCeil(subgroupSize, laneDimX * laneDimY * laneDimZ);
// Note: we add an extra entry to the delinearization here so that, if the
// vector layout requires fewer lanes than are present in the subgroup.
// Otherwise, we'd, for example, construct delinearizations with the basis (1,
// 1, 16) when there are 32 lanes, which would simplify to no delinearization
// at all. To resolve this, we add an extra term to the grid to capture the
// overflow.
auto reversedLaneGrid = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, laneId,
ArrayRef<int64_t>{gridsPerSubgroup, laneDimZ, laneDimY, laneDimX});

SmallVector<Value> simdIndex;

// Calculate the index for each dim separately.
for (PerDimLayoutAttr dimLayout : layout.getLayouts()) {
AffineExpr offset = getAffineConstantExpr(0, ctx);
AffineExpr stride = getAffineConstantExpr(1, ctx);
for (auto [label, shape] : llvm::reverse(
llvm::zip(dimLayout.getLabels(), dimLayout.getShapes()))) {
SmallVector<Value> linearizeVals;
for (LayoutDimensionAttr label : dimLayout.getLabels()) {
int64_t position = state.lookup(label.getValue()).getPosition();

// Note: indices are into a reversed lane grid that has an extra leading
// term we must ignore (so the X coordinate is result #3 and the Z
// coordinate is result #1).
switch (label.getValue()) {
case LayoutDimension::LANEX:
offset = offset + stride * threadX;
linearizeVals.push_back(reversedLaneGrid.getResult(3));
break;
case LayoutDimension::LANEY:
offset = offset + stride * threadY;
linearizeVals.push_back(reversedLaneGrid.getResult(2));
break;
case LayoutDimension::LANEZ:
offset = offset + stride * threadZ;
linearizeVals.push_back(reversedLaneGrid.getResult(1));
break;
default:
offset = offset + stride * getAffineConstantExpr(position, ctx);
linearizeVals.push_back(
rewriter.createOrFold<arith::ConstantIndexOp>(loc, position));
break;
}
stride = stride * getAffineConstantExpr(shape, ctx);
}

auto [laneDimX, laneDimY, laneDimZ] = layout.getLaneGrid();
SmallVector<Value> laneGrid = {
rewriter.create<arith::ConstantIndexOp>(laneId.getLoc(), laneDimZ),
rewriter.create<arith::ConstantIndexOp>(laneId.getLoc(), laneDimY),
rewriter.create<arith::ConstantIndexOp>(laneId.getLoc(), laneDimX)};
FailureOr<SmallVector<Value>> maybeReversedLaneGridVals =
affine::delinearizeIndex(rewriter, laneId.getLoc(), laneId, laneGrid);
assert(succeeded(maybeReversedLaneGridVals) &&
"Failed to delinearize lane index");
SmallVector<Value> laneGridVals = {(*maybeReversedLaneGridVals)[2],
(*maybeReversedLaneGridVals)[1],
(*maybeReversedLaneGridVals)[0]};

// Compute the index for the dim.
AffineMap indexMap = AffineMap::get(0, 3, offset);
Value index = rewriter.create<affine::AffineApplyOp>(
rewriter.getUnknownLoc(), indexMap, laneGridVals);
Value index = rewriter.create<affine::AffineLinearizeIndexOp>(
rewriter.getUnknownLoc(), linearizeVals, dimLayout.getShapes(),
/*disjoint=*/true);
simdIndex.push_back(index);
}

Expand Down Expand Up @@ -199,8 +200,9 @@ struct DistributeXferLayoutAttr : OpDistributionPattern<OpTy> {
"expected vector::TransferReadOp or vector::TransferWriteOp");

DistributeXferLayoutAttr(MLIRContext *context, Value laneId,
PatternBenefit benefit = 1)
: OpDistributionPattern<OpTy>(context, benefit), laneId(laneId) {}
int64_t subgroupSize, PatternBenefit benefit = 1)
: OpDistributionPattern<OpTy>(context, benefit), laneId(laneId),
subgroupSize(subgroupSize) {}

VectorValue accessMemory(OpTy xferOp, VectorValue accumulator,
LayoutAttr vectorLayout,
Expand Down Expand Up @@ -237,7 +239,7 @@ struct DistributeXferLayoutAttr : OpDistributionPattern<OpTy> {
llvm::SmallBitVector &projectedDims,
RewriterBase &rewriter) const {
SmallVector<Value> simdIndices =
computeSIMDIndex(state, memoryLayout, laneId, rewriter);
computeSIMDIndex(state, memoryLayout, laneId, subgroupSize, rewriter);
SmallVector<Value> memoryIndices(indices);

// The memory layout has some projected leading dims that indices doesn't.
Expand Down Expand Up @@ -272,6 +274,7 @@ struct DistributeXferLayoutAttr : OpDistributionPattern<OpTy> {
}

Value laneId;
int64_t subgroupSize;
};

struct DistributeTransferReadLayoutAttr final
Expand Down Expand Up @@ -1118,10 +1121,11 @@ void populateGPUDistributionPatterns(RewritePatternSet &patterns) {
}

void populateGPUDistributionLayoutAttrPatterns(Value laneId,
int64_t subgroupSize,
RewritePatternSet &patterns) {
patterns
.add<DistributeTransferReadLayoutAttr, DistributeTransferWriteLayoutAttr>(
patterns.getContext(), laneId);
patterns.getContext(), laneId, subgroupSize);
patterns.add<DistributeBroadcastLayoutAttr, DistributeTransposeLayoutAttr>(
patterns.getContext());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void populateDropSharedMemoryDeallocOpPatterns(RewritePatternSet &patterns);
void populateGPUDistributionPatterns(RewritePatternSet &patterns);

void populateGPUDistributionLayoutAttrPatterns(Value laneId,
int64_t subgroupSize,
RewritePatternSet &patterns);

void populateGPUReductionDistributionPatterns(RewritePatternSet &patterns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,6 @@ builtin.module attributes { transform.with_named_sequence } {
#layout_row_major = #iree_vector_ext.layout<<[BATCHX, LANEY], [2, 8]>, <[BATCHY, LANEX, VECTORX], [2, 1, 8]>>
#layout_col_major = #iree_vector_ext.layout<<[BATCHX, LANEY, VECTORX], [1, 4, 4]>, <[BATCHY, LANEX], [2, 8]>>

// TODO: Use affine min tricks based on the grid size to elide the mod.
// Note that this IR is invalid if subgroup size != 8.

func.func @distribute_transfer_write_row_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
%c0 = arith.constant 0 : index
%rootl = iree_vector_ext.to_layout %root to layout(#layout_row_major) : vector<16x16xf16>
Expand All @@ -208,24 +205,23 @@ func.func @distribute_transfer_write_row_major(%root: vector<16x16xf16>, %alloc:
: vector<16x16xf16>, memref<64x64xf16>
func.return
}
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 mod 8)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 mod 8 + 8)>

// CHECK-LABEL: @distribute_transfer_write_row_major
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[LANEID:.+]] = gpu.thread_id x
// CHECK: %[[VEC_LANE_Y:.+]] = affine.apply #[[$MAP0]]()[%[[LANEID]]]
// CHECK: %[[SPLIT_ID:.+]]:2 = affine.delinearize_index %[[LANEID]] into (8, 8)
// CHECK: %[[DIST_SRC_VEC:.+]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<2x2x8xf16>
// CHECK: %[[BATCH_0_0:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 0] : vector<8xf16> from vector<2x2x8xf16>
// CHECK: vector.store %[[BATCH_0_0]], %{{.*}}[%[[VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16>
// CHECK: vector.store %[[BATCH_0_0]], %{{.*}}[%[[SPLIT_ID]]#1, %[[C0]]] : memref<64x64xf16>, vector<8xf16>

// CHECK: %[[NEXT_VEC_LANE_Y:.+]] = affine.apply #[[$MAP1]]()[%[[LANEID]]]
// CHECK: %[[NEXT_VEC_LANE_Y:.+]] = affine.linearize_index disjoint [%[[C1]], %[[SPLIT_ID]]#1] by (2, 8) : index
// CHECK: %[[BATCH_1_0:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 0] : vector<8xf16> from vector<2x2x8xf16>
// CHECK: vector.store %[[BATCH_1_0]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16>

// CHECK: %[[BATCH_0_1:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 1] : vector<8xf16> from vector<2x2x8xf16>
// CHECK: vector.store %[[BATCH_0_1]], %{{.*}}[%[[VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16>
// CHECK: vector.store %[[BATCH_0_1]], %{{.*}}[%[[SPLIT_ID]]#1, %[[C8]]] : memref<64x64xf16>, vector<8xf16>

// CHECK: %[[BATCH_1_1:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 1] : vector<8xf16> from vector<2x2x8xf16>
// CHECK: vector.store %[[BATCH_1_1]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16>
Expand Down Expand Up @@ -560,8 +556,6 @@ builtin.module attributes { transform.with_named_sequence } {
#layoutB2 = #iree_vector_ext.layout<<[ BATCHX, LANEY, VECTORX], [1, 1, 16]>, <[ BATCHY, LANEX], [1, 16]>>
#layoutC2 = #iree_vector_ext.layout<<[ BATCHX, VECTORY, LANEY, VECTORX], [1, 8, 2, 1]>, <[ BATCHY, LANEX], [1, 16]>>

// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + (s0 floordiv 32) * 16)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 mod 16)>
// CHECK-LABEL: func.func @resolve_wmma_layout_conflict_with_shared_memory
func.func @resolve_wmma_layout_conflict_with_shared_memory(%15 : vector<16x16xf16>,
%14 : vector<16x16xf16>,
Expand Down Expand Up @@ -607,19 +601,18 @@ func.func @resolve_wmma_layout_conflict_with_shared_memory(%15 : vector<16x16xf1
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[VEC_INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x16xf16

// CHECK: %[[VEC_INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x16xf16
// CHECK: %[[TID_X:.+]] = gpu.thread_id x
// CHECK: %[[TID_Y:.+]] = gpu.thread_id y
// CHECK: %[[TID_Z:.+]] = gpu.thread_id z
// CHECK: %[[SUBGROUP_OFFSET:.+]] = affine.apply #[[$MAP0]]()[%[[TID_X]], %[[TID_Y]], %[[TID_Z]]]
// CHECK: %[[TIDX:.+]] = gpu.thread_id x
// CHECK: %[[TIDY:.+]] = gpu.thread_id y
// CHECK: %[[SUBGROUP_OFFSET:.+]] = affine.linearize_index disjoint [%[[TIDY]], %[[C0]]] by (2, 16)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][0, %[[SUBGROUP_OFFSET]]] [16, 16] [1, 1]
// CHECK: %[[HALF_LANE_ID:.+]] = affine.apply #[[$MAP1]]()[%[[TID_X]]]
// CHECK-COUNT-8: vector.store %{{.+}}, %[[SUBVIEW]][%{{.+}}, %[[HALF_LANE_ID]]]
// CHECK: %[[SPLIT_LANE_ID:.+]]:2 = affine.delinearize_index %[[TIDX]] into (2, 16)
// CHECK-COUNT-8: vector.store %{{.+}}, %[[SUBVIEW]][%{{.+}}, %[[SPLIT_LANE_ID]]#1]
// CHECK-AFTER: gpu.barrier

// CHECK: %[[LANE_OFFSET:.+]] = arith.addi %[[SUBGROUP_OFFSET]], %[[HALF_LANE_ID]]
// CHECK: %[[LANE_OFFSET:.+]] = arith.addi %[[SUBGROUP_OFFSET]], %[[SPLIT_LANE_ID]]#1
// CHECK: %[[LOAD0:.+]] = vector.load %[[ALLOC]][%[[C0]], %[[LANE_OFFSET]]]
// CHECK: %[[INSERT0:.+]] = vector.insert_strided_slice %[[LOAD0]], %[[VEC_INIT]] {offsets = [0, 0, 0], strides = [1]} : vector<1xf16> into vector<1x1x16xf16>
// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[C1]], %[[LANE_OFFSET]]]
Expand All @@ -636,7 +629,7 @@ func.func @resolve_wmma_layout_conflict_with_shared_memory(%15 : vector<16x16xf1
builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func {experimental = true} : !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func {experimental = true, subgroup_size = 32 : i64} : !transform.any_op
transform.yield
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1113,18 +1113,17 @@ transform_dialect::TestGpuVectorDistribution::applyToOne(
rewriter.setInsertionPointToStart(&target.getFunctionBody().front());
// This is a test op so we unsafely use thread_id x as the lane ID. In
// general this should linearize the thread IDs based on the workgroup size
// and divide by the subgroup size. i.e.
// and take the modulo by the subgroup size. i.e.
//
// lane_id = (tid_x + tid_y * dim_x + tid_z * dim_y * dim_x) / subgroup_size;
// lane_id = (tid_x + tid_y * dim_x + tid_z * dim_y * dim_x) % subgroup_size;
Value laneId =
rewriter.create<gpu::ThreadIdOp>(target.getLoc(), gpu::Dimension::x);
int64_t subgroupSize = getSubgroupSize();

populateGPUDistributionPatterns(patterns);
populateGPUDistributionLayoutAttrPatterns(laneId, patterns);
populateGPUDistributionLayoutAttrPatterns(laneId, subgroupSize, patterns);
populateGPUReductionDistributionPatterns(patterns);
// For testing we use subgroup size = 64.
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId,
/*subgroupSize=*/64);
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, subgroupSize);
populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns);
if (getExperimental())
populateGPULayoutResolutionDistributionPatterns(patterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ def TestGpuVectorDistribution :
}];

let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedOptionalAttr<BoolAttr, "false">:$experimental);
DefaultValuedOptionalAttr<BoolAttr, "false">:$experimental,
DefaultValuedOptionalAttr<I64Attr, "64">:$subgroup_size);
let results = (outs);

let assemblyFormat = [{ $target attr-dict `:` type($target)}];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions {
int64_t subgroupSize)
: VectorLayoutOptions(root), patterns(root->getContext()) {
populateGPUDistributionPatterns(patterns);
populateGPUDistributionLayoutAttrPatterns(laneId, patterns);
populateGPUDistributionLayoutAttrPatterns(laneId, subgroupSize, patterns);
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId,
subgroupSize);
populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1473,13 +1473,13 @@ transform_dialect::AMDGPUDistributeVectorsOp::applyToOne(
rewriter.setInsertionPointToStart(&target.getFunctionBody().front());
Value laneId =
rewriter.create<gpu::ThreadIdOp>(target.getLoc(), gpu::Dimension::x);
int64_t subgroupSize = getSubgroupSize();

populateGPUDistributionPatterns(patterns);
populateGPUDistributionLayoutAttrPatterns(laneId, patterns);
populateGPUDistributionLayoutAttrPatterns(laneId, subgroupSize, patterns);
populateGPUReductionDistributionPatterns(patterns);
// For testing we use subgroup size = 64.
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId,
/*subgroupSize=*/64);
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, subgroupSize);
populateAMDGPUDistributionPatterns(patterns);
populateGPULayoutResolutionDistributionPatterns(patterns);
if (failed(distributeVectorOps(target, patterns, options))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,8 @@ def AMDGPUDistributeVectorsOp :
}];

let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$test_conversion);
UnitAttr:$test_conversion,
DefaultValuedOptionalAttr<I64Attr, "64">:$subgroup_size);
let results = (outs TransformHandleTypeInterface:$result);

let assemblyFormat = [{
Expand Down

0 comments on commit c67c36d

Please sign in to comment.