-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add random.binomial and random.multinomial #13327
Comments
@sharadmv, is it possible to port tfp's implementation to jax.random? |
Possibly, yes. The implementation is fairly complex though, and makes some accelerator-specific tradeoffs IIRC. cc: @srvasude @brianwa84. |
Any plans to add |
Thanks for reaching out – I don't know of anyone working on this currently. |
A workaround if you want pure JAX is to take the log of your probabilities vector (nonnegative, sums to 1): jax.random.categorical(key, jnp.log(p)) (Based on the "contract" of
) That can't be the most efficient way to sample though... |
Fyi: You would need to sum many categorical samples to get a multinomial
sample.
…On Sat, Jan 27, 2024, 7:23 PM Andrey Portnoy ***@***.***> wrote:
A workaround if you want pure JAX is to take the log of your probabilities
vector (that sums to 1):
jax.random.categorical(key, jnp.log(p))
(Based on the "contract" of jax.random.categorical:
logits: Unnormalized log probabilities of the categorical distribution(s)
to sample from,
so that softmax(logits, axis) gives the corresponding probabilities.
)
That can't be the most efficient way to sample though...
—
Reply to this email directly, view it on GitHub
<#13327 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFJFSI62724M6LFSTMPHHGLYQWLABAVCNFSM6AAAAAASFQR3O6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMJTGM4DGOJZGI>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
I would use a multinomial feature if added! The categorical workaround is a nonstarter if I want a lot of samples. |
I would also benefit from a multinomial feature to experiment with speeding up the numerical simulation of jump processes in my simulation code pySODM. |
👍 |
I would benefit from multinomial as well. I'm currently using the numpy one for efficient bootstrap sampling and not having it here makes it harder to make the switch to jax. |
I've created a PR for this: #25688. |
Add JAX counterparts of numpy.random.binomial and numpy.random.multinomial to jax.random package. See #480 (comment) for context. A current workaround is using the JAX substrate of TensorFlow Probability:
Output:
The text was updated successfully, but these errors were encountered: