Skip to content

Commit

Permalink
make cancel() more robust
Browse files Browse the repository at this point in the history
Summary:
I gave a talk at pycon about what not to do with asyncio.  One thing I said was not to break the cancellation contract.   These changes make later.cancel() not break the cancellation contract, in the cases where it is cancelled while attempting to cancel a task.
If the task being cancelled does anything but cancel, like raises some other error, or returns we raise an exception.
I said I would have this out shortly after my talk, well I forgot.

Reviewed By: cooperlees

Differential Revision: D37406583

fbshipit-source-id: e7160598f5c6becb9064d7f12bdd013460d43c7f
  • Loading branch information
fried authored and facebook-github-bot committed Jun 24, 2022
1 parent 9efd44c commit 50bcc37
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 4 deletions.
44 changes: 40 additions & 4 deletions later/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,48 @@ def __init__(self):
async def cancel(fut: asyncio.Future) -> None:
"""
Cancel a future/task and await for it to cancel.
This method suppresses the CancelledError
If the fut is already done() this is a no-op
If everything goes well this returns None.
If this coroutine is cancelled, we wait for the passed in argument to cancel
but we will raise the CancelledError as per Cancellation Contract, Unless the task
doesn't cancel correctly then we could raise other exceptions.
If the task raises an exception during cancellation we re-raise it
if the task completes instead of cancelling we raise a InvalidStateError
"""
if fut.done():
return # nothing to do
fut.cancel()
await asyncio.sleep(0) # let loop cycle
with suppress(asyncio.CancelledError):
await fut
exc: Optional[asyncio.CancelledError] = None
while not fut.done():
shielded = asyncio.shield(fut)
try:
await asyncio.wait([shielded])
except asyncio.CancelledError as ex:
exc = ex
finally:
# Insure we handle the exception/value that may exist on the shielded task
# This will prevent errors logged to the asyncio logger
if (
shielded.done()
and not shielded.cancelled()
and not shielded.exception()
):
shielded.result()
if fut.cancelled():
if exc is None:
return
# we were cancelled also so honor the contract
raise exc from None
# Some exception thrown during cancellation
ex = fut.exception()
if ex is not None:
raise ex from None
# fut finished instead of cancelled, wat?
raise asyncio.InvalidStateError(
f"task didn't raise CancelledError on cancel: {fut} had result {fut.result()}"
)


def as_task(func: F) -> F:
Expand Down
88 changes: 88 additions & 0 deletions later/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,94 @@ async def test_cancel_task(self) -> None:
self.assertTrue(task.done())
self.assertTrue(task.cancelled())

async def test_cancel_raises_other_exception(self) -> None:
started = False

@later.as_task
async def _coro():
nonlocal started
started = True
try:
await asyncio.sleep(500)
except asyncio.CancelledError:
raise TypeError

task: asyncio.Task = cast(asyncio.Task, _coro())
await asyncio.sleep(0)
self.assertTrue(started)
with self.assertRaises(TypeError):
await later.cancel(task)
self.assertTrue(task.done())
self.assertFalse(task.cancelled())

async def test_cancel_already_done_task(self) -> None:
started = False

@later.as_task
async def _coro():
nonlocal started
started = True

task: asyncio.Task = cast(asyncio.Task, _coro())
await asyncio.sleep(0)
self.assertTrue(started)
self.assertTrue(task.done())
await later.cancel(task)

async def test_cancel_task_completes(self) -> None:
started = False

@later.as_task
async def _coro():
nonlocal started
started = True
try:
await asyncio.sleep(500)
except asyncio.CancelledError:
return 5

task: asyncio.Task = cast(asyncio.Task, _coro())
await asyncio.sleep(0)
self.assertTrue(started)
with self.assertRaises(asyncio.InvalidStateError):
await later.cancel(task)
self.assertTrue(task.done())
self.assertFalse(task.cancelled())

async def test_cancel_when_cancelled(self) -> None:
started, cancelled = False, False

@later.as_task
async def test():
nonlocal cancelled, started
started = True
try:
await asyncio.sleep(500)
except asyncio.CancelledError:
cancelled = True
await asyncio.sleep(0.5)
raise

# neat :P
cancel_as_task = later.as_task(later.cancel)
otask = cast(asyncio.Task, test()) # task created a scheduled.
await asyncio.sleep(0) # let test start
self.assertTrue(started)
ctask = cast(asyncio.Task, cancel_as_task(otask))
await asyncio.sleep(0) # let the cancel as task start
self.assertFalse(otask.cancelled())
ctask.cancel()
await asyncio.sleep(0) # Insure the cancel was raised in the ctask
self.assertTrue(cancelled)
# Not done yet since the orignal task is sleeping
self.assertFalse(ctask.cancelled())
# we are not cancelled yet, there is a 0.5 sleep in the cancellation flow
with self.assertRaises(asyncio.CancelledError):
# now our cancel must raise a CancelledError as per contract
await ctask
self.assertTrue(ctask.cancelled())
self.assertTrue(otask.cancelled())


class WatcherTests(TestCase):
async def test_empty_watcher(self) -> None:
Expand Down

0 comments on commit 50bcc37

Please sign in to comment.