Skip to content

Commit

Permalink
feat: allow binding of dependencies from within lifespans (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored Apr 22, 2022
1 parent 26f983f commit b15dcc1
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 40 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "xpresso"
version = "0.38.3"
version = "0.39.0"
description = "A developer centric, performant Python web framework"
authors = ["Adrian Garcia Badaracco <[email protected]>"]
readme = "README.md"
Expand Down
23 changes: 23 additions & 0 deletions tests/test_dependencies/test_dependency_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,29 @@ async def endpoint(app: App) -> Response:
assert resp.status_code == 200


def test_bind_from_lifespan() -> None:
class Foo:
pass

class Bar(Foo):
pass

@asynccontextmanager
async def lifespan(app: App) -> AsyncIterator[None]:
with app.dependency_overrides as overrides:
overrides[Foo] = Bar
yield

async def endpoint(foo: Foo) -> None:
assert isinstance(foo, Bar)

app = App([Path("/", get=endpoint)], lifespan=lifespan)

with TestClient(app=app) as client:
resp = client.get("/")
assert resp.status_code == 200


def test_default_scope_for_autowired_deps() -> None:
"""Child dependencies of an "endpoint" scoped dep (often the endpoint itself)
should have a "connection" scope so that they are compatible with the default scope of Depends().
Expand Down
99 changes: 63 additions & 36 deletions xpresso/applications.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import contextlib
import functools
import inspect
import typing

import starlette.types
from di.api.dependencies import DependantBase
from di.api.solved import SolvedDependant
from di.container import Container, ContainerState, bind_by_type
from di.dependant import Dependant, JoinedDependant
from di.executors import AsyncExecutor
Expand Down Expand Up @@ -94,43 +96,64 @@ def __init__(

@contextlib.asynccontextmanager
async def lifespan_ctx(*_: typing.Any) -> typing.AsyncIterator[None]:
lifespans, lifespan_deps = self._setup()
# first run setup to find all routes, their lifespans and callbacks to solve them
lifespans, prepare_cbs = self._setup()
self._setup_run = True
original_container = self.container
placeholder = Dependant(lambda: None, scope="app")
dep: "DependantBase[typing.Any]"
executor = AsyncExecutor()
async with self._container_state.enter_scope(
"app"
) as self._container_state:
# now solve and execute all lifespans
# lifespans can get a reference to the container and create/replace binds
# so it is important that we execute them before solving the endpoints
if lifespan is not None:
dep = Dependant(
_wrap_lifespan_as_async_generator(lifespan), scope="app"
)
else:

async def null_lifespan() -> typing.AsyncIterator[None]:
yield

dep = Dependant(null_lifespan, scope="app")
dep = placeholder
solved = self.container.solve(
JoinedDependant(
dep,
siblings=[
*(
Dependant(lifespan, scope="app")
for lifespan in lifespans
),
*lifespan_deps,
Dependant(lifespan, scope="app") for lifespan in lifespans
],
),
scopes=Scopes,
)
try:
await self.container.execute_async(
solved, executor=AsyncExecutor(), state=self._container_state
solved, executor=executor, state=self._container_state
)
# now we can solve the endpoints
# we accumulate any endpoint dependencies that are part of the "app"
# scope and execute them immediately so that their setup and teardown
# run in the same task
# (the server will create separate tasks for the lifespan and endpoint,
# if we run app scoped dependencies lazily the setup would run in a different
# scope than the teardown)
lifespan_deps: "typing.List[DependantBase[typing.Any]]" = []
for cb in prepare_cbs:
prepared = cb()
lifespan_deps.extend(
d for d in prepared.dag if d.scope == "app"
)
await self.container.execute_async(
self.container.solve(
JoinedDependant(
placeholder,
siblings=lifespan_deps,
),
scopes=Scopes,
),
executor,
state=self._container_state,
)
yield
finally:
# make this cm reentrant for testing purposes
self.container = original_container
# make this context manager reentrant for testing purposes
self._setup_run = False
self._container_state = ContainerState()

Expand Down Expand Up @@ -189,7 +212,9 @@ async def __call__(
return
# http or websocket
if not self._setup_run:
self._setup()
*_, prepare_callbacks = self._setup()
for cb in prepare_callbacks:
cb()
if "extensions" not in scope:
scope["extensions"] = extensions = {}
else:
Expand All @@ -210,11 +235,13 @@ def _setup(
self,
) -> typing.Tuple[
typing.List[typing.Callable[..., typing.AsyncIterator[None]]],
typing.List[DependantBase[typing.Any]],
typing.List[typing.Callable[[], SolvedDependant[typing.Any]]],
]:
lifespans: typing.List[typing.Callable[..., typing.AsyncIterator[None]]] = []
lifespan_dependants: typing.List[DependantBase[typing.Any]] = []
seen_routers: typing.Set[typing.Any] = set()
lifespans: "typing.List[typing.Callable[..., typing.AsyncIterator[None]]]" = []
seen_routers: "typing.Set[typing.Any]" = set()
prepare_cbs: "typing.List[typing.Callable[[], SolvedDependant[typing.Any]]]" = (
[]
)
for route in visit_routes(
app_type=App, router=self.router, nodes=[self, self.router], path=""
):
Expand All @@ -232,29 +259,29 @@ def _setup(
)
if isinstance(route.route, Path):
for operation in route.route.operations.values():
operation.prepare(
prepare_cbs.append(
functools.partial(
operation.prepare,
dependencies=[
*dependencies,
*route.route.dependencies,
*operation.dependencies,
],
container=self.container,
)
)
elif isinstance(route.route, WebSocketRoute):
prepare_cbs.append(
functools.partial(
route.route.prepare,
dependencies=[
*dependencies,
*route.route.dependencies,
*operation.dependencies,
],
container=self.container,
)
for dep in operation.dependant.get_flat_subdependants():
if dep.scope == "app":
lifespan_dependants.append(dep)
elif isinstance(route.route, WebSocketRoute):
route.route.prepare(
dependencies=[
*dependencies,
*route.route.dependencies,
],
container=self.container,
)
for dep in route.route.dependant.get_flat_subdependants():
if dep.scope == "app":
lifespan_dependants.append(dep)
return lifespans, lifespan_dependants
return lifespans, prepare_cbs

def get_openapi(
self, servers: typing.List[openapi_models.Server]
Expand Down
5 changes: 3 additions & 2 deletions xpresso/routing/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from starlette.datastructures import URLPath
from starlette.requests import HTTPConnection, Request
from starlette.responses import JSONResponse, Response
from starlette.routing import BaseRoute, NoMatchFound, get_name
from starlette.routing import BaseRoute, NoMatchFound, get_name # type: ignore
from starlette.types import ASGIApp, Receive, Scope, Send

import xpresso.openapi.models as openapi_models
Expand Down Expand Up @@ -158,7 +158,7 @@ def prepare(
self,
container: Container,
dependencies: typing.Iterable[DependantBase[typing.Any]],
) -> None:
) -> SolvedDependant[typing.Any]:
self.dependant = container.solve(
JoinedDependant(
EndpointDependant(self.endpoint, sync_to_thread=self._sync_to_thread),
Expand All @@ -178,6 +178,7 @@ def prepare(
response_encoder=self._response_encoder,
response_factory=self._response_factory,
)
return self.dependant

def url_path_for(self, name: str, **path_params: str) -> URLPath:
if path_params:
Expand Down
3 changes: 2 additions & 1 deletion xpresso/routing/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def prepare(
self,
container: Container,
dependencies: typing.Iterable[DependantBase[typing.Any]],
) -> None:
) -> SolvedDependant[typing.Any]:
self.dependant = container.solve(
JoinedDependant(
EndpointDependant(self.endpoint),
Expand All @@ -108,3 +108,4 @@ def prepare(
executor=executor,
container=container,
)
return self.dependant

0 comments on commit b15dcc1

Please sign in to comment.