Skip to content

Commit

Permalink
[LLVMGPUVectorDistribute] Fix batch dimensions extraction for attenti…
Browse files Browse the repository at this point in the history
…on-like ops (#19040)

Currently, the batch dimensions are extracted as the union of dimensions
present across Q & K & V. This is not correct if one of the dims in
inputs (Q,K and V) could be seen as broadcasting.

Therefore, this commit changes this to be:
B = Union ( Q & K & O , K & V & O ) 
where if parallel dimensions common between both matmuls will be treated
as batching dimensions.

Signed-off-by: Manupa Karunaratne <[email protected]>
  • Loading branch information
manupak authored Nov 7, 2024
1 parent 766db96 commit d90aaae
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1751,7 +1751,8 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
IREE::LinalgExt::AttentionOp attnOp) {
FailureOr<IREE::LinalgExt::AttentionOpDetail> maybeOpInfo =
IREE::LinalgExt::AttentionOpDetail::get(
attnOp.getQueryMap(), attnOp.getKeyMap(), attnOp.getValueMap());
attnOp.getQueryMap(), attnOp.getKeyMap(), attnOp.getValueMap(),
attnOp.getOutputMap());
assert(succeeded(maybeOpInfo) && "failed to infer attention dims");
auto opInfo = maybeOpInfo.value();

Expand Down
7 changes: 4 additions & 3 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,10 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,

ArrayRef<int64_t> bounds = maybeBounds.value();

auto opInfo = IREE::LinalgExt::AttentionOpDetail::get(
op.getQueryMap(), op.getKeyMap(), op.getValueMap())
.value();
auto opInfo =
IREE::LinalgExt::AttentionOpDetail::get(
op.getQueryMap(), op.getKeyMap(), op.getValueMap(), op.getOutputMap())
.value();

int64_t mDim = opInfo.getMDims().back();
int64_t k1Dim = opInfo.getK1Dims().back();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,82 @@ hal.executable private @attention_mfma_32x32x8 {
// MEMORY-LABEL: func.func @attention_mfma_32x32x8()
// MEMORY-COUNT-3: memref.alloc
// MEMORY-NOT: memref.alloc

// -----

!Q = tensor<1x16x64xf16>
!K_SK = tensor<1x4x256x64xf16>
!V_SK = tensor<1x4x256x64xf16>
!O_SK = tensor<1x4x16x64xf32>
!ROWRED_SK= tensor<1x4x16xf32>

#config = #iree_gpu.lowering_config<{ workgroup = [1, 1, 16, 0, 0, 0], reduction = [0, 0, 0, 0, 0, 32],promote_operands = [0, 1, 2] }>
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64] subgroup_size = 64>

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
hal.executable private @online_attention_split_k2 {
hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @online_attention_split_k2 ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @online_attention_split_k2() attributes {translation_info = #translation} {
%cst = arith.constant 1.0 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:!Q>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:!K_SK>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:!V_SK>
%out_arg = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:!O_SK>
%max_arg = hal.interface.binding.subspan layout(#pipeline_layout) binding(4) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:!ROWRED_SK>
%sum_arg = hal.interface.binding.subspan layout(#pipeline_layout) binding(5) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:!ROWRED_SK>
%4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [1, 16, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:!Q> -> !Q
%5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [1, 4, 256, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:!K_SK> -> !K_SK
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [1, 4, 256, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:!V_SK> -> !V_SK
%empty_o = tensor.empty() : !O_SK
%empty_rowmax = tensor.empty() : !ROWRED_SK
%empty_rowsum = tensor.empty() : !ROWRED_SK
%out:3 = iree_linalg_ext.online_attention {indexing_maps = [affine_map<(b1, b2, m, n, k1, k2) -> (b1, m, k1)>,
affine_map<(b1, b2, m, n, k1, k2) -> (b1, b2, k2, k1)>,
affine_map<(b1, b2, m, n, k1, k2) -> (b1, b2, k2, n)>,
affine_map<(b1, b2, m, n, k1, k2) -> ()>,
affine_map<(b1, b2, m, n, k1, k2) -> (b1, b2, m, n)>,
affine_map<(b1, b2, m, n, k1, k2) -> (b1, b2, m)>,
affine_map<(b1, b2, m, n, k1, k2) -> (b1, b2, m)>],
lowering_config = #config,
decomposition_config = {
qk_attrs = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 1, promote_operands = [0, 1]}>},
pv_attrs = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 1, promote_operands = [1]}>}
}}
ins(%4, %5, %6, %cst : !Q, !K_SK, !V_SK, f16) outs(%empty_o, %empty_rowmax, %empty_rowsum: !O_SK, !ROWRED_SK, !ROWRED_SK) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> !O_SK, !ROWRED_SK, !ROWRED_SK
flow.dispatch.tensor.store %out#0, %out_arg, offsets = [0, 0, 0, 0], sizes = [1, 4, 16, 64], strides = [1, 1, 1, 1] : !O_SK -> !flow.dispatch.tensor<writeonly:!O_SK>
flow.dispatch.tensor.store %out#1, %max_arg, offsets = [0, 0, 0], sizes = [1, 4, 16], strides = [1, 1, 1] : !ROWRED_SK -> !flow.dispatch.tensor<writeonly:!ROWRED_SK>
flow.dispatch.tensor.store %out#2, %sum_arg, offsets = [0, 0, 0], sizes = [1, 4, 16], strides = [1, 1, 1] : !ROWRED_SK -> !flow.dispatch.tensor<writeonly:!ROWRED_SK>
return
}
}
}
}

