From 4ace441d356f2587603435280c5bbe2e9113b739 Mon Sep 17 00:00:00 2001 From: Hynek Schlawack Date: Wed, 3 Jul 2024 11:30:45 +0200 Subject: [PATCH] Enter/exit sync contextmanagers in aget() (#93) Co-authored-by: Adrian Schneider <17550019+adrianschneider94@users.noreply.github.com> --- CHANGELOG.md | 6 ++++++ pyproject.toml | 4 ++++ src/svcs/_core.py | 6 ++++++ tests/test_container.py | 20 ++++++++++++++++++++ 4 files changed, 36 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1abd8df..e0e967f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,12 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/ [#73](https://github.com/hynek/svcs/pull/73) +### Fixed + +- `Container.aget()` now also enters and exits synchronous context managers. + [#93](https://github.com/hynek/svcs/pull/93) + + ## [24.1.0](https://github.com/hynek/svcs/compare/23.21.0...24.1.0) - 2024-01-25 ### Fixed diff --git a/pyproject.toml b/pyproject.toml index d3174c4..5d30b8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,6 +182,10 @@ ignore_missing_imports = true module = "tests.*" ignore_errors = true +[[tool.mypy.overrides]] +module = "conftest" +ignore_errors = true + [[tool.mypy.overrides]] module = "tests.typing.*" ignore_errors = false diff --git a/src/svcs/_core.py b/src/svcs/_core.py index 42b44f4..2c39ff8 100644 --- a/src/svcs/_core.py +++ b/src/svcs/_core.py @@ -995,6 +995,9 @@ async def aget(self, *svc_types: type) -> object: Also works with synchronous services, so in an async application, just use this. + + .. versionchanged:: 24.2.0 + Synchronous context managers are now entered/exited, too. """ rv = [] for svc_type in svc_types: @@ -1006,6 +1009,9 @@ async def aget(self, *svc_types: type) -> object: if enter and isinstance(svc, AbstractAsyncContextManager): self._on_close.append((name, svc)) svc = await svc.__aenter__() + elif enter and isinstance(svc, AbstractContextManager): + self._on_close.append((name, svc)) + svc = svc.__enter__() elif isawaitable(svc): svc = await svc diff --git a/tests/test_container.py b/tests/test_container.py index bc38204..6e9f269 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -109,6 +109,26 @@ def scope(): "Container was garbage-collected with pending cleanups.", ) == recwarn.list[0].message.args + @pytest.mark.asyncio() + async def test_aget_enters_sync_contextmanagers(self, container): + """ + aget enters (and exits) synchronous context managers. + """ + is_closed = False + + def factory(): + yield 42 + + nonlocal is_closed + is_closed = True + + container.registry.register_factory(int, factory) + + async with container: + assert 42 == await container.aget(int) + + assert is_closed + class TestServicePing: def test_ping(self, registry, container, close_me):