Skip to content

Commit

Permalink
remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
lialan committed Nov 7, 2024
1 parent bba2388 commit 7f25513
Showing 1 changed file with 0 additions and 191 deletions.
191 changes: 0 additions & 191 deletions compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,196 +101,6 @@ static bool isByteAligned(ShapedType type) {
return (numElements * elementBits) % 8 == 0;
}

struct PadSubbyteTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const final {
auto target = writeOp.getVector();
auto targetType = cast<VectorType>(target.getType());
if (isByteAligned(targetType)) {
return failure();
}

auto source = writeOp.getSource();
auto sourceType = cast<ShapedType>(source.getType());
auto elemType = targetType.getElementType();
unsigned elementBits = targetType.getElementType().getIntOrFloatBitWidth();
auto numElements = targetType.getNumElements();

SmallVector<int64_t> strides;
SmallVector<int64_t> offsets;
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
strides.push_back(1);
offsets.push_back(0);
}

// TODO: we should keep the source and sink ... otherwise we are
// overwriting some part of the source tensor

SmallVector<int64_t> newShape = SmallVector<int64_t>(targetType.getShape());
newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits;
auto newTargetType = VectorType::get(newShape, elemType);

// create an empty vector of the correct size
SmallVector<bool> zeroValues;
for (unsigned i = 0; i < newTargetType.getNumElements(); ++i) {
zeroValues.push_back(false);
}
auto zeroVector = rewriter.create<arith::ConstantOp>(
writeOp.getLoc(), DenseIntElementsAttr::get(newTargetType, zeroValues));

auto extendedOp = rewriter.create<vector::InsertStridedSliceOp>(
writeOp->getLoc(), target, zeroVector, offsets, strides);

writeOp.getVectorMutable().assign(extendedOp);
return success();
}
};

struct PadSubbyteTransferReadPattern
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const final {
auto resultType = cast<VectorType>(readOp.getResult().getType());
if (isByteAligned(resultType)) {
return failure();
}

unsigned elementBits = resultType.getElementType().getIntOrFloatBitWidth();
auto numElements = resultType.getNumElements();

// pad the type to be byte aligned
SmallVector<int64_t> newShape = SmallVector<int64_t>(resultType.getShape());
newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits;
// Create a new vector type with the padded shape
auto newType = VectorType::get(newShape, resultType.getElementType());

// Create a new transfer read op with the new type
auto paddingValue = rewriter.create<arith::ConstantOp>(
readOp.getLoc(), resultType.getElementType(),
rewriter.getZeroAttr(resultType.getElementType()));

// use a vector extract to extract the original vector
SmallVector<int64_t> offsets, strides;
for (unsigned i = 0; i < resultType.getRank(); ++i) {
offsets.push_back(0);
strides.push_back(1);
}

auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), newType, readOp.getSource(), readOp.getIndices(),
paddingValue);

rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
readOp, newTransferReadOp, offsets, resultType.getShape(), strides);
return success();
}
};

struct PadSubbyteVectorLoadPattern : public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const final {
auto result = loadOp.getResult();
auto resultType = mlir::cast<VectorType>(result.getType());
if (isByteAligned(resultType)) {
return failure();
}

unsigned elementBits = resultType.getElementType().getIntOrFloatBitWidth();
auto numElements = resultType.getNumElements();

SmallVector<int64_t> newShape = SmallVector<int64_t>(resultType.getShape());
newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits;
auto newTargetType = VectorType::get(newShape, resultType.getElementType());

// create a new vector load op with the new type
auto newVectorLoad = rewriter.create<vector::LoadOp>(
loadOp.getLoc(), newTargetType, loadOp.getBase(), loadOp.getIndices());

auto newNumElements = newTargetType.getNumElements();
SmallVector<bool> zeroValues;
for (unsigned i = 0; i < newNumElements; ++i) {
zeroValues.push_back(false);
}

// extract strided slice
SmallVector<int64_t> offsets, strides;
for (unsigned i = 0; i < resultType.getRank(); ++i) {
offsets.push_back(0);
strides.push_back(1);
}

rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
loadOp, newVectorLoad, offsets, resultType.getShape(), strides);
return success();
}
};

struct PadSubbyteVectorStorePattern : public OpRewritePattern<vector::StoreOp> {
using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
PatternRewriter &rewriter) const final {
auto storeValue = storeOp.getValueToStore();
auto valueType = mlir::cast<ShapedType>(storeValue.getType());
if (isByteAligned(valueType)) {
return failure();
}

auto target = storeOp.getBase();
auto targetType = mlir::cast<ShapedType>(target.getType());
// check that the type size is byte aligned
auto elemType = valueType.getElementType();
unsigned elementBits = valueType.getElementType().getIntOrFloatBitWidth();
auto numElements = valueType.getNumElements();

SmallVector<int64_t> newShape = SmallVector<int64_t>(valueType.getShape());
newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits;
auto newValueType = VectorType::get(newShape, elemType);

SmallVector<int64_t> strides;
SmallVector<int64_t> offsets;
for (unsigned i = 0; i < targetType.getRank(); ++i) {
strides.push_back(1);
offsets.push_back(0);
}

// create an empty vector of the correct size
SmallVector<bool> zeroValues;
for (unsigned i = 0; i < newValueType.getNumElements(); ++i) {
zeroValues.push_back(false);
}
auto zeroVector = rewriter.create<arith::ConstantOp>(
storeOp.getLoc(), DenseIntElementsAttr::get(newValueType, zeroValues));

auto extendedOp = rewriter.create<vector::InsertStridedSliceOp>(
storeOp->getLoc(), storeValue, zeroVector, offsets, strides);

// create a mask and use masked store:
SmallVector<Value> maskShape;
for (auto dim : valueType.getShape()) {
maskShape.push_back(
rewriter.create<arith::ConstantIndexOp>(storeOp.getLoc(), dim));
}
auto mask = rewriter.create<vector::CreateMaskOp>(storeOp.getLoc(),
newValueType, maskShape);

rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
storeOp, target, storeOp.getIndices(), mask, extendedOp);
return success();
}
};

static void populateSubbyteTypeHandlingPatterns(RewritePatternSet &patterns) {
patterns.add<PadSubbyteTransferReadPattern, PadSubbyteTransferWritePattern,
PadSubbyteVectorLoadPattern, PadSubbyteVectorStorePattern>(
patterns.getContext());
}

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -329,7 +139,6 @@ struct EmulateNarrowTypePass final
affine::AffineDialect, IREE::HAL::HALDialect>(opLegalCallback);

RewritePatternSet patterns(ctx);
populateSubbyteTypeHandlingPatterns(patterns);
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
populateIREEResolveExtractStridedMetadataPatterns(ctx, patterns);
Expand Down

0 comments on commit 7f25513

Please sign in to comment.