-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make RNNs blocked (and maybe fixing gradients along the way) #2258
Comments
On your second approach, how about emulating what PyTorch does with immutable struct wrappers over the RNN cell types? Say Integration with This approach avoids having to deal with state by making the user save it and carry it over themselves. It's not as ergonomic since they'd have to thread said states through larger models, but that's more than doable with layers like |
Good idea. This also seems similar to Haiku's approach as well afaict. I think this could also give us an opportunity to provide an interface for a static unroll vs a dynamic unroll. I think at first we should just do a loop, but there might an opportunity to use a generated function to replace the loop with the unrolled version. But that might be more problematic than its worth depending on how the |
I added a first pass of this functionality to |
Ok. We have merged a potential interface ( |
I like that idea. We should probably figure out how to make the rrule type stable as part of that :P |
Working PR in #2316. |
Motivation and description
Given #2185 and other issues caused by the current mutability of the recur interface, we should move to a more standard blocked (i.e. 3D for simple RNN) interface. This has the benefits of:
I have not tested how we might fix the gradients by moving to this restricted interface. But if we decide to remove the statefulness (see below) we can fix gradients as seen in FluxML/Fluxperimental.jl#7.
Possible Implementation
I see two ways we can do this change, one which is a wider change of the Flux chain interface and another which tries to only fix Recur. In either case, the implementation would assume the final dimension of your multi-dimensional array is the time index. For a simple RNN it would assume the dimensions of the incoming array as: Features x Batch x Time. It will produce an error if a 2d array or 1d array is passed to recur, to avoid ambiguities.
One possible implementation is to go ahead and do the full change over to removing state from the network generally. See FluxML/Fluxperimental.jl#7. This would overhaul large parts of the interface into chain, and could be targeted at 0.14. See the implementation done in the above PR and FluxML/Fluxperimental.jl#5 for details.
The second possible approach is to just first remove the loop over timesteps interface and replace with the 3d interface. This initial change restricts the interface to be 3d, but I haven't tested how we could fix gradients while maintaining mutability and statefulness in Recur. The interface/impl would likely look much like:
Flux.jl/src/layers/recurrent.jl
Lines 184 to 188 in c9c262d
The text was updated successfully, but these errors were encountered: