Skip to content

Commit

Permalink
Julia 1.11: the adventure continues (#1966)
Browse files Browse the repository at this point in the history
* Julia 1.11: the adventure continues

* more fixups of gc

* fixup

* fixup

* mem

* fix

* evoice

* fix

* improve del

* fix

* assertion

* fix

* around

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Oct 16, 2024
1 parent 5172d77 commit b80735e
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 26 deletions.
9 changes: 7 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2965,8 +2965,13 @@ function zero_allocation(
push!(function_attributes(wrapper_f), StringAttribute("enzyme_no_escaping_allocation"))
push!(function_attributes(wrapper_f), EnumAttribute("alwaysinline", 0))
push!(function_attributes(wrapper_f), EnumAttribute("nofree", 0))
push!(function_attributes(wrapper_f), EnumAttribute("argmemonly", 0))
push!(function_attributes(wrapper_f), EnumAttribute("writeonly", 0))

if LLVM.version().major <= 15
push!(function_attributes(wrapper_f), EnumAttribute("argmemonly", 0))
push!(function_attributes(wrapper_f), EnumAttribute("writeonly", 0))
else
push!(function_attributes(wrapper_f), EnumAttribute("memory", WriteOnlyArgMemEffects.data))
end
push!(function_attributes(wrapper_f), EnumAttribute("willreturn", 0))
if LLVM.version().major >= 12
push!(function_attributes(wrapper_f), EnumAttribute("mustprogress", 0))
Expand Down
54 changes: 49 additions & 5 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,20 @@ function nodecayed_phis!(mod::LLVM.Module)
end
if addr == 13 && !hasload
if isa(v, LLVM.LoadInst)
return getparent(operands(v)[1], offset, true)
v2, o2, hl2 = getparent(operands(v)[1], LLVM.ConstantInt(offty, 0), true)
@assert o2 == LLVM.ConstantInt(offty, 0)
return v2, offset, true
end
if isa(v, LLVM.CallInst)
cf = LLVM.called_operand(v)
if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded"
ld = operands(v)[2]
if isa(ld, LLVM.LoadInst)
v2, o2, hl2 = getparent(operands(ld)[1], LLVM.ConstantInt(offty, 0), true)
@assert o2 == LLVM.ConstantInt(offty, sizeof(Int))
return v2, offset, true
end
end
end
end

Expand Down Expand Up @@ -894,7 +907,7 @@ function nodecayed_phis!(mod::LLVM.Module)
return v2, offset, skipload
end

if isa(v, LLVM.GetElementPtrInst) && !hasload
if isa(v, LLVM.GetElementPtrInst)
v2, offset, skipload =
getparent(operands(v)[1], offset, hasload)
offset = nuwadd!(
Expand Down Expand Up @@ -1035,9 +1048,40 @@ function nodecayed_phis!(mod::LLVM.Module)

position!(nb, nonphi)
if addr == 13
nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10))
nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11))
nphi = load!(nb, ty, nphi)
@static if VERSION < v"1.11-"
nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10))
nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11))
nphi = load!(nb, ty, nphi)
else
base_obj = nphi

# %value_phi11 = phi {} addrspace(10)* [ %55, %L78 ], [ %54, %L76 ]

# %.phi.trans.insert77 = bitcast {} addrspace(10)* %value_phi11 to { i64, {} addrspace(10)** } addrspace(10)*
# %.phi.trans.insert78 = addrspacecast { i64, {} addrspace(10)** } addrspace(10)* %.phi.trans.insert77 to { i64, {} addrspace(10)** } addrspace(11)*
# %.phi.trans.insert79 = getelementptr inbounds { i64, {} addrspace(10)** }, { i64, {} addrspace(10)** } addrspace(11)* %.phi.trans.insert78, i64 0, i32 1
# %.pre80 = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %.phi.trans.insert79, align 8, !dbg !532, !tbaa !19, !alias.scope !26, !noalias !29

# %154 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %value_phi11, {} addrspace(10)** %.pre80), !dbg !532

jlt = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), 10)
pjlt = LLVM.PointerType(jlt)
gent = LLVM.StructType([convert(LLVMType, Int), pjlt])
pgent = LLVM.PointerType(LLVM.StructType([convert(LLVMType, Int), pjlt]), 10)

nphi = bitcast!(nb, nphi, pgent)
nphi = addrspacecast!(nb, nphi, LLVM.PointerType(gent, 11))
nphi = inbounds_gep!(nb, gent, nphi, [LLVM.ConstantInt(Int64(0)), LLVM.ConstantInt(Int32(1))])
nphi = load!(nb, pjlt, nphi)

