diff --git a/.gitignore b/.gitignore
index 43e1bff..9fc37b0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -167,4 +167,5 @@ cython_debug/
# Test File
-main.py
+resty_proto
+main.py
\ No newline at end of file
diff --git a/Makefile b/Makefile
index 344506a..6396c4f 100644
--- a/Makefile
+++ b/Makefile
@@ -20,4 +20,5 @@ coverage-report: coverage;
.PHONY: format
format:
poetry run python -m black tests
- poetry run python -m black resty
\ No newline at end of file
+ poetry run python -m black --exclude "types.py" resty
+ poetry run python -m black examples
\ No newline at end of file
diff --git a/README.md b/README.md
index 87d127b..8b5416c 100644
--- a/README.md
+++ b/README.md
@@ -9,6 +9,7 @@
+
@@ -38,91 +39,116 @@ poetry add resty-client
## Getting-Started
-### Schema
+See [examples](examples) for more.
+
+### Schemas
```python
-from pydantic import BaseModel
+from resty.types import Schema
-class Product(BaseModel):
- id: int | None = None
- name: str
- description: str
- code: str
-```
+class UserCreateSchema(Schema):
+ username: str
+ email: str
+ password: str
+ age: int
-### Serializer
-```python
-from resty.serializers import Serializer
+class UserReadSchema(Schema):
+ id: int
+ username: str
+ email: str
+ age: int
-class ProductSerializer(Serializer):
- schema = Product
+class UserUpdateSchema(Schema):
+ username: str = None
+ email: str = None
```
### Manager
```python
-from resty.enums import (
- Endpoint,
- Field
-)
from resty.managers import Manager
+from resty.enums import Endpoint, Field
-class ProductManager(Manager):
- serializer = ProductSerializer
+class UserManager(Manager):
endpoints = {
- Endpoint.CREATE: '/products/',
- Endpoint.READ: '/products/',
- Endpoint.READ_ONE: '/products/{pk}/',
- Endpoint.UPDATE: '/products/{pk}/',
- Endpoint.DELETE: '/products/{pk}/',
+ Endpoint.CREATE: "users/",
+ Endpoint.READ: "users/",
+ Endpoint.READ_ONE: "users/{pk}",
+ Endpoint.UPDATE: "users/{pk}",
+ Endpoint.DELETE: "users/{pk}",
}
fields = {
- Field.PRIMARY: 'id',
+ Field.PRIMARY: "id",
}
```
### CRUD
```python
-from httpx import AsyncClient
+import asyncio
+
+import httpx
from resty.clients.httpx import RESTClient
async def main():
- xclient = AsyncClient(base_url='http://localhost:8000/')
- rest_client = RESTClient(xclient=xclient)
+ client = RESTClient(httpx.AsyncClient(base_url="https://localhost:8000"))
+
+ response = await UserManager.create(
+ client=client,
+ obj=UserCreateSchema(
+ username="admin",
+ email="admin@admin.com",
+ password="admin",
+ age=19,
+ ),
+ response_type=UserReadSchema,
+ )
+ print(response) # id=1 username='admin' email='admin@admin.com' age=19
- product = Product(
- name='First prod',
- description='My Desc',
- code='123W31Q'
+ response = await UserManager.read(
+ client=client,
+ response_type=UserReadSchema,
)
- # Create
- created = await ProductManager.create(rest_client, product)
+ for obj in response:
+ print(obj) # id=1 username='admin' email='admin@admin.com' age=19
- # Read
- my_product = await ProductManager.read_one(rest_client, created.id)
+ response = await UserManager.read_one(
+ client=client,
+ obj_or_pk=1,
+ response_type=UserReadSchema,
+ )
- for prod in await ProductManager.read(rest_client):
- print(prod.name)
+ print(response) # id=1 username='admin' email='admin@admin.com' age=19
+
+ response = await UserManager.update(
+ client=client,
+ obj=UserUpdateSchema(id=1, username="admin123", ),
+ response_type=UserReadSchema,
+ )
+
+ print(response) # id=1 username='admin123' email='admin@admin.com' age=19
+
+ await UserManager.delete(
+ client=client,
+ obj_or_pk=1,
+ expected_status=204,
+ )
- # Update
- my_product.description = 'QWERTY'
- await ProductManager.update(rest_client, my_product)
- # Delete
- await ProductManager.delete(rest_client, my_product.id)
+if __name__ == "__main__":
+ asyncio.run(main())
```
## Status
-``0.0.4`` - **RELEASED**
+``0.0.5`` - **RELEASED**
## Licence
diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md
index e0ef194..e026815 100644
--- a/docs/CHANGELOG.md
+++ b/docs/CHANGELOG.md
@@ -16,4 +16,10 @@
## v0.0.4
-- Manager fixes
+- Manager important fixes!!!!
+
+## v0.0.5
+
+- Improved test coverage to 100%
+- Improved architecture
+- Added examples
diff --git a/examples/crud.py b/examples/crud.py
new file mode 100644
index 0000000..b257b78
--- /dev/null
+++ b/examples/crud.py
@@ -0,0 +1,90 @@
+import asyncio
+
+import httpx
+
+from resty.enums import Endpoint, Field
+from resty.types import Schema
+from resty.managers import Manager
+from resty.clients.httpx import RESTClient
+
+
+class UserCreateSchema(Schema):
+ username: str
+ email: str
+ password: str
+ age: int
+
+
+class UserReadSchema(Schema):
+ id: int
+ username: str
+ email: str
+ age: int
+
+
+class UserUpdateSchema(Schema):
+ username: str = None
+ email: str = None
+
+
+class UserManager(Manager):
+ endpoints = {
+ Endpoint.CREATE: "users/",
+ Endpoint.READ: "users/",
+ Endpoint.READ_ONE: "users/{pk}",
+ Endpoint.UPDATE: "users/{pk}",
+ Endpoint.DELETE: "users/{pk}",
+ }
+ fields = {
+ Field.PRIMARY: "id",
+ }
+
+
+async def main():
+ client = RESTClient(httpx.AsyncClient(base_url="https://localhost:8000"))
+
+ response = await UserManager.create(
+ client=client,
+ obj=UserCreateSchema(
+ username="admin",
+ email="admin@admin.com",
+ password="admin",
+ age=19,
+ ),
+ response_type=UserReadSchema,
+ )
+ print(response) # id=1 username='admin' email='admin@admin.com' age=19
+
+ response = await UserManager.read(
+ client=client,
+ response_type=UserReadSchema,
+ )
+
+ for obj in response:
+ print(obj) # id=1 username='admin' email='admin@admin.com' age=19
+
+ response = await UserManager.read_one(
+ client=client,
+ obj_or_pk=1,
+ response_type=UserReadSchema,
+ )
+
+ print(response) # id=1 username='admin' email='admin@admin.com' age=19
+
+ response = await UserManager.update(
+ client=client,
+ obj=UserUpdateSchema(id=1, username="admin123", ),
+ response_type=UserReadSchema,
+ )
+
+ print(response) # id=1 username='admin123' email='admin@admin.com' age=19
+
+ await UserManager.delete(
+ client=client,
+ obj_or_pk=1,
+ expected_status=204,
+ )
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/examples/crud/main.py b/examples/crud/main.py
deleted file mode 100644
index bb453b1..0000000
--- a/examples/crud/main.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import asyncio
-import httpx
-
-from resty.clients.httpx import RESTClient
-from resty.ext.django.middlewares import DjangoPagePaginationMiddleware
-
-from managers import ProductManager
-from schemas import ProductSchema
-
-
-async def main():
- xclient = httpx.AsyncClient(base_url='http://localhost:8000/')
-
- rest_client = RESTClient(xclient=xclient)
-
- rest_client.add_middleware(DjangoPagePaginationMiddleware())
-
- product = ProductSchema(
- name='My Product',
- description='My Desc',
- code='123W31QQW'
- )
-
- # Create
- created = await ProductManager.create(rest_client, product)
-
- # Read
- my_product = await ProductManager.read_one(rest_client, created.id)
-
- for prod in await ProductManager.read(rest_client):
- print(prod.name)
-
- # Update
- my_product.description = 'QWERTY'
- await ProductManager.update(rest_client, my_product)
-
- # Delete
- await ProductManager.delete(rest_client, my_product.id)
-
-
-if __name__ == '__main__':
- asyncio.run(main())
diff --git a/examples/crud/managers.py b/examples/crud/managers.py
deleted file mode 100644
index c6ddfe4..0000000
--- a/examples/crud/managers.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from resty.enums import (
- Endpoint,
- Field
-)
-from resty.managers import Manager
-
-from serializers import ProductSerializer
-
-
-class ProductManager(Manager):
- serializer = ProductSerializer
- endpoints = {
- Endpoint.CREATE: '/products/',
- Endpoint.READ: '/products/',
- Endpoint.READ_ONE: '/products/{pk}/',
- Endpoint.UPDATE: '/products/{pk}/',
- Endpoint.DELETE: '/products/{pk}/',
- }
- fields = {
- Field.PRIMARY: 'id',
- }
diff --git a/examples/crud/schemas.py b/examples/crud/schemas.py
deleted file mode 100644
index 55ca66d..0000000
--- a/examples/crud/schemas.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from pydantic import BaseModel
-
-
-class ProductSchema(BaseModel):
- id: int | None = None
- name: str
- description: str
- code: str
-
diff --git a/examples/crud/serializers.py b/examples/crud/serializers.py
deleted file mode 100644
index 38677f0..0000000
--- a/examples/crud/serializers.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from resty.serializers import Serializer
-
-from schemas import ProductSchema
-
-
-class ProductSerializer(Serializer):
- schema = ProductSchema
diff --git a/examples/crud_many_layers/managers.py b/examples/crud_many_layers/managers.py
deleted file mode 100644
index 99d08bc..0000000
--- a/examples/crud_many_layers/managers.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from resty.enums import (
- Endpoint,
- Field
-)
-from resty.managers import Manager
-
-from serializers import ProductSerializer
-
-
-class UserProductManager(Manager):
- serializer = ProductSerializer
- endpoints = {
- Endpoint.CREATE: 'users/{user_pk}/products/',
- Endpoint.READ: 'users/{user_pk}/products/',
- Endpoint.READ_ONE: 'users/{user_pk}/products/{pk}/',
- Endpoint.UPDATE: 'users/{user_pk}/products/{pk}/',
- Endpoint.DELETE: 'users/{user_pk}/products/{pk}/',
- }
- fields = {
- Field.PRIMARY: 'id',
- }
diff --git a/examples/crud_many_layers/schemas.py b/examples/crud_many_layers/schemas.py
deleted file mode 100644
index 55ca66d..0000000
--- a/examples/crud_many_layers/schemas.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from pydantic import BaseModel
-
-
-class ProductSchema(BaseModel):
- id: int | None = None
- name: str
- description: str
- code: str
-
diff --git a/examples/crud_many_layers/serializers.py b/examples/crud_many_layers/serializers.py
deleted file mode 100644
index 38677f0..0000000
--- a/examples/crud_many_layers/serializers.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from resty.serializers import Serializer
-
-from schemas import ProductSchema
-
-
-class ProductSerializer(Serializer):
- schema = ProductSchema
diff --git a/examples/httpx_client.py b/examples/httpx_client.py
new file mode 100644
index 0000000..10e6838
--- /dev/null
+++ b/examples/httpx_client.py
@@ -0,0 +1,19 @@
+import asyncio
+
+from resty.enums import Method
+from resty.types import Request
+from resty.clients.httpx import RESTClient
+
+
+async def main():
+ client = RESTClient()
+
+ response = await client.request(
+ Request(url="https://example.com", method=Method.GET)
+ )
+
+ print(response.text)
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/examples/middlewares.py b/examples/middlewares.py
new file mode 100644
index 0000000..1fea09e
--- /dev/null
+++ b/examples/middlewares.py
@@ -0,0 +1,37 @@
+import asyncio
+
+from resty.enums import Method
+from resty.types import Request, Response
+from resty.middlewares import BaseRequestMiddleware, BaseResponseMiddleware
+from resty.clients.httpx import RESTClient
+
+
+class LoggingMiddleware(BaseRequestMiddleware, BaseResponseMiddleware):
+ async def __call__(self, reqresp: Request | Response, **kwargs):
+ print(reqresp)
+
+
+class HelloWorldMiddleware(BaseRequestMiddleware):
+
+ async def __call__(self, request: Request, **kwargs):
+ print("Hello, World!")
+
+
+async def main():
+ client = RESTClient()
+
+ client.middlewares.add_middlewares(LoggingMiddleware())
+
+ await client.request(Request(url="https://example.com", method=Method.GET))
+ # Request(url='https://example.com', method=, ...)
+ # Response(request=Request(url='https://example.com', method=, ...)
+
+ with client.middlewares.middleware(HelloWorldMiddleware()):
+ await client.request(Request(url="https://example.com", method=Method.GET))
+ # Request(url='https://example.com', method=, ...)
+ # Hello, World!
+ # Response(request=Request(url='https://example.com', method=, ...)
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/examples/serializer_schemas/managers.py b/examples/serializer_schemas/managers.py
deleted file mode 100644
index c6ddfe4..0000000
--- a/examples/serializer_schemas/managers.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from resty.enums import (
- Endpoint,
- Field
-)
-from resty.managers import Manager
-
-from serializers import ProductSerializer
-
-
-class ProductManager(Manager):
- serializer = ProductSerializer
- endpoints = {
- Endpoint.CREATE: '/products/',
- Endpoint.READ: '/products/',
- Endpoint.READ_ONE: '/products/{pk}/',
- Endpoint.UPDATE: '/products/{pk}/',
- Endpoint.DELETE: '/products/{pk}/',
- }
- fields = {
- Field.PRIMARY: 'id',
- }
diff --git a/examples/serializer_schemas/schemas.py b/examples/serializer_schemas/schemas.py
deleted file mode 100644
index dc76e3a..0000000
--- a/examples/serializer_schemas/schemas.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from pydantic import BaseModel
-
-
-class ProductCreateSchema(BaseModel):
- name: str
- description: str
-
-
-class ProductReadSchema(BaseModel):
- id: int | None = None
- name: str
- description: str
- code: str
-
-
-class ProductUpdateSchema(BaseModel):
- name: str = None
- description: str = None
diff --git a/examples/serializer_schemas/serializers.py b/examples/serializer_schemas/serializers.py
deleted file mode 100644
index be5b9ca..0000000
--- a/examples/serializer_schemas/serializers.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from resty.serializers import Serializer
-from resty.enums import Endpoint
-
-from schemas import (
- ProductCreateSchema,
- ProductReadSchema,
- ProductUpdateSchema
-)
-
-
-class ProductSerializer(Serializer):
- schemas = {
- Endpoint.CREATE: ProductCreateSchema,
- Endpoint.READ: ProductReadSchema,
- Endpoint.READ_ONE: ProductReadSchema,
- Endpoint.UPDATE: ProductUpdateSchema
- }
diff --git a/pyproject.toml b/pyproject.toml
index c647c78..4a4f505 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "resty-client"
-version = "0.0.4"
+version = "0.0.5"
description = "RestyClient is a simple, easy-to-use Python library for interacting with REST APIs using Pydantic's powerful data validation and deserialization tools."
authors = ["CrazyProger1 "]
license = "MIT"
@@ -19,6 +19,7 @@ pytest = "^7.4.4"
httpx = "^0.26.0"
black = "^24.4.0"
coverage = "^7.5.0"
+pytest-asyncio = "^0.23.6"
[tool.ruff]
exclude = [
@@ -50,7 +51,6 @@ files = ["resty"]
show_error_codes = true
strict = true
-
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
diff --git a/resty/__init__.py b/resty/__init__.py
index 760653f..ef7232d 100644
--- a/resty/__init__.py
+++ b/resty/__init__.py
@@ -1,4 +1,11 @@
-try:
- import pydantic
-except ImportError:
- raise ImportError("Please install pydantic to use resty-client package")
+from resty.__version__ import (
+ __version__,
+ __title__,
+ __description__,
+)
+
+__all__ = [
+ "__version__",
+ "__title__",
+ "__description__",
+]
diff --git a/resty/__version__.py b/resty/__version__.py
new file mode 100644
index 0000000..90300ba
--- /dev/null
+++ b/resty/__version__.py
@@ -0,0 +1,4 @@
+__title__ = "Resty-Client"
+__version__ = "0.0.5"
+__description__ = """RestyClient is a simple, easy-to-use Python library for interacting with REST APIs using Pydantic's
+powerful data validation and deserialization tools."""
diff --git a/resty/clients/__init__.py b/resty/clients/__init__.py
index e69de29..085c19f 100644
--- a/resty/clients/__init__.py
+++ b/resty/clients/__init__.py
@@ -0,0 +1,5 @@
+from resty.clients.types import BaseRESTClient
+
+__all__ = [
+ "BaseRESTClient",
+]
diff --git a/resty/clients/httpx/__init__.py b/resty/clients/httpx/__init__.py
index 8b015fa..dff2eaf 100644
--- a/resty/clients/httpx/__init__.py
+++ b/resty/clients/httpx/__init__.py
@@ -1,8 +1,5 @@
-try:
- import httpx
-except ImportError:
- raise ImportError("Please install httpx to use httpx rest client")
+from resty.clients.httpx.clients import RESTClient
-from .client import RESTClient
-
-__all__ = ["RESTClient"]
+__all__ = [
+ "RESTClient",
+]
diff --git a/resty/clients/httpx/client.py b/resty/clients/httpx/client.py
deleted file mode 100644
index 6982c8a..0000000
--- a/resty/clients/httpx/client.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import json.decoder
-from typing import Container
-
-import httpx
-
-from resty.constants import DEFAULT_CODES, STATUS_ERRORS
-from resty.types import (
- BaseRESTClient,
- Request,
- Response,
- BaseMiddleware,
- BaseMiddlewareManager,
-)
-from resty.exceptions import HTTPError
-from resty.middlewares import MiddlewareManager
-
-
-class RESTClient(BaseRESTClient):
-
- def __init__(
- self,
- xclient: httpx.AsyncClient = None,
- middleware_manager: BaseMiddlewareManager = None,
- ):
-
- self._xclient = xclient or httpx.AsyncClient()
- self._middleware_manager = middleware_manager or MiddlewareManager()
-
- @staticmethod
- def _parse_xresponse(xresponse: httpx.Response) -> dict | list | None:
- try:
- data = xresponse.json()
- except json.decoder.JSONDecodeError:
- data = {}
-
- return data
-
- @staticmethod
- def _check_status(
- status: int,
- expected_status: int | Container[int],
- request: Request,
- url: str,
- data: dict = None,
- ):
- if status != expected_status:
- if isinstance(expected_status, Container) and status in expected_status:
- pass
- else:
- exc: type[HTTPError] = STATUS_ERRORS.get(status, HTTPError)
- raise exc(request=request, status=status, url=url, data=data)
-
- async def _make_xrequest(self, request: Request):
- return await self._xclient.request(
- method=request.method.value,
- url=request.url,
- headers=request.headers,
- json=request.json,
- data=request.data,
- params=request.params,
- cookies=request.cookies,
- follow_redirects=request.redirects,
- timeout=request.timeout,
- )
-
- def add_middleware(self, middleware: BaseMiddleware):
- self._middleware_manager.add_middleware(middleware=middleware)
-
- async def request(self, request: Request, **context) -> Response:
- if not isinstance(request, Request):
- raise TypeError("request is not of type Request")
-
- expected_status: int = context.pop(
- "expected_status", DEFAULT_CODES.get(request.method)
- )
- check_status: bool = context.pop("check_status", True)
-
- if not isinstance(expected_status, (int, Container)):
- raise TypeError("expected status should be type of int or Container[int]")
-
- await self._middleware_manager.call_request_middlewares(
- request=request, **context
- )
-
- xresponse = await self._make_xrequest(request=request)
-
- data = self._parse_xresponse(xresponse=xresponse)
-
- status = xresponse.status_code
-
- if check_status:
- self._check_status(
- status=status,
- expected_status=expected_status,
- request=request,
- url=str(xresponse.url),
- data=data,
- )
- response = Response(
- request=request,
- status=status,
- data=data,
- )
-
- await self._middleware_manager.call_response_middlewares(
- response=response, **context
- )
-
- return response
diff --git a/resty/clients/httpx/clients.py b/resty/clients/httpx/clients.py
new file mode 100644
index 0000000..148891c
--- /dev/null
+++ b/resty/clients/httpx/clients.py
@@ -0,0 +1,94 @@
+import json
+
+import httpx
+
+from resty.clients.types import (
+ BaseRESTClient,
+ Request,
+ Response,
+)
+from resty.middlewares import (
+ MiddlewareManager,
+ BaseRequestMiddleware,
+ BaseResponseMiddleware,
+ StatusCheckingMiddleware,
+ BaseMiddlewareManager,
+)
+from resty.exceptions import ConnectError
+
+
+class RESTClient(BaseRESTClient):
+
+ def __init__(
+ self,
+ httpx_client: httpx.AsyncClient = None,
+ check_status: bool = True,
+ middleware_manager: BaseMiddlewareManager = None,
+ ):
+ self.middlewares = middleware_manager or MiddlewareManager()
+ self._xclient = httpx_client or httpx.AsyncClient()
+
+ if check_status:
+ self.middlewares.add_middleware(StatusCheckingMiddleware())
+
+ async def _make_xrequest(self, request: Request) -> httpx.Response:
+ try:
+ return await self._xclient.request(
+ method=request.method.value,
+ url=request.url,
+ headers=request.headers,
+ json=request.json,
+ data=request.data,
+ params=request.params,
+ cookies=request.cookies,
+ follow_redirects=request.redirects,
+ timeout=request.timeout,
+ )
+ except httpx.ConnectError:
+ raise ConnectError(url=request.url)
+
+ @staticmethod
+ def _extract_json_data(xresponse: httpx.Response) -> dict | list:
+ try:
+ data = xresponse.json()
+ except json.decoder.JSONDecodeError:
+ data = {}
+
+ return data
+
+ async def _parse_xresponse(
+ self, request: Request, xresponse: httpx.Response
+ ) -> Response:
+ return Response(
+ request=request,
+ status=xresponse.status_code,
+ json=self._extract_json_data(xresponse=xresponse),
+ content=xresponse.content,
+ text=xresponse.text,
+ middleware_options=request.middleware_options,
+ )
+
+ async def _make_request(self, request: Request) -> Response:
+ xresponse = await self._make_xrequest(request=request)
+ response = await self._parse_xresponse(request=request, xresponse=xresponse)
+ return response
+
+ async def _call_middlewares(self, reqresp: Request | Response):
+ await self.middlewares(
+ reqresp,
+ base=(
+ BaseRequestMiddleware
+ if isinstance(reqresp, Request)
+ else BaseResponseMiddleware
+ ),
+ **reqresp.middleware_options
+ )
+
+ async def request(self, request: Request) -> Response:
+ await self._call_middlewares(reqresp=request)
+
+ response = await self._make_request(request)
+
+ await self._call_middlewares(reqresp=response)
+
+ return response
diff --git a/resty/clients/types.py b/resty/clients/types.py
new file mode 100644
index 0000000..3b8f696
--- /dev/null
+++ b/resty/clients/types.py
@@ -0,0 +1,11 @@
+from abc import ABC, abstractmethod
+
+from resty.middlewares import BaseMiddlewareManager
+from resty.types import Request, Response
+
+
+class BaseRESTClient(ABC):
+ middlewares: BaseMiddlewareManager
+
+ @abstractmethod
+ async def request(self, request: Request) -> Response: ...
diff --git a/resty/constants.py b/resty/constants.py
index e72f769..f4459b8 100644
--- a/resty/constants.py
+++ b/resty/constants.py
@@ -1,4 +1,3 @@
-from resty.enums import Method
from resty.exceptions import (
NotFoundError,
BadRequestError,
@@ -8,14 +7,6 @@
ForbiddenError,
)
-DEFAULT_CODES = {
- Method.GET: 200,
- Method.POST: {201, 200},
- Method.PUT: 200,
- Method.PATCH: 200,
- Method.DELETE: {204, 200},
-}
-
STATUS_ERRORS = {
400: BadRequestError,
401: UnauthorizedError,
diff --git a/resty/exceptions.py b/resty/exceptions.py
index d7c850a..187ec80 100644
--- a/resty/exceptions.py
+++ b/resty/exceptions.py
@@ -1,21 +1,32 @@
-from resty.types import Request
+from resty.types import Response
class RestyError(Exception):
pass
-class URLFormattingError(RestyError):
+class NetworkError(RestyError):
pass
-class HTTPError(RestyError):
- def __init__(self, request: Request, status: int, url: str, data: dict):
- self.request = request
- self.status = status
- self.url = url
- self.data = data
- super().__init__(f"{request.method.value}: {url} -> {status}")
+class URLBuildingError(RestyError):
+ pass
+
+
+class ConnectError(NetworkError, ConnectionError):
+ def __init__(self, url: str):
+ self.url = url # pragma: no cover
+ super().__init__(
+ f"Failed to establish a connection to the server {url}"
+ ) # pragma: no cover
+
+
+class HTTPError(NetworkError):
+ def __init__(self, response: Response):
+ self.response = response # pragma: no cover
+ super().__init__( # pragma: no cover
+ f"{response.request.method.value}: {response.request.url} -> {response.status}"
+ )
class BadRequestError(HTTPError):
diff --git a/resty/ext/django/__init__.py b/resty/ext/django/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/resty/ext/django/middlewares/__init__.py b/resty/ext/django/middlewares/__init__.py
deleted file mode 100644
index 8914536..0000000
--- a/resty/ext/django/middlewares/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from .pagination import (
- DjangoLimitOffsetPaginationMiddleware,
- DjangoPagePaginationMiddleware,
-)
-
-__all__ = ["DjangoLimitOffsetPaginationMiddleware", "DjangoPagePaginationMiddleware"]
diff --git a/resty/ext/django/middlewares/pagination.py b/resty/ext/django/middlewares/pagination.py
deleted file mode 100644
index 8dea486..0000000
--- a/resty/ext/django/middlewares/pagination.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from resty.middlewares import BasePaginationMiddleware
-from resty.types import Response, Request
-from resty.enums import Method
-
-
-class DjangoPagePaginationMiddleware(BasePaginationMiddleware):
- async def handle_request(self, request: Request, **context):
- if request.method in {
- Method.GET,
- }:
- page = context.pop("page", 1)
-
- request.params.update(
- {
- "page": page,
- }
- )
-
- async def handle_response(self, response: Response, **context):
- if response.request.method in {
- Method.GET,
- }:
- data = response.data
- results = data.get("results", data)
- response.data = results
-
-
-class DjangoLimitOffsetPaginationMiddleware(DjangoPagePaginationMiddleware):
- def __init__(self, page_size: int = 100):
- self._limit = page_size
-
- async def handle_request(self, request: Request, **context):
- if request.method in {
- Method.GET,
- }:
- page = context.pop("page", 1) - 1
- limit = context.pop("limit", self._limit)
- offset = context.pop("offset", page * self._limit)
-
- request.params.update({"limit": limit, "offset": offset})
diff --git a/resty/managers/__init__.py b/resty/managers/__init__.py
index ade67b7..bb71e1d 100644
--- a/resty/managers/__init__.py
+++ b/resty/managers/__init__.py
@@ -1,3 +1,10 @@
-from resty.managers.manager import Manager
+from resty.managers.types import BaseManager, BaseURLBuilder
+from resty.managers.managers import Manager
+from resty.managers.builders import URLBuilder
-__all__ = ["Manager"]
+__all__ = [
+ "BaseManager",
+ "Manager",
+ "BaseURLBuilder",
+ "URLBuilder",
+]
diff --git a/resty/managers/builders.py b/resty/managers/builders.py
new file mode 100644
index 0000000..7f3ff57
--- /dev/null
+++ b/resty/managers/builders.py
@@ -0,0 +1,47 @@
+from functools import cache
+from urllib.parse import urljoin, urlparse
+
+from resty.enums import Endpoint
+from resty.exceptions import URLBuildingError
+from resty.managers.types import BaseURLBuilder, Endpoints
+
+
+class URLBuilder(BaseURLBuilder):
+ @classmethod
+ def _get_endpoint_url(cls, endpoints: Endpoints, endpoint: Endpoint) -> str | None:
+ url = endpoints.get(endpoint, endpoints.get(Endpoint.BASE, None))
+ return url
+
+ @classmethod
+ def _inject_params(cls, url: str, **params) -> str:
+ try:
+ return url.format(**params)
+ except KeyError as e:
+ raise URLBuildingError(f"Missing '{e.args[0]}' in {url}")
+
+ @classmethod
+ @cache
+ def _normalize_url(cls, url: str | None) -> str:
+ if not url:
+ return ""
+
+ if not url.endswith('/'):
+ return url + '/'
+ return url
+
+ @classmethod
+ def build(
+ cls, endpoints: Endpoints, endpoint: Endpoint, base_url: str = None, **kwargs
+ ) -> str:
+
+ endpoint_url = cls._get_endpoint_url(endpoints=endpoints, endpoint=endpoint)
+
+ if endpoint_url:
+ url = urljoin(
+ cls._normalize_url(url=base_url),
+ endpoint_url
+ )
+ else:
+ url = base_url or ""
+
+ return cls._inject_params(url=url, **kwargs)
diff --git a/resty/managers/manager.py b/resty/managers/manager.py
deleted file mode 100644
index 56ca5bb..0000000
--- a/resty/managers/manager.py
+++ /dev/null
@@ -1,150 +0,0 @@
-from typing import Iterable
-
-from pydantic import BaseModel
-
-from resty.types import BaseManager, BaseRESTClient, Request, BaseSerializer
-from resty.enums import Endpoint, Method, Field
-from resty.exceptions import URLFormattingError
-
-
-class Manager(BaseManager):
- @classmethod
- def _get_endpoint_url(cls, endpoint: Endpoint) -> str:
- return cls.endpoints.get(endpoint, cls.endpoints.get(endpoint.BASE, ""))
-
- @classmethod
- def _get_pk_field(cls) -> str | None:
- return cls.fields.get(Field.PRIMARY)
-
- @classmethod
- def _get_pk(cls, obj) -> any:
- pk_field = cls._get_pk_field()
- if isinstance(obj, dict):
- return obj.get(pk_field)
- return getattr(obj, pk_field)
-
- @classmethod
- def _set_pk(cls, obj: BaseModel, pk: any):
- setattr(obj, cls._get_pk_field(), pk)
-
- @classmethod
- def _inject_into_url(cls, url: str, **data) -> str:
- try:
- return url.format(**data)
- except KeyError as e:
- raise URLFormattingError(f"Missing '{e.args[0]}' in {url}")
-
- @classmethod
- def _prepare_url(cls, **options) -> str:
- url = options.get("url")
-
- if url is not None:
- return url
-
- endpoint = options.get("endpoint", Endpoint.BASE)
- url = cls._get_endpoint_url(endpoint)
- return cls._inject_into_url(url, **options)
-
- @classmethod
- def _build_request(cls, **options) -> Request:
- return Request(
- method=options.get("method", Method.GET),
- url=options.get("url"),
- json=options.get("json", {}),
- headers=options.get("headers", {}),
- params=options.get("params", {}),
- cookies=options.get("cookies", {}),
- redirects=options.get("redirects", False),
- timeout=options.get("timeout", None),
- )
-
- @classmethod
- def _prepare_options(cls, endpoint: Endpoint, method: Method, **kwargs) -> dict:
- options = {
- "endpoint": endpoint,
- "method": method,
- }
- options.update(kwargs)
-
- options["url"] = cls._prepare_url(**options)
-
- return options
-
- @classmethod
- async def _make_request(cls, client: BaseRESTClient, **options):
- request = cls._build_request(**options)
- return await client.request(request=request, **options)
-
- @classmethod
- def _get_serializer(cls, **options) -> type[BaseSerializer]:
- return cls.serializer
-
- @classmethod
- def _serialize(cls, obj: BaseModel, **options) -> dict:
- serializer = cls._get_serializer(**options)
- return serializer.serialize(obj, **options)
-
- @classmethod
- def _deserialize(cls, data: list | dict, many: bool = False, **options):
- serializer = cls._get_serializer(**options)
- if many:
- return serializer.deserialize_many(data=data, **options)
- return serializer.deserialize(data=data, **options)
-
- @classmethod
- async def create(
- cls, client: BaseRESTClient, obj: BaseModel, **kwargs
- ) -> BaseModel:
-
- set_pk = kwargs.pop("set_pk", True)
-
- options = cls._prepare_options(
- endpoint=Endpoint.CREATE, method=Method.POST, **kwargs
- )
-
- options["json"] = cls._serialize(obj, **options)
-
- response = await cls._make_request(client=client, **options)
-
- if set_pk:
- cls._set_pk(obj, pk=cls._get_pk(response.data))
-
- return obj
-
- @classmethod
- async def read(cls, client: BaseRESTClient, **kwargs) -> Iterable[BaseModel]:
- options = cls._prepare_options(
- endpoint=Endpoint.READ, method=Method.GET, **kwargs
- )
-
- response = await cls._make_request(client=client, **options)
-
- return cls._deserialize(data=response.data, many=True, **options)
-
- @classmethod
- async def read_one(cls, client: BaseRESTClient, pk: any, **kwargs) -> BaseModel:
- options = cls._prepare_options(
- endpoint=Endpoint.READ_ONE, method=Method.GET, pk=pk, **kwargs
- )
-
- response = await cls._make_request(client=client, **options)
-
- return cls._deserialize(data=response.data, many=False, **options)
-
- @classmethod
- async def update(cls, client: BaseRESTClient, obj: BaseModel, **kwargs) -> None:
- options = cls._prepare_options(
- endpoint=Endpoint.UPDATE, method=Method.PATCH, pk=cls._get_pk(obj), **kwargs
- )
-
- options["json"] = cls._serialize(obj=obj, **options)
-
- await cls._make_request(client=client, **options)
-
- @classmethod
- async def delete(cls, client: BaseRESTClient, pk: any, **kwargs) -> None:
- options = cls._prepare_options(
- endpoint=Endpoint.DELETE, method=Method.DELETE, pk=pk, **kwargs
- )
-
- await cls._make_request(client=client, **options)
diff --git a/resty/managers/managers.py b/resty/managers/managers.py
new file mode 100644
index 0000000..70029e0
--- /dev/null
+++ b/resty/managers/managers.py
@@ -0,0 +1,187 @@
+import inspect
+from typing import Mapping, Iterable, Callable
+
+from resty.clients import BaseRESTClient
+from resty.enums import Endpoint, Method, Field
+from resty.types import Schema, Response, Request
+from resty.serializers import Serializer, BaseSerializer
+from resty.managers.types import BaseManager, ResponseType
+from resty.managers.builders import URLBuilder
+
+
+class Manager(BaseManager):
+ serializer_class = Serializer
+ url_builder_class = URLBuilder
+
+ @classmethod
+ def get_serializer(cls, **kwargs) -> type[BaseSerializer]:
+ serializer = kwargs.pop('serializer', cls.serializer_class)
+
+ if not serializer:
+ raise RuntimeError('Serializer not specified')
+
+ if not (
+ isinstance(serializer, BaseSerializer)
+ or inspect.isclass(serializer)
+ and issubclass(serializer, BaseSerializer)
+ ):
+ raise RuntimeError('The serializer must be a subclass of BaseSerializer')
+
+ return serializer
+
+ @classmethod
+ def get_method(cls, endpoint: Endpoint, **kwargs) -> Method:
+ method = cls.methods.get(endpoint)
+ if not method:
+ raise RuntimeError(f'Method not specified for endpoint: {endpoint}')
+
+ return method
+
+ @classmethod
+ def get_field(cls, field: Field) -> str:
+ field = cls.fields.get(field)
+
+ if not field:
+ raise RuntimeError(f'Field not specified: {field}')
+
+ return field
+
+ @classmethod
+ def get_pk(cls, obj: Schema | Mapping) -> any:
+ field = cls.get_field(Field.PRIMARY)
+
+ if isinstance(obj, Mapping):
+ return obj.get(field)
+
+ return getattr(obj, field, None)
+
+ @classmethod
+ def _get_pk(cls, obj_or_pk: any) -> any:
+ if isinstance(obj_or_pk, Mapping | Schema):
+ return cls.get_pk(obj=obj_or_pk)
+ return obj_or_pk
+
+ @classmethod
+ def _deserialize(cls, schema: type[Schema], data: any, **kwargs):
+ serializer = cls.get_serializer(**kwargs)
+ if isinstance(data, Mapping):
+ return serializer.deserialize(schema=schema, data=data, **kwargs)
+ return serializer.deserialize_many(schema=schema, data=data, **kwargs)
+
+ @classmethod
+ def _serialize(cls, obj: Schema, **kwargs):
+ serializer = cls.get_serializer(**kwargs)
+ return serializer.serialize(obj=obj, **kwargs)
+
+ @classmethod
+ async def _make_request(cls, client: BaseRESTClient, request: Request) -> Response:
+ return await client.request(request=request)
+
+ @classmethod
+ def _prepare_url(cls, endpoint: Endpoint, **kwargs) -> str:
+ url = kwargs.pop('url', None)
+ base_url = kwargs.pop('base_url', None)
+
+ if isinstance(url, str):
+ return url
+
+ return cls.url_builder_class.build(
+ endpoints=cls.endpoints,
+ endpoint=endpoint,
+ base_url=base_url or cls.url,
+ **kwargs
+ )
+
+ @classmethod
+ def _prepare_json(cls, **kwargs):
+ obj = kwargs.pop('obj', None)
+
+ if isinstance(obj, dict | list | set | tuple):
+ return obj
+ elif isinstance(obj, Schema):
+ return cls._serialize(obj, **kwargs)
+ return {}
+
+ @classmethod
+ def _prepare_request(cls, endpoint: Endpoint, **kwargs) -> Request:
+ return Request(
+ url=cls._prepare_url(endpoint=endpoint, **kwargs),
+ method=cls.get_method(endpoint, **kwargs),
+ endpoint=endpoint,
+ data=kwargs.pop('data', {}),
+ json=cls._prepare_json(**kwargs),
+ timeout=kwargs.pop('timeout', None),
+ params=kwargs.pop('params', {}),
+ headers=kwargs.pop('headers', {}),
+ cookies=kwargs.pop('cookies', {}),
+ redirects=kwargs.pop('redirects', False),
+ middleware_options=kwargs.copy(),
+ )
+
+ @classmethod
+ def _handle_response(cls, response: Response, response_type: ResponseType, **kwargs) -> any:
+ if not response:
+ return
+
+ if inspect.isclass(response_type):
+ if issubclass(response_type, dict | list | tuple | set):
+ return response_type(response.json)
+ elif issubclass(response_type, Schema):
+ return cls._deserialize(schema=response_type, data=response.json, **kwargs)
+
+ if callable(response_type):
+ try:
+ return response_type(response)
+ except TypeError:
+ pass
+
+ return response.json
+
+ @classmethod
+ async def create[T: Schema](cls, client: BaseRESTClient, obj: Schema | Mapping, response_type: ResponseType = None,
+ **kwargs) -> T | None:
+ request = cls._prepare_request(endpoint=Endpoint.CREATE, obj=obj, **kwargs)
+ response = await cls._make_request(client=client, request=request)
+ return cls._handle_response(response=response, response_type=response_type, **kwargs)
+
+ @classmethod
+ async def read[T: Schema](cls, client: BaseRESTClient, response_type: ResponseType = None, **kwargs) -> Iterable[T]:
+ request = cls._prepare_request(endpoint=Endpoint.READ, **kwargs)
+ response = await cls._make_request(client=client, request=request)
+ return cls._handle_response(response=response, response_type=response_type, **kwargs)
+
+ @classmethod
+ async def read_one[T: Schema](cls, client: BaseRESTClient, obj_or_pk: Schema | Mapping | any,
+ response_type: ResponseType = None, **kwargs) -> T:
+
+ request = cls._prepare_request(
+ endpoint=Endpoint.READ_ONE,
+ pk=cls._get_pk(obj_or_pk=obj_or_pk),
+ **kwargs
+ )
+ response = await cls._make_request(client=client, request=request)
+ return cls._handle_response(response=response, response_type=response_type, **kwargs)
+
+ @classmethod
+ async def update[T: Schema](cls, client: BaseRESTClient, obj: Schema | Mapping, response_type: ResponseType = None,
+ **kwargs) -> T | None:
+ request = cls._prepare_request(
+ endpoint=Endpoint.UPDATE,
+ pk=kwargs.pop('pk', cls.get_pk(obj)),
+ obj=obj,
+ **kwargs
+ )
+ response = await cls._make_request(client=client, request=request)
+ return cls._handle_response(response=response, response_type=response_type, **kwargs)
+
+ @classmethod
+ async def delete[T: Schema](cls, client: BaseRESTClient, obj_or_pk: Schema | Mapping | any,
+ response_type: ResponseType = None,
+ **kwargs) -> T | None:
+ request = cls._prepare_request(
+ endpoint=Endpoint.DELETE,
+ pk=cls._get_pk(obj_or_pk=obj_or_pk),
+ **kwargs
+ )
+ response = await cls._make_request(client=client, request=request)
+ return cls._handle_response(response=response, response_type=response_type, **kwargs)
diff --git a/resty/managers/types.py b/resty/managers/types.py
new file mode 100644
index 0000000..19acd51
--- /dev/null
+++ b/resty/managers/types.py
@@ -0,0 +1,104 @@
+from abc import ABC, abstractmethod
+from typing import Mapping, Iterable, Callable
+
+from resty.enums import Endpoint, Field, Method
+from resty.serializers import BaseSerializer
+from resty.clients import BaseRESTClient
+from resty.types import Schema, Response
+
+type Endpoints = Mapping[Endpoint, str]
+type Fields = Mapping[Field, str]
+type Methods = Mapping[Endpoint, Method]
+type ResponseType = type[Schema] | type[Mapping | Iterable] | Callable[[Response, ], any] | None
+
+
+class BaseURLBuilder(ABC):
+ @classmethod
+ @abstractmethod
+ def build(
+ cls,
+ base: str,
+ endpoints: Endpoints,
+ endpoint: Endpoint,
+ **kwargs,
+ ) -> str: ...
+
+
+class BaseManager(ABC):
+ url: str = None
+ methods: Methods = {
+ Endpoint.CREATE: Method.POST,
+ Endpoint.READ: Method.GET,
+ Endpoint.READ_ONE: Method.GET,
+ Endpoint.UPDATE: Method.PATCH,
+ Endpoint.DELETE: Method.DELETE,
+ }
+ endpoints: Endpoints = {}
+ fields: Fields = {}
+ serializer_class: BaseSerializer
+ url_builder_class: BaseURLBuilder
+
+ @classmethod
+ @abstractmethod
+ def get_serializer(cls, **kwargs) -> type[BaseSerializer]: ...
+
+ @classmethod
+ @abstractmethod
+ def get_method(cls, endpoint: Endpoint, **kwargs) -> Method: ...
+
+ @classmethod
+ @abstractmethod
+ def get_field(cls, field: Field) -> str: ...
+
+ @classmethod
+ @abstractmethod
+ def get_pk(cls, obj: Schema | Mapping) -> any: ...
+
+ @classmethod
+ @abstractmethod
+ async def create[T: Schema](
+ cls,
+ client: BaseRESTClient,
+ obj: Schema | Mapping,
+ response_type: ResponseType = None,
+ **kwargs,
+ ) -> T | None: ...
+
+ @classmethod
+ @abstractmethod
+ async def read[T: Schema](
+ cls,
+ client: BaseRESTClient,
+ response_type: ResponseType = None,
+ **kwargs,
+ ) -> Iterable[T]: ...
+
+ @classmethod
+ @abstractmethod
+ async def read_one[T: Schema](
+ cls,
+ client: BaseRESTClient,
+ obj_or_pk: Schema | Mapping | any,
+ response_type: ResponseType = None,
+ **kwargs,
+ ) -> T: ...
+
+ @classmethod
+ @abstractmethod
+ async def update[T: Schema](
+ cls,
+ client: BaseRESTClient,
+ obj: Schema | Mapping,
+ response_type: ResponseType = None,
+ **kwargs,
+ ) -> T | None: ...
+
+ @classmethod
+ @abstractmethod
+ async def delete[T: Schema](
+ cls,
+ client: BaseRESTClient,
+ obj_or_pk: Schema | Mapping | any,
+ response_type: ResponseType = None,
+ **kwargs,
+ ) -> T | None: ...
diff --git a/resty/middlewares/__init__.py b/resty/middlewares/__init__.py
index 75a776b..2975981 100644
--- a/resty/middlewares/__init__.py
+++ b/resty/middlewares/__init__.py
@@ -1,10 +1,17 @@
-from resty.middlewares.manager import MiddlewareManager
-from resty.middlewares.types import BasePaginationMiddleware
-from resty.types import BaseRequestMiddleware, BaseResponseMiddleware
+from resty.middlewares.types import (
+ BaseMiddleware,
+ BaseMiddlewareManager,
+ BaseResponseMiddleware,
+ BaseRequestMiddleware,
+)
+from resty.middlewares.managers import MiddlewareManager
+from resty.middlewares.status import StatusCheckingMiddleware
__all__ = [
"MiddlewareManager",
- "BasePaginationMiddleware",
+ "BaseMiddleware",
+ "BaseMiddlewareManager",
"BaseResponseMiddleware",
"BaseRequestMiddleware",
+ "StatusCheckingMiddleware",
]
diff --git a/resty/middlewares/manager.py b/resty/middlewares/manager.py
deleted file mode 100644
index 9efab42..0000000
--- a/resty/middlewares/manager.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from typing import Iterable
-
-from resty.types import (
- BaseMiddlewareManager,
- BaseMiddleware,
- BaseRequestMiddleware,
- BaseResponseMiddleware,
-)
-
-from resty.types import Request, Response
-
-
-class MiddlewareManager(BaseMiddlewareManager):
- def __init__(self, middlewares: Iterable[BaseMiddleware] = None):
- if not middlewares:
- middlewares = []
- for middleware in middlewares:
- self.add_middleware(middleware=middleware)
-
- self._middlewares = []
-
- async def call_request_middlewares(self, request: Request, **context):
- for middleware in self._middlewares:
- if isinstance(middleware, BaseRequestMiddleware):
- await middleware.handle_request(request=request, **context)
-
- async def call_response_middlewares(self, response: Response, **context):
- for middleware in self._middlewares:
- if isinstance(middleware, BaseResponseMiddleware):
- await middleware.handle_response(response=response, **context)
-
- def add_middleware(self, middleware: BaseMiddleware):
- if not isinstance(middleware, BaseMiddleware):
- raise TypeError("middleware is not of type BaseMiddleware")
- self._middlewares.append(middleware)
diff --git a/resty/middlewares/managers.py b/resty/middlewares/managers.py
new file mode 100644
index 0000000..450b5e1
--- /dev/null
+++ b/resty/middlewares/managers.py
@@ -0,0 +1,47 @@
+from contextlib import contextmanager
+from typing import Iterable
+
+from resty.middlewares.types import BaseMiddlewareManager, BaseMiddleware
+
+
+class MiddlewareManager(BaseMiddlewareManager):
+ def __init__(self, middlewares: Iterable = None):
+ self._middlewares = []
+
+ self.add_middlewares(*middlewares or ())
+
+ @property
+ def middlewares(self) -> Iterable[BaseMiddleware]:
+ return tuple(self._middlewares)
+
+ def add_middleware(self, middleware: BaseMiddleware):
+ if not isinstance(middleware, BaseMiddleware):
+ raise TypeError("Middleware must inherit the base type BaseMiddleware")
+
+ if middleware not in self._middlewares:
+ self._middlewares.append(middleware)
+
+ def add_middlewares(self, *middlewares: BaseMiddleware):
+ for middleware in middlewares:
+ self.add_middleware(middleware)
+
+ def remove_middleware(self, middleware: BaseMiddleware):
+ if middleware in self._middlewares:
+ self._middlewares.remove(middleware)
+
+ def remove_middlewares(self, *middlewares: BaseMiddleware):
+ for middleware in middlewares:
+ self.remove_middleware(middleware)
+
+ @contextmanager
+ def middleware(self, *middlewares: BaseMiddleware):
+ self.add_middlewares(*middlewares)
+ yield
+ self.remove_middlewares(*middlewares)
+
+ async def __call__(
+ self, *args, base: type[BaseMiddleware] = BaseMiddleware, **kwargs
+ ):
+ for middleware in self._middlewares:
+ if isinstance(middleware, base):
+ await middleware(*args, **kwargs)
diff --git a/resty/middlewares/status.py b/resty/middlewares/status.py
new file mode 100644
index 0000000..d4980ae
--- /dev/null
+++ b/resty/middlewares/status.py
@@ -0,0 +1,42 @@
+from typing import Container, Mapping
+
+from resty.middlewares import BaseResponseMiddleware
+from resty.types import Response
+from resty.constants import STATUS_ERRORS
+from resty.exceptions import HTTPError
+
+
+class StatusCheckingMiddleware(BaseResponseMiddleware):
+ def __init__(
+ self,
+ errors: Mapping[int, type[Exception]] = None,
+ default_error: type[Exception] = HTTPError,
+ ):
+ self._errors = errors or STATUS_ERRORS
+ self._default_error = default_error
+
+ @staticmethod
+ def _check_status(actual: int, expected: int | Container[int] = 200) -> bool:
+ if isinstance(expected, Container):
+ return actual in expected
+ return actual == expected
+
+ def _raise_error(self, status: int, *args):
+ exc = self._errors.get(status, self._default_error)
+
+ try:
+ raise exc(*args)
+ except TypeError:
+ raise exc()
+
+ async def __call__(self, response: Response, **kwargs):
+ actual_status = response.status
+ expected_status = kwargs.pop(
+ "expected_status",
+ {200, 201},
+ )
+
+ check_status = kwargs.pop("check_status", True)
+
+ if check_status and not self._check_status(actual_status, expected_status):
+ self._raise_error(actual_status, response)
diff --git a/resty/middlewares/types.py b/resty/middlewares/types.py
index ca78451..b7ef4c7 100644
--- a/resty/middlewares/types.py
+++ b/resty/middlewares/types.py
@@ -1,11 +1,47 @@
-from abc import ABC
+from typing import Iterable
+from contextlib import contextmanager
+from abc import ABC, abstractmethod
-from resty.types import BaseResponseMiddleware, BaseRequestMiddleware
+from resty.types import Request, Response
-class BasePaginationMiddleware(BaseRequestMiddleware, BaseResponseMiddleware, ABC):
- pass
+class BaseMiddleware(ABC):
+ @abstractmethod
+ def __call__(self, *args, **kwargs): ...
-class BaseFilterMiddleware(BaseRequestMiddleware, ABC):
- pass
+class BaseRequestMiddleware(BaseMiddleware, ABC):
+ @abstractmethod
+ def __call__(self, request: Request, **kwargs): ...
+
+
+class BaseResponseMiddleware(BaseMiddleware, ABC):
+ @abstractmethod
+ def __call__(self, response: Response, **kwargs): ...
+
+
+class BaseMiddlewareManager(ABC):
+ @property
+ @abstractmethod
+ def middlewares(self) -> Iterable[BaseMiddleware]: ...
+
+ @abstractmethod
+ def add_middleware(self, middleware: BaseMiddleware): ...
+
+ @abstractmethod
+ def add_middlewares(self, *middlewares: BaseMiddleware): ...
+
+ @abstractmethod
+ def remove_middleware(self, middleware: BaseMiddleware): ...
+
+ @abstractmethod
+ def remove_middlewares(self, *middlewares: BaseMiddleware): ...
+
+ @contextmanager
+ @abstractmethod
+ def middleware(self, *middlewares: BaseMiddleware): ...
+
+ @abstractmethod
+ async def __call__(
+ self, *args, base: type[BaseMiddleware] = BaseMiddleware, **kwargs
+ ): ...
diff --git a/resty/serializers/__init__.py b/resty/serializers/__init__.py
index 0dc4093..15743fe 100644
--- a/resty/serializers/__init__.py
+++ b/resty/serializers/__init__.py
@@ -1,3 +1,7 @@
-from resty.serializers.serializer import Serializer
+from resty.serializers.types import BaseSerializer
+from resty.serializers.serializers import Serializer
-__all__ = ["Serializer"]
+__all__ = [
+ "BaseSerializer",
+ "Serializer",
+]
diff --git a/resty/serializers/serializer.py b/resty/serializers/serializer.py
deleted file mode 100644
index 1ad9843..0000000
--- a/resty/serializers/serializer.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import inspect
-
-from pydantic import BaseModel
-
-from resty.enums import Endpoint
-from resty.types import BaseSerializer
-
-
-class Serializer(BaseSerializer):
-
- @classmethod
- def get_schema(cls, **context) -> type[BaseModel]:
- schema = context.get("schema")
-
- if inspect.isclass(schema) and issubclass(schema, BaseModel):
- return schema
-
- endpoint = context.get("endpoint")
- schema = cls.schema
-
- if cls.schemas is not None:
- schema = cls.schemas.get(
- endpoint,
- cls.schema or cls.schemas.get(Endpoint.BASE)
- )
-
- if schema is None:
- raise TypeError(f"Schema should be specified for {endpoint}")
-
- return schema
-
- @classmethod
- def serialize(cls, obj: BaseModel, **context) -> dict:
- schema = cls.get_schema(**context)
- if not isinstance(obj, schema):
- raise TypeError("Object must be of type {}".format(schema))
- return schema.model_dump(obj)
-
- @classmethod
- def deserialize(cls, data: dict, **context) -> BaseModel:
- schema = cls.get_schema(**context)
- return schema.model_validate(data)
-
- @classmethod
- def deserialize_many(cls, data: list[dict], **context) -> list[BaseModel]:
- return [cls.deserialize(dataset, **context) for dataset in data]
diff --git a/resty/serializers/serializers.py b/resty/serializers/serializers.py
new file mode 100644
index 0000000..4f033c5
--- /dev/null
+++ b/resty/serializers/serializers.py
@@ -0,0 +1,26 @@
+from typing import Iterable, Mapping
+
+from resty.serializers.types import BaseSerializer
+from resty.types import Schema
+
+
+class Serializer(BaseSerializer):
+ @classmethod
+ def serialize(cls, obj: Schema, **kwargs) -> Mapping:
+ return obj.model_dump()
+
+ @classmethod
+ def serialize_many(cls, objs: Iterable[Schema], **kwargs) -> Iterable:
+ return tuple(cls.serialize(obj=obj, **kwargs) for obj in objs)
+
+ @classmethod
+ def deserialize[T: Schema](cls, schema: type[T], data: Mapping, **kwargs) -> T:
+ return schema.model_validate(data)
+
+ @classmethod
+ def deserialize_many[
+ T: Schema
+ ](cls, schema: type[T], data: Iterable, **kwargs,) -> Iterable[T]:
+ return tuple(
+ cls.deserialize(schema=schema, data=dataset, **kwargs) for dataset in data
+ )
diff --git a/resty/serializers/types.py b/resty/serializers/types.py
new file mode 100644
index 0000000..e519728
--- /dev/null
+++ b/resty/serializers/types.py
@@ -0,0 +1,40 @@
+from abc import ABC, abstractmethod
+from typing import Mapping, Iterable
+
+from resty.types import Schema
+
+
+class BaseSerializer(ABC):
+ @classmethod
+ @abstractmethod
+ def serialize(
+ cls,
+ obj: Schema,
+ **kwargs,
+ ) -> Mapping: ...
+
+ @classmethod
+ @abstractmethod
+ def serialize_many(
+ cls,
+ objs: Iterable[Schema],
+ **kwargs,
+ ) -> Iterable: ...
+
+ @classmethod
+ @abstractmethod
+ def deserialize[T: Schema](
+ cls,
+ schema: type[T],
+ data: Mapping,
+ **kwargs,
+ ) -> T: ...
+
+ @classmethod
+ @abstractmethod
+ def deserialize_many[T: Schema](
+ cls,
+ schema: type[T],
+ data: Iterable,
+ **kwargs,
+ ) -> Iterable[T]: ...
diff --git a/resty/types.py b/resty/types.py
index 2e3eaf5..e3567db 100644
--- a/resty/types.py
+++ b/resty/types.py
@@ -1,110 +1,33 @@
-from abc import ABC, abstractmethod
-from typing import Iterable
from dataclasses import dataclass, field
+from typing import Mapping, Iterable
from pydantic import BaseModel
-from resty.enums import Endpoint, Field
-from resty.enums import Method
+from resty.enums import Method, Endpoint
+
+Schema = BaseModel
@dataclass
class Request:
url: str
method: Method
- data: dict = None
- json: dict = None
+ endpoint: Endpoint = None
+ data: Mapping | Iterable = None
+ json: Mapping | Iterable = None
timeout: int | None = None
params: dict = field(default_factory=dict)
headers: dict = field(default_factory=dict)
cookies: dict = field(default_factory=dict)
redirects: bool = False
+ middleware_options: dict = field(default_factory=dict)
@dataclass
class Response:
request: Request
status: int
- data: list | dict = None
-
-
-class BaseMiddleware(ABC):
- pass
-
-
-class BaseRequestMiddleware(BaseMiddleware):
- @abstractmethod
- async def handle_request(self, request: Request, **context): ...
-
-
-class BaseResponseMiddleware(BaseMiddleware):
- @abstractmethod
- async def handle_response(self, response: Response, **context): ...
-
-
-class BaseMiddlewareManager(ABC):
- @abstractmethod
- async def call_request_middlewares(self, request: Request, **context): ...
-
- @abstractmethod
- async def call_response_middlewares(self, response: Response, **context): ...
-
- @abstractmethod
- def add_middleware(self, middleware: BaseMiddleware): ...
-
-
-class BaseRESTClient(ABC):
- @abstractmethod
- def add_middleware(self, middleware: BaseMiddleware): ...
-
- @abstractmethod
- async def request(self, request: Request, **context) -> Response: ...
-
-
-class BaseSerializer:
- schema: type[BaseModel] = None
- schemas: dict[Endpoint, type[BaseModel]] = None
-
- @classmethod
- @abstractmethod
- def get_schema(cls, **context) -> type[BaseModel]: ...
-
- @classmethod
- @abstractmethod
- def serialize(cls, obj: BaseModel, **context) -> dict: ...
-
- @classmethod
- @abstractmethod
- def deserialize(cls, data: dict, **context) -> BaseModel: ...
-
- @classmethod
- @abstractmethod
- def deserialize_many(cls, data: list[dict], **context) -> list[BaseModel]: ...
-
-
-class BaseManager:
- serializer: type[BaseSerializer]
- endpoints: dict[Endpoint, str]
- fields: dict[Field, str]
-
- @classmethod
- @abstractmethod
- async def create(
- cls, client: BaseRESTClient, obj: BaseModel, **kwargs
- ) -> BaseModel: ...
-
- @classmethod
- @abstractmethod
- async def read(cls, client: BaseRESTClient, **kwargs) -> Iterable[BaseModel]: ...
-
- @classmethod
- @abstractmethod
- async def read_one(cls, client: BaseRESTClient, pk: any, **kwargs) -> BaseModel: ...
-
- @classmethod
- @abstractmethod
- async def update(cls, client: BaseRESTClient, obj: BaseModel, **kwargs) -> None: ...
-
- @classmethod
- @abstractmethod
- async def delete(cls, client: BaseRESTClient, pk: any, **kwargs) -> None: ...
+ content: bytes = None
+ text: str = None
+ json: list | dict = None
+ middleware_options: dict = field(default_factory=dict)
diff --git a/tests/clients/conftest.py b/tests/clients/conftest.py
new file mode 100644
index 0000000..bf2e46b
--- /dev/null
+++ b/tests/clients/conftest.py
@@ -0,0 +1,14 @@
+import httpx
+
+
+class HTTPXAsyncClientMock(httpx.AsyncClient): # pragma: nocover
+ def __init__(self, response=None, error: Exception = None):
+ self.response = response
+ self.error = error
+ super().__init__()
+
+ async def request(self, *args, **kwargs):
+ if self.error:
+ raise self.error
+
+ return self.response
diff --git a/tests/clients/test_httpx_client.py b/tests/clients/test_httpx_client.py
new file mode 100644
index 0000000..a552e93
--- /dev/null
+++ b/tests/clients/test_httpx_client.py
@@ -0,0 +1,70 @@
+import pytest
+import httpx
+
+from resty.enums import Method
+from resty.types import Request, Response
+from resty.clients.httpx import RESTClient
+from resty.middlewares import BaseRequestMiddleware, BaseResponseMiddleware
+from resty.exceptions import ConnectError
+
+from tests.clients.conftest import HTTPXAsyncClientMock
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "base",
+ [
+ BaseRequestMiddleware,
+ BaseResponseMiddleware,
+ ],
+)
+async def test_middlewares_calling(
+ base,
+):
+ class Mid(base):
+
+ def __init__(self):
+ self.called = False
+
+ async def __call__(self, request: Request, **kwargs):
+ self.called = True
+
+ mid = Mid()
+
+ client = RESTClient(HTTPXAsyncClientMock(httpx.Response(status_code=200)))
+ client.middlewares.add_middleware(mid)
+
+ await client.request(
+ Request(
+ url="test",
+ method=Method.GET,
+ )
+ )
+
+ assert mid.called
+
+
+@pytest.mark.asyncio
+async def test_request():
+ client = RESTClient(
+ HTTPXAsyncClientMock(httpx.Response(status_code=123)), check_status=False
+ )
+ print(client.middlewares.middlewares)
+
+ response = await client.request(Request(url="test", method=Method.GET))
+
+ assert isinstance(response, Response)
+ assert response.status == 123
+
+
+@pytest.mark.asyncio
+async def test_connection_error():
+ client = RESTClient(
+ HTTPXAsyncClientMock(
+ httpx.Response(status_code=123), error=httpx.ConnectError("test")
+ ),
+ check_status=False,
+ )
+
+ with pytest.raises(ConnectError):
+ await client.request(Request(url="test", method=Method.GET))
diff --git a/tests/managers/conftest.py b/tests/managers/conftest.py
new file mode 100644
index 0000000..e8c9477
--- /dev/null
+++ b/tests/managers/conftest.py
@@ -0,0 +1,26 @@
+import pytest
+
+from resty.clients import BaseRESTClient
+from resty.types import Request, Response
+
+
+class RESTClientMock(BaseRESTClient): # pragma: nocover
+ def __init__(self, response: Response = None, **expected):
+ self.response = response
+ self.expected = expected
+
+ async def request(self, request: Request) -> Response:
+ for key, value in self.expected.items():
+ assert getattr(request, key) == value
+
+ return self.response
+
+
+@pytest.fixture
+def client(request): # pragma: nocover
+ response, expected = request.param
+
+ return RESTClientMock(
+ response=response,
+ **expected,
+ )
diff --git a/tests/managers/test_manager.py b/tests/managers/test_manager.py
new file mode 100644
index 0000000..5a83fdb
--- /dev/null
+++ b/tests/managers/test_manager.py
@@ -0,0 +1,238 @@
+import pytest
+
+from resty.enums import Method, Endpoint, Field
+from resty.managers import Manager
+from resty.types import Response, Request, Schema
+
+from tests.managers.conftest import RESTClientMock
+
+
+class UserCreate(Schema):
+ username: str
+
+
+class UserRead(Schema):
+ id: int
+ username: str
+
+
+class UserUpdate(Schema):
+ id: int
+ username: str
+
+
+class UserManager(Manager):
+ endpoints = {
+ Endpoint.CREATE: "users/",
+ Endpoint.READ: "users/",
+ Endpoint.READ_ONE: "users/{pk}",
+ Endpoint.UPDATE: "users/{pk}",
+ Endpoint.DELETE: "users/{pk}",
+
+ }
+ fields = {
+ Field.PRIMARY: 'id',
+ }
+
+
+class UserManagerForURLBuilding(Manager):
+ endpoints = {
+ Endpoint.READ_ONE: "users/{pk}/{abc}",
+ }
+ fields = {
+ Field.PRIMARY: 'id',
+ }
+
+
+class ManagerWithoutSerializer(Manager):
+ serializer_class = None
+
+
+class ManagerWithInvalidSerializer(Manager):
+ serializer_class = 123
+
+
+class ManagerWithUnspecifiedMethods(Manager):
+ methods = {}
+
+
+class ManagerWithUnspecifiedFields(Manager):
+ pass
+
+
+class ManagerWithPkField(Manager):
+ fields = {
+ Field.PRIMARY: 'id',
+ }
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("client, obj", [
+ (RESTClientMock(json={"username": "test"}, method=Method.POST), UserCreate(username="test")),
+ (RESTClientMock(json={"username": "321"}, method=Method.POST), UserCreate(username="321")),
+ (RESTClientMock(json={"username": "test"}, method=Method.POST), {"username": "test"}),
+])
+async def test_create(client, obj):
+ manager = UserManager()
+
+ await manager.create(client=client, obj=obj)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("data", [
+ ({"username": "test", "id": 1}, {"username": "test123", "id": 2}, {"username": "test321", "id": 3}),
+])
+async def test_read(data):
+ client = RESTClientMock(
+ response=Response(
+ request=Request(url="", method=Method.GET), status=200, json=data, ),
+ method=Method.GET
+ )
+ manager = UserManager()
+
+ objs = await manager.read(client=client, response_type=UserRead)
+
+ assert tuple(obj.model_dump() for obj in objs) == data
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("client, obj", [
+ (
+ RESTClientMock(
+ response=Response(
+ Request("", Method.GET),
+ status=200,
+ json={"username": "test", "id": 123}
+ ), method=Method.GET),
+ UserRead(username="test", id=123)
+ ),
+
+])
+async def test_read_one(client, obj):
+ manager = UserManager()
+
+ assert await manager.read_one(client=client, obj_or_pk=123, response_type=UserRead) == obj
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("client, obj", [
+ (
+ RESTClientMock(
+ response=Response(
+ Request("", Method.GET),
+ status=200,
+ json={"username": "test", "id": 123}
+ ), method=Method.PATCH, url="users/123"),
+ UserUpdate(username="test", id=123)
+ ),
+
+])
+async def test_update(client, obj):
+ manager = UserManager()
+
+ await manager.update(client=client, obj=obj)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("client, pk", [
+ (
+ RESTClientMock(url="users/123", method=Method.DELETE),
+ 123
+ ),
+ (
+ RESTClientMock(url="users/321", method=Method.DELETE),
+ UserRead(username="test", id=321)
+ )
+
+])
+async def test_delete(client, pk):
+ manager = UserManager()
+
+ await manager.delete(client, obj_or_pk=pk)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("client, pk, abc", [
+ (RESTClientMock(url="users/123/hello"), 123, "hello"),
+ (RESTClientMock(url="users/321/world"), 321, "world"),
+])
+async def test_url_building(client, pk, abc):
+ manager = UserManagerForURLBuilding()
+
+ await manager.read_one(client=client, obj_or_pk=pk, abc=abc)
+
+
+@pytest.mark.parametrize("manager", [
+ ManagerWithoutSerializer,
+ ManagerWithInvalidSerializer,
+])
+def test_invalid_or_unspec_serializer(manager):
+ manager = manager()
+
+ with pytest.raises(RuntimeError):
+ manager.get_serializer()
+
+
+def test_get_unspec_method():
+ manager = ManagerWithUnspecifiedMethods()
+
+ with pytest.raises(RuntimeError):
+ manager.get_method(Endpoint.CREATE)
+
+
+def test_get_unspec_field():
+ manager = ManagerWithUnspecifiedFields()
+
+ with pytest.raises(RuntimeError):
+ manager.get_field(Field.PRIMARY)
+
+
+def test_get_field():
+ manager = ManagerWithPkField()
+
+ assert manager.get_field(Field.PRIMARY) == 'id'
+
+
+@pytest.mark.parametrize("obj, pk", [
+ ({'id': "test"}, "test"),
+ (UserRead(id=321, username='123'), 321),
+])
+def test_get_pk(obj, pk):
+ manager = ManagerWithPkField()
+
+ assert manager.get_pk(obj) == pk
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("url", [
+ "test",
+])
+async def test_passing_url(url):
+ client = RESTClientMock(url=url)
+ manager = UserManagerForURLBuilding()
+
+ await manager.delete(client, obj_or_pk=1, url=url)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("data, response_type, result", [
+ ({"username": "test"}, dict, {"username": "test"}),
+ (("username", "test"), list, ["username", "test"]),
+ ({"username": "test", "id": 123}, lambda r: dict.keys(r.json), dict.keys({"username": "test", "id": 123})),
+ ({"username": "test", "id": 123}, lambda r, t: dict.keys(r), {"username": "test", "id": 123})
+])
+async def test_response_type(data, response_type, result):
+ client = RESTClientMock(response=Response(
+ request=Request(
+ url="",
+ method=Method.GET,
+ ),
+ status=200,
+ json=data
+ ))
+
+ manager = UserManager()
+
+ response = await manager.read(client=client, response_type=response_type, )
+
+ assert response == result
diff --git a/tests/managers/test_url_builder.py b/tests/managers/test_url_builder.py
new file mode 100644
index 0000000..20ad907
--- /dev/null
+++ b/tests/managers/test_url_builder.py
@@ -0,0 +1,26 @@
+import pytest
+
+from resty.exceptions import URLBuildingError
+from resty.managers import URLBuilder
+from resty.enums import Endpoint
+
+
+@pytest.mark.parametrize('endpoints, endpoint, base, kwargs, expected', [
+ ({Endpoint.CREATE: "users/", }, Endpoint.CREATE, "base/", {}, "base/users/"),
+ ({Endpoint.BASE: "users/", }, Endpoint.CREATE, "base/", {}, "base/users/"),
+ ({Endpoint.BASE: "users/", }, Endpoint.CREATE, "base", {}, "base/users/"),
+ ({Endpoint.BASE: "users/", }, Endpoint.CREATE, None, {}, "users/"),
+ ({Endpoint.BASE: "users/{pk}", }, Endpoint.CREATE, None, {"pk": 123}, "users/123"),
+ ({}, Endpoint.CREATE, "base/{pk}", {"pk": 123}, "base/123"),
+ ({}, Endpoint.CREATE, None, {"pk": 123}, ""),
+])
+def test_build(endpoints, endpoint, base, kwargs, expected):
+ builder = URLBuilder()
+
+ assert builder.build(endpoints=endpoints, endpoint=endpoint, base_url=base, **kwargs) == expected
+
+
+def test_build_missing_kwargs():
+ with pytest.raises(URLBuildingError):
+ builder = URLBuilder()
+ builder.build({}, Endpoint.CREATE, "users/{pk}")
diff --git a/resty/ext/__init__.py b/tests/middlewares/conftest.py
similarity index 100%
rename from resty/ext/__init__.py
rename to tests/middlewares/conftest.py
diff --git a/tests/middlewares/test_middleware_manager.py b/tests/middlewares/test_middleware_manager.py
new file mode 100644
index 0000000..0784481
--- /dev/null
+++ b/tests/middlewares/test_middleware_manager.py
@@ -0,0 +1,80 @@
+import pytest
+
+from resty.middlewares import (
+ MiddlewareManager,
+ BaseMiddleware,
+ BaseResponseMiddleware,
+ BaseRequestMiddleware,
+)
+
+
+@pytest.mark.parametrize(
+ "base",
+ [
+ BaseMiddleware,
+ BaseResponseMiddleware,
+ BaseRequestMiddleware,
+ ],
+)
+def test_add_middleware(base):
+ manager = MiddlewareManager()
+
+ class MyMiddleware(base):
+ async def __call__(self, *args, **kwargs): ...
+
+ manager.add_middleware(MyMiddleware())
+
+ assert type(tuple(manager.middlewares)[0]) is MyMiddleware
+
+
+@pytest.mark.parametrize(
+ "base",
+ [
+ BaseMiddleware,
+ BaseResponseMiddleware,
+ BaseRequestMiddleware,
+ ],
+)
+def test_remove_middleware(base):
+ manager = MiddlewareManager()
+
+ class MyMiddleware(base):
+ async def __call__(self, *args, **kwargs): ...
+
+ middleware = MyMiddleware()
+
+ manager.add_middleware(middleware)
+
+ manager.remove_middleware(middleware)
+
+ assert len(tuple(manager.middlewares)) == 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "base",
+ [
+ BaseMiddleware,
+ BaseResponseMiddleware,
+ BaseRequestMiddleware,
+ ],
+)
+async def test_with_middleware(base):
+ manager = MiddlewareManager()
+
+ class MyMiddleware(base):
+ async def __call__(self, *args, **kwargs):
+ assert args[0] == "test"
+
+ with manager.middleware(MyMiddleware()):
+ assert len(tuple(manager.middlewares)) == 1
+
+ await manager("test")
+
+ assert len(tuple(manager.middlewares)) == 0
+
+
+def test_invalid_middleware():
+ with pytest.raises(TypeError):
+ manager = MiddlewareManager()
+ manager.add_middleware("test")
diff --git a/tests/middlewares/test_status_checking_middleware.py b/tests/middlewares/test_status_checking_middleware.py
new file mode 100644
index 0000000..b65a85a
--- /dev/null
+++ b/tests/middlewares/test_status_checking_middleware.py
@@ -0,0 +1,42 @@
+import pytest
+
+from resty.enums import Method
+from resty.middlewares import StatusCheckingMiddleware
+from resty.types import Response, Request
+
+
+class CustomError(Exception):
+ pass
+
+
+class ErrorWithConstructor(Exception):
+ def __init__(self):
+ pass
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "expected_status, actual_status, error",
+ [
+ (200, 404, CustomError),
+ ((200, 201), 401, CustomError),
+ ((200, 204), 403, ErrorWithConstructor),
+ ],
+)
+async def test_invalid_status_check(expected_status, actual_status, error):
+ mid = StatusCheckingMiddleware(errors={actual_status: error})
+ with pytest.raises(error):
+ await mid(
+ response=Response(
+ request=Request(
+ url="",
+ method=Method.GET,
+ ),
+ status=actual_status,
+ content=b"",
+ text="",
+ json={},
+ middleware_options={},
+ ),
+ expected_status=expected_status,
+ )
diff --git a/tests/resty/managers/test_manager.py b/tests/resty/managers/test_manager.py
deleted file mode 100644
index e69de29..0000000
diff --git a/tests/resty/serializers/test_serializer.py b/tests/resty/serializers/test_serializer.py
deleted file mode 100644
index abfb8fd..0000000
--- a/tests/resty/serializers/test_serializer.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import pytest
-from pydantic import BaseModel
-
-from resty.enums import Endpoint
-from resty.serializers import Serializer
-
-
-class Schema1(BaseModel):
- id: int
- name: str
- private: bool
-
-
-class CreateSchema(BaseModel):
- pass
-
-
-class ReadSchema(BaseModel):
- pass
-
-
-@pytest.mark.parametrize(
- "model, data", [(Schema1, {"id": 1, "name": "test", "private": True})]
-)
-def test_serializer_serialize(model, data):
- class Ser(Serializer):
- schema = model
-
- data = Ser.serialize(model(**data))
-
- assert data == data
-
-
-@pytest.mark.parametrize(
- "model, data", [(Schema1, {"id": 1, "name": "test", "private": True})]
-)
-def test_serializer_deserialize(model, data):
- class Ser(Serializer):
- schema = model
-
- obj = Ser.deserialize(data)
-
- assert obj.model_dump() == data
-
-
-@pytest.mark.parametrize(
- "model, data", [(Schema1, [{"id": 1, "name": "test", "private": True}])]
-)
-def test_serializer_deserialize_many(model, data):
- class Ser(Serializer):
- schema = model
-
- objects = Ser.deserialize_many(data)
-
- assert len(objects) == len(data)
-
-
-@pytest.mark.parametrize(
- "models, context, expected_schema", [
- ({Endpoint.CREATE: CreateSchema}, {"endpoint": Endpoint.CREATE}, CreateSchema),
- ({Endpoint.READ: ReadSchema}, {"endpoint": Endpoint.READ}, ReadSchema),
- ({Endpoint.READ_ONE: Schema1}, {"endpoint": Endpoint.READ_ONE}, Schema1),
- ({Endpoint.BASE: Schema1}, {"endpoint": Endpoint.READ_ONE}, Schema1),
- ({Endpoint.BASE: Schema1}, {"schema": ReadSchema}, ReadSchema),
- ]
-)
-def test_get_schema(models, context, expected_schema):
- class Ser(Serializer):
- schemas = models
-
- assert Ser.get_schema(**context) == expected_schema
-
-
-def test_get_unspecified_schema():
- class Ser(Serializer):
- pass
-
- with pytest.raises(TypeError):
- Ser.get_schema()
diff --git a/tests/serializers/test_serializer.py b/tests/serializers/test_serializer.py
new file mode 100644
index 0000000..6c62429
--- /dev/null
+++ b/tests/serializers/test_serializer.py
@@ -0,0 +1,61 @@
+from resty.types import Schema
+from resty.serializers import Serializer
+
+
+def test_serialize():
+ class MySchema(Schema):
+ id: int
+ name: str
+ private: bool
+
+ serializer = Serializer()
+
+ obj = MySchema(id=1, name="test", private=False)
+
+ assert serializer.serialize(obj) == obj.model_dump()
+
+
+def test_serialize_many():
+ class MySchema(Schema):
+ id: int
+ name: str
+ private: bool
+
+ serializer = Serializer()
+
+ objs = [
+ MySchema(id=1, name="test", private=False),
+ MySchema(id=2, name="test2", private=True),
+ ]
+
+ assert serializer.serialize_many(objs) == tuple(obj.model_dump() for obj in objs)
+
+
+def test_deserialize():
+ class MySchema(Schema):
+ id: int
+ name: str
+ private: bool
+
+ serializer = Serializer()
+
+ obj = MySchema(id=1, name="test", private=False)
+
+ assert serializer.deserialize(MySchema, obj.model_dump()) == obj
+
+
+def test_deserialize_many():
+ class MySchema(Schema):
+ id: int
+ name: str
+ private: bool
+
+ serializer = Serializer()
+
+ objs = (
+ MySchema(id=1, name="test", private=False),
+ MySchema(id=2, name="test2", private=True),
+ )
+ serialized = tuple(obj.model_dump() for obj in objs)
+
+ assert serializer.deserialize_many(MySchema, serialized) == objs