Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Controlling which optimizer rule you're adjusting. #202

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

murrellb
Copy link

This PR tries to address the issue in #201 where, if two Rules have the same parameter name, you can't use the adjust interface to control them independently.

This is intended to be non-breaking, so the previous behavior is undisturbed:
image
but you can now additionally specify the type you wish to change:
image

Implementation-wise, this added quite a few adjust and adjust! methods. I'm not 100% sure all of them are needed, so maybe we can prune a few back but I wasn't quite sure of the edge cases I'm not aware of, so I tried to cover everything.

I this PR's version, analogous to above, if you just pass in a float you'll change eta, but only in the optimizer you intend:
image
This is maybe less critical to have, so we could cut back all of the non-keyword methods that specify a type.

Doing all this introduces a bit of core package developer complexity, but it is all generic so you shouldn't need to touch it when adding a new optimizer, and the usage seems pretty intuitive from a user perspective, so I don't think it adds interface complexity. To be honest, the current behavior seems like a footgun, and as a user with an OptimizerChain, I'd only ever use adjust/adjust! while specifying the type, because otherwise I'd have to check the parameter list of all the OptimizerChain components just to be sure I'm not doing something I don't intend.

PR Checklist

  • Tests are added
  • Documentation, if applicable

@mcabbott
Copy link
Member

mcabbott commented Dec 21, 2024

I can't easily copy-paste from screenshots...

  1. The spelling proposed here is Optimisers.adjust!(os, SignDecay, lambda=0.3), passing just the type.

Alternatives are:

  1. Optimisers.adjust!(os, SignDecay(0.3)), passing a whole new rule.

  2. Changing e.g. SignDecay to use lambda1 and WeightDecay to use lambda2 to avoid the clash.

  3. Something more exotic, like adjust!(fun::Function, os, lambda=0.3) where this fun must return a Bool and gets... the rule?

  4. Leave it to you:

julia> using Optimisers, Functors

julia> os = Optimisers.setup(OptimiserChain(SignDecay(0.1), AdamW(lambda=0.1)), (; x=[12 34.]))
(x = Leaf(OptimiserChain(SignDecay(0.1), AdamW(0.001, (0.9, 0.999), 0.1, 1.0e-8, true)), (nothing, ([0.0 0.0], [0.0 0.0], (0.9, 0.999)))),)

julia> fmap(os; exclude = r -> r isa SignDecay) do r
         Optimisers.adjust(r, lambda=0.333)
       end
(x = Leaf(OptimiserChain(SignDecay(0.333), AdamW(0.001, (0.9, 0.999), 0.1, 1.0e-8, true)), (nothing, ([0.0 0.0], [0.0 0.0], (0.9, 0.999)))),)

Are there other options?

I suppose 3 doesn't help if you wish to adjust some learning rates and not others, we aren't going to rename them like eta37.

The advantage of 2 is you need never know the field names. The disadvantage is that adjust!(os, AdamW(lambda=0.3)) will reset eta too, you have to still have your original choices around.

@murrellb
Copy link
Author

This should be easy, convenient, and visible (so that the user is aware of the possible footgun). That rules out 5.

Optimisers.adjust!(os, SignDecay(0.3)), passing a whole new rule.

Will this replace a whole OptimiserChain, or just the rule that matches the type? The fact that I'm asking means that the user has to memorize/understand an extra convention to use this.

Something more exotic, like adjust!(fun::Function, os, lambda=0.3) where this fun must return a Bool and gets... the rule?

I'm not sure what I'd do with this that you can't do with the current PR, but with (slightly) more difficulty.

Another option is allowing you to label an optimizer component, with a default that matches its type. So you can go:

os = Optimisers.setup(OptimiserChain(SignDecay(0.1, label = "a"), AdamW(lambda=0.1)), (; x=[12 34.]))
Optimisers.adjust!(os, "a", lambda = 0.3)
Optimisers.adjust!(os, "AdamW", lambda = 0.01) #"AdamW" is the default label for AdamW.

