From f5d6026f67aa27399018eaad6223e35edeb8ca3d Mon Sep 17 00:00:00 2001 From: Elien Vandermaesen Date: Wed, 18 Dec 2024 11:48:53 +0100 Subject: [PATCH 1/4] issue #254 Clear cache when authenticate --- openeo/rest/connection.py | 2 ++ openeo/util.py | 3 +++ tests/rest/test_connection.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 28de15dfd..8a6ebb95c 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -402,6 +402,7 @@ def authenticate_basic(self, username: Optional[str] = None, password: Optional[ if username is None: raise OpenEoClientException("No username/password given or found.") + self._capabilities_cache.clear() resp = self.get( '/credentials/basic', # /credentials/basic is the only endpoint that expects a Basic HTTP auth @@ -470,6 +471,7 @@ def _get_oidc_provider( f"No OIDC provider given. Using first provider {provider_id!r} as advertised by backend." ) + self._capabilities_cache.clear() provider_info = OidcProviderInfo.from_dict(provider) if parse_info else None return provider_id, provider_info diff --git a/openeo/util.py b/openeo/util.py index 44842124a..53550ccb8 100644 --- a/openeo/util.py +++ b/openeo/util.py @@ -476,6 +476,9 @@ def get(self, key: Union[str, tuple], load: Callable[[], Any]): self._cache[key] = load() return self._cache[key] + def clear(self): + self._cache = {} + def str_truncate(text: str, width: int = 64, ellipsis: str = "...") -> str: """Shorten a string (with an ellipsis) if it is longer than certain length.""" diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 0bf36d2d7..f93bed4a9 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -513,6 +513,39 @@ def test_capabilities_caching(requests_mock): assert con.capabilities().api_version() == "1.0.0" assert m.call_count == 1 +def test_capabilities_caching_after_authenticate_basic(requests_mock): + user, pwd = "john262", "J0hndo3" + requests_mock.get(API_URL, json={"api_version": "1.0.0", "endpoints": BASIC_ENDPOINTS}) + requests_mock.get(API_URL + 'credentials/basic', text=_credentials_basic_handler(user, pwd)) + + with mock.patch('openeo.rest.connection.AuthConfig') as AuthConfig: + conn = Connection(API_URL) + conn._capabilities_cache._cache={"test":"test1"} + assert conn._capabilities_cache._cache != {} + AuthConfig.return_value.get_basic_auth.return_value = (user, pwd) + conn.authenticate_basic(user, pwd) + assert conn._capabilities_cache._cache == {} + + +def test_capabilities_caching_after_authenticate_oidc(requests_mock): + requests_mock.get(API_URL, json={"api_version": "1.0.0"}) + client_id = "myclient" + requests_mock.get(API_URL + 'credentials/oidc', json={ + "providers": [{"id": "fauth", "issuer": "https://fauth.test", "title": "Foo Auth", "scopes": ["openid", "im"]}] + }) + oidc_mock = OidcMock( + requests_mock=requests_mock, + expected_grant_type="authorization_code", + expected_client_id=client_id, + expected_fields={"scope": "im openid"}, + oidc_issuer="https://fauth.test", + scopes_supported=["openid", "im"], + ) + conn = Connection(API_URL) + conn._capabilities_cache._cache = {"test": "test1"} + conn.authenticate_oidc_authorization_code(client_id=client_id, webbrowser_open=oidc_mock.webbrowser_open) + assert conn._capabilities_cache._cache == {} + def test_file_formats(requests_mock): requests_mock.get("https://oeo.test/", json={"api_version": "1.0.0"}) From 629abe9f8bafbc955e5a6bd7f199e07cbbd2ddea Mon Sep 17 00:00:00 2001 From: Elien Vandermaesen Date: Wed, 8 Jan 2025 10:01:18 +0100 Subject: [PATCH 2/4] issue #254 move clear cache and improve tests --- openeo/rest/connection.py | 4 +-- tests/rest/test_connection.py | 67 +++++++++++++++++++++++++++++------ 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 8a6ebb95c..86b25a976 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -402,7 +402,6 @@ def authenticate_basic(self, username: Optional[str] = None, password: Optional[ if username is None: raise OpenEoClientException("No username/password given or found.") - self._capabilities_cache.clear() resp = self.get( '/credentials/basic', # /credentials/basic is the only endpoint that expects a Basic HTTP auth @@ -410,6 +409,7 @@ def authenticate_basic(self, username: Optional[str] = None, password: Optional[ ).json() # Switch to bearer based authentication in further requests. self.auth = BasicBearerAuth(access_token=resp["access_token"]) + self._capabilities_cache.clear() return self def _get_oidc_provider( @@ -471,7 +471,6 @@ def _get_oidc_provider( f"No OIDC provider given. Using first provider {provider_id!r} as advertised by backend." ) - self._capabilities_cache.clear() provider_info = OidcProviderInfo.from_dict(provider) if parse_info else None return provider_id, provider_info @@ -545,6 +544,7 @@ def _authenticate_oidc( _log.warning("No OIDC refresh token to store.") token = tokens.access_token self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token) + self._capabilities_cache.clear() self._oidc_auth_renewer = oidc_auth_renewer return self diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index f93bed4a9..c4f4066db 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -515,21 +515,51 @@ def test_capabilities_caching(requests_mock): def test_capabilities_caching_after_authenticate_basic(requests_mock): user, pwd = "john262", "J0hndo3" - requests_mock.get(API_URL, json={"api_version": "1.0.0", "endpoints": BASIC_ENDPOINTS}) + + def get_capabilities(request, context): + endpoints = BASIC_ENDPOINTS + if "Authorization" in request.headers: + endpoints.append({"path": "/account/status", "methods": ["GET"]}) + return {"api_version": "1.0.0", "endpoints": endpoints} + + get_capabilities_mock = requests_mock.get(API_URL, json=get_capabilities) requests_mock.get(API_URL + 'credentials/basic', text=_credentials_basic_handler(user, pwd)) - with mock.patch('openeo.rest.connection.AuthConfig') as AuthConfig: - conn = Connection(API_URL) - conn._capabilities_cache._cache={"test":"test1"} - assert conn._capabilities_cache._cache != {} - AuthConfig.return_value.get_basic_auth.return_value = (user, pwd) - conn.authenticate_basic(user, pwd) - assert conn._capabilities_cache._cache == {} + con = Connection(API_URL) + assert con.capabilities().capabilities == { + "api_version": "1.0.0", + "endpoints": [ + {"methods": ["GET"], "path": "/credentials/basic"}, + ], + } + assert get_capabilities_mock.call_count == 1 + con.capabilities() + assert get_capabilities_mock.call_count == 1 + + con.authenticate_basic(user, pwd) + assert get_capabilities_mock.call_count == 1 + assert con.capabilities().capabilities == { + "api_version": "1.0.0", + "endpoints": [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/account/status"}, + ], + } + assert get_capabilities_mock.call_count == 2 + def test_capabilities_caching_after_authenticate_oidc(requests_mock): requests_mock.get(API_URL, json={"api_version": "1.0.0"}) client_id = "myclient" + + def get_capabilities(request, context): + endpoints = BASIC_ENDPOINTS + if "Authorization" in request.headers: + endpoints.append({"path": "/account/status", "methods": ["GET"]}) + return {"api_version": "1.0.0", "endpoints": endpoints} + + get_capabilities_mock = requests_mock.get(API_URL, json=get_capabilities) requests_mock.get(API_URL + 'credentials/oidc', json={ "providers": [{"id": "fauth", "issuer": "https://fauth.test", "title": "Foo Auth", "scopes": ["openid", "im"]}] }) @@ -542,9 +572,26 @@ def test_capabilities_caching_after_authenticate_oidc(requests_mock): scopes_supported=["openid", "im"], ) conn = Connection(API_URL) - conn._capabilities_cache._cache = {"test": "test1"} + assert conn.capabilities().capabilities == { + "api_version": "1.0.0", + "endpoints": [ + {"methods": ["GET"], "path": "/credentials/basic"}, + ], + } + assert get_capabilities_mock.call_count == 1 + conn.capabilities() + assert get_capabilities_mock.call_count == 1 + conn.authenticate_oidc_authorization_code(client_id=client_id, webbrowser_open=oidc_mock.webbrowser_open) - assert conn._capabilities_cache._cache == {} + assert get_capabilities_mock.call_count == 1 + assert conn.capabilities().capabilities == { + "api_version": "1.0.0", + "endpoints": [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/account/status"}, + ], + } + assert get_capabilities_mock.call_count == 2 def test_file_formats(requests_mock): From 3a665fc7f25810813dcb35570ffaf163fa89b60f Mon Sep 17 00:00:00 2001 From: Elien Vandermaesen Date: Wed, 8 Jan 2025 10:56:56 +0100 Subject: [PATCH 3/4] issue #254 fix test by making copy --- tests/rest/test_connection.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index c4f4066db..7f4a4f89c 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -517,7 +517,7 @@ def test_capabilities_caching_after_authenticate_basic(requests_mock): user, pwd = "john262", "J0hndo3" def get_capabilities(request, context): - endpoints = BASIC_ENDPOINTS + endpoints = BASIC_ENDPOINTS.copy() if "Authorization" in request.headers: endpoints.append({"path": "/account/status", "methods": ["GET"]}) return {"api_version": "1.0.0", "endpoints": endpoints} @@ -550,11 +550,10 @@ def get_capabilities(request, context): def test_capabilities_caching_after_authenticate_oidc(requests_mock): - requests_mock.get(API_URL, json={"api_version": "1.0.0"}) client_id = "myclient" def get_capabilities(request, context): - endpoints = BASIC_ENDPOINTS + endpoints = BASIC_ENDPOINTS.copy() if "Authorization" in request.headers: endpoints.append({"path": "/account/status", "methods": ["GET"]}) return {"api_version": "1.0.0", "endpoints": endpoints} From 9d95feba335f93edb5699c7db506d86d4e1aaae8 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Fri, 17 Jan 2025 18:07:35 +0100 Subject: [PATCH 4/4] Issue #254/#691 introduce _on_auth_update handler - to make sure all cases are covered - include authenticate_oidc_access_token --- openeo/rest/connection.py | 21 +++++- tests/rest/test_connection.py | 126 ++++++++++++++++++++-------------- 2 files changed, 91 insertions(+), 56 deletions(-) diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 04a5a0242..fa4725822 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -113,6 +113,7 @@ def __init__( slow_response_threshold: Optional[float] = None, ): self._root_url = root_url + self._auth = None self.auth = auth or NullAuth() self.session = session or requests.Session() self.default_timeout = default_timeout or DEFAULT_TIMEOUT @@ -129,6 +130,18 @@ def __init__( def root_url(self): return self._root_url + @property + def auth(self) -> Union[AuthBase, None]: + return self._auth + + @auth.setter + def auth(self, auth: Union[AuthBase, None]): + self._auth = auth + self._on_auth_update() + + def _on_auth_update(self): + pass + def build_url(self, path: str): return url_join(self._root_url, path) @@ -340,12 +353,12 @@ def __init__( if "://" not in url: url = "https://" + url self._orig_url = url + self._capabilities_cache = LazyLoadCache() super().__init__( root_url=self.version_discovery(url, session=session, timeout=default_timeout), auth=auth, session=session, default_timeout=default_timeout, slow_response_threshold=slow_response_threshold, ) - self._capabilities_cache = LazyLoadCache() # Initial API version check. self._api_version.require_at_least(self._MINIMUM_API_VERSION) @@ -380,6 +393,10 @@ def version_discovery( # Be very lenient about failing on the well-known URI strategy. return url + def _on_auth_update(self): + super()._on_auth_update() + self._capabilities_cache.clear() + def _get_auth_config(self) -> AuthConfig: if self._auth_config is None: self._auth_config = AuthConfig() @@ -411,7 +428,6 @@ def authenticate_basic(self, username: Optional[str] = None, password: Optional[ ).json() # Switch to bearer based authentication in further requests. self.auth = BasicBearerAuth(access_token=resp["access_token"]) - self._capabilities_cache.clear() return self def _get_oidc_provider( @@ -546,7 +562,6 @@ def _authenticate_oidc( _log.warning("No OIDC refresh token to store.") token = tokens.access_token self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token) - self._capabilities_cache.clear() self._oidc_auth_renewer = oidc_auth_renewer return self diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 02d55e1d7..643f734c3 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -49,6 +49,7 @@ API_URL = "https://oeo.test/" +# TODO: eliminate this and replace with `build_capabilities` usage BASIC_ENDPOINTS = [{"path": "/credentials/basic", "methods": ["GET"]}] @@ -551,83 +552,102 @@ def test_capabilities_caching(requests_mock): assert con.capabilities().api_version() == "1.0.0" assert m.call_count == 1 -def test_capabilities_caching_after_authenticate_basic(requests_mock): - user, pwd = "john262", "J0hndo3" - def get_capabilities(request, context): - endpoints = BASIC_ENDPOINTS.copy() - if "Authorization" in request.headers: - endpoints.append({"path": "/account/status", "methods": ["GET"]}) - return {"api_version": "1.0.0", "endpoints": endpoints} +def _get_capabilities_auth_dependent(request, context): + capabilities = build_capabilities() + capabilities["endpoints"] = [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] + if "Authorization" in request.headers: + capabilities["endpoints"].append({"methods": ["GET"], "path": "/me"}) + return capabilities + - get_capabilities_mock = requests_mock.get(API_URL, json=get_capabilities) +def test_capabilities_caching_after_authenticate_basic(requests_mock): + user, pwd = "john262", "J0hndo3" + get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent) requests_mock.get(API_URL + 'credentials/basic', text=_credentials_basic_handler(user, pwd)) con = Connection(API_URL) - assert con.capabilities().capabilities == { - "api_version": "1.0.0", - "endpoints": [ - {"methods": ["GET"], "path": "/credentials/basic"}, - ], - } + assert con.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] assert get_capabilities_mock.call_count == 1 con.capabilities() assert get_capabilities_mock.call_count == 1 - con.authenticate_basic(user, pwd) + con.authenticate_basic(username=user, password=pwd) assert get_capabilities_mock.call_count == 1 - assert con.capabilities().capabilities == { - "api_version": "1.0.0", - "endpoints": [ - {"methods": ["GET"], "path": "/credentials/basic"}, - {"methods": ["GET"], "path": "/account/status"}, - ], - } - assert get_capabilities_mock.call_count == 2 + assert con.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + {"methods": ["GET"], "path": "/me"}, + ] + assert get_capabilities_mock.call_count == 2 -def test_capabilities_caching_after_authenticate_oidc(requests_mock): +def test_capabilities_caching_after_authenticate_oidc_refresh_token(requests_mock): client_id = "myclient" - - def get_capabilities(request, context): - endpoints = BASIC_ENDPOINTS.copy() - if "Authorization" in request.headers: - endpoints.append({"path": "/account/status", "methods": ["GET"]}) - return {"api_version": "1.0.0", "endpoints": endpoints} - - get_capabilities_mock = requests_mock.get(API_URL, json=get_capabilities) - requests_mock.get(API_URL + 'credentials/oidc', json={ - "providers": [{"id": "fauth", "issuer": "https://fauth.test", "title": "Foo Auth", "scopes": ["openid", "im"]}] - }) + refresh_token = "fr65h!" + get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent) + requests_mock.get( + API_URL + "credentials/oidc", + json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]}, + ) oidc_mock = OidcMock( requests_mock=requests_mock, - expected_grant_type="authorization_code", + expected_grant_type="refresh_token", expected_client_id=client_id, - expected_fields={"scope": "im openid"}, - oidc_issuer="https://fauth.test", - scopes_supported=["openid", "im"], + expected_fields={"refresh_token": refresh_token}, ) + conn = Connection(API_URL) - assert conn.capabilities().capabilities == { - "api_version": "1.0.0", - "endpoints": [ - {"methods": ["GET"], "path": "/credentials/basic"}, - ], - } + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] + assert get_capabilities_mock.call_count == 1 conn.capabilities() assert get_capabilities_mock.call_count == 1 - conn.authenticate_oidc_authorization_code(client_id=client_id, webbrowser_open=oidc_mock.webbrowser_open) + conn.authenticate_oidc_refresh_token(client_id=client_id, refresh_token=refresh_token) assert get_capabilities_mock.call_count == 1 - assert conn.capabilities().capabilities == { - "api_version": "1.0.0", - "endpoints": [ - {"methods": ["GET"], "path": "/credentials/basic"}, - {"methods": ["GET"], "path": "/account/status"}, - ], - } + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + {"methods": ["GET"], "path": "/me"}, + ] + assert get_capabilities_mock.call_count == 2 + + +def test_capabilities_caching_after_authenticate_oidc_access_token(requests_mock): + get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent) + requests_mock.get( + API_URL + "credentials/oidc", + json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]}, + ) + + conn = Connection(API_URL) + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] + + assert get_capabilities_mock.call_count == 1 + conn.capabilities() + assert get_capabilities_mock.call_count == 1 + + conn.authenticate_oidc_access_token(access_token="6cc355!") + assert get_capabilities_mock.call_count == 1 + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + {"methods": ["GET"], "path": "/me"}, + ] assert get_capabilities_mock.call_count == 2