diff --git a/README.md b/README.md index 8b5416c..730ab63 100644 --- a/README.md +++ b/README.md @@ -99,8 +99,9 @@ from resty.clients.httpx import RESTClient async def main(): client = RESTClient(httpx.AsyncClient(base_url="https://localhost:8000")) - response = await UserManager.create( - client=client, + manager = UserManager(client=client) + + response = await manager.create( obj=UserCreateSchema( username="admin", email="admin@admin.com", @@ -111,44 +112,44 @@ async def main(): ) print(response) # id=1 username='admin' email='admin@admin.com' age=19 - response = await UserManager.read( - client=client, + response = await manager.read( response_type=UserReadSchema, ) for obj in response: print(obj) # id=1 username='admin' email='admin@admin.com' age=19 - response = await UserManager.read_one( - client=client, + response = await manager.read_one( obj_or_pk=1, response_type=UserReadSchema, ) print(response) # id=1 username='admin' email='admin@admin.com' age=19 - response = await UserManager.update( - client=client, - obj=UserUpdateSchema(id=1, username="admin123", ), + response = await manager.update( + obj=UserUpdateSchema( + id=1, + username="admin123", + ), response_type=UserReadSchema, ) print(response) # id=1 username='admin123' email='admin@admin.com' age=19 - await UserManager.delete( - client=client, + await manager.delete( obj_or_pk=1, expected_status=204, ) + if __name__ == "__main__": asyncio.run(main()) ``` ## Status -``0.0.5`` - **RELEASED** +``0.0.6`` - **RELEASED** ## Licence diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index e026815..8cbf998 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -23,3 +23,37 @@ - Improved test coverage to 100% - Improved architecture - Added examples + +## v0.0.6 + +- Changed Manager API: + +Now instantiating manager is required. + +You can pass the REST Client into the constructor: + +```python +manager = UserManager(client=client) + +response = await manager.read( + response_type=UserReadSchema, +) +``` + +or specify the client explicitly during calling: + +```python +manager = UserManager() + +response = await manager.read( + response_type=UserReadSchema, + client=client, +) +``` + +- Added Django pagination middlewares: + + - `LimitOffsetPaginationMiddleware` + - `PagePaginationMiddleware` + + diff --git a/examples/crud.py b/examples/crud.py index b257b78..bb3170d 100644 --- a/examples/crud.py +++ b/examples/crud.py @@ -43,8 +43,9 @@ class UserManager(Manager): async def main(): client = RESTClient(httpx.AsyncClient(base_url="https://localhost:8000")) - response = await UserManager.create( - client=client, + manager = UserManager(client=client) + + response = await manager.create( obj=UserCreateSchema( username="admin", email="admin@admin.com", @@ -55,32 +56,31 @@ async def main(): ) print(response) # id=1 username='admin' email='admin@admin.com' age=19 - response = await UserManager.read( - client=client, + response = await manager.read( response_type=UserReadSchema, ) for obj in response: print(obj) # id=1 username='admin' email='admin@admin.com' age=19 - response = await UserManager.read_one( - client=client, + response = await manager.read_one( obj_or_pk=1, response_type=UserReadSchema, ) print(response) # id=1 username='admin' email='admin@admin.com' age=19 - response = await UserManager.update( - client=client, - obj=UserUpdateSchema(id=1, username="admin123", ), + response = await manager.update( + obj=UserUpdateSchema( + id=1, + username="admin123", + ), response_type=UserReadSchema, ) print(response) # id=1 username='admin123' email='admin@admin.com' age=19 - await UserManager.delete( - client=client, + await manager.delete( obj_or_pk=1, expected_status=204, ) diff --git a/examples/middlewares.py b/examples/middlewares/common.py similarity index 100% rename from examples/middlewares.py rename to examples/middlewares/common.py diff --git a/examples/middlewares/django.py b/examples/middlewares/django.py new file mode 100644 index 0000000..873032e --- /dev/null +++ b/examples/middlewares/django.py @@ -0,0 +1,70 @@ +import asyncio + +import httpx + +from resty.enums import Endpoint, Field +from resty.types import Schema +from resty.managers import Manager +from resty.clients.httpx import RESTClient +from resty.ext.django.middlewares.pagination import ( + LimitOffsetPaginationMiddleware, + PagePaginationMiddleware, +) + + +class UserCreateSchema(Schema): + username: str + email: str + password: str + age: int + + +class UserReadSchema(Schema): + id: int + username: str + email: str + age: int + + +class UserUpdateSchema(Schema): + username: str = None + email: str = None + + +class UserManager(Manager): + endpoints = { + Endpoint.CREATE: "users/", + Endpoint.READ: "users/", + Endpoint.READ_ONE: "users/{pk}", + Endpoint.UPDATE: "users/{pk}", + Endpoint.DELETE: "users/{pk}", + } + fields = { + Field.PRIMARY: "id", + } + + +async def main(): + client = RESTClient(httpx.AsyncClient(base_url="https://localhost:8000")) + + # Using LimitOffset pagination middleware + with client.middlewares.middleware(LimitOffsetPaginationMiddleware(limit=200)): + manager = UserManager(client=client) + + paginated_response = await manager.read( + response_type=UserReadSchema, + offset=100, + ) + + # Using Page pagination middleware + with client.middlewares.middleware(PagePaginationMiddleware()): + manager = UserManager(client=client) + + paginated_response = await manager.read( + response_type=UserReadSchema, + page=3, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 4a4f505..188c5a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "resty-client" -version = "0.0.5" +version = "0.0.6" description = "RestyClient is a simple, easy-to-use Python library for interacting with REST APIs using Pydantic's powerful data validation and deserialization tools." authors = ["CrazyProger1 "] license = "MIT" diff --git a/resty/__version__.py b/resty/__version__.py index 90300ba..5937e1a 100644 --- a/resty/__version__.py +++ b/resty/__version__.py @@ -1,4 +1,4 @@ __title__ = "Resty-Client" -__version__ = "0.0.5" +__version__ = "0.0.6" __description__ = """RestyClient is a simple, easy-to-use Python library for interacting with REST APIs using Pydantic's powerful data validation and deserialization tools.""" diff --git a/resty/clients/httpx/clients.py b/resty/clients/httpx/clients.py index 148891c..e13a094 100644 --- a/resty/clients/httpx/clients.py +++ b/resty/clients/httpx/clients.py @@ -1,4 +1,5 @@ import json +from urllib.parse import urljoin import httpx @@ -20,10 +21,10 @@ class RESTClient(BaseRESTClient): def __init__( - self, - httpx_client: httpx.AsyncClient = None, - check_status: bool = True, - middleware_manager: BaseMiddlewareManager = None, + self, + httpx_client: httpx.AsyncClient = None, + check_status: bool = True, + middleware_manager: BaseMiddlewareManager = None, ): self.middlewares = middleware_manager or MiddlewareManager() self._xclient = httpx_client or httpx.AsyncClient() @@ -45,7 +46,7 @@ async def _make_xrequest(self, request: Request) -> httpx.Response: timeout=request.timeout, ) except httpx.ConnectError: - raise ConnectError(url=request.url) + raise ConnectError(url=urljoin(str(self._xclient.base_url), request.url)) @staticmethod def _extract_json_data(xresponse: httpx.Response) -> dict | list: @@ -57,7 +58,7 @@ def _extract_json_data(xresponse: httpx.Response) -> dict | list: return data async def _parse_xresponse( - self, request: Request, xresponse: httpx.Response + self, request: Request, xresponse: httpx.Response ) -> Response: return Response( request=request, diff --git a/resty/ext/__init__.py b/resty/ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/resty/ext/django/__init__.py b/resty/ext/django/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/resty/ext/django/middlewares/__init__.py b/resty/ext/django/middlewares/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/resty/ext/django/middlewares/pagination/__init__.py b/resty/ext/django/middlewares/pagination/__init__.py new file mode 100644 index 0000000..9f7a200 --- /dev/null +++ b/resty/ext/django/middlewares/pagination/__init__.py @@ -0,0 +1,6 @@ +from .middlewares import LimitOffsetPaginationMiddleware, PagePaginationMiddleware + +__all__ = [ + "LimitOffsetPaginationMiddleware", + "PagePaginationMiddleware", +] diff --git a/resty/ext/django/middlewares/pagination/constants.py b/resty/ext/django/middlewares/pagination/constants.py new file mode 100644 index 0000000..b7c724a --- /dev/null +++ b/resty/ext/django/middlewares/pagination/constants.py @@ -0,0 +1 @@ +DEFAULT_LIMIT = 100 diff --git a/resty/ext/django/middlewares/pagination/middlewares.py b/resty/ext/django/middlewares/pagination/middlewares.py new file mode 100644 index 0000000..1dfc0ac --- /dev/null +++ b/resty/ext/django/middlewares/pagination/middlewares.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractmethod +from typing import Container + +from resty.enums import Endpoint +from resty.types import Request, Response +from resty.middlewares import BaseRequestMiddleware, BaseResponseMiddleware + +from .constants import DEFAULT_LIMIT + + +class PaginationMiddleware(BaseRequestMiddleware, BaseResponseMiddleware, ABC): + def __init__(self, endpoints: Container[Endpoint] = None): + self._endpoints = endpoints or { + Endpoint.READ, + } + + @abstractmethod + async def paginate(self, request: Request, **kwargs): # pragma: nocover + ... + + async def unpaginate(self, response: Response, **kwargs): + response.json = response.json.get("results", response.json) + + async def _handle_request(self, request: Request, **kwargs): + if request.endpoint in self._endpoints: + await self.paginate(request=request, **kwargs) + + async def _handle_response(self, response: Response, **kwargs): + if response.request.endpoint in self._endpoints: + await self.unpaginate(response=response, **kwargs) + + async def __call__(self, reqresp: Request | Response, **kwargs): + if isinstance(reqresp, Request): + return await self._handle_request(request=reqresp, **kwargs) + return await self._handle_response(response=reqresp, **kwargs) + + +class LimitOffsetPaginationMiddleware(PaginationMiddleware): + def __init__(self, limit: int = DEFAULT_LIMIT, **kwargs): + self._limit = limit + super().__init__(**kwargs) + + async def paginate(self, request: Request, **kwargs): + request.params.update( + { + "limit": kwargs.pop("limit", self._limit), + "offset": kwargs.pop("offset", 0), + } + ) + + +class PagePaginationMiddleware(PaginationMiddleware): + async def paginate(self, request: Request, **kwargs): + request.params.update( + { + "page": kwargs.pop("page", 1), + } + ) diff --git a/resty/managers/builders.py b/resty/managers/builders.py index 7f3ff57..2b7abf6 100644 --- a/resty/managers/builders.py +++ b/resty/managers/builders.py @@ -25,22 +25,19 @@ def _normalize_url(cls, url: str | None) -> str: if not url: return "" - if not url.endswith('/'): - return url + '/' + if not url.endswith("/"): + return url + "/" return url @classmethod def build( - cls, endpoints: Endpoints, endpoint: Endpoint, base_url: str = None, **kwargs + cls, endpoints: Endpoints, endpoint: Endpoint, base_url: str = None, **kwargs ) -> str: endpoint_url = cls._get_endpoint_url(endpoints=endpoints, endpoint=endpoint) if endpoint_url: - url = urljoin( - cls._normalize_url(url=base_url), - endpoint_url - ) + url = urljoin(cls._normalize_url(url=base_url), endpoint_url) else: url = base_url or "" diff --git a/resty/managers/managers.py b/resty/managers/managers.py index 70029e0..6288b0f 100644 --- a/resty/managers/managers.py +++ b/resty/managers/managers.py @@ -1,5 +1,5 @@ import inspect -from typing import Mapping, Iterable, Callable +from typing import Mapping, Iterable from resty.clients import BaseRESTClient from resty.enums import Endpoint, Method, Field @@ -13,113 +13,120 @@ class Manager(BaseManager): serializer_class = Serializer url_builder_class = URLBuilder - @classmethod - def get_serializer(cls, **kwargs) -> type[BaseSerializer]: - serializer = kwargs.pop('serializer', cls.serializer_class) + def __init__(self, client: BaseRESTClient = None): + self._client = client + + def get_serializer(self, **kwargs) -> type[BaseSerializer]: + serializer = kwargs.pop("serializer", self.serializer_class) if not serializer: - raise RuntimeError('Serializer not specified') + raise RuntimeError("Serializer not specified") if not ( - isinstance(serializer, BaseSerializer) - or inspect.isclass(serializer) - and issubclass(serializer, BaseSerializer) + isinstance(serializer, BaseSerializer) + or inspect.isclass(serializer) + and issubclass(serializer, BaseSerializer) ): - raise RuntimeError('The serializer must be a subclass of BaseSerializer') + raise RuntimeError("The serializer must be a subclass of BaseSerializer") return serializer - @classmethod - def get_method(cls, endpoint: Endpoint, **kwargs) -> Method: - method = cls.methods.get(endpoint) + def get_method(self, endpoint: Endpoint, **kwargs) -> Method: + method = self.methods.get(endpoint) if not method: - raise RuntimeError(f'Method not specified for endpoint: {endpoint}') + raise RuntimeError(f"Method not specified for endpoint: {endpoint}") return method - @classmethod - def get_field(cls, field: Field) -> str: - field = cls.fields.get(field) + def get_field(self, field: Field) -> str: + field = self.fields.get(field) if not field: - raise RuntimeError(f'Field not specified: {field}') + raise RuntimeError(f"Field not specified: {field}") return field - @classmethod - def get_pk(cls, obj: Schema | Mapping) -> any: - field = cls.get_field(Field.PRIMARY) + def get_pk(self, obj: Schema | Mapping) -> any: + field = self.get_field(Field.PRIMARY) if isinstance(obj, Mapping): return obj.get(field) return getattr(obj, field, None) - @classmethod - def _get_pk(cls, obj_or_pk: any) -> any: + def _get_pk(self, obj_or_pk: any) -> any: if isinstance(obj_or_pk, Mapping | Schema): - return cls.get_pk(obj=obj_or_pk) + return self.get_pk(obj=obj_or_pk) return obj_or_pk - @classmethod - def _deserialize(cls, schema: type[Schema], data: any, **kwargs): - serializer = cls.get_serializer(**kwargs) + def _deserialize(self, schema: type[Schema], data: any, **kwargs): + serializer = self.get_serializer(**kwargs) if isinstance(data, Mapping): return serializer.deserialize(schema=schema, data=data, **kwargs) return serializer.deserialize_many(schema=schema, data=data, **kwargs) - @classmethod - def _serialize(cls, obj: Schema, **kwargs): - serializer = cls.get_serializer(**kwargs) + def _serialize(self, obj: Schema, **kwargs): + serializer = self.get_serializer(**kwargs) return serializer.serialize(obj=obj, **kwargs) - @classmethod - async def _make_request(cls, client: BaseRESTClient, request: Request) -> Response: + def _get_client(self, **kwargs) -> BaseRESTClient: + client = kwargs.pop("client", self._client) + + if not client: + raise TypeError( + "REST Client not specified. Pass it to the constructor or via kwargs" + ) + + if not isinstance(client, BaseRESTClient): + raise TypeError("Client must inherit from BaseRESTClient") + + return client + + async def _make_request(self, request: Request, **kwargs) -> Response: + client = self._get_client(**kwargs) return await client.request(request=request) - @classmethod - def _prepare_url(cls, endpoint: Endpoint, **kwargs) -> str: - url = kwargs.pop('url', None) - base_url = kwargs.pop('base_url', None) + def _prepare_url(self, endpoint: Endpoint, **kwargs) -> str: + url = kwargs.pop("url", None) + base_url = kwargs.pop("base_url", None) if isinstance(url, str): return url - return cls.url_builder_class.build( - endpoints=cls.endpoints, + return self.url_builder_class.build( + endpoints=self.endpoints, endpoint=endpoint, - base_url=base_url or cls.url, - **kwargs + base_url=base_url or self.url, + **kwargs, ) - @classmethod - def _prepare_json(cls, **kwargs): - obj = kwargs.pop('obj', None) + def _prepare_json(self, **kwargs): + obj = kwargs.pop("obj", None) if isinstance(obj, dict | list | set | tuple): return obj elif isinstance(obj, Schema): - return cls._serialize(obj, **kwargs) + return self._serialize(obj, **kwargs) return {} - @classmethod - def _prepare_request(cls, endpoint: Endpoint, **kwargs) -> Request: + def _prepare_request(self, endpoint: Endpoint, **kwargs) -> Request: return Request( - url=cls._prepare_url(endpoint=endpoint, **kwargs), - method=cls.get_method(endpoint, **kwargs), + url=self._prepare_url(endpoint=endpoint, **kwargs), + method=self.get_method(endpoint, **kwargs), endpoint=endpoint, - data=kwargs.pop('data', {}), - json=cls._prepare_json(**kwargs), - timeout=kwargs.pop('timeout', None), - params=kwargs.pop('params', {}), - headers=kwargs.pop('headers', {}), - cookies=kwargs.pop('cookies', {}), - redirects=kwargs.pop('redirects', False), + data=kwargs.pop("data", {}), + json=self._prepare_json(**kwargs), + timeout=kwargs.pop("timeout", None), + params=kwargs.pop("params", {}), + headers=kwargs.pop("headers", {}), + cookies=kwargs.pop("cookies", {}), + redirects=kwargs.pop("redirects", False), middleware_options=kwargs.copy(), ) - @classmethod - def _handle_response(cls, response: Response, response_type: ResponseType, **kwargs) -> any: + def _handle_response( + self, response: Response, response_type: ResponseType, **kwargs + ) -> any: if not response: return @@ -127,7 +134,9 @@ def _handle_response(cls, response: Response, response_type: ResponseType, **kwa if issubclass(response_type, dict | list | tuple | set): return response_type(response.json) elif issubclass(response_type, Schema): - return cls._deserialize(schema=response_type, data=response.json, **kwargs) + return self._deserialize( + schema=response_type, data=response.json, **kwargs + ) if callable(response_type): try: @@ -137,51 +146,81 @@ def _handle_response(cls, response: Response, response_type: ResponseType, **kwa return response.json - @classmethod - async def create[T: Schema](cls, client: BaseRESTClient, obj: Schema | Mapping, response_type: ResponseType = None, - **kwargs) -> T | None: - request = cls._prepare_request(endpoint=Endpoint.CREATE, obj=obj, **kwargs) - response = await cls._make_request(client=client, request=request) - return cls._handle_response(response=response, response_type=response_type, **kwargs) - - @classmethod - async def read[T: Schema](cls, client: BaseRESTClient, response_type: ResponseType = None, **kwargs) -> Iterable[T]: - request = cls._prepare_request(endpoint=Endpoint.READ, **kwargs) - response = await cls._make_request(client=client, request=request) - return cls._handle_response(response=response, response_type=response_type, **kwargs) - - @classmethod - async def read_one[T: Schema](cls, client: BaseRESTClient, obj_or_pk: Schema | Mapping | any, - response_type: ResponseType = None, **kwargs) -> T: - - request = cls._prepare_request( - endpoint=Endpoint.READ_ONE, - pk=cls._get_pk(obj_or_pk=obj_or_pk), - **kwargs + async def create[ + T: Schema + ]( + self, + obj: Schema | Mapping, + response_type: ResponseType = None, + **kwargs, + ) -> ( + T | None + ): + request = self._prepare_request(endpoint=Endpoint.CREATE, obj=obj, **kwargs) + response = await self._make_request(request=request, **kwargs) + return self._handle_response( + response=response, response_type=response_type, **kwargs + ) + + async def read[ + T: Schema + ](self, response_type: ResponseType = None, **kwargs) -> Iterable[T]: + request = self._prepare_request(endpoint=Endpoint.READ, **kwargs) + response = await self._make_request(request=request, **kwargs) + return self._handle_response( + response=response, response_type=response_type, **kwargs ) - response = await cls._make_request(client=client, request=request) - return cls._handle_response(response=response, response_type=response_type, **kwargs) - @classmethod - async def update[T: Schema](cls, client: BaseRESTClient, obj: Schema | Mapping, response_type: ResponseType = None, - **kwargs) -> T | None: - request = cls._prepare_request( + async def read_one[ + T: Schema + ]( + self, + obj_or_pk: Schema | Mapping | any, + response_type: ResponseType = None, + **kwargs, + ) -> T: + + request = self._prepare_request( + endpoint=Endpoint.READ_ONE, pk=self._get_pk(obj_or_pk=obj_or_pk), **kwargs + ) + response = await self._make_request(request=request, **kwargs) + return self._handle_response( + response=response, response_type=response_type, **kwargs + ) + + async def update[ + T: Schema + ]( + self, + obj: Schema | Mapping, + response_type: ResponseType = None, + **kwargs, + ) -> ( + T | None + ): + request = self._prepare_request( endpoint=Endpoint.UPDATE, - pk=kwargs.pop('pk', cls.get_pk(obj)), + pk=kwargs.pop("pk", self.get_pk(obj)), obj=obj, - **kwargs + **kwargs, + ) + response = await self._make_request(request=request, **kwargs) + return self._handle_response( + response=response, response_type=response_type, **kwargs + ) + + async def delete[ + T: Schema + ]( + self, + obj_or_pk: Schema | Mapping | any, + response_type: ResponseType = None, + **kwargs, + ) -> (T | None): + request = self._prepare_request( + endpoint=Endpoint.DELETE, pk=self._get_pk(obj_or_pk=obj_or_pk), **kwargs ) - response = await cls._make_request(client=client, request=request) - return cls._handle_response(response=response, response_type=response_type, **kwargs) - - @classmethod - async def delete[T: Schema](cls, client: BaseRESTClient, obj_or_pk: Schema | Mapping | any, - response_type: ResponseType = None, - **kwargs) -> T | None: - request = cls._prepare_request( - endpoint=Endpoint.DELETE, - pk=cls._get_pk(obj_or_pk=obj_or_pk), - **kwargs + response = await self._make_request(request=request, **kwargs) + return self._handle_response( + response=response, response_type=response_type, **kwargs ) - response = await cls._make_request(client=client, request=request) - return cls._handle_response(response=response, response_type=response_type, **kwargs) diff --git a/resty/managers/types.py b/resty/managers/types.py index 19acd51..93ca821 100644 --- a/resty/managers/types.py +++ b/resty/managers/types.py @@ -38,66 +38,52 @@ class BaseManager(ABC): serializer_class: BaseSerializer url_builder_class: BaseURLBuilder - @classmethod @abstractmethod - def get_serializer(cls, **kwargs) -> type[BaseSerializer]: ... + def get_serializer(self, **kwargs) -> type[BaseSerializer]: ... - @classmethod @abstractmethod - def get_method(cls, endpoint: Endpoint, **kwargs) -> Method: ... + def get_method(self, endpoint: Endpoint, **kwargs) -> Method: ... - @classmethod @abstractmethod - def get_field(cls, field: Field) -> str: ... + def get_field(self, field: Field) -> str: ... - @classmethod @abstractmethod - def get_pk(cls, obj: Schema | Mapping) -> any: ... + def get_pk(self, obj: Schema | Mapping) -> any: ... - @classmethod @abstractmethod async def create[T: Schema]( - cls, - client: BaseRESTClient, + self, obj: Schema | Mapping, response_type: ResponseType = None, **kwargs, ) -> T | None: ... - @classmethod @abstractmethod async def read[T: Schema]( - cls, - client: BaseRESTClient, + self, response_type: ResponseType = None, **kwargs, ) -> Iterable[T]: ... - @classmethod @abstractmethod async def read_one[T: Schema]( - cls, - client: BaseRESTClient, + self, obj_or_pk: Schema | Mapping | any, response_type: ResponseType = None, **kwargs, ) -> T: ... - @classmethod @abstractmethod async def update[T: Schema]( - cls, - client: BaseRESTClient, + self, obj: Schema | Mapping, response_type: ResponseType = None, **kwargs, ) -> T | None: ... - @classmethod @abstractmethod async def delete[T: Schema]( - cls, - client: BaseRESTClient, + self, obj_or_pk: Schema | Mapping | any, response_type: ResponseType = None, **kwargs, diff --git a/resty/middlewares/status.py b/resty/middlewares/status.py index d4980ae..18078f1 100644 --- a/resty/middlewares/status.py +++ b/resty/middlewares/status.py @@ -8,9 +8,9 @@ class StatusCheckingMiddleware(BaseResponseMiddleware): def __init__( - self, - errors: Mapping[int, type[Exception]] = None, - default_error: type[Exception] = HTTPError, + self, + errors: Mapping[int, type[Exception]] = None, + default_error: type[Exception] = HTTPError, ): self._errors = errors or STATUS_ERRORS self._default_error = default_error diff --git a/tests/ext/django/test_django_pagination_middlewares.py b/tests/ext/django/test_django_pagination_middlewares.py new file mode 100644 index 0000000..297cac5 --- /dev/null +++ b/tests/ext/django/test_django_pagination_middlewares.py @@ -0,0 +1,99 @@ +import pytest + +from resty.enums import Method, Endpoint +from resty.types import Request, Response +from resty.ext.django.middlewares.pagination import ( + LimitOffsetPaginationMiddleware, + PagePaginationMiddleware, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "req, kwargs, expected", + [ + ( + Request( + url="https://example.com/", method=Method.GET, endpoint=Endpoint.READ + ), + {"limit": 100, "offset": 10}, + {"limit": 100, "offset": 10}, + ) + ], +) +async def test_limit_offset_pagination(req, kwargs, expected): + middleware = LimitOffsetPaginationMiddleware() + + await middleware(reqresp=req, **kwargs) + + assert req.params == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "req, kwargs, expected", + [ + ( + Request( + url="https://example.com/", method=Method.GET, endpoint=Endpoint.READ + ), + {"page": 100}, + {"page": 100}, + ) + ], +) +async def test_page_pagination(req, kwargs, expected): + middleware = PagePaginationMiddleware() + + await middleware(reqresp=req, **kwargs) + + assert req.params == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "middleware, resp, expected", + [ + ( + LimitOffsetPaginationMiddleware(), + Response( + request=Request( + url="https://example.com/", + method=Method.GET, + endpoint=Endpoint.READ, + ), + status=200, + json={ + "results": [ + {"id": 1, "username": "josh"}, + ] + }, + ), + [ + {"id": 1, "username": "josh"}, + ], + ), + ( + PagePaginationMiddleware(), + Response( + request=Request( + url="https://example.com/", + method=Method.GET, + endpoint=Endpoint.READ, + ), + status=200, + json={ + "results": [ + {"id": 1, "username": "josh"}, + ] + }, + ), + [ + {"id": 1, "username": "josh"}, + ], + ), + ], +) +async def test_unpaginate(middleware, resp, expected): + await middleware(reqresp=resp) + assert resp.json == expected diff --git a/tests/managers/conftest.py b/tests/managers/conftest.py index e8c9477..168168a 100644 --- a/tests/managers/conftest.py +++ b/tests/managers/conftest.py @@ -17,7 +17,7 @@ async def request(self, request: Request) -> Response: @pytest.fixture -def client(request): # pragma: nocover +def client(request): # pragma: nocover response, expected = request.param return RESTClientMock( diff --git a/tests/managers/test_manager.py b/tests/managers/test_manager.py index 5a83fdb..6483464 100644 --- a/tests/managers/test_manager.py +++ b/tests/managers/test_manager.py @@ -28,10 +28,9 @@ class UserManager(Manager): Endpoint.READ_ONE: "users/{pk}", Endpoint.UPDATE: "users/{pk}", Endpoint.DELETE: "users/{pk}", - } fields = { - Field.PRIMARY: 'id', + Field.PRIMARY: "id", } @@ -40,7 +39,7 @@ class UserManagerForURLBuilding(Manager): Endpoint.READ_ONE: "users/{pk}/{abc}", } fields = { - Field.PRIMARY: 'id', + Field.PRIMARY: "id", } @@ -62,16 +61,28 @@ class ManagerWithUnspecifiedFields(Manager): class ManagerWithPkField(Manager): fields = { - Field.PRIMARY: 'id', + Field.PRIMARY: "id", } @pytest.mark.asyncio -@pytest.mark.parametrize("client, obj", [ - (RESTClientMock(json={"username": "test"}, method=Method.POST), UserCreate(username="test")), - (RESTClientMock(json={"username": "321"}, method=Method.POST), UserCreate(username="321")), - (RESTClientMock(json={"username": "test"}, method=Method.POST), {"username": "test"}), -]) +@pytest.mark.parametrize( + "client, obj", + [ + ( + RESTClientMock(json={"username": "test"}, method=Method.POST), + UserCreate(username="test"), + ), + ( + RESTClientMock(json={"username": "321"}, method=Method.POST), + UserCreate(username="321"), + ), + ( + RESTClientMock(json={"username": "test"}, method=Method.POST), + {"username": "test"}, + ), + ], +) async def test_create(client, obj): manager = UserManager() @@ -79,14 +90,24 @@ async def test_create(client, obj): @pytest.mark.asyncio -@pytest.mark.parametrize("data", [ - ({"username": "test", "id": 1}, {"username": "test123", "id": 2}, {"username": "test321", "id": 3}), -]) +@pytest.mark.parametrize( + "data", + [ + ( + {"username": "test", "id": 1}, + {"username": "test123", "id": 2}, + {"username": "test321", "id": 3}, + ), + ], +) async def test_read(data): client = RESTClientMock( response=Response( - request=Request(url="", method=Method.GET), status=200, json=data, ), - method=Method.GET + request=Request(url="", method=Method.GET), + status=200, + json=data, + ), + method=Method.GET, ) manager = UserManager() @@ -96,37 +117,49 @@ async def test_read(data): @pytest.mark.asyncio -@pytest.mark.parametrize("client, obj", [ - ( +@pytest.mark.parametrize( + "client, obj", + [ + ( RESTClientMock( response=Response( Request("", Method.GET), status=200, - json={"username": "test", "id": 123} - ), method=Method.GET), - UserRead(username="test", id=123) - ), - -]) + json={"username": "test", "id": 123}, + ), + method=Method.GET, + ), + UserRead(username="test", id=123), + ), + ], +) async def test_read_one(client, obj): manager = UserManager() - assert await manager.read_one(client=client, obj_or_pk=123, response_type=UserRead) == obj + assert ( + await manager.read_one(client=client, obj_or_pk=123, response_type=UserRead) + == obj + ) @pytest.mark.asyncio -@pytest.mark.parametrize("client, obj", [ - ( +@pytest.mark.parametrize( + "client, obj", + [ + ( RESTClientMock( response=Response( Request("", Method.GET), status=200, - json={"username": "test", "id": 123} - ), method=Method.PATCH, url="users/123"), - UserUpdate(username="test", id=123) - ), - -]) + json={"username": "test", "id": 123}, + ), + method=Method.PATCH, + url="users/123", + ), + UserUpdate(username="test", id=123), + ), + ], +) async def test_update(client, obj): manager = UserManager() @@ -134,38 +167,43 @@ async def test_update(client, obj): @pytest.mark.asyncio -@pytest.mark.parametrize("client, pk", [ - ( - RESTClientMock(url="users/123", method=Method.DELETE), - 123 - ), - ( +@pytest.mark.parametrize( + "client, pk", + [ + (RESTClientMock(url="users/123", method=Method.DELETE), 123), + ( RESTClientMock(url="users/321", method=Method.DELETE), - UserRead(username="test", id=321) - ) - -]) + UserRead(username="test", id=321), + ), + ], +) async def test_delete(client, pk): manager = UserManager() - await manager.delete(client, obj_or_pk=pk) + await manager.delete(obj_or_pk=pk, client=client) @pytest.mark.asyncio -@pytest.mark.parametrize("client, pk, abc", [ - (RESTClientMock(url="users/123/hello"), 123, "hello"), - (RESTClientMock(url="users/321/world"), 321, "world"), -]) +@pytest.mark.parametrize( + "client, pk, abc", + [ + (RESTClientMock(url="users/123/hello"), 123, "hello"), + (RESTClientMock(url="users/321/world"), 321, "world"), + ], +) async def test_url_building(client, pk, abc): manager = UserManagerForURLBuilding() await manager.read_one(client=client, obj_or_pk=pk, abc=abc) -@pytest.mark.parametrize("manager", [ - ManagerWithoutSerializer, - ManagerWithInvalidSerializer, -]) +@pytest.mark.parametrize( + "manager", + [ + ManagerWithoutSerializer, + ManagerWithInvalidSerializer, + ], +) def test_invalid_or_unspec_serializer(manager): manager = manager() @@ -190,13 +228,16 @@ def test_get_unspec_field(): def test_get_field(): manager = ManagerWithPkField() - assert manager.get_field(Field.PRIMARY) == 'id' + assert manager.get_field(Field.PRIMARY) == "id" -@pytest.mark.parametrize("obj, pk", [ - ({'id': "test"}, "test"), - (UserRead(id=321, username='123'), 321), -]) +@pytest.mark.parametrize( + "obj, pk", + [ + ({"id": "test"}, "test"), + (UserRead(id=321, username="123"), 321), + ], +) def test_get_pk(obj, pk): manager = ManagerWithPkField() @@ -204,35 +245,69 @@ def test_get_pk(obj, pk): @pytest.mark.asyncio -@pytest.mark.parametrize("url", [ - "test", -]) +@pytest.mark.parametrize( + "url", + [ + "test", + ], +) async def test_passing_url(url): client = RESTClientMock(url=url) - manager = UserManagerForURLBuilding() + manager = UserManagerForURLBuilding(client=client) - await manager.delete(client, obj_or_pk=1, url=url) + await manager.delete(obj_or_pk=1, url=url) @pytest.mark.asyncio -@pytest.mark.parametrize("data, response_type, result", [ - ({"username": "test"}, dict, {"username": "test"}), - (("username", "test"), list, ["username", "test"]), - ({"username": "test", "id": 123}, lambda r: dict.keys(r.json), dict.keys({"username": "test", "id": 123})), - ({"username": "test", "id": 123}, lambda r, t: dict.keys(r), {"username": "test", "id": 123}) -]) -async def test_response_type(data, response_type, result): - client = RESTClientMock(response=Response( - request=Request( - url="", - method=Method.GET, +@pytest.mark.parametrize( + "data, response_type, result", + [ + ({"username": "test"}, dict, {"username": "test"}), + (("username", "test"), list, ["username", "test"]), + ( + {"username": "test", "id": 123}, + lambda r: dict.keys(r.json), + dict.keys({"username": "test", "id": 123}), ), - status=200, - json=data - )) + ( + {"username": "test", "id": 123}, + lambda r, t: dict.keys(r), + {"username": "test", "id": 123}, + ), + ], +) +async def test_response_type(data, response_type, result): + client = RESTClientMock( + response=Response( + request=Request( + url="", + method=Method.GET, + ), + status=200, + json=data, + ) + ) manager = UserManager() - response = await manager.read(client=client, response_type=response_type, ) + response = await manager.read( + client=client, + response_type=response_type, + ) assert response == result + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", + [ + None, + "test", + ], +) +async def test_passing_invalid_client(client): + manager = UserManager() + + with pytest.raises(TypeError): + await manager.read(client=client) diff --git a/tests/managers/test_url_builder.py b/tests/managers/test_url_builder.py index 20ad907..346ca78 100644 --- a/tests/managers/test_url_builder.py +++ b/tests/managers/test_url_builder.py @@ -5,19 +5,65 @@ from resty.enums import Endpoint -@pytest.mark.parametrize('endpoints, endpoint, base, kwargs, expected', [ - ({Endpoint.CREATE: "users/", }, Endpoint.CREATE, "base/", {}, "base/users/"), - ({Endpoint.BASE: "users/", }, Endpoint.CREATE, "base/", {}, "base/users/"), - ({Endpoint.BASE: "users/", }, Endpoint.CREATE, "base", {}, "base/users/"), - ({Endpoint.BASE: "users/", }, Endpoint.CREATE, None, {}, "users/"), - ({Endpoint.BASE: "users/{pk}", }, Endpoint.CREATE, None, {"pk": 123}, "users/123"), - ({}, Endpoint.CREATE, "base/{pk}", {"pk": 123}, "base/123"), - ({}, Endpoint.CREATE, None, {"pk": 123}, ""), -]) +@pytest.mark.parametrize( + "endpoints, endpoint, base, kwargs, expected", + [ + ( + { + Endpoint.CREATE: "users/", + }, + Endpoint.CREATE, + "base/", + {}, + "base/users/", + ), + ( + { + Endpoint.BASE: "users/", + }, + Endpoint.CREATE, + "base/", + {}, + "base/users/", + ), + ( + { + Endpoint.BASE: "users/", + }, + Endpoint.CREATE, + "base", + {}, + "base/users/", + ), + ( + { + Endpoint.BASE: "users/", + }, + Endpoint.CREATE, + None, + {}, + "users/", + ), + ( + { + Endpoint.BASE: "users/{pk}", + }, + Endpoint.CREATE, + None, + {"pk": 123}, + "users/123", + ), + ({}, Endpoint.CREATE, "base/{pk}", {"pk": 123}, "base/123"), + ({}, Endpoint.CREATE, None, {"pk": 123}, ""), + ], +) def test_build(endpoints, endpoint, base, kwargs, expected): builder = URLBuilder() - assert builder.build(endpoints=endpoints, endpoint=endpoint, base_url=base, **kwargs) == expected + assert ( + builder.build(endpoints=endpoints, endpoint=endpoint, base_url=base, **kwargs) + == expected + ) def test_build_missing_kwargs():