Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple datasets in fit_to_data and add option to return opt_state #210

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

aseyboldt
Copy link
Contributor

Some loss functions can require several arrays instead of only one.
This extends fit_to_data so that it passes batches of those to the loss functions.

It can also be quite useful to reuse the optimizer state across different runs, so I added opt_state and return_opt_state arguments to fit_to_data to facilitate that.

@danielward27
Copy link
Owner

I can see this could be useful, e.g. it would give a simple way to support weighted samples (#211). I appreciate the effort to maintain backwards compatibility, but it does lead to a bit of an ugly interface. If rewriting from scratch, you could consider replacing x with data: Array | tuple[Array, ...] , and remove the condition argument (just pass it in the data tuple). Then we could just document that train test splitting/batching etc will happen on all of data, and that the batched arrays should match the positional arguments of the loss function. Right now, the implementation you have given forces condition to be the last positional argument, which whilst fine, isn't clear until you read the code.

To me there isn't really a compelling reason to me why condition should be passed as a separate argument once we have chosen to allow x (or data) to be a tuple of arrays. If we go down this road, it might be best to just immediately deprecate condition, and x, with a warning encouraging people to use the new data positional argument (I guess before x, so that no change is necessary if x was passed as a positional argument).

Ability to pass and return the optimizer state is a good idea, but if we include it in fit_to_data, it should also be included in fit_to_key_based_loss for consistency. I would hope (I haven't checked) that it may also allow avoiding recompilation when calling the fitting function twice? Again, if willing to make a breaking change, you could choose to always return the optimizer state, and users could ignore it if they didn't need it. If writing from scratch this is better to me - one consistent return type, and one less argument. But obviously it introduces a breaking change, which I think is better to avoid, at least for now (i.e. as you have done).

Although I feel a little unsure about it, I think I am happy to merge this with the addition of the following changes:

  1. Add a data positional argument before x which can be an array or a tuple of arrays (and a default x=None).
  2. We add deprecation messages x and condition saying they will be removed in the following version in favour of a new data positional argument, in a way that is non-breaking for now (obviously it will be breaking when deprecated, but at least there has been some warning).
  3. Document the "alignment" between the tuple of arrays in data, and the positional arguments in the loss functions (e.g. (target, condition) for MaximumLikelihoodLoss).

I'm definitely open to feedback on this though if you or anyone else has any thoughts: there is inevitably a trade off between maintaining backwards compatibility and simplifying the code and improving the API. Users are probably better posed to discuss this than I am - in my applications I don't really mind quick to fix breaking changes (in both FlowJAX and other dependencies), but I understand that may be different for others.

@danielward27
Copy link
Owner

An issue I have just thought about, is it may be required to support None in data if assuming they map to positional arguments in the loss. Imagine a loss taking x, condition=None, weights=None. For unconditional weighted density estimation, data would have to be (x, None, weights) and then presumably we'd need code to ensure None doesn't get passed to the train val split and batching code (or we need to handle it in those functions). That is a little unsatisfying... Another possibility would be to force use of key words i.e. data could be a dict of arrays, matching key word arguments of the loss, rather than positional arguments - but that also feels a little unsatisfying. Maybe worth considering a few options before committing to one. Maintaining the current implementation is also somewhat reasonable - keeping the training loop as simple as possible so it's easier for users to copy and modify it as needed.

@aseyboldt
Copy link
Contributor Author

Yeah, I also wasn't too happy about introducing the return_opt_state parameter, but I thought I'd rather avoid a breaking change. I also added this to the other fit function.
Maybe the updated version is a nice solution for the multiple arrays and the condition argument?
I changed the x argument to a *data argument, so you can just pass multiple arguments as positional arguments.
Strictly speaking this is also a breaking change, because a user might have called it with a keyword argument for x, but this would at least be obvious to fix, and I would hope isn't common.
For the condition argument, I now also raise a deprecation warning, as that can be passed as the last data argument.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants