Skip to content

Commit

Permalink
Finish SpecializeEncodingPass! (update TypeInterface and FlowTensorType)
Browse files Browse the repository at this point in the history
Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW committed Nov 7, 2024
1 parent 998f183 commit 224a47a
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 2 deletions.
10 changes: 10 additions & 0 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ iree_compiler_cc_library(
"EncodingInterfaces.cpp.inc",
"EncodingOps.cpp",
"EncodingOps.cpp.inc",
"EncodingTypeInterfaces.cpp.inc",
"EncodingTypes.cpp.inc",
],
hdrs = [
Expand All @@ -60,6 +61,7 @@ iree_compiler_cc_library(
"EncodingOps.h",
"EncodingOps.h.inc",
"EncodingTypes.h",
"EncodingTypeInterfaces.h.inc",
"EncodingTypes.h.inc",
],
deps = [
Expand Down Expand Up @@ -155,6 +157,14 @@ iree_gentbl_cc_library(
["--gen-attr-interface-defs"],
"EncodingInterfaces.cpp.inc",
),
(
["--gen-type-interface-decls"],
"EncodingTypeInterfaces.h.inc",
),
(
["--gen-type-interface-defs"],
"EncodingTypeInterfaces.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "EncodingInterfaces.td",
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_cc_library(
"EncodingInterfaces.h.inc"
"EncodingOps.h"
"EncodingOps.h.inc"
"EncodingTypeInterfaces.h.inc"
"EncodingTypes.h"
"EncodingTypes.h.inc"
SRCS
Expand All @@ -32,6 +33,7 @@ iree_cc_library(
"EncodingInterfaces.cpp.inc"
"EncodingOps.cpp"
"EncodingOps.cpp.inc"
"EncodingTypeInterfaces.cpp.inc"
"EncodingTypes.cpp.inc"
DEPS
::EncodingEnumsGen
Expand Down Expand Up @@ -94,6 +96,8 @@ iree_tablegen_library(
OUTS
--gen-attr-interface-decls EncodingInterfaces.h.inc
--gen-attr-interface-defs EncodingInterfaces.cpp.inc
--gen-type-interface-decls EncodingTypeInterfaces.h.inc
--gen-type-interface-defs EncodingTypeInterfaces.cpp.inc
)

iree_tablegen_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp.inc"
#include "iree/compiler/Dialect/Encoding/IR/EncodingEnums.cpp.inc"
#include "iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.cpp.inc"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypeInterfaces.cpp.inc"
#undef GET_ATTRDEF_CLASSES

namespace mlir::iree_compiler::IREE::Encoding {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
include "iree/compiler/Dialect/Encoding/IR/EncodingBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"

//===----------------------------------------------------------------------===//
// Attribute Interfaces
//===----------------------------------------------------------------------===//

def IREEEncoding_EncodingSolverInterfaceAttr :
AttrInterface<"EncodingSolverInterfaceAttr"> {
let cppNamespace = "::mlir::iree_compiler::IREE::Encoding";
Expand Down Expand Up @@ -52,4 +56,34 @@ def IREEEncoding_EncodingSolverInterfaceAttr :
];
}

//===----------------------------------------------------------------------===//
// Type Interfaces
//===----------------------------------------------------------------------===//

def IREEEncoding_EncodingTypeInterface :
TypeInterface<"EncodingTypeInterface"> {
let cppNamespace = "::mlir::iree_compiler::IREE::Encoding";

let description = [{
Interface used to access/update tensor types with encodings.
}];


let methods = [
InterfaceMethod<
[{
Returns the tensor type with the updated encoding.
}],
/*retTy=*/"::mlir::Type",
/*methodName=*/"updateEncoding",
/*args=*/(ins
"::mlir::iree_compiler::IREE::Encoding::EncodingAttr":$encoding),
/*defaultImplementation=*/[{
return {};
}]
>,
];
}


#endif // IREE_DIALECT_ENCODING_INTERFACES
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#define GET_TYPEDEF_CLASSES
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h.inc" // IWYU pragma: export
#undef GET_TYPEDEF_CLASSES
// The EncodingTypeInterfaces.h.inc needs to be included after EncodingAttrs
// because an interface method could have EncodingAttr types.
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypeInterfaces.h.inc" // IWYU pragma: export
// clang-format on

//===---------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include <libxml/hash.h>

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -116,6 +118,12 @@ bool DispatchTensorType::hasStaticShape(ArrayRef<int64_t> shape) const {
return hasStaticShape() && getShape() == shape;
}

Type DispatchTensorType::updateEncoding(
IREE::Encoding::EncodingAttr encoding) const {
return DispatchTensorType::get(getAccess(), getShape(), getBoundElementType(),
encoding);
}

LogicalResult
DispatchTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
uint32_t access, Type boundType) {
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_DIALECT_FLOW_IR_FLOWTYPES_H_
#define IREE_COMPILER_DIALECT_FLOW_IR_FLOWTYPES_H_

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/DenseMapInfo.h"
Expand Down Expand Up @@ -45,9 +46,12 @@ enum class TensorAccess : uint32_t {

// Blatantly ripped from ShapedType, because the closed type system means that
// we can't extend it and reuse all of this.
// TODO(hanchung): I'm not sure if I attach TypeInterface correct or not. It
// adds traits; it declares and implements the method itself.
class DispatchTensorType
: public Type::TypeBase<DispatchTensorType, Type,
detail::DispatchTensorTypeStorage> {
detail::DispatchTensorTypeStorage,
IREE::Encoding::EncodingTypeInterface::Trait> {
public:
using ImplType = detail::DispatchTensorTypeStorage;

Expand Down Expand Up @@ -132,6 +136,8 @@ class DispatchTensorType
}
return llvm::cast<RankedTensorType>(boundType);
}

Type updateEncoding(IREE::Encoding::EncodingAttr encoding) const;
};

void printType(DispatchTensorType &type, DialectAsmPrinter &p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.h"
Expand All @@ -16,6 +17,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/Attributes.h"
Expand All @@ -27,7 +29,10 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -84,6 +89,71 @@ getOperandsResourceAffinities(AffinityAnalysis &affinityAnalysis,
return operandAttrs;
}

static void updateExecutableOpEncodings(
ModuleOp moduleOp, Stream::ExecutableOp executableOp,
ArrayRef<Attribute> operandAttrs, AffinityAttr resultAffinity,
SymbolTable symbolTable,
std::function<LogicalResult(AffinityAttr, Operation *,
SetVector<Attribute> &)>
resolver) {
LLVM_DEBUG(llvm::dbgs() << "Update ExecutableOp: "
<< executableOp.getSymName() << "\n");
LLVM_DEBUG({
llvm::dbgs() << " operand affinities: [";
llvm::interleaveComma(operandAttrs, llvm::dbgs());
llvm::dbgs() << "]\n";
});

MLIRContext *ctx = executableOp.getContext();
for (auto exportOp :
executableOp.getOps<IREE::Stream::ExecutableExportOp>()) {
exportOp.getSymName();
auto funcOp = cast<mlir::FunctionOpInterface>(symbolTable.lookupSymbolIn(
executableOp.getInnerModule(), exportOp.getSymName()));
Region &region = funcOp.getFunctionBody();
auto argsAffinities = llvm::map_to_vector(
operandAttrs, [](Attribute attr) { return cast<ArrayAttr>(attr); });
auto resAffinityAttr =
ArrayAttr::get(ctx, {cast<Attribute>(resultAffinity)});
argsAffinities.resize(region.getNumArguments(), resAffinityAttr);
int idx = 0;
for (auto arg : region.getArguments()) {
if (!isa<IREE::Stream::BindingType>(arg.getType())) {
continue;
}
ArrayRef<Attribute> affinities = argsAffinities[idx++].getValue();
assert(affinities.size() == 1);
SetVector<Attribute> resolvedTargets;
if (failed(resolver(cast<Stream::AffinityAttr>(affinities[0]), moduleOp,
resolvedTargets))) {
LLVM_DEBUG(llvm::dbgs() << "failed on getting target resolvers\n");
continue;
}

for (auto user : arg.getUsers()) {
// TODO(hanchung): Is it the only case?
auto subspanOp = cast<IREE::Stream::BindingSubspanOp>(user);
auto resType =
dyn_cast<IREE::Flow::DispatchTensorType>(subspanOp.getType());
if (!resType) {
continue;
}
auto tensorType = dyn_cast<RankedTensorType>(resType.getBoundType());
if (!tensorType || !tensorType.getEncoding()) {
continue;
}
auto encoding =
dyn_cast<Encoding::EncodingAttr>(tensorType.getEncoding());

SmallVector<Attribute> targets(resolvedTargets.begin(),
resolvedTargets.end());
subspanOp.getResult().setType(
resType.updateEncoding(encoding.cloneWithTargets(targets)));
}
}
}
}

} // namespace

struct MakeEncodingSolvablePass
Expand Down Expand Up @@ -208,6 +278,7 @@ struct MakeEncodingSolvablePass
}
});

// Duplicate executables for each unqiue resource affinities.
IRRewriter rewriter(ctx);
DenseMap<std::tuple<AffinityAttr, ArrayAttr, Stream::ExecutableExportOp>,
Stream::ExecutableOp>
Expand Down Expand Up @@ -237,6 +308,7 @@ struct MakeEncodingSolvablePass
}
}