GTy = LLVM.FunctionType(LLVM.PointerType(jlt, 13), LLVM.LLVMType[jlt, pjlt])
gcloaded, _ = get_function!(
mod,
"julia.gc_loaded",
GTy
)
nphi = call!(nb, GTy, gcloaded, LLVM.Value[base_obj, nphi])
end
else
nphi = addrspacecast!(nb, nphi, ty)
end
Expand Down
5 changes: 5 additions & 0 deletions src/compiler/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ const ReadOnlyArgMemEffects = MemoryEffect(
(MRI_NoModRef << getLocationPos(InaccessibleMem)) |
(MRI_NoModRef << getLocationPos(Other)),
)
const WriteOnlyArgMemEffects = MemoryEffect(
(MRI_Mod << getLocationPos(ArgMem)) |
(MRI_NoModRef << getLocationPos(InaccessibleMem)) |
(MRI_NoModRef << getLocationPos(Other)),
)
const NoEffects = MemoryEffect(
(MRI_NoModRef << getLocationPos(ArgMem)) |
(MRI_NoModRef << getLocationPos(InaccessibleMem)) |
Expand Down
150 changes: 131 additions & 19 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,38 +373,150 @@ function check_ir!(job, errors, mod::LLVM.Module)
eraseInst(mod, f)
end
rewrite_ccalls!(mod)

del = LLVM.Function[]
for f in collect(functions(mod))
check_ir!(job, errors, imported, f)
if in(f, del)
continue
end
check_ir!(job, errors, imported, f, del)
end
for d in del
LLVM.API.LLVMDeleteFunction(d)
end

del = LLVM.Function[]
for f in collect(functions(mod))
check_ir!(job, errors, imported, f)
if in(f, del)
continue
end
check_ir!(job, errors, imported, f, del)
end
for d in del
LLVM.API.LLVMDeleteFunction(d)
end

return errors
end

function check_ir!(job, errors, imported, f::LLVM.Function)

function unwrap_ptr_casts(val::LLVM.Value)
while true
is_simple_cast = false
is_simple_cast |= isa(val, LLVM.BitCastInst)
is_simple_cast |= isa(val, LLVM.AddrSpaceCastInst) || isa(val, LLVM.PtrToIntInst)
is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMAddrSpaceCast
is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMIntToPtr
is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMBitCast

if !is_simple_cast
return val
else
val = operands(val)[1]
end
end
end

function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns)
calls = []
isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0
for bb in blocks(f), inst in instructions(bb)
mod = LLVM.parent(f)
for bb in blocks(f), inst in collect(instructions(bb))
if isa(inst, LLVM.CallInst)
push!(calls, inst)
# remove illegal invariant.load and jtbaa_const invariants
elseif isInline && isa(inst, LLVM.LoadInst)
md = metadata(inst)
if haskey(md, LLVM.MD_tbaa)
modified = LLVM.Metadata(
ccall(
(:EnzymeMakeNonConstTBAA, API.libEnzyme),
LLVM.API.LLVMMetadataRef,
(LLVM.API.LLVMMetadataRef,),
md[LLVM.MD_tbaa],
),
)
setindex!(md, modified, LLVM.MD_tbaa)
end
if haskey(md, LLVM.MD_invariant_load)
delete!(md, LLVM.MD_invariant_load)
elseif isa(inst, LLVM.LoadInst)

fn_got = unwrap_ptr_casts(operands(inst)[1])
fname = String(name(fn_got))
match_ = match(r"^jlplt_(.*)_\d+_got$", fname)

if match_ !== nothing
fname = match_[1]
FT = nothing
todo = LLVM.Instruction[inst]
while length(todo) != 0
v = pop!(todo)
for u in LLVM.uses(v)
u = LLVM.user(u)
if isa(u, LLVM.CallInst)
FT = called_type(u)
break
end
if isa(u, LLVM.BitCastInst)
push!(todo, u)
continue
end
end
if FT !== nothing
break
end
end
@assert FT !== nothing
newf, _ = get_function!(mod, String(fname), FT)

initfn = unwrap_ptr_casts(LLVM.initializer(fn_got))
loadfn = first(instructions(first(blocks(initfn))))::LLVM.LoadInst
opv = operands(loadfn)[1]::LLVM.GlobalVariable

if startswith(fname, "jl_") || startswith(fname, "ijl_")
else
@assert "unsupported jl got"
msg = sprint() do io::IO
println(
io,
"Enzyme internal error unsupported got",
)
println(io, "inst=", inst)
println(io, "fname=", fname)
println(io, "FT=", FT)
println(io, "fn_got=", fn_got)
println(io, "init=", string(initfn))
println(io, "opv=", string(opv))
end
throw(AssertionError(msg))
end

if value_type(newf) != value_type(inst)
newf = const_pointercast(newf, value_type(inst))
end
replace_uses!(inst, newf)
LLVM.API.LLVMInstructionEraseFromParent(inst)

baduse = false
for u in LLVM.uses(fn_got)
u = LLVM.user(u)
if isa(u, LLVM.StoreInst)
continue
end
baduse = true
end

if !baduse
push!(deletedfns, initfn)
LLVM.initializer!(fn_got, LLVM.null(value_type(LLVM.initializer(fn_got))))
replace_uses!(opv, LLVM.null(value_type(opv)))
LLVM.API.LLVMDeleteGlobal(opv)
replace_uses!(fn_got, LLVM.null(value_type(fn_got)))
LLVM.API.LLVMDeleteGlobal(fn_got)
end

elseif isInline
md = metadata(inst)
if haskey(md, LLVM.MD_tbaa)
modified = LLVM.Metadata(
ccall(
(:EnzymeMakeNonConstTBAA, API.libEnzyme),
LLVM.API.LLVMMetadataRef,
(LLVM.API.LLVMMetadataRef,),
md[LLVM.MD_tbaa],
),
)
setindex!(md, modified, LLVM.MD_tbaa)
end
if haskey(md, LLVM.MD_invariant_load)
delete!(md, LLVM.MD_invariant_load)
end
end
end
end
Expand Down

2 comments on commit b80735e

@wsmoses
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/117367

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.10 -m "<description of version>" b80735e9437180b1c5952582fcca06acb62a15d4
git push origin v0.13.10

Please sign in to comment.