From b80735e9437180b1c5952582fcca06acb62a15d4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 16 Oct 2024 01:42:20 -0500 Subject: [PATCH] Julia 1.11: the adventure continues (#1966) * Julia 1.11: the adventure continues * more fixups of gc * fixup * fixup * mem * fix * evoice * fix * improve del * fix * assertion * fix * around * fix * fix * fix --- src/compiler.jl | 9 ++- src/compiler/optimize.jl | 54 +++++++++++-- src/compiler/utils.jl | 5 ++ src/compiler/validation.jl | 150 ++++++++++++++++++++++++++++++++----- 4 files changed, 192 insertions(+), 26 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 049d96285d..710725dcca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2965,8 +2965,13 @@ function zero_allocation( push!(function_attributes(wrapper_f), StringAttribute("enzyme_no_escaping_allocation")) push!(function_attributes(wrapper_f), EnumAttribute("alwaysinline", 0)) push!(function_attributes(wrapper_f), EnumAttribute("nofree", 0)) - push!(function_attributes(wrapper_f), EnumAttribute("argmemonly", 0)) - push!(function_attributes(wrapper_f), EnumAttribute("writeonly", 0)) + + if LLVM.version().major <= 15 + push!(function_attributes(wrapper_f), EnumAttribute("argmemonly", 0)) + push!(function_attributes(wrapper_f), EnumAttribute("writeonly", 0)) + else + push!(function_attributes(wrapper_f), EnumAttribute("memory", WriteOnlyArgMemEffects.data)) + end push!(function_attributes(wrapper_f), EnumAttribute("willreturn", 0)) if LLVM.version().major >= 12 push!(function_attributes(wrapper_f), EnumAttribute("mustprogress", 0)) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index eccc5789dd..7214aa540f 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -793,7 +793,20 @@ function nodecayed_phis!(mod::LLVM.Module) end if addr == 13 && !hasload if isa(v, LLVM.LoadInst) - return getparent(operands(v)[1], offset, true) + v2, o2, hl2 = getparent(operands(v)[1], LLVM.ConstantInt(offty, 0), true) + @assert o2 == LLVM.ConstantInt(offty, 0) + return v2, offset, true + end + if isa(v, LLVM.CallInst) + cf = LLVM.called_operand(v) + if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" + ld = operands(v)[2] + if isa(ld, LLVM.LoadInst) + v2, o2, hl2 = getparent(operands(ld)[1], LLVM.ConstantInt(offty, 0), true) + @assert o2 == LLVM.ConstantInt(offty, sizeof(Int)) + return v2, offset, true + end + end end end @@ -894,7 +907,7 @@ function nodecayed_phis!(mod::LLVM.Module) return v2, offset, skipload end - if isa(v, LLVM.GetElementPtrInst) && !hasload + if isa(v, LLVM.GetElementPtrInst) v2, offset, skipload = getparent(operands(v)[1], offset, hasload) offset = nuwadd!( @@ -1035,9 +1048,40 @@ function nodecayed_phis!(mod::LLVM.Module) position!(nb, nonphi) if addr == 13 - nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) - nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) - nphi = load!(nb, ty, nphi) + @static if VERSION < v"1.11-" + nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) + nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) + nphi = load!(nb, ty, nphi) + else + base_obj = nphi + + # %value_phi11 = phi {} addrspace(10)* [ %55, %L78 ], [ %54, %L76 ] + + # %.phi.trans.insert77 = bitcast {} addrspace(10)* %value_phi11 to { i64, {} addrspace(10)** } addrspace(10)* + # %.phi.trans.insert78 = addrspacecast { i64, {} addrspace(10)** } addrspace(10)* %.phi.trans.insert77 to { i64, {} addrspace(10)** } addrspace(11)* + # %.phi.trans.insert79 = getelementptr inbounds { i64, {} addrspace(10)** }, { i64, {} addrspace(10)** } addrspace(11)* %.phi.trans.insert78, i64 0, i32 1 + # %.pre80 = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %.phi.trans.insert79, align 8, !dbg !532, !tbaa !19, !alias.scope !26, !noalias !29 + + # %154 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %value_phi11, {} addrspace(10)** %.pre80), !dbg !532 + + jlt = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), 10) + pjlt = LLVM.PointerType(jlt) + gent = LLVM.StructType([convert(LLVMType, Int), pjlt]) + pgent = LLVM.PointerType(LLVM.StructType([convert(LLVMType, Int), pjlt]), 10) + + nphi = bitcast!(nb, nphi, pgent) + nphi = addrspacecast!(nb, nphi, LLVM.PointerType(gent, 11)) + nphi = inbounds_gep!(nb, gent, nphi, [LLVM.ConstantInt(Int64(0)), LLVM.ConstantInt(Int32(1))]) + nphi = load!(nb, pjlt, nphi) + + GTy = LLVM.FunctionType(LLVM.PointerType(jlt, 13), LLVM.LLVMType[jlt, pjlt]) + gcloaded, _ = get_function!( + mod, + "julia.gc_loaded", + GTy + ) + nphi = call!(nb, GTy, gcloaded, LLVM.Value[base_obj, nphi]) + end else nphi = addrspacecast!(nb, nphi, ty) end diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 8a801067eb..5539c5ed06 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -35,6 +35,11 @@ const ReadOnlyArgMemEffects = MemoryEffect( (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other)), ) +const WriteOnlyArgMemEffects = MemoryEffect( + (MRI_Mod << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), +) const NoEffects = MemoryEffect( (MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index b672c50f57..839aa120d7 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -373,38 +373,150 @@ function check_ir!(job, errors, mod::LLVM.Module) eraseInst(mod, f) end rewrite_ccalls!(mod) + + del = LLVM.Function[] for f in collect(functions(mod)) - check_ir!(job, errors, imported, f) + if in(f, del) + continue + end + check_ir!(job, errors, imported, f, del) + end + for d in del + LLVM.API.LLVMDeleteFunction(d) end + + del = LLVM.Function[] for f in collect(functions(mod)) - check_ir!(job, errors, imported, f) + if in(f, del) + continue + end + check_ir!(job, errors, imported, f, del) + end + for d in del + LLVM.API.LLVMDeleteFunction(d) end return errors end -function check_ir!(job, errors, imported, f::LLVM.Function) + +function unwrap_ptr_casts(val::LLVM.Value) + while true + is_simple_cast = false + is_simple_cast |= isa(val, LLVM.BitCastInst) + is_simple_cast |= isa(val, LLVM.AddrSpaceCastInst) || isa(val, LLVM.PtrToIntInst) + is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMAddrSpaceCast + is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMIntToPtr + is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMBitCast + + if !is_simple_cast + return val + else + val = operands(val)[1] + end + end +end + +function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) calls = [] isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0 - for bb in blocks(f), inst in instructions(bb) + mod = LLVM.parent(f) + for bb in blocks(f), inst in collect(instructions(bb)) if isa(inst, LLVM.CallInst) push!(calls, inst) # remove illegal invariant.load and jtbaa_const invariants - elseif isInline && isa(inst, LLVM.LoadInst) - md = metadata(inst) - if haskey(md, LLVM.MD_tbaa) - modified = LLVM.Metadata( - ccall( - (:EnzymeMakeNonConstTBAA, API.libEnzyme), - LLVM.API.LLVMMetadataRef, - (LLVM.API.LLVMMetadataRef,), - md[LLVM.MD_tbaa], - ), - ) - setindex!(md, modified, LLVM.MD_tbaa) - end - if haskey(md, LLVM.MD_invariant_load) - delete!(md, LLVM.MD_invariant_load) + elseif isa(inst, LLVM.LoadInst) + + fn_got = unwrap_ptr_casts(operands(inst)[1]) + fname = String(name(fn_got)) + match_ = match(r"^jlplt_(.*)_\d+_got$", fname) + + if match_ !== nothing + fname = match_[1] + FT = nothing + todo = LLVM.Instruction[inst] + while length(todo) != 0 + v = pop!(todo) + for u in LLVM.uses(v) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + FT = called_type(u) + break + end + if isa(u, LLVM.BitCastInst) + push!(todo, u) + continue + end + end + if FT !== nothing + break + end + end + @assert FT !== nothing + newf, _ = get_function!(mod, String(fname), FT) + + initfn = unwrap_ptr_casts(LLVM.initializer(fn_got)) + loadfn = first(instructions(first(blocks(initfn))))::LLVM.LoadInst + opv = operands(loadfn)[1]::LLVM.GlobalVariable + + if startswith(fname, "jl_") || startswith(fname, "ijl_") + else + @assert "unsupported jl got" + msg = sprint() do io::IO + println( + io, + "Enzyme internal error unsupported got", + ) + println(io, "inst=", inst) + println(io, "fname=", fname) + println(io, "FT=", FT) + println(io, "fn_got=", fn_got) + println(io, "init=", string(initfn)) + println(io, "opv=", string(opv)) + end + throw(AssertionError(msg)) + end + + if value_type(newf) != value_type(inst) + newf = const_pointercast(newf, value_type(inst)) + end + replace_uses!(inst, newf) + LLVM.API.LLVMInstructionEraseFromParent(inst) + + baduse = false + for u in LLVM.uses(fn_got) + u = LLVM.user(u) + if isa(u, LLVM.StoreInst) + continue + end + baduse = true + end + + if !baduse + push!(deletedfns, initfn) + LLVM.initializer!(fn_got, LLVM.null(value_type(LLVM.initializer(fn_got)))) + replace_uses!(opv, LLVM.null(value_type(opv))) + LLVM.API.LLVMDeleteGlobal(opv) + replace_uses!(fn_got, LLVM.null(value_type(fn_got))) + LLVM.API.LLVMDeleteGlobal(fn_got) + end + + elseif isInline + md = metadata(inst) + if haskey(md, LLVM.MD_tbaa) + modified = LLVM.Metadata( + ccall( + (:EnzymeMakeNonConstTBAA, API.libEnzyme), + LLVM.API.LLVMMetadataRef, + (LLVM.API.LLVMMetadataRef,), + md[LLVM.MD_tbaa], + ), + ) + setindex!(md, modified, LLVM.MD_tbaa) + end + if haskey(md, LLVM.MD_invariant_load) + delete!(md, LLVM.MD_invariant_load) + end end end end