// Update dispatch sites.
for (auto dispatchOp : candidates) {
SmallVector<IREE::Stream::AffinityAttr> affinities;
assert(affinityAnalysis.tryLookupExecutionAffinity(dispatchOp,
Expand All @@ -253,7 +325,8 @@ struct MakeEncodingSolvablePass
affinities[0], ArrayAttr::get(ctx, operandAttrs), exportOp);

if (!dispatchSiteToExecutable.count(info)) {
LLVM_DEBUG(llvm::dbgs() << "not found, skip\n");
LLVM_DEBUG(llvm::dbgs() << "not found, skip\n "
<< dispatchOp.getEntryPoints() << "\n");
continue;
}

Expand All @@ -269,6 +342,22 @@ struct MakeEncodingSolvablePass
});
}

// Attach encoding targets to all the executables.
for (auto dispatchOp : candidates) {
SmallVector<IREE::Stream::AffinityAttr> affinities;
assert(affinityAnalysis.tryLookupExecutionAffinity(dispatchOp,
affinities));
SymbolRefAttr entryPoint =
*dispatchOp.getEntryPoints().getAsRange<SymbolRefAttr>().begin();
auto exportOp = cast<IREE::Stream::ExecutableExportOp>(
symbolTable.lookupSymbolIn(moduleOp, entryPoint));
auto executableOp =
exportOp->getParentOfType<IREE::Stream::ExecutableOp>();
SmallVector<Attribute> operandAttrs =
operandsResourceAffinities[dispatchOp];
updateExecutableOpEncodings(moduleOp, executableOp, operandAttrs,
affinities[0], symbolTable, resolver);
}
}
}
};
Expand Down

0 comments on commit 224a47a

Please sign in to comment.