From 2d6dcfe60e1b77c46aa2385114c6e99132a8fef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Fri, 23 Aug 2024 15:35:38 +0100 Subject: [PATCH] [ENH] safer `get_fitted_params` default functionality to avoid exception on `getattr` (#353) This PR makes the `get_fitted_params` core functionality and defaults safer against exceptions on `getattr`. In rare cases, `getattr` can cause an exception, namely if a property is being accessed in a way that generates an exception. Examples are some fitted parameter arguments in `sklearn` that are decorated as property in newer versions, e.g., `RandomForestRegressor.estimator_` in unfitted state. In most use cases, there will be no change in behaviour; in the described case where an exception would be raised, this is now caught and suppressed, and the corresponding parameter is considered as not present. Changes occur only in cases that would have previously raised genuine exceptions, so no working code is affected, and no deprecation is necessary despite this being a change to a core interface element. --- skbase/base/_base.py | 45 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index c67b4e41..28512554 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -1292,10 +1292,47 @@ def _get_fitted_params_default(self, obj=None): fitted_params = [ attr for attr in dir(obj) if attr.endswith("_") and not attr.startswith("_") ] - # remove the "_" at the end - fitted_param_dict = { - p[:-1]: getattr(obj, p) for p in fitted_params if hasattr(obj, p) - } + + def getattr_safe(obj, attr): + """Get attribute of object, safely. + + Safe version of getattr, that returns None if attribute does not exist, + or if an exception is raised during getattr. + Also returns a boolean indicating whether the attribute was successfully + retrieved, to distinguish between None value and non-existent attribute, + or exception during getattr. + + Parameters + ---------- + obj : any object + object to get attribute from + attr : str + attribute name to get from obj + + Returns + ------- + attr : Any + attribute of obj, if it exists and does not raise on getattr; + otherwise None + success : bool + whether the attribute was successfully retrieved + """ + try: + if hasattr(obj, attr): + attr = getattr(obj, attr) + return attr, True + else: + return None, False + except Exception: + return None, False + + fitted_param_dict = {} + + for p in fitted_params: + attr, success = getattr_safe(obj, p) + if success: + p_name = p[:-1] # remove the "_" at the end to get the parameter name + fitted_param_dict[p_name] = attr return fitted_param_dict