This would let you have two optimisers of the same type in a chain, but label them differently so you can adjust them independently (which my current PR doesn't allow). However I can't think of motivation for needing this, and this would mean we have to introduce changes to all the current optimisers (and might invalidate checkpointed opt states, depending on how it is done?).

Still prefer the current PR's solution, but I'm not picky as long as it is easy to use and visible.

@mcabbott
Copy link
Member

mcabbott commented Dec 21, 2024

I agree 5 is messy.

But the point is that the state tree is some simple nested Julia object, which you are free to manipulate however you like -- you do not have to work through any explicit API of this package, and hence the API does not have to expand to allow arbitrarily complex manipulations.

Maybe I'm a little hesitant to add API too quickly, which is why I think we should think through alternatives carefully.

Adding labels seems way over the top to me.

or just the rule that matches the type?

Obviously just the rule, surely? If you have setup Momentum, and do adjust!(... AdamW...) in any form, it can't replace the whole rule because it can't change the saved momenta -- that's the whole point of adjust! vs. a fresh setup.

@murrellb
Copy link
Author

Obviously just the rule, surely? If you have setup Momentum, and do adjust!(... AdamW...) in any form, it can't replace the whole rule because it can't change the saved momenta -- that's the whole point of adjust! vs. a fresh setup.

This will only be "obvious" to a user who has both i) a very clear understanding of how this package does what it does, and ii) a mental model of the package devs. They also have to be very confident that there isn't some other reason (that they're currently unaware of) that doesn't eg. force them to have to re-specify the entire rule (eg. the package internals for some reason not handling automatic construction of the rest of the optimizer chain). Not obvious.

@CarloLucibello
Copy link
Member

I think the proposed solution is a light addition and intuitive addition to the interface and solves a concrete problem with no simple workaround, so we should go with it.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for drafting this up. I think the only missing pieces are tests and an expanded docstring for adjust[!]?

#adjust with type control
###

adjust!(ℓ::Leaf, oT::Type, eta::Real) = (ℓ.rule = adjust(ℓ.rule, oT, eta); nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can all the Types here be changed to Type{<:AbstractRule}? It would provide some more information for users of the interface and ensure we're not promising support for arbitrary types.

adjust!(tree, oT::Type, eta::Real) = foreach(st -> adjust!(st, oT, eta), tree)
adjust!(tree, oT::Type; kw...) = foreach(st -> adjust!(st, oT; kw...), tree)

adjust(r::AbstractRule, oT::Type, eta::Real) = ifelse(isa(r, oT), adjust(r, eta), r)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These overloads are allocating the output of adjust(r, ...) even if it doesn't end up being used. Unless ifelse brings a demonstrable type stability benefit over a plain conditional, I'd recommend using a ternary instead:

Suggested change
adjust(r::AbstractRule, oT::Type, eta::Real) = ifelse(isa(r, oT), adjust(r, eta), r)
adjust(r::AbstractRule, oT::Type{<:AbstractRule}, eta::Real) = isa(r, oT) ? adjust(r, eta) : r

Copy link
Member

@mcabbott mcabbott Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would guess the branch is resolved when compiling? Possibly it will prefer this style:

Suggested change
adjust(r::AbstractRule, oT::Type, eta::Real) = ifelse(isa(r, oT), adjust(r, eta), r)
adjust(r::AbstractRule, ::Type{T}, eta::Real) where {T<:AbstractRule} = isa(r, T) ? adjust(r, eta) : r

I think you can also make dispatch do this, although whether it's a good idea I don't know... less obvious to read:

adjust(r::C, oT::Type{A}, eta::Real) where {C <: A <: AbstractRule}  where {C<:AbstractRule} = adjust(r, eta)
adjust(r:: AbstractRule, oT::Type{<:AbstractRule}, eta::Real) = r

Edit, less obscure:

Suggested change
adjust(r::AbstractRule, oT::Type, eta::Real) = ifelse(isa(r, oT), adjust(r, eta), r)
adjust(r::T, oT::Type{T}, eta::Real) where {T<:AbstractRule} = adjust(r, eta)
adjust(r:: AbstractRule, oT::Type{<: AbstractRule}, eta::Real) = r

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants