From f960194de067f1b91646eb7ab626ce237f597358 Mon Sep 17 00:00:00 2001 From: Xinyu Wu <57612792+Xinyu-Wu-0000@users.noreply.github.com> Date: Wed, 7 Aug 2024 01:07:26 +0800 Subject: [PATCH] [ENH] StderrMute (#350) Add stderr context manager. related: https://github.com/sktime/sktime/pull/6891, https://github.com/sktime/sktime/issues/6653 --- skbase/tests/conftest.py | 3 ++ skbase/utils/stderr_mute.py | 64 +++++++++++++++++++++++++++++ skbase/utils/tests/test_std_mute.py | 31 ++++++++++++++ 3 files changed, 98 insertions(+) create mode 100644 skbase/utils/stderr_mute.py create mode 100644 skbase/utils/tests/test_std_mute.py diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 9615a3b8..bb86cbf1 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -54,6 +54,7 @@ "skbase.utils.dependencies", "skbase.utils.dependencies._dependencies", "skbase.utils.random_state", + "skbase.utils.stderr_mute", "skbase.utils.stdout_mute", "skbase.validate", "skbase.validate._named_objects", @@ -80,6 +81,7 @@ "skbase.utils.deep_equals", "skbase.utils.dependencies", "skbase.utils.random_state", + "skbase.utils.stderr_mute", "skbase.utils.stdout_mute", "skbase.validate", ) @@ -108,6 +110,7 @@ "QuickTester", "TestAllObjects", ), + "skbase.utils.stderr_mute": ("StderrMute",), "skbase.utils.stdout_mute": ("StdoutMute",), } SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy() diff --git a/skbase/utils/stderr_mute.py b/skbase/utils/stderr_mute.py new file mode 100644 index 00000000..774078e8 --- /dev/null +++ b/skbase/utils/stderr_mute.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +"""Context manager to suppress stderr.""" + +__author__ = ["XinyuWu"] + +import io +import sys + + +class StderrMute: + """A context manager to suppress stderr. + + Exception handling on exit can be customized by overriding + the ``_handle_exit_exceptions`` method. + + Parameters + ---------- + active : bool, default=True + Whether to suppress stderr or not. + If True, stderr is suppressed. + If False, stderr is not suppressed, and the context manager does nothing + except catch and suppress ModuleNotFoundError. + """ + + def __init__(self, active=True): + self.active = active + + def __enter__(self): + """Context manager entry point.""" + # capture stderr if active + # store the original stderr so it can be restored in __exit__ + if self.active: + self._stderr = sys.stderr + sys.stderr = io.StringIO() + + def __exit__(self, type, value, traceback): # noqa: A002 + """Context manager exit point.""" + # restore stderr if active + # if not active, nothing needs to be done, since stderr was not replaced + if self.active: + sys.stderr = self._stderr + + if type is not None: + return self._handle_exit_exceptions(type, value, traceback) + + # if no exception was raised, return True to indicate successful exit + # return statement not needed as type was None, but included for clarity + return True + + def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002 + """Handle exceptions raised during __exit__. + + Parameters + ---------- + type : type + The type of the exception raised. + Known to be not-None and Exception subtype when this method is called. + value : Exception + The exception instance raised. + traceback : traceback + The traceback object associated with the exception. + """ + # by default, all exceptions are raised + return False diff --git a/skbase/utils/tests/test_std_mute.py b/skbase/utils/tests/test_std_mute.py new file mode 100644 index 00000000..ec139fd8 --- /dev/null +++ b/skbase/utils/tests/test_std_mute.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +"""Tests of stdout_mute and stderr_mute.""" +import io +import sys +from contextlib import redirect_stderr, redirect_stdout + +import pytest + +from skbase.utils.stderr_mute import StderrMute +from skbase.utils.stdout_mute import StdoutMute + +__author__ = ["XinyuWu"] + + +@pytest.mark.parametrize( + "mute, expected", [(True, ["", ""]), (False, ["test stdout", "test sterr"])] +) +def test_std_mute(mute, expected): + """Test StderrMute.""" + stderr_io = io.StringIO() + stdout_io = io.StringIO() + + try: + with redirect_stderr(stderr_io), redirect_stdout(stdout_io): + with StderrMute(mute), StdoutMute(mute): + sys.stdout.write("test stdout") + sys.stderr.write("test sterr") + 1 / 0 + except ZeroDivisionError: + assert expected == [stdout_io.getvalue(), stderr_io.getvalue()]