Skip to content

Commit

Permalink
0-d vector fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Nov 6, 2024
1 parent 92ac1b2 commit 15c2937
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ struct DistributeBroadcast final : OpDistributionPattern<vector::BroadcastOp> {
auto vectorType = VectorType::get(distShape, elementType);

VectorValue srcVector = dyn_cast<VectorValue>(broadcastOp.getSource());
// If the srcVector is a scalar (like f32) or a rank-0 vector (like
// vector<f32>), we proceed with the scalar distribution branch.
if (!srcVector || !isNonZeroRank(srcVector)) {
// If the srcVector is a scalar (like f32) we proceed with the scalar
// distribution branch.
if (!srcVector) {
// The way distribution currently works, there is no partial thread
// distribution, so a scalar is available to all threads. Scalar
// distribution is simply a broadcast from scalar to the distributed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,14 @@ void DistributionPattern::replaceOpWithDistributedValues(
for (auto [opResult, replacement] :
llvm::zip_equal(op->getOpResults(), values)) {
// If this value is a vector type, it must be converted back to simd.
if (auto replacementType = dyn_cast<VectorType>(replacement.getType())) {
if (replacementType.getRank() != 0) {
auto oldResult = cast<VectorValue>(opResult);
// Create a toSIMD op to convert the value back to the simd.
rewriter.setInsertionPointAfterValue(oldResult);
Value toSIMD = rewriter.create<IREE::VectorExt::ToSIMDOp>(
oldResult.getLoc(), oldResult.getType(), replacement);
// Add to replacements.
replacement = toSIMD;
}
if (isa<VectorType>(replacement.getType())) {
auto oldResult = cast<VectorValue>(opResult);
// Create a toSIMD op to convert the value back to the simd.
rewriter.setInsertionPointAfterValue(oldResult);
Value toSIMD = rewriter.create<IREE::VectorExt::ToSIMDOp>(
oldResult.getLoc(), oldResult.getType(), replacement);
// Add to replacements.
replacement = toSIMD;
}
replacements.push_back(replacement);
}
Expand Down

0 comments on commit 15c2937

Please sign in to comment.