-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Controlling which optimizer rule you're adjusting. #202
base: master
Are you sure you want to change the base?
Conversation
I can't easily copy-paste from screenshots...
Alternatives are:
julia> using Optimisers, Functors
julia> os = Optimisers.setup(OptimiserChain(SignDecay(0.1), AdamW(lambda=0.1)), (; x=[12 34.]))
(x = Leaf(OptimiserChain(SignDecay(0.1), AdamW(0.001, (0.9, 0.999), 0.1, 1.0e-8, true)), (nothing, ([0.0 0.0], [0.0 0.0], (0.9, 0.999)))),)
julia> fmap(os; exclude = r -> r isa SignDecay) do r
Optimisers.adjust(r, lambda=0.333)
end
(x = Leaf(OptimiserChain(SignDecay(0.333), AdamW(0.001, (0.9, 0.999), 0.1, 1.0e-8, true)), (nothing, ([0.0 0.0], [0.0 0.0], (0.9, 0.999)))),) Are there other options? I suppose 3 doesn't help if you wish to adjust some learning rates and not others, we aren't going to rename them like The advantage of 2 is you need never know the field names. The disadvantage is that |
This should be easy, convenient, and visible (so that the user is aware of the possible footgun). That rules out 5.
Will this replace a whole OptimiserChain, or just the rule that matches the type? The fact that I'm asking means that the user has to memorize/understand an extra convention to use this.
I'm not sure what I'd do with this that you can't do with the current PR, but with (slightly) more difficulty. Another option is allowing you to label an optimizer component, with a default that matches its type. So you can go: os = Optimisers.setup(OptimiserChain(SignDecay(0.1, label = "a"), AdamW(lambda=0.1)), (; x=[12 34.]))
Optimisers.adjust!(os, "a", lambda = 0.3)
Optimisers.adjust!(os, "AdamW", lambda = 0.01) #"AdamW" is the default label for AdamW. This would let you have two optimisers of the same type in a chain, but label them differently so you can adjust them independently (which my current PR doesn't allow). However I can't think of motivation for needing this, and this would mean we have to introduce changes to all the current optimisers (and might invalidate checkpointed opt states, depending on how it is done?). Still prefer the current PR's solution, but I'm not picky as long as it is easy to use and visible. |
I agree 5 is messy. But the point is that the state tree is some simple nested Julia object, which you are free to manipulate however you like -- you do not have to work through any explicit API of this package, and hence the API does not have to expand to allow arbitrarily complex manipulations. Maybe I'm a little hesitant to add API too quickly, which is why I think we should think through alternatives carefully. Adding labels seems way over the top to me.
Obviously just the rule, surely? If you have setup |
This will only be "obvious" to a user who has both i) a very clear understanding of how this package does what it does, and ii) a mental model of the package devs. They also have to be very confident that there isn't some other reason (that they're currently unaware of) that doesn't eg. force them to have to re-specify the entire rule (eg. the package internals for some reason not handling automatic construction of the rest of the optimizer chain). Not obvious. |
I think the proposed solution is a light addition and intuitive addition to the interface and solves a concrete problem with no simple workaround, so we should go with it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for drafting this up. I think the only missing pieces are tests and an expanded docstring for adjust[!]
?
#adjust with type control | ||
### | ||
|
||
adjust!(ℓ::Leaf, oT::Type, eta::Real) = (ℓ.rule = adjust(ℓ.rule, oT, eta); nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can all the Type
s here be changed to Type{<:AbstractRule}
? It would provide some more information for users of the interface and ensure we're not promising support for arbitrary types.
adjust!(tree, oT::Type, eta::Real) = foreach(st -> adjust!(st, oT, eta), tree) | ||
adjust!(tree, oT::Type; kw...) = foreach(st -> adjust!(st, oT; kw...), tree) | ||
|
||
adjust(r::AbstractRule, oT::Type, eta::Real) = ifelse(isa(r, oT), adjust(r, eta), r) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These overloads are allocating the output of adjust(r, ...)
even if it doesn't end up being used. Unless ifelse
brings a demonstrable type stability benefit over a plain conditional, I'd recommend using a ternary instead:
adjust(r::AbstractRule, oT::Type, eta::Real) = ifelse(isa(r, oT), adjust(r, eta), r) | |
adjust(r::AbstractRule, oT::Type{<:AbstractRule}, eta::Real) = isa(r, oT) ? adjust(r, eta) : r |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would guess the branch is resolved when compiling? Possibly it will prefer this style:
adjust(r::AbstractRule, oT::Type, eta::Real) = ifelse(isa(r, oT), adjust(r, eta), r) | |
adjust(r::AbstractRule, ::Type{T}, eta::Real) where {T<:AbstractRule} = isa(r, T) ? adjust(r, eta) : r |
I think you can also make dispatch do this, although whether it's a good idea I don't know... less obvious to read:
adjust(r::C, oT::Type{A}, eta::Real) where {C <: A <: AbstractRule} where {C<:AbstractRule} = adjust(r, eta)
adjust(r:: AbstractRule, oT::Type{<:AbstractRule}, eta::Real) = r
Edit, less obscure:
adjust(r::AbstractRule, oT::Type, eta::Real) = ifelse(isa(r, oT), adjust(r, eta), r) | |
adjust(r::T, oT::Type{T}, eta::Real) where {T<:AbstractRule} = adjust(r, eta) | |
adjust(r:: AbstractRule, oT::Type{<: AbstractRule}, eta::Real) = r |
This PR tries to address the issue in #201 where, if two Rules have the same parameter name, you can't use the
adjust
interface to control them independently.This is intended to be non-breaking, so the previous behavior is undisturbed:
but you can now additionally specify the type you wish to change:
Implementation-wise, this added quite a few
adjust
andadjust!
methods. I'm not 100% sure all of them are needed, so maybe we can prune a few back but I wasn't quite sure of the edge cases I'm not aware of, so I tried to cover everything.I this PR's version, analogous to above, if you just pass in a float you'll change eta, but only in the optimizer you intend:
This is maybe less critical to have, so we could cut back all of the non-keyword methods that specify a type.
Doing all this introduces a bit of core package developer complexity, but it is all generic so you shouldn't need to touch it when adding a new optimizer, and the usage seems pretty intuitive from a user perspective, so I don't think it adds interface complexity. To be honest, the current behavior seems like a footgun, and as a user with an OptimizerChain, I'd only ever use
adjust
/adjust!
while specifying the type, because otherwise I'd have to check the parameter list of all the OptimizerChain components just to be sure I'm not doing something I don't intend.PR Checklist