Skip to content

Commit

Permalink
Revert GenericVectorization and move things to FoldMemRef
Browse files Browse the repository at this point in the history
  • Loading branch information
lialan committed Oct 7, 2024
1 parent 178a875 commit bc5aa85
Showing 1 changed file with 0 additions and 101 deletions.
101 changes: 0 additions & 101 deletions compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,100 +325,6 @@ class GenericVectorizationPass final
void runOnOperation() override;
};

struct PadSubtypeTransferReadPattern
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const final {
auto resultType = mlir::cast<VectorType>(readOp.getResult().getType());
// check that the type size is byte aligned
unsigned elementBits = resultType.getElementType().getIntOrFloatBitWidth();
auto numElements = resultType.getNumElements();
if ((numElements * elementBits) % 8 != 0) {
// pad the type to be byte aligned
SmallVector<int64_t> newShape = SmallVector<int64_t>(resultType.getShape());
newShape.back() +=
(8 - (numElements * elementBits) % 8) / elementBits;
// Create a new vector type with the padded shape
auto newType = VectorType::get(newShape, resultType.getElementType());

// Create a new transfer read op with the new type
auto paddingValue = rewriter.create<arith::ConstantOp>(
readOp.getLoc(), resultType.getElementType(),
rewriter.getZeroAttr(resultType.getElementType()));

// use a vector extract to extract the original vector
SmallVector<int64_t> offsets, strides;
for (unsigned i = 0; i < resultType.getRank(); ++i) {
offsets.push_back(0);
strides.push_back(1);
}

auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), newType, readOp.getSource(), readOp.getIndices(),
paddingValue);

rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
readOp, newTransferReadOp, offsets, resultType.getShape(), strides);
}
return success(true);
}
};

struct PadSubtypeTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const final {
auto source = writeOp.getSource();
auto target = writeOp.getVector();
auto targetType = mlir::cast<VectorType>(target.getType());
auto sourceType = mlir::cast<ShapedType>(source.getType());
auto elemType = targetType.getElementType();
// check that the type size is byte aligned
unsigned elementBits = targetType.getElementType().getIntOrFloatBitWidth();
auto numElements = targetType.getNumElements();
if ((numElements * elementBits) % 8 != 0) {
SmallVector<int64_t> strides;
SmallVector<int64_t> offsets;
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
strides.push_back(1);
offsets.push_back(0);
}

// TODO: we should keep the source and sink ... otherwise we are
// overwriting some part of the source tensor
SmallVector<int64_t> newShape = SmallVector<int64_t>(targetType.getShape());
newShape.back() +=
(8 - (numElements * elementBits) % 8) / elementBits;
auto newTargetType = VectorType::get(newShape, elemType);

// create an empty vector of the correct size
auto numElements = newTargetType.getNumElements();
SmallVector<bool> zeroValues;
for (unsigned i = 0; i < numElements; ++i) {
zeroValues.push_back(false);
}
auto zeroVector = rewriter.create<arith::ConstantOp>(
writeOp.getLoc(), DenseIntElementsAttr::get(newTargetType, zeroValues));

auto extendedOp = rewriter.create<vector::InsertStridedSliceOp>(
writeOp->getLoc(), target, zeroVector, offsets, strides);

writeOp.getVectorMutable().assign(extendedOp);
return success();
}
return failure();
}
};

static void populateSubBytePaddingPatterns(RewritePatternSet &patterns) {
patterns.add<PadSubtypeTransferReadPattern, PadSubtypeTransferWritePattern>(
patterns.getContext());
}

void GenericVectorizationPass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();
Expand Down Expand Up @@ -533,13 +439,6 @@ void GenericVectorizationPass::runOnOperation() {
linalg::populatePadOpVectorizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

// in case the data type is subbyte-aligned, we need to pad the data type
{
RewritePatternSet patterns(funcOp.getContext());
mlir::iree_compiler::populateSubBytePaddingPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
}

} // namespace
Expand Down

0 comments on commit bc5aa85

Please sign in to comment.