From 224a47a461460e28f1e714b22fd02c34c39a1b4f Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 7 Nov 2024 14:39:52 -0800 Subject: [PATCH] Finish SpecializeEncodingPass! (update TypeInterface and FlowTensorType) Signed-off-by: hanhanW --- .../compiler/Dialect/Encoding/IR/BUILD.bazel | 10 ++ .../Dialect/Encoding/IR/CMakeLists.txt | 4 + .../Dialect/Encoding/IR/EncodingDialect.cpp | 1 + .../Dialect/Encoding/IR/EncodingInterfaces.td | 34 +++++++ .../Dialect/Encoding/IR/EncodingTypes.h | 3 + .../compiler/Dialect/Flow/IR/FlowTypes.cpp | 8 ++ .../iree/compiler/Dialect/Flow/IR/FlowTypes.h | 8 +- .../Transforms/MakeEncodingSolvable.cpp | 91 ++++++++++++++++++- 8 files changed, 157 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel index 3c89e75e4fd89..d25d00c04478f 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel @@ -49,6 +49,7 @@ iree_compiler_cc_library( "EncodingInterfaces.cpp.inc", "EncodingOps.cpp", "EncodingOps.cpp.inc", + "EncodingTypeInterfaces.cpp.inc", "EncodingTypes.cpp.inc", ], hdrs = [ @@ -60,6 +61,7 @@ iree_compiler_cc_library( "EncodingOps.h", "EncodingOps.h.inc", "EncodingTypes.h", + "EncodingTypeInterfaces.h.inc", "EncodingTypes.h.inc", ], deps = [ @@ -155,6 +157,14 @@ iree_gentbl_cc_library( ["--gen-attr-interface-defs"], "EncodingInterfaces.cpp.inc", ), + ( + ["--gen-type-interface-decls"], + "EncodingTypeInterfaces.h.inc", + ), + ( + ["--gen-type-interface-defs"], + "EncodingTypeInterfaces.cpp.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "EncodingInterfaces.td", diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt index 98438d6502561..95e0019b6f73b 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt @@ -21,6 +21,7 @@ iree_cc_library( "EncodingInterfaces.h.inc" "EncodingOps.h" "EncodingOps.h.inc" + "EncodingTypeInterfaces.h.inc" "EncodingTypes.h" "EncodingTypes.h.inc" SRCS @@ -32,6 +33,7 @@ iree_cc_library( "EncodingInterfaces.cpp.inc" "EncodingOps.cpp" "EncodingOps.cpp.inc" + "EncodingTypeInterfaces.cpp.inc" "EncodingTypes.cpp.inc" DEPS ::EncodingEnumsGen @@ -94,6 +96,8 @@ iree_tablegen_library( OUTS --gen-attr-interface-decls EncodingInterfaces.h.inc --gen-attr-interface-defs EncodingInterfaces.cpp.inc + --gen-type-interface-decls EncodingTypeInterfaces.h.inc + --gen-type-interface-defs EncodingTypeInterfaces.cpp.inc ) iree_tablegen_library( diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.cpp index eef0843c121c5..2e9582768a577 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.cpp @@ -21,6 +21,7 @@ #include "iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp.inc" #include "iree/compiler/Dialect/Encoding/IR/EncodingEnums.cpp.inc" #include "iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.cpp.inc" +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypeInterfaces.cpp.inc" #undef GET_ATTRDEF_CLASSES namespace mlir::iree_compiler::IREE::Encoding { diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td index 9a5e3ddac1cc0..9b46d5e2f243e 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td @@ -10,6 +10,10 @@ include "iree/compiler/Dialect/Encoding/IR/EncodingBase.td" include "mlir/IR/BuiltinAttributeInterfaces.td" +//===----------------------------------------------------------------------===// +// Attribute Interfaces +//===----------------------------------------------------------------------===// + def IREEEncoding_EncodingSolverInterfaceAttr : AttrInterface<"EncodingSolverInterfaceAttr"> { let cppNamespace = "::mlir::iree_compiler::IREE::Encoding"; @@ -52,4 +56,34 @@ def IREEEncoding_EncodingSolverInterfaceAttr : ]; } +//===----------------------------------------------------------------------===// +// Type Interfaces +//===----------------------------------------------------------------------===// + +def IREEEncoding_EncodingTypeInterface : + TypeInterface<"EncodingTypeInterface"> { + let cppNamespace = "::mlir::iree_compiler::IREE::Encoding"; + + let description = [{ + Interface used to access/update tensor types with encodings. + }]; + + + let methods = [ + InterfaceMethod< + [{ + Returns the tensor type with the updated encoding. + }], + /*retTy=*/"::mlir::Type", + /*methodName=*/"updateEncoding", + /*args=*/(ins + "::mlir::iree_compiler::IREE::Encoding::EncodingAttr":$encoding), + /*defaultImplementation=*/[{ + return {}; + }] + >, + ]; +} + + #endif // IREE_DIALECT_ENCODING_INTERFACES diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h index 62a3efdcb98a5..f11cbafaebc1c 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h @@ -26,6 +26,9 @@ #define GET_TYPEDEF_CLASSES #include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h.inc" // IWYU pragma: export #undef GET_TYPEDEF_CLASSES +// The EncodingTypeInterfaces.h.inc needs to be included after EncodingAttrs +// because an interface method could have EncodingAttr types. +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypeInterfaces.h.inc" // IWYU pragma: export // clang-format on //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp index 58c853fb08cbe..6eabf88a86281 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp @@ -5,7 +5,9 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" +#include +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/BuiltinTypes.h" @@ -116,6 +118,12 @@ bool DispatchTensorType::hasStaticShape(ArrayRef shape) const { return hasStaticShape() && getShape() == shape; } +Type DispatchTensorType::updateEncoding( + IREE::Encoding::EncodingAttr encoding) const { + return DispatchTensorType::get(getAccess(), getShape(), getBoundElementType(), + encoding); +} + LogicalResult DispatchTensorType::verify(function_ref emitError, uint32_t access, Type boundType) { diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h index 85f0141d68366..5178f85cb3239 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_DIALECT_FLOW_IR_FLOWTYPES_H_ #define IREE_COMPILER_DIALECT_FLOW_IR_FLOWTYPES_H_ +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" #include "llvm/ADT/DenseMapInfo.h" @@ -45,9 +46,12 @@ enum class TensorAccess : uint32_t { // Blatantly ripped from ShapedType, because the closed type system means that // we can't extend it and reuse all of this. +// TODO(hanchung): I'm not sure if I attach TypeInterface correct or not. It +// adds traits; it declares and implements the method itself. class DispatchTensorType : public Type::TypeBase { + detail::DispatchTensorTypeStorage, + IREE::Encoding::EncodingTypeInterface::Trait> { public: using ImplType = detail::DispatchTensorTypeStorage; @@ -132,6 +136,8 @@ class DispatchTensorType } return llvm::cast(boundType); } + + Type updateEncoding(IREE::Encoding::EncodingAttr encoding) const; }; void printType(DispatchTensorType &type, DialectAsmPrinter &p); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MakeEncodingSolvable.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MakeEncodingSolvable.cpp index a426a0e80bb61..8d60ffc431268 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MakeEncodingSolvable.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MakeEncodingSolvable.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" +#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" #include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" #include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" #include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.h" @@ -16,6 +17,7 @@ #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/IR/Attributes.h" @@ -27,7 +29,10 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -84,6 +89,71 @@ getOperandsResourceAffinities(AffinityAnalysis &affinityAnalysis, return operandAttrs; } +static void updateExecutableOpEncodings( + ModuleOp moduleOp, Stream::ExecutableOp executableOp, + ArrayRef operandAttrs, AffinityAttr resultAffinity, + SymbolTable symbolTable, + std::function &)> + resolver) { + LLVM_DEBUG(llvm::dbgs() << "Update ExecutableOp: " + << executableOp.getSymName() << "\n"); + LLVM_DEBUG({ + llvm::dbgs() << " operand affinities: ["; + llvm::interleaveComma(operandAttrs, llvm::dbgs()); + llvm::dbgs() << "]\n"; + }); + + MLIRContext *ctx = executableOp.getContext(); + for (auto exportOp : + executableOp.getOps()) { + exportOp.getSymName(); + auto funcOp = cast(symbolTable.lookupSymbolIn( + executableOp.getInnerModule(), exportOp.getSymName())); + Region ®ion = funcOp.getFunctionBody(); + auto argsAffinities = llvm::map_to_vector( + operandAttrs, [](Attribute attr) { return cast(attr); }); + auto resAffinityAttr = + ArrayAttr::get(ctx, {cast(resultAffinity)}); + argsAffinities.resize(region.getNumArguments(), resAffinityAttr); + int idx = 0; + for (auto arg : region.getArguments()) { + if (!isa(arg.getType())) { + continue; + } + ArrayRef affinities = argsAffinities[idx++].getValue(); + assert(affinities.size() == 1); + SetVector resolvedTargets; + if (failed(resolver(cast(affinities[0]), moduleOp, + resolvedTargets))) { + LLVM_DEBUG(llvm::dbgs() << "failed on getting target resolvers\n"); + continue; + } + + for (auto user : arg.getUsers()) { + // TODO(hanchung): Is it the only case? + auto subspanOp = cast(user); + auto resType = + dyn_cast(subspanOp.getType()); + if (!resType) { + continue; + } + auto tensorType = dyn_cast(resType.getBoundType()); + if (!tensorType || !tensorType.getEncoding()) { + continue; + } + auto encoding = + dyn_cast(tensorType.getEncoding()); + + SmallVector targets(resolvedTargets.begin(), + resolvedTargets.end()); + subspanOp.getResult().setType( + resType.updateEncoding(encoding.cloneWithTargets(targets))); + } + } + } +} + } // namespace struct MakeEncodingSolvablePass @@ -208,6 +278,7 @@ struct MakeEncodingSolvablePass } }); + // Duplicate executables for each unqiue resource affinities. IRRewriter rewriter(ctx); DenseMap, Stream::ExecutableOp> @@ -237,6 +308,7 @@ struct MakeEncodingSolvablePass } } + // Update dispatch sites. for (auto dispatchOp : candidates) { SmallVector affinities; assert(affinityAnalysis.tryLookupExecutionAffinity(dispatchOp, @@ -253,7 +325,8 @@ struct MakeEncodingSolvablePass affinities[0], ArrayAttr::get(ctx, operandAttrs), exportOp); if (!dispatchSiteToExecutable.count(info)) { - LLVM_DEBUG(llvm::dbgs() << "not found, skip\n"); + LLVM_DEBUG(llvm::dbgs() << "not found, skip\n " + << dispatchOp.getEntryPoints() << "\n"); continue; } @@ -269,6 +342,22 @@ struct MakeEncodingSolvablePass }); } + // Attach encoding targets to all the executables. + for (auto dispatchOp : candidates) { + SmallVector affinities; + assert(affinityAnalysis.tryLookupExecutionAffinity(dispatchOp, + affinities)); + SymbolRefAttr entryPoint = + *dispatchOp.getEntryPoints().getAsRange().begin(); + auto exportOp = cast( + symbolTable.lookupSymbolIn(moduleOp, entryPoint)); + auto executableOp = + exportOp->getParentOfType(); + SmallVector operandAttrs = + operandsResourceAffinities[dispatchOp]; + updateExecutableOpEncodings(moduleOp, executableOp, operandAttrs, + affinities[0], symbolTable, resolver); + } } } };