Skip to content

Commit

Permalink
feat: support AsyncIterator[bytes] instead of requiring BinaryStream …
Browse files Browse the repository at this point in the history
…class

BREAKING_CHANGES: xpresso.BinaryStream is removed
  • Loading branch information
adriangb committed Apr 3, 2022
1 parent 572e449 commit 01276e6
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 59 deletions.
5 changes: 1 addition & 4 deletions docs/tutorial/files.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@ This can be convenient if you know the files are not large.

## As a stream

If you want to read the bytes without buffering to disk or memory, use `xpresso.BinaryStream` as the type:
If you want to read the bytes without buffering to disk or memory, use `AsyncIterator[bytes]` as the type:

```python
--8<-- "docs_src/tutorial/files/tutorial_003.py"
```

!!! note "Implementation detail"
`xpresso.BinaryStream` is just a thin wrapper around `typing.AsyncIterator[bytes]`, but you must use `xpresso.BinaryStream` instead of `typing.AsyncIterator[bytes]` directly, otherwise Xpresso won't know how to build the argument.

## Setting the expected content-type

You can set the media type via the `media_type` parameter to `File()` and enforce it via the `enforce_media_type` parameter:
Expand Down
8 changes: 6 additions & 2 deletions docs_src/tutorial/files/tutorial_003.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from xpresso import App, BinaryStream, FromFile, Path
from typing import AsyncIterator

from xpresso import App, FromFile, Path

async def count_bytes_in_file(data: FromFile[BinaryStream]) -> int:

async def count_bytes_in_file(
data: FromFile[AsyncIterator[bytes]],
) -> int:
size = 0
async for chunk in data:
size += len(chunk)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "xpresso"
version = "0.37.0"
version = "0.38.0"
description = "A developer centric, performant Python web framework"
authors = ["Adrian Garcia Badaracco <[email protected]>"]
readme = "README.md"
Expand Down
18 changes: 10 additions & 8 deletions tests/test_request_bodies/test_file.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict, Optional
from typing import Any, AsyncIterator, Dict, Optional

import pytest
from starlette.responses import Response
from starlette.testclient import TestClient

from xpresso import App, BinaryStream, File, Path, UploadFile
from xpresso import App, File, Path, UploadFile
from xpresso.bodies import FromFile
from xpresso.typing import Annotated

Expand Down Expand Up @@ -36,7 +36,7 @@ async def endpoint(file: Annotated[UploadFile, File(consume=consume)]) -> Respon


def test_extract_into_stream():
async def endpoint(file: FromFile[BinaryStream]) -> Response:
async def endpoint(file: FromFile[AsyncIterator[bytes]]) -> Response:
data = bytearray()
async for chunk in file:
data.extend(chunk)
Expand All @@ -51,14 +51,14 @@ async def endpoint(file: FromFile[BinaryStream]) -> Response:


def test_read_into_stream():
async def endpoint(file: Annotated[BinaryStream, File(consume=False)]) -> Response:
async def endpoint(
file: Annotated[AsyncIterator[bytes], File(consume=False)]
) -> Response:
...

app = App([Path("/", post=endpoint)])

with pytest.raises(
ValueError, match="consume=False is not supported for BinaryStream"
):
with pytest.raises(ValueError, match="consume=False is not supported for streams"):
with TestClient(app):
pass

