diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 96cc32f4ba59..cede292ec1c3 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -421,12 +421,15 @@ static Value wrapIndicesAroundMax(OpBuilder &b, Location loc, Value index, Value input, int64_t dim) { // performs the operation : index = index % maxIndex to wrap index around // maxIndex - Value maxIndexValue = castIndexToInt64(b, loc, getDimOp(b, loc, input, dim)); - Value isBeyondMaxIndices = b.create( + Value maxIndexValue = getDimOp(b, loc, input, dim); + maxIndexValue = + b.createOrFold(loc, index.getType(), maxIndexValue); + Value isBeyondMaxIndices = b.createOrFold( loc, arith::CmpIPredicate::sge, index, maxIndexValue); - Value wrappedIndices = b.create(loc, index, maxIndexValue); - return b.create(loc, isBeyondMaxIndices, wrappedIndices, - index); + Value wrappedIndices = + b.createOrFold(loc, index, maxIndexValue); + return b.createOrFold(loc, isBeyondMaxIndices, + wrappedIndices, index); } namespace {