Skip to content

Commit

Permalink
Merge pull request #5 from simonsobs/with_single
Browse files Browse the repository at this point in the history
Improve the `gen` attribute support of the context manager hook
  • Loading branch information
TaiSakuma authored Nov 28, 2023
2 parents cd31b92 + 147ca29 commit 2b452de
Show file tree
Hide file tree
Showing 19 changed files with 1,201 additions and 417 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/type-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ jobs:
pip install --upgrade mypy
- name: Run mypy
run: mypy --show-traceback src/ tests/
run: mypy --show-traceback src/
2 changes: 2 additions & 0 deletions src/apluggy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
'HookimplMarker',
'contextmanager',
'asynccontextmanager',
'stack_gen_ctxs',
]


Expand All @@ -14,3 +15,4 @@

from ._decorator import asynccontextmanager
from ._wrap import PluginManager
from .gen import stack_gen_ctxs
30 changes: 19 additions & 11 deletions src/apluggy/_wrap/awith.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,13 @@


class AWith:
def __init__(self, pm: PluginManager_) -> None:
def __init__(self, pm: PluginManager_, reverse: bool = False) -> None:
self.pm = pm
self.reverse = reverse

def __getattr__(self, name: str) -> Callable[..., AsyncContextManager]:
hook: HookCaller = getattr(self.pm.hook, name)
return _Call(hook)


class AWithReverse:
def __init__(self, pm: PluginManager_) -> None:
self.pm = pm

def __getattr__(self, name: str) -> Callable[..., AsyncContextManager]:
hook: HookCaller = getattr(self.pm.hook, name)
return _Call(hook, reverse=True)
return _Call(hook, reverse=self.reverse)


