diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index 9fcf5205..9889f64f 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -3,7 +3,6 @@ import sys import warnings from functools import lru_cache -from importlib.metadata import distributions from inspect import isclass from packaging.markers import InvalidMarker, Marker @@ -186,9 +185,19 @@ def _get_installed_packages_private(): Same as _get_installed_packages, but internal to avoid mutating the lru_cache by accident. """ + from importlib.metadata import distributions, version + dists = distributions() - packages = {dist.metadata["Name"]: dist.version for dist in dists} - return packages + package_names = {dist.metadata["Name"] for dist in dists} + package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names} + # developer note: + # we cannot just use distributions naively, + # because the same top level package name may appear *twice*, + # e.g., in a situation where a virtual env overrides a base env, + # such as in deployment environments like databricks. + # the "version" contract ensures we always get the version that corresponds + # to the importable distribution, i.e., the top one in the sys.path. + return package_versions def _get_installed_packages():