-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Utilize ChainRulesCore thunks #966
base: master
Are you sure you want to change the base?
Conversation
I think it may actually be the unthunk in |
It looks like for this to work, we'll need to overload a lot of functions like |
Ok, with FluxML/ZygoteRules.jl#17 this passes the Zygote test suit on my system now (incl. CUDA tests). In both the simple example above and more complex cases that I've tested only the required thunks are unthunked. Hard to give guarantees on this, obviously. The
|
I had a look, the number of packages that actually define rrules seems finite (Bijectors, ChainRules, ChainRulesCore, ComponentArrays, CRlibm, DiffEqBase, DistributionsAD, LoopVectorization, NNlib, SpecialFunctions, SymbolicUtils, WebSockets and Zygote in my fairly well-filled package dir). So maybe the better approach would be:
This way, we can have a "soft" path to using thunks more and more. We'll have an immediate benefit since this PR can be done as a patch releases of ZygoteRules and Zygot, without ecosystem propagation delay. And after a (hopefully not too long while), Update: see below for changed proposal |
I think this idea is closer to the right way. How much of a problem is it if linear operators on a What if we made |
I guess most of them will have to?
How about this:
I suspect switching all necessary Zygote-internal adjoints to |
My main takeaway from this discussion is that not supporting thunks inside rrules is a serious deficiency, and we should work towards fixing that. It should probably be on the list for 1.0? JuliaDiff/ChainRules.jl#408 A couple of questions:
I think of Zygote as doing one of the three things to generate a pullback: hit an rrule, hit an adjoint, or do its compiler magic. Right now, we unthunk before getting out of the rrule, so the other two parts never see thunks. I guess the minimal change we can make is:
In this way, we keep compatibility with ChainRules (since the rrules never get thunks - though that should really be supported IMO, and this unthunk can be dropped once ready), and we keep compatibility with existing adjoints. What we gain is that the compiler magic part figures out it can ignore some thunks. What we lose (compared to this PR) is the efficiency in adjoints that have been moved to use |
Ah, sorry, yes - in fact, more basic functions need to support (and in some cases pass through) both Zero() and thunks. That would make make a lot of existing Maybe we could offer a function
Yes, that's what this PR and FluxML/ZygoteRules.jl#17 do. Zygote needs quite a few non-unthunking internal adjoints to get an actual benefit, they are already in included in this PR. At least in some use cases I tested, it works quite nicely. Long term we should then mandate that |
Looking at them, I think they either are things that should just move to ChainRules.jl The only one I am unsure about is: I think @DhairyaLGandhi needs to make the call as to if to do the parts of the 3 step plan that involve changes to Zygote/Zygote rules. I am strongly infavor of making changes in ChainRules and ChainRulesCore to facilitate this. Either as part of 3 step plan or as part of going straight to keeping thunks is done via
I think better to leave that to the AD package. |
We wouldn't really add a feature to ZygoteRules though, right? We'd just make |
Ah, no, I meant to make life easier on rule writers (who would only depend on |
I am, for the same reasons, happy for now to just leave it for the rule authors. |
Ah, yeah that makes sense. |
Will have to go through the discussion to make a note informed call, but why is it that we need a separate macro to define thunk aware adjoints? Generally, I see this more as an implementation detail which shouldn't leak to the API. Apologies if it's been discussed already, I'm a little late to the thread. |
You're right - |
Yes, hence my changed proposal to do this as a Zygote-internal We do need two macros (one of them non-public/API-stable) for a soft transition, since we need a lot of non-unthunking Zygote-internal adjoints right now to have any benefit from thunks, but the ecosystem will break if existing adjoints suddenly get thunks as an input. That's why I'd like to do this in steps. |
Ok I've removed
|
Here's a little demo on how this will work once we can disable the unthunking of With # Requires
# * https://github.com/FluxML/ZygoteRules.jl/pull/17
# * https://github.com/FluxML/Zygote.jl/pull/966
using Zygote, LinearAlgebra, ChainRulesCore, ChainRulesCore
using ChainRules: CommutativeMulNumber
# Make thunks in log more readable:
Base.show(io::IO, x::Thunk) = print(io, typeof(x).name.name)
Base.show(io::IO, x::InplaceableThunk) = print(io, typeof(x).name.name)
# Disable unthunking of rrule input (to become default behavior in the future):
function Zygote.wrap_chainrules_input(x)
@info "wrap_chainrules_input($(x))"
return x
end
# Same rrule as in ChainRules, just add some logging:
function ChainRulesCore.rrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
@info "rrule for $(A) * $(B)"
function times_pullback(Ȳ)
@info "pullback of $(A) * $(B) for $(Ȳ)"
return (
NO_FIELDS,
InplaceableThunk(
@thunk((@info("!left thunk.val of pullback of $(A) * $(B) for $(Ȳ) = $(Ȳ) * $(B)"); Ȳ * B')),
X̄ -> (@info("!left thunk.add! of pullback of $(A) * $(B) for $(Ȳ) = $(Ȳ) * $(B)");mul!(X̄, Ȳ, B', true, true))
),
InplaceableThunk(
@thunk((@info("!right thunk.val of pullback of $(A) * $(B) for $(Ȳ) = $(A') * $(Ȳ)"); A' * Ȳ)),
X̄ -> (@info("!right thunk.add! of pullback of $(A) * $(B) for $(Ȳ) = $(A') * $(Ȳ)");mul!(X̄, A', Ȳ, true, true))
)
)
end
return A * B, times_pullback
end
A, B, C, D = [fill(i,1,1) for i in 2:5] We get:
This should save a lot of computation time in applications that don't diff in respect to every argument along the computational graph. Standard ML applications usually want a gradients for almost everything, of course (since almost everything is a free parameter), but applications like fitting, MCMC, "scientific" ML and the like tend to be more selective in which gradient(s) they need. If we re-enable unthunking of rrule input (as is part of this PR, to keep compatibility until "step 3") function Zygote.wrap_chainrules_input(x)
@info "wrap_chainrules_input($(x))"
return Zygote.unthunk_tangent(x)
end things aren't quite that nice if we diff in respect to
but we still get some immediate benefit from this PR, since with the current release of Zygote, things look like this:
|
Here are some benchmarks for the example above, with bigger arrays: using Zygote, BenchmarkTools
A, B, C, D = [fill(i,100,100) for i in 2:5]
Using current Zygote v0.6.10:
With this PR (requires FluxML/ZygoteRules.jl#17):
And after disabling forced unthunking of rrules input (step 3 in the plan above):
|
Hey @DhairyaLGandhi , it seems all that's missing before we can move with this is your blessing. Would you have time to take a look? |
I'll take a look this week, thanks for this! |
@DhairyaLGandhi , gentle bump - I don't mean to nag, but I would love to start using the first stage of this for some projects. |
On it :) Not nagging at all, I was trying to explore what it would entail for writing rules in zygote and what it means for the ergonomics anyway |
Thanks! There should be no difference for users writing |
@mzgubic is working as we speak for making sure all rules in ChainRules.jl support recieving thunks as inputs. |
Worth rebasing and see exactly where we are |
I've fixed some of the issues after 482ab1b in my local fork, but there's one I'm not sure where it'd be optimal to land a fix. MWE for the problem (taken from one of the failing tests): x = ones(Float32, 3, 2)
Zygote.gradient(x) do x
sum(map(norm, eachcol(x)))
end Returns:
The reason this happens is because
Where each element of
And so the execution of I feel like this is due to rrule for eachcol not handling correctly unthunking and so the fix should go there. If this is acceptable fix, then I can open a corresponding PR. |
fd8de94
to
c21697d
Compare
Thanks for picking this back up guys! |
I'm still looking at the remaining tests. But here're some initial results with lazy Zygote: Context: Alternating rounds of generator vs 2 discriminators training (computing grads either w.r.t. generator or discriminators). Batch size: Context: Using 1k training steps benchmark: from |
c21697d
to
73eb1e3
Compare
With latest commit all of the newly introduced test failures are now fixed. |
Do you mind rebasing one more time so we can see what CI looks like with #1545? |
Introduces @_adjoint_keepthunks to mark adjoints that should pass chunks through.
Co-authored-by: Brian Chen <[email protected]>
This reverts commit 34865ea.
Co-authored-by: Brian Chen <[email protected]>
322c320
to
9a850e0
Compare
9a850e0
to
8797788
Compare
Hm... Still some failures. I'll take a look at them in a bit |
Note: Requires FluxML/ZygoteRules.jl#17
Currently, Zygote always unthunks
ChainRuleCore
thunks, which is wasteful and may also lead to trouble in cases why a thunks just can't be run for the given types/contents.With
we obviously get
but we currently also get
so Zygote also executes both thunks if we only require the pullback in respect to
b
.With this PR, we should get
instead.
Note: Requires FluxML/ZygoteRules.jl#17
CC @oxinabox , @mzgubic