Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to a new thread-safe utility for catching warnings. #25626

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ py_library(
testonly = 1,
srcs = [
"_src/test_util.py",
"_src/test_warning_util.py",
],
visibility = [
":internal",
Expand Down
57 changes: 44 additions & 13 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import time
from typing import Any, TextIO
import unittest
import warnings
import zlib

from absl.testing import absltest
Expand All @@ -49,6 +48,7 @@
from jax._src import dtypes as _dtypes
from jax._src import lib as _jaxlib
from jax._src import monitoring
from jax._src import test_warning_util
from jax._src import xla_bridge
from jax._src import util
from jax._src import mesh as mesh_lib
Expand Down Expand Up @@ -118,7 +118,7 @@
)

TEST_NUM_THREADS = config.int_flag(
'jax_test_num_threads', 0,
'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')),
help='Number of threads to use for running tests. 0 means run everything '
'in the main thread. Using > 1 thread is experimental.'
)
Expand Down Expand Up @@ -1076,7 +1076,7 @@ def stopTest(self, test: unittest.TestCase):
with self.lock:
# We assume test_result is an ABSL _TextAndXMLTestResult, so we can
# override how it gets the time.
time_getter = self.test_result.time_getter
time_getter = getattr(self.test_result, "time_getter", None)
try:
self.test_result.time_getter = lambda: self.start_time
self.test_result.startTest(test)
Expand All @@ -1085,7 +1085,8 @@ def stopTest(self, test: unittest.TestCase):
self.test_result.time_getter = lambda: stop_time
self.test_result.stopTest(test)
finally:
self.test_result.time_getter = time_getter
if time_getter is not None:
self.test_result.time_getter = time_getter

def addSuccess(self, test: unittest.TestCase):
self.actions.append(lambda: self.test_result.addSuccess(test))
Expand Down Expand Up @@ -1120,6 +1121,8 @@ def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.Test
if TEST_NUM_THREADS.value <= 0:
return super().run(result)

test_warning_util.install_threadsafe_warning_handlers()

executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
lock = threading.Lock()
futures = []
Expand Down Expand Up @@ -1368,11 +1371,44 @@ def assertMultiLineStrippedEqual(self, expected, what):
self.assertMultiLineEqual(expected_clean, what_clean,
msg=f"Found\n{what}\nExpecting\n{expected}")


@contextmanager
def assertNoWarnings(self):
with warnings.catch_warnings():
warnings.simplefilter("error")
with test_warning_util.raise_on_warnings():
yield

# We replace assertWarns and assertWarnsRegex with functions that use the
# thread-safe warning utilities. Unlike the unittest versions these only
# function as context managers.
@contextmanager
def assertWarns(self, warning, *, msg=None):
with test_warning_util.record_warnings() as ws:
yield
for w in ws:
if not isinstance(w.message, warning):
continue
if msg is not None and msg not in str(w.message):
continue
return
self.fail(f"Expected warning not found {warning}:'{msg}', got "
f"{ws}")

@contextmanager
def assertWarnsRegex(self, warning, regex):
if regex is not None:
regex = re.compile(regex)

with test_warning_util.record_warnings() as ws:
yield
for w in ws:
if not isinstance(w.message, warning):
continue
if regex is not None and not regex.search(str(w.message)):
continue
return
self.fail(f"Expected warning not found {warning}:'{regex}', got "
f"{ws}")

hawkinsp marked this conversation as resolved.
Show resolved Hide resolved

def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None,
rtol=None, atol=None, check_cache_misses=True):
Expand Down Expand Up @@ -1449,11 +1485,7 @@ def assertNotDeleted(self, x):
self.assertFalse(x.is_deleted())


@contextmanager
def ignore_warning(*, message='', category=Warning, **kw):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=message, category=category, **kw)
yield
ignore_warning = test_warning_util.ignore_warning

# -------------------- Mesh parametrization helpers --------------------

Expand Down Expand Up @@ -1768,9 +1800,8 @@ def make_axis_points(size):
logtiny = finfo.minexp / prec_dps_ratio
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)

with warnings.catch_warnings():
with ignore_warning(category=RuntimeWarning):
# Silence RuntimeWarning: overflow encountered in cast
warnings.simplefilter("ignore")
half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
half_line = -half_neg_line[::-1]
axis_points[-size - 1:-1] = half_line
Expand Down
132 changes: 132 additions & 0 deletions jax/_src/test_warning_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Thread-safe utilities for catching and testing for warnings.
#
# The Python warnings module, at least as of Python 3.13, is not thread-safe.
# The catch_warnings() feature is inherently racy, see
# https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe
#
# This module offers a thread-safe way to catch and record warnings. We install
# a custom showwarning hook with the Python warning module, and then rely on
# the CPython warnings module to call our show warning function. We then use it
# to create our own thread-safe warning filtering utilities.

import contextlib
import re
import threading
import warnings


class _WarningContext(threading.local):
"Thread-local state that contains a list of warning handlers."

def __init__(self):
self.handlers = []


_context = _WarningContext()


# Callback that applies the handlers in reverse order. If no handler matches,
# we raise an error.
def _showwarning(message, category, filename, lineno, file=None, line=None):
for handler in reversed(_context.handlers):
if handler(message, category, filename, lineno, file, line):
return
raise category(message)


