Skip to content

Commit

Permalink
Merge pull request FluxML#62 from darsnack/scheduler-w-constants
Browse files Browse the repository at this point in the history
Add support for constants in `Scheduler`
  • Loading branch information
darsnack authored Mar 5, 2024
2 parents 463aab2 + 4ba2f57 commit e2f05d1
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/src/tutorials/optimizers.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Scheduling optimizers

A schedule by itself is not helpful; we need to use the schedules to adjust parameters. In this tutorial, we will examine three ways to do just that---iterating the schedule, using a stateful iterator, and using an scheduled optimizer.
A schedule by itself is not helpful; we need to use the schedules to adjust parameters. In this tutorial, we will examine three ways to do just that---iterating the schedule, using a stateful iterator, and using an scheduled optimizer. The final option is the preferred method for FluxML.

## Iterating during training

Expand Down
18 changes: 15 additions & 3 deletions src/scheduler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,25 @@ julia> opt = Scheduler(Descent, CosAnneal(l0 = 0.1, l1 = 0.8, period = 10));
# schedule learning rate and momentum of Momentum
julia> opt = Scheduler(Momentum, CosAnneal(l0 = 0.1, l1 = 0.8, period = 10), Exp(0.999, 0.8));
# schedule the weight decay term of AdamW
julia> opt = Scheduler(AdamW, decay = Exp(1e-3, 0.7));
# schedule the weight decay term of AdamW with a custom fixed learning rate
julia> opt = Scheduler(AdamW, eta = 1e-4, decay = Exp(1e-3, 0.7));
```
"""
struct Scheduler{T<:Union{<:Tuple, <:NamedTuple}, F} <: AbstractRule
constructor::F
schedules::T

function Scheduler(constructor, schedules::Tuple)
_schedules = map(s -> s isa Number ? Constant(s) : s, schedules)

new{typeof(_schedules), typeof(constructor)}(constructor, _schedules)
end
function Scheduler(constructor, schedules::NamedTuple{K}) where K
_schedules = map(s -> s isa Number ? Constant(s) : s, schedules)
_schedules = NamedTuple{K}(_schedules)

new{typeof(_schedules), typeof(constructor)}(constructor, _schedules)
end
end
Scheduler(constructor, schedules...) = Scheduler(constructor, schedules)
Scheduler(constructor; schedules...) = Scheduler(constructor, (; schedules...))
Expand All @@ -45,7 +57,7 @@ _get_opt(scheduler::Scheduler{<:Tuple}, t) =
function _get_opt(scheduler::Scheduler{<:NamedTuple}, t)
kwargs = NamedTuple{keys(scheduler.schedules)}(s(t) for s in scheduler.schedules)

return scheduler.constructor(kwargs...)
return scheduler.constructor(; kwargs...)
end

Optimisers.init(o::Scheduler, x::AbstractArray) =
Expand Down
20 changes: 20 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,25 @@ end
m = m′
o = o′
end

o = Optimisers.setup(Scheduler(Optimisers.Momentum, rho = srho), m)
for t in 1:10
g = Zygote.gradient(m -> sum(m.W * x + m.b), m)[1]
o′, m′ = Optimisers.update(o, m, g)
@test m′.W m.W - (srho(t) * o.W.state.opt + g.W * 0.01)
@test m′.b m.b - (srho(t) * o.b.state.opt + g.b * 0.01)
m = m′
o = o′
end

o = Optimisers.setup(Scheduler(Optimisers.Momentum, rho = 0.8), m)
for t in 1:10
g = Zygote.gradient(m -> sum(m.W * x + m.b), m)[1]
o′, m′ = Optimisers.update(o, m, g)
@test m′.W m.W - (0.8 * o.W.state.opt + g.W * 0.01)
@test m′.b m.b - (0.8 * o.b.state.opt + g.b * 0.01)
m = m′
o = o′
end
end
end

0 comments on commit e2f05d1

Please sign in to comment.