Skip to content

Commit

Permalink
[ENH] refactor - move StdoutMute context manager to utils (#338)
Browse files Browse the repository at this point in the history
Minor refactor in anticipation of external and further internal usage -
moves `StdoutMute` context manager to a submodule of `utils`.

Also refactors manual `stdout` handling in `_check_soft_dependencies` to
use the utility.
  • Loading branch information
fkiraly authored Jun 20, 2024
1 parent 7dfd670 commit 44d0cae
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 46 deletions.
62 changes: 24 additions & 38 deletions skbase/lookup/_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
# https://github.com/sktime/sktime/blob/main/LICENSE
import importlib
import inspect
import io
import os
import pathlib
import pkgutil
import re
import sys
import warnings
from collections.abc import Iterable
from copy import deepcopy
Expand All @@ -31,6 +29,7 @@
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union

from skbase.base import BaseObject
from skbase.utils.stdout_mute import StdoutMute
from skbase.validate import check_sequence

__all__: List[str] = ["all_objects", "get_package_metadata"]
Expand Down Expand Up @@ -335,7 +334,7 @@ def _import_module(

# if suppress_import_stdout:
# setup text trap, import
with StdoutMute(active=suppress_import_stdout):
with StdoutMuteNCatchMNF(active=suppress_import_stdout):
if isinstance(module, str):
imported_mod = importlib.import_module(module)
elif isinstance(module, importlib.machinery.SourceFileLoader):
Expand Down Expand Up @@ -865,7 +864,7 @@ class name if ``return_names=False`` and ``return_tags is not None``.
obj_types = _check_object_types(object_types, class_lookup)

# Ignore deprecation warnings triggered at import time and from walking packages
with warnings.catch_warnings(), StdoutMute(active=suppress_import_stdout):
with warnings.catch_warnings(), StdoutMuteNCatchMNF(active=suppress_import_stdout):
warnings.simplefilter("ignore", category=FutureWarning)
warnings.simplefilter("module", category=ImportWarning)
warnings.filterwarnings(
Expand Down Expand Up @@ -1025,7 +1024,7 @@ def _make_dataframe(all_objects, columns):
return pd.DataFrame(all_objects, columns=columns)


class StdoutMute:
class StdoutMuteNCatchMNF(StdoutMute):
"""A context manager to suppress stdout.
This class is used to suppress stdout when importing modules.
Expand All @@ -1042,39 +1041,26 @@ class StdoutMute:
except catch and suppress ModuleNotFoundError.
"""

def __init__(self, active=True):
self.active = active

def __enter__(self):
"""Context manager entry point."""
# capture stdout if active
# store the original stdout so it can be restored in __exit__
if self.active:
self._stdout = sys.stdout
sys.stdout = io.StringIO()

def __exit__(self, type, value, traceback): # noqa: A002
"""Context manager exit point."""
# restore stdout if active
# if not active, nothing needs to be done, since stdout was not replaced
if self.active:
sys.stdout = self._stdout

if type is not None:
# if a ModuleNotFoundError is raised,
# we suppress to a warning if "soft dependency" is in the error message
# otherwise, raise
if type is ModuleNotFoundError:
if "soft dependency" not in str(value):
return False
warnings.warn(str(value), ImportWarning, stacklevel=2)
return True

# all other exceptions are raised
return False
# if no exception was raised, return True to indicate successful exit
# return statement not needed as type was None, but included for clarity
return True
def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
"""Handle exceptions raised during __exit__.
Parameters
----------
type : type
The type of the exception raised.
Known to be not-None and Exception subtype when this method is called.
"""
# if a ModuleNotFoundError is raised,
# we suppress to a warning if "soft dependency" is in the error message
# otherwise, raise
if type is ModuleNotFoundError:
if "soft dependency" not in str(value):
return False
warnings.warn(str(value), ImportWarning, stacklevel=2)
return True

# all other exceptions are raised
return False


def _coerce_to_tuple(x):
Expand Down
5 changes: 4 additions & 1 deletion skbase/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"skbase.utils.dependencies",
"skbase.utils.dependencies._dependencies",
"skbase.utils.random_state",
"skbase.utils.stdout_mute",
"skbase.validate",
"skbase.validate._named_objects",
"skbase.validate._types",
Expand All @@ -79,6 +80,7 @@
"skbase.utils.deep_equals",
"skbase.utils.dependencies",
"skbase.utils.random_state",
"skbase.utils.stdout_mute",
"skbase.validate",
)
SKBASE_PUBLIC_CLASSES_BY_MODULE = {
Expand All @@ -99,13 +101,14 @@
"BaseMetaEstimatorMixin",
),
"skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"),
"skbase.lookup._lookup": ("StdoutMute",),
"skbase.lookup._lookup": ("StdoutMuteNCatchMNF",),
"skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"),
"skbase.testing.test_all_objects": (
"BaseFixtureGenerator",
"QuickTester",
"TestAllObjects",
),
"skbase.utils.stdout_mute": ("StdoutMute",),
}
SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
SKBASE_CLASSES_BY_MODULE.update(
Expand Down
10 changes: 3 additions & 7 deletions skbase/utils/dependencies/_dependencies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
"""Utility to check soft dependency imports, and raise warnings or errors."""
import io
import sys
import warnings
from importlib import import_module
Expand All @@ -10,6 +9,8 @@
from packaging.requirements import InvalidRequirement, Requirement
from packaging.specifiers import InvalidSpecifier, SpecifierSet

from skbase.utils.stdout_mute import StdoutMute

__author__: List[str] = ["fkiraly", "mloning"]


Expand Down Expand Up @@ -130,12 +131,7 @@ def _check_soft_dependencies(
package_import_name = package_name
# attempt import - if not possible, we know we need to raise warning/exception
try:
if suppress_import_stdout:
# setup text trap, import, then restore
sys.stdout = io.StringIO()
pkg_ref = import_module(package_import_name)
sys.stdout = sys.__stdout__
else:
with StdoutMute(active=suppress_import_stdout):
pkg_ref = import_module(package_import_name)
# if package cannot be imported, make the user aware of installation requirement
except ModuleNotFoundError as e:
Expand Down
64 changes: 64 additions & 0 deletions skbase/utils/stdout_mute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
"""Context manager to suppress stdout."""

__author__ = ["fkiraly"]

import io
import sys


class StdoutMute:
"""A context manager to suppress stdout.
Exception handling on exit can be customized by overriding
the ``_handle_exit_exceptions`` method.
Parameters
----------
active : bool, default=True
Whether to suppress stdout or not.
If True, stdout is suppressed.
If False, stdout is not suppressed, and the context manager does nothing
except catch and suppress ModuleNotFoundError.
"""

def __init__(self, active=True):
self.active = active

def __enter__(self):
"""Context manager entry point."""
# capture stdout if active
# store the original stdout so it can be restored in __exit__
if self.active:
self._stdout = sys.stdout
sys.stdout = io.StringIO()

def __exit__(self, type, value, traceback): # noqa: A002
"""Context manager exit point."""
# restore stdout if active
# if not active, nothing needs to be done, since stdout was not replaced
if self.active:
sys.stdout = self._stdout

if type is not None:
return self._handle_exit_exceptions(type, value, traceback)

# if no exception was raised, return True to indicate successful exit
# return statement not needed as type was None, but included for clarity
return True

def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
"""Handle exceptions raised during __exit__.
Parameters
----------
type : type
The type of the exception raised.
Known to be not-None and Exception subtype when this method is called.
value : Exception
The exception instance raised.
traceback : traceback
The traceback object associated with the exception.
"""
# by default, all exceptions are raised
return False

0 comments on commit 44d0cae

Please sign in to comment.