@contextlib.contextmanager
def raise_on_warnings():
"Context manager that raises an exception if a warning is raised."
if warnings.showwarning is not _showwarning:
with warnings.catch_warnings():
warnings.simplefilter("error")
yield
return

def handler(message, category, filename, lineno, file=None, line=None):
raise category(message)

_context.handlers.append(handler)
try:
yield
finally:
_context.handlers.pop()


@contextlib.contextmanager
def record_warnings():
"Context manager that yields a list of warnings that are raised."
if warnings.showwarning is not _showwarning:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
yield w
return

log = []

def handler(message, category, filename, lineno, file=None, line=None):
log.append(warnings.WarningMessage(message, category, filename, lineno, file, line))
return True

_context.handlers.append(handler)
try:
yield log
finally:
_context.handlers.pop()


@contextlib.contextmanager
def ignore_warning(*, message: str | None = None, category: type = Warning):
"Context manager that ignores any matching warnings."
if warnings.showwarning is not _showwarning:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="" if message is None else message, category=category)
yield
return

if message:
message_re = re.compile(message)
else:
message_re = None

category_cls = category

def handler(message, category, filename, lineno, file=None, line=None):
text = str(message) if isinstance(message, Warning) else message
if (message_re is None or message_re.match(text)) and issubclass(
category, category_cls
):
return True
return False

_context.handlers.append(handler)
try:
yield
finally:
_context.handlers.pop()


def install_threadsafe_warning_handlers():
# Hook the showwarning method. The warnings module explicitly notes that
# this is a function that users may replace.
warnings.showwarning = _showwarning

# Set the warnings module to always display warnings. We hook into it by
# overriding the "showwarning" method, so it's important that all warnings
# are "shown" by the usual mechanism.
warnings.simplefilter("always")
8 changes: 8 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,14 @@ jax_py_test(
],
)

jax_py_test(
name = "warnings_util_test",
srcs = ["warnings_util_test.py"],
deps = [
"//jax:test_util",
] + py_deps("absl/testing"),
)

jax_py_test(
name = "xla_bridge_test",
srcs = ["xla_bridge_test.py"],
Expand Down
48 changes: 23 additions & 25 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import unittest
from unittest import mock
from unittest import SkipTest
import warnings

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -39,6 +38,7 @@
from jax._src import monitoring
from jax._src import path as pathlib
from jax._src import test_util as jtu
from jax._src import test_warning_util
from jax._src import xla_bridge
from jax._src.compilation_cache_interface import CacheInterface
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -232,21 +232,20 @@ def test_cache_write_warning(self):
with (
config.raise_persistent_cache_errors(False),
mock.patch.object(cc._get_cache(backend).__class__, "put") as mock_put,
warnings.catch_warnings(record=True) as w,
test_warning_util.record_warnings() as w,
):
warnings.simplefilter("always")
mock_put.side_effect = RuntimeError("test error")
self.assertEqual(f(2).item(), 4)
if len(w) != 1:
print("Warnings:", [str(w_) for w_ in w], flush=True)
self.assertLen(w, 1)
self.assertIn(
(
"Error writing persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error"
),
str(w[0].message),
)
if len(w) != 1:
print("Warnings:", [str(w_) for w_ in w], flush=True)
self.assertLen(w, 1)
self.assertIn(
(
"Error writing persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error"
),
str(w[0].message),
)

def test_cache_read_warning(self):
f = jit(lambda x: x * x)
Expand All @@ -255,23 +254,22 @@ def test_cache_read_warning(self):
with (
config.raise_persistent_cache_errors(False),
mock.patch.object(cc._get_cache(backend).__class__, "get") as mock_get,
warnings.catch_warnings(record=True) as w,
test_warning_util.record_warnings() as w,
):
warnings.simplefilter("always")
mock_get.side_effect = RuntimeError("test error")
# Calling assertEqual with the jitted f will generate two PJIT
# executables: Equal and the lambda function itself.
self.assertEqual(f(2).item(), 4)
if len(w) != 1:
print("Warnings:", [str(w_) for w_ in w], flush=True)
self.assertLen(w, 1)
self.assertIn(
(
"Error reading persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error"
),
str(w[0].message),
)
if len(w) != 1:
print("Warnings:", [str(w_) for w_ in w], flush=True)
self.assertLen(w, 1)
self.assertIn(
(
"Error reading persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error"
),
str(w[0].message),
)

def test_min_entry_size(self):
with (
Expand Down
6 changes: 2 additions & 4 deletions tests/deprecation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

from absl.testing import absltest
from jax._src import deprecations
from jax._src import test_util as jtu
from jax._src import test_warning_util
from jax._src.internal_test_util import deprecation_module as m

class DeprecationTest(absltest.TestCase):

def testModuleDeprecation(self):
with warnings.catch_warnings():
warnings.simplefilter("error")
with test_warning_util.raise_on_warnings():
self.assertEqual(m.x, 42)

with self.assertWarnsRegex(DeprecationWarning, "Please use x"):
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def testReducer(self, name, rng_factory, shape, dtype, out_dtype,
rng = rng_factory(self.rng())
@jtu.ignore_warning(category=NumpyComplexWarning)
@jtu.ignore_warning(category=RuntimeWarning,
message="mean of empty slice.*")
message="Mean of empty slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="overflow encountered.*")
def np_fun(x):
Expand Down
Loading
Loading