Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: add a supertype to layers #2028

Closed
wants to merge 4 commits into from
Closed

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jul 27, 2022

This proposes to gives Flux's layer types various supertypes.

One reason to like this is that it simplifies the use of show. If you have the same supertype as Chain, you will be unfolded at top level like it is. No mystery functions to overload. Closes #1932, closes #2044

Edit: 568af9b goes further: We can define functor for this abstract type, eliminating the need to call @functor. (It's a pretty mysterious macro, and you can't @macroexpand it to have a look.) We can also define trainable for some abstract types; maybe that's also less mysterious.

Another is this: Flux.gpu and CUDA.cu both recursively move things, the latter via Adapt not Functors. Which means cu does not preserve tied weights. But if we can rely on the the shared arrays both living within a struct whose type we own (like Chain) then we can convert cu to something Functors-based at that boundary. (It's enough for one outer layer to do this -- using weird Dense-like layers marked only with @functor within a Chain is fine.)

Note that this supertype is entirely optional. The PR does not change the fact that functor is how Flux walks models, etc, and so it does not restrict how you can customise things. It only aims to make it easy to opt-in to the standard behaviours, without a zoo of weird macros.

src/layers/types.jl Outdated Show resolved Hide resolved
Copy link
Contributor

@ericphanson ericphanson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this (after arguing for macro in #2044 😄). One comment is it might be helpful to put the subtype declarations in the docstring for each function, e.g.

"""
    Chain(layers...) <: ContainerLayer

and to document the abstract types in https://github.com/FluxML/Flux.jl/blob/master/docs/src/models/layers.md.

:((NamedTuple{$F}(($(args...),)), $recon))
else
# Getting this parameterless type takes about 2μs, every time:
namedtuple(x), Base.splat(Base.typename(T).wrapper)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess one option here is to use ConstructionBase.constructorof so that this can be customized in problematic cases (if there are any with layers... not really sure since the discussion in JuliaLang/julia#46213 is kinda hard to follow). But I guess we can't do that in the generated path, and we don't want diverging behavior.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My argument against ConstructionBase is this:

  • At present the path for "easy" structs is @functor MyStruct, which makes assumptions about the default constructor.

  • If you have a weird struct, you must write a method for functor which understands it. But this is discouraged.

With the supertype, you should automatically get the same behaviour as @functor MyStruct.

But Functors is still handling recursion and re-building, and you can still supply a more specific method for functor. It is in this case that a the supertype method using ConstructionBase would provide an alternative way to customise behaviour. But now you have to know about two different recurse-and-rebuild libraries, instead of one. And because we own the supertype, it's fair to assume that you are writing this for Flux, it isn't some large pre-existing thing which already has ConstructionBase methods written... I mean that's fine, you can still use Flux, but not via the supertype.

(I would actually probably favour ditching Functors, and simply demanding that all Flux models use structs with default constructors. Unlike Zygote it has no ambition to apply to arbitrary code. But that's for another day, this PR adds a friendly path without changing how other paths work.)

Copy link
Member

@ToucheSir ToucheSir Aug 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My argument is for ConstructionBase, but not in Flux. Rather, Functors should use it instead of rolling its own thing. The fallback behaviour will be identical (call T), but it's more likely a library will need to depend on ConstructionBase anyhow for compat with Accessors etc. instead of on Functors (which is why we see piracy).

That or nix Functors as Michael said, but then the argument is even stronger IMO. You can still get away with a default constructor, but if you want something fancy there is but one game left in town.

Copy link
Member Author

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to putting the description of these into the docs.

I'm not sure whether they should be in layer docstrings. IMO these should aim at people calling the layer, for which it doesn't matter at all. Whereas to write a new one you should first do @less Dense(1 => 1).

Or look at the right bit of the manual... I think we had some examples of defining new structs, and probably those should also grow to mention supertypes.

Comment on lines +80 to +84
abstract type PartialTrainLayer{which} <: SimpleLayer end

function Optimisers.trainable(layer::PartialTrainLayer{which}) where {which}
NamedTuple{which}(map(sy -> getfield(layer, sy), which))
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type is the odd one out, I'm not entirely sure we should have it.

One, if we wish to have other traits which identify some subset of fields, the way trainable does, then if they are independent they cannot share the same type hierarchy, and so are a bit second-class. Maybe that's ok, trainability is pretty important.

Two, it doesn't cover all existing cases of trainable, e.g. we want mutable struct Recur{T,S} <: ContainerLayer for printing, but also have trainable(a::Recur) = (; cell = a.cell).

Comment on lines -192 to 193
struct RNNCell{F,A,V,S}
struct RNNCell{F,A,V,S} <: SimpleLayer # or should it be PartialTrainLayer{(:Wi, :Wh, :b)}?
σ::F
Wi::A
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm also not sure. Should state0 be trainable?

Comment on lines -355 to -356
@functor BatchNorm
trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here PartialTrainLayer{(:β, :γ)} fixes which fields are trainable, permanently. That's OK, since β = nothing when it's not trainable, so it'll be ignored. It improves type-stability although I doubt this matters at all.

NamedTuple{F}(map(sy -> getfield(x, sy), F))
end

Adapt.adapt_structure(to, layer::AbstractLayer) = fmap(x -> adapt(to, x), layer)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be just this

Suggested change
Adapt.adapt_structure(to, layer::AbstractLayer) = fmap(x -> adapt(to, x), layer)
Adapt.adapt_structure(to, layer::AbstractLayer) = fmap(adapt(to), layer)

but check the version bound on Adapt

@mcabbott mcabbott added this to the v0.14 milestone Aug 22, 2022
@darsnack
Copy link
Member

Since this has now been put on the next milestone, I should chime in. I personally do not like having a super-type. My reasons:

  • It precludes someone from using their own type hierarchy. There are real examples of this: InvertibleNetworks.jl and Mill.jl. Or any code where we write functions of the form f: model -> model' will reasonably want to dispatch on the model type. Depending on what exactly f and model are, using a type hierarchy could be very useful here.
  • It's a sharp change from Flux's approach to this problem in the past. If we were getting something of important value here, then it might be worth it. But it seems like this is one implementation that is subjectively more convenient than a macro.
  • One positive for this approach is that it is more natural to Julia users than a macro. I agree in general with this for types, but I don't think it holds as much weight in our case. Any user wanting to implement a custom layer is going to have to look at the documentation...these types aren't so intuitive that they should naturally come to mind. I don't see how learning that you need to @functor MyLayer is so much worse at this point.
  • I agree that @functor is weird and confusing. I think Flux ought to provide something more friendly like @layer which does the Functor stuff but also show stuff and more.
  • The current type hierarchy—PartialTrainLayer{which} <: SimpleLayer <: AbstractLayer—is already getting messy. Julia's type system just isn't good for hanging all this information on, IMO (see the never-ending traits discussions). What happens for a layer which <: ContainerLayer but also has trainable parameters of its own?
  • A super-type makes adopting the defaults easy, but it doesn't avoid needing to know and override trainable, show, etc. The non-default cases are still just as annoying. I'd rather work on a solution that makes those cases better. Hiding the default case is not so hard of a problem.
  • I don't see us ever moving away from something Functors-like. We can nix Functors, but we will still have the need to walk and recurse nested structs. Having multiple levels/ways of opting into being a Flux layer seems confusing to me.

It seems like the main motivation for this PR originally (though it has grown) is to make _big_show easier to opt-into. Once we adopt a type hierarchy, it will be hard to change our minds. We should think carefully about what it should be and what functionality should hang on top of it. If default show is what's at stake here, then a more minimal change might be easier for right now. Two options:

  • Define something like @layer which does @functor + show. This seems like it has no downsides. We mask the confusing @functor, and we can always introduce types on top of it later. If we want to remove it in favor of types, it seems easier to deprecate.
  • Introduce just SimpleLayer and ContainerLayer for the purpose of show only. Leave trainable, etc. as is for now.

@mcabbott
Copy link
Member Author

I agree the present state of this PR is a sort-of "maximial" look at what's possible, to see what clashes result, and as noted I'm not sure that including trainable is a great idea. The "minimal" one would be just SimpleLayer and ContainerLayer.

Figuring out whether things clash with packages is one goal here. It looks at first glance that InvertibleNetworks.jl has a type hierarchy which could subtype Flux.AbstractLayer, but perhaps no finer, not sure, anyone know more details? Mill.jl I'm less sure -- anyone know? Note that both packages mentioned do load Flux.

I also agree that some kind of Flux.@layer macro is an alternative option. Maybe a smaller change from @functor / @treelike. (Maybe someone can make a PR to try it out and see what problems arise; @functor is a weird macro with eval but perhaps it need not be.) To be useful for show it would somehow need to take options to make the Chain-like / Dense-like distinction, default to Dense-like perhaps. Maybe it ought to take options to control trainable too.

Unlike a supertype, the new macro could be obligatory. The idea of translating cu to an fmap walk at the outermost layer would then be easy.

Macros are always a bit opaque. You could argue that FluxML has probably gone a little too far in avoiding types, e.g. train! took 4 arguments of ::Any, good luck remembering... but now methods(Flux.train!) makes it easier not to mess up. Although it's certainly possible to go too far the other way.

(Don't read too much into the milestone, just trying to separate future ideas from bugfixes a bit.)

@ericphanson
Copy link
Contributor

I don’t think a macro that defines a show method should be obligatory because then you will get method overwrites if you define your own show method (which is one nice thing about the abstract type- it’s a fallback then).

@darsnack
Copy link
Member

Good point. The messy thing about macros will be the order in which stuff is called. Something like

struct MyLayer
    foo
end

Base.show(..., ::MyLayer) = ...

@layer MyLayer

seems like a simple error to run into. The current @functor only defines functor which is likely not to run into this issue.

I think with something like FluxML/Functors.jl#41, we remove trainable etc. completely. Then we are left with show which can be opt-in:

@layer MyLayer prettyprint=:simple

I was going to prototype the Functors issue anyways, so I can also test out the @layer stuff.

Another option is to make @layer apply on the struct definition itself. This will ensure that any methods that are defined later like a custom show will happen after the default.

Alternatively, introducing the hierarchy with just SimpleLayer and ContainerLayer only for non-parameter related stuff like show seems reasonable to me. At least there will be a clean demarcation between Functors stuff and other stuff.

@mcabbott
Copy link
Member Author

Yes the macro needs at least a way to choose Chain-like vs. Dense-like printing. It it not only methods for show that do this, but possibly one option should apply no methods to show. I don't think we should ever rely on order of overwriting methods.

Macros which act on structs seem a step more opaque. And also more complicated to write, e.g. how should it interact with @kwdef?

@@ -55,16 +54,11 @@ _show_children(m::Maxout) = m.layers
_show_children(p::Parallel) = (p.connection, p.layers...)
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods are a further complication to any attempt to make a show easier to opt into. If you claim (through a macro or a type) that your layer is a Chain-like container, do you imply that you want any Tuple within it to be splatted, i.e. that you have a constructor like Parallel(f, layers...) = Parallel(f, layers)? Or is that still an obscure function you ought to overload?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Macro to display model struct the way Flux does
5 participants