Skip to content

Commit

Permalink
Deprecate scipy.special.lpmn & lpmn_values
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 2, 2025
1 parent 726950b commit 804cfef
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* {func}`jax.scipy.special.lpmn` and {func}`jax.scipy.special.lpmn_values`
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.

* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
Expand Down
26 changes: 24 additions & 2 deletions jax/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
log_softmax as log_softmax,
logit as logit,
logsumexp as logsumexp,
lpmn as lpmn,
lpmn_values as lpmn_values,
lpmn as _deprecated_lpmn,
lpmn_values as _deprecated_lpmn_values,
multigammaln as multigammaln,
ndtr as ndtr,
ndtri as ndtri,
Expand All @@ -65,3 +65,25 @@
from jax._src.third_party.scipy.special import (
fresnel as fresnel,
)

_deprecations = {
# Added Nov 20 2024
"lpmn": (
"jax.scipy.special.lpmn is deprecated; no replacement is planned.",
_deprecated_lpmn,
),
"lpmn_values": (
"jax.scipy.special.lpmn_values is deprecated; no replacement is planned.",
_deprecated_lpmn_values,
),
}

import typing as _typing
if _typing.TYPE_CHECKING:
lpmn = _deprecated_lpmn
lpmn_values = _deprecated_lpmn_values
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
2 changes: 2 additions & 0 deletions tests/lax_scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def scipy_fun(z):
shape=[(5,), (10,)],
dtype=float_dtypes,
)
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
def testLpmn(self, l_max, shape, dtype):
if jtu.is_device_tpu(6, "e"):
self.skipTest("TODO(b/364258243): fails on TPU v6e")
Expand All @@ -354,6 +355,7 @@ def scipy_fun(z, m=l_max, n=l_max):
shape=[(2,), (3,), (4,), (64,)],
dtype=float_dtypes,
)
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
def testNormalizedLpmnValues(self, l_max, shape, dtype):
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
args_maker = lambda: [rng(shape, dtype)]
Expand Down

0 comments on commit 804cfef

Please sign in to comment.