// CHECK-LABEL: func.func @online_attention_split_k2()
// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c32
// CHECK-SAME: -> (vector<1x1x1xf32>, vector<1x1x1xf32>, vector<1x4x1x1x1x4xf32>)
// CHECK-COUNT-16: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield

// Check that we only use alloc for Q, K, and V. No shared memory for S is
// needed because the intrinsic layout mathes.
// MEMORY-LABEL: func.func @online_attention_split_k2()
// MEMORY-COUNT-3: memref.alloc
// MEMORY-NOT: memref.alloc
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,8 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
}
Value output = getOutput();

FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getQueryMap(), getKeyMap(), getValueMap());
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
AttentionOpDetail opInfo = maybeOpInfo.value();

Expand Down Expand Up @@ -527,8 +527,8 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
}

FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getQueryMap(), getKeyMap(), getValueMap());
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
AttentionOpDetail opInfo = maybeOpInfo.value();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1224,8 +1224,8 @@ LogicalResult AttentionOp::verify() {
if (indexingMaps.size() != getOperation()->getNumOperands()) {
return attnOp->emitOpError("expected an indexing map for each operand");
}
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getQueryMap(), getKeyMap(), getValueMap());
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
if (failed(maybeOpInfo)) {
return attnOp->emitOpError("failed to verify op's indexing maps");
}
Expand Down Expand Up @@ -1397,8 +1397,8 @@ LogicalResult OnlineAttentionOp::verify() {
SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();

// Check if indexing maps can represent attention.
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getQueryMap(), getKeyMap(), getValueMap());
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());

// Check shape compatibility based on indexing maps.
SmallVector<int64_t> shape(getIterationDomainRank());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1787,9 +1787,9 @@ getAttentionIterationDomain(Location loc, OpBuilder &b, int64_t domainRank,

static SmallVector<utils::IteratorType>
getAttentionIteratorTypes(int64_t domainRank, AffineMap qMap, AffineMap kMap,
AffineMap vMap) {
AffineMap vMap, AffineMap oMap) {
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(qMap, kMap, vMap);
AttentionOpDetail::get(qMap, kMap, vMap, oMap);
assert(succeeded(maybeOpInfo) && "Failed to infer attention op details");
AttentionOpDetail opInfo = maybeOpInfo.value();

Expand Down Expand Up @@ -1838,7 +1838,7 @@ SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &b) {

SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
return getAttentionIteratorTypes(getIterationDomainRank(), getQueryMap(),
getKeyMap(), getValueMap());
getKeyMap(), getValueMap(), getOutputMap());
}

