diff --git a/asgiref/sync.py b/asgiref/sync.py index 4427fc2a..87ee4064 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -217,8 +217,12 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: sys.exc_info(), task_context, context, - *args, - **kwargs, + # prepare an awaitable which can be passed as is to self.main_wrap, + # so that `args` and `kwargs` don't need to be + # destructured when passed to self.main_wrap + # (which is required by `ParamSpec`) + # as that may cause overlapping arguments + self.awaitable(*args, **kwargs), ) if not (self.main_event_loop and self.main_event_loop.is_running()): @@ -302,8 +306,7 @@ async def main_wrap( exc_info: "OptExcInfo", task_context: "Optional[List[asyncio.Task[Any]]]", context: List[contextvars.Context], - *args: _P.args, - **kwargs: _P.kwargs, + awaitable: Union[Coroutine[Any, Any, _R], Awaitable[_R]], ) -> None: """ Wraps the awaitable with something that puts the result into the @@ -326,9 +329,9 @@ async def main_wrap( try: raise exc_info[1] except BaseException: - result = await self.awaitable(*args, **kwargs) + result = await awaitable else: - result = await self.awaitable(*args, **kwargs) + result = await awaitable except BaseException as e: call_result.set_exception(e) else: diff --git a/tests/test_sync.py b/tests/test_sync.py index a4d2413b..0c67308c 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -7,6 +7,7 @@ import warnings from concurrent.futures import ThreadPoolExecutor from functools import wraps +from typing import Any from unittest import TestCase import pytest @@ -1174,3 +1175,36 @@ async def async_task(): assert task_complete assert task_executed + + +def test_async_to_sync_overlapping_kwargs() -> None: + """ + Tests that AsyncToSync correctly passes through kwargs to the wrapped function, + particularly in the case where the wrapped function uses same names for the parameters + as the wrapper. + """ + + @async_to_sync + async def test_function(**kwargs: Any) -> None: + assert kwargs + + # AsyncToSync.main_wrap has a param named `context`. + # So we pass the same argument here to test for the error + # "AsyncToSync.main_wrap() got multiple values for argument ''" + test_function(context=1) + + +@pytest.mark.asyncio +async def test_sync_to_async_overlapping_kwargs() -> None: + """ + Tests that SyncToAsync correctly passes through kwargs to the wrapped function, + particularly in the case where the wrapped function uses same names for the parameters + as the wrapper. + """ + + @sync_to_async + def test_function(**kwargs: Any) -> None: + assert kwargs + + # SyncToAsync.__call__.loop.run_in_executor has a param named `task_context`. + await test_function(task_context=1)