Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilize ChainRulesCore thunks #966

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open

Conversation

oschulz
Copy link

@oschulz oschulz commented May 7, 2021

Note: Requires FluxML/ZygoteRules.jl#17

Currently, Zygote always unthunks ChainRuleCore thunks, which is wasteful and may also lead to trouble in cases why a thunks just can't be run for the given types/contents.

With

using ChainRulesCore, Zygote
foo(a, b) = a * b
function ChainRulesCore.rrule(::typeof(foo), a, b)
    y = foo(a, b)
    function foo_pullback(Ȳ)
        ∂a = @thunk (@info "Thunk ∂a"; return Ȳ * b')
        ∂b = @thunk (@info "Thunk ∂b"; return a' * Ȳ)
        return (NO_FIELDS, ∂a, ∂b)
    end
    return y, foo_pullback
end
a = rand(4,3); b = rand(3,2); Ȳ = rand(4,2);
Zygote.pullback(foo, a, b)[2](Ȳ)
let a = a; Zygote.pullback(b -> foo(a, b), b)[2](Ȳ) end

we obviously get

julia> Zygote.pullback(foo, a, b)[2](Ȳ)
[ Info: Thunk ∂a
[ Info: Thunk ∂b
([...], [...])

but we currently also get

julia> let a = a; Zygote.pullback(b -> foo(a, b), b)[2](Ȳ) end
[ Info: Thunk ∂a
[ Info: Thunk ∂b
([...],)

so Zygote also executes both thunks if we only require the pullback in respect to b.

With this PR, we should get

julia> let a = a; Zygote.pullback(b -> foo(a, b), b)[2](Ȳ) end
[ Info: Thunk ∂b
([0.5793294513282393 0.5442572290286626; 0.748469566554706 0.8319235475298901; 0.609222600261931 0.6201995902507516],)

instead.

Note: Requires FluxML/ZygoteRules.jl#17

CC @oxinabox , @mzgubic

src/compiler/interface.jl Outdated Show resolved Hide resolved
@oschulz
Copy link
Author

oschulz commented May 7, 2021

I think it may actually be the unthunk in tailmemaybe(x::Tuple) that results in that error, not the one in Grads(IdDict(...)).

@oschulz
Copy link
Author

oschulz commented May 7, 2021

It looks like for this to work, we'll need to overload a lot of functions like x', Adjoint, transpose, Transpose, diag, Diagonal, permutedims, permutedims!, Ref, Vector, Matrix, Array, collect, convert, and so on, to unthunk thunks, since thunks can now appear as in pullback functions.

@oschulz
Copy link
Author

oschulz commented May 8, 2021

Ok, with FluxML/ZygoteRules.jl#17 this passes the Zygote test suit on my system now (incl. CUDA tests).

In both the simple example above and more complex cases that I've tested only the required thunks are unthunked. Hard to give guarantees on this, obviously.

The @adjoint_keepthunks in FluxML/ZygoteRules.jl#17 provides a way for adjoint authors to declare whether their pullback code is prepared to handle thunks or not.

Maybe we could add a similar mechanism to ChainRulesCore (an rrule_keepthunks or so) to provide a soft transition to (hopefully) more and more rules that support thunks in the ecosystem? (see below)

@oschulz
Copy link
Author

oschulz commented May 8, 2021

Maybe we could add a similar mechanism to ChainRulesCore (an rrule_keepthunks or so) to provide a soft transition to (hopefully) more and more rules that support thunks in the ecosystem?

I had a look, the number of packages that actually define rrules seems finite (Bijectors, ChainRules, ChainRulesCore, ComponentArrays, CRlibm, DiffEqBase, DistributionsAD, LoopVectorization, NNlib, SpecialFunctions, SymbolicUtils, WebSockets and Zygote in my fairly well-filled package dir). So maybe the better approach would be:

  • Do unthunk in Zygote.wrap_chainrules_input for now (already in this PR). Release a patch version of ZygoteRules with Make @adjoint unthunk pullback inputs ZygoteRules.jl#17 a patch version of Zygote with this PR (Utilize ChainRulesCore thunks #966), these changes should be non-breaking across the ecosystem.

  • Add thunk-support to functions like x', Adjoint, transpose, Transpose, diag, Diagonal, permutedims, permutedims!, Ref, Vector, Matrix, Array, collect, convert (probably a few more) in ChainRules (maybe some of them in ChainRulesCore?).

  • Release ChainRulesCore v0.10, alert users that from now on we're going to take thunks seriously. :-) (Meaning, they may appear in the input to pullback functions). With the additional thunk-supporting methods, we probably wouldn't need to change too much code in rrule-using packages like those listed above.

  • Release a Zygote v0.7 or so, requiring ChainRulesCore v0.10 and removing unthunking in Zygote.wrap_chainrules_input.

This way, we can have a "soft" path to using thunks more and more. We'll have an immediate benefit since this PR can be done as a patch releases of ZygoteRules and Zygot, without ecosystem propagation delay. And after a (hopefully not too long while), rrules will be required to be thunk-aware, so no need to introduce an rrule_keepthunks. And packages with existing ZygoteRules.@adjoints can just replace them with ZygoteRules.@adjoint_keepthunks in their code if the pullback is thunk-compatible (many will be after more basics functions support thunks) - bit by bit, so thunks can propagate deeper with minimal effort.

Update: see below for changed proposal

@oxinabox
Copy link
Member

and after a (hopefully not too long while), rrules will be required to be thunk-aware, so no need to introduce an rrule_keepthunks

I think this idea is closer to the right way.
We want to overload all linear operators in Base/stdlibs to work on Zero and Thunk (and all other abstract differentials) anyway.
We can probably enforce thunk support via automatically testing it in ChainRulesTestUtils.

How much of a problem is it if linear operators on a Thunk unthunk them?
I guess we can decide on a per operator basis. I guess maybe we can have Thunk-ness propagate until + is called (which means a gradient is being accumulated) or something like that, though I worry about the overhead that adds carrying that comptuational graph around (but maybe it isn't so much).

What if we made @adjoint always unthunk, and didn't introduce a @adjoint_keepthunks , but instead ensured that all existing rrules supported thunks. and moved the rules you have marked here with @adjoint_keepthunks into ChainRules ? And just told people wanting keep thunks that they should write rrules instead of using @adjoint ?
I guess it would be too long to wait for all of the rrules to be updated to support thunks?

cc @mzgubic @willtebbutt @nickrobinson251

@oschulz
Copy link
Author

oschulz commented May 11, 2021

How much of a problem is it if linear operators on a Thunk unthunk them?

I guess most of them will have to?

What if we made @adjoint always unthunk, and didn't introduce a @adjoint_keepthunks , but instead ensured that all existing rrules supported thunks. [...]

How about this:

  • Step 1: We remove @adjoint_keepthunks Make @adjoint unthunk pullback inputs ZygoteRules.jl#17 and add it to Zygote as an internal Zygote.@_adjoint_keepthunks for now. That way, it's easy to change all the low-level adjoints in Zygote (we need keepthunks to a lot of adjoints in Zygote itself to have any benefit from this, as done in this PR) and easy to change back in cases where we might have been to zealous. We can then merge and release Make @adjoint unthunk pullback inputs ZygoteRules.jl#17 and Utilize ChainRulesCore thunks #966 as non-breaking and without introducing new APIs.

  • Step 2: We release a breaking new version of ChainRulesCore that mandates that rrules are prepared to handle thunks, and update the rules in ChainRules accordingly.

  • Step 3: We add all Zygote-internal @_adjoint_keepthunks rules (and maybe some more @_adjoint rules as well) to as rrules to ChainRules. We test performance implications (since rrules undergo some wrapping/unwrapping and a lot of these rules are low level and will be called frequently) while we remove those Zygote-internal adjoints until none are left. Then we can get rid of Zygote.@_adjoint_keepthunks again. We encourage the ecosystem to do the same (use ChainRulesCore instead of ZygoteRules).
    May need to fix things like Zygote ignores ChainRulesCore.rrule for Base.getindex on AbstractArray types #811 first.

I guess it would be too long to wait for all of the rrules to be updated to support thunks?

I suspect switching all necessary Zygote-internal adjoints to rrules and having the ecosystem make all rrules thunk-compatible (step 3 above) will take some time. But if we do it like above, we'd get an immediate benefit. (Purely from an egoistic point of view, I have two use cases right now that would profit a lot, and since there are other people involved, living on dev-branches for Zygote and ZygoteRules would not be an option).

@mzgubic
Copy link
Collaborator

mzgubic commented May 11, 2021

My main takeaway from this discussion is that not supporting thunks inside rrules is a serious deficiency, and we should work towards fixing that. It should probably be on the list for 1.0? JuliaDiff/ChainRules.jl#408

A couple of questions:

  • @oschulz Why do we need methods for Zero()? I thought we needed methods for Thunks?
  • @oxinabox Could we move all the @adjoint_keepthunk methods to rrules? Since some of the existing adjoints touch the accum_param? I guess we'd have to keep the adjoints as a wrapper?

I think of Zygote as doing one of the three things to generate a pullback: hit an rrule, hit an adjoint, or do its compiler magic. Right now, we unthunk before getting out of the rrule, so the other two parts never see thunks.

I guess the minimal change we can make is:

  • move unthunk from chain_rules_output to chain_rules_input
  • add unthunk to @adjoint

In this way, we keep compatibility with ChainRules (since the rrules never get thunks - though that should really be supported IMO, and this unthunk can be dropped once ready), and we keep compatibility with existing adjoints. What we gain is that the compiler magic part figures out it can ignore some thunks.

What we lose (compared to this PR) is the efficiency in adjoints that have been moved to use @adjoint_keepthunks. Do we actually gain much by having them? I guess you do in your use case @oschulz ?

@oschulz
Copy link
Author

oschulz commented May 11, 2021

@oschulz Why do we need methods for Zero()? I thought we needed methods for Thunks?

Ah, sorry, yes - in fact, more basic functions need to support (and in some cases pass through) both Zero() and thunks. That would make make a lot of existing rrules that are written in a generic fashion thunk-compatible automatically. Others will need to be changed, of course.

Maybe we could offer a function unthunking in ChainRulesCore so that rrules can return y, unthunking(back) for rrules that need to unthunk.

I guess the minimal change we can make is:

  • move unthunk from chain_rules_output to chain_rules_input
  • add unthunk to @adjoint
    In this way, we keep compatibility with ChainRules

Yes, that's what this PR and FluxML/ZygoteRules.jl#17 do. Zygote needs quite a few non-unthunking internal adjoints to get an actual benefit, they are already in included in this PR. At least in some use cases I tested, it works quite nicely.

Long term we should then mandate that rrules accept thunks, as suggested in `JuliaDiff/ChainRules.jl#408. That's why I'd like to propose the "three-step" plan above.

@oxinabox
Copy link
Member

@oxinabox Could we move all the @adjoint_keepthunk methods to rrules? Since some of the existing adjoints touch the accum_param?

Looking at them, I think they either are things that should just move to ChainRules.jl
or they are Zygote internals that can never the less be written using rrules

The only one I am unsure about is:
https://github.com/FluxML/Zygote.jl/pull/966/files#diff-cd0210083ce3136f79bee6ebca2bcca77f41a14f11b5a7a65ea1cc54803164c3R103-R110
which I don't understand what it is doing.


I think @DhairyaLGandhi needs to make the call as to if to do the parts of the 3 step plan that involve changes to Zygote/Zygote rules.
I am gently in favour of them.
On the one hand they seem to work, and that is really useful, and fixing things sooner rather than later is better: Better an egg in the hand than two in the bush.
On the other hand it is more complexity to maintain in Zygote; and I dislike adding features to ZygoteRules when we are trying to stop using it.

I am strongly infavor of making changes in ChainRules and ChainRulesCore to facilitate this. Either as part of 3 step plan or as part of going straight to keeping thunks is done via rrule and @adjoint always removes them.


Maybe we could offer a function unthunking in ChainRulesCore so that rrules can return y, unthunking(back) for rrules that need to unthunk.

I think better to leave that to the AD package.
It is a simple enough function, I would rather a little code duplication than expand the API surface in a way that might be wrong/made redundant. (I was reading https://sandimetz.com/blog/2016/1/20/the-wrong-abstraction recently, I think it makes some good points)

@oschulz
Copy link
Author

oschulz commented May 11, 2021

dislike adding features to ZygoteRules when we are trying to stop using it

We wouldn't really add a feature to ZygoteRules though, right? We'd just make @adjoint unthunk in general. And Zygote.@_adjoint_keepthunks would be a purely internal and temporary thing in Zygote.

@oschulz
Copy link
Author

oschulz commented May 11, 2021

Maybe we could offer a function unthunking in ChainRulesCore so that rrules can return y, unthunking(back) for rrules that need to unthunk.
I think better to leave that to the AD package.

Ah, no, I meant to make life easier on rule writers (who would only depend on ChainRulesCore not an AD package). I assume there will be quite a few rules that need to always unthunk, and this could be a tool to keep their code more concise.

@oxinabox
Copy link
Member

Ah, no, I meant to make life easier on rule writers (who would only depend on ChainRulesCore not an AD package). I assume there will be quite a few rules that need to always unthunk, and this could be a tool to keep their code more concise.

I am, for the same reasons, happy for now to just leave it for the rule authors.
It is still a very short function.
And in that case when unthunking the input, often they will want to unthunk some parts but not others.
So mostly just callung unthunk as neeeded seems easiest.
But anyway it is easy to revise this opinion later.

@oxinabox
Copy link
Member

We wouldn't really add a feature to ZygoteRules though, right? We'd just make @adjoint unthunk in general. And Zygote.@_adjoint_keepthunks would be a purely internal and temporary thing in Zygote.

Ah, yeah that makes sense.
That makes me even more in favor of this plan

@DhairyaLGandhi
Copy link
Member

Will have to go through the discussion to make a note informed call, but why is it that we need a separate macro to define thunk aware adjoints? Generally, I see this more as an implementation detail which shouldn't leak to the API. Apologies if it's been discussed already, I'm a little late to the thread.

@oschulz
Copy link
Author

oschulz commented May 11, 2021

I am, for the same reasons, happy for now to just leave it for the rule authors. It is still a very short function.

You're right - unthunked() may be an unnecessary complication. Explicit unthunk()s in the rule code will be clearer.

@oschulz
Copy link
Author

oschulz commented May 11, 2021

but why is it that we need a separate macro to define thunk aware adjoints? Generally, I see this more as an implementation detail which shouldn't leak to the API

Yes, hence my changed proposal to do this as a Zygote-internal @_adjoint_keepthunks.

We do need two macros (one of them non-public/API-stable) for a soft transition, since we need a lot of non-unthunking Zygote-internal adjoints right now to have any benefit from thunks, but the ecosystem will break if existing adjoints suddenly get thunks as an input. That's why I'd like to do this in steps.

@oschulz
Copy link
Author

oschulz commented May 11, 2021

Ok I've removed @adjoint_keepthunks from ZygoteRules and added an internal @_adjoint_keepthunks to Zygote itself instead. I've also added some comments to label it as temporary.

Zygote.@_adjoint_keepthunks still uses the extended ZygoteRules.gradm(ex, mut, keepthunks) introduced by FluxML/ZygoteRules.jl#17. gradm is a fairly complicated bit of code and the difference between thunking and unthunking is minimal, so I didn't want to basically duplicate it in Zygote. But it's not an exported function, so we can remove the keepthunks argument from gradm later as a non-breaking change. FluxML/ZygoteRules.jl#17 is necessary in any case, since we do need to make @adjoint unthunk for any of this to work.

@oschulz
Copy link
Author

oschulz commented May 14, 2021

Here's a little demo on how this will work once we can disable the unthunking of rrule input (step 3 in the proposal above).

With

# Requires
# * https://github.com/FluxML/ZygoteRules.jl/pull/17
# * https://github.com/FluxML/Zygote.jl/pull/966

using Zygote, LinearAlgebra, ChainRulesCore, ChainRulesCore
using ChainRules: CommutativeMulNumber


# Make thunks in log more readable:
Base.show(io::IO, x::Thunk) = print(io, typeof(x).name.name)
Base.show(io::IO, x::InplaceableThunk) = print(io, typeof(x).name.name)


# Disable unthunking of rrule input (to become default behavior in the future):
function Zygote.wrap_chainrules_input(x)
    @info "wrap_chainrules_input($(x))"
    return x
end


# Same rrule as in ChainRules, just add some logging:
function ChainRulesCore.rrule(
    ::typeof(*),
    A::AbstractVecOrMat{<:CommutativeMulNumber},
    B::AbstractVecOrMat{<:CommutativeMulNumber},
)
    @info "rrule for $(A) * $(B)"
    function times_pullback(Ȳ)
        @info "pullback of $(A) * $(B) for $(Ȳ)"
        return (
            NO_FIELDS,
            InplaceableThunk(
                @thunk((@info("!left thunk.val of pullback of $(A) * $(B) for $(Ȳ) = $(Ȳ) * $(B)"); Ȳ * B')),
                X̄ -> (@info("!left thunk.add! of pullback of $(A) * $(B) for $(Ȳ) = $(Ȳ) * $(B)");mul!(X̄, Ȳ, B', true, true))
            ),
            InplaceableThunk(
                @thunk((@info("!right thunk.val of pullback of $(A) * $(B) for $(Ȳ) = $(A') * $(Ȳ)"); A' * Ȳ)),
                X̄ -> (@info("!right thunk.add! of pullback of $(A) * $(B) for $(Ȳ) = $(A') * $(Ȳ)");mul!(X̄, A', Ȳ, true, true))
            )
        )
    end
    return A * B, times_pullback
end


A, B, C, D = [fill(i,1,1) for i in 2:5]

We get:

julia> Zygote.gradient(X -> sum(A * B * C * X), D)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: wrap_chainrules_input(1×1 Fill{Int64}: entries equal to 1)
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [6] * [4] for InplaceableThunk
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [2] * [3] for InplaceableThunk
[ Info: !right thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = [24] * 1×1 Fill{Int64}: entries equal to 1
([24],)

julia> Zygote.gradient(X -> sum(X * B * C * D), A)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: wrap_chainrules_input(1×1 Fill{Int64}: entries equal to 1)
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [6] * [4] for InplaceableThunk
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [2] * [3] for InplaceableThunk
[ Info: !left thunk.val of pullback of [2] * [3] for InplaceableThunk = InplaceableThunk * [3]
[ Info: !left thunk.val of pullback of [6] * [4] for InplaceableThunk = InplaceableThunk * [4]
[ Info: !left thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = 1×1 Fill{Int64}: entries equal to 1 * [5]
([60],)

This should save a lot of computation time in applications that don't diff in respect to every argument along the computational graph. Standard ML applications usually want a gradients for almost everything, of course (since almost everything is a free parameter), but applications like fitting, MCMC, "scientific" ML and the like tend to be more selective in which gradient(s) they need.

If we re-enable unthunking of rrule input (as is part of this PR, to keep compatibility until "step 3")

function Zygote.wrap_chainrules_input(x)
    @info "wrap_chainrules_input($(x))"
    return Zygote.unthunk_tangent(x)
end

things aren't quite that nice if we diff in respect to D

julia> Zygote.gradient(X -> sum(A * B * C * X), D)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: wrap_chainrules_input(1×1 Fill{Int64}: entries equal to 1)
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: !left thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = 1×1 Fill{Int64}: entries equal to 1 * [5]
[ Info: pullback of [6] * [4] for [5]
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: !left thunk.val of pullback of [6] * [4] for [5] = [5] * [4]
[ Info: pullback of [2] * [3] for [20]
[ Info: !right thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = [24] * 1×1 Fill{Int64}: entries equal to 1
([24],)

but we still get some immediate benefit from this PR, since with the current release of Zygote, things look like this:

julia> Zygote.gradient(X -> sum(A * B * C * X), D)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: !left thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = 1×1 Fill{Int64}: entries equal to 1 * [5]
[ Info: !right thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = [24] * 1×1 Fill{Int64}: entries equal to 1
[ Info: pullback of [6] * [4] for [5]
[ Info: !left thunk.val of pullback of [6] * [4] for [5] = [5] * [4]
[ Info: !right thunk.val of pullback of [6] * [4] for [5] = [6] * [5]
[ Info: pullback of [2] * [3] for [20]
[ Info: !left thunk.val of pullback of [2] * [3] for [20] = [20] * [3]
[ Info: !right thunk.val of pullback of [2] * [3] for [20] = [2] * [20]
([24],)

@oschulz
Copy link
Author

oschulz commented May 14, 2021

Here are some benchmarks for the example above, with bigger arrays:

using Zygote, BenchmarkTools

A, B, C, D = [fill(i,100,100) for i in 2:5]
julia> @btime sum($A * $B * $C * $D)
  1.081 ms (24 allocations: 235.59 KiB)

Using current Zygote v0.6.10:

julia> @btime Zygote.gradient(X -> sum($A * $B * $C * X), $D)
  3.066 ms (222 allocations: 792.28 KiB)

With this PR (requires FluxML/ZygoteRules.jl#17):

julia> @btime Zygote.gradient(X -> sum($A * $B * $C * X), $D)
  2.192 ms (245 allocations: 562.34 KiB)

And after disabling forced unthunking of rrules input (step 3 in the plan above):

julia> @inline Zygote.wrap_chainrules_input(x) = x

julia> @btime Zygote.gradient(X -> sum($A * $B * $C * X), $D)
  1.413 ms (229 allocations: 416.22 KiB)

@oschulz
Copy link
Author

oschulz commented May 30, 2021

Hey @DhairyaLGandhi , it seems all that's missing before we can move with this is your blessing. Would you have time to take a look?

@DhairyaLGandhi
Copy link
Member

I'll take a look this week, thanks for this!

@oschulz
Copy link
Author

oschulz commented Jun 10, 2021

@DhairyaLGandhi , gentle bump - I don't mean to nag, but I would love to start using the first stage of this for some projects.

@DhairyaLGandhi
Copy link
Member

On it :)

Not nagging at all, I was trying to explore what it would entail for writing rules in zygote and what it means for the ergonomics anyway

@oschulz
Copy link
Author

oschulz commented Jun 14, 2021

Thanks!

There should be no difference for users writing ZygoteRules rules. Starting with step 2 in the plan above, user writing ChainRulesCore rules (and I guess the idea is that this will be the preferred way in the future, @oxinabox ? ) will need to start using unthunk for the tangents their pullbacks receive, if the pullback uses functionality that isn't itself thunk-compatible. That should really be all, regarding ergonomics.

@oxinabox
Copy link
Member

@mzgubic is working as we speak for making sure all rules in ChainRules.jl support recieving thunks as inputs.
Mostly by adding missing overloads for common linear operators to ChainRulesCore, but probably in some cases via calling unthunk at the start.
With the intent that when we release ChainRulesCore 1.0 and ChainRulesTestUtils 1.0,
it will be a requirement of all rules.

@CarloLucibello
Copy link
Member

Worth rebasing and see exactly where we are

@pxl-th
Copy link
Member

pxl-th commented Dec 24, 2024

I've fixed some of the issues after 482ab1b in my local fork, but there's one I'm not sure where it'd be optimal to land a fix.

MWE for the problem (taken from one of the failing tests):

x = ones(Float32, 3, 2)
Zygote.gradient(x) do x
    sum(map(norm, eachcol(x)))
end

Returns:

(Float32[0.0 0.0; 0.0 0.0; 0.0 0.0],)

The reason this happens is because ZBack no longer unthunks Thunks when being called (in wrap_chainrules_output), so rrule for eachcol receives dy as:

Vector{ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1071#1074"{Float32, SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Float32}}, ChainRules.var"#1070#1073"{Float32, SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Float32}}}

Where each element of Vector is produced by norm_pullback_2.
While previously due to unthunking in ZBack we'd get for dy:

Vector{Vector{Float32}}

And so the execution of ∇eachslice woudn't terminate here.

I feel like this is due to rrule for eachcol not handling correctly unthunking and so the fix should go there.
We can simply unthunk each element of dys as we iterate it (if it is a thunk).

If this is acceptable fix, then I can open a corresponding PR.
But maybe I'm missing something and there's better way of doing it.

@oschulz
Copy link
Author

oschulz commented Dec 27, 2024

Thanks for picking this back up guys!

@pxl-th
Copy link
Member

pxl-th commented Dec 27, 2024

I'm still looking at the remaining tests. But here're some initial results with lazy Zygote:

Context: Alternating rounds of generator vs 2 discriminators training (computing grads either w.r.t. generator or discriminators).

Batch size: 16.
GPU memory utilization: from 22.36 GiB down to 18.72 GiB.
1 Epoch training time: from 16 minutes down to 14 minutes.

Context: Using Flux.Conv layer only as a loss function with fixed weights (previously it'd compute grads w.r.t. conv weights).

1k training steps benchmark: from 93 seconds down to 46 seconds.

@pxl-th
Copy link
Member

pxl-th commented Dec 28, 2024

With latest commit all of the newly introduced test failures are now fixed.
Relies on JuliaDiff/ChainRules.jl#814

@ToucheSir
Copy link
Member

Do you mind rebasing one more time so we can see what CI looks like with #1545?

@pxl-th
Copy link
Member

pxl-th commented Dec 31, 2024

Hm... Still some failures. I'll take a look at them in a bit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants