From 2b681d54234190fdcf70b2844af8cd33700ee0be Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 8 Aug 2024 12:44:09 -0700 Subject: [PATCH] Fixup more than simple jacobian (#1712) * Fixup more than simple jacobian * remove prints * work around stack * ease jacobian * fix * fix * fix * fix * fix * fix --- src/Enzyme.jl | 277 +++++++++++++++++++++++++++++++++++++++-------- src/compiler.jl | 5 + test/runtests.jl | 125 +++++++++++++++++++++ 3 files changed, 362 insertions(+), 45 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 5ac12d1fe0..5bdecbcceb 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1002,6 +1002,10 @@ end end end +@inline function onehot(x::AbstractFloat) + return (one(x),) +end + """ gradient(::ReverseMode, f, x) @@ -1126,10 +1130,15 @@ grad = gradient(Forward, f, [2.0, 3.0]) ``` """ @inline function gradient(::ForwardMode, f, x; shadow=onehot(x)) - if length(x) == 0 + if length(shadow) == 0 return () end - values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) + res = values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) + if x isa AbstractFloat + res[1] + else + res + end end @inline function chunkedonehot(x, ::Val{chunk}) where chunk @@ -1141,6 +1150,10 @@ end end end +@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where chunk + return ((one(x),),) +end + @inline tupleconcat(x) = x @inline tupleconcat(x, y) = (x..., y...) @inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) @@ -1171,44 +1184,84 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) tmp = ntuple(length(shadow)) do i values(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) end - tupleconcat(tmp...) + res = tupleconcat(tmp...) + if x isa AbstractFloat + res[1] + else + res + end end @inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X} - ntuple(length(shadow)) do i + res = ntuple(length(shadow)) do i autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] end + if x isa AbstractFloat + res[1] + else + res + end end """ jacobian(::ForwardMode, f, x; shadow=onehot(x)) jacobian(::ForwardMode, f, x, ::Val{chunk}; shadow=onehot(x)) -Compute the jacobian of an array-input function `f` using (potentially vector) -forward mode. This is a simple rename of the [`gradient`](@ref) function, -and all relevant arguments apply here. +Compute the jacobian of an array or scalar-input function `f` using (potentially vector) +forward mode. All relevant arguments of the forward-mode [`gradient`](@ref) function +apply here. Example: ```jldoctest -f(x) = [x[1]*x[2], x[2]] +f(x) = [ x[1] * x[2], x[2] + x[3] ] -grad = jacobian(Forward, f, [2.0, 3.0]) +grad = jacobian(Forward, f, [2.0, 3.0, 4.0]) # output -2×2 Matrix{Float64}: - 3.0 2.0 - 0.0 1.0 +2×3 Matrix{Float64}: + 3.0 2.0 0.0 + 0.0 1.0 1.0 ``` + +For functions which return an AbstractArray, this function will return an array +whose shape is `(size(output)..., size(input)...)` + +For functions who return other types, this function will retun an array or tuple +of shape `size(input)` of values of the output type. """ @inline function jacobian(::ForwardMode, f, x; shadow=onehot(x)) - cols = if length(x) == 0 - return () + cols = if length(shadow) == 0 + () else values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) end - reduce(hcat, cols) + if x isa AbstractFloat + cols[1] + elseif length(cols) > 0 && cols[1] isa AbstractArray + inshape = size(x) + outshape = size(cols[1]) + # st : outshape x total inputs + st = @static if VERSION >= v"1.9" + Base.stack(cols) + else + reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) + end + + st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st + else + reshape(st, (outshape..., inshape...)) + end + + st3 + elseif x isa AbstractArray + inshape = size(x) + reshape(collect(cols), inshape) + else + cols + end end @inline function jacobian(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} @@ -1216,50 +1269,109 @@ end throw(ErrorException("Cannot differentiate with a batch size of 0")) end tmp = ntuple(length(shadow)) do i + Base.@_inline_meta values(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) end cols = tupleconcat(tmp...) - reduce(hcat, cols) + if x isa AbstractFloat + cols[1] + elseif length(cols) > 0 && cols[1] isa AbstractArray + inshape = size(x) + outshape = size(cols[1]) + # st : outshape x total inputs + st = @static if VERSION >= v"1.9" + Base.stack(cols) + else + reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) + end + + st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st + else + reshape(st, (outshape..., inshape...)) + end + + st3 + elseif x isa AbstractArray + inshape = size(x) + reshape(collect(cols), inshape) + else + cols + end end @inline function jacobian(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F,X} cols = ntuple(length(shadow)) do i + Base.@_inline_meta autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] end - reduce(hcat, cols) + if x isa AbstractFloat + cols[1] + elseif length(cols) > 0 && cols[1] isa AbstractArray + inshape = size(x) + outshape = size(cols[1]) + # st : outshape x total inputs + st = @static if VERSION >= v"1.9" + Base.stack(cols) + else + reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) + end + + st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st + else + reshape(st, (outshape..., inshape...)) + end + + st3 + elseif x isa AbstractArray + inshape = size(x) + reshape(collect(cols), inshape) + else + cols + end end """ - jacobian(::ReverseMode, f, x, ::Val{num_outs}, ::Val{chunk}) + jacobian(::ReverseMode, f, x, ::Val{num_outs}, ::Val{chunk}=Val(1)) + jacobian(::ReverseMode, f, x) -Compute the jacobian of an array-input function `f` using (potentially vector) +Compute the jacobian of an array-output function `f` using (potentially vector) reverse mode. The `chunk` argument denotes the chunk size to use and `num_outs` denotes the number of outputs `f` will return in an array. Example: ```jldoctest -f(x) = [x[1]*x[2], x[2]] +f(x) = [ x[1] * x[2], x[2] + x[3] ] -grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) +grad = jacobian(Reverse, f, [2.0, 3.0, 4.0], Val(2)) # output -2×2 Matrix{Float64}: - 3.0 2.0 - 0.0 1.0 +2×3 transpose(::Matrix{Float64}) with eltype Float64: + 3.0 2.0 0.0 + 0.0 1.0 1.0 +``` + +For functions which return an AbstractArray, this function will return an array +whose shape is `(size(output)..., size(input)...)` + +For functions who return other types, this function will retun an array or tuple +of shape `size(output)` of values of the input type. ``` """ -@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, ReturnPrimal, RABI<:ABI, ErrIfFuncWritten} - @assert !ReturnPrimal +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten} num = ((n_out_val + chunk - 1) ÷ chunk) if chunk == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - tt′ = Tuple{BatchDuplicated{Core.Typeof(x), chunk}} - tt = Tuple{Core.Typeof(x)} + XT = Core.Typeof(x) + MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + tt′ = MD ? Tuple{BatchMixedDuplicated{XT, chunk}} : Tuple{BatchDuplicated{XT, chunk}} + tt = Tuple{XT} rt = Core.Compiler.return_type(f, tt) ModifiedBetween = Val((false, false)) FA = Const{Core.Typeof(f)} @@ -1281,28 +1393,59 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) tmp = ntuple(num) do i Base.@_inline_meta - dx = ntuple(i == num ? last_size : chunk) do idx + dx = ntuple(Val(i == num ? last_size : chunk)) do idx Base.@_inline_meta - zero(x) + z = make_zero(x) + MD ? Ref(z) : z end - res = (i == num ? primal2 : primal)(Const(f), BatchDuplicated(x, dx)) + res = (i == num ? primal2 : primal)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx)) tape = res[1] j = 0 for shadow in res[3] j += 1 @inbounds shadow[(i-1)*chunk+j] += Compiler.default_adjoint(eltype(typeof(shadow))) end - (i == num ? adjoint2 : adjoint)(Const(f), BatchDuplicated(x, dx), tape) - return dx + (i == num ? adjoint2 : adjoint)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), tape) + return MD ? (ntuple(Val(i == num ? last_size : chunk)) do idx + Base.@_inline_meta + dx[idx][] + end) : dx, (i == 1 ? size(res[3][1]) : nothing) + end + rows = tupleconcat(map(first, tmp)...) + outshape = tmp[1][2] + if x isa AbstractArray + inshape = size(x) + + st = @static if VERSION >= v"1.9" + Base.stack(rows) + else + reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) + end + + st2 = if length(outshape) == 1 || VERSION < v"1.9" + st + else + reshape(st, (inshape..., outshape...)) + end + + st3 = if length(outshape) == 1 && length(inshape) == 1 + transpose(st2) + else + transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) + PermutedDimsArray(st2, transp) + end + + st3 + else + reshape(collect(rows), outshape) end - rows = tupleconcat(tmp...) - mapreduce(LinearAlgebra.adjoint, vcat, rows) end -@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,ReturnPrimal,RABI<:ABI, ErrIfFuncWritten} - @assert !ReturnPrimal - tt′ = Tuple{Duplicated{Core.Typeof(x)}} - tt = Tuple{Core.Typeof(x)} +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} + XT = Core.Typeof(x) + MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} + tt = Tuple{XT} rt = Core.Compiler.return_type(f, tt) ModifiedBetween = Val((false, false)) FA = Const{Core.Typeof(f)} @@ -1312,16 +1455,60 @@ end Val(codegen_world_age(Core.Typeof(f), tt)) end primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) - rows = ntuple(n_outs) do i + tmp = ntuple(n_outs) do i Base.@_inline_meta - dx = zero(x) - res = primal(Const(f), Duplicated(x, dx)) + z = make_zero(x) + dx = MD ? Ref(z) : z + res = primal(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx)) tape = res[1] @inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3]))) - adjoint(Const(f), Duplicated(x, dx), tape) - return dx + adjoint(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx), tape) + return MD ? dx[] : dx, (i == 1 ? size(res[3]) : nothing) + end + rows = map(first, tmp) + outshape = tmp[1][2] + if x isa AbstractArray + inshape = size(x) + st = @static if VERSION >= v"1.9" + Base.stack(rows) + else + reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) + end + + st2 = if length(outshape) == 1 || VERSION < v"1.9" + st + else + reshape(st, (inshape..., outshape...)) + end + + st3 = if length(outshape) == 1 && length(inshape) == 1 + transpose(st2) + else + transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) + PermutedDimsArray(st2, transp) + end + + st3 + else + reshape(collect(rows), outshape) + end +end + +@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} + res = f(x) + jac = if res isa AbstractArray + jacobian(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x, Val(length(jac))) + elseif res isa AbstractFloat + gradient(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x) + else + throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) + end + + if ReturnPrimal + (res, jac) + else + jac end - mapreduce(LinearAlgebra.adjoint, vcat, rows) end """ diff --git a/src/compiler.jl b/src/compiler.jl index c0f00a0bf1..e1ca4c395a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3886,7 +3886,12 @@ include("rules/activityrules.jl") @inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: DuplicatedNoNeed = API.DFT_DUP_NONEED @inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedNoNeed = API.DFT_DUP_NONEED +const DumpPreEnzyme = Ref(false) + function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs) + if DumpPreEnzyme[] + API.EnzymeDumpModuleRef(mod.ref) + end world = job.world interp = GPUCompiler.get_interpreter(job) rt = job.config.params.rt diff --git a/test/runtests.jl b/test/runtests.jl index 8fdd6f6037..94015cfa4f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2803,6 +2803,131 @@ end end end +@testset "Simple Jacobian" begin + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0) ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0]) ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, Val(1)) ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, Val(1)) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], Val(1)) ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, Val(2)) ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, Val(2)) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], Val(2)) ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2)) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2), Val(1)) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2), Val(2)) ≈ [1.0, 2.0] + + x = float.(reshape(1:6, 2, 3)) + + fillabs2(x) = [sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x), 1000*sum(abs2, x)] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x) + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, Val(1)) + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, Val(2)) + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, Val(4), Val(1)) + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, Val(4), Val(2)) + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + struct InpStruct + i1::Float64 + i2::Float64 + i3::Float64 + end + + fillinpabs2(x) = [(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 10*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 100*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 1000*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3)] + + x2 = InpStruct(1.0, 2.0, 3.0) + + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, Val(4), Val(1)) + + @test jac[1] == InpStruct(2.0, 4.0, 6.0) + @test jac[2] == InpStruct(20.0, 40.0, 60.0) + @test jac[3] == InpStruct(200.0, 400.0, 600.0) + @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) + + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, Val(4), Val(2)) + + @test jac[1] == InpStruct(2.0, 4.0, 6.0) + @test jac[2] == InpStruct(20.0, 40.0, 60.0) + @test jac[3] == InpStruct(200.0, 400.0, 600.0) + @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) + + struct OutStruct + i1::Float64 + i2::Float64 + i3::Float64 + end + + filloutabs2(x) = OutStruct(sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x)) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x) + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, Val(1)) + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, Val(2)) + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) + +end + + @testset "Jacobian" begin function inout(v) [v[2], v[1]*v[1], v[1]*v[1]*v[1]]