FailureOr<TilingResult>
Expand Down Expand Up @@ -1991,7 +1991,7 @@ SmallVector<Range> OnlineAttentionOp::getIterationDomain(OpBuilder &b) {

SmallVector<utils::IteratorType> OnlineAttentionOp::getLoopIteratorTypes() {
return getAttentionIteratorTypes(getIterationDomainRank(), getQueryMap(),
getKeyMap(), getValueMap());
getKeyMap(), getValueMap(), getOutputMap());
}

FailureOr<TilingResult>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
Location loc = attnOp.getLoc();
MLIRContext *ctx = attnOp.getContext();

FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
attnOp.getQueryMap(), attnOp.getKeyMap(), attnOp.getValueMap());
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(attnOp.getQueryMap(), attnOp.getKeyMap(),
attnOp.getValueMap(), attnOp.getOutputMap());
assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
AttentionOpDetail opInfo = maybeOpInfo.value();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,25 @@ findPermutationsIndexingOperand(AffineMap indexingMap) {
}; // namespace

void AttentionOpDetail::inferFromIndexingMaps(AffineMap qMap, AffineMap kMap,
AffineMap vMap) {
AffineMap vMap, AffineMap oMap) {
// Q = B x M x K1
// K = B x K2 x K1
// V = B x K2 x N
// O = B x M x N
llvm::SmallDenseSet<int64_t> qSet = findPermutationsIndexingOperand(qMap);
llvm::SmallDenseSet<int64_t> kSet = findPermutationsIndexingOperand(kMap);
llvm::SmallDenseSet<int64_t> vSet = findPermutationsIndexingOperand(vMap);

// B = Q & K & V
llvm::SmallDenseSet<int64_t> bSet = qSet;
llvm::set_intersect(bSet, vSet);
llvm::set_intersect(bSet, kSet);
llvm::SmallDenseSet<int64_t> oSet = findPermutationsIndexingOperand(oMap);

// B = (Q & K & O) U (K & V & O)
llvm::SmallDenseSet<int64_t> b1Set = qSet;
llvm::set_intersect(b1Set, kSet);
llvm::set_intersect(b1Set, oSet);
llvm::SmallDenseSet<int64_t> b2Set = kSet;
llvm::set_intersect(b2Set, vSet);
llvm::set_intersect(b2Set, oSet);
llvm::SmallDenseSet<int64_t> bSet = b1Set;
llvm::set_union(bSet, b2Set);

// K1 = Q & K - B
llvm::SmallDenseSet<int64_t> k1Set = qSet;
Expand Down Expand Up @@ -74,10 +81,12 @@ void AttentionOpDetail::inferFromIndexingMaps(AffineMap qMap, AffineMap kMap,
llvm::sort(n);
}

FailureOr<AttentionOpDetail>
AttentionOpDetail::get(AffineMap qMap, AffineMap kMap, AffineMap vMap) {
FailureOr<AttentionOpDetail> AttentionOpDetail::get(AffineMap qMap,
AffineMap kMap,
AffineMap vMap,
AffineMap oMap) {
AttentionOpDetail opInfo;
opInfo.inferFromIndexingMaps(qMap, kMap, vMap);
opInfo.inferFromIndexingMaps(qMap, kMap, vMap, oMap);
opInfo.context = qMap.getContext();
opInfo.domainRank = qMap.getNumDims();
return opInfo;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace mlir::iree_compiler::IREE::LinalgExt {
class AttentionOpDetail {
public:
static FailureOr<AttentionOpDetail> get(AffineMap qMap, AffineMap kMap,
AffineMap vMap);
AffineMap vMap, AffineMap oMap);

int64_t getDomainRank() const { return domainRank; }
ArrayRef<int64_t> getBatchDims() const { return batch; }
Expand All @@ -52,7 +52,8 @@ class AttentionOpDetail {
AffineMap getSMap() const;

private:
void inferFromIndexingMaps(AffineMap qMap, AffineMap kMap, AffineMap vMap);
void inferFromIndexingMaps(AffineMap qMap, AffineMap kMap, AffineMap vMap,
AffineMap oMap);
MLIRContext *getContext() const { return context; }
SmallVector<int64_t> batch;
SmallVector<int64_t> m;
Expand Down

0 comments on commit d90aaae

Please sign in to comment.