Skip to content

Commit

Permalink
[NFC] Refactor AttentionOpDetail struct (#19042)
Browse files Browse the repository at this point in the history
Currently, AttentionOpDetail struct is meant
to accept *all* the indexing maps of an attention-like operator and use
presumed indices to obtain the required maps from that.

Going forward, this could be fragile if we want
to obtain oMap as well because oMap does not have
a static index unlike Q, K and V input.

Therefore, this commit refactors AttentionOpDetail to expect callers to
provide exactly the maps it requires. Thus, rely on the callers to
provide those maps.

Signed-off-by: Manupa Karunaratne <[email protected]>
  • Loading branch information
manupak authored Nov 7, 2024
1 parent 8f925d4 commit 766db96
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,8 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
IREE::LinalgExt::AttentionOp attnOp) {
FailureOr<IREE::LinalgExt::AttentionOpDetail> maybeOpInfo =
IREE::LinalgExt::AttentionOpDetail::get(attnOp.getIndexingMapsArray());
IREE::LinalgExt::AttentionOpDetail::get(
attnOp.getQueryMap(), attnOp.getKeyMap(), attnOp.getValueMap());
assert(succeeded(maybeOpInfo) && "failed to infer attention dims");
auto opInfo = maybeOpInfo.value();

Expand Down
6 changes: 3 additions & 3 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,9 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,

ArrayRef<int64_t> bounds = maybeBounds.value();

auto opInfo =
IREE::LinalgExt::AttentionOpDetail::get(op.getIndexingMapsArray())
.value();
auto opInfo = IREE::LinalgExt::AttentionOpDetail::get(
op.getQueryMap(), op.getKeyMap(), op.getValueMap())
.value();

int64_t mDim = opInfo.getMDims().back();
int64_t k1Dim = opInfo.getK1Dims().back();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
Value output = getOutput();

FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getIndexingMapsArray());
AttentionOpDetail::get(getQueryMap(), getKeyMap(), getValueMap());
assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
AttentionOpDetail opInfo = maybeOpInfo.value();

Expand Down Expand Up @@ -528,7 +528,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
}

FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getIndexingMapsArray());
AttentionOpDetail::get(getQueryMap(), getKeyMap(), getValueMap());
assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
AttentionOpDetail opInfo = maybeOpInfo.value();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ LogicalResult AttentionOp::verify() {
return attnOp->emitOpError("expected an indexing map for each operand");
}
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(indexingMaps);
AttentionOpDetail::get(getQueryMap(), getKeyMap(), getValueMap());
if (failed(maybeOpInfo)) {
return attnOp->emitOpError("failed to verify op's indexing maps");
}
Expand Down Expand Up @@ -1398,7 +1398,7 @@ LogicalResult OnlineAttentionOp::verify() {

// Check if indexing maps can represent attention.
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(indexingMaps);
AttentionOpDetail::get(getQueryMap(), getKeyMap(), getValueMap());

// Check shape compatibility based on indexing maps.
SmallVector<int64_t> shape(getIterationDomainRank());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1786,10 +1786,10 @@ getAttentionIterationDomain(Location loc, OpBuilder &b, int64_t domainRank,
}

static SmallVector<utils::IteratorType>
getAttentionIteratorTypes(int64_t domainRank,
ArrayRef<AffineMap> indexingMaps) {
getAttentionIteratorTypes(int64_t domainRank, AffineMap qMap, AffineMap kMap,
AffineMap vMap) {
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(indexingMaps);
AttentionOpDetail::get(qMap, kMap, vMap);
assert(succeeded(maybeOpInfo) && "Failed to infer attention op details");
AttentionOpDetail opInfo = maybeOpInfo.value();

Expand Down Expand Up @@ -1837,8 +1837,8 @@ SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &b) {
}

SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
return getAttentionIteratorTypes(getIterationDomainRank(),
getIndexingMapsArray());
return getAttentionIteratorTypes(getIterationDomainRank(), getQueryMap(),
getKeyMap(), getValueMap());
}

FailureOr<TilingResult>
Expand Down Expand Up @@ -1990,8 +1990,8 @@ SmallVector<Range> OnlineAttentionOp::getIterationDomain(OpBuilder &b) {
}

SmallVector<utils::IteratorType> OnlineAttentionOp::getLoopIteratorTypes() {
return getAttentionIteratorTypes(getIterationDomainRank(),
getIndexingMapsArray());
return getAttentionIteratorTypes(getIterationDomainRank(), getQueryMap(),
getKeyMap(), getValueMap());
}

FailureOr<TilingResult>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
Location loc = attnOp.getLoc();
MLIRContext *ctx = attnOp.getContext();

FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(attnOp.getIndexingMapsArray());
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
attnOp.getQueryMap(), attnOp.getKeyMap(), attnOp.getValueMap());
assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
AttentionOpDetail opInfo = maybeOpInfo.value();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,8 @@ findPermutationsIndexingOperand(AffineMap indexingMap) {

}; // namespace

