Skip to content

Commit

Permalink
Add jax.random.multinomial.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Dec 29, 2024
1 parent 6dbda90 commit cc65d03
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* Added {func}`jax.random.multinomial`.

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
Expand Down
1 change: 1 addition & 0 deletions docs/jax.random.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Random Samplers
logistic
lognormal
maxwell
multinomial
multivariate_normal
normal
orthogonal
Expand Down
59 changes: 59 additions & 0 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2627,6 +2627,65 @@ def binomial(
batching.defvectorized(random_clone_p)
mlir.register_lowering(random_clone_p, lambda _, k: [k])


def multinomial(
key: Array,
n: RealArray,
p: RealArray,
axis: int = -1,
shape: Shape | None = None,
):
r"""Sample from a multinomial distribution.
The probability mass function is
.. math::
f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}
Args:
key: a PRNG key used as the random key.
n: a float array-like representing the number of trials.
p: a float array-like representing the probabilities of each outcome.
axis: axis along which probabilities are defined for each outcome.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``n`` and ``p``.
Returns:
An array of counts for each outcome.
"""

key, _ = _check_prng_key("multinomial", key)
check_arraylike("multinomial", n, p)

if shape is not None:
p = jnp.broadcast_to(p, shape)

def f(remainder, ratio_key):
ratio, key = ratio_key
count = binomial(key, remainder, ratio)
return remainder - count, count

p = jnp.moveaxis(p, axis, 0)

p_shape = jnp.shape(p)

shape = jnp.broadcast_shapes(jnp.shape(n), p_shape[1:])
n = jnp.broadcast_to(n, shape)
p = jnp.broadcast_to(p, (p_shape[0],) + shape)

remaining_probs = lax.cumsum(p, 0, reverse=True)
ratios = p / jnp.where(remaining_probs == 0, 1, remaining_probs)

keys = split(key, ratios.shape[0])

remainder, counts = lax.scan(f, n, (ratios, keys), unroll=True)
# final remainder should be zero

counts = jnp.moveaxis(counts, 0, axis)

return counts


def clone(key):
"""Clone a key for reuse
Expand Down
1 change: 1 addition & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@
loggamma as loggamma,
lognormal as lognormal,
maxwell as maxwell,
multinomial as multinomial,
multivariate_normal as multivariate_normal,
normal as normal,
orthogonal as orthogonal,
Expand Down
41 changes: 41 additions & 0 deletions tests/random_lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,47 @@ def testBinomialCornerCases(self):
self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False)
self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False)

def testMultinomial(self):
key = random.key(0)
probs = jnp.array([
[0.5, 0.2, 0.3],
[0.1, 0.2, 0.7],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.5, 0.0, 0.5],
])
trials = 1e8
counts = random.multinomial(key, trials, probs)
freqs = counts / trials
self.assertAllClose(freqs, probs, atol=1e-3)

with self.subTest("test with axis=0"):
counts = random.multinomial(key, trials, probs.T, axis=0)
freqs = counts / trials
self.assertAllClose(freqs, probs.T, atol=1e-3)

@jtu.sample_product(
[
dict(shape=shape, axis=axis)
for shape in [(2, 3), (2, 3, 5)]
for ndim in [len(shape)]
for axis in range(-ndim, ndim)
]
)
def testMultinomialAxisShape(self, axis, shape):
key = random.key(0)

key, subkey = random.split(key)
exps = random.exponential(key, shape)
probs = exps / exps.sum(axis=axis, keepdims=True)

trials = 1e8
counts = random.multinomial(key, trials, probs, axis, shape)
freqs = counts / trials

self.assertAllClose(freqs, probs, atol=1e-3)

def test_batched_key_errors(self):
keys = lambda: jax.random.split(self.make_key(0))
msg = "{} accepts a single key, but was given a key array of shape.*"
Expand Down

0 comments on commit cc65d03

Please sign in to comment.