-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow multiple datasets in fit_to_data and add option to return opt_state #210
base: main
Are you sure you want to change the base?
Conversation
I can see this could be useful, e.g. it would give a simple way to support weighted samples (#211). I appreciate the effort to maintain backwards compatibility, but it does lead to a bit of an ugly interface. If rewriting from scratch, you could consider replacing To me there isn't really a compelling reason to me why condition should be passed as a separate argument once we have chosen to allow Ability to pass and return the optimizer state is a good idea, but if we include it in Although I feel a little unsure about it, I think I am happy to merge this with the addition of the following changes:
I'm definitely open to feedback on this though if you or anyone else has any thoughts: there is inevitably a trade off between maintaining backwards compatibility and simplifying the code and improving the API. Users are probably better posed to discuss this than I am - in my applications I don't really mind quick to fix breaking changes (in both FlowJAX and other dependencies), but I understand that may be different for others. |
An issue I have just thought about, is it may be required to support |
954fdd7
to
4acaa15
Compare
Yeah, I also wasn't too happy about introducing the |
4acaa15
to
06bf498
Compare
06bf498
to
ff3d30e
Compare
Some loss functions can require several arrays instead of only one.
This extends
fit_to_data
so that it passes batches of those to the loss functions.It can also be quite useful to reuse the optimizer state across different runs, so I added
opt_state
andreturn_opt_state
arguments tofit_to_data
to facilitate that.