From f81e493487d872198981fa6cefb3a0d93ab03c08 Mon Sep 17 00:00:00 2001 From: guacs <126393040+guacs@users.noreply.github.com> Date: Sun, 19 Nov 2023 17:14:22 +0530 Subject: [PATCH] Improve the handling of stringified annotations in `_takes_container` (#55) --- conftest.py | 1 - src/svcs/_core.py | 25 +++++++++----- tests/test_registry.py | 76 +++++++++++++++++++++++++++++++++++++++--- 3 files changed, 88 insertions(+), 14 deletions(-) diff --git a/conftest.py b/conftest.py index c17cc56..f12f920 100644 --- a/conftest.py +++ b/conftest.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: MIT - from doctest import ELLIPSIS import pytest diff --git a/src/svcs/_core.py b/src/svcs/_core.py index 386ec75..67cd627 100644 --- a/src/svcs/_core.py +++ b/src/svcs/_core.py @@ -403,9 +403,18 @@ def _takes_container(factory: Callable) -> bool: Return True if *factory* takes a svcs.Container as its first argument. """ try: - sig = inspect.signature(factory) + # Provide the locals so that `eval_str` will work even if the user places the `Container` + # under a `if TYPE_CHECKING` block + sig = inspect.signature( + factory, locals={"Container": Container}, eval_str=True + ) except Exception: # noqa: BLE001 - return False + # Retry without `eval_str` since if the annotation is "svcs.Container" the eval + # will fail due to it not finding the `svcs` module + try: + sig = inspect.signature(factory) + except Exception: # noqa: BLE001 + return False if not sig.parameters: return False @@ -415,13 +424,13 @@ def _takes_container(factory: Callable) -> bool: raise TypeError(msg) ((name, p),) = tuple(sig.parameters.items()) - if name == "svcs_container": - return True - - if (annot := p.annotation) is Container or annot == "svcs.Container": - return True - return False + return ( + name == "svcs_container" + or p.annotation is Container + or p.annotation == "svcs.Container" + or p.annotation == "Container" + ) T1 = TypeVar("T1") diff --git a/tests/test_registry.py b/tests/test_registry.py index 4b9370f..e4f3761 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -4,8 +4,10 @@ import contextlib import gc +import importlib.util import inspect import logging +import sys from unittest.mock import AsyncMock, Mock @@ -24,6 +26,23 @@ ) +@pytest.fixture(name="create_module") +def _create_module(tmp_path): + def wrapper(source): + module_name = "_svcs_testing_tmp_module" + module_path = tmp_path / f"{module_name}.py" + module_path.write_text(source) + + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + return module + + return wrapper + + class TestRegistry: def test_repr_empty(self, registry): """ @@ -316,6 +335,51 @@ def diff_name(): ... +takes_containers_annotation_string_modules = ( + """ +from __future__ import annotations + +from svcs import Container + +def factory(container: Container) -> int: + ... + """, + """ +from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from svcs import Container + +def factory(container: Container) -> int: + ... + """, + """ +def factory(container: "svcs.Container") -> int: + ... + """, + """ +from __future__ import annotations + +import svcs + +def factory(container: svcs.Container) -> int: + ... + + """, + """ +from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import svcs + +def factory(container: svcs.Container) -> int: + ... + """, +) + + class TestTakesContainer: @pytest.mark.parametrize( "factory", @@ -347,16 +411,18 @@ def factory(foo: svcs.Container): assert svcs._core._takes_container(factory) - def test_annotation_str(self): + @pytest.mark.parametrize( + "module_source", takes_containers_annotation_string_modules + ) + def test_annotation_str(self, module_source, create_module): """ - Return true if the first argument is annotated as `svcs.Container` + Return `True` if the first argument is annotated as `svcs.Container` using a string. """ - def factory(bar: "svcs.Container"): - ... + module = create_module(module_source) - assert svcs._core._takes_container(factory) + assert svcs._core._takes_container(module.factory) def test_catches_invalid_sigs(self): """