Skip to content

Commit

Permalink
[ENH] StderrMute (#350)
Browse files Browse the repository at this point in the history
Add stderr context manager.

related: sktime/sktime#6891,
sktime/sktime#6653
  • Loading branch information
XinyuWuu authored Aug 6, 2024
1 parent 8b2134e commit f960194
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 0 deletions.
3 changes: 3 additions & 0 deletions skbase/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"skbase.utils.dependencies",
"skbase.utils.dependencies._dependencies",
"skbase.utils.random_state",
"skbase.utils.stderr_mute",
"skbase.utils.stdout_mute",
"skbase.validate",
"skbase.validate._named_objects",
Expand All @@ -80,6 +81,7 @@
"skbase.utils.deep_equals",
"skbase.utils.dependencies",
"skbase.utils.random_state",
"skbase.utils.stderr_mute",
"skbase.utils.stdout_mute",
"skbase.validate",
)
Expand Down Expand Up @@ -108,6 +110,7 @@
"QuickTester",
"TestAllObjects",
),
"skbase.utils.stderr_mute": ("StderrMute",),
"skbase.utils.stdout_mute": ("StdoutMute",),
}
SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
Expand Down
64 changes: 64 additions & 0 deletions skbase/utils/stderr_mute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
"""Context manager to suppress stderr."""

__author__ = ["XinyuWu"]

import io
import sys


class StderrMute:
"""A context manager to suppress stderr.
Exception handling on exit can be customized by overriding
the ``_handle_exit_exceptions`` method.
Parameters
----------
active : bool, default=True
Whether to suppress stderr or not.
If True, stderr is suppressed.
If False, stderr is not suppressed, and the context manager does nothing
except catch and suppress ModuleNotFoundError.
"""

def __init__(self, active=True):
self.active = active

def __enter__(self):
"""Context manager entry point."""
# capture stderr if active
# store the original stderr so it can be restored in __exit__
if self.active:
self._stderr = sys.stderr
sys.stderr = io.StringIO()

def __exit__(self, type, value, traceback): # noqa: A002
"""Context manager exit point."""
# restore stderr if active
# if not active, nothing needs to be done, since stderr was not replaced
if self.active:
sys.stderr = self._stderr

if type is not None:
return self._handle_exit_exceptions(type, value, traceback)

# if no exception was raised, return True to indicate successful exit
# return statement not needed as type was None, but included for clarity
return True

def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
"""Handle exceptions raised during __exit__.
Parameters
----------
type : type
The type of the exception raised.
Known to be not-None and Exception subtype when this method is called.
value : Exception
The exception instance raised.
traceback : traceback
The traceback object associated with the exception.
"""
# by default, all exceptions are raised
return False
31 changes: 31 additions & 0 deletions skbase/utils/tests/test_std_mute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Tests of stdout_mute and stderr_mute."""
import io
import sys
from contextlib import redirect_stderr, redirect_stdout

import pytest

from skbase.utils.stderr_mute import StderrMute
from skbase.utils.stdout_mute import StdoutMute

__author__ = ["XinyuWu"]


@pytest.mark.parametrize(
"mute, expected", [(True, ["", ""]), (False, ["test stdout", "test sterr"])]
)
def test_std_mute(mute, expected):
"""Test StderrMute."""
stderr_io = io.StringIO()
stdout_io = io.StringIO()

try:
with redirect_stderr(stderr_io), redirect_stdout(stdout_io):
with StderrMute(mute), StdoutMute(mute):
sys.stdout.write("test stdout")
sys.stderr.write("test sterr")
1 / 0
except ZeroDivisionError:
assert expected == [stdout_io.getvalue(), stderr_io.getvalue()]

0 comments on commit f960194

Please sign in to comment.