Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change update! to update!! #116

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Optimisers.OptimiserChain
```@docs
Optimisers.setup
Optimisers.update
Optimisers.update!
Optimisers.update!!
Optimisers.adjust(::Any, ::Real)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Notice that a completely new instance of the model is returned. Internally, this
is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the
tree formed by the model and update the parameters using the gradients.

There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
There is also [`Optimisers.update!!`](@ref) which similarly returns a new model and new state,
but is free to mutate arrays within the old one for efficiency.
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.
Expand Down
16 changes: 9 additions & 7 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, ClipGrad, ClipNorm, OptimiserChain


@deprecate update! update!! false
###
### one-array functions
###
Expand Down Expand Up @@ -70,7 +72,7 @@ init

Initialises the given optimiser for every trainable parameter within the model.
Returns a tree of the relevant states, which must be passed to [`update`](@ref)
or [`update!`](@ref).
or [`update!!`](@ref).

# Example
```jldoctest
Expand Down Expand Up @@ -113,7 +115,7 @@ Uses the optimiser and the gradient to change the trainable parameters in the mo
Returns the improved model, and the optimiser states needed for the next update.
The initial tree of states comes from [`setup`](@ref).

See also [`update!`](@ref), which will be faster for models of ordinary `Array`s or `CuArray`s.
See also [`update!!`](@ref), which will be faster for models of ordinary `Array`s or `CuArray`s.

# Example
```jldoctest
Expand All @@ -131,7 +133,7 @@ julia> Optimisers.update(t, m, g)
update

"""
Optimisers.update!(tree, model, gradient) -> (tree, model)
Optimisers.update!!(tree, model, gradient) -> (tree, model)

Uses the optimiser and the gradient to change the trainable parameters in the model.
Returns the improved model, and the optimiser states needed for the next update.
Expand All @@ -154,12 +156,12 @@ julia> t = Optimisers.setup(Momentum(1/30, 0.9), m);
julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1]
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])

julia> t2, m2 = Optimisers.update!(t, m, g);
julia> t2, m2 = Optimisers.update!!(t, m, g);

julia> m2 # after update or update!, this is the new model
julia> m2 # this is the model with new parameters
(x = Float32[0.6666666, 1.5333333], y = Float32[3.6666667, 4.5333333])

julia> m2.x === m.x # update! has re-used this array, for efficiency
julia> m2.x === m.x # update!! has re-used this array, for efficiency
true

julia> m # original should be discarded, may be mutated but no guarantee
Expand All @@ -169,6 +171,6 @@ julia> t == t2 # original state is in fact guaranteed to be mutated
true
```
"""
update!
update!!

end # module
4 changes: 3 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ function update(tree, model, grad, higher...)
update!(t′, x′, grad, higher...)
end

function update!(tree, model, grad, higher...)
update!!(tree, model, grad, higher...) = old_update!(tree, model, grad, higher...)

function old_update!(tree, model, grad, higher...)
# First walk is to accumulate the gradient. This recursion visits every copy of
# shared leaves, but stops when branches are absent from the gradient:
grads = IdDict{Leaf, Any}()
Expand Down