Skip to content

Commit

Permalink
Improve the handling of stringified annotations in _takes_container (
Browse files Browse the repository at this point in the history
  • Loading branch information
guacs authored Nov 19, 2023
1 parent 4b05ab8 commit f81e493
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 14 deletions.
1 change: 0 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: MIT


from doctest import ELLIPSIS

import pytest
Expand Down
25 changes: 17 additions & 8 deletions src/svcs/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,18 @@ def _takes_container(factory: Callable) -> bool:
Return True if *factory* takes a svcs.Container as its first argument.
"""
try:
sig = inspect.signature(factory)
# Provide the locals so that `eval_str` will work even if the user places the `Container`
# under a `if TYPE_CHECKING` block
sig = inspect.signature(
factory, locals={"Container": Container}, eval_str=True
)
except Exception: # noqa: BLE001
return False
# Retry without `eval_str` since if the annotation is "svcs.Container" the eval
# will fail due to it not finding the `svcs` module
try:
sig = inspect.signature(factory)
except Exception: # noqa: BLE001
return False

if not sig.parameters:
return False
Expand All @@ -415,13 +424,13 @@ def _takes_container(factory: Callable) -> bool:
raise TypeError(msg)

((name, p),) = tuple(sig.parameters.items())
if name == "svcs_container":
return True

if (annot := p.annotation) is Container or annot == "svcs.Container":
return True

return False
return (
name == "svcs_container"
or p.annotation is Container
or p.annotation == "svcs.Container"
or p.annotation == "Container"
)


T1 = TypeVar("T1")
Expand Down
76 changes: 71 additions & 5 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import contextlib
import gc
import importlib.util
import inspect
import logging
import sys

from unittest.mock import AsyncMock, Mock

Expand All @@ -24,6 +26,23 @@
)


@pytest.fixture(name="create_module")
def _create_module(tmp_path):
def wrapper(source):
module_name = "_svcs_testing_tmp_module"
module_path = tmp_path / f"{module_name}.py"
module_path.write_text(source)

spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)

return module

return wrapper


class TestRegistry:
def test_repr_empty(self, registry):
"""
Expand Down Expand Up @@ -316,6 +335,51 @@ def diff_name():
...


takes_containers_annotation_string_modules = (
"""
from __future__ import annotations
from svcs import Container
def factory(container: Container) -> int:
...
""",
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from svcs import Container
def factory(container: Container) -> int:
...
""",
"""
def factory(container: "svcs.Container") -> int:
...
""",
"""
from __future__ import annotations
import svcs
def factory(container: svcs.Container) -> int:
...
""",
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import svcs
def factory(container: svcs.Container) -> int:
...
""",
)


class TestTakesContainer:
@pytest.mark.parametrize(
"factory",
Expand Down Expand Up @@ -347,16 +411,18 @@ def factory(foo: svcs.Container):

assert svcs._core._takes_container(factory)

def test_annotation_str(self):
@pytest.mark.parametrize(
"module_source", takes_containers_annotation_string_modules
)
def test_annotation_str(self, module_source, create_module):
"""
Return true if the first argument is annotated as `svcs.Container`
Return `True` if the first argument is annotated as `svcs.Container`
using a string.
"""

def factory(bar: "svcs.Container"):
...
module = create_module(module_source)

assert svcs._core._takes_container(factory)
assert svcs._core._takes_container(module.factory)

def test_catches_invalid_sigs(self):
"""
Expand Down

0 comments on commit f81e493

Please sign in to comment.