Skip to content

Commit

Permalink
Fix casting for arith.cmpi operands to be of same type.
Browse files Browse the repository at this point in the history
  • Loading branch information
sahas3 committed Jan 5, 2025
1 parent 6ff8070 commit a3cd406
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,15 @@ static Value wrapIndicesAroundMax(OpBuilder &b, Location loc, Value index,
Value input, int64_t dim) {
// performs the operation : index = index % maxIndex to wrap index around
// maxIndex
Value maxIndexValue = castIndexToInt64(b, loc, getDimOp(b, loc, input, dim));
Value isBeyondMaxIndices = b.create<arith::CmpIOp>(
Value maxIndexValue = getDimOp(b, loc, input, dim);
maxIndexValue =
b.createOrFold<arith::IndexCastOp>(loc, index.getType(), maxIndexValue);
Value isBeyondMaxIndices = b.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, index, maxIndexValue);
Value wrappedIndices = b.create<arith::RemSIOp>(loc, index, maxIndexValue);
return b.create<arith::SelectOp>(loc, isBeyondMaxIndices, wrappedIndices,
index);
Value wrappedIndices =
b.createOrFold<arith::RemSIOp>(loc, index, maxIndexValue);
return b.createOrFold<arith::SelectOp>(loc, isBeyondMaxIndices,
wrappedIndices, index);
}

namespace {
Expand Down

0 comments on commit a3cd406

Please sign in to comment.