From bc5aa8577e78147eac868e75f6228fb2e1fdbbfc Mon Sep 17 00:00:00 2001 From: Alan Li Date: Sat, 5 Oct 2024 01:41:37 +0000 Subject: [PATCH] Revert GenericVectorization and move things to FoldMemRef --- .../Codegen/Common/GenericVectorization.cpp | 101 ------------------ 1 file changed, 101 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp index 7780ac99327af..8aee5ba2c0e4a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp @@ -325,100 +325,6 @@ class GenericVectorizationPass final void runOnOperation() override; }; -struct PadSubtypeTransferReadPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::TransferReadOp readOp, - PatternRewriter &rewriter) const final { - auto resultType = mlir::cast(readOp.getResult().getType()); - // check that the type size is byte aligned - unsigned elementBits = resultType.getElementType().getIntOrFloatBitWidth(); - auto numElements = resultType.getNumElements(); - if ((numElements * elementBits) % 8 != 0) { - // pad the type to be byte aligned - SmallVector newShape = SmallVector(resultType.getShape()); - newShape.back() += - (8 - (numElements * elementBits) % 8) / elementBits; - // Create a new vector type with the padded shape - auto newType = VectorType::get(newShape, resultType.getElementType()); - - // Create a new transfer read op with the new type - auto paddingValue = rewriter.create( - readOp.getLoc(), resultType.getElementType(), - rewriter.getZeroAttr(resultType.getElementType())); - - // use a vector extract to extract the original vector - SmallVector offsets, strides; - for (unsigned i = 0; i < resultType.getRank(); ++i) { - offsets.push_back(0); - strides.push_back(1); - } - - auto newTransferReadOp = rewriter.create( - readOp.getLoc(), newType, readOp.getSource(), readOp.getIndices(), - paddingValue); - - rewriter.replaceOpWithNewOp( - readOp, newTransferReadOp, offsets, resultType.getShape(), strides); - } - return success(true); - } -}; - -struct PadSubtypeTransferWritePattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, - PatternRewriter &rewriter) const final { - auto source = writeOp.getSource(); - auto target = writeOp.getVector(); - auto targetType = mlir::cast(target.getType()); - auto sourceType = mlir::cast(source.getType()); - auto elemType = targetType.getElementType(); - // check that the type size is byte aligned - unsigned elementBits = targetType.getElementType().getIntOrFloatBitWidth(); - auto numElements = targetType.getNumElements(); - if ((numElements * elementBits) % 8 != 0) { - SmallVector strides; - SmallVector offsets; - for (unsigned i = 0; i < sourceType.getRank(); ++i) { - strides.push_back(1); - offsets.push_back(0); - } - - // TODO: we should keep the source and sink ... otherwise we are - // overwriting some part of the source tensor - SmallVector newShape = SmallVector(targetType.getShape()); - newShape.back() += - (8 - (numElements * elementBits) % 8) / elementBits; - auto newTargetType = VectorType::get(newShape, elemType); - - // create an empty vector of the correct size - auto numElements = newTargetType.getNumElements(); - SmallVector zeroValues; - for (unsigned i = 0; i < numElements; ++i) { - zeroValues.push_back(false); - } - auto zeroVector = rewriter.create( - writeOp.getLoc(), DenseIntElementsAttr::get(newTargetType, zeroValues)); - - auto extendedOp = rewriter.create( - writeOp->getLoc(), target, zeroVector, offsets, strides); - - writeOp.getVectorMutable().assign(extendedOp); - return success(); - } - return failure(); - } -}; - -static void populateSubBytePaddingPatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); -} - void GenericVectorizationPass::runOnOperation() { MLIRContext *context = &getContext(); auto funcOp = getOperation(); @@ -533,13 +439,6 @@ void GenericVectorizationPass::runOnOperation() { linalg::populatePadOpVectorizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } - - // in case the data type is subbyte-aligned, we need to pad the data type - { - RewritePatternSet patterns(funcOp.getContext()); - mlir::iree_compiler::populateSubBytePaddingPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - } } } // namespace