Skip to content

Commit

Permalink
[CIR][ABI][Lowering] Fixes calling convention (#1308)
Browse files Browse the repository at this point in the history
This PR fixes two run time bugs in the calling convention pass. These
bugs were found with `csmith`.

Case #1.  Return value from a function. 
Before this PR the returned value were stored in a bit casted memory
location.
But for the next example it's not safe: the size of a memory slot is
less than the size of return value. And the store operation cause a
segfault!
```
#pragma pack(push)
#pragma pack(1)
 typedef struct {
    int f0 : 18;
    int f1 : 31;
    int f2 : 5;
    int f3 : 29;
    int f4 : 24;
 } PackedS;
#pragma pack(pop)
```
CIR type for this struct is `!ty_PackedS1_ = !cir.struct<struct
"PackedS1" {!cir.array<!u8i x 14>}>`, i.e. it occupies 14 bytes.
Before this PR the next code 
```
PackedS foo(void) {
     PackedS s;
     return s;
}

 void check(void) {     
     PackedS y = foo();    
}
```
produced the next CIR:
```
  %0 = cir.alloca !ty_PackedS1_, !cir.ptr<!ty_PackedS1_>, ["y", init] {alignment = 1 : i64}
  %1 = cir.call @foo() : () -> !cir.array<!u64i x 2> 
  %2 = cir.cast(bitcast, %0 : !cir.ptr<!ty_PackedS1_>), !cir.ptr<!cir.array<!u64i x 2>>
  cir.store %1, %2 : !cir.array<!u64i x 2>, !cir.ptr<!cir.array<!u64i x 2>>
```
As one cat see, `%1` is an array of two 64-bit integers and the memory
was allocated for 14 bytes only (size of struct). Hence the segfault!
This PR fixes such cases and now we have a coercion through memory,
which is even with the OG.


Case #2.  Passing an argument from a pointer deref.
Previously for the struct types passed by value we tried to find alloca
instruction in order to use it as a source for memcpy operation. But if
we have pointer dereference, (in other words if we have a `<!cir.ptr <
!cir.ptr ... > >` as alloca result) we don't need to search for the
address of the location where this pointer stored - instead we're
interested in the pointer itself. And it's a general approach - instead
of trying to find an alloca instruction we need to find a first pointer
on the way - that will be an address we meed to use for the memcpy
source.

I combined these two cases into a single PR since there are only few
changes actually. But I can split in two if you'd prefer
  • Loading branch information
gitoleg authored Feb 5, 2025
1 parent fee4bb6 commit 5373f42
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 19 deletions.
37 changes: 25 additions & 12 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ cir::AllocaOp findAlloca(mlir::Operation *op) {
return {};
}

mlir::Value findAddr(mlir::Value v) {
if (mlir::isa<cir::PointerType>(v.getType()))
return v;

auto op = v.getDefiningOp();
if (!op || !mlir::isa<cir::CastOp, cir::LoadOp, cir::ReturnOp>(op))
return {};

return findAddr(op->getOperand(0));
}

/// Create a store to \param Dst from \param Src where the source and
/// destination may have different types.
///
Expand Down Expand Up @@ -338,10 +349,10 @@ mlir::Value createCoercedValue(mlir::Value Src, mlir::Type Ty,
return CGF.buildAggregateBitcast(Src, Ty);
}

if (auto alloca = findAlloca(Src.getDefiningOp())) {
auto tmpAlloca = createTmpAlloca(CGF, alloca.getLoc(), Ty);
createMemCpy(CGF, tmpAlloca, alloca, SrcSize.getFixedValue());
return CGF.getRewriter().create<LoadOp>(alloca.getLoc(),
if (mlir::Value addr = findAddr(Src)) {
auto tmpAlloca = createTmpAlloca(CGF, addr.getLoc(), Ty);
createMemCpy(CGF, tmpAlloca, addr, SrcSize.getFixedValue());
return CGF.getRewriter().create<LoadOp>(addr.getLoc(),
tmpAlloca.getResult());
}

Expand Down Expand Up @@ -371,7 +382,6 @@ mlir::Value createCoercedNonPrimitive(mlir::Value src, mlir::Type ty,

auto tySize = LF.LM.getDataLayout().getTypeStoreSize(ty);
createMemCpy(LF, alloca, addr, tySize.getFixedValue());

auto newLoad = bld.create<LoadOp>(src.getLoc(), alloca.getResult());
bld.replaceAllOpUsesWith(load, newLoad);

Expand Down Expand Up @@ -1265,6 +1275,14 @@ mlir::Value LowerFunction::rewriteCallOp(const LowerFunctionInfo &CallInfo,

// FIXME(cir): Use return value slot here.
mlir::Value RetVal = callOp.getResult();
mlir::Value dstPtr;
for (auto *user : Caller->getUsers()) {
if (auto storeOp = mlir::dyn_cast<StoreOp>(user)) {
assert(!dstPtr && "multiple destinations for the return value");
dstPtr = storeOp.getAddr();
}
}

// TODO(cir): Check for volatile return values.
cir_cconv_assert(!cir::MissingFeatures::volatileTypes());

Expand All @@ -1283,16 +1301,11 @@ mlir::Value LowerFunction::rewriteCallOp(const LowerFunctionInfo &CallInfo,
if (mlir::dyn_cast<StructType>(RetTy) &&
mlir::cast<StructType>(RetTy).getNumElements() != 0) {
RetVal = newCallOp.getResult();
createCoercedStore(RetVal, dstPtr, false, *this);

llvm::SmallVector<StoreOp, 8> workList;
for (auto *user : Caller->getUsers())
if (auto storeOp = mlir::dyn_cast<StoreOp>(user))
workList.push_back(storeOp);
for (StoreOp storeOp : workList) {
auto destPtr =
createCoercedBitcast(storeOp.getAddr(), RetVal.getType(), *this);
rewriter.replaceOpWithNewOp<StoreOp>(storeOp, RetVal, destPtr);
}
rewriter.eraseOp(storeOp);
}

// NOTE(cir): No need to convert from a temp to an RValue. This is
Expand Down
94 changes: 87 additions & 7 deletions clang/test/CIR/CallConvLowering/AArch64/aarch64-cc-structs.c
Original file line number Diff line number Diff line change
Expand Up @@ -302,19 +302,18 @@ void pass_nested_u(NESTED_U a) {}

// CHECK: cir.func no_proto @call_nested_u()
// CHECK: %[[#V0:]] = cir.alloca !ty_NESTED_U, !cir.ptr<!ty_NESTED_U>
// CHECK: %[[#V1:]] = cir.alloca !u64i, !cir.ptr<!u64i>, ["tmp"] {alignment = 8 : i64}
// CHECK: %[[#V1:]] = cir.alloca !u64i, !cir.ptr<!u64i>, ["tmp"]
// CHECK: %[[#V2:]] = cir.load %[[#V0]] : !cir.ptr<!ty_NESTED_U>, !ty_NESTED_U
// CHECK: %[[#V3:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_NESTED_U>)
// CHECK: %[[#V4:]] = cir.load %[[#V3]]
// CHECK: %[[#V5:]] = cir.cast(bitcast, %[[#V3]]
// CHECK: %[[#V6:]] = cir.load %[[#V5]]
// CHECK: %[[#V7:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_NESTED_U>), !cir.ptr<!void>
// CHECK: %[[#V3:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_NESTED_U>), !cir.ptr<!ty_anon2E0_>
// CHECK: %[[#V4:]] = cir.load %[[#V3]] : !cir.ptr<!ty_anon2E0_>, !ty_anon2E0_
// CHECK: %[[#V5:]] = cir.cast(bitcast, %[[#V3]] : !cir.ptr<!ty_anon2E0_>), !cir.ptr<!ty_anon2E1_>
// CHECK: %[[#V6:]] = cir.load %[[#V5]] : !cir.ptr<!ty_anon2E1_>, !ty_anon2E1_
// CHECK: %[[#V7:]] = cir.cast(bitcast, %[[#V5]] : !cir.ptr<!ty_anon2E1_>), !cir.ptr<!void>
// CHECK: %[[#V8:]] = cir.cast(bitcast, %[[#V1]] : !cir.ptr<!u64i>), !cir.ptr<!void>
// CHECK: %[[#V9:]] = cir.const #cir.int<2> : !u64i
// CHECK: cir.libc.memcpy %[[#V9]] bytes from %[[#V7]] to %[[#V8]] : !u64i, !cir.ptr<!void> -> !cir.ptr<!void>
// CHECK: %[[#V10:]] = cir.load %[[#V1]] : !cir.ptr<!u64i>, !u64i
// CHECK: cir.call @pass_nested_u(%[[#V10]]) : (!u64i) -> ()
// CHECK: cir.return

// LLVM: void @call_nested_u()
// LLVM: %[[#V1:]] = alloca %struct.NESTED_U, i64 1, align 1
Expand All @@ -330,3 +329,84 @@ void call_nested_u() {
NESTED_U a;
pass_nested_u(a);
}


#pragma pack(push)
#pragma pack(1)
typedef struct {
int f0 : 18;
int f1 : 31;
int f2 : 5;
int f3 : 29;
int f4 : 24;
} PackedS1;
#pragma pack(pop)

PackedS1 foo(void) {
PackedS1 s;
return s;
}

void bar(void) {
PackedS1 y = foo();
}

// CHECK: cir.func @bar
// CHECK: %[[#V0:]] = cir.alloca !ty_PackedS1_, !cir.ptr<!ty_PackedS1_>, ["y", init]
// CHECK: %[[#V1:]] = cir.alloca !cir.array<!u64i x 2>, !cir.ptr<!cir.array<!u64i x 2>>, ["tmp"]
// CHECK: %[[#V2:]] = cir.call @foo() : () -> !cir.array<!u64i x 2>
// CHECK: cir.store %[[#V2]], %[[#V1]] : !cir.array<!u64i x 2>, !cir.ptr<!cir.array<!u64i x 2>>
// CHECK: %[[#V3:]] = cir.cast(bitcast, %[[#V1]] : !cir.ptr<!cir.array<!u64i x 2>>), !cir.ptr<!void>
// CHECK: %[[#V4:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_PackedS1_>), !cir.ptr<!void>
// CHECK: %[[#V5:]] = cir.const #cir.int<14> : !u64i
// CHECK: cir.libc.memcpy %[[#V5]] bytes from %[[#V3]] to %[[#V4]] : !u64i, !cir.ptr<!void> -> !cir.ptr<!void>

// LLVML: void @bar
// LLVM: %[[#V1:]] = alloca %struct.PackedS1, i64 1, align 1
// LLVM: %[[#V2:]] = alloca [2 x i64], i64 1, align 8
// LLVM: %[[#V3:]] = call [2 x i64] @foo()
// LLVM: store [2 x i64] %[[#V3]], ptr %[[#V2]], align 8
// LLVM: call void @llvm.memcpy.p0.p0.i64(ptr %[[#V1]], ptr %[[#V2]], i64 14, i1 false)


#pragma pack(push)
#pragma pack(1)
typedef struct {
short f0;
int f1;
} PackedS2;
#pragma pack(pop)

PackedS2 g[3] = {{1,2},{3,4},{5,6}};

void baz(PackedS2 a) {
short *x = &g[2].f0;
(*x) = a.f0;
}

void qux(void) {
const PackedS2 *s1 = &g[1];
baz(*s1);
}

// check source of memcpy
// CHECK: cir.func @qux
// CHECK: %[[#V0:]] = cir.alloca !cir.ptr<!ty_PackedS2_>, !cir.ptr<!cir.ptr<!ty_PackedS2_>>, ["s1", init]
// CHECK: %[[#V1:]] = cir.alloca !u64i, !cir.ptr<!u64i>, ["tmp"]
// CHECK: %[[#V2:]] = cir.get_global @g : !cir.ptr<!cir.array<!ty_PackedS2_ x 3>>
// CHECK: %[[#V3:]] = cir.const #cir.int<1> : !s32i
// CHECK: %[[#V4:]] = cir.cast(array_to_ptrdecay, %[[#V2]] : !cir.ptr<!cir.array<!ty_PackedS2_ x 3>>), !cir.ptr<!ty_PackedS2_>
// CHECK: %[[#V5:]] = cir.ptr_stride(%[[#V4]] : !cir.ptr<!ty_PackedS2_>, %[[#V3]] : !s32i), !cir.ptr<!ty_PackedS2_>
// CHECK: cir.store %[[#V5]], %[[#V0]] : !cir.ptr<!ty_PackedS2_>, !cir.ptr<!cir.ptr<!ty_PackedS2_>>
// CHECK: %[[#V6:]] = cir.load deref %[[#V0]] : !cir.ptr<!cir.ptr<!ty_PackedS2_>>, !cir.ptr<!ty_PackedS2_>
// CHECK: %[[#V7:]] = cir.cast(bitcast, %[[#V6]] : !cir.ptr<!ty_PackedS2_>), !cir.ptr<!void>
// CHECK: %[[#V8:]] = cir.const #cir.int<6> : !u64i
// CHECK: cir.libc.memcpy %[[#V8]] bytes from %[[#V7]]

// LLVM: void @qux
// LLVM: %[[#V1:]] = alloca ptr, i64 1, align 8
// LLVM: %[[#V2:]] = alloca i64, i64 1, align 8
// LLVM: store ptr getelementptr (%struct.PackedS2, ptr @g, i64 1), ptr %[[#V1]], align 8
// LLVM: %[[#V3:]] = load ptr, ptr %[[#V1]], align 8
// LLVM: %[[#V4:]] = load %struct.PackedS2, ptr %[[#V3]], align 1
// LLVM: call void @llvm.memcpy.p0.p0.i64(ptr %[[#V2]], ptr %[[#V3]], i64 6, i1 false)

0 comments on commit 5373f42

Please sign in to comment.