Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit train!, unify update!, and auto-translate the two Adams #2082

Merged
merged 18 commits into from
Nov 20, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
explicit train, take 2
mcabbott committed Nov 18, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit c20fb9eae8ef5891c8556c2da189580ff3c0ba40
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@

## v0.13.7
* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)
* New method of `train!` using Zygote's "explicit" mode, allows changing AD back-end.

## v0.13.4
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -34,11 +34,13 @@ MacroTools = "0.5"
NNlib = "0.8.9"
NNlibCUDA = "0.2.4"
OneHotArrays = "0.1, 0.2"
Optimisers = "0.2.1"
Optimisers = "0.2.10"
ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
SpecialFunctions = "1.8.2, 2.1.2"
StatsBase = "0.33"
Tracker = "0.2.22"
Yota = "0.8.1"
Zygote = "0.6.34"
julia = "1.6"

@@ -48,7 +50,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker", "Yota"]
12 changes: 6 additions & 6 deletions docs/src/models/overview.md
Original file line number Diff line number Diff line change
@@ -77,9 +77,9 @@ julia> predict(x_train)
In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions.

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> loss(x, y) = Flux.Losses.mse(predict(x), y);
julia> loss(model, x, y) = mean(abs2.(model(x) .- y));

julia> loss(x_train, y_train)
julia> loss(predict, x_train, y_train)
122.64734f0
```

@@ -131,7 +131,7 @@ The first parameter is the weight and the second is the bias. Flux will adjust p
This optimiser implements the classic gradient descent strategy. Now improve the parameters of the model with a call to [`Flux.train!`](@ref) like this:

```jldoctest overview
julia> train!(loss, parameters, data, opt)
julia> train!(loss, predict, data, opt)
```

And check the loss:
@@ -156,10 +156,10 @@ In the previous section, we made a single call to `train!` which iterates over t

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> for epoch in 1:200
train!(loss, parameters, data, opt)
train!(loss, predict, data, opt)
end

julia> loss(x_train, y_train)
julia> loss(predict, x_train, y_train)
0.00339581f0

julia> parameters
@@ -188,7 +188,7 @@ First, we gathered real-world data into the variables `x_train`, `y_train`, `x_t

Then, we built a single input, single output predictive model, `predict = Dense(1 => 1)`. The initial predictions weren't accurate, because we had not trained the model yet.

After building the model, we trained it with `train!(loss, parameters, data, opt)`. The loss function is first, followed by the `parameters` holding the weights and biases of the model, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.
After building the model, we trained it with `train!(loss, predict, data, opt)`. The loss function is first, followed by the model itself, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.

After we trained the model, we verified it with the test data to verify the results.

4 changes: 4 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
@@ -34,6 +34,10 @@ export Descent, Adam, Momentum, Nesterov, RMSProp,
AdamW, RAdam, AdaBelief, InvDecay, ExpDecay,
WeightDecay, ClipValue, ClipNorm

include("train.jl")
using .Train
# using .Train: setup, @train_autodiff

using CUDA
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)

62 changes: 62 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -82,3 +82,65 @@ Base.@deprecate_binding ADAGrad AdaGrad
Base.@deprecate_binding ADADelta AdaDelta

@deprecate rng_from_array() default_rng_value()

#=
# Valid method in Optimise, old implicit style, is:
train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
# Valid methods in Train, new explict style, are:
train!(loss, model, data, opt)
train!(loss, model, data, opt::Optimisers.AbstractRule)
# ... and 3-arg:
train!(loss, model, opt)
train!(loss, model, opt::Optimisers.AbstractRule)
# Provide friendly errors for what happens if you mix these up:
=#
import .Optimise: train!
train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state")
train!(loss, ps::Params, opt) = error("can't mix implict Params with explict state")

train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
train!(loss, ps::Params, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")

train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
train!(loss, model, opt::Optimise.AbstractOptimiser) = train!(loss, model, _old_to_new(opt))

train!(loss, ps::Params, opt::Optimise.AbstractOptimiser; cb=0) = error("3-arg train does not exist for implicit mode")

# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
# Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
# where `loss_mxy` accepts the model as its first argument.
# """
# ))

# Next, to use the new `setup` with the still-exported old-style Adam etc:
import .Train: setup
setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)

for T in [:Descent, :Adam, :Momentum, :Nesterov,
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief,
# :InvDecay, :ExpDecay,
]
@eval function _old_to_new(rule::$T)
args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T))
Optimisers.$T(args...)
end
end
_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too.
_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now
_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thesh) # called omega, and there are more fields
_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thesh) # called delta now, and struct name differs
const ClipGrad = Optimise.ClipValue
_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred

_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")

Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = error("please use Flux.setup not Optimisers.setup, it may be able to translate this rule")

# v0.14 deprecations

# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc:
# Base.@deprecate_binding Optimiser OptimiserChain
# Base.@deprecate_binding ClipValue ClipGrad
13 changes: 12 additions & 1 deletion src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
using ProgressLogging: @progress, @withprogress, @logprogress
import Zygote: Params, gradient, withgradient

# Add methods to Optimisers.jl's function, so that there is just one Flux.update!
# for both explicit and implicit parameters.
import Optimisers.update!

"""
update!(opt, p, g)
update!(opt, ps::Params, gs)

Perform an update step of the parameters `ps` (or the single parameter `p`)
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
according to optimizer `opt::AbstractOptimiser` and the gradients `gs` (the gradient `g`).

As a result, the parameters are mutated and the optimizer's internal state may change.
The gradient could be mutated as well.

!!! note
This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14.
The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain.
"""
function update!(opt::AbstractOptimiser, x, x̄)
x̄r = copyto!(similar(x̄), x̄) # Flux.Optimise assumes it can mutate the gradient. This is not
@@ -88,6 +95,10 @@ batchmemaybe(x::Tuple) = x
Uses a `loss` function and training `data` to improve the
model's parameters according to a particular optimisation rule `opt`.

!!! note
This method with implicit `Params` will be removed from Flux 0.14.
It should be replaced with the explicit method `train!(loss, model, data, opt)`.

For each `d in data`, first the gradient of the `loss` is computed like this:
```
gradient(() -> loss(d...), pars) # if d isa Tuple
Loading