Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to a new thread-safe utility for catching warnings. #25626

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hawkinsp
Copy link
Collaborator

The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.

The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
@hawkinsp hawkinsp requested a review from jakevdp December 20, 2024 02:32
@hawkinsp
Copy link
Collaborator Author

There are some test failures because there are still some uses of warnings.catch_warnings in non-test parts of JAX. I'll need to clean those up (they will certainly break in free-threading mode).

yield

@contextmanager
def assertWarnsRegex(self, expected_warning, expected_regex):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing assertWarnsRegex can be used either as a context manager or as a simple function call, passing the callable and args/kwargs directly. Should we support that in the new version?

return
self.fail(f"Expected warning not found {expected_warning}:'{expected_regex}', got "
f"{ws}")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also probably define assertWarns(), which is used in a few places.

addopts = "--doctest-glob='*.rst' --ignore='examples/ffi'"
# We disable pytest's warning capture system, because neither it not the
# warnings module is thread-safe. Instead, we use our own utilities to catch
# and test for warnings in tests, see test_warning_util.py.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this means that tests which do not import jax._src.test_util will no longer treat warnings as errors. I believe this includes doctests, as well as the array API test. Is there any way we can restore the previous behavior just for those cases?


def test_warning_raises(self):
with self.assertRaises(UserWarning, msg="hello"):
warnings.warn("hello", category=UserWarning)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add a similar test DeprecationWarning? It has some different handing than UserWarning within Python's warning system, and it's important that we still error when we see it.

@nascheme
Copy link

FYI: python/cpython#128300 : Add thread-safe context manager to "warnings" module

This is going to take some time to get fixed in CPython but it might be a good idea to make a similar API so that you can switch back to the CPython version in the future, assuming it gets fixed properly. If you have any feedback on my PR, I would be happy to hear about it. My main concerns were to not break existing non-threaded non-async code that is happily using catch_watchings() and add an API that can be used in a backwards compatible way (for code that wants to support both old and newer versions of Python).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants