Skip to content

Commit

Permalink
Type unstable custom rule tape (#1755)
Browse files Browse the repository at this point in the history
* Type unstable custom rule tape

* fix

* fix

* Specialize tape type further
  • Loading branch information
wsmoses authored Aug 26, 2024
1 parent 32b7aa2 commit c1e98c9
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 6 deletions.
46 changes: 46 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,52 @@ function emit_jl!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value
call!(B, FT, fn, [val])
end

function emit_getfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LLVM.Value
curent_bb = position(B)
fn = LLVM.parent(curent_bb)
mod = LLVM.parent(fn)

T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
T_pprjlvalue = LLVM.PointerType(T_prjlvalue)
T_int32 = LLVM.Int32Type()

gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32])
inv, _ = get_function!(mod, "jl_f_getfield", gen_FT)

args = [val, fld]

@static if VERSION < v"1.9.0-"
FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true)
inv = bitcast!(B, inv, LLVM.PointerType(FT))
res = call!(B, FT, inv, args)
LLVM.callconv!(res, 37)
else
julia_call, FT = get_function!(mod, "julia.call",
LLVM.FunctionType(T_prjlvalue,
[LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true))
res = call!(B, FT, julia_call, LLVM.Value[inv, args...])
end
return res
end


function emit_nthfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LLVM.Value
curent_bb = position(B)
fn = LLVM.parent(curent_bb)
mod = LLVM.parent(fn)

T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
T_size_t = convert(LLVM.LLVMType, Int)

gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_size_t])
inv, _ = get_function!(mod, "jl_get_nth_field_checked", gen_FT)

args = [val, fld]
call!(B, gen_FT, inv, args)
end

function emit_jl_throw!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value
curent_bb = position(B)
fn = LLVM.parent(curent_bb)
Expand Down
53 changes: 47 additions & 6 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,20 @@ end

if (aug_RT <: EnzymeRules.AugmentedReturn || aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && !(aug_RT isa UnionAll) && !(aug_RT isa Union) && !(aug_RT === Union{})
TapeT = EnzymeRules.tape_type(aug_RT)
elseif (aug_RT isa UnionAll) && (aug_RT <: EnzymeRules.AugmentedReturn) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturn.body.body.body.name
if aug_RT.body.parameters[3] isa TypeVar
TapeT = aug_RT.body.parameters[3].ub
else
TapeT = Any
end
elseif (aug_RT isa UnionAll) && (aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturnFlexShadow.body.body.body.name
if aug_RT.body.parameters[3] isa TypeVar
TapeT = aug_RT.body.parameters[3].ub
else
TapeT = Any
end
else
TapeT = Any
end

mod = LLVM.parent(LLVM.parent(LLVM.parent(orig)))
Expand Down Expand Up @@ -778,6 +792,12 @@ end

miRT = enzyme_custom_extract_mi(llvmf)[2]
_, sret, returnRoots = get_return_info(miRT)
sret_union = is_sret_union(miRT)

if sret_union
emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " had a union sret of type "*string(miRT)*" which is not currently supported")
return tapeV
end

if !forward
funcTy = rev_TT.parameters[isKWCall ? 4 : 2]
Expand Down Expand Up @@ -960,16 +980,33 @@ end
ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadowJL ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT}
end
end
abstract = false
if aug_RT != ST
ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Any}
emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " return type mismatch, expected "*string(ST)*" found "* string(aug_RT))
return tapeV
abs = (EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, T} where T)
if aug_RT <: abs
abstract = true
else
ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Any}
emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " return type mismatch, expected "*string(ST)*" found "* string(aug_RT))
return tapeV
end
end

resV = if abstract
StructTy = convert(LLVMType, EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Nothing})
if StructTy != LLVM.VoidType()
load!(B, StructTy, bitcast!(B, res, LLVM.PointerType(StructTy, addrspace(value_type(res)))))
else
res
end
else
res
end

idx = 0
if needsPrimal
@assert !isghostty(RealRt)
normalV = extract_value!(B, res, idx)
normalV = extract_value!(B, resV, idx)
if get_return_info(RealRt)[2] !== nothing
val = new_from_original(gutils, operands(orig)[1])
store!(B, normalV, val)
Expand All @@ -982,7 +1019,7 @@ end
if needsShadow
if needsShadowJL
@assert !isghostty(RealRt)
shadowV = extract_value!(B, res, idx)
shadowV = extract_value!(B, resV, idx)
if get_return_info(RealRt)[2] !== nothing
dval = invert_pointer(gutils, operands(orig)[1], B)

Expand All @@ -1002,7 +1039,11 @@ end
end
end
if needsTape
tapeV = extract_value!(B, res, idx).ref
tapeV = if abstract
emit_nthfield!(B, res, LLVM.ConstantInt(2)).ref
else
extract_value!(B, res, idx).ref
end
idx+=1
end
else
Expand Down
33 changes: 33 additions & 0 deletions test/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,5 +378,38 @@ end
@test dvals [0., 0., 46.7, 0.]
end

unstabletape(x) = x^2

function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(unstabletape)}, ::Type{<:Active}, x::Active)
tape = if x.val < 3
400
else
(x.val +7 ) * 10
end
if needs_primal(config)
return AugmentedReturn{eltype(x), Nothing, typeof(tape)}(func.val(x.val), nothing, tape)
else
return AugmentedReturn{Nothing, Nothing, typeof(tape)}(nothing, nothing, tape)
end
end

function reverse(config::ConfigWidth{1}, ::Const{typeof(unstabletape)}, dret, tape, x::Active{T}) where T
return (T(tape)::T,)
end

unstabletapesq(x) = unstabletape(x)^2

@testset "Unstable Tape" begin
@test Enzyme.autodiff(Enzyme.Reverse, unstabletape, Active(2.0))[1][1] 400.0
@test Enzyme.autodiff(Enzyme.ReverseWithPrimal, unstabletape, Active(2.0))[1][1] 400.0
@test Enzyme.autodiff(Enzyme.Reverse, unstabletape, Active(5.0))[1][1] (5.0 + 7) * 10
@test Enzyme.autodiff(Enzyme.ReverseWithPrimal, unstabletape, Active(5.0))[1][1] (5.0 + 7) * 10

@test Enzyme.autodiff(Enzyme.Reverse, unstabletapesq, Active(2.0))[1][1] (400.0)
@test Enzyme.autodiff(Enzyme.ReverseWithPrimal, unstabletapesq, Active(2.0))[1][1] (400.0)
@test Enzyme.autodiff(Enzyme.Reverse, unstabletapesq, Active(5.0))[1][1] ((5.0 + 7) * 10)
@test Enzyme.autodiff(Enzyme.ReverseWithPrimal, unstabletapesq, Active(5.0))[1][1] ((5.0 + 7) * 10)
end

include("mixedrrule.jl")
end # ReverseRules

2 comments on commit c1e98c9

@wsmoses
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/113839

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.33 -m "<description of version>" c1e98c9ebb2921a78fd69ecab104b4a3cae1efb3
git push origin v0.12.33

Please sign in to comment.