def _Call(
Expand All @@ -44,4 +36,20 @@ async def call(*args: Any, **kwargs: Any) -> AsyncIterator[list]:

yield yields

# TODO: The following commented out code is an attempt to support
# `asend()` through the `gen` attribute. It only works for
# simple cases. It doesn't work with starlette.lifespan().
# When starlette is shutting down, an exception is raised
# `RuntimeError: generator didn't stop after athrow()`.

# stop = False
# while not stop:
# sent = yield yields
# try:
# yields = await asyncio.gather(
# *[ctx.gen.asend(sent) for ctx in ctxs]
# )
# except StopAsyncIteration:
# stop = True

return call
8 changes: 4 additions & 4 deletions src/apluggy/_wrap/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from pluggy._hooks import _Plugin

from .ahook import AHook
from .awith import AWith, AWithReverse
from .with_ import With, WithReverse
from .awith import AWith
from .with_ import With


class PluginManager(PluginManager_):
Expand Down Expand Up @@ -116,9 +116,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.ahook = AHook(self)
self.with_ = With(self)
self.with_reverse = WithReverse(self)
self.with_reverse = With(self, reverse=True)
self.awith = AWith(self)
self.awith_reverse = AWithReverse(self)
self.awith_reverse = AWith(self, reverse=True)

def register(
self, plugin: _Plugin | Callable[[], _Plugin], name: Optional[str] = None
Expand Down
118 changes: 13 additions & 105 deletions src/apluggy/_wrap/with_.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,27 @@
import contextlib
from collections.abc import Callable, Generator
from dataclasses import dataclass
from typing import Any, Optional
from collections.abc import Callable
from typing import Any

from exceptiongroup import BaseExceptionGroup
from pluggy import HookCaller
from pluggy import PluginManager as PluginManager_

from ..gen import stack_gen_ctxs

GenCtxManager = contextlib._GeneratorContextManager


class With:
def __init__(self, pm: PluginManager_) -> None:
self.pm = pm

def __getattr__(self, name: str) -> Callable[..., GenCtxManager]:
hook: HookCaller = getattr(self.pm.hook, name)
return _Call(hook)


class WithReverse:
def __init__(self, pm: PluginManager_) -> None:
def __init__(self, pm: PluginManager_, reverse: bool = False) -> None:
self.pm = pm
self.reverse = reverse

def __getattr__(self, name: str) -> Callable[..., GenCtxManager]:
def __getattr__(self, name: str) -> Callable[..., GenCtxManager[list]]:
hook: HookCaller = getattr(self.pm.hook, name)
return _Call(hook, reverse=True)


def _Call(
hook: Callable[..., list[GenCtxManager]], reverse: bool = False
) -> Callable[..., GenCtxManager]:
@contextlib.contextmanager
def call(*args: Any, **kwargs: Any) -> Generator[list, Any, list]:
ctxs = hook(*args, **kwargs)
if reverse:
ctxs = list(reversed(ctxs))
with contextlib.ExitStack() as stack:
yields = [stack.enter_context(ctx) for ctx in ctxs]

# yield yields

# This function could end here with the above line uncommented
# for a normal usage of context managers.

# Instead, yield from another generator method that supports
# `send()` and `throw()` and returns the return values of the
# hook implementations.

# TODO: Stop yielding from _support_gen() and simply uncomment
# above `yield yields` as Nextline no longer uses `send()` or
# `throw()`. ExitStack correctly executes the code after the yield
# statement in the reverse order of entering the contexts and
# propagates exceptions from inner contexts to outer contexts.
# _support_gen() also executes the code after the first yield in
# the reverse order. However, it might not be the most sensible
# order if `send()` is used. _support_gen() doesn't propagate the
# exceptions in the same way as ExitStack.

returns = yield from _support_gen(yields, ctxs)
return returns

return call


def _support_gen(yields: list, ctxs: list[GenCtxManager]) -> Generator[list, Any, list]:
'''This generator method
1. supports `send()` through the `gen` attribute
(https://stackoverflow.com/a/68304565/7309855),
2. supports `throw()` through the `gen` attribute,
3. and returns the return values of the hook implementations.
TODO: Support `close()`.
'''

@dataclass
class _Context:
context: GenCtxManager
stop_iteration: Optional[StopIteration] = None

contexts = [_Context(context=ctx) for ctx in ctxs]

while True:
try:
sent = yield yields
except BaseException as thrown:
# gen.throw() has been called.
# Throw the exception to all hook implementations
# that have not exited.
raised: list[BaseException] = []
for c in contexts:
if c.stop_iteration:
continue
try:
c.context.gen.throw(thrown)
except StopIteration:
pass
except BaseException as e:
raised.append(e)
if raised:
raise BaseExceptionGroup('Raised in hook implementations.', raised)
raise

yields = []
for c in reversed(contexts): # close in the reversed order after yielding
y = None
if not c.stop_iteration:
try:
y = c.context.gen.send(sent)
except StopIteration as e:
c.stop_iteration = e
yields.append(y)
def call(*args: Any, **kwargs: Any) -> GenCtxManager[list]:
ctxs = hook(*args, **kwargs)
if self.reverse:
ctxs = list(reversed(ctxs))
return stack_gen_ctxs(ctxs)

if all(c.stop_iteration for c in contexts):
# All hook implementations have exited.
# Collect return values from StopIteration.
returns = [c.stop_iteration and c.stop_iteration.value for c in contexts]
return returns
return call
154 changes: 154 additions & 0 deletions src/apluggy/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import contextlib
import sys
from collections.abc import Generator, Sequence
from typing import Any, TypeVar

T = TypeVar('T')

GenCtxMngr = contextlib._GeneratorContextManager


@contextlib.contextmanager
def stack_gen_ctxs(ctxs: Sequence[GenCtxMngr[T]]) -> Generator[list[T], Any, Any]:
'''Manage multiple context managers with the support of the `gen` attribute.
A context manager can receive values inside the `with` block with multiple
`yield` statements. You can send a value to the context manager with the
`send()` method of the `gen` attribute as explained in
https://stackoverflow.com/a/68304565/7309855.
This function lets you stack multiple context managers each with multiple
`yield` statements and send values to them.
Example: Suppose you have two context managers `ctx0` and `ctx1`:
>>> @contextlib.contextmanager
... def ctx0():
... print('ctx0: enter')
... sent = yield 'ctx0: yield 0'
... print('ctx0: received', sent)
... yield 'ctx0: yield 1'
... print('ctx0: exit')
>>> @contextlib.contextmanager
... def ctx1():
... print('ctx1: enter')
... sent = yield 'ctx1: yield 0'
... print('ctx1: received', sent)
... yield 'ctx1: yield 1'
... print('ctx1: exit')
Stack these context managers with `stack_gen_ctxs()`:
>>> with (stack := stack_gen_ctxs([ctx0(), ctx1()])) as yields:
... print('main: received', yields)
... yields = stack.gen.send('send 0')
... print('main: received', yields)
ctx0: enter
ctx1: enter
main: received ['ctx0: yield 0', 'ctx1: yield 0']
ctx1: received send 0
ctx0: received send 0
main: received ['ctx1: yield 1', 'ctx0: yield 1']
ctx1: exit
ctx0: exit
As the output indicates, the context managers are called in the reverse
order after the first `yield` statement as if they were nested with the
`with` block. In the above example, `ctx1` is the inner context manager and
`ctx0` is the outer context manager.
In addition to the `send()` method, you can also use the `throw()` and `close()`
methods of the `gen` attribute.
An exception will be propagated from an inner context manager to an outer
context manager. The propagation stops if a context manager handles the
exception.
'''

try:
# Append a context manager as it is entered and remove one as it is exited.
entered = list[GenCtxMngr]()

ys = []
for ctx in ctxs:
y = ctx.__enter__()
entered.append(ctx)
ys.append(y)

while True:
sent = None
raised = False # True if an exception is raised at `yield`
broken = False # Used to break `while` loop from inside `for` loop.
try:
sent = yield ys
except BaseException:
raised = True
exc_info = sys.exc_info()
else:
exc_info = (None, None, None)

ys = []

for ctx in list(reversed(entered)): # From the innermost to outwards.
try:
match exc_info[1]:
case val if isinstance(val, GeneratorExit):
ctx.gen.close()
raise exc_info[1].with_traceback(exc_info[2])
case val if isinstance(val, BaseException):
try:
ctx.gen.throw(*exc_info)
except StopIteration: # `ctx` has exited.
entered.remove(ctx)
exc_info = (None, None, None)
case None:
if raised:
# The exception has been handled by an inner
# context manager. However, still exit so as to
# reproduce the behavior of an reference
# implementation with `contextlib.ExitStack`
# when `gen.send()` is not used.
broken = True # Break from the outer `while` loop.
break
try:
y = ctx.gen.send(sent)
ys.append(y)
except StopIteration:
entered.remove(ctx)
case _:
raise NotImplementedError()
except BaseException:
entered.remove(ctx)
exc_info = sys.exc_info()
else:
exc_info = (None, None, None)

if broken: # broke from the inner `for` loop
break

if isinstance(exc_info[1], BaseException):
# An exception is still outstanding after the outermost context manager.
raise exc_info[1].with_traceback(exc_info[2])

if not entered:
break

except BaseException:
exc_info = sys.exc_info()
else:
exc_info = (None, None, None)
finally:
# Exit the remaining context managers from the innermost to the outermost.
while entered:
ctx = entered.pop()
try:
if ctx.__exit__(*exc_info):
# The exception is handled.
exc_info = (None, None, None)
except BaseException: # A new or the same exception is raised.
exc_info = sys.exc_info()

if isinstance(exc_info[1], BaseException):
# An exception is unhandled after all context managers have exited.
raise exc_info[1].with_traceback(exc_info[2])
Loading

0 comments on commit 2b452de

Please sign in to comment.