Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some minor style changes #135

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module Optimisers

using Functors: functor, fmap, isleaf, @functor, fmapstructure, children, AbstractWalk
using LinearAlgebra
using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk

include("interface.jl")
export AbstractRule
Expand All @@ -16,6 +18,8 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion

include("deprecations.jl")

###
### one-array functions
###
Expand Down
12 changes: 12 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# To be removed in Optimisers v0.3

@deprecate iswriteable maywrite false # remove when releasing [email protected]

@deprecate ADAM Adam
@deprecate NADAM NAdam
@deprecate ADAMW AdamW
@deprecate RADAM RAdam
@deprecate OADAM OAdam
@deprecate ADAGrad AdaGrad
@deprecate ADADelta AdaDelta

1 change: 0 additions & 1 deletion src/destructure.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
const NoT = NoTangent()

"""
Expand Down
50 changes: 28 additions & 22 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent
base(dx::Tangent) = backing(canonicalize(dx))
base(dx) = dx

const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}

abstract type AbstractRule end
Expand All @@ -15,6 +15,7 @@ mutable struct Leaf{R,S} # mutable so that its identity encodes parameter shari
state::S
frozen::Bool # ... and to allow freeze! to act on this.
end

Leaf(rule, state; frozen::Bool = false) = Leaf(rule, state, frozen)

@functor Leaf
Expand All @@ -23,9 +24,9 @@ Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b)

function setup(rule::AbstractRule, model)
cache = IdDict()
tree = _setup(rule, model; cache)
state = _setup(rule, model; cache)
isempty(cache) && @warn "setup found no trainable parameters in this model"
tree
state
end

# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc.
Expand Down Expand Up @@ -56,38 +57,38 @@ end
### update
###

function update(tree, model, grad, higher...)
t′ = fmap(copy, tree; exclude = maywrite) # walks inside Leaf
function update(state, model, grad, higher...)
t′ = fmap(copy, state; exclude = maywrite) # walks inside Leaf
x′ = fmap(copy, model; exclude = maywrite)
update!(t′, x′, grad, higher...)
end

function update!(tree, model, grad, higher...)
function update!(state, model, grad, higher...)
# First walk is to accumulate the gradient. This recursion visits every copy of
# shared leaves, but stops when branches are absent from the gradient:
grads = IdDict{Leaf, Any}()
_grads!(grads, tree, model, grad, higher...)
# Second walk is to update the model. The params cache indexed by (tree,x),
_grads!(grads, state, model, grad, higher...)
# Second walk is to update the model. The params cache indexed by (state,x),
# so that identified Leafs can tie isbits parameters, but setup won't do that for you:
newmodel = _update!(tree, model; grads, params = IdDict())
tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree.
newmodel = _update!(state, model; grads, params = IdDict())
state, newmodel # Note that state is guaranteed to be updated. Also that it's not necc a tree.
end

function _update!(tree, x; grads, params)
haskey(params, (tree,x)) && return params[(tree,x)]
isbits(tree) && return x # means () is not cached, and also (((),),)
function _update!(state, x; grads, params)
haskey(params, (state, x)) && return params[(state, x)]
isbits(state) && return x # means () is not cached, and also (((),),)
x′, re = functor(x)
x′′ = re(valuemap((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
x′′ = re(valuemap((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), state, x′))
if ismutable(x′′)
params[(tree,x)] = x′′
params[(state, x)] = x′′
else # no ties to preserve between immutable structs, right?
x′′
end
end
function _update!(ℓ::Leaf, x; grads, params)
haskey(params, (ℓ,x)) && return params[(ℓ,x)]
haskey(params, (ℓ, x)) && return params[(ℓ, x)]
ℓ.frozen && return x
params[(ℓ,x)] = if haskey(grads, ℓ)
params[(ℓ, x)] = if haskey(grads, ℓ)
ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...)
subtract!(x, x̄′)
else
Expand All @@ -98,18 +99,21 @@ end
subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)

_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing

function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...)
x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s))
dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible.
nothing
end

_grads!(dict::IdDict, t, x, ::Zero...) = nothing
function _grads!(dict::IdDict, tree, x, x̄s...)

function _grads!(dict::IdDict, state, x, x̄s...)
# The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from
# functor(typeof(tree), base(x̄)), for things like Transpose
# functor(typeof(state), base(x̄)), for things like Transpose
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
x′, _ = functor(typeof(x), x)
valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), state, x′, x̄s′...)
end

# default all rules to first order calls
Expand Down Expand Up @@ -142,8 +146,6 @@ For now, simply `x isa DenseArray` allowing `Array`, `CuArray`, etc.
maywrite(::DenseArray) = true # see https://github.com/FluxML/Optimisers.jl/issues/99 for discussion
maywrite(_) = false

@deprecate iswriteable maywrite false # remove when releasing [email protected]

"""
trainable(x::Layer) -> NamedTuple

Expand Down Expand Up @@ -175,6 +177,7 @@ end

valuemap(f, x...) = map(f, x...)
valuemap(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)

valueforeach(f, x...) = foreach(f, x...)
valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
f(v, (get(y, k, nothing) for y in ys)...)
Expand Down Expand Up @@ -216,8 +219,11 @@ macro lazy(ex)
end

function lazy end

Broadcast.broadcasted(::typeof(lazy), x) = Lazy(x)

struct Lazy{T}; bc::T; end

Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)

onevalue(λ::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
Expand Down
8 changes: 0 additions & 8 deletions src/rules.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
@deprecate ADAM Adam
@deprecate NADAM NAdam
@deprecate ADAMW AdamW
@deprecate RADAM RAdam
@deprecate OADAM OAdam
@deprecate ADAGrad AdaGrad
@deprecate ADADelta AdaDelta

"""
Descent(η = 1f-1)

Expand Down