Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ParamSpec for wrapped signatures #508

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[mypy]
files = async_lru, tests
check_untyped_defs = True
follow_imports_for_stubs = True
disallow_any_decorated = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a reason for enumerating all mypy options?
I would prefer strict = true plus maybe a few additional fine-tunes.
Maintaining all mypy opts adds a burden; I never could keep all these options in my mind without looking in docs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds a lot more options than strict. This is the config we started in aiohttp a few years ago. Basically, strict isn't so strict, so we like to enable all the options that aren't too much.

At some point I'd like to find a way to make a config that's reusable across all projects without copy/paste, which will ease the maintenance burden.

disallow_any_generics = True
disallow_any_unimported = True
disallow_incomplete_defs = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_decorators = True
disallow_untyped_defs = True
enable_error_code = ignore-without-code, possibly-undefined, redundant-expr, redundant-self, truthy-bool, truthy-iterable, unused-awaitable
implicit_reexport = False
no_implicit_optional = True
pretty = True
show_column_numbers = True
show_error_codes = True
strict_equality = True
warn_incomplete_stub = True
warn_redundant_casts = True
warn_return_any = True
warn_unreachable = True
warn_unused_ignores = True
49 changes: 29 additions & 20 deletions async_lru/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
)


if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


if sys.version_info >= (3, 11):
from typing import Self
else:
Expand All @@ -35,9 +41,10 @@

_T = TypeVar("_T")
_R = TypeVar("_R")
_P = ParamSpec("_P")
_Coro = Coroutine[Any, Any, _R]
_CB = Callable[..., _Coro[_R]]
_CBP = Union[_CB[_R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"]
_CB = Callable[_P, _Coro[_R]]
_CBP = Union[_CB[_P, _R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"]


@final
Expand All @@ -61,10 +68,10 @@ def cancel(self) -> None:


@final
class _LRUCacheWrapper(Generic[_R]):
class _LRUCacheWrapper(Generic[_P, _R]):
def __init__(
self,
fn: _CB[_R],
fn: _CB[_P, _R],
maxsize: Optional[int],
typed: bool,
ttl: Optional[float],
Expand Down Expand Up @@ -188,7 +195,7 @@ def _task_done_callback(

fut.set_result(task.result())

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
async def __call__(self, /, *fn_args: _P.args, **fn_kwargs: _P.kwargs) -> _R:
if self.__closed:
raise RuntimeError(f"alru_cache is closed for {self}")

Expand All @@ -207,7 +214,7 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:

fut = loop.create_future()
coro = self.__wrapped__(*fn_args, **fn_kwargs)
task: asyncio.Task[_R] = loop.create_task(coro)
task = loop.create_task(coro)
self.__tasks.add(task)
task.add_done_callback(partial(self._task_done_callback, fut, key))

Expand All @@ -222,18 +229,18 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:

def __get__(
self, instance: _T, owner: Optional[Type[_T]]
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]:
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_P, _R, _T]"]:
if owner is None:
return self
else:
return _LRUCacheWrapperInstanceMethod(self, instance)


@final
class _LRUCacheWrapperInstanceMethod(Generic[_R, _T]):
class _LRUCacheWrapperInstanceMethod(Generic[_P, _R, _T]):
def __init__(
self,
wrapper: _LRUCacheWrapper[_R],
wrapper: _LRUCacheWrapper[_P, _R],
instance: _T,
) -> None:
try:
Expand Down Expand Up @@ -284,16 +291,16 @@ def cache_info(self) -> _CacheInfo:
def cache_parameters(self) -> _CacheParameters:
return self.__wrapper.cache_parameters()

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs)
async def __call__(self, /, *fn_args: _P.args, **fn_kwargs: _P.kwargs) -> _R:
return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs) # type: ignore[arg-type]


def _make_wrapper(
maxsize: Optional[int],
typed: bool,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
) -> Callable[[_CBP[_P, _R]], _LRUCacheWrapper[_P, _R]]:
def wrapper(fn: _CBP[_P, _R]) -> _LRUCacheWrapper[_P, _R]:
origin = fn

while isinstance(origin, (partial, partialmethod)):
Expand All @@ -306,7 +313,7 @@ def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
if hasattr(fn, "_make_unbound_method"):
fn = fn._make_unbound_method()

return _LRUCacheWrapper(cast(_CB[_R], fn), maxsize, typed, ttl)
return _LRUCacheWrapper(cast(_CB[_P, _R], fn), maxsize, typed, ttl)

return wrapper

Expand All @@ -317,28 +324,30 @@ def alru_cache(
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
) -> Callable[[_CBP[_P, _R]], _LRUCacheWrapper[_P, _R]]:
...


@overload
def alru_cache(
maxsize: _CBP[_R],
maxsize: _CBP[_P, _R],
/,
) -> _LRUCacheWrapper[_R]:
) -> _LRUCacheWrapper[_P, _R]:
...


def alru_cache(
maxsize: Union[Optional[int], _CBP[_R]] = 128,
maxsize: Union[Optional[int], _CBP[_P, _R]] = 128,
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Union[Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], _LRUCacheWrapper[_R]]:
) -> Union[
Callable[[_CBP[_P, _R]], _LRUCacheWrapper[_P, _R]], _LRUCacheWrapper[_P, _R]
]:
if maxsize is None or isinstance(maxsize, int):
return _make_wrapper(maxsize, typed, ttl)
else:
fn = cast(_CB[_R], maxsize)
fn = maxsize

if callable(fn) or hasattr(fn, "_make_unbound_method"):
return _make_wrapper(128, False, None)(fn)
Expand Down
5 changes: 0 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,3 @@ junit_family=xunit2
asyncio_mode=auto
timeout=15
xfail_strict = true

[mypy]
strict=True
pretty=True
packages=async_lru, tests
19 changes: 15 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
import sys
from functools import _CacheInfo
from typing import Callable
from typing import Callable, TypeVar

import pytest

from async_lru import _R, _LRUCacheWrapper
from async_lru import _LRUCacheWrapper


if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


_T = TypeVar("_T")
_P = ParamSpec("_P")


@pytest.fixture
def check_lru() -> Callable[..., None]:
def check_lru() -> Callable[..., None]: # type: ignore[misc]
def _check_lru(
wrapped: _LRUCacheWrapper[_R],
wrapped: _LRUCacheWrapper[_P, _T],
*,
hits: int,
misses: int,
Expand Down
Loading