void AttentionOpDetail::inferFromIndexingMaps(
ArrayRef<AffineMap> indexingMaps) {
assert(indexingMaps.size() >= 4);
AffineMap qMap = indexingMaps[0];
AffineMap kMap = indexingMaps[1];
AffineMap vMap = indexingMaps[2];

void AttentionOpDetail::inferFromIndexingMaps(AffineMap qMap, AffineMap kMap,
AffineMap vMap) {
// Q = B x M x K1
// K = B x K2 x K1
// V = B x K2 x N
Expand Down Expand Up @@ -80,10 +75,11 @@ void AttentionOpDetail::inferFromIndexingMaps(
}

FailureOr<AttentionOpDetail>
AttentionOpDetail::get(ArrayRef<AffineMap> indexingMaps) {
AttentionOpDetail::get(AffineMap qMap, AffineMap kMap, AffineMap vMap) {
AttentionOpDetail opInfo;
opInfo.inferFromIndexingMaps(indexingMaps);
opInfo.maps = SmallVector<AffineMap>(indexingMaps);
opInfo.inferFromIndexingMaps(qMap, kMap, vMap);
opInfo.context = qMap.getContext();
opInfo.domainRank = qMap.getNumDims();
return opInfo;
}

Expand Down
18 changes: 7 additions & 11 deletions compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,27 @@ namespace mlir::iree_compiler::IREE::LinalgExt {
/// Tiling on K1 is generally not done because it's so small and is non-trivial.
class AttentionOpDetail {
public:
static FailureOr<AttentionOpDetail> get(ArrayRef<AffineMap> indexingMaps);
static FailureOr<AttentionOpDetail> get(AffineMap qMap, AffineMap kMap,
AffineMap vMap);

int64_t getDomainRank() const { return maps[0].getNumDims(); }
int64_t getDomainRank() const { return domainRank; }
ArrayRef<int64_t> getBatchDims() const { return batch; }
ArrayRef<int64_t> getMDims() const { return m; }
ArrayRef<int64_t> getK1Dims() const { return k1; }
ArrayRef<int64_t> getK2Dims() const { return k2; }
ArrayRef<int64_t> getNDims() const { return n; }

ArrayRef<AffineMap> getIndexingMaps() const { return maps; }

AffineMap getSMap() const;

private:
void inferFromIndexingMaps(ArrayRef<AffineMap> indexingMaps);

MLIRContext *getContext() const { return maps[0].getContext(); }

void inferFromIndexingMaps(AffineMap qMap, AffineMap kMap, AffineMap vMap);
MLIRContext *getContext() const { return context; }
SmallVector<int64_t> batch;
SmallVector<int64_t> m;
SmallVector<int64_t> k1;
SmallVector<int64_t> k2;
SmallVector<int64_t> n;

SmallVector<AffineMap> maps;
MLIRContext *context;
int64_t domainRank;
};

}; // namespace mlir::iree_compiler::IREE::LinalgExt
Expand Down

0 comments on commit 766db96

Please sign in to comment.