Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rule for mixed precision training #152

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ In addition to the main course, you may wish to order some of these condiments:
Optimisers.AccumGrad
Optimisers.ClipGrad
Optimisers.ClipNorm
Optimisers.WeightDecay
Optimisers.MixedPrecision
Optimisers.OptimiserChain
Optimisers.WeightDecay
```

## Model Interface
Expand Down
2 changes: 1 addition & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
AccumGrad
AccumGrad, MixedPrecision

###
### one-array functions
Expand Down
65 changes: 65 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -752,3 +752,68 @@ function apply!(o::AccumGrad, state, x, dx)
return (accum_dx, counter + 1), nothing
end
end

"""
MixedPrecision(opt)
MixedPrecision{T}(opt)

An optimiser that wraps another optimiser `opt` in order to perform mixed precision
training [1].

The state of `MixedPrecision` will contain a copy in precision `T` of the trainable parameter `x`,
call it `xT`.
The internal state of `opt` also operates at precision `T`.
If `T` is not specified, it defaults to `Float32`.
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

Call `g` the gradient of `x`. Both `g` and `x` are typically in a precision lower than `T`
(e.g. `Float16`).

In the `update!(opt_state, x, g)` call, `opt` is used to update `xT` instead of `x`,
then `x` is updated with the value of `xT`.

[1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 .
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 .
# Reference
[1] Micikevicius et al., "Mixed Precision Training", ICLR 2018, https://arxiv.org/abs/1710.03740 .


# Examples

```julia
x = rand(Float16, 2) # a trainable parameter in low precision

opt = MixedPrecision(Adam(1e-3))
opt_state = Optimisers.setup(opt, x) # the state contains a copy of x in Float32 precision

g = rand(Float16, 2) # a gradient in low precision

# accumulation is performed in high precision,
# then also the low precision x is synced
Optimisers.update!(opt_state, x, g)
```
"""
struct MixedPrecision{T<:Number, O<:AbstractRule} <: AbstractRule
opt::O
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be rule?

Suggested change
opt::O
rule::O

end

@functor MixedPrecision

MixedPrecision(opt::AbstractRule) = MixedPrecision{Float32, typeof(opt)}(opt)
MixedPrecision{T}(opt::AbstractRule) where T = MixedPrecision{T, typeof(opt)}(opt)

function init(o::MixedPrecision{T}, x::AbstractArray) where T
xT = T.(x)
return (xT, init(o.opt, xT))
end

function apply!(o::MixedPrecision{T}, state, x, dx) where T
xT, st = state
st′, dx′ = apply!(o.opt, st, xT, dx)
xT = subtract!(xT, dx′)
if maywrite(x)
x .= xT
dx′ = nothing
darsnack marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is correct.

But perhaps weird things will happen if you try to compose it, e.g. OptimiserChain(MixedPrecision(...), ClipGrad()). If so then we should make sure such things give an error.

else
dx′ = x .- eltype(x).(xT)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this path, should the subtraction happen in high or low precision, does it matter?

This is the sort of place that I worry about scaling & the range of Float16. But haven't thought hard.

end
return (xT, st′), dx′
end

adjust(o::MixedPrecision, eta::Real) = MixedPrecision(adjust(o.opt, eta))
adjust(o::MixedPrecision; kw...) = MixedPrecision(adjust(o.opt; kw...))
Comment on lines +816 to +817
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't these forget T?

19 changes: 18 additions & 1 deletion test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ RULES = [
Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(),
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(),
MixedPrecision{Float64}(Adam()),
# A few chained combinations:
OptimiserChain(WeightDecay(), Adam(0.001)),
OptimiserChain(ClipNorm(), Adam(0.001)),
Expand Down Expand Up @@ -266,4 +267,20 @@ end

tree, x4 = Optimisers.update(tree, x3, g4)
@test x4 ≈ x3
end
end

@testset "MixedPrecision" begin
x = rand(Float16, 2)
opt_state = Optimisers.setup(MixedPrecision(Adam(1e-3)), x)
@test opt_state.state[1] isa Vector{Float32}
@test opt_state.state[2][1] isa Vector{Float32}
g = rand(Float16, 2)
new_state, new_x = Optimisers.update(opt_state, x, rand(Float16, 2))
@test new_x == Float16.(new_state.state[1])
@test new_x ≈ x .- 1e-3 .* g

x = rand(Float16, 2)
opt_state = Optimisers.setup(MixedPrecision{Float64}(Adam(1e-3)), x)
@test opt_state.state[1] isa Vector{Float64}
@test opt_state.state[2][1] isa Vector{Float64}
end
28 changes: 28 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,20 @@ y2z(x) = x
@test sc2.γ.rule.opts[1].delta == 2.5
@test sc2.γ.rule.opts[2].eta === 0.001f0 # unchanged
@test sc2.γ.state[2][1] ≈ [0.1, 0.2, 0.2]

# MixedPrecision
mp = Optimisers.setup(MixedPrecision(Momentum(0.1, 0.9)), m)
mp1, mp2 = Optimisers.update(mp, m, (α = nothing, γ = [1,10,100],))
@test mp1.γ.rule.opt.eta == 0.1
@test mp1.γ.state[2] ≈ [0.1, 1, 10]

mp2 = Optimisers.adjust(mp1, 0.2)
@test mp2.γ.rule.opt.eta == 0.2
@test mp2.γ.rule.opt.rho == 0.9

mp3 = Optimisers.adjust(mp1; eta=0.3, rho=0.7)
@test mp3.γ.rule.opt.eta == 0.3
@test mp3.γ.rule.opt.rho == 0.7
end

@testset "adjusting parameters, in-place" begin
Expand Down Expand Up @@ -302,6 +316,20 @@ y2z(x) = x
@test sc1.γ.rule.opts[1].delta == 2.5
@test sc1.γ.rule.opts[2].eta === 0.2f0 # unchanged
@test sc1.γ.state[2][1] ≈ [0.1, 0.2, 0.2]

# MixedPrecision
mp = Optimisers.setup(MixedPrecision(Momentum(0.1, 0.9)), m)
mp1, mp2 = Optimisers.update(mp, m, (α = nothing, γ = [1,10,100],))
@test mp1.γ.rule.opt.eta == 0.1
@test mp1.γ.state[2] ≈ [0.1, 1, 10]

Optimisers.adjust!(mp1, 0.2)
@test mp1.γ.rule.opt.eta == 0.2
@test mp1.γ.rule.opt.rho == 0.9

Optimisers.adjust!(mp1; eta=0.3, rho=0.7)
@test mp1.γ.rule.opt.eta == 0.3
@test mp1.γ.rule.opt.rho == 0.7
end

@testset "freeze/thaw" begin
Expand Down