Expand Down Expand Up @@ -123,7 +123,9 @@ async def endpoint(
def test_extract_into_stream_empty_file(
data: Optional[bytes],
):
async def endpoint(file: FromFile[Optional[BinaryStream]] = None) -> Response:
async def endpoint(
file: FromFile[Optional[AsyncIterator[bytes]]] = None,
) -> Response:
assert file is None
return Response()

Expand Down
3 changes: 1 addition & 2 deletions xpresso/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Json,
Multipart,
)
from xpresso.datastructures import BinaryStream, UploadFile
from xpresso.datastructures import UploadFile
from xpresso.dependencies import Depends
from xpresso.exception_handlers import ExcHandler
from xpresso.parameters import (
Expand All @@ -40,7 +40,6 @@
from xpresso.websockets import WebSocket

__all__ = (
"BinaryStream",
"ExcHandler",
"Operation",
"Path",
Expand Down
15 changes: 13 additions & 2 deletions xpresso/_utils/pydantic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,25 @@


def model_field_from_param(
param: inspect.Parameter, alias: typing.Optional[str] = None
param: inspect.Parameter,
alias: typing.Optional[str] = None,
arbitrary_types_allowed: bool = False,
) -> ModelField:

Config = BaseConfig
if arbitrary_types_allowed:

class _Config(BaseConfig):
arbitrary_types_allowed = True

Config = _Config

return ModelField.infer(
name=alias or param.name,
value=param.default if param.default is not param.empty else ...,
annotation=param.annotation,
class_validators={},
config=BaseConfig,
config=Config,
)


Expand Down
50 changes: 32 additions & 18 deletions xpresso/binders/_binders/file_body.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import collections.abc
import enum
import inspect
import typing

from pydantic.fields import ModelField
from starlette.datastructures import UploadFile
from starlette.requests import HTTPConnection, Request

from xpresso._utils.pydantic_utils import model_field_from_param
Expand All @@ -14,10 +17,27 @@
SupportsExtractor,
SupportsOpenAPI,
)
from xpresso.datastructures import BinaryStream, UploadFile
from xpresso.openapi import models as openapi_models
from xpresso.openapi._utils import parse_examples
from xpresso.typing import Some


class FileType(enum.Enum):
bytes = enum.auto()
uploadfile = enum.auto()
stream = enum.auto()


STREAM_TYPES = (typing.AsyncIterator, typing.AsyncGenerator, typing.AsyncIterable, collections.abc.AsyncGenerator, collections.abc.AsyncIterable, collections.abc.AsyncIterator) # type: ignore


def get_file_type(field: ModelField) -> FileType:
if field.type_ is bytes:
return FileType.bytes
if inspect.isclass(field.type_) and issubclass(field.type_, UploadFile):
return FileType.uploadfile
if field.type_ in STREAM_TYPES: # type: ignore
return FileType.stream
raise TypeError(f"Target type {field.type_.__name__} is not recognized")


async def consume_into_bytes(request: Request) -> bytes:
Expand Down Expand Up @@ -61,8 +81,8 @@ async def read_into_uploadfile(request: Request) -> UploadFile:
return read_into_uploadfile


async def consume_into_stream(request: Request) -> BinaryStream:
return BinaryStream(request.stream())
async def consume_into_stream(request: Request) -> typing.AsyncIterator[bytes]:
return request.stream()


class Extractor(typing.NamedTuple):
Expand All @@ -82,9 +102,7 @@ async def extract(self, connection: HTTPConnection) -> typing.Any:
if media_type is None and connection.headers.get("content-length", "0") == "0":
return validate_body_field(None, field=self.field, loc=("body",))
self.media_type_validator.validate(media_type)
return validate_body_field(
Some(await self.consumer(connection)), field=self.field, loc=("body",)
)
return await self.consumer(connection)


class ExtractorMarker(typing.NamedTuple):
Expand All @@ -98,27 +116,23 @@ def register_parameter(self, param: inspect.Parameter) -> SupportsExtractor:
else:
media_type_validator = MediaTypeValidator(None)
consumer: typing.Callable[[Request], typing.Any]
field = model_field_from_param(param)
if field.type_ is bytes:
field = model_field_from_param(param, arbitrary_types_allowed=True)
file_type = get_file_type(field)
if file_type is FileType.bytes:
if self.consume:
consumer = consume_into_bytes
else:
consumer = read_into_bytes
elif inspect.isclass(field.type_) and issubclass(field.type_, UploadFile):
elif file_type is FileType.uploadfile:
if self.consume:
consumer = create_consume_into_uploadfile(field.type_)
else:
consumer = create_read_into_uploadfile(field.type_)
elif field.type_ is BinaryStream:
# a stream
else: # stream
if self.consume:
consumer = consume_into_stream
else:
raise ValueError("consume=False is not supported for BinaryStream")
else:
raise TypeError(
f"Target type {field.type_.__name__} is not recognized, you must use `bytes`, `xpresso.UploadFile` or `xpresso.BinaryStream`"
)
raise ValueError("consume=False is not supported for streams")
return Extractor(
media_type_validator=media_type_validator,
consumer=consumer,
Expand Down Expand Up @@ -172,7 +186,7 @@ class OpenAPIMarker(typing.NamedTuple):
include_in_schema: bool

def register_parameter(self, param: inspect.Parameter) -> SupportsOpenAPI:
field = model_field_from_param(param)
field = model_field_from_param(param, arbitrary_types_allowed=True)
examples = parse_examples(self.examples) if self.examples else None
required = field.required is not False
return OpenAPI(
Expand Down
15 changes: 10 additions & 5 deletions xpresso/binders/_binders/form_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class FormFileExtractorMarker(typing.NamedTuple):
alias: typing.Optional[str]

def register_parameter(self, param: inspect.Parameter) -> FormFileExtractor:
field = model_field_from_param(param)
field = model_field_from_param(param, arbitrary_types_allowed=True)
repeated = is_sequence_like(field)
if field.type_ is bytes:

Expand Down Expand Up @@ -208,12 +208,13 @@ class FormFileOpenAPIMarker(typing.NamedTuple):
alias: typing.Optional[str]

def register_parameter(self, param: inspect.Parameter) -> FormFileOpenAPI:
field = model_field_from_param(param, arbitrary_types_allowed=True)
return FormFileOpenAPI(
field_name=self.alias or param.name,
media_type=self.media_type,
format=self.format,
nullable=model_field_from_param(param).allow_none,
repeated=is_sequence_like(model_field_from_param(param)),
nullable=field.allow_none,
repeated=is_sequence_like(field),
)


Expand Down Expand Up @@ -377,8 +378,12 @@ def register_parameter(self, param: inspect.Parameter) -> SupportsOpenAPI:
).register_parameter(field_param)
field_name = field_openapi.field_name
field_openapi_providers[field_name] = field_openapi
field = model_field_from_param(field_param)
if field.required is not False:
if (
model_field_from_param(
field_param, arbitrary_types_allowed=True
).required
is not False
):
required_fields.append(field_name)
examples = parse_examples(self.examples) if self.examples else None
return OpenAPI(
Expand Down
18 changes: 1 addition & 17 deletions xpresso/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, AsyncIterator, Callable, Iterable, Type
from typing import Any, Callable, Iterable, Type

from starlette.datastructures import UploadFile as StarletteUploadFile

Expand All @@ -13,19 +13,3 @@ async def read(self, size: int = -1) -> bytes:
# this is implemented just to fix the return type annotation
# which is always bytes
return await super().read(size) # type: ignore


class BinaryStream(AsyncIterator[bytes]):
def __init__(self, stream: AsyncIterator[bytes]) -> None:
self._stream = stream

def __aiter__(self) -> AsyncIterator[bytes]:
return self._stream.__aiter__()

async def __anext__(self) -> bytes: # pragma: no cover
return await self._stream.__anext__()

@classmethod
def __get_validators__(cls) -> Iterable[Callable[..., Any]]:
# this is required so that this class can be a Pydantic field
return iter(())

0 comments on commit 01276e6

Please sign in to comment.