From bfb1b8fae2397f436bd679628b27dbc9a58227d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 17 Aug 2024 09:25:50 +0100 Subject: [PATCH] &toandaominh1997 [BUG] fix dependency checkers in case of multiple distributions available in environment, e.g., on databricks (#352) Mirror bugfix of https://github.com/sktime/sktime/pull/6986 in preparation of deduplication refactor. --- skbase/utils/dependencies/_dependencies.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index 9fcf5205..9889f64f 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -3,7 +3,6 @@ import sys import warnings from functools import lru_cache -from importlib.metadata import distributions from inspect import isclass from packaging.markers import InvalidMarker, Marker @@ -186,9 +185,19 @@ def _get_installed_packages_private(): Same as _get_installed_packages, but internal to avoid mutating the lru_cache by accident. """ + from importlib.metadata import distributions, version + dists = distributions() - packages = {dist.metadata["Name"]: dist.version for dist in dists} - return packages + package_names = {dist.metadata["Name"] for dist in dists} + package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names} + # developer note: + # we cannot just use distributions naively, + # because the same top level package name may appear *twice*, + # e.g., in a situation where a virtual env overrides a base env, + # such as in deployment environments like databricks. + # the "version" contract ensures we always get the version that corresponds + # to the importable distribution, i.e., the top one in the sys.path. + return package_versions def _get_installed_packages():