diff --git a/CHANGELOG.md b/CHANGELOG.md index 7971fa8..77b4ca8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,10 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/ [#72](https://github.com/hynek/svcs/issues/72) [#73](https://github.com/hynek/svcs/pull/73) +- `Registry.register_factory()` is now more lenient regarding the arguments of the factory. + It only looks at the first argument (if present) and ignores the rest. + [#110](https://github.com/hynek/svcs/pull/110) + ### Fixed diff --git a/src/svcs/_core.py b/src/svcs/_core.py index 73ba965..c9d1342 100644 --- a/src/svcs/_core.py +++ b/src/svcs/_core.py @@ -279,6 +279,9 @@ def register_factory( Can also be an async callable or an :class:`collections.abc.Awaitable`; then :meth:`svcs.Registry.aclose()` must be called. + + .. versionchanged:: 24.2.0 + *factory* now may take any amount of arguments and they are ignored. """ rs = self._register_factory( svc_type, @@ -454,14 +457,10 @@ def _takes_container(factory: Callable) -> bool: except Exception: # noqa: BLE001 return False - if not sig.parameters: - return False - - if len(sig.parameters) != 1: - msg = "Factories must take 0 or 1 parameters." - raise TypeError(msg) - - ((name, p),) = tuple(sig.parameters.items()) + try: + (name, p) = next(iter(sig.parameters.items())) + except StopIteration: + return False # 0 arguments return name == "svcs_container" or p.annotation in ( Container, diff --git a/tests/test_registry.py b/tests/test_registry.py index a11c266..8f967a2 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -432,17 +432,14 @@ def test_annotation_str(self, module_source, create_module): assert svcs._core._takes_container(module.factory) - def test_catches_invalid_sigs(self): + def test_ignores_invalid_sigs(self): """ - If the factory takes more than one parameter, raise an TypeError. + If the first parameter is anything but what we handle, ignore it. """ def factory(foo, bar): ... - with pytest.raises( - TypeError, match="Factories must take 0 or 1 parameters." - ): - svcs._core._takes_container(factory) + assert not svcs._core._takes_container(factory) def test_call_works(self): """