Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NewRecur for Blocked RNNs #2316

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions recur_funcs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using Flux

function run_new_recur()
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2


x_block = reshape(reduce(vcat, x), 1, 1, length(x))
nm_layer = Flux.NewRecur(cell; return_sequence = true)
_out = layer(x_block)
e, g = Flux.withgradient(nm_layer) do layer
out = layer(x_block)
sum(out[1, 1, end])
end
grads = g[1][:cell]

@show primal[1] ≈ e
@show ∇Wi ≈ grads[:Wi]
@show ∇Wh ≈ grads[:Wh]
@show ∇b ≈ grads[:b]
@show ∇state0 ≈ grads[:state0]

return
end

function run_scan_full()

x = [[2.0f0], [3.0f0], [4.0f0]]
x_block = reshape(reduce(vcat, x), 1, 1, length(x))
# nm_layer = Flux.NewRecur(cell; return_sequence = true)
w = zeros(1)
_out = Flux.scan_full((a, b)->(sum(w.*b), sum(w.*b)), 0.0f0, x_block)
e, g = Flux.withgradient(w) do layer
out = Flux.scan_full((a, b)->(sum(w.*b), sum(w.*b)), 0.0f0, x_block)
sum(out[1, 1, end])
end
grads = g[1][:cell]
return
end
195 changes: 194 additions & 1 deletion src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,185 @@ end

reshape_cell_output(h, x) = reshape(h, :, size(x)[2:end]...)



# non-stateful recurrence

"""
scan_full

Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the full output of the sequence and the final carry. See `scan_partial` to only return the final output of the sequence.
"""
function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray})
# Recurrence operation used in the fold. Takes the state of the
# fold and the next input, returns the new state.
function recurrence_op((carry, outputs), input)
carry, out = func(carry, input)
return carry, vcat(outputs, [out])
end
# Fold left to right.
return Base.mapfoldl_impl(identity, recurrence_op, (init_carry, empty(xs)), xs)
end

function scan_full(func, init_carry, x_block)
# x_block is an abstractarray and we want to scan over the last dimension.
xs_ = Flux.eachlastdim(x_block)

# this is needed due to a bug in eachlastdim which produces a vector in a
# gradient context, but a generator otherwise.
xs = if xs_ isa Base.Generator
collect(xs_) # eachlastdim produces a generator in non-gradient environment
else
xs_
end
scan_full(func, init_carry, xs)
end

# Chain Rule for Base.mapfoldl_impl
function ChainRulesCore.rrule(
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(Base.mapfoldl_impl),
::typeof(identity),
op::G,
init,
x::Union{AbstractArray, Tuple};
) where {G}
# Hobbits has two types afaict, first is for the first component, then the second component.
# This has to do with the entrance I believe (i.e. we don't know what function enters, but we know what
# function is called in subsequent things...
# hobbits = Vector{Tuple}(undef, length(x)) # Unfornately Zygote needs this
# accum_init = ChainRulesCore.rrule_via_ad(config, op, init[1], nothing)
accum_init = ChainRulesCore.rrule_via_ad(config, op, init, x[1])
hobbits = accumulate(x[begin+1:end]; init=accum_init) do (a, _), b
@show a, b
c, back = ChainRulesCore.rrule_via_ad(config, op, a, b)
end

y = first(last(hobbits))
axe = axes(x)
project = ChainRulesCore.ProjectTo(x)
function unfoldl(dy)
trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
ds, da, db = back(dc)
end
f_ds, f_da, f_db = accum_init[2](trio[end][2])
dop = sum(first, trio) + f_ds
dx = [[f_db]; map(last, Iterators.reverse(trio))]
d_init = f_da
return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe)))
end
return y, unfoldl
end

# From Lux.jl
# function ChainRulesCore.rrule(
# config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
# ::typeof(Base.mapfoldl_impl),
# ::typeof(identity),
# op::G,
# init,
# x::Union{AbstractArray, Tuple};
# ) where {G}
# hobbits = Vector{Any}(undef, length(x)) # Unfornately Zygote needs this
# accumulate!(hobbits, x; init=(init, nothing)) do (a, _), b
# c, back = ChainRulesCore.rrule_via_ad(config, op, a, b)
# end
# y = first(last(hobbits))
# axe = axes(x)
# project = ChainRulesCore.ProjectTo(x)
# function unfoldl(dy)
# trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
# ds, da, db = back(dc)
# end
# dop = sum(first, trio)
# dx = map(last, Iterators.reverse(trio))
# @show dx
# d_init = trio[end][2]
# return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe)))
# end
# return y, unfoldl
# end


"""
scan_partial

Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the final output of the sequence and the final carry. See `scan_full` to return the entire output sequence.
"""
function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray})
x_init, x_rest = Iterators.peel(xs)
(carry, y) = func(init_carry, x_init)
for x in x_rest
(carry, y) = func(carry, x)
end
carry, y
end

function scan_partial(func, init_carry, x_block)
# x_block is an abstractarray and we want to scan over the last dimension.
xs_ = Flux.eachlastdim(x_block)

# this is needed due to a bug in eachlastdim which produces a vector in a
# gradient context, but a generator otherwise.
xs = if xs_ isa Base.Generator
collect(xs_) # eachlastdim produces a generator in non-gradient environment
else
xs_
end
scan_partial(func, init_carry, xs)
end


"""
NewRecur
New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. This struct has two type parameters. The first `RET_SEQUENCE` is a boolean which determines whether `scan_full` (`RET_SEQUENCE=true`) or `scan_partial` (`RET_SEQUENCE=false`) is used to scan through the sequence. This structure has no internal state, and instead returns:

```julia
l = NewRNN(1,2)
xs # Some input array Input x BatchSize x Time
init_carry # the initial carry of the cell.
l(xs) # -> returns the output of the RNN, uses cell.state0 as init_carry.
l(init_carry, xs) # -> returns (final_carry, output), where the size ofoutput is determined by RET_SEQUENCE.
```
"""
struct NewRecur{RET_SEQUENCE, T}
cell::T
# state::S
function NewRecur(cell; return_sequence::Bool=false)
new{return_sequence, typeof(cell)}(cell)
end
function NewRecur{true}(cell)
new{true, typeof(cell)}(cell)
end
function NewRecur{false}(cell)
new{false, typeof(cell)}(cell)
end
end

Flux.@functor NewRecur
Flux.trainable(a::NewRecur) = (; cell = a.cell)
Base.show(io::IO, m::NewRecur) = print(io, "NewRecur(", m.cell, ")")

(l::NewRecur)(init_carry, x_mat::AbstractMatrix) = MethodError("Matrix is ambiguous with NewRecur")
(l::NewRecur)(init_carry, x_mat::AbstractVector{T}) where {T<:Number} = MethodError("Vector is ambiguous with NewRecur")

function (l::NewRecur)(xs::AbstractArray)
results = l(l.cell.state0, xs)
results[2] # Only return the output here.
end

function (l::NewRecur{false})(init_carry, xs)
results = scan_partial(l.cell, init_carry, xs)
results[1], results[2]
end

function (l::NewRecur{true})(init_carry, xs)
results = scan_full(l.cell, init_carry, xs)
results[1], stack(results[2], dims=3)
end



# Stateful recurrence

"""
Expand Down Expand Up @@ -187,8 +366,14 @@ function (m::Recur)(x::AbstractArray{T, 3}) where T
reshape(reduce(hcat, h), sze[1], sze[2], length(h))
end

# Vanilla RNN

########
#
# Recurrent Cells
#
########

# Vanilla RNN
struct RNNCell{F,I,H,V,S}
σ::F
Wi::I
Expand Down Expand Up @@ -289,6 +474,8 @@ julia> r(rand(4, 10)) |> size # batch size of 10
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
Recur(m::RNNCell) = Recur(m, m.state0)

NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence)

# LSTM

struct LSTMCell{I,H,V,S}
Expand Down Expand Up @@ -362,6 +549,8 @@ julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
Recur(m::LSTMCell) = Recur(m, m.state0)

NewLSTM(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.LSTMCell(a...; ka...); return_sequence=return_sequence)

# GRU

function _gru_output(gxs, ghs, bs)
Expand Down Expand Up @@ -436,6 +625,8 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
Recur(m::GRUCell) = Recur(m, m.state0)

NewGRU(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.GRUCell(a...; ka...); return_sequence=return_sequence)

# GRU v3

struct GRUv3Cell{I,H,V,HH,S}
Expand Down Expand Up @@ -505,3 +696,5 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
"""
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
Recur(m::GRUv3Cell) = Recur(m, m.state0)

NewGRUv3(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.GRUv3Cell(a...; ka...); return_sequence=return_sequence)
Loading