diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 745ec88d70..0000000000 --- a/.style.yapf +++ /dev/null @@ -1,185 +0,0 @@ -[style] -# Align closing bracket with visual indentation. -align_closing_bracket_with_visual_indent=True - -# Allow dictionary keys to exist on multiple lines. For example: -# -# x = { -# ('this is the first element of a tuple', -# 'this is the second element of a tuple'): -# value, -# } -allow_multiline_dictionary_keys=False - -# Allow lambdas to be formatted on more than one line. -allow_multiline_lambdas=False - -# Insert a blank line before a class-level docstring. -blank_line_before_class_docstring=False - -# Insert a blank line before a 'def' or 'class' immediately nested -# within another 'def' or 'class'. For example: -# -# class Foo: -# # <------ this blank line -# def method(): -# ... -blank_line_before_nested_class_or_def=False - -# Do not split consecutive brackets. Only relevant when -# dedent_closing_brackets is set. For example: -# -# call_func_that_takes_a_dict( -# { -# 'key1': 'value1', -# 'key2': 'value2', -# } -# ) -# -# would reformat to: -# -# call_func_that_takes_a_dict({ -# 'key1': 'value1', -# 'key2': 'value2', -# }) -coalesce_brackets=False - -# The column limit. -column_limit=79 - -# Indent width used for line continuations. -continuation_indent_width=4 - -# Put closing brackets on a separate line, dedented, if the bracketed -# expression can't fit in a single line. Applies to all kinds of brackets, -# including function definitions and calls. For example: -# -# config = { -# 'key1': 'value1', -# 'key2': 'value2', -# } # <--- this bracket is dedented and on a separate line -# -# time_series = self.remote_client.query_entity_counters( -# entity='dev3246.region1', -# key='dns.query_latency_tcp', -# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), -# start_ts=now()-timedelta(days=3), -# end_ts=now(), -# ) # <--- this bracket is dedented and on a separate line -dedent_closing_brackets=True - -# Place each dictionary entry onto its own line. -each_dict_entry_on_separate_line=True - -# The regex for an i18n comment. The presence of this comment stops -# reformatting of that line, because the comments are required to be -# next to the string they translate. -i18n_comment= - -# The i18n function call names. The presence of this function stops -# reformattting on that line, because the string it has cannot be moved -# away from the i18n comment. -i18n_function_call= - -# Indent the dictionary value if it cannot fit on the same line as the -# dictionary key. For example: -# -# config = { -# 'key1': -# 'value1', -# 'key2': value1 + -# value2, -# } -indent_dictionary_value=True - -# The number of columns to use for indentation. -indent_width=4 - -# Join short lines into one line. E.g., single line 'if' statements. -join_multiple_lines=False - -# Use spaces around default or named assigns. -spaces_around_default_or_named_assign=False - -# Use spaces around the power operator. -spaces_around_power_operator=False - -# The number of spaces required before a trailing comment. -spaces_before_comment=2 - -# Insert a space between the ending comma and closing bracket of a list, -# etc. -space_between_ending_comma_and_closing_bracket=False - -# Split before arguments if the argument list is terminated by a -# comma. -split_arguments_when_comma_terminated=True - -# Set to True to prefer splitting before '&', '|' or '^' rather than -# after. -split_before_bitwise_operator=True - -# Split before a dictionary or set generator (comp_for). For example, note -# the split before the 'for': -# -# foo = { -# variable: 'Hello world, have a nice day!' -# for variable in bar if variable != 42 -# } -split_before_dict_set_generator=True - -# If an argument / parameter list is going to be split, then split before -# the first argument. -split_before_first_argument=False - -# Set to True to prefer splitting before 'and' or 'or' rather than -# after. -split_before_logical_operator=True - -# Split named assignments onto individual lines. -split_before_named_assigns=True - -# The penalty for splitting right after the opening bracket. -split_penalty_after_opening_bracket=30 - -# The penalty for splitting the line after a unary operator. -split_penalty_after_unary_operator=10000 - -# The penalty for splitting right before an if expression. -split_penalty_before_if_expr=0 - -# The penalty of splitting the line around the '&', '|', and '^' -# operators. -split_penalty_bitwise_operator=300 - -# The penalty for characters over the column limit. -split_penalty_excess_character=4500 - -# The penalty incurred by adding a line split to the unwrapped line. The -# more line splits added the higher the penalty. -split_penalty_for_added_line_split=30 - -# The penalty of splitting a list of "import as" names. For example: -# -# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, -# long_argument_2, -# long_argument_3) -# -# would reformat to something like: -# -# from a_very_long_or_indented_module_name_yada_yad import ( -# long_argument_1, long_argument_2, long_argument_3) -split_penalty_import_names=0 - -# The penalty of splitting the line around the 'and' and 'or' -# operators. -split_penalty_logical_operator=0 - -# Use the Tab character for indentation. -use_tabs=False - -# Without this, yapf likes to write things like -# "foo bar {}". -# format(...) -# which is just awful. -split_before_dot=True diff --git a/check.sh b/check.sh index 0b66ca227b..16394ccea1 100755 --- a/check.sh +++ b/check.sh @@ -13,7 +13,7 @@ python ./trio/_tools/gen_exports.py --test \ # see https://forum.bors.tech/t/pre-test-and-pre-merge-hooks/322) # autoflake --recursive --in-place . # pyupgrade --py3-plus $(find . -name "*.py") -yapf -rpd setup.py trio \ +black --diff setup.py trio \ || EXIT_STATUS=$? # Run flake8 without pycodestyle and import-related errors @@ -31,7 +31,7 @@ Problems were found by static analysis (listed above). To fix formatting and see remaining errors, run pip install -r test-requirements.txt - yapf -rpi setup.py trio + black setup.py trio ./check.sh in your local checkout. diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 8c6a734100..d1723c7fc5 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -133,7 +133,7 @@ in separate sections below: adding a test to make sure it stays fixed. * :ref:`pull-request-formatting`: If you changed Python code, then did - you run ``yapf -rpi setup.py trio``? (Or for other packages, replace + you run ``black setup.py trio``? (Or for other packages, replace ``trio`` with the package name.) * :ref:`pull-request-release-notes`: If your change affects @@ -285,31 +285,30 @@ of eyes can be helpful when trying to come up with devious tricks. Code formatting ~~~~~~~~~~~~~~~ -Instead of wasting time arguing about code formatting, we use `yapf -`__ to automatically format all our +Instead of wasting time arguing about code formatting, we use `black +`__ to automatically format all our code to a standard style. While you're editing code you can be as sloppy as you like about whitespace; and then before you commit, just run:: - pip install -U yapf - yapf -rpi setup.py trio + pip install -U black + black setup.py trio to fix it up. (And don't worry if you forget – when you submit a pull request then we'll automatically check and remind you.) Hopefully this will let you focus on more important style issues like choosing good names, writing useful comments, and making sure your docstrings are -nicely formatted. (Yapf doesn't reformat comments or docstrings.) +nicely formatted. (black doesn't reformat comments or docstrings.) -Very occasionally, yapf will generate really ugly and unreadable -formatting (usually for large literal structures like dicts nested -inside dicts). In these cases, you can add a ``# yapf: disable`` -comment to tell it to leave that particular statement alone. +Very occasionally, you'll want to override black formatting. To do so, +you can can add ``# fmt: off`` and ``# fmt: on`` comments. -If you want to see what changes yapf will make, you can use:: +If you want to see what changes black will make, you can use:: - yapf -rpd setup.py trio + black --diff setup.py trio -(``-d`` displays a diff, versus ``-i`` which fixes files in-place.) +(``--diff`` displays a diff, versus the default mode which fixes files +in-place.) .. _pull-request-release-notes: diff --git a/pyproject.toml b/pyproject.toml index 6100edcba8..57d9c23f42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,7 @@ +[tool.black] +target-version = ['py36'] + + [tool.towncrier] # Usage: # - PRs should drop a file like "issuenumber.feature" in newsfragments diff --git a/setup.py b/setup.py index 852fa616b2..f76c36e378 100644 --- a/setup.py +++ b/setup.py @@ -89,7 +89,7 @@ # cffi 1.14 fixes memory leak inside ffi.getwinerror() # cffi is required on Windows, except on PyPy where it is built-in "cffi>=1.14; os_name == 'nt' and implementation_name != 'pypy'", - "contextvars>=2.1; python_version < '3.7'" + "contextvars>=2.1; python_version < '3.7'", ], # This means, just install *everything* you see under trio/, even if it # doesn't look like a source file, so long as it appears in MANIFEST.in: diff --git a/test-requirements.in b/test-requirements.in index f03409d484..987567acc5 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -8,7 +8,7 @@ pylint # for pylint finding all symbols tests jedi # for jedi code completion tests # Tools -yapf ==0.30.0 # formatting +black; implementation_name == "cpython" flake8 astor # code generation diff --git a/test-requirements.txt b/test-requirements.txt index f32a41c163..cb75bfb8ac 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -4,18 +4,24 @@ # # pip-compile --output-file test-requirements.txt test-requirements.in # +appdirs==1.4.4 # via black +appnope==0.1.0 # via ipython astor==0.8.1 # via -r test-requirements.in astroid==2.4.1 # via pylint async-generator==1.10 # via -r test-requirements.in -attrs==19.3.0 # via -r test-requirements.in, outcome, pytest +attrs==19.3.0 # via -r test-requirements.in, black, outcome, pytest backcall==0.1.0 # via ipython +black==19.10b0 ; implementation_name == "cpython" # via -r test-requirements.in cffi==1.14.0 # via cryptography +click==7.1.2 # via black +contextvars==2.4 ; python_version < "3.7" # via -r test-requirements.in, sniffio coverage==5.1 # via pytest-cov cryptography==2.9.2 # via pyopenssl, trustme decorator==4.4.2 # via ipython, traitlets flake8==3.8.1 # via -r test-requirements.in idna==2.9 # via -r test-requirements.in, trustme -immutables==0.14 # via -r test-requirements.in +immutables==0.14 # via -r test-requirements.in, contextvars +importlib-metadata==1.6.0 # via flake8, pluggy, pytest ipython-genutils==0.2.0 # via traitlets ipython==7.14.0 # via -r test-requirements.in isort==4.3.21 # via pylint @@ -26,6 +32,7 @@ more-itertools==8.3.0 # via pytest outcome==1.0.1 # via -r test-requirements.in packaging==20.4 # via pytest parso==0.7.0 # via jedi +pathspec==0.8.0 # via black pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython pluggy==0.13.1 # via pytest @@ -41,12 +48,14 @@ pyopenssl==19.1.0 # via -r test-requirements.in pyparsing==2.4.7 # via packaging pytest-cov==2.8.1 # via -r test-requirements.in pytest==5.4.2 # via -r test-requirements.in, pytest-cov +regex==2020.5.14 # via black six==1.15.0 # via astroid, cryptography, packaging, pyopenssl, traitlets sniffio==1.1.0 # via -r test-requirements.in sortedcontainers==2.1.0 # via -r test-requirements.in -toml==0.10.1 # via pylint +toml==0.10.1 # via black, pylint traitlets==4.3.3 # via ipython trustme==0.6.0 # via -r test-requirements.in +typed-ast==1.4.1 ; python_version < "3.8" and implementation_name == "cpython" # via -r test-requirements.in, astroid, black wcwidth==0.1.9 # via prompt-toolkit, pytest wrapt==1.12.1 # via astroid -yapf==0.30.0 # via -r test-requirements.in +zipp==3.1.0 # via importlib-metadata diff --git a/trio/__init__.py b/trio/__init__.py index 30e42da97e..5339d107eb 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -16,19 +16,42 @@ from ._version import __version__ from ._core import ( - TrioInternalError, RunFinishedError, WouldBlock, Cancelled, - BusyResourceError, ClosedResourceError, MultiError, run, open_nursery, - CancelScope, current_effective_deadline, TASK_STATUS_IGNORED, current_time, - BrokenResourceError, EndOfChannel, Nursery + TrioInternalError, + RunFinishedError, + WouldBlock, + Cancelled, + BusyResourceError, + ClosedResourceError, + MultiError, + run, + open_nursery, + CancelScope, + current_effective_deadline, + TASK_STATUS_IGNORED, + current_time, + BrokenResourceError, + EndOfChannel, + Nursery, ) from ._timeouts import ( - move_on_at, move_on_after, sleep_forever, sleep_until, sleep, fail_at, - fail_after, TooSlowError + move_on_at, + move_on_after, + sleep_forever, + sleep_until, + sleep, + fail_at, + fail_after, + TooSlowError, ) from ._sync import ( - Event, CapacityLimiter, Semaphore, Lock, StrictFIFOLock, Condition + Event, + CapacityLimiter, + Semaphore, + Lock, + StrictFIFOLock, + Condition, ) from ._threads import BlockingTrioPortal as _BlockingTrioPortal @@ -36,7 +59,9 @@ from ._highlevel_generic import aclose_forcefully, StapledStream from ._channel import ( - open_memory_channel, MemorySendChannel, MemoryReceiveChannel + open_memory_channel, + MemorySendChannel, + MemoryReceiveChannel, ) from ._signals import open_signal_receiver @@ -60,7 +85,9 @@ from ._highlevel_open_unix_stream import open_unix_socket from ._highlevel_ssl_helpers import ( - open_ssl_over_tcp_stream, open_ssl_over_tcp_listeners, serve_ssl_over_tcp + open_ssl_over_tcp_stream, + open_ssl_over_tcp_listeners, + serve_ssl_over_tcp, ) from ._deprecate import TrioDeprecationWarning @@ -71,6 +98,7 @@ from . import abc from . import from_thread from . import to_thread + # Not imported by default, but mentioned here so static analysis tools like # pylint will know that it exists. if False: @@ -81,85 +109,57 @@ _deprecate.enable_attribute_deprecations(__name__) __deprecated_attributes__ = { - "ssl": - _deprecate.DeprecatedAttribute( - _deprecated_ssl_reexports, - "0.11.0", - issue=852, - instead=( - "trio.SSLStream, trio.SSLListener, trio.NeedHandshakeError, " - "and the standard library 'ssl' module (minus SSLSocket and " - "wrap_socket())" - ), - ), - "subprocess": - _deprecate.DeprecatedAttribute( - _deprecated_subprocess_reexports, - "0.11.0", - issue=852, - instead=( - "trio.Process and the constants in the standard " - "library 'subprocess' module" - ), - ), - "run_sync_in_worker_thread": - _deprecate.DeprecatedAttribute( - to_thread.run_sync, - "0.12.0", - issue=810, + "ssl": _deprecate.DeprecatedAttribute( + _deprecated_ssl_reexports, + "0.11.0", + issue=852, + instead=( + "trio.SSLStream, trio.SSLListener, trio.NeedHandshakeError, " + "and the standard library 'ssl' module (minus SSLSocket and " + "wrap_socket())" ), - "current_default_worker_thread_limiter": - _deprecate.DeprecatedAttribute( - to_thread.current_default_thread_limiter, - "0.12.0", - issue=810, - ), - "BlockingTrioPortal": - _deprecate.DeprecatedAttribute( - _BlockingTrioPortal, - "0.12.0", - issue=810, - instead=from_thread, + ), + "subprocess": _deprecate.DeprecatedAttribute( + _deprecated_subprocess_reexports, + "0.11.0", + issue=852, + instead=( + "trio.Process and the constants in the standard " + "library 'subprocess' module" ), + ), + "run_sync_in_worker_thread": _deprecate.DeprecatedAttribute( + to_thread.run_sync, "0.12.0", issue=810, + ), + "current_default_worker_thread_limiter": _deprecate.DeprecatedAttribute( + to_thread.current_default_thread_limiter, "0.12.0", issue=810, + ), + "BlockingTrioPortal": _deprecate.DeprecatedAttribute( + _BlockingTrioPortal, "0.12.0", issue=810, instead=from_thread, + ), # NOTE: when you remove this, you should also remove the file # trio/hazmat.py. For details on why we have both, see: # # https://github.com/python-trio/trio/pull/1484#issuecomment-622574499 - "hazmat": - _deprecate.DeprecatedAttribute( - lowlevel, - "0.15.0", - issue=476, - instead="trio.lowlevel", - ), + "hazmat": _deprecate.DeprecatedAttribute( + lowlevel, "0.15.0", issue=476, instead="trio.lowlevel", + ), } _deprecate.enable_attribute_deprecations(lowlevel.__name__) lowlevel.__deprecated_attributes__ = { - "wait_socket_readable": - _deprecate.DeprecatedAttribute( - lowlevel.wait_readable, - "0.12.0", - issue=878, - ), - "wait_socket_writable": - _deprecate.DeprecatedAttribute( - lowlevel.wait_writable, - "0.12.0", - issue=878, - ), - "notify_socket_close": - _deprecate.DeprecatedAttribute( - lowlevel.notify_closing, - "0.12.0", - issue=878, - ), - "notify_fd_close": - _deprecate.DeprecatedAttribute( - lowlevel.notify_closing, - "0.12.0", - issue=878, - ), + "wait_socket_readable": _deprecate.DeprecatedAttribute( + lowlevel.wait_readable, "0.12.0", issue=878, + ), + "wait_socket_writable": _deprecate.DeprecatedAttribute( + lowlevel.wait_writable, "0.12.0", issue=878, + ), + "notify_socket_close": _deprecate.DeprecatedAttribute( + lowlevel.notify_closing, "0.12.0", issue=878, + ), + "notify_fd_close": _deprecate.DeprecatedAttribute( + lowlevel.notify_closing, "0.12.0", issue=878, + ), } # Having the public path in .__module__ attributes is important for: @@ -169,6 +169,7 @@ # - pickle # - probably other stuff from ._util import fixup_module_metadata + fixup_module_metadata(__name__, globals()) fixup_module_metadata(lowlevel.__name__, lowlevel.__dict__) fixup_module_metadata(socket.__name__, socket.__dict__) diff --git a/trio/_abc.py b/trio/_abc.py index 504c145baa..e3ccb930ee 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -9,6 +9,7 @@ class Clock(metaclass=ABCMeta): """The interface for custom run loop clocks. """ + __slots__ = () @abstractmethod @@ -63,6 +64,7 @@ class Instrument(metaclass=ABCMeta): of these methods are optional. This class serves mostly as documentation. """ + __slots__ = () def before_run(self): @@ -144,12 +146,11 @@ class HostnameResolver(metaclass=ABCMeta): See :func:`trio.socket.set_custom_hostname_resolver`. """ + __slots__ = () @abstractmethod - async def getaddrinfo( - self, host, port, family=0, type=0, proto=0, flags=0 - ): + async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): """A custom implementation of :func:`~trio.socket.getaddrinfo`. Called by :func:`trio.socket.getaddrinfo`. @@ -181,6 +182,7 @@ class SocketFactory(metaclass=ABCMeta): See :func:`trio.socket.set_custom_socket_factory`. """ + @abstractmethod def socket(self, family=None, type=None, proto=None): """Create and return a socket object. @@ -224,6 +226,7 @@ class AsyncResource(metaclass=ABCMeta): ``__aenter__`` and ``__aexit__`` should be adequate for all subclasses. """ + __slots__ = () @abstractmethod @@ -277,6 +280,7 @@ class SendStream(AsyncResource): :class:`SendChannel`. """ + __slots__ = () @abstractmethod @@ -382,6 +386,7 @@ class ReceiveStream(AsyncResource): byte, and the loop automatically exits when reaching end-of-file. """ + __slots__ = () @abstractmethod @@ -433,6 +438,7 @@ class Stream(SendStream, ReceiveStream): step further and implement :class:`HalfCloseableStream`. """ + __slots__ = () @@ -441,6 +447,7 @@ class HalfCloseableStream(Stream): part of the stream without closing the receive part. """ + __slots__ = () @abstractmethod @@ -519,6 +526,7 @@ class Listener(AsyncResource, Generic[T_resource]): or using an ``async with`` block. """ + __slots__ = () @abstractmethod @@ -560,6 +568,7 @@ class SendChannel(AsyncResource, Generic[SendType]): `SendStream`. """ + __slots__ = () @abstractmethod @@ -604,6 +613,7 @@ class ReceiveChannel(AsyncResource, Generic[ReceiveType]): `ReceiveStream`. """ + __slots__ = () @abstractmethod diff --git a/trio/_channel.py b/trio/_channel.py index 3ec404a8a2..dac7935c0c 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -70,7 +70,8 @@ def open_memory_channel(max_buffer_size): raise ValueError("max_buffer_size must be >= 0") state = MemoryChannelState(max_buffer_size) return ( - MemorySendChannel._create(state), MemoryReceiveChannel._create(state) + MemorySendChannel._create(state), + MemoryReceiveChannel._create(state), ) @@ -120,10 +121,8 @@ def __attrs_post_init__(self): self._state.open_send_channels += 1 def __repr__(self): - return ( - "".format( - id(self), id(self._state) - ) + return "".format( + id(self), id(self._state) ) def statistics(self): @@ -341,9 +340,7 @@ async def aclose(self): assert not self._state.receive_tasks for task in self._state.send_tasks: task.custom_sleep_data._tasks.remove(task) - trio.lowlevel.reschedule( - task, Error(trio.BrokenResourceError()) - ) + trio.lowlevel.reschedule(task, Error(trio.BrokenResourceError())) self._state.send_tasks.clear() self._state.data.clear() await trio.lowlevel.checkpoint() diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index 4b3a088d1b..c28b7f4078 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -5,31 +5,59 @@ """ from ._exceptions import ( - TrioInternalError, RunFinishedError, WouldBlock, Cancelled, - BusyResourceError, ClosedResourceError, BrokenResourceError, EndOfChannel + TrioInternalError, + RunFinishedError, + WouldBlock, + Cancelled, + BusyResourceError, + ClosedResourceError, + BrokenResourceError, + EndOfChannel, ) from ._multierror import MultiError from ._ki import ( - enable_ki_protection, disable_ki_protection, currently_ki_protected + enable_ki_protection, + disable_ki_protection, + currently_ki_protected, ) # Imports that always exist from ._run import ( - Task, CancelScope, run, open_nursery, checkpoint, current_task, - current_effective_deadline, checkpoint_if_cancelled, TASK_STATUS_IGNORED, - current_statistics, current_trio_token, reschedule, remove_instrument, - add_instrument, current_clock, current_root_task, spawn_system_task, - current_time, wait_all_tasks_blocked, wait_readable, wait_writable, - notify_closing, Nursery + Task, + CancelScope, + run, + open_nursery, + checkpoint, + current_task, + current_effective_deadline, + checkpoint_if_cancelled, + TASK_STATUS_IGNORED, + current_statistics, + current_trio_token, + reschedule, + remove_instrument, + add_instrument, + current_clock, + current_root_task, + spawn_system_task, + current_time, + wait_all_tasks_blocked, + wait_readable, + wait_writable, + notify_closing, + Nursery, ) # Has to come after _run to resolve a circular import from ._traps import ( - cancel_shielded_checkpoint, Abort, wait_task_rescheduled, - temporarily_detach_coroutine_object, permanently_detach_coroutine_object, - reattach_detached_coroutine_object + cancel_shielded_checkpoint, + Abort, + wait_task_rescheduled, + temporarily_detach_coroutine_object, + permanently_detach_coroutine_object, + reattach_detached_coroutine_object, ) from ._entry_queue import TrioToken @@ -42,15 +70,19 @@ # Kqueue imports try: - from ._run import (current_kqueue, monitor_kevent, wait_kevent) + from ._run import current_kqueue, monitor_kevent, wait_kevent except ImportError: pass # Windows imports try: from ._run import ( - monitor_completion_key, current_iocp, register_with_iocp, - wait_overlapped, write_overlapped, readinto_overlapped + monitor_completion_key, + current_iocp, + register_with_iocp, + wait_overlapped, + write_overlapped, + readinto_overlapped, ) except ImportError: pass diff --git a/trio/_core/_entry_queue.py b/trio/_core/_entry_queue.py index 0d29e393b0..791ab8ca6e 100644 --- a/trio/_core/_entry_queue.py +++ b/trio/_core/_entry_queue.py @@ -141,7 +141,7 @@ class TrioToken(metaclass=NoPublicConstructor): """ - __slots__ = ('_reentry_queue',) + __slots__ = ("_reentry_queue",) def __init__(self, reentry_queue): self._reentry_queue = reentry_queue @@ -190,6 +190,4 @@ def run_sync_soon(self, sync_fn, *args, idempotent=False): exits.) """ - self._reentry_queue.run_sync_soon( - sync_fn, *args, idempotent=idempotent - ) + self._reentry_queue.run_sync_soon(sync_fn, *args, idempotent=idempotent) diff --git a/trio/_core/_exceptions.py b/trio/_core/_exceptions.py index ef70eb3df5..2754b8c838 100644 --- a/trio/_core/_exceptions.py +++ b/trio/_core/_exceptions.py @@ -65,6 +65,7 @@ class Cancelled(BaseException, metaclass=NoPublicConstructor): everywhere. """ + def __str__(self): return "Cancelled" diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index fe63a6ee0c..9583c7ff4f 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -4,25 +4,31 @@ from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED - +# fmt: off + async def wait_readable(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_writable(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def notify_closing(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + + +# fmt: on diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 059a8a95d1..ab95d1e30c 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -4,46 +4,55 @@ from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED - +# fmt: off + def current_kqueue(): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() + return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def monitor_kevent(ident, filter): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) + return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_kevent(ident, filter, abort_func): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_readable(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_writable(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def notify_closing(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + + +# fmt: on diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 78dd30db19..d6a5760374 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -4,67 +4,79 @@ from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED - +# fmt: off + async def wait_readable(sock): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_writable(sock): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def notify_closing(handle): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def register_with_iocp(handle): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) + return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_overlapped(handle, lpOverlapped): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped(handle, lpOverlapped) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def write_overlapped(handle, data, file_offset=0): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped(handle, data, file_offset) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def readinto_overlapped(handle, buffer, file_offset=0): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped(handle, buffer, file_offset) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def current_iocp(): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() + return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def monitor_completion_key(): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() + return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + + +# fmt: on diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index 75f61bfdc5..edf46fd741 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -4,7 +4,8 @@ from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED - +# fmt: off + def current_statistics(): """Returns an object containing run-loop-level debugging information. @@ -31,9 +32,10 @@ def current_statistics(): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.current_statistics() + return GLOBAL_RUN_CONTEXT.runner.current_statistics() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def current_time(): """Returns the current time according to Trio's internal clock. @@ -47,9 +49,10 @@ def current_time(): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.current_time() + return GLOBAL_RUN_CONTEXT.runner.current_time() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def current_clock(): """Returns the current :class:`~trio.abc.Clock`. @@ -57,9 +60,10 @@ def current_clock(): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.current_clock() + return GLOBAL_RUN_CONTEXT.runner.current_clock() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def current_root_task(): """Returns the current root :class:`Task`. @@ -69,9 +73,10 @@ def current_root_task(): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.current_root_task() + return GLOBAL_RUN_CONTEXT.runner.current_root_task() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def reschedule(task, next_send=_NO_SEND): """Reschedule the given task with the given @@ -93,9 +98,10 @@ def reschedule(task, next_send=_NO_SEND): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send) + return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def spawn_system_task(async_fn, *args, name=None): """Spawn a "system" task. @@ -138,9 +144,10 @@ def spawn_system_task(async_fn, *args, name=None): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.spawn_system_task(async_fn, *args, name=name) + return GLOBAL_RUN_CONTEXT.runner.spawn_system_task(async_fn, *args, name=name) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def current_trio_token(): """Retrieve the :class:`TrioToken` for the current call to @@ -149,9 +156,10 @@ def current_trio_token(): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.current_trio_token() + return GLOBAL_RUN_CONTEXT.runner.current_trio_token() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_all_tasks_blocked(cushion=0.0, tiebreaker=0): """Block until there are no runnable tasks. @@ -217,7 +225,8 @@ async def test_lock_fairness(): try: return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion, tiebreaker) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def add_instrument(instrument): """Start instrumenting the current run loop with the given instrument. @@ -230,9 +239,10 @@ def add_instrument(instrument): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.add_instrument(instrument) + return GLOBAL_RUN_CONTEXT.runner.add_instrument(instrument) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def remove_instrument(instrument): """Stop instrumenting the current run loop with the given instrument. @@ -249,6 +259,9 @@ def remove_instrument(instrument): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.remove_instrument(instrument) + return GLOBAL_RUN_CONTEXT.runner.remove_instrument(instrument) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + + +# fmt: on diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 5d73a58c84..71f46c40a7 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -235,9 +235,7 @@ def _update_registrations(self, fd): self._epoll.modify(fd, wanted_flags | select.EPOLLONESHOT) except OSError: # If that fails, it might be a new fd; try EPOLL_CTL_ADD - self._epoll.register( - fd, wanted_flags | select.EPOLLONESHOT - ) + self._epoll.register(fd, wanted_flags | select.EPOLLONESHOT) waiters.current_flags = wanted_flags except OSError as exc: # If everything fails, probably it's a bad fd, e.g. because @@ -284,7 +282,7 @@ def notify_closing(self, fd): fd = fd.fileno() wake_all( self._registered[fd], - _core.ClosedResourceError("another task closed this fd") + _core.ClosedResourceError("another task closed this fd"), ) del self._registered[fd] try: diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 593f2e353a..e2d134ce05 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -30,10 +30,7 @@ def statistics(self): tasks_waiting += 1 else: monitors += 1 - return _KqueueStatistics( - tasks_waiting=tasks_waiting, - monitors=monitors, - ) + return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors,) def close(self): self._kqueue.close() @@ -84,8 +81,7 @@ def monitor_kevent(self, ident, filter): key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( - "attempt to register multiple listeners for same " - "ident/filter pair" + "attempt to register multiple listeners for same ident/filter pair" ) q = _core.UnboundedQueue() self._registered[key] = q @@ -99,8 +95,7 @@ async def wait_kevent(self, ident, filter, abort_func): key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( - "attempt to register multiple listeners for same " - "ident/filter pair" + "attempt to register multiple listeners for same ident/filter pair" ) self._registered[key] = _core.current_task() @@ -134,9 +129,7 @@ def abort(_): # the fact... oh well, you can't have everything.) # # FreeBSD reports this using EBADF. macOS uses ENOENT. - if exc.errno in ( - errno.EBADF, errno.ENOENT - ): # pragma: no branch + if exc.errno in (errno.EBADF, errno.ENOENT,): # pragma: no branch pass else: # pragma: no cover # As far as we know, this branch can't happen. diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 6302094888..d1fd20c418 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -325,9 +325,7 @@ def __init__(self): self._afd = None self._iocp = _check( - kernel32.CreateIoCompletionPort( - INVALID_HANDLE_VALUE, ffi.NULL, 0, 0 - ) + kernel32.CreateIoCompletionPort(INVALID_HANDLE_VALUE, ffi.NULL, 0, 0) ) self._events = ffi.new("OVERLAPPED_ENTRY[]", MAX_EVENTS) @@ -350,9 +348,7 @@ def __init__(self): base_handle = _get_base_socket(s, which=WSAIoctls.SIO_BASE_HANDLE) # LSPs can in theory override this, but we believe that it never # actually happens in the wild. - select_handle = _get_base_socket( - s, which=WSAIoctls.SIO_BSP_HANDLE_SELECT - ) + select_handle = _get_base_socket(s, which=WSAIoctls.SIO_BSP_HANDLE_SELECT) if base_handle != select_handle: # pragma: no cover raise RuntimeError( "Unexpected network configuration detected. " @@ -400,8 +396,7 @@ def handle_io(self, timeout): try: _check( kernel32.GetQueuedCompletionStatusEx( - self._iocp, self._events, MAX_EVENTS, received, - milliseconds, 0 + self._iocp, self._events, MAX_EVENTS, received, milliseconds, 0, ) ) except OSError as exc: @@ -476,18 +471,13 @@ def handle_io(self, timeout): overlapped = int(ffi.cast("uintptr_t", entry.lpOverlapped)) transferred = entry.dwNumberOfBytesTransferred info = CompletionKeyEventInfo( - lpOverlapped=overlapped, - dwNumberOfBytesTransferred=transferred, + lpOverlapped=overlapped, dwNumberOfBytesTransferred=transferred, ) queue.put_nowait(info) def _register_with_iocp(self, handle, completion_key): handle = _handle(handle) - _check( - kernel32.CreateIoCompletionPort( - handle, self._iocp, completion_key, 0 - ) - ) + _check(kernel32.CreateIoCompletionPort(handle, self._iocp, completion_key, 0)) # Supposedly this makes things slightly faster, by disabling the # ability to do WaitForSingleObject(handle). We would never want to do # that anyway, so might as well get the extra speed (if any). @@ -506,11 +496,7 @@ def _refresh_afd(self, base_handle): waiters = self._afd_waiters[base_handle] if waiters.current_op is not None: try: - _check( - kernel32.CancelIoEx( - self._afd, waiters.current_op.lpOverlapped - ) - ) + _check(kernel32.CancelIoEx(self._afd, waiters.current_op.lpOverlapped)) except OSError as exc: if exc.winerror != ErrorCodes.ERROR_NOT_FOUND: # I don't think this is possible, so if it happens let's @@ -530,7 +516,7 @@ def _refresh_afd(self, base_handle): lpOverlapped = ffi.new("LPOVERLAPPED") poll_info = ffi.new("AFD_POLL_INFO *") - poll_info.Timeout = 2**63 - 1 # INT64_MAX + poll_info.Timeout = 2 ** 63 - 1 # INT64_MAX poll_info.NumberOfHandles = 1 poll_info.Exclusive = 0 poll_info.Handles[0].Handle = base_handle @@ -669,9 +655,7 @@ def abort(cancel_exc_): # We didn't request this cancellation, so assume # it happened due to the underlying handle being # closed before the operation could complete. - raise _core.ClosedResourceError( - "another task closed this resource" - ) + raise _core.ClosedResourceError("another task closed this resource") else: raise_winerror(code) @@ -700,7 +684,7 @@ async def write_overlapped(self, handle, data, file_offset=0): def submit_write(lpOverlapped): # yes, these are the real documented names offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME - offset_fields.Offset = file_offset & 0xffffffff + offset_fields.Offset = file_offset & 0xFFFFFFFF offset_fields.OffsetHigh = file_offset >> 32 _check( kernel32.WriteFile( @@ -722,7 +706,7 @@ async def readinto_overlapped(self, handle, buffer, file_offset=0): def submit_read(lpOverlapped): offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME - offset_fields.Offset = file_offset & 0xffffffff + offset_fields.Offset = file_offset & 0xFFFFFFFF offset_fields.OffsetHigh = file_offset >> 32 _check( kernel32.ReadFile( diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py index 9406ebb7e1..d5aa63f5d9 100644 --- a/trio/_core/_ki.py +++ b/trio/_core/_ki.py @@ -10,7 +10,8 @@ if False: from typing import Any, TypeVar, Callable - F = TypeVar('F', bound=Callable[..., Any]) + + F = TypeVar("F", bound=Callable[..., Any]) # In ordinary single-threaded Python code, when you hit control-C, it raises # an exception and automatically does all the regular unwinding stuff. @@ -77,7 +78,7 @@ # We use this special string as a unique key into the frame locals dictionary. # The @ ensures it is not a valid identifier and can't clash with any possible # real local name. See: https://github.com/python-trio/trio/issues/469 -LOCALS_KEY_KI_PROTECTION_ENABLED = '@TRIO_KI_PROTECTION_ENABLED' +LOCALS_KEY_KI_PROTECTION_ENABLED = "@TRIO_KI_PROTECTION_ENABLED" # NB: according to the signal.signal docs, 'frame' can be None on entry to @@ -119,8 +120,7 @@ def decorator(fn): def wrapper(*args, **kwargs): # See the comment for regular generators below coro = fn(*args, **kwargs) - coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED - ] = enabled + coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return coro return wrapper @@ -137,8 +137,7 @@ def wrapper(*args, **kwargs): # thrown into! See: # https://bugs.python.org/issue29590 gen = fn(*args, **kwargs) - gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED - ] = enabled + gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return gen return wrapper @@ -148,8 +147,7 @@ def wrapper(*args, **kwargs): def wrapper(*args, **kwargs): # See the comment for regular generators above agen = fn(*args, **kwargs) - agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED - ] = enabled + agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return agen return wrapper @@ -168,9 +166,7 @@ def wrapper(*args, **kwargs): enable_ki_protection = _ki_protection_decorator(True) # type: Callable[[F], F] enable_ki_protection.__name__ = "enable_ki_protection" -disable_ki_protection = _ki_protection_decorator( - False -) # type: Callable[[F], F] +disable_ki_protection = _ki_protection_decorator(False) # type: Callable[[F], F] disable_ki_protection.__name__ = "disable_ki_protection" diff --git a/trio/_core/_local.py b/trio/_core/_local.py index aaf5f6d2f7..352caa5682 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -40,8 +40,7 @@ def get(self, default=_NO_DEFAULT): try: return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] except AttributeError: - raise RuntimeError("Cannot be used outside of a run context") \ - from None + raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: # contextvars consistency if default is not self._NO_DEFAULT: @@ -95,4 +94,4 @@ def reset(self, token): token.redeemed = True def __repr__(self): - return ("".format(self._name)) + return "".format(self._name) diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index f98540344d..cdeeac6269 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -173,6 +173,7 @@ class MultiError(BaseException): :exc:`BaseException`. """ + def __init__(self, exceptions): # Avoid recursion when exceptions[0] returned by __new__() happens # to be a MultiError and subsequently __init__() is called. @@ -186,9 +187,7 @@ def __new__(cls, exceptions): exceptions = list(exceptions) for exc in exceptions: if not isinstance(exc, BaseException): - raise TypeError( - "Expected an exception object, not {!r}".format(exc) - ) + raise TypeError("Expected an exception object, not {!r}".format(exc)) if len(exceptions) == 1: # If this lone object happens to itself be a MultiError, then # Python will implicitly call our __init__ on it again. See @@ -280,16 +279,20 @@ def controller(operation): # no missing test we could add, and no value in coverage nagging # us about adding one. if operation.opname in [ - "__getattribute__", "__getattr__" + "__getattribute__", + "__getattr__", ]: # pragma: no cover if operation.args[0] == "tb_next": return tb_next return operation.delegate() return tputil.make_proxy(controller, type(base_tb), base_tb) + + else: # ctypes it is import ctypes + # How to handle refcounting? I don't want to use ctypes.py_object because # I don't understand or trust it, and I don't want to use # ctypes.pythonapi.Py_{Inc,Dec}Ref because we might clash with user code @@ -374,7 +377,7 @@ def traceback_exception_init( limit=None, lookup_lines=True, capture_locals=False, - _seen=None + _seen=None, ): if _seen is None: _seen = set() @@ -388,7 +391,7 @@ def traceback_exception_init( limit=limit, lookup_lines=lookup_lines, capture_locals=capture_locals, - _seen=_seen + _seen=_seen, ) # Capture each of the exceptions in the MultiError along with each of their causes and contexts @@ -404,7 +407,7 @@ def traceback_exception_init( capture_locals=capture_locals, # copy the set of _seen exceptions so that duplicates # shared between sub-exceptions are not omitted - _seen=set(_seen) + _seen=set(_seen), ) ) self.embedded = embedded @@ -421,9 +424,7 @@ def traceback_exception_format(self, *, chain=True): for i, exc in enumerate(self.embedded): yield "\nDetails of embedded exception {}:\n\n".format(i + 1) - yield from ( - textwrap.indent(line, " " * 2) for line in exc.format(chain=chain) - ) + yield from (textwrap.indent(line, " " * 2) for line in exc.format(chain=chain)) traceback.TracebackException.format = traceback_exception_format @@ -438,6 +439,7 @@ def trio_excepthook(etype, value, tb): if "IPython" in sys.modules: import IPython + ip = IPython.get_ipython() if ip is not None: if ip.custom_exceptions != (): @@ -446,7 +448,7 @@ def trio_excepthook(etype, value, tb): "handler installed. I'll skip installing Trio's custom " "handler, but this means MultiErrors will not show full " "tracebacks.", - category=RuntimeWarning + category=RuntimeWarning, ) monkeypatched_or_warned = True else: @@ -477,6 +479,7 @@ def trio_show_traceback(self, etype, value, tb, tb_offset=None): # More details: https://github.com/python-trio/trio/issues/1065 if sys.excepthook.__name__ == "apport_excepthook": import apport_python_hook + assert sys.excepthook is apport_python_hook.apport_excepthook # Give it a descriptive name as a hint for anyone who's stuck trying to @@ -496,5 +499,5 @@ class TrioFakeSysModuleForApport: "You seem to already have a custom sys.excepthook handler " "installed. I'll skip installing Trio's custom handler, but this " "means MultiErrors will not show full tracebacks.", - category=RuntimeWarning + category=RuntimeWarning, ) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 41573d9280..62cbd7eb51 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -23,9 +23,11 @@ from outcome import Error, Value, capture from ._entry_queue import EntryQueue, TrioToken -from ._exceptions import (TrioInternalError, RunFinishedError, Cancelled) +from ._exceptions import TrioInternalError, RunFinishedError, Cancelled from ._ki import ( - LOCALS_KEY_KI_PROTECTION_ENABLED, ki_manager, enable_ki_protection + LOCALS_KEY_KI_PROTECTION_ENABLED, + ki_manager, + enable_ki_protection, ) from ._multierror import MultiError from ._traps import ( @@ -249,7 +251,8 @@ def close(self): @property def parent_cancellation_is_visible_to_us(self): return ( - self._parent is not None and not self._scope.shield + self._parent is not None + and not self._scope.shield and self._parent.effectively_cancelled ) @@ -370,9 +373,7 @@ def __enter__(self): if current_time() >= self._deadline: self.cancel() with self._might_change_registered_deadline(): - self._cancel_status = CancelStatus( - scope=self, parent=task._cancel_status - ) + self._cancel_status = CancelStatus(scope=self, parent=task._cancel_status) task._activate_cancel_status(self._cancel_status) return self @@ -423,8 +424,10 @@ def _close(self, exc): new_exc = RuntimeError( "Cancel scope stack corrupted: attempted to exit {!r} " "in {!r} that's still within its child {!r}\n{}".format( - self, scope_task, scope_task._cancel_status._scope, - MISNESTING_ADVICE + self, + scope_task, + scope_task._cancel_status._scope, + MISNESTING_ADVICE, ) ) new_exc.__context__ = exc @@ -433,7 +436,8 @@ def _close(self, exc): else: scope_task._activate_cancel_status(self._cancel_status.parent) if ( - exc is not None and self._cancel_status.effectively_cancelled + exc is not None + and self._cancel_status.effectively_cancelled and not self._cancel_status.parent_cancellation_is_visible_to_us ): exc = MultiError.filter(self._exc_filter, exc) @@ -486,12 +490,10 @@ def __repr__(self): else: state = ", deadline is {:.2f} seconds {}".format( abs(self._deadline - now), - "from now" if self._deadline >= now else "ago" + "from now" if self._deadline >= now else "ago", ) - return "".format( - id(self), binding, state - ) + return "".format(id(self), binding, state) @contextmanager @enable_ki_protection @@ -639,9 +641,7 @@ def __repr__(self): def started(self, value=None): if self._called_started: - raise RuntimeError( - "called 'started' twice on the same task status" - ) + raise RuntimeError("called 'started' twice on the same task status") self._called_started = True self._value = value @@ -700,6 +700,7 @@ class NurseryManager: and StopAsyncIteration. """ + @enable_ki_protection async def __aenter__(self): self._scope = CancelScope() @@ -769,6 +770,7 @@ class Nursery(metaclass=NoPublicConstructor): other things, e.g. if you want to explicitly cancel all children in response to some external event. """ + def __init__(self, parent_task, cancel_scope): self._parent_task = parent_task parent_task._child_nurseries.append(self) @@ -805,9 +807,7 @@ def _add_exc(self, exc): self.cancel_scope.cancel() def _check_nursery_closed(self): - if not any( - [self._nested_child_running, self._children, self._pending_starts] - ): + if not any([self._nested_child_running, self._children, self._pending_starts]): self._closed = True if self._parent_waiting_in_aexit: self._parent_waiting_in_aexit = False @@ -951,9 +951,7 @@ async def async_fn(arg1, arg2, \*, task_status=trio.TASK_STATUS_IGNORED): # normally. The complicated logic is all in _TaskStatus.started(). # (Any exceptions propagate directly out of the above.) if not task_status._called_started: - raise RuntimeError( - "child exited without calling task_status.started()" - ) + raise RuntimeError("child exited without calling task_status.started()") return task_status._value finally: self._pending_starts -= 1 @@ -1004,7 +1002,7 @@ class Task(metaclass=NoPublicConstructor): _schedule_points = attr.ib(default=0) def __repr__(self): - return ("".format(self.name, id(self))) + return "".format(self.name, id(self)) @property def parent_nursery(self): @@ -1256,19 +1254,13 @@ async def python_wrapper(orig_coro): return await orig_coro coro = python_wrapper(coro) - coro.cr_frame.f_locals.setdefault( - LOCALS_KEY_KI_PROTECTION_ENABLED, system_task - ) + coro.cr_frame.f_locals.setdefault(LOCALS_KEY_KI_PROTECTION_ENABLED, system_task) ###### # Set up the Task object ###### task = Task._create( - coro=coro, - parent_nursery=nursery, - runner=self, - name=name, - context=context, + coro=coro, parent_nursery=nursery, runner=self, name=name, context=context, ) self.tasks.add(task) @@ -1380,9 +1372,7 @@ async def init(self, async_fn, args): async with open_nursery() as system_nursery: self.system_nursery = system_nursery try: - self.main_task = self.spawn_impl( - async_fn, args, system_nursery, None - ) + self.main_task = self.spawn_impl(async_fn, args, system_nursery, None) except BaseException as exc: self.main_task_outcome = Error(exc) system_nursery.cancel_scope.cancel() @@ -1514,7 +1504,9 @@ def instrument(self, method_name, *args): self.instruments.remove(instrument) INSTRUMENT_LOGGER.exception( "Exception raised when calling %r on instrument %r. " - "Instrument has been disabled.", method_name, instrument + "Instrument has been disabled.", + method_name, + instrument, ) @_public @@ -1562,7 +1554,7 @@ def run( *args, clock=None, instruments=(), - restrict_keyboard_interrupt_to_checkpoints=False + restrict_keyboard_interrupt_to_checkpoints=False, ): """Run a Trio-flavored async function, and return the result. @@ -1662,9 +1654,7 @@ def run( # where KeyboardInterrupt would be allowed and converted into an # TrioInternalError: try: - with ki_manager( - runner.deliver_ki, restrict_keyboard_interrupt_to_checkpoints - ): + with ki_manager(runner.deliver_ki, restrict_keyboard_interrupt_to_checkpoints): try: with closing(runner): with runner.entry_queue.wakeup.wakeup_on_signals(): @@ -1706,11 +1696,7 @@ def run_impl(runner, async_fn, args): runner.instrument("before_run") runner.clock.start_clock() runner.init_task = runner.spawn_impl( - runner.init, - (async_fn, args), - None, - "", - system_task=True, + runner.init, (async_fn, args), None, "", system_task=True, ) # You know how people talk about "event loops"? This 'while' loop right diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py index 33b4249416..f2cd9cf4c7 100644 --- a/trio/_core/_traps.py +++ b/trio/_core/_traps.py @@ -55,6 +55,7 @@ class Abort(enum.Enum): FAILED """ + SUCCEEDED = 1 FAILED = 2 diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py index efa0a8d1b3..9830df4bea 100644 --- a/trio/_core/_unbounded_queue.py +++ b/trio/_core/_unbounded_queue.py @@ -40,11 +40,12 @@ class UnboundedQueue(metaclass=SubclassingDeprecatedIn_v0_15_0): ... """ + @deprecated( "0.9.0", issue=497, thing="trio.lowlevel.UnboundedQueue", - instead="trio.open_memory_channel(math.inf)" + instead="trio.open_memory_channel(math.inf)", ) def __init__(self): self._lot = _core.ParkingLot() @@ -140,8 +141,7 @@ def statistics(self): """ return _UnboundedQueueStats( - qsize=len(self._data), - tasks_waiting=self._lot.statistics().tasks_waiting + qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting, ) def __aiter__(self): diff --git a/trio/_core/_wakeup_socketpair.py b/trio/_core/_wakeup_socketpair.py index 0c37928a55..3513cc1ab3 100644 --- a/trio/_core/_wakeup_socketpair.py +++ b/trio/_core/_wakeup_socketpair.py @@ -35,9 +35,7 @@ def __init__(self): # On Windows this is a TCP socket so this might matter. On other # platforms this fails b/c AF_UNIX sockets aren't actually TCP. try: - self.write_sock.setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, 1 - ) + self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) except OSError: pass @@ -54,7 +52,7 @@ async def wait_woken(self): def drain(self): try: while True: - self.wakeup_sock.recv(2**16) + self.wakeup_sock.recv(2 ** 16) except BlockingIOError: pass diff --git a/trio/_core/tests/test_io.py b/trio/_core/tests/test_io.py index fb459bea2d..397375503d 100644 --- a/trio/_core/tests/test_io.py +++ b/trio/_core/tests/test_io.py @@ -53,7 +53,9 @@ def fileno_wrapper(fileobj): notify_closing_options = [trio.lowlevel.notify_closing] for options_list in [ - wait_readable_options, wait_writable_options, notify_closing_options + wait_readable_options, + wait_writable_options, + notify_closing_options, ]: options_list += [using_fileno(f) for f in options_list] @@ -196,9 +198,7 @@ async def writer(): @read_socket_test @write_socket_test -async def test_socket_simultaneous_read_write( - socketpair, wait_readable, wait_writable -): +async def test_socket_simultaneous_read_write(socketpair, wait_readable, wait_writable): record = [] async def r_task(sock): @@ -226,9 +226,7 @@ async def w_task(sock): @read_socket_test @write_socket_test -async def test_socket_actual_streaming( - socketpair, wait_readable, wait_writable -): +async def test_socket_actual_streaming(socketpair, wait_readable, wait_writable): a, b = socketpair # Use a small send buffer on one of the sockets to increase the chance of diff --git a/trio/_core/tests/test_ki.py b/trio/_core/tests/test_ki.py index ddbf10c8e5..b7aaa76cf9 100644 --- a/trio/_core/tests/test_ki.py +++ b/trio/_core/tests/test_ki.py @@ -8,7 +8,10 @@ import time from async_generator import ( - async_generator, yield_, isasyncgenfunction, asynccontextmanager + async_generator, + yield_, + isasyncgenfunction, + asynccontextmanager, ) from ... import _core @@ -107,6 +110,7 @@ async def unprotected(): async def child(expected): import traceback + traceback.print_stack() assert _core.currently_ki_protected() == expected await _core.checkpoint() @@ -259,9 +263,7 @@ async def raiser(name, record): # If we didn't raise (b/c protected), then we *should* get # cancelled at the next opportunity try: - await _core.wait_task_rescheduled( - lambda _: _core.Abort.SUCCEEDED - ) + await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) except _core.Cancelled: record.add(name + " cancel ok") @@ -288,9 +290,7 @@ async def check_protected_kill(): async with _core.open_nursery() as nursery: nursery.start_soon(sleeper, "s1", record) nursery.start_soon(sleeper, "s2", record) - nursery.start_soon( - _core.enable_ki_protection(raiser), "r1", record - ) + nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record) # __aexit__ blocks, and then receives the KI with pytest.raises(KeyboardInterrupt): @@ -472,9 +472,7 @@ def test_ki_with_broken_threads(): @_core.enable_ki_protection async def inner(): - assert signal.getsignal( - signal.SIGINT - ) != signal.default_int_handler + assert signal.getsignal(signal.SIGINT) != signal.default_int_handler _core.run(inner) finally: @@ -522,9 +520,9 @@ def test_ki_wakes_us_up(): # https://bugs.python.org/issue31119 # https://bitbucket.org/pypy/pypy/issues/2623 import platform + buggy_wakeup_fd = ( - platform.python_implementation() == "CPython" and sys.version_info < - (3, 6, 2) + sys.version_info < (3, 6, 2) and platform.python_implementation() == "CPython" ) # lock is only needed to avoid an annoying race condition where the diff --git a/trio/_core/tests/test_multierror.py b/trio/_core/tests/test_multierror.py index 6debf4d45c..c1444b4c0e 100644 --- a/trio/_core/tests/test_multierror.py +++ b/trio/_core/tests/test_multierror.py @@ -1,7 +1,12 @@ import logging import pytest -from traceback import extract_tb, print_exception, format_exception, _cause_message +from traceback import ( + extract_tb, + print_exception, + format_exception, + _cause_message, +) import sys import os import re @@ -243,9 +248,9 @@ def simple_filter(exc): assert isinstance(orig.exceptions[0].exceptions[1], KeyError) # get original traceback summary orig_extracted = ( - extract_tb(orig.__traceback__) + - extract_tb(orig.exceptions[0].__traceback__) + - extract_tb(orig.exceptions[0].exceptions[1].__traceback__) + extract_tb(orig.__traceback__) + + extract_tb(orig.exceptions[0].__traceback__) + + extract_tb(orig.exceptions[0].exceptions[1].__traceback__) ) def p(exc): @@ -495,7 +500,7 @@ def test_format_exception(): r"in raiser3", r"NameError", ], - formatted + formatted, ) # Prints duplicate exceptions in sub-exceptions @@ -556,7 +561,7 @@ def raise2_raiser1(): r"in raise2_raiser1", r" KeyError: 'bar'", ], - formatted + formatted, ) @@ -572,15 +577,14 @@ def test_logging(caplog): except MultiError as exc: logging.getLogger().exception(message) # Join lines together - formatted = "".join( - format_exception(type(exc), exc, exc.__traceback__) - ) + formatted = "".join(format_exception(type(exc), exc, exc.__traceback__)) assert message in caplog.text assert formatted in caplog.text def run_script(name, use_ipython=False): import trio + trio_path = Path(trio.__file__).parent.parent script_path = Path(__file__).parent / "test_multierror_scripts" / name @@ -605,7 +609,7 @@ def run_script(name, use_ipython=False): "IPython", # no startup files "--quick", - "--TerminalIPythonApp.code_to_run=" + '\n'.join(lines), + "--TerminalIPythonApp.code_to_run=" + "\n".join(lines), ] else: cmd = [sys.executable, "-u", str(script_path)] @@ -629,7 +633,8 @@ def check_simple_excepthook(completed): "Details of embedded exception 2", "in exc2_fn", "KeyError", - ], completed.stdout.decode("utf-8") + ], + completed.stdout.decode("utf-8"), ) @@ -652,17 +657,18 @@ def test_custom_excepthook(): # The MultiError "MultiError:", ], - completed.stdout.decode("utf-8") + completed.stdout.decode("utf-8"), ) # This warning is triggered by ipython 7.5.0 on python 3.8 import warnings + warnings.filterwarnings( "ignore", - message=".*\"@coroutine\" decorator is deprecated", + message='.*"@coroutine" decorator is deprecated', category=DeprecationWarning, - module="IPython.*" + module="IPython.*", ) try: import IPython @@ -705,7 +711,7 @@ def test_ipython_custom_exc_handler(): "ValueError", "KeyError", ], - completed.stdout.decode("utf-8") + completed.stdout.decode("utf-8"), ) # Make sure our other warning doesn't show up assert "custom sys.excepthook" not in completed.stdout.decode("utf-8") @@ -714,7 +720,7 @@ def test_ipython_custom_exc_handler(): @slow @pytest.mark.skipif( not Path("/usr/lib/python3/dist-packages/apport_python_hook.py").exists(), - reason="need Ubuntu with python3-apport installed" + reason="need Ubuntu with python3-apport installed", ) def test_apport_excepthook_monkeypatch_interaction(): completed = run_script("apport_excepthook.py") @@ -725,10 +731,6 @@ def test_apport_excepthook_monkeypatch_interaction(): # Proper traceback assert_match_in_seq( - [ - "Details of embedded", - "KeyError", - "Details of embedded", - "ValueError", - ], stdout + ["Details of embedded", "KeyError", "Details of embedded", "ValueError",], + stdout, ) diff --git a/trio/_core/tests/test_multierror_scripts/apport_excepthook.py b/trio/_core/tests/test_multierror_scripts/apport_excepthook.py index ac8110f36e..12e7fb0851 100644 --- a/trio/_core/tests/test_multierror_scripts/apport_excepthook.py +++ b/trio/_core/tests/test_multierror_scripts/apport_excepthook.py @@ -2,8 +2,10 @@ # python, and not available in venvs. So before we can import it we have to # make sure it's on sys.path. import sys + sys.path.append("/usr/lib/python3/dist-packages") import apport_python_hook + apport_python_hook.install() import trio diff --git a/trio/_core/tests/test_multierror_scripts/ipython_custom_exc.py b/trio/_core/tests/test_multierror_scripts/ipython_custom_exc.py index 017c5ea059..b3fd110e50 100644 --- a/trio/_core/tests/test_multierror_scripts/ipython_custom_exc.py +++ b/trio/_core/tests/test_multierror_scripts/ipython_custom_exc.py @@ -14,6 +14,7 @@ def custom_excepthook(*args): sys.excepthook = custom_excepthook import IPython + ip = IPython.get_ipython() diff --git a/trio/_core/tests/test_parking_lot.py b/trio/_core/tests/test_parking_lot.py index 95e4a96b50..1d1ecd111a 100644 --- a/trio/_core/tests/test_parking_lot.py +++ b/trio/_core/tests/test_parking_lot.py @@ -32,10 +32,7 @@ async def waiter(i, lot): assert len(record) == 6 check_sequence_matches( - record, [ - {"sleep 0", "sleep 1", "sleep 2"}, - {"wake 0", "wake 1", "wake 2"}, - ] + record, [{"sleep 0", "sleep 1", "sleep 2"}, {"wake 0", "wake 1", "wake 2"},], ) async with _core.open_nursery() as nursery: @@ -71,12 +68,7 @@ async def waiter(i, lot): lot.unpark(count=2) await wait_all_tasks_blocked() check_sequence_matches( - record, [ - "sleep 0", - "sleep 1", - "sleep 2", - {"wake 0", "wake 1"}, - ] + record, ["sleep 0", "sleep 1", "sleep 2", {"wake 0", "wake 1"},] ) lot.unpark_all() @@ -115,13 +107,7 @@ async def test_parking_lot_cancel(): assert len(record) == 6 check_sequence_matches( - record, [ - "sleep 1", - "sleep 2", - "sleep 3", - "cancelled 2", - {"wake 1", "wake 3"}, - ] + record, ["sleep 1", "sleep 2", "sleep 3", "cancelled 2", {"wake 1", "wake 3"},], ) @@ -160,13 +146,22 @@ async def test_parking_lot_repark(): await wait_all_tasks_blocked() assert len(lot2) == 1 assert record == [ - "sleep 1", "sleep 2", "sleep 3", "wake 1", "cancelled 2" + "sleep 1", + "sleep 2", + "sleep 3", + "wake 1", + "cancelled 2", ] lot2.unpark_all() await wait_all_tasks_blocked() assert record == [ - "sleep 1", "sleep 2", "sleep 3", "wake 1", "cancelled 2", "wake 3" + "sleep 1", + "sleep 2", + "sleep 3", + "wake 1", + "cancelled 2", + "wake 3", ] diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index 0362f81662..d81433c4cc 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -16,8 +16,10 @@ import pytest from .tutil import ( - slow, check_sequence_matches, gc_collect_harder, - ignore_coroutine_never_awaited_warnings + slow, + check_sequence_matches, + gc_collect_harder, + ignore_coroutine_never_awaited_warnings, ) from ... import _core @@ -129,8 +131,7 @@ async def looper(whoami, record): nursery.start_soon(looper, "b", record) check_sequence_matches( - record, - [{("a", 0), ("b", 0)}, {("a", 1), ("b", 1)}, {("a", 2), ("b", 2)}] + record, [{("a", 0), ("b", 0)}, {("a", 1), ("b", 1)}, {("a", 2), ("b", 2)}], ) @@ -174,8 +175,10 @@ async def main(): with pytest.raises(_core.MultiError) as excinfo: _core.run(main) print(excinfo.value) - assert {type(exc) - for exc in excinfo.value.exceptions} == {ValueError, KeyError} + assert {type(exc) for exc in excinfo.value.exceptions} == { + ValueError, + KeyError, + } def test_two_child_crashes(): @@ -189,8 +192,10 @@ async def main(): with pytest.raises(_core.MultiError) as excinfo: _core.run(main) - assert {type(exc) - for exc in excinfo.value.exceptions} == {ValueError, KeyError} + assert {type(exc) for exc in excinfo.value.exceptions} == { + ValueError, + KeyError, + } async def test_child_crash_wakes_parent(): @@ -395,9 +400,9 @@ async def main(): # reschedules the task immediately upon yielding, before the # after_task_step event fires. expected = ( - [("before_run",), ("schedule", task)] + - [("before", task), ("schedule", task), ("after", task)] * 5 + - [("before", task), ("after", task), ("after_run",)] + [("before_run",), ("schedule", task)] + + [("before", task), ("schedule", task), ("after", task)] * 5 + + [("before", task), ("after", task), ("after_run",)] ) assert len(r1.record) > len(r2.record) > len(r3.record) assert r1.record == r2.record + r3.record @@ -433,16 +438,16 @@ async def main(): ("after", tasks["t1"]), ("before", tasks["t2"]), ("schedule", tasks["t2"]), - ("after", tasks["t2"]) + ("after", tasks["t2"]), }, { ("before", tasks["t1"]), ("after", tasks["t1"]), ("before", tasks["t2"]), - ("after", tasks["t2"]) + ("after", tasks["t2"]), }, ("after_run",), - ] # yapf: disable + ] print(list(r.filter_tasks(tasks.values()))) check_sequence_matches(list(r.filter_tasks(tasks.values())), expected) @@ -963,9 +968,7 @@ async def task2(): for exc in exc_info.value.__context__.exceptions: assert isinstance(exc, RuntimeError) assert "closed before the task exited" in str(exc) - cancelled_in_context |= isinstance( - exc.__context__, _core.Cancelled - ) + cancelled_in_context |= isinstance(exc.__context__, _core.Cancelled) assert cancelled_in_context # for the sleep_forever # Trying to exit a cancel scope from an unrelated task raises an error @@ -1231,9 +1234,14 @@ async def child2(): nursery.start_soon(child2) assert record == [ - "child1 raise", "child1 sleep", "child2 wake", "child2 sleep again", - "child1 re-raise", "child1 success", "child2 re-raise", - "child2 success" + "child1 raise", + "child1 sleep", + "child2 wake", + "child2 sleep again", + "child1 re-raise", + "child1 success", + "child2 re-raise", + "child2 success", ] @@ -1720,8 +1728,7 @@ async def async_gen(arg): # pragma: no cover yield arg with pytest.raises( - TypeError, - match="expected an async function but got an async generator" + TypeError, match="expected an async function but got an async generator", ): bad_call(async_gen, 0) @@ -1729,6 +1736,7 @@ async def async_gen(arg): # pragma: no cover def test_calling_asyncio_function_gives_nice_error(): async def child_xyzzy(): import asyncio + await asyncio.Future() async def misguided(): @@ -1749,6 +1757,7 @@ async def test_asyncio_function_inside_nursery_does_not_explode(): with pytest.raises(TypeError) as excinfo: async with _core.open_nursery() as nursery: import asyncio + nursery.start_soon(sleep_forever) await asyncio.Future() assert "asyncio" in str(excinfo.value) @@ -1772,10 +1781,10 @@ async def test_trivial_yields(): async with _core.open_nursery(): raise KeyError assert len(excinfo.value.exceptions) == 2 - assert {type(e) - for e in excinfo.value.exceptions} == { - KeyError, _core.Cancelled - } + assert {type(e) for e in excinfo.value.exceptions} == { + KeyError, + _core.Cancelled, + } async def test_nursery_start(autojump_clock): @@ -1787,9 +1796,7 @@ async def no_args(): # pragma: no cover with pytest.raises(TypeError): await nursery.start(no_args) - async def sleep_then_start( - seconds, *, task_status=_core.TASK_STATUS_IGNORED - ): + async def sleep_then_start(seconds, *, task_status=_core.TASK_STATUS_IGNORED): repr(task_status) # smoke test await sleep(seconds) task_status.started(seconds) @@ -1853,9 +1860,7 @@ async def just_started(task_status=_core.TASK_STATUS_IGNORED): # and if after the no-op started(), the child crashes, the error comes out # of start() - async def raise_keyerror_after_started( - task_status=_core.TASK_STATUS_IGNORED - ): + async def raise_keyerror_after_started(task_status=_core.TASK_STATUS_IGNORED,): task_status.started() raise KeyError("whoopsiedaisy") @@ -1864,8 +1869,10 @@ async def raise_keyerror_after_started( cs.cancel() with pytest.raises(_core.MultiError) as excinfo: await nursery.start(raise_keyerror_after_started) - assert {type(e) - for e in excinfo.value.exceptions} == {_core.Cancelled, KeyError} + assert {type(e) for e in excinfo.value.exceptions} == { + _core.Cancelled, + KeyError, + } # trying to start in a closed nursery raises an error immediately async with _core.open_nursery() as closed_nursery: @@ -2005,9 +2012,7 @@ def __aiter__(self): async def __anext__(self): nexts = self.nexts - items = [ - None, - ] * len(nexts) + items = [None,] * len(nexts) got_stop = False def handle(exc): @@ -2097,7 +2102,7 @@ async def t2(): def test_system_task_contexts(): - cvar = contextvars.ContextVar('qwilfish') + cvar = contextvars.ContextVar("qwilfish") cvar.set("water") async def system_task(): @@ -2147,7 +2152,7 @@ def test_Cancelled_init(): def test_Cancelled_str(): cancelled = _core.Cancelled._create() - assert str(cancelled) == 'Cancelled' + assert str(cancelled) == "Cancelled" def test_Cancelled_subclass(): @@ -2205,9 +2210,7 @@ async def detachable_coroutine(task_outcome, yield_value): await async_yield(yield_value) async with _core.open_nursery() as nursery: - nursery.start_soon( - detachable_coroutine, outcome.Value(None), "I'm free!" - ) + nursery.start_soon(detachable_coroutine, outcome.Value(None), "I'm free!") # If we get here then Trio thinks the task has exited... but the coroutine # is still iterable @@ -2222,9 +2225,7 @@ async def detachable_coroutine(task_outcome, yield_value): pdco_outcome = None with pytest.raises(KeyError): async with _core.open_nursery() as nursery: - nursery.start_soon( - detachable_coroutine, outcome.Error(KeyError()), "uh oh" - ) + nursery.start_soon(detachable_coroutine, outcome.Error(KeyError()), "uh oh") throw_in = ValueError() assert task.coro.throw(throw_in) == "uh oh" assert pdco_outcome == outcome.Error(throw_in) @@ -2234,9 +2235,7 @@ async def detachable_coroutine(task_outcome, yield_value): async def bad_detach(): async with _core.open_nursery(): with pytest.raises(RuntimeError) as excinfo: - await _core.permanently_detach_coroutine_object( - outcome.Value(None) - ) + await _core.permanently_detach_coroutine_object(outcome.Value(None)) assert "open nurser" in str(excinfo.value) async with _core.open_nursery() as nursery: @@ -2267,9 +2266,7 @@ def abort_fn(_): # pragma: no cover await async_yield(2) with pytest.raises(RuntimeError) as excinfo: - await _core.reattach_detached_coroutine_object( - unrelated_task, None - ) + await _core.reattach_detached_coroutine_object(unrelated_task, None) assert "does not match" in str(excinfo.value) await _core.reattach_detached_coroutine_object(task, "byebye") diff --git a/trio/_core/tests/test_windows.py b/trio/_core/tests/test_windows.py index 2fb8a97092..0a8179f88f 100644 --- a/trio/_core/tests/test_windows.py +++ b/trio/_core/tests/test_windows.py @@ -4,25 +4,28 @@ import pytest -on_windows = (os.name == "nt") +on_windows = os.name == "nt" # Mark all the tests in this file as being windows-only pytestmark = pytest.mark.skipif(not on_windows, reason="windows only") from .tutil import slow, gc_collect_harder from ... import _core, sleep, move_on_after from ...testing import wait_all_tasks_blocked + if on_windows: from .._windows_cffi import ( - ffi, kernel32, INVALID_HANDLE_VALUE, raise_winerror, FileFlags + ffi, + kernel32, + INVALID_HANDLE_VALUE, + raise_winerror, + FileFlags, ) # The undocumented API that this is testing should be changed to stop using # UnboundedQueue (or just removed until we have time to redo it), but until # then we filter out the warning. -@pytest.mark.filterwarnings( - "ignore:.*UnboundedQueue:trio.TrioDeprecationWarning" -) +@pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning") async def test_completion_key_listen(): async def post(key): iocp = ffi.cast("HANDLE", _core.current_iocp()) @@ -30,9 +33,7 @@ async def post(key): print("post", i) if i % 3 == 0: await _core.checkpoint() - success = kernel32.PostQueuedCompletionStatus( - iocp, i, key, ffi.NULL - ) + success = kernel32.PostQueuedCompletionStatus(iocp, i, key, ffi.NULL) assert success with _core.monitor_completion_key() as (key, queue): @@ -80,9 +81,7 @@ async def test_readinto_overlapped(): async def read_region(start, end): await _core.readinto_overlapped( - handle, - buffer_view[start:end], - start, + handle, buffer_view[start:end], start, ) _core.register_with_iocp(handle) @@ -124,10 +123,7 @@ async def main(): try: async with _core.open_nursery() as nursery: nursery.start_soon( - _core.readinto_overlapped, - read_handle, - target, - name="xyz" + _core.readinto_overlapped, read_handle, target, name="xyz", ) await wait_all_tasks_blocked() nursery.cancel_scope.cancel() diff --git a/trio/_core/tests/tutil.py b/trio/_core/tests/tutil.py index ac090cb8de..b569371482 100644 --- a/trio/_core/tests/tutil.py +++ b/trio/_core/tests/tutil.py @@ -10,15 +10,11 @@ # See trio/tests/conftest.py for the other half of this from trio.tests.conftest import RUN_SLOW -slow = pytest.mark.skipif( - not RUN_SLOW, - reason="use --run-slow to run slow tests", -) + +slow = pytest.mark.skipif(not RUN_SLOW, reason="use --run-slow to run slow tests",) try: - s = stdlib_socket.socket( - stdlib_socket.AF_INET6, stdlib_socket.SOCK_STREAM, 0 - ) + s = stdlib_socket.socket(stdlib_socket.AF_INET6, stdlib_socket.SOCK_STREAM, 0) except OSError: # pragma: no cover # Some systems don't even support creating an IPv6 socket, let alone # binding it. (ex: Linux with 'ipv6.disable=1' in the kernel command line) @@ -30,7 +26,7 @@ can_create_ipv6 = True with s: try: - s.bind(('::1', 0)) + s.bind(("::1", 0)) except OSError: can_bind_ipv6 = False else: @@ -61,9 +57,7 @@ def gc_collect_harder(): @contextmanager def ignore_coroutine_never_awaited_warnings(): with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="coroutine '.*' was never awaited" - ) + warnings.filterwarnings("ignore", message="coroutine '.*' was never awaited") try: yield finally: @@ -79,14 +73,15 @@ def check_sequence_matches(seq, template): for pattern in template: if not isinstance(pattern, set): pattern = {pattern} - got = set(seq[i:i + len(pattern)]) + got = set(seq[i : i + len(pattern)]) assert got == pattern i += len(got) # https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=246350 skip_if_fbsd_pipes_broken = pytest.mark.skipif( - hasattr(os, "uname") and os.uname().sysname == "FreeBSD" + hasattr(os, "uname") + and os.uname().sysname == "FreeBSD" and os.uname().release[:4] < "12.2", - reason="hangs on FreeBSD 12.1 and earlier, due to FreeBSD bug #246350" + reason="hangs on FreeBSD 12.1 and earlier, due to FreeBSD bug #246350", ) diff --git a/trio/_deprecate.py b/trio/_deprecate.py index b1362cbc38..4f9f15ec35 100644 --- a/trio/_deprecate.py +++ b/trio/_deprecate.py @@ -117,9 +117,7 @@ def __getattr__(self, name): if instead is DeprecatedAttribute._not_set: instead = info.value thing = "{}.{}".format(self.__name__, name) - warn_deprecated( - thing, info.version, issue=info.issue, instead=instead - ) + warn_deprecated(thing, info.version, issue=info.issue, instead=instead) return info.value msg = "module '{}' has no attribute '{}'" diff --git a/trio/_deprecated_ssl_reexports.py b/trio/_deprecated_ssl_reexports.py index 35d22f49f4..31d5fdf9f6 100644 --- a/trio/_deprecated_ssl_reexports.py +++ b/trio/_deprecated_ssl_reexports.py @@ -12,11 +12,24 @@ # Always available from ssl import ( - cert_time_to_seconds, CertificateError, create_default_context, - DER_cert_to_PEM_cert, get_default_verify_paths, match_hostname, - PEM_cert_to_DER_cert, Purpose, SSLEOFError, SSLError, SSLSyscallError, - SSLZeroReturnError, AlertDescription, SSLErrorNumber, SSLSession, - VerifyFlags, VerifyMode, Options + cert_time_to_seconds, + CertificateError, + create_default_context, + DER_cert_to_PEM_cert, + get_default_verify_paths, + match_hostname, + PEM_cert_to_DER_cert, + Purpose, + SSLEOFError, + SSLError, + SSLSyscallError, + SSLZeroReturnError, + AlertDescription, + SSLErrorNumber, + SSLSession, + VerifyFlags, + VerifyMode, + Options, ) # Added in python 3.7 @@ -35,50 +48,98 @@ # (Real import is below) try: from ssl import ( - AF_INET, ALERT_DESCRIPTION_ACCESS_DENIED, + AF_INET, + ALERT_DESCRIPTION_ACCESS_DENIED, ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE, ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE, - ALERT_DESCRIPTION_BAD_CERTIFICATE, ALERT_DESCRIPTION_BAD_RECORD_MAC, + ALERT_DESCRIPTION_BAD_CERTIFICATE, + ALERT_DESCRIPTION_BAD_RECORD_MAC, ALERT_DESCRIPTION_CERTIFICATE_EXPIRED, ALERT_DESCRIPTION_CERTIFICATE_REVOKED, ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN, ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE, - ALERT_DESCRIPTION_CLOSE_NOTIFY, ALERT_DESCRIPTION_DECODE_ERROR, + ALERT_DESCRIPTION_CLOSE_NOTIFY, + ALERT_DESCRIPTION_DECODE_ERROR, ALERT_DESCRIPTION_DECOMPRESSION_FAILURE, - ALERT_DESCRIPTION_DECRYPT_ERROR, ALERT_DESCRIPTION_HANDSHAKE_FAILURE, + ALERT_DESCRIPTION_DECRYPT_ERROR, + ALERT_DESCRIPTION_HANDSHAKE_FAILURE, ALERT_DESCRIPTION_ILLEGAL_PARAMETER, ALERT_DESCRIPTION_INSUFFICIENT_SECURITY, - ALERT_DESCRIPTION_INTERNAL_ERROR, ALERT_DESCRIPTION_NO_RENEGOTIATION, - ALERT_DESCRIPTION_PROTOCOL_VERSION, ALERT_DESCRIPTION_RECORD_OVERFLOW, - ALERT_DESCRIPTION_UNEXPECTED_MESSAGE, ALERT_DESCRIPTION_UNKNOWN_CA, + ALERT_DESCRIPTION_INTERNAL_ERROR, + ALERT_DESCRIPTION_NO_RENEGOTIATION, + ALERT_DESCRIPTION_PROTOCOL_VERSION, + ALERT_DESCRIPTION_RECORD_OVERFLOW, + ALERT_DESCRIPTION_UNEXPECTED_MESSAGE, + ALERT_DESCRIPTION_UNKNOWN_CA, ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY, ALERT_DESCRIPTION_UNRECOGNIZED_NAME, ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE, ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION, - ALERT_DESCRIPTION_USER_CANCELLED, CERT_NONE, CERT_OPTIONAL, - CERT_REQUIRED, CHANNEL_BINDING_TYPES, HAS_ALPN, HAS_ECDH, - HAS_NEVER_CHECK_COMMON_NAME, HAS_NPN, HAS_SNI, OP_ALL, - OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION, OP_COOKIE_EXCHANGE, - OP_DONT_INSERT_EMPTY_FRAGMENTS, OP_EPHEMERAL_RSA, - OP_LEGACY_SERVER_CONNECT, OP_MICROSOFT_BIG_SSLV3_BUFFER, - OP_MICROSOFT_SESS_ID_BUG, OP_MSIE_SSLV2_RSA_PADDING, - OP_NETSCAPE_CA_DN_BUG, OP_NETSCAPE_CHALLENGE_BUG, + ALERT_DESCRIPTION_USER_CANCELLED, + CERT_NONE, + CERT_OPTIONAL, + CERT_REQUIRED, + CHANNEL_BINDING_TYPES, + HAS_ALPN, + HAS_ECDH, + HAS_NEVER_CHECK_COMMON_NAME, + HAS_NPN, + HAS_SNI, + OP_ALL, + OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION, + OP_COOKIE_EXCHANGE, + OP_DONT_INSERT_EMPTY_FRAGMENTS, + OP_EPHEMERAL_RSA, + OP_LEGACY_SERVER_CONNECT, + OP_MICROSOFT_BIG_SSLV3_BUFFER, + OP_MICROSOFT_SESS_ID_BUG, + OP_MSIE_SSLV2_RSA_PADDING, + OP_NETSCAPE_CA_DN_BUG, + OP_NETSCAPE_CHALLENGE_BUG, OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG, - OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG, OP_NO_QUERY_MTU, OP_PKCS1_CHECK_1, - OP_PKCS1_CHECK_2, OP_SSLEAY_080_CLIENT_DH_BUG, - OP_SSLREF2_REUSE_CERT_TYPE_BUG, OP_TLS_BLOCK_PADDING_BUG, - OP_TLS_D5_BUG, OP_TLS_ROLLBACK_BUG, SSL_ERROR_NONE, - SSL_ERROR_NO_SOCKET, OP_CIPHER_SERVER_PREFERENCE, OP_NO_COMPRESSION, - OP_NO_RENEGOTIATION, OP_NO_TICKET, OP_SINGLE_DH_USE, - OP_SINGLE_ECDH_USE, OPENSSL_VERSION_INFO, OPENSSL_VERSION_NUMBER, - OPENSSL_VERSION, PEM_FOOTER, PEM_HEADER, PROTOCOL_TLS_CLIENT, - PROTOCOL_TLS_SERVER, PROTOCOL_TLS, SO_TYPE, SOCK_STREAM, SOL_SOCKET, - SSL_ERROR_EOF, SSL_ERROR_INVALID_ERROR_CODE, SSL_ERROR_SSL, - SSL_ERROR_SYSCALL, SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_READ, - SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_X509_LOOKUP, - SSL_ERROR_ZERO_RETURN, VERIFY_CRL_CHECK_CHAIN, VERIFY_CRL_CHECK_LEAF, - VERIFY_DEFAULT, VERIFY_X509_STRICT, VERIFY_X509_TRUSTED_FIRST, - OP_ENABLE_MIDDLEBOX_COMPAT + OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG, + OP_NO_QUERY_MTU, + OP_PKCS1_CHECK_1, + OP_PKCS1_CHECK_2, + OP_SSLEAY_080_CLIENT_DH_BUG, + OP_SSLREF2_REUSE_CERT_TYPE_BUG, + OP_TLS_BLOCK_PADDING_BUG, + OP_TLS_D5_BUG, + OP_TLS_ROLLBACK_BUG, + SSL_ERROR_NONE, + SSL_ERROR_NO_SOCKET, + OP_CIPHER_SERVER_PREFERENCE, + OP_NO_COMPRESSION, + OP_NO_RENEGOTIATION, + OP_NO_TICKET, + OP_SINGLE_DH_USE, + OP_SINGLE_ECDH_USE, + OPENSSL_VERSION_INFO, + OPENSSL_VERSION_NUMBER, + OPENSSL_VERSION, + PEM_FOOTER, + PEM_HEADER, + PROTOCOL_TLS_CLIENT, + PROTOCOL_TLS_SERVER, + PROTOCOL_TLS, + SO_TYPE, + SOCK_STREAM, + SOL_SOCKET, + SSL_ERROR_EOF, + SSL_ERROR_INVALID_ERROR_CODE, + SSL_ERROR_SSL, + SSL_ERROR_SYSCALL, + SSL_ERROR_WANT_CONNECT, + SSL_ERROR_WANT_READ, + SSL_ERROR_WANT_WRITE, + SSL_ERROR_WANT_X509_LOOKUP, + SSL_ERROR_ZERO_RETURN, + VERIFY_CRL_CHECK_CHAIN, + VERIFY_CRL_CHECK_LEAF, + VERIFY_DEFAULT, + VERIFY_X509_STRICT, + VERIFY_X509_TRUSTED_FIRST, + OP_ENABLE_MIDDLEBOX_COMPAT, ) except ImportError: pass @@ -86,10 +147,11 @@ # Dynamically re-export whatever constants this particular Python happens to # have: import ssl as _stdlib_ssl + globals().update( { _name: getattr(_stdlib_ssl, _name) for _name in _stdlib_ssl.__dict__.keys() - if _name.isupper() and not _name.startswith('_') + if _name.isupper() and not _name.startswith("_") } ) diff --git a/trio/_deprecated_subprocess_reexports.py b/trio/_deprecated_subprocess_reexports.py index b91e28784a..2d1e4eed25 100644 --- a/trio/_deprecated_subprocess_reexports.py +++ b/trio/_deprecated_subprocess_reexports.py @@ -2,16 +2,27 @@ # Reexport constants and exceptions from the stdlib subprocess module from subprocess import ( - PIPE, STDOUT, DEVNULL, CalledProcessError, SubprocessError, TimeoutExpired, - CompletedProcess + PIPE, + STDOUT, + DEVNULL, + CalledProcessError, + SubprocessError, + TimeoutExpired, + CompletedProcess, ) # Windows only try: from subprocess import ( - STARTUPINFO, STD_INPUT_HANDLE, STD_OUTPUT_HANDLE, STD_ERROR_HANDLE, - SW_HIDE, STARTF_USESTDHANDLES, STARTF_USESHOWWINDOW, - CREATE_NEW_CONSOLE, CREATE_NEW_PROCESS_GROUP + STARTUPINFO, + STD_INPUT_HANDLE, + STD_OUTPUT_HANDLE, + STD_ERROR_HANDLE, + SW_HIDE, + STARTF_USESTDHANDLES, + STARTF_USESHOWWINDOW, + CREATE_NEW_CONSOLE, + CREATE_NEW_PROCESS_GROUP, ) except ImportError: pass @@ -19,10 +30,16 @@ # Windows 3.7+ only try: from subprocess import ( - ABOVE_NORMAL_PRIORITY_CLASS, BELOW_NORMAL_PRIORITY_CLASS, - HIGH_PRIORITY_CLASS, IDLE_PRIORITY_CLASS, NORMAL_PRIORITY_CLASS, - REALTIME_PRIORITY_CLASS, CREATE_NO_WINDOW, DETACHED_PROCESS, - CREATE_DEFAULT_ERROR_MODE, CREATE_BREAKAWAY_FROM_JOB + ABOVE_NORMAL_PRIORITY_CLASS, + BELOW_NORMAL_PRIORITY_CLASS, + HIGH_PRIORITY_CLASS, + IDLE_PRIORITY_CLASS, + NORMAL_PRIORITY_CLASS, + REALTIME_PRIORITY_CLASS, + CREATE_NO_WINDOW, + DETACHED_PROCESS, + CREATE_DEFAULT_ERROR_MODE, + CREATE_BREAKAWAY_FROM_JOB, ) except ImportError: pass diff --git a/trio/_file_io.py b/trio/_file_io.py index 1d3508875e..15e45711c5 100644 --- a/trio/_file_io.py +++ b/trio/_file_io.py @@ -8,43 +8,43 @@ # This list is also in the docs, make sure to keep them in sync _FILE_SYNC_ATTRS = { - 'closed', - 'encoding', - 'errors', - 'fileno', - 'isatty', - 'newlines', - 'readable', - 'seekable', - 'writable', + "closed", + "encoding", + "errors", + "fileno", + "isatty", + "newlines", + "readable", + "seekable", + "writable", # not defined in *IOBase: - 'buffer', - 'raw', - 'line_buffering', - 'closefd', - 'name', - 'mode', - 'getvalue', - 'getbuffer', + "buffer", + "raw", + "line_buffering", + "closefd", + "name", + "mode", + "getvalue", + "getbuffer", } # This list is also in the docs, make sure to keep them in sync _FILE_ASYNC_METHODS = { - 'flush', - 'read', - 'read1', - 'readall', - 'readinto', - 'readline', - 'readlines', - 'seek', - 'tell', - 'truncate', - 'write', - 'writelines', + "flush", + "read", + "read1", + "readall", + "readinto", + "readline", + "readlines", + "seek", + "tell", + "truncate", + "write", + "writelines", # not defined in *IOBase: - 'readinto1', - 'peek', + "readinto1", + "peek", } @@ -57,6 +57,7 @@ class AsyncIOWrapper(AsyncResource): wrapper, if they exist in the wrapped file object. """ + def __init__(self, file): self._wrapped = file @@ -88,9 +89,7 @@ async def wrapper(*args, **kwargs): def __dir__(self): attrs = set(super().__dir__()) attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a)) - attrs.update( - a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a) - ) + attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a)) return attrs def __aiter__(self): @@ -131,13 +130,13 @@ async def aclose(self): async def open_file( file, - mode='r', + mode="r", buffering=-1, encoding=None, errors=None, newline=None, closefd=True, - opener=None + opener=None, ): """Asynchronous version of :func:`io.open`. @@ -158,8 +157,7 @@ async def open_file( """ _file = wrap_file( await trio.to_thread.run_sync( - io.open, file, mode, buffering, encoding, errors, newline, closefd, - opener + io.open, file, mode, buffering, encoding, errors, newline, closefd, opener, ) ) return _file @@ -182,13 +180,14 @@ def wrap_file(file): assert await async_file.read() == 'asdf' """ + def has(attr): return hasattr(file, attr) and callable(getattr(file, attr)) - if not (has('close') and (has('read') or has('write'))): + if not (has("close") and (has("read") or has("write"))): raise TypeError( - '{} does not implement required duck-file methods: ' - 'close and (read or write)'.format(file) + "{} does not implement required duck-file methods: " + "close and (read or write)".format(file) ) return AsyncIOWrapper(file) diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index 7cdc0c75d1..d4091942db 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -37,9 +37,7 @@ async def aclose_forcefully(resource): @attr.s(eq=False, hash=False) -class StapledStream( - HalfCloseableStream, metaclass=SubclassingDeprecatedIn_v0_15_0 -): +class StapledStream(HalfCloseableStream, metaclass=SubclassingDeprecatedIn_v0_15_0): """This class `staples `__ together two unidirectional streams to make single bidirectional stream. @@ -73,6 +71,7 @@ class StapledStream( is delegated to this object. """ + send_stream = attr.ib() receive_stream = attr.ib() diff --git a/trio/_highlevel_open_tcp_listeners.py b/trio/_highlevel_open_tcp_listeners.py index 6ac44db43b..8e399a61b5 100644 --- a/trio/_highlevel_open_tcp_listeners.py +++ b/trio/_highlevel_open_tcp_listeners.py @@ -39,7 +39,7 @@ def _compute_backlog(backlog): # Many systems (Linux, BSDs, ...) store the backlog in a uint16 and are # missing overflow protection, so we apply our own overflow protection. # https://github.com/golang/go/issues/5030 - return min(backlog, 0xffff) + return min(backlog, 0xFFFF) async def open_tcp_listeners(port, *, host=None, backlog=None): @@ -92,10 +92,7 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): backlog = _compute_backlog(backlog) addresses = await tsocket.getaddrinfo( - host, - port, - type=tsocket.SOCK_STREAM, - flags=tsocket.AI_PASSIVE, + host, port, type=tsocket.SOCK_STREAM, flags=tsocket.AI_PASSIVE, ) listeners = [] @@ -119,14 +116,10 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): try: # See https://github.com/python-trio/trio/issues/39 if sys.platform != "win32": - sock.setsockopt( - tsocket.SOL_SOCKET, tsocket.SO_REUSEADDR, 1 - ) + sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_REUSEADDR, 1) if family == tsocket.AF_INET6: - sock.setsockopt( - tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, 1 - ) + sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, 1) await sock.bind(sockaddr) sock.listen(backlog) @@ -144,7 +137,7 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): raise OSError( errno.EAFNOSUPPORT, "This system doesn't support any of the kinds of " - "socket that that address could use" + "socket that that address could use", ) from trio.MultiError(unsupported_address_families) return listeners @@ -157,7 +150,7 @@ async def serve_tcp( host=None, backlog=None, handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED + task_status=trio.TASK_STATUS_IGNORED, ): """Listen for incoming TCP connections, and for each one start a task running ``handler(stream)``. @@ -224,8 +217,5 @@ async def serve_tcp( """ listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog) await trio.serve_listeners( - handler, - listeners, - handler_nursery=handler_nursery, - task_status=task_status + handler, listeners, handler_nursery=handler_nursery, task_status=task_status, ) diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index 11a1e1846f..847fac6a96 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -169,7 +169,7 @@ async def open_tcp_stream( port, *, # No trailing comma b/c bpo-9232 (fixed in py36) - happy_eyeballs_delay=DEFAULT_DELAY + happy_eyeballs_delay=DEFAULT_DELAY, ): """Connect to the given host and port over TCP. diff --git a/trio/_highlevel_open_unix_stream.py b/trio/_highlevel_open_unix_stream.py index 87462c1232..eea1e77ffc 100644 --- a/trio/_highlevel_open_unix_stream.py +++ b/trio/_highlevel_open_unix_stream.py @@ -6,6 +6,7 @@ try: from trio.socket import AF_UNIX + has_unix = True except ImportError: has_unix = False diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py index 0a46be780a..55216c85c5 100644 --- a/trio/_highlevel_serve_listeners.py +++ b/trio/_highlevel_serve_listeners.py @@ -39,7 +39,7 @@ async def _serve_one_listener(listener, handler_nursery, handler): errno.errorcode[exc.errno], os.strerror(exc.errno), SLEEP_TIME, - exc_info=True + exc_info=True, ) await trio.sleep(SLEEP_TIME) else: @@ -49,11 +49,7 @@ async def _serve_one_listener(listener, handler_nursery, handler): async def serve_listeners( - handler, - listeners, - *, - handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED + handler, listeners, *, handler_nursery=None, task_status=trio.TASK_STATUS_IGNORED, ): r"""Listen for incoming connections on ``listeners``, and for each one start a task running ``handler(stream)``. @@ -118,9 +114,7 @@ async def serve_listeners( if handler_nursery is None: handler_nursery = nursery for listener in listeners: - nursery.start_soon( - _serve_one_listener, listener, handler_nursery, handler - ) + nursery.start_soon(_serve_one_listener, listener, handler_nursery, handler) # The listeners are already queueing connections when we're called, # but we wait until the end to call started() just in case we get an # error or whatever. diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index abcf951028..a707f0cac6 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -28,18 +28,14 @@ def _translate_socket_errors_to_stream_errors(): yield except OSError as exc: if exc.errno in _closed_stream_errnos: - raise trio.ClosedResourceError( - "this socket was already closed" - ) from None + raise trio.ClosedResourceError("this socket was already closed") from None else: raise trio.BrokenResourceError( "socket connection broken: {}".format(exc) ) from exc -class SocketStream( - HalfCloseableStream, metaclass=SubclassingDeprecatedIn_v0_15_0 -): +class SocketStream(HalfCloseableStream, metaclass=SubclassingDeprecatedIn_v0_15_0): """An implementation of the :class:`trio.abc.HalfCloseableStream` interface based on a raw network socket. @@ -62,6 +58,7 @@ class SocketStream( The Trio socket object that this stream wraps. """ + def __init__(self, socket): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") @@ -92,9 +89,7 @@ def __init__(self, socket): # http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1 # ). The theory is that you want it to be bandwidth * # rescheduling interval. - self.setsockopt( - tsocket.IPPROTO_TCP, tsocket.TCP_NOTSENT_LOWAT, 2**14 - ) + self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NOTSENT_LOWAT, 2 ** 14) except OSError: pass @@ -106,9 +101,7 @@ async def send_all(self, data): with memoryview(data) as data: if not data: if self.socket.fileno() == -1: - raise trio.ClosedResourceError( - "socket was already closed" - ) + raise trio.ClosedResourceError("socket was already closed") await trio.lowlevel.checkpoint() return total_sent = 0 @@ -322,9 +315,7 @@ def getsockopt(self, level, option, buffersize=0): pass -class SocketListener( - Listener[SocketStream], metaclass=SubclassingDeprecatedIn_v0_15_0 -): +class SocketListener(Listener[SocketStream], metaclass=SubclassingDeprecatedIn_v0_15_0): """A :class:`~trio.abc.Listener` that uses a listening socket to accept incoming connections as :class:`SocketStream` objects. @@ -340,15 +331,14 @@ class SocketListener( The Trio socket object that this stream wraps. """ + def __init__(self, socket): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: raise ValueError("SocketListener requires a SOCK_STREAM socket") try: - listening = socket.getsockopt( - tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN - ) + listening = socket.getsockopt(tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN) except OSError: # SO_ACCEPTCONN fails on macOS; we just have to trust the user. pass diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py index fc604d9286..61c463c65e 100644 --- a/trio/_highlevel_ssl_helpers.py +++ b/trio/_highlevel_ssl_helpers.py @@ -20,7 +20,7 @@ async def open_ssl_over_tcp_stream( https_compatible=False, ssl_context=None, # No trailing comma b/c bpo-9232 (fixed in py36) - happy_eyeballs_delay=DEFAULT_DELAY + happy_eyeballs_delay=DEFAULT_DELAY, ): """Make a TLS-encrypted Connection to the given host and port over TCP. @@ -49,9 +49,7 @@ async def open_ssl_over_tcp_stream( """ tcp_stream = await trio.open_tcp_stream( - host, - port, - happy_eyeballs_delay=happy_eyeballs_delay, + host, port, happy_eyeballs_delay=happy_eyeballs_delay, ) if ssl_context is None: ssl_context = ssl.create_default_context() @@ -78,15 +76,10 @@ async def open_ssl_over_tcp_listeners( backlog (int or None): See :func:`open_tcp_listeners` for details. """ - tcp_listeners = await trio.open_tcp_listeners( - port, host=host, backlog=backlog - ) + tcp_listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog) ssl_listeners = [ - trio.SSLListener( - tcp_listener, - ssl_context, - https_compatible=https_compatible, - ) for tcp_listener in tcp_listeners + trio.SSLListener(tcp_listener, ssl_context, https_compatible=https_compatible,) + for tcp_listener in tcp_listeners ] return ssl_listeners @@ -100,7 +93,7 @@ async def serve_ssl_over_tcp( https_compatible=False, backlog=None, handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED + task_status=trio.TASK_STATUS_IGNORED, ): """Listen for incoming TCP connections, and for each one start a task running ``handler(stream)``. @@ -154,11 +147,8 @@ async def serve_ssl_over_tcp( ssl_context, host=host, https_compatible=https_compatible, - backlog=backlog + backlog=backlog, ) await trio.serve_listeners( - handler, - listeners, - handler_nursery=handler_nursery, - task_status=task_status + handler, listeners, handler_nursery=handler_nursery, task_status=task_status, ) diff --git a/trio/_path.py b/trio/_path.py index 75f1fab615..095bdc779f 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -88,7 +88,7 @@ def __init__(cls, name, bases, attrs): def generate_forwards(cls, attrs): # forward functions of _forwards for attr_name, attr in cls._forwards.__dict__.items(): - if attr_name.startswith('_') or attr_name in attrs: + if attr_name.startswith("_") or attr_name in attrs: continue if isinstance(attr, property): @@ -103,7 +103,7 @@ def generate_wraps(cls, attrs): # generate wrappers for functions of _wraps for attr_name, attr in cls._wraps.__dict__.items(): # .z. exclude cls._wrap_iter - if attr_name.startswith('_') or attr_name in attrs: + if attr_name.startswith("_") or attr_name in attrs: continue if isinstance(attr, classmethod): wrapper = classmethod_wrapper_factory(cls, attr_name) @@ -138,18 +138,18 @@ class Path(metaclass=AsyncAutoWrapperType): _wraps = pathlib.Path _forwards = pathlib.PurePath _forward_magic = [ - '__str__', - '__bytes__', - '__truediv__', - '__rtruediv__', - '__eq__', - '__lt__', - '__le__', - '__gt__', - '__ge__', - '__hash__', + "__str__", + "__bytes__", + "__truediv__", + "__rtruediv__", + "__eq__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__hash__", ] - _wrap_iter = ['glob', 'rglob', 'iterdir'] + _wrap_iter = ["glob", "rglob", "iterdir"] def __init__(self, *args): self._wrapped = pathlib.Path(*args) @@ -164,7 +164,7 @@ def __dir__(self): return super().__dir__() + self._forward def __repr__(self): - return 'trio.Path({})'.format(repr(str(self))) + return "trio.Path({})".format(repr(str(self))) def __fspath__(self): return os.fspath(self._wrapped) diff --git a/trio/_socket.py b/trio/_socket.py index 5518335b44..ccd2a6ff20 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -150,8 +150,10 @@ async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): # with the _NUMERIC_ONLY flags set, and then only spawn a thread if that # fails with EAI_NONAME: def numeric_only_failure(exc): - return isinstance(exc, _stdlib_socket.gaierror) and \ - exc.errno == _stdlib_socket.EAI_NONAME + return ( + isinstance(exc, _stdlib_socket.gaierror) + and exc.errno == _stdlib_socket.EAI_NONAME + ) async with _try_sync(numeric_only_failure): return _stdlib_socket.getaddrinfo( @@ -185,7 +187,7 @@ def numeric_only_failure(exc): type, proto, flags, - cancellable=True + cancellable=True, ) @@ -266,7 +268,7 @@ def socket( family=_stdlib_socket.AF_INET, type=_stdlib_socket.SOCK_STREAM, proto=0, - fileno=None + fileno=None, ): """Create a new Trio socket, like :func:`socket.socket`. @@ -279,9 +281,7 @@ def socket( if sf is not None: return sf.socket(family, type, proto) else: - family, type, proto = _sniff_sockopts_for_fileno( - family, type, proto, fileno - ) + family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fileno) stdlib_socket = _stdlib_socket.socket(family, type, proto, fileno) return from_stdlib_socket(stdlib_socket) @@ -296,6 +296,7 @@ def _sniff_sockopts_for_fileno(family, type, proto, fileno): if not _sys.platform == "linux": return family, type, proto from socket import SO_DOMAIN, SO_PROTOCOL, SOL_SOCKET, SO_TYPE + sockobj = _stdlib_socket.socket(family, type, proto, fileno=fileno) try: family = sockobj.getsockopt(SOL_SOCKET, SO_DOMAIN) @@ -338,15 +339,13 @@ def _make_simple_sock_method_wrapper(methname, wait_fn, maybe_avail=False): async def wrapper(self, *args, **kwargs): return await self._nonblocking_helper(fn, args, kwargs, wait_fn) - wrapper.__doc__ = ( - """Like :meth:`socket.socket.{}`, but async. + wrapper.__doc__ = f"""Like :meth:`socket.socket.{methname}`, but async. - """.format(methname) - ) + """ if maybe_avail: wrapper.__doc__ += ( - "Only available on platforms where :meth:`socket.socket.{}` " - "is available.".format(methname) + f"Only available on platforms where :meth:`socket.socket.{methname}` is " + "available." ) return wrapper @@ -451,7 +450,8 @@ async def bind(self, address): address = await self._resolve_local_address(address) if ( hasattr(_stdlib_socket, "AF_UNIX") - and self.family == _stdlib_socket.AF_UNIX and address[0] + and self.family == _stdlib_socket.AF_UNIX + and address[0] ): # Use a thread for the filesystem traversal (unless it's an # abstract domain socket) @@ -497,8 +497,7 @@ async def _resolve_address(self, address, flags): elif self._sock.family == _stdlib_socket.AF_INET6: if not isinstance(address, tuple) or not 2 <= len(address) <= 4: raise ValueError( - "address should be a (host, port, [flowinfo, [scopeid]]) " - "tuple" + "address should be a (host, port, [flowinfo, [scopeid]]) tuple" ) elif self._sock.family == _stdlib_socket.AF_UNIX: await trio.lowlevel.checkpoint() @@ -522,9 +521,7 @@ async def _resolve_address(self, address, flags): # no ipv6. # flags |= AI_ADDRCONFIG if self._sock.family == _stdlib_socket.AF_INET6: - if not self._sock.getsockopt( - IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY - ): + if not self._sock.getsockopt(IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY): flags |= _stdlib_socket.AI_V4MAPPED gai_res = await getaddrinfo( host, port, self._sock.family, self.type, self._sock.proto, flags @@ -669,9 +666,7 @@ async def connect(self, address): self._sock.close() raise # Okay, the connect finished, but it might have failed: - err = self._sock.getsockopt( - _stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR - ) + err = self._sock.getsockopt(_stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR) if err != 0: raise OSError(err, "Error in connect: " + _os.strerror(err)) @@ -685,17 +680,13 @@ async def connect(self, address): # recv_into ################################################################ - recv_into = _make_simple_sock_method_wrapper( - "recv_into", _core.wait_readable - ) + recv_into = _make_simple_sock_method_wrapper("recv_into", _core.wait_readable) ################################################################ # recvfrom ################################################################ - recvfrom = _make_simple_sock_method_wrapper( - "recvfrom", _core.wait_readable - ) + recvfrom = _make_simple_sock_method_wrapper("recvfrom", _core.wait_readable) ################################################################ # recvfrom_into diff --git a/trio/_ssl.py b/trio/_ssl.py index 12182c58af..41242f45c2 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -328,14 +328,12 @@ def __init__( server_hostname=None, server_side=False, https_compatible=False, - max_refill_bytes="unused and deprecated" + max_refill_bytes="unused and deprecated", ): self.transport_stream = transport_stream self._state = _State.OK if max_refill_bytes != "unused and deprecated": - warn_deprecated( - "max_refill_bytes=...", "0.12.0", issue=959, instead=None - ) + warn_deprecated("max_refill_bytes=...", "0.12.0", issue=959, instead=None) self._https_compatible = https_compatible self._outgoing = _stdlib_ssl.MemoryBIO() self._delayed_outgoing = None @@ -344,7 +342,7 @@ def __init__( self._incoming, self._outgoing, server_side=server_side, - server_hostname=server_hostname + server_hostname=server_hostname, ) # Tracks whether we've already done the initial handshake self._handshook = _Once(self._do_handshake) @@ -429,9 +427,7 @@ def _check_status(self): # comments, though, just make sure to think carefully if you ever have to # touch it. The big comment at the top of this file will help explain # too. - async def _retry( - self, fn, *args, ignore_want_read=False, is_handshake=False - ): + async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False): await trio.lowlevel.checkpoint_if_cancelled() yielded = False finished = False @@ -495,7 +491,9 @@ async def _retry( # # https://github.com/python-trio/trio/issues/819#issuecomment-517529763 if ( - is_handshake and not want_read and self._ssl_object.server_side + is_handshake + and not want_read + and self._ssl_object.server_side and self._ssl_object.version() == "TLSv1.3" ): assert self._delayed_outgoing is None @@ -664,11 +662,9 @@ async def receive_some(self, max_bytes=None): # For some reason, EOF before handshake sometimes raises # SSLSyscallError instead of SSLEOFError (e.g. on my linux # laptop, but not on appveyor). Thanks openssl. - if ( - self._https_compatible and isinstance( - exc.__cause__, - (_stdlib_ssl.SSLEOFError, _stdlib_ssl.SSLSyscallError) - ) + if self._https_compatible and isinstance( + exc.__cause__, + (_stdlib_ssl.SSLEOFError, _stdlib_ssl.SSLSyscallError), ): await trio.lowlevel.checkpoint() return b"" @@ -678,9 +674,7 @@ async def receive_some(self, max_bytes=None): # If we somehow have more data already in our pending buffer # than the estimate receive size, bump up our size a bit for # this read only. - max_bytes = max( - self._estimated_receive_size, self._incoming.pending - ) + max_bytes = max(self._estimated_receive_size, self._incoming.pending) else: max_bytes = _operator.index(max_bytes) if max_bytes < 1: @@ -693,9 +687,8 @@ async def receive_some(self, max_bytes=None): # BROKEN. But that's actually fine, because after getting an # EOF on TLS then the only thing you can do is close the # stream, and closing doesn't care about the state. - if ( - self._https_compatible - and isinstance(exc.__cause__, _stdlib_ssl.SSLEOFError) + if self._https_compatible and isinstance( + exc.__cause__, _stdlib_ssl.SSLEOFError ): await trio.lowlevel.checkpoint() return b"" @@ -740,8 +733,7 @@ async def unwrap(self): ``transport_stream.receive_some(...)``. """ - with self._outer_recv_conflict_detector, \ - self._outer_send_conflict_detector: + with self._outer_recv_conflict_detector, self._outer_send_conflict_detector: self._check_status() await self._handshook.ensure(checkpoint=False) await self._retry(self._ssl_object.unwrap) @@ -824,9 +816,7 @@ async def aclose(self): # going to be able to do a clean shutdown. If that happens, we'll # just do an unclean shutdown. try: - await self._retry( - self._ssl_object.unwrap, ignore_want_read=True - ) + await self._retry(self._ssl_object.unwrap, ignore_want_read=True) except (trio.BrokenResourceError, trio.BusyResourceError): pass except: @@ -882,9 +872,7 @@ async def wait_send_all_might_not_block(self): await self.transport_stream.wait_send_all_might_not_block() -class SSLListener( - Listener[SSLStream], metaclass=SubclassingDeprecatedIn_v0_15_0 -): +class SSLListener(Listener[SSLStream], metaclass=SubclassingDeprecatedIn_v0_15_0): """A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers. :class:`SSLListener` wraps around another Listener, and converts @@ -905,18 +893,17 @@ class SSLListener( passed to ``__init__``. """ + def __init__( self, transport_listener, ssl_context, *, https_compatible=False, - max_refill_bytes="unused and deprecated" + max_refill_bytes="unused and deprecated", ): if max_refill_bytes != "unused and deprecated": - warn_deprecated( - "max_refill_bytes=...", "0.12.0", issue=959, instead=None - ) + warn_deprecated("max_refill_bytes=...", "0.12.0", issue=959, instead=None) self.transport_listener = transport_listener self._ssl_context = ssl_context self._https_compatible = https_compatible diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 16a869f8ef..ac51915eb0 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -9,8 +9,9 @@ from ._highlevel_generic import StapledStream from ._sync import Lock from ._subprocess_platform import ( - wait_child_exiting, create_pipe_to_child_stdin, - create_pipe_from_child_output + wait_child_exiting, + create_pipe_to_child_stdin, + create_pipe_from_child_output, ) from ._util import NoPublicConstructor import trio @@ -23,6 +24,7 @@ except ImportError: if sys.platform == "linux": import ctypes + _cdll_for_pidfd_open = ctypes.CDLL(None, use_errno=True) _cdll_for_pidfd_open.syscall.restype = ctypes.c_long # pid and flags are actually int-sized, but the syscall() function @@ -45,6 +47,7 @@ def pidfd_open(fd, flags): err = ctypes.get_errno() raise OSError(err, os.strerror(err)) return result + else: can_try_pidfd_open = False @@ -312,12 +315,11 @@ async def open_process( specified command could not be found. """ - for key in ('universal_newlines', 'text', 'encoding', 'errors', 'bufsize'): + for key in ("universal_newlines", "text", "encoding", "errors", "bufsize"): if options.get(key): raise TypeError( "trio.Process only supports communicating over " - "unbuffered byte streams; the '{}' option is not supported" - .format(key) + "unbuffered byte streams; the '{}' option is not supported".format(key) ) if os.name == "posix": @@ -361,7 +363,7 @@ async def open_process( stdin=stdin, stdout=stdout, stderr=stderr, - **options + **options, ) ) finally: @@ -382,9 +384,7 @@ async def _windows_deliver_cancel(p): try: p.terminate() except OSError as exc: - warnings.warn( - RuntimeWarning(f"TerminateProcess on {p!r} failed with: {exc!r}") - ) + warnings.warn(RuntimeWarning(f"TerminateProcess on {p!r} failed with: {exc!r}")) async def _posix_deliver_cancel(p): @@ -401,9 +401,7 @@ async def _posix_deliver_cancel(p): p.kill() except OSError as exc: warnings.warn( - RuntimeWarning( - f"tried to kill process {p!r}, but failed with: {exc!r}" - ) + RuntimeWarning(f"tried to kill process {p!r}, but failed with: {exc!r}") ) @@ -415,7 +413,7 @@ async def run_process( capture_stderr=False, check=True, deliver_cancel=None, - **options + **options, ): """Run ``command`` in a subprocess, wait for it to complete, and return a :class:`subprocess.CompletedProcess` instance describing @@ -638,6 +636,4 @@ async def killer(): proc.returncode, proc.args, output=stdout, stderr=stderr ) else: - return subprocess.CompletedProcess( - proc.args, proc.returncode, stdout, stderr - ) + return subprocess.CompletedProcess(proc.args, proc.returncode, stdout, stderr) diff --git a/trio/_subprocess_platform/kqueue.py b/trio/_subprocess_platform/kqueue.py index 837b556fed..17e2df5c6f 100644 --- a/trio/_subprocess_platform/kqueue.py +++ b/trio/_subprocess_platform/kqueue.py @@ -13,16 +13,11 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: KQ_NOTE_EXIT = 0x80000000 make_event = lambda flags: select.kevent( - process.pid, - filter=select.KQ_FILTER_PROC, - flags=flags, - fflags=KQ_NOTE_EXIT + process.pid, filter=select.KQ_FILTER_PROC, flags=flags, fflags=KQ_NOTE_EXIT, ) try: - kqueue.control( - [make_event(select.KQ_EV_ADD | select.KQ_EV_ONESHOT)], 0 - ) + kqueue.control([make_event(select.KQ_EV_ADD | select.KQ_EV_ONESHOT)], 0) except ProcessLookupError: # pragma: no cover # This can supposedly happen if the process is in the process # of exiting, and it can even be the case that kqueue says the diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py index 030c546f88..81fdf88884 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/trio/_subprocess_platform/waitid.py @@ -13,10 +13,12 @@ def sync_wait_reapable(pid): waitid(os.P_PID, pid, os.WEXITED | os.WNOWAIT) + except ImportError: # pypy doesn't define os.waitid so we need to pull it out ourselves # using cffi: https://bitbucket.org/pypy/pypy/issues/2922/ import cffi + waitid_ffi = cffi.FFI() # Believe it or not, siginfo_t starts with fields in the @@ -43,7 +45,7 @@ def sync_wait_reapable(pid): def sync_wait_reapable(pid): P_PID = 1 WEXITED = 0x00000004 - if sys.platform == 'darwin': # pragma: no cover + if sys.platform == "darwin": # pragma: no cover # waitid() is not exposed on Python on Darwin but does # work through CFFI; note that we typically won't get # here since Darwin also defines kqueue @@ -75,10 +77,7 @@ async def _waitid_system_task(pid: int, event: Event) -> None: try: await to_thread_run_sync( - sync_wait_reapable, - pid, - cancellable=True, - limiter=waitid_limiter, + sync_wait_reapable, pid, cancellable=True, limiter=waitid_limiter, ) except OSError: # If waitid fails, waitpid will fail too, so it still makes diff --git a/trio/_sync.py b/trio/_sync.py index 068a5684f4..9c4fad3a18 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -156,6 +156,7 @@ class CapacityLimiter(metaclass=SubclassingDeprecatedIn_v0_15_0): just borrowed and then put back. """ + def __init__(self, total_tokens): self._lot = ParkingLot() self._borrowers = set() @@ -166,11 +167,8 @@ def __init__(self, total_tokens): assert self._total_tokens == total_tokens def __repr__(self): - return ( - "".format( - id(self), len(self._borrowers), self._total_tokens, - len(self._lot) - ) + return "".format( + id(self), len(self._borrowers), self._total_tokens, len(self._lot) ) @property @@ -190,9 +188,7 @@ def total_tokens(self): @total_tokens.setter def total_tokens(self, new_total_tokens): - if not isinstance( - new_total_tokens, int - ) and new_total_tokens != math.inf: + if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf: raise TypeError("total_tokens must be an int or math.inf") if new_total_tokens < 1: raise ValueError("total_tokens must be >= 1") @@ -321,8 +317,7 @@ def release_on_behalf_of(self, borrower): """ if borrower not in self._borrowers: raise RuntimeError( - "this borrower isn't holding any of this CapacityLimiter's " - "tokens" + "this borrower isn't holding any of this CapacityLimiter's tokens" ) self._borrowers.remove(borrower) self._wake_waiters() @@ -381,6 +376,7 @@ class Semaphore(metaclass=SubclassingDeprecatedIn_v0_15_0): ``max_value``. """ + def __init__(self, initial_value, *, max_value=None): if not isinstance(initial_value, int): raise TypeError("initial_value must be an int") @@ -404,10 +400,8 @@ def __repr__(self): max_value_str = "" else: max_value_str = ", max_value={}".format(self._max_value) - return ( - "".format( - self._value, max_value_str, id(self) - ) + return "".format( + self._value, max_value_str, id(self) ) @property @@ -502,10 +496,8 @@ def __repr__(self): else: s1 = "unlocked" s2 = "" - return ( - "<{} {} object at {:#x}{}>".format( - s1, self.__class__.__name__, id(self), s2 - ) + return "<{} {} object at {:#x}{}>".format( + s1, self.__class__.__name__, id(self), s2 ) def locked(self): @@ -580,9 +572,7 @@ def statistics(self): """ return _LockStatistics( - locked=self.locked(), - owner=self._owner, - tasks_waiting=len(self._lot), + locked=self.locked(), owner=self._owner, tasks_waiting=len(self._lot), ) @@ -684,6 +674,7 @@ class Condition(metaclass=SubclassingDeprecatedIn_v0_15_0): and used. """ + def __init__(self, lock=None): if lock is None: lock = Lock() @@ -795,6 +786,5 @@ def statistics(self): """ return _ConditionStatistics( - tasks_waiting=len(self._lot), - lock_statistics=self._lock.statistics(), + tasks_waiting=len(self._lot), lock_statistics=self._lock.statistics(), ) diff --git a/trio/_threads.py b/trio/_threads.py index c03a353789..92e2b5dc0f 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -9,7 +9,12 @@ import trio from ._sync import CapacityLimiter -from ._core import enable_ki_protection, disable_ki_protection, RunVar, TrioToken +from ._core import ( + enable_ki_protection, + disable_ki_protection, + RunVar, + TrioToken, +) from ._util import coroutine_or_error # Global due to Threading API, thread local storage for trio token @@ -291,10 +296,7 @@ def worker_thread_fn(trio_token): # this case shouldn't block process exit. current_trio_token = trio.lowlevel.current_trio_token() thread = threading.Thread( - target=worker_thread_fn, - args=(current_trio_token,), - name=name, - daemon=True + target=worker_thread_fn, args=(current_trio_token,), name=name, daemon=True, ) thread.start() except: @@ -338,9 +340,7 @@ def _run_fn_as_system_task(cb, fn, *args, trio_token=None): except RuntimeError: pass else: - raise RuntimeError( - "this is a blocking function; call it from a thread" - ) + raise RuntimeError("this is a blocking function; call it from a thread") q = stdlib_queue.Queue() trio_token.run_sync_soon(cb, q, fn, args) @@ -380,6 +380,7 @@ def from_thread_run(afn, *args, trio_token=None): "foreign" thread, spawned using some other framework, and still want to enter Trio. """ + def callback(q, afn, args): @disable_ki_protection async def unprotected_afn(): @@ -424,6 +425,7 @@ def from_thread_run_sync(fn, *args, trio_token=None): "foreign" thread, spawned using some other framework, and still want to enter Trio. """ + def callback(q, fn, args): @disable_ki_protection def unprotected_fn(): diff --git a/trio/_timeouts.py b/trio/_timeouts.py index 9bfefe4b03..517f344bf3 100644 --- a/trio/_timeouts.py +++ b/trio/_timeouts.py @@ -37,9 +37,7 @@ async def sleep_forever(): Equivalent to calling ``await sleep(math.inf)``. """ - await trio.lowlevel.wait_task_rescheduled( - lambda _: trio.lowlevel.Abort.SUCCEEDED - ) + await trio.lowlevel.wait_task_rescheduled(lambda _: trio.lowlevel.Abort.SUCCEEDED) async def sleep_until(deadline): diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 1340d049a2..33cc736ae1 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -13,7 +13,7 @@ from textwrap import indent -PREFIX = '_generated' +PREFIX = "_generated" HEADER = """# *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** @@ -21,14 +21,17 @@ from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED - +# fmt: off +""" + +FOOTER = """# fmt: on """ TEMPLATE = """locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return {} GLOBAL_RUN_CONTEXT.{}.{} + return{}GLOBAL_RUN_CONTEXT.{}.{} except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") """ @@ -36,10 +39,7 @@ def is_function(node): """Check if the AST node is either a function or an async function """ - if ( - isinstance(node, ast.FunctionDef) - or isinstance(node, ast.AsyncFunctionDef) - ): + if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef): return True return False @@ -112,21 +112,22 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: del method.body[1:] # Create the function definition including the body - func = astor.to_source(method, indent_with=' ' * 4) + func = astor.to_source(method, indent_with=" " * 4) # Create export function body template = TEMPLATE.format( - 'await' if isinstance(method, ast.AsyncFunctionDef) else '', + " await " if isinstance(method, ast.AsyncFunctionDef) else " ", lookup_path, method.name + new_args, ) # Assemble function definition arguments and body - snippet = func + indent(template, ' ' * 4) + snippet = func + indent(template, " " * 4) # Append the snippet to the corresponding module generated.append(snippet) - return "\n".join(generated) + generated.append(FOOTER) + return "\n\n".join(generated) def matches_disk_files(new_files): @@ -165,20 +166,17 @@ def process(sources_and_lookups, *, do_test): # doesn't collect coverage. def main(): # pragma: no cover parser = argparse.ArgumentParser( - description='Generate python code for public api wrappers' + description="Generate python code for public api wrappers" ) parser.add_argument( - '--test', - '-t', - action='store_true', - help='test if code is still up to date' + "--test", "-t", action="store_true", help="test if code is still up to date", ) parsed_args = parser.parse_args() source_root = Path.cwd() # Double-check we found the right directory assert (source_root / "LICENSE").exists() - core = source_root / 'trio/_core' + core = source_root / "trio/_core" to_wrap = [ (core / "_run.py", "runner"), (core / "_io_windows.py", "runner.io_manager"), @@ -189,5 +187,5 @@ def main(): # pragma: no cover process(to_wrap, do_test=parsed_args.test) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover main() diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index f2afe8d2b1..cb63009249 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -105,6 +105,7 @@ class FdStream(Stream, metaclass=SubclassingDeprecatedIn_v0_15_0): Returns: A new `FdStream` object. """ + def __init__(self, fd: int): self._fd_holder = _FdHolder(fd) self._send_conflict_detector = ConflictDetector( @@ -130,9 +131,7 @@ async def send_all(self, data: bytes): try: sent += os.write(self._fd_holder.fd, remaining) except BlockingIOError: - await trio.lowlevel.wait_writable( - self._fd_holder.fd - ) + await trio.lowlevel.wait_writable(self._fd_holder.fd) except OSError as e: if e.errno == errno.EBADF: raise trio.ClosedResourceError( diff --git a/trio/_util.py b/trio/_util.py index 58ebf800ed..03b79065e2 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -123,8 +123,9 @@ def _return_value_looks_like_wrong_library(value): "Instead, you want (notice the parentheses!):\n" "\n" " trio.run({async_fn.__name__}, ...) # correct!\n" - " nursery.start_soon({async_fn.__name__}, ...) # correct!" - .format(async_fn=async_fn) + " nursery.start_soon({async_fn.__name__}, ...) # correct!".format( + async_fn=async_fn + ) ) from None # Give good error for: nursery.start_soon(future) @@ -148,8 +149,7 @@ def _return_value_looks_like_wrong_library(value): raise TypeError( "Trio got unexpected {!r} – are you trying to use a " "library written for asyncio/twisted/tornado or similar? " - "That won't work without some sort of compatibility shim." - .format(coro) + "That won't work without some sort of compatibility shim.".format(coro) ) if isasyncgen(coro): @@ -180,6 +180,7 @@ class ConflictDetector: tasks don't call sendall simultaneously on the same stream. """ + def __init__(self, msg): self._msg = msg self._held = False @@ -198,9 +199,10 @@ def async_wraps(cls, wrapped_cls, attr_name): """Similar to wraps, but for async wrappers of non-async functions. """ + def decorator(func): func.__name__ = attr_name - func.__qualname__ = '.'.join((cls.__qualname__, attr_name)) + func.__qualname__ = ".".join((cls.__qualname__, attr_name)) func.__doc__ = """Like :meth:`~{}.{}.{}`, but async. @@ -257,6 +259,7 @@ def open_memory_channel(max_buffer_size: int) -> Tuple[ and currently won't type-check without a mypy plugin or clever stubs, but at least it becomes possible to write those. """ + def __init__(self, fn): update_wrapper(self, fn) self._fn = fn @@ -296,6 +299,7 @@ class SomeClass(metaclass=Final): ------ - TypeError if a sub class is created """ + def __new__(cls, name, bases, cls_namespace): for base in bases: if isinstance(base, Final): @@ -313,7 +317,7 @@ def __new__(cls, name, bases, cls_namespace): f"subclassing {base.__module__}.{base.__qualname__}", "0.15.0", issue=1044, - instead="composition or delegation" + instead="composition or delegation", ) break return super().__new__(cls, name, bases, cls_namespace) @@ -337,6 +341,7 @@ class SomeClass(metaclass=NoPublicConstructor): ------ - TypeError if a sub class or an instance is created. """ + def __call__(self, *args, **kwargs): raise TypeError( f"{self.__module__}.{self.__qualname__} has no public constructor" diff --git a/trio/_wait_for_object.py b/trio/_wait_for_object.py index 1b209ddb26..3231e4d551 100644 --- a/trio/_wait_for_object.py +++ b/trio/_wait_for_object.py @@ -1,7 +1,13 @@ import math from . import _timeouts import trio -from ._core._windows_cffi import ffi, kernel32, ErrorCodes, raise_winerror, _handle +from ._core._windows_cffi import ( + ffi, + kernel32, + ErrorCodes, + raise_winerror, + _handle, +) async def WaitForSingleObject(obj): @@ -52,9 +58,7 @@ def WaitForMultipleObjects_sync(*handles): handle_arr = ffi.new("HANDLE[{}]".format(n)) for i in range(n): handle_arr[i] = handles[i] - timeout = 0xffffffff # INFINITE - retcode = kernel32.WaitForMultipleObjects( - n, handle_arr, False, timeout - ) # blocking + timeout = 0xFFFFFFFF # INFINITE + retcode = kernel32.WaitForMultipleObjects(n, handle_arr, False, timeout) # blocking if retcode == ErrorCodes.WAIT_FAILED: raise_winerror() diff --git a/trio/_windows_pipes.py b/trio/_windows_pipes.py index 6af69f364a..81b9834b19 100644 --- a/trio/_windows_pipes.py +++ b/trio/_windows_pipes.py @@ -41,6 +41,7 @@ class PipeSendStream(SendStream, metaclass=Final): """Represents a send stream over a Windows named pipe that has been opened in OVERLAPPED mode. """ + def __init__(self, handle: int) -> None: self._handle_holder = _HandleHolder(handle) self._conflict_detector = ConflictDetector( @@ -57,9 +58,7 @@ async def send_all(self, data: bytes): return try: - written = await _core.write_overlapped( - self._handle_holder.handle, data - ) + written = await _core.write_overlapped(self._handle_holder.handle, data) except BrokenPipeError as ex: raise _core.BrokenResourceError from ex # By my reading of MSDN, this assert is guaranteed to pass so long @@ -81,6 +80,7 @@ async def aclose(self): class PipeReceiveStream(ReceiveStream, metaclass=Final): """Represents a receive stream over an os.pipe object.""" + def __init__(self, handle: int) -> None: self._handle_holder = _HandleHolder(handle) self._conflict_detector = ConflictDetector( diff --git a/trio/abc.py b/trio/abc.py index e3348360e4..ce0a1f6c00 100644 --- a/trio/abc.py +++ b/trio/abc.py @@ -5,7 +5,17 @@ # implementation in an underscored module, and then re-export the public parts # here. from ._abc import ( - Clock, Instrument, AsyncResource, SendStream, ReceiveStream, Stream, - HalfCloseableStream, SocketFactory, HostnameResolver, Listener, - SendChannel, ReceiveChannel, Channel + Clock, + Instrument, + AsyncResource, + SendStream, + ReceiveStream, + Stream, + HalfCloseableStream, + SocketFactory, + HostnameResolver, + Listener, + SendChannel, + ReceiveChannel, + Channel, ) diff --git a/trio/lowlevel.py b/trio/lowlevel.py index 5fe32c03d9..21ec0597df 100644 --- a/trio/lowlevel.py +++ b/trio/lowlevel.py @@ -13,14 +13,34 @@ # Generally available symbols from ._core import ( - cancel_shielded_checkpoint, Abort, wait_task_rescheduled, - enable_ki_protection, disable_ki_protection, currently_ki_protected, Task, - checkpoint, current_task, ParkingLot, UnboundedQueue, RunVar, TrioToken, - current_trio_token, temporarily_detach_coroutine_object, - permanently_detach_coroutine_object, reattach_detached_coroutine_object, - current_statistics, reschedule, remove_instrument, add_instrument, - current_clock, current_root_task, checkpoint_if_cancelled, - spawn_system_task, wait_readable, wait_writable, notify_closing + cancel_shielded_checkpoint, + Abort, + wait_task_rescheduled, + enable_ki_protection, + disable_ki_protection, + currently_ki_protected, + Task, + checkpoint, + current_task, + ParkingLot, + UnboundedQueue, + RunVar, + TrioToken, + current_trio_token, + temporarily_detach_coroutine_object, + permanently_detach_coroutine_object, + reattach_detached_coroutine_object, + current_statistics, + reschedule, + remove_instrument, + add_instrument, + current_clock, + current_root_task, + checkpoint_if_cancelled, + spawn_system_task, + wait_readable, + wait_writable, + notify_closing, ) # Unix-specific symbols diff --git a/trio/socket.py b/trio/socket.py index 5951b5b099..ebbccd50ea 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -20,6 +20,7 @@ # going on. There's a test in test_exports.py to make sure that the list is # kept up to date. try: + # fmt: off from socket import ( CMSG_LEN, CMSG_SPACE, CAPI, AF_UNSPEC, AF_INET, AF_UNIX, AF_IPX, AF_APPLETALK, AF_INET6, AF_ROUTE, AF_LINK, AF_SNA, PF_SYSTEM, @@ -117,6 +118,7 @@ SCM_J1939_PRIO, SO_J1939_ERRQUEUE, SO_J1939_FILTER, SO_J1939_PROMISC, SO_J1939_SEND_PRIO, UDPLITE_RECV_CSCOV, UDPLITE_SEND_CSCOV ) + # fmt: on except ImportError: pass @@ -125,7 +127,7 @@ import socket as _stdlib_socket _bad_symbols = set() -if _sys.platform == 'win32': +if _sys.platform == "win32": # See https://github.com/python-trio/trio/issues/39 # Do not import for windows platform # (you can still get it from stdlib socket, of course, if you want it) @@ -141,9 +143,16 @@ # import the overwrites from ._socket import ( - fromfd, from_stdlib_socket, getprotobyname, socketpair, getnameinfo, - socket, getaddrinfo, set_custom_hostname_resolver, - set_custom_socket_factory, SocketType + fromfd, + from_stdlib_socket, + getprotobyname, + socketpair, + getnameinfo, + socket, + getaddrinfo, + set_custom_hostname_resolver, + set_custom_socket_factory, + SocketType, ) # not always available so expose only if @@ -168,9 +177,7 @@ # not always available so expose only if try: - from socket import ( - sethostname, if_nameindex, if_nametoindex, if_indextoname - ) + from socket import sethostname, if_nameindex, if_nametoindex, if_indextoname except ImportError: pass diff --git a/trio/testing/__init__.py b/trio/testing/__init__.py index df150ec62b..8c730ffeb5 100644 --- a/trio/testing/__init__.py +++ b/trio/testing/__init__.py @@ -9,13 +9,19 @@ from ._sequencer import Sequencer from ._check_streams import ( - check_one_way_stream, check_two_way_stream, check_half_closeable_stream + check_one_way_stream, + check_two_way_stream, + check_half_closeable_stream, ) from ._memory_streams import ( - MemorySendStream, MemoryReceiveStream, memory_stream_pump, - memory_stream_one_way_pair, memory_stream_pair, - lockstep_stream_one_way_pair, lockstep_stream_pair + MemorySendStream, + MemoryReceiveStream, + memory_stream_pump, + memory_stream_one_way_pair, + memory_stream_pair, + lockstep_stream_one_way_pair, + lockstep_stream_pair, ) from ._network import open_stream_to_socket_listener @@ -23,5 +29,6 @@ ################################################################ from .._util import fixup_module_metadata + fixup_module_metadata(__name__, globals()) del fixup_module_metadata diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 48880c13c4..7a9006ff43 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -130,8 +130,7 @@ async def simple_check_wait_send_all_might_not_block(scope): async with _core.open_nursery() as nursery: nursery.start_soon( - simple_check_wait_send_all_might_not_block, - nursery.cancel_scope + simple_check_wait_send_all_might_not_block, nursery.cancel_scope ) nursery.start_soon(do_receive_some, 1) @@ -398,19 +397,18 @@ async def flipped_stream_maker(): async def flipped_clogged_stream_maker(): return reversed(await clogged_stream_maker()) + else: flipped_clogged_stream_maker = None - await check_one_way_stream( - flipped_stream_maker, flipped_clogged_stream_maker - ) + await check_one_way_stream(flipped_stream_maker, flipped_clogged_stream_maker) async with _ForceCloseBoth(await stream_maker()) as (s1, s2): assert isinstance(s1, Stream) assert isinstance(s2, Stream) # Duplex can be a bit tricky, might as well check it as well - DUPLEX_TEST_SIZE = 2**20 - CHUNK_SIZE_MAX = 2**14 + DUPLEX_TEST_SIZE = 2 ** 20 + CHUNK_SIZE_MAX = 2 ** 14 r = random.Random(0) i = r.getrandbits(8 * DUPLEX_TEST_SIZE) diff --git a/trio/testing/_checkpoints.py b/trio/testing/_checkpoints.py index 27fd4d187d..5804295300 100644 --- a/trio/testing/_checkpoints.py +++ b/trio/testing/_checkpoints.py @@ -11,19 +11,13 @@ def _assert_yields_or_not(expected): orig_schedule = task._schedule_points try: yield - if ( - expected and ( - task._cancel_points == orig_cancel - or task._schedule_points == orig_schedule - ) + if expected and ( + task._cancel_points == orig_cancel or task._schedule_points == orig_schedule ): raise AssertionError("assert_checkpoints block did not yield!") finally: - if ( - not expected and ( - task._cancel_points != orig_cancel - or task._schedule_points != orig_schedule - ) + if not expected and ( + task._cancel_points != orig_cancel or task._schedule_points != orig_schedule ): raise AssertionError("assert_no_checkpoints block yielded!") diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 62d63b73d9..66f7f25d97 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -73,9 +73,7 @@ async def get(self, max_bytes=None): return self._get_impl(max_bytes) -class MemorySendStream( - SendStream, metaclass=_util.SubclassingDeprecatedIn_v0_15_0 -): +class MemorySendStream(SendStream, metaclass=_util.SubclassingDeprecatedIn_v0_15_0): """An in-memory :class:`~trio.abc.SendStream`. Args: @@ -95,11 +93,12 @@ class MemorySendStream( you can change them at any time. """ + def __init__( self, send_all_hook=None, wait_send_all_might_not_block_hook=None, - close_hook=None + close_hook=None, ): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" @@ -208,6 +207,7 @@ class MemoryReceiveStream( change them at any time. """ + def __init__(self, receive_some_hook=None, close_hook=None): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" @@ -270,9 +270,7 @@ def put_eof(self): self._incoming.close() -def memory_stream_pump( - memory_send_stream, memory_receive_stream, *, max_bytes=None -): +def memory_stream_pump(memory_send_stream, memory_receive_stream, *, max_bytes=None): """Take data out of the given :class:`MemorySendStream`'s internal buffer, and put it into the given :class:`MemoryReceiveStream`'s internal buffer. diff --git a/trio/testing/_mock_clock.py b/trio/testing/_mock_clock.py index 997e701f39..843f51197f 100644 --- a/trio/testing/_mock_clock.py +++ b/trio/testing/_mock_clock.py @@ -79,6 +79,7 @@ class MockClock(Clock, metaclass=SubclassingDeprecatedIn_v0_15_0): :func:`wait_all_tasks_blocked`. """ + def __init__(self, rate=0.0, autojump_threshold=inf): # when the real clock said 'real_base', the virtual time was # 'virtual_base', and since then it's advanced at 'rate' virtual @@ -97,10 +98,8 @@ def __init__(self, rate=0.0, autojump_threshold=inf): self.autojump_threshold = autojump_threshold def __repr__(self): - return ( - "".format( - self.current_time(), self._rate, id(self) - ) + return "".format( + self.current_time(), self._rate, id(self) ) @property @@ -141,9 +140,7 @@ async def _autojumper(self): # to raise Cancelled, which is absorbed by the cancel # scope above, and effectively just causes us to skip back # to the start the loop, like a 'continue'. - await _core.wait_all_tasks_blocked( - self._autojump_threshold, inf - ) + await _core.wait_all_tasks_blocked(self._autojump_threshold, inf) statistics = _core.current_statistics() jump = statistics.seconds_to_next_deadline if 0 < jump < inf: diff --git a/trio/testing/_sequencer.py b/trio/testing/_sequencer.py index 21fc492dff..abecf396ce 100644 --- a/trio/testing/_sequencer.py +++ b/trio/testing/_sequencer.py @@ -61,9 +61,7 @@ async def main(): @asynccontextmanager async def __call__(self, position: int): if position in self._claimed: - raise RuntimeError( - "Attempted to re-use sequence point {}".format(position) - ) + raise RuntimeError("Attempted to re-use sequence point {}".format(position)) if self._broken: raise RuntimeError("sequence broken!") self._claimed.add(position) @@ -74,9 +72,7 @@ async def __call__(self, position: int): self._broken = True for event in self._sequence_points.values(): event.set() - raise RuntimeError( - "Sequencer wait cancelled -- sequence broken" - ) + raise RuntimeError("Sequencer wait cancelled -- sequence broken") else: if self._broken: raise RuntimeError("sequence broken!") diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index 87caa5881a..4fcaeae372 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -24,8 +24,6 @@ def wrapper(**kwargs): else: raise ValueError("too many clocks spoil the broth!") instruments = [i for i in kwargs.values() if isinstance(i, Instrument)] - return _core.run( - partial(fn, **kwargs), clock=clock, instruments=instruments - ) + return _core.run(partial(fn, **kwargs), clock=clock, instruments=instruments) return wrapper diff --git a/trio/tests/module_with_deprecations.py b/trio/tests/module_with_deprecations.py index d194b6a5bd..b0f83b1540 100644 --- a/trio/tests/module_with_deprecations.py +++ b/trio/tests/module_with_deprecations.py @@ -8,22 +8,14 @@ # attributes in between calling enable_attribute_deprecations and defining # __deprecated_attributes__: import sys + this_mod = sys.modules[__name__] assert this_mod.regular == "hi" assert not hasattr(this_mod, "dep1") __deprecated_attributes__ = { - "dep1": - _deprecate.DeprecatedAttribute( - "value1", - "1.1", - issue=1, - ), - "dep2": - _deprecate.DeprecatedAttribute( - "value2", - "1.2", - issue=1, - instead="instead-string", - ), + "dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1,), + "dep2": _deprecate.DeprecatedAttribute( + "value2", "1.2", issue=1, instead="instead-string", + ), } diff --git a/trio/tests/test_deprecate.py b/trio/tests/test_deprecate.py index 6ecd00003e..a11e9b8d1f 100644 --- a/trio/tests/test_deprecate.py +++ b/trio/tests/test_deprecate.py @@ -4,7 +4,10 @@ import warnings from .._deprecate import ( - TrioDeprecationWarning, warn_deprecated, deprecated, deprecated_alias + TrioDeprecationWarning, + warn_deprecated, + deprecated, + deprecated_alias, ) from . import module_with_deprecations @@ -25,8 +28,8 @@ def test_warn_deprecated(recwarn_always): def deprecated_thing(): warn_deprecated("ice", "1.2", issue=1, instead="water") - filename, lineno = _here() # https://github.com/google/yapf/issues/447 deprecated_thing() + filename, lineno = _here() assert len(recwarn_always) == 1 got = recwarn_always.pop(TrioDeprecationWarning) assert "ice is deprecated" in got.message.args[0] @@ -34,7 +37,7 @@ def deprecated_thing(): assert "water instead" in got.message.args[0] assert "/issues/1" in got.message.args[0] assert got.filename == filename - assert got.lineno == lineno + 1 + assert got.lineno == lineno - 1 def test_warn_deprecated_no_instead_or_issue(recwarn_always): @@ -54,7 +57,7 @@ def nested1(): def nested2(): warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3) - filename, lineno = _here() # https://github.com/google/yapf/issues/447 + filename, lineno = _here() nested1() got = recwarn_always.pop(TrioDeprecationWarning) assert got.filename == filename @@ -181,40 +184,52 @@ def docstring_test4(): # pragma: no cover def test_deprecated_docstring_munging(): - assert docstring_test1.__doc__ == """Hello! + assert ( + docstring_test1.__doc__ + == """Hello! .. deprecated:: 2.1 Use hi instead. For details, see `issue #1 `__. """ + ) - assert docstring_test2.__doc__ == """Hello! + assert ( + docstring_test2.__doc__ + == """Hello! .. deprecated:: 2.1 Use hi instead. """ + ) - assert docstring_test3.__doc__ == """Hello! + assert ( + docstring_test3.__doc__ + == """Hello! .. deprecated:: 2.1 For details, see `issue #1 `__. """ + ) - assert docstring_test4.__doc__ == """Hello! + assert ( + docstring_test4.__doc__ + == """Hello! .. deprecated:: 2.1 """ + ) def test_module_with_deprecations(recwarn_always): assert module_with_deprecations.regular == "hi" assert len(recwarn_always) == 0 - filename, lineno = _here() # https://github.com/google/yapf/issues/447 + filename, lineno = _here() assert module_with_deprecations.dep1 == "value1" got = recwarn_always.pop(TrioDeprecationWarning) assert got.filename == filename diff --git a/trio/tests/test_exports.py b/trio/tests/test_exports.py index af34070267..1d49f4765d 100644 --- a/trio/tests/test_exports.py +++ b/trio/tests/test_exports.py @@ -18,13 +18,12 @@ def test_core_is_properly_reexported(): # three modules: sources = [trio, trio.lowlevel, trio.testing] for symbol in dir(_core): - if symbol.startswith('_') or symbol == 'tests': + if symbol.startswith("_") or symbol == "tests": continue found = 0 for source in sources: - if ( - symbol in dir(source) - and getattr(source, symbol) is getattr(_core, symbol) + if symbol in dir(source) and getattr(source, symbol) is getattr( + _core, symbol ): found += 1 print(symbol, found) @@ -85,11 +84,13 @@ def no_underscores(symbols): if tool == "pylint": from pylint.lint import PyLinter + linter = PyLinter() ast = linter.get_ast(module.__file__, modname) static_names = no_underscores(ast) elif tool == "jedi": import jedi + # Simulate typing "import trio; trio." script = jedi.Script("import {}; {}.".format(modname, modname)) completions = script.complete() diff --git a/trio/tests/test_file_io.py b/trio/tests/test_file_io.py index fd4fa648b4..b40f7518a9 100644 --- a/trio/tests/test_file_io.py +++ b/trio/tests/test_file_io.py @@ -12,7 +12,7 @@ @pytest.fixture def path(tmpdir): - return os.fspath(tmpdir.join('test')) + return os.fspath(tmpdir.join("test")) @pytest.fixture @@ -58,9 +58,7 @@ def test_dir_matches_wrapped(async_file, wrapped): attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS) # all supported attrs in wrapped should be available in async_file - assert all( - attr in dir(async_file) for attr in attrs if attr in dir(wrapped) - ) + assert all(attr in dir(async_file) for attr in attrs if attr in dir(wrapped)) # all supported attrs not in wrapped should not be available in async_file assert not any( attr in dir(async_file) for attr in attrs if attr not in dir(wrapped) @@ -74,10 +72,10 @@ def unsupported_attr(self): # pragma: no cover async_file = trio.wrap_file(FakeFile()) - assert hasattr(async_file.wrapped, 'unsupported_attr') + assert hasattr(async_file.wrapped, "unsupported_attr") with pytest.raises(AttributeError): - getattr(async_file, 'unsupported_attr') + getattr(async_file, "unsupported_attr") def test_sync_attrs_forwarded(async_file, wrapped): @@ -110,10 +108,10 @@ def test_async_methods_generated_once(async_file): def test_async_methods_signature(async_file): # use read as a representative of all async methods - assert async_file.read.__name__ == 'read' - assert async_file.read.__qualname__ == 'AsyncIOWrapper.read' + assert async_file.read.__name__ == "read" + assert async_file.read.__qualname__ == "AsyncIOWrapper.read" - assert 'io.StringIO.read' in async_file.read.__doc__ + assert "io.StringIO.read" in async_file.read.__doc__ async def test_async_methods_wrap(async_file, wrapped): @@ -147,7 +145,7 @@ async def test_async_methods_match_wrapper(async_file, wrapped): async def test_open(path): - f = await trio.open_file(path, 'w') + f = await trio.open_file(path, "w") assert isinstance(f, AsyncIOWrapper) @@ -155,7 +153,7 @@ async def test_open(path): async def test_open_context_manager(path): - async with await trio.open_file(path, 'w') as f: + async with await trio.open_file(path, "w") as f: assert isinstance(f, AsyncIOWrapper) assert not f.closed @@ -163,7 +161,7 @@ async def test_open_context_manager(path): async def test_async_iter(): - async_file = trio.wrap_file(io.StringIO('test\nfoo\nbar')) + async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar")) expected = list(async_file.wrapped) result = [] async_file.wrapped.seek(0) @@ -176,11 +174,11 @@ async def test_async_iter(): async def test_aclose_cancelled(path): with _core.CancelScope() as cscope: - f = await trio.open_file(path, 'w') + f = await trio.open_file(path, "w") cscope.cancel() with pytest.raises(_core.Cancelled): - await f.write('a') + await f.write("a") with pytest.raises(_core.Cancelled): await f.aclose() diff --git a/trio/tests/test_highlevel_open_tcp_listeners.py b/trio/tests/test_highlevel_open_tcp_listeners.py index d11c1c6e6d..ffd708807a 100644 --- a/trio/tests/test_highlevel_open_tcp_listeners.py +++ b/trio/tests/test_highlevel_open_tcp_listeners.py @@ -6,9 +6,7 @@ import attr import trio -from trio import ( - open_tcp_listeners, serve_tcp, SocketListener, open_tcp_stream -) +from trio import open_tcp_listeners, serve_tcp, SocketListener, open_tcp_stream from trio.testing import open_stream_to_socket_listener from .. import socket as tsocket from .._core.tests.tutil import slow, creates_ipv6, binds_ipv6 @@ -79,9 +77,7 @@ async def measure_backlog(listener, limit): # has been observed to sometimes raise ConnectionResetError. with trio.move_on_after(0.5) as cancel_scope: try: - client_stream = await open_stream_to_socket_listener( - listener - ) + client_stream = await open_stream_to_socket_listener(listener) except ConnectionResetError: # pragma: no cover break client_streams.append(client_stream) @@ -272,28 +268,18 @@ async def handler(stream): @pytest.mark.parametrize( - "try_families", [ - {tsocket.AF_INET}, - {tsocket.AF_INET6}, - {tsocket.AF_INET, tsocket.AF_INET6}, - ] + "try_families", + [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6},], ) @pytest.mark.parametrize( - "fail_families", [ - {tsocket.AF_INET}, - {tsocket.AF_INET6}, - {tsocket.AF_INET, tsocket.AF_INET6}, - ] + "fail_families", + [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6},], ) async def test_open_tcp_listeners_some_address_families_unavailable( try_families, fail_families ): fsf = FakeSocketFactory( - 10, - raise_on_family={ - family: errno.EAFNOSUPPORT - for family in fail_families - } + 10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families}, ) tsocket.set_custom_socket_factory(fsf) tsocket.set_custom_hostname_resolver( @@ -326,13 +312,11 @@ async def test_open_tcp_listeners_socket_fails_not_afnosupport(): raise_on_family={ tsocket.AF_INET: errno.EAFNOSUPPORT, tsocket.AF_INET6: errno.EINVAL, - } + }, ) tsocket.set_custom_socket_factory(fsf) tsocket.set_custom_hostname_resolver( - FakeHostnameResolver( - [(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")] - ) + FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")]) ) with pytest.raises(OSError) as exc_info: diff --git a/trio/tests/test_highlevel_open_tcp_stream.py b/trio/tests/test_highlevel_open_tcp_stream.py index 0c28450e5e..9fd0f3992a 100644 --- a/trio/tests/test_highlevel_open_tcp_stream.py +++ b/trio/tests/test_highlevel_open_tcp_stream.py @@ -46,7 +46,11 @@ def close(self): def test_reorder_for_rfc_6555_section_5_4(): def fake4(i): return ( - AF_INET, SOCK_STREAM, IPPROTO_TCP, "", ("10.0.0.{}".format(i), 80) + AF_INET, + SOCK_STREAM, + IPPROTO_TCP, + "", + ("10.0.0.{}".format(i), 80), ) def fake6(i): @@ -225,7 +229,7 @@ async def run_scenario( # If this is True, we require there to be an exception, and return # (exception, scenario object) expect_error=(), - **kwargs + **kwargs, ): supported_families = set() if ipv4_supported: @@ -278,8 +282,7 @@ async def test_one_host_slow_fail(autojump_clock): async def test_one_host_failed_after_connect(autojump_clock): exc, scenario = await run_scenario( - 83, [("1.2.3.4", 1, "postconnect_fail")], - expect_error=KeyboardInterrupt + 83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt ) assert isinstance(exc, KeyboardInterrupt) diff --git a/trio/tests/test_highlevel_open_unix_stream.py b/trio/tests/test_highlevel_open_unix_stream.py index 872a43dd6d..211aff3e70 100644 --- a/trio/tests/test_highlevel_open_unix_stream.py +++ b/trio/tests/test_highlevel_open_unix_stream.py @@ -5,9 +5,7 @@ import pytest from trio import open_unix_socket, Path -from trio._highlevel_open_unix_stream import ( - close_on_error, -) +from trio._highlevel_open_unix_stream import close_on_error if not hasattr(socket, "AF_UNIX"): pytestmark = pytest.mark.skip("Needs unix socket support") @@ -30,7 +28,7 @@ def close(self): assert c.closed -@pytest.mark.parametrize('filename', [4, 4.5]) +@pytest.mark.parametrize("filename", [4, 4.5]) async def test_open_with_bad_filename_type(filename): with pytest.raises(TypeError): await open_unix_socket(filename) diff --git a/trio/tests/test_highlevel_serve_listeners.py b/trio/tests/test_highlevel_serve_listeners.py index e26a65605f..b028092eb9 100644 --- a/trio/tests/test_highlevel_serve_listeners.py +++ b/trio/tests/test_highlevel_serve_listeners.py @@ -136,8 +136,9 @@ async def connection_watcher(*, task_status=trio.TASK_STATUS_IGNORED): await nursery.start( partial( trio.serve_listeners, - handler, [listener], - handler_nursery=handler_nursery + handler, + [listener], + handler_nursery=handler_nursery, ) ) for _ in range(10): diff --git a/trio/tests/test_highlevel_socket.py b/trio/tests/test_highlevel_socket.py index dc19219e3e..f3570f743e 100644 --- a/trio/tests/test_highlevel_socket.py +++ b/trio/tests/test_highlevel_socket.py @@ -6,7 +6,9 @@ from .. import _core from ..testing import ( - check_half_closeable_stream, wait_all_tasks_blocked, assert_checkpoints + check_half_closeable_stream, + wait_all_tasks_blocked, + assert_checkpoints, ) from .._highlevel_socket import * from .. import socket as tsocket diff --git a/trio/tests/test_highlevel_ssl_helpers.py b/trio/tests/test_highlevel_ssl_helpers.py index 1583f4cd54..99ee46c9d0 100644 --- a/trio/tests/test_highlevel_ssl_helpers.py +++ b/trio/tests/test_highlevel_ssl_helpers.py @@ -10,7 +10,9 @@ from .test_ssl import client_ctx, SERVER_CTX from .._highlevel_ssl_helpers import ( - open_ssl_over_tcp_stream, open_ssl_over_tcp_listeners, serve_ssl_over_tcp + open_ssl_over_tcp_stream, + open_ssl_over_tcp_listeners, + serve_ssl_over_tcp, ) @@ -41,18 +43,10 @@ async def getnameinfo(self, *args): # pragma: no cover # This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners... # noqa is needed because flake8 doesn't understand how pytest fixtures work. -async def test_open_ssl_over_tcp_stream_and_everything_else( - client_ctx, # noqa: F811 -): +async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx,): # noqa: F811 async with trio.open_nursery() as nursery: (listener,) = await nursery.start( - partial( - serve_ssl_over_tcp, - echo_handler, - 0, - SERVER_CTX, - host="127.0.0.1" - ) + partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1",) ) sockaddr = listener.transport_listener.socket.getsockname() hostname_resolver = FakeHostnameResolver(sockaddr) @@ -67,18 +61,14 @@ async def test_open_ssl_over_tcp_stream_and_everything_else( # We have the trust but not the hostname # (checks custom ssl_context + hostname checking) stream = await open_ssl_over_tcp_stream( - "xyzzy.example.org", - 80, - ssl_context=client_ctx, + "xyzzy.example.org", 80, ssl_context=client_ctx, ) with pytest.raises(trio.BrokenResourceError): await stream.do_handshake() # This one should work! stream = await open_ssl_over_tcp_stream( - "trio-test-1.example.org", - 80, - ssl_context=client_ctx, + "trio-test-1.example.org", 80, ssl_context=client_ctx, ) assert isinstance(stream, trio.SSLStream) assert stream.server_hostname == "trio-test-1.example.org" @@ -103,9 +93,7 @@ async def test_open_ssl_over_tcp_stream_and_everything_else( async def test_open_ssl_over_tcp_listeners(): - (listener,) = await open_ssl_over_tcp_listeners( - 0, SERVER_CTX, host="127.0.0.1" - ) # yapf: disable + (listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1") async with listener: assert isinstance(listener, trio.SSLListener) tl = listener.transport_listener diff --git a/trio/tests/test_path.py b/trio/tests/test_path.py index 9bbbfd4df2..284bcf82dd 100644 --- a/trio/tests/test_path.py +++ b/trio/tests/test_path.py @@ -10,7 +10,7 @@ @pytest.fixture def path(tmpdir): - p = str(tmpdir.join('test')) + p = str(tmpdir.join("test")) return trio.Path(p) @@ -21,32 +21,33 @@ def method_pair(path, method_name): async def test_open_is_async_context_manager(path): - async with await path.open('w') as f: + async with await path.open("w") as f: assert isinstance(f, AsyncIOWrapper) assert f.closed async def test_magic(): - path = trio.Path('test') + path = trio.Path("test") - assert str(path) == 'test' - assert bytes(path) == b'test' + assert str(path) == "test" + assert bytes(path) == b"test" cls_pairs = [ - (trio.Path, pathlib.Path), (pathlib.Path, trio.Path), - (trio.Path, trio.Path) + (trio.Path, pathlib.Path), + (pathlib.Path, trio.Path), + (trio.Path, trio.Path), ] -@pytest.mark.parametrize('cls_a,cls_b', cls_pairs) +@pytest.mark.parametrize("cls_a,cls_b", cls_pairs) async def test_cmp_magic(cls_a, cls_b): - a, b = cls_a(''), cls_b('') + a, b = cls_a(""), cls_b("") assert a == b assert not a != b - a, b = cls_a('a'), cls_b('b') + a, b = cls_a("a"), cls_b("b") assert a < b assert b > a @@ -60,24 +61,26 @@ async def test_cmp_magic(cls_a, cls_b): # __*div__ does not properly raise NotImplementedError like the other comparison # magic, so trio.Path's implementation does not get dispatched cls_pairs = [ - (trio.Path, pathlib.Path), (trio.Path, trio.Path), (trio.Path, str), - (str, trio.Path) + (trio.Path, pathlib.Path), + (trio.Path, trio.Path), + (trio.Path, str), + (str, trio.Path), ] -@pytest.mark.parametrize('cls_a,cls_b', cls_pairs) +@pytest.mark.parametrize("cls_a,cls_b", cls_pairs) async def test_div_magic(cls_a, cls_b): - a, b = cls_a('a'), cls_b('b') + a, b = cls_a("a"), cls_b("b") result = a / b assert isinstance(result, trio.Path) - assert str(result) == os.path.join('a', 'b') + assert str(result) == os.path.join("a", "b") @pytest.mark.parametrize( - 'cls_a,cls_b', [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] + "cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] ) -@pytest.mark.parametrize('path', ["foo", "foo/bar/baz", "./foo"]) +@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"]) async def test_hash_magic(cls_a, cls_b, path): a, b = cls_a(path), cls_b(path) assert hash(a) == hash(b) @@ -86,23 +89,23 @@ async def test_hash_magic(cls_a, cls_b, path): async def test_forwarded_properties(path): # use `name` as a representative of forwarded properties - assert 'name' in dir(path) - assert path.name == 'test' + assert "name" in dir(path) + assert path.name == "test" async def test_async_method_signature(path): # use `resolve` as a representative of wrapped methods - assert path.resolve.__name__ == 'resolve' - assert path.resolve.__qualname__ == 'Path.resolve' + assert path.resolve.__name__ == "resolve" + assert path.resolve.__qualname__ == "Path.resolve" - assert 'pathlib.Path.resolve' in path.resolve.__doc__ + assert "pathlib.Path.resolve" in path.resolve.__doc__ -@pytest.mark.parametrize('method_name', ['is_dir', 'is_file']) +@pytest.mark.parametrize("method_name", ["is_dir", "is_file"]) async def test_compare_async_stat_methods(method_name): - method, async_method = method_pair('.', method_name) + method, async_method = method_pair(".", method_name) result = method() async_result = await async_method() @@ -112,13 +115,13 @@ async def test_compare_async_stat_methods(method_name): async def test_invalid_name_not_wrapped(path): with pytest.raises(AttributeError): - getattr(path, 'invalid_fake_attr') + getattr(path, "invalid_fake_attr") -@pytest.mark.parametrize('method_name', ['absolute', 'resolve']) +@pytest.mark.parametrize("method_name", ["absolute", "resolve"]) async def test_async_methods_rewrap(method_name): - method, async_method = method_pair('.', method_name) + method, async_method = method_pair(".", method_name) result = method() async_result = await async_method() @@ -128,13 +131,13 @@ async def test_async_methods_rewrap(method_name): async def test_forward_methods_rewrap(path, tmpdir): - with_name = path.with_name('foo') - with_suffix = path.with_suffix('.py') + with_name = path.with_name("foo") + with_suffix = path.with_suffix(".py") assert isinstance(with_name, trio.Path) - assert with_name == tmpdir.join('foo') + assert with_name == tmpdir.join("foo") assert isinstance(with_suffix, trio.Path) - assert with_suffix == tmpdir.join('test.py') + assert with_suffix == tmpdir.join("test.py") async def test_forward_properties_rewrap(path): @@ -144,18 +147,18 @@ async def test_forward_properties_rewrap(path): async def test_forward_methods_without_rewrap(path, tmpdir): path = await path.parent.resolve() - assert path.as_uri().startswith('file:///') + assert path.as_uri().startswith("file:///") async def test_repr(): - path = trio.Path('.') + path = trio.Path(".") assert repr(path) == "trio.Path('.')" class MockWrapped: - unsupported = 'unsupported' - _private = 'private' + unsupported = "unsupported" + _private = "private" class MockWrapper: @@ -174,18 +177,18 @@ async def test_type_wraps_unsupported(): async def test_type_forwards_private(): - Type.generate_forwards(MockWrapper, {'unsupported': None}) + Type.generate_forwards(MockWrapper, {"unsupported": None}) - assert not hasattr(MockWrapper, '_private') + assert not hasattr(MockWrapper, "_private") async def test_type_wraps_private(): - Type.generate_wraps(MockWrapper, {'unsupported': None}) + Type.generate_wraps(MockWrapper, {"unsupported": None}) - assert not hasattr(MockWrapper, '_private') + assert not hasattr(MockWrapper, "_private") -@pytest.mark.parametrize('meth', [trio.Path.__init__, trio.Path.joinpath]) +@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath]) async def test_path_wraps_path(path, meth): wrapped = await path.absolute() result = meth(path, wrapped) @@ -201,22 +204,22 @@ async def test_path_nonpath(): async def test_open_file_can_open_path(path): - async with await trio.open_file(path, 'w') as f: + async with await trio.open_file(path, "w") as f: assert f.name == os.fspath(path) async def test_globmethods(path): # Populate a directory tree await path.mkdir() - await (path / 'foo').mkdir() - await (path / 'foo' / '_bar.txt').write_bytes(b'') - await (path / 'bar.txt').write_bytes(b'') - await (path / 'bar.dat').write_bytes(b'') + await (path / "foo").mkdir() + await (path / "foo" / "_bar.txt").write_bytes(b"") + await (path / "bar.txt").write_bytes(b"") + await (path / "bar.dat").write_bytes(b"") # Path.glob for _pattern, _results in { - '*.txt': {'bar.txt'}, - '**/*.txt': {'_bar.txt', 'bar.txt'}, + "*.txt": {"bar.txt"}, + "**/*.txt": {"_bar.txt", "bar.txt"}, }.items(): entries = set() for entry in await path.glob(_pattern): @@ -227,32 +230,32 @@ async def test_globmethods(path): # Path.rglob entries = set() - for entry in await path.rglob('*.txt'): + for entry in await path.rglob("*.txt"): assert isinstance(entry, trio.Path) entries.add(entry.name) - assert entries == {'_bar.txt', 'bar.txt'} + assert entries == {"_bar.txt", "bar.txt"} async def test_iterdir(path): # Populate a directory await path.mkdir() - await (path / 'foo').mkdir() - await (path / 'bar.txt').write_bytes(b'') + await (path / "foo").mkdir() + await (path / "bar.txt").write_bytes(b"") entries = set() for entry in await path.iterdir(): assert isinstance(entry, trio.Path) entries.add(entry.name) - assert entries == {'bar.txt', 'foo'} + assert entries == {"bar.txt", "foo"} async def test_classmethods(): assert isinstance(await trio.Path.home(), trio.Path) # pathlib.Path has only two classmethods - assert str(await trio.Path.home()) == os.path.expanduser('~') + assert str(await trio.Path.home()) == os.path.expanduser("~") assert str(await trio.Path.cwd()) == os.getcwd() # Wrapped method has docstring diff --git a/trio/tests/test_scheduler_determinism.py b/trio/tests/test_scheduler_determinism.py index ba5f469396..67b2447f0a 100644 --- a/trio/tests/test_scheduler_determinism.py +++ b/trio/tests/test_scheduler_determinism.py @@ -26,9 +26,7 @@ def test_the_trio_scheduler_is_not_deterministic(): def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch): - monkeypatch.setattr( - trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True - ) + monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True) traces = [] for _ in range(10): state = trio._core._run._r.getstate() diff --git a/trio/tests/test_signals.py b/trio/tests/test_signals.py index 20821b40f2..235772f900 100644 --- a/trio/tests/test_signals.py +++ b/trio/tests/test_signals.py @@ -108,6 +108,7 @@ async def test_open_signal_receiver_no_starvation(): # open_signal_receiver block might cause the signal to be # redelivered and give us a core dump instead of a traceback... import traceback + traceback.print_exc() @@ -164,9 +165,7 @@ def raise_handler(signum, _): with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler): with pytest.raises(RuntimeError) as excinfo: - with open_signal_receiver( - signal.SIGILL, signal.SIGFPE - ) as receiver: + with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: signal_raise(signal.SIGILL) signal_raise(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index 4e76711d34..2a217abdf6 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -44,9 +44,7 @@ def getaddrinfo(self, *args, **kwargs): elif bound[-1] & stdlib_socket.AI_NUMERICHOST: return self._orig_getaddrinfo(*args, **kwargs) else: - raise RuntimeError( - "gai called with unexpected arguments {}".format(bound) - ) + raise RuntimeError("gai called with unexpected arguments {}".format(bound)) @pytest.fixture @@ -117,29 +115,35 @@ def filtered(gai_list): # Simple non-blocking non-error cases, ipv4 and ipv6: with assert_checkpoints(): - res = await tsocket.getaddrinfo( - "127.0.0.1", "12345", type=tsocket.SOCK_STREAM - ) - - check(res, [ - (tsocket.AF_INET, # 127.0.0.1 is ipv4 - tsocket.SOCK_STREAM, - tsocket.IPPROTO_TCP, - "", - ("127.0.0.1", 12345)), - ]) # yapf: disable + res = await tsocket.getaddrinfo("127.0.0.1", "12345", type=tsocket.SOCK_STREAM) + + check( + res, + [ + ( + tsocket.AF_INET, # 127.0.0.1 is ipv4 + tsocket.SOCK_STREAM, + tsocket.IPPROTO_TCP, + "", + ("127.0.0.1", 12345), + ), + ], + ) with assert_checkpoints(): - res = await tsocket.getaddrinfo( - "::1", "12345", type=tsocket.SOCK_DGRAM - ) - check(res, [ - (tsocket.AF_INET6, - tsocket.SOCK_DGRAM, - tsocket.IPPROTO_UDP, - "", - ("::1", 12345, 0, 0)), - ]) # yapf: disable + res = await tsocket.getaddrinfo("::1", "12345", type=tsocket.SOCK_DGRAM) + check( + res, + [ + ( + tsocket.AF_INET6, + tsocket.SOCK_DGRAM, + tsocket.IPPROTO_UDP, + "", + ("::1", 12345, 0, 0), + ), + ], + ) monkeygai.set("x", b"host", "port", family=0, type=0, proto=0, flags=0) with assert_checkpoints(): @@ -276,6 +280,7 @@ async def test_socket_v6(): @pytest.mark.skipif(not _sys.platform == "linux", reason="linux only") async def test_sniff_sockopts(): from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM + # generate the combinations of families/types we're testing: sockets = [] for family in [AF_INET, AF_INET6]: @@ -392,10 +397,11 @@ async def test_SocketType_shutdown(): @pytest.mark.parametrize( - "address, socket_type", [ - ('127.0.0.1', tsocket.AF_INET), - pytest.param('::1', tsocket.AF_INET6, marks=binds_ipv6) - ] + "address, socket_type", + [ + ("127.0.0.1", tsocket.AF_INET), + pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6), + ], ) async def test_SocketType_simple_server(address, socket_type): # listen, bind, accept, connect, getpeername, getsockname @@ -448,7 +454,8 @@ class Addresses: # Direct thorough tests of the implicit resolver helpers @pytest.mark.parametrize( - "socket_type, addrs", [ + "socket_type, addrs", + [ ( tsocket.AF_INET, Addresses( @@ -470,10 +477,10 @@ class Addresses: ), marks=creates_ipv6, ), - ] + ], ) async def test_SocketType_resolve(socket_type, addrs): - v6 = (socket_type == tsocket.AF_INET6) + v6 = socket_type == tsocket.AF_INET6 # For some reason the stdlib special-cases "" to pass NULL to getaddrinfo # They also error out on None, but whatever, None is much more consistent, @@ -492,22 +499,21 @@ async def test_SocketType_resolve(socket_type, addrs): async def res(*args): return await getattr(sock, resolver)(*args) - # yapf: disable - assert await res((addrs.arbitrary, - "http")) == (addrs.arbitrary, 80, *addrs.extra) + assert await res((addrs.arbitrary, "http")) == ( + addrs.arbitrary, + 80, + *addrs.extra, + ) if v6: assert await res(("1::2", 80, 1)) == ("1::2", 80, 1, 0) assert await res(("1::2", 80, 1, 2)) == ("1::2", 80, 1, 2) # V4 mapped addresses resolved if V6ONLY is False sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, False) - assert await res(("1.2.3.4", - "http")) == ("::ffff:1.2.3.4", 80, 0, 0) + assert await res(("1.2.3.4", "http")) == ("::ffff:1.2.3.4", 80, 0, 0,) # Check the special case, because why not - assert await res(("", - 123)) == (addrs.broadcast, 123, *addrs.extra) - # yapf: enable + assert await res(("", 123)) == (addrs.broadcast, 123, *addrs.extra,) # But not if it's true (at least on systems where getaddrinfo works # correctly) @@ -704,7 +710,7 @@ async def _resolve_remote_address(self, *args, **kwargs): sock._resolve_remote_address = _resolve_remote_address with assert_checkpoints(): with pytest.raises(_core.Cancelled): - await sock.connect('') + await sock.connect("") assert sock.fileno() == -1 @@ -850,9 +856,11 @@ async def getnameinfo(self, sockaddr, flags): (0, 0, tsocket.IPPROTO_TCP, 0), (0, 0, 0, tsocket.AI_CANONNAME), ]: - assert ( - await tsocket.getaddrinfo("localhost", "foo", *vals) == - ("custom_gai", b"localhost", "foo", *vals) + assert await tsocket.getaddrinfo("localhost", "foo", *vals) == ( + "custom_gai", + b"localhost", + "foo", + *vals, ) # IDNA encoding is handled before calling the special object @@ -860,7 +868,7 @@ async def getnameinfo(self, sockaddr, flags): expected = ("custom_gai", b"xn--f-1gaa", "foo", 0, 0, 0, 0) assert got == expected - assert (await tsocket.getnameinfo("a", 0) == ("custom_gni", "a", 0)) + assert await tsocket.getnameinfo("a", 0) == ("custom_gni", "a", 0) # We can set it back to None assert tsocket.set_custom_hostname_resolver(None) is cr @@ -903,9 +911,7 @@ async def test_SocketType_is_abstract(): tsocket.SocketType() -@pytest.mark.skipif( - not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets" -) +@pytest.mark.skipif(not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets") async def test_unix_domain_socket(): # Bind has a special branch to use a thread, since it has to do filesystem # traversal. Maybe connect should too? Not sure. diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index 9cdad56fe3..dee8e325f7 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -146,8 +146,7 @@ async def ssl_echo_server_raw(**kwargs): # nursery context manager to exit too. with a, b: nursery.start_soon( - trio.to_thread.run_sync, - partial(ssl_echo_serve_sync, b, **kwargs) + trio.to_thread.run_sync, partial(ssl_echo_serve_sync, b, **kwargs), ) yield SocketStream(tsocket.from_stdlib_socket(a)) @@ -158,9 +157,7 @@ async def ssl_echo_server_raw(**kwargs): @asynccontextmanager async def ssl_echo_server(client_ctx, **kwargs): async with ssl_echo_server_raw(**kwargs) as sock: - yield SSLStream( - sock, client_ctx, server_hostname="trio-test-1.example.org" - ) + yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org") # The weird in-memory server ... thing. @@ -190,6 +187,7 @@ def __init__(self, sleeper=None): # Fortunately pyopenssl uses cryptography under the hood, so we can be # confident that they're using the same version of openssl from cryptography.hazmat.bindings.openssl.binding import Binding + b = Binding() if hasattr(b.lib, "SSL_OP_NO_TLSv1_3"): ctx.set_options(b.lib.SSL_OP_NO_TLSv1_3) @@ -358,9 +356,7 @@ async def test_PyOpenSSLEchoStream_gives_resource_busy_errors(): @contextmanager def virtual_ssl_echo_server(client_ctx, **kwargs): fakesock = PyOpenSSLEchoStream(**kwargs) - yield SSLStream( - fakesock, client_ctx, server_hostname="trio-test-1.example.org" - ) + yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org") def ssl_wrap_pair( @@ -369,13 +365,13 @@ def ssl_wrap_pair( server_transport, *, client_kwargs={}, - server_kwargs={} + server_kwargs={}, ): client_ssl = SSLStream( client_transport, client_ctx, server_hostname="trio-test-1.example.org", - **client_kwargs + **client_kwargs, ) server_ssl = SSLStream( server_transport, SERVER_CTX, server_side=True, **server_kwargs @@ -385,16 +381,12 @@ def ssl_wrap_pair( def ssl_memory_stream_pair(client_ctx, **kwargs): client_transport, server_transport = memory_stream_pair() - return ssl_wrap_pair( - client_ctx, client_transport, server_transport, **kwargs - ) + return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) def ssl_lockstep_stream_pair(client_ctx, **kwargs): client_transport, server_transport = lockstep_stream_pair() - return ssl_wrap_pair( - client_ctx, client_transport, server_transport, **kwargs - ) + return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) # Simple smoke test for handshake/send/receive/shutdown talking to a @@ -411,9 +403,7 @@ async def test_ssl_client_basics(client_ctx): # Didn't configure the CA file, should fail async with ssl_echo_server_raw(expect_fail=True) as sock: bad_client_ctx = ssl.create_default_context() - s = SSLStream( - sock, bad_client_ctx, server_hostname="trio-test-1.example.org" - ) + s = SSLStream(sock, bad_client_ctx, server_hostname="trio-test-1.example.org") assert not s.server_side with pytest.raises(BrokenResourceError) as excinfo: await s.send_all(b"x") @@ -421,9 +411,7 @@ async def test_ssl_client_basics(client_ctx): # Trusted CA, but wrong host name async with ssl_echo_server_raw(expect_fail=True) as sock: - s = SSLStream( - sock, client_ctx, server_hostname="trio-test-2.example.org" - ) + s = SSLStream(sock, client_ctx, server_hostname="trio-test-2.example.org") assert not s.server_side with pytest.raises(BrokenResourceError) as excinfo: await s.send_all(b"x") @@ -464,9 +452,7 @@ async def test_attributes(client_ctx): async with ssl_echo_server_raw(expect_fail=True) as sock: good_ctx = client_ctx bad_ctx = ssl.create_default_context() - s = SSLStream( - sock, good_ctx, server_hostname="trio-test-1.example.org" - ) + s = SSLStream(sock, good_ctx, server_hostname="trio-test-1.example.org") assert s.transport_stream is sock @@ -593,6 +579,7 @@ async def test_renegotiation_randomized(mock_clock, client_ctx): mock_clock.autojump_threshold = 0 import random + r = random.Random(0) async def sleeper(_): @@ -629,8 +616,8 @@ async def expect(expected): await clear() for i in range(100): - b1 = bytes([i % 0xff]) - b2 = bytes([(2 * i) % 0xff]) + b1 = bytes([i % 0xFF]) + b2 = bytes([(2 * i) % 0xFF]) s.transport_stream.renegotiate() async with _core.open_nursery() as nursery: nursery.start_soon(send, b1) @@ -641,8 +628,8 @@ async def expect(expected): await clear() for i in range(100): - b1 = bytes([i % 0xff]) - b2 = bytes([(2 * i) % 0xff]) + b1 = bytes([i % 0xFF]) + b2 = bytes([(2 * i) % 0xFF]) await send(b1) s.transport_stream.renegotiate() await expect(b1) @@ -668,9 +655,7 @@ async def sleep_then_wait_writable(): await trio.sleep(1000) await s.wait_send_all_might_not_block() - with virtual_ssl_echo_server( - client_ctx, sleeper=sleeper_with_slow_send_all - ) as s: + with virtual_ssl_echo_server(client_ctx, sleeper=sleeper_with_slow_send_all) as s: await send(b"x") s.transport_stream.renegotiate() async with _core.open_nursery() as nursery: @@ -1023,7 +1008,7 @@ async def test_ssl_bad_shutdown_but_its_ok(client_ctx): client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, - client_kwargs={"https_compatible": True} + client_kwargs={"https_compatible": True}, ) async with _core.open_nursery() as nursery: @@ -1047,9 +1032,7 @@ async def test_ssl_handshake_failure_during_aclose(): async with ssl_echo_server_raw(expect_fail=True) as sock: # Don't configure trust correctly client_ctx = ssl.create_default_context() - s = SSLStream( - sock, client_ctx, server_hostname="trio-test-1.example.org" - ) + s = SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org") # It's a little unclear here whether aclose should swallow the error # or let it escape. We *do* swallow the error if it arrives when we're # sending close_notify, because both sides closing the connection @@ -1089,7 +1072,7 @@ async def test_ssl_https_compatibility_disagreement(client_ctx): client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": False}, - client_kwargs={"https_compatible": True} + client_kwargs={"https_compatible": True}, ) async with _core.open_nursery() as nursery: @@ -1112,7 +1095,7 @@ async def test_https_mode_eof_before_handshake(client_ctx): client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, - client_kwargs={"https_compatible": True} + client_kwargs={"https_compatible": True}, ) async def server_expect_clean_eof(): @@ -1185,8 +1168,7 @@ async def test_selected_alpn_protocol_when_not_set(client_ctx): assert client.selected_alpn_protocol() is None assert server.selected_alpn_protocol() is None - assert client.selected_alpn_protocol() == \ - server.selected_alpn_protocol() + assert client.selected_alpn_protocol() == server.selected_alpn_protocol() async def test_selected_npn_protocol_before_handshake(client_ctx): @@ -1211,8 +1193,7 @@ async def test_selected_npn_protocol_when_not_set(client_ctx): assert client.selected_npn_protocol() is None assert server.selected_npn_protocol() is None - assert client.selected_npn_protocol() == \ - server.selected_npn_protocol() + assert client.selected_npn_protocol() == server.selected_npn_protocol() async def test_get_channel_binding_before_handshake(client_ctx): @@ -1235,8 +1216,7 @@ async def test_get_channel_binding_after_handshake(client_ctx): assert client.get_channel_binding() is not None assert server.get_channel_binding() is not None - assert client.get_channel_binding() == \ - server.get_channel_binding() + assert client.get_channel_binding() == server.get_channel_binding() async def test_getpeercert(client_ctx): @@ -1249,10 +1229,7 @@ async def test_getpeercert(client_ctx): assert server.getpeercert() is None print(client.getpeercert()) - assert ( - ("DNS", "trio-test-1.example.org") - in client.getpeercert()["subjectAltName"] - ) + assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"] async def test_SSLListener(client_ctx): @@ -1265,9 +1242,7 @@ async def setup(**kwargs): transport_client = await open_tcp_stream(*listen_sock.getsockname()) ssl_client = SSLStream( - transport_client, - client_ctx, - server_hostname="trio-test-1.example.org" + transport_client, client_ctx, server_hostname="trio-test-1.example.org", ) return listen_sock, ssl_listener, ssl_client diff --git a/trio/tests/test_subprocess.py b/trio/tests/test_subprocess.py index c5be0e393e..1f489db50c 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/tests/test_subprocess.py @@ -7,8 +7,15 @@ from functools import partial from .. import ( - _core, move_on_after, fail_after, sleep, sleep_forever, Process, - open_process, run_process, TrioDeprecationWarning + _core, + move_on_after, + fail_after, + sleep, + sleep_forever, + Process, + open_process, + run_process, + TrioDeprecationWarning, ) from .._core.tests.tutil import slow, skip_if_fbsd_pipes_broken from ..testing import wait_all_tasks_blocked @@ -183,9 +190,7 @@ async def drain_one(stream, count, digit): assert await stream.receive_some(len(newline)) == newline nursery.start_soon(drain_one, proc.stdout, request, idx * 2) - nursery.start_soon( - drain_one, proc.stderr, request * 2, idx * 2 + 1 - ) + nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1) with fail_after(5): await proc.stdin.send_all(b"12") @@ -210,7 +215,7 @@ async def drain_one(stream, count, digit): async def test_run(): - data = bytes(random.randint(0, 255) for _ in range(2**18)) + data = bytes(random.randint(0, 255) for _ in range(2 ** 18)) result = await run_process( CAT, stdin=data, capture_stdout=True, capture_stderr=True @@ -269,8 +274,7 @@ async def test_run_check(): @skip_if_fbsd_pipes_broken async def test_run_with_broken_pipe(): result = await run_process( - [sys.executable, "-c", "import sys; sys.stdin.close()"], - stdin=b"x" * 131072, + [sys.executable, "-c", "import sys; sys.stdin.close()"], stdin=b"x" * 131072, ) assert result.returncode == 0 assert result.stdout is result.stderr is None @@ -402,6 +406,7 @@ def test_waitid_eintr(): # This only matters on PyPy (where we're coding EINTR handling # ourselves) but the test works on all waitid platforms. from .._subprocess_platform import wait_child_exiting + if not wait_child_exiting.__module__.endswith("waitid"): pytest.skip("waitid only") from .._subprocess_platform.waitid import sync_wait_reapable @@ -444,9 +449,7 @@ async def custom_deliver_cancel(proc): async with _core.open_nursery() as nursery: nursery.start_soon( - partial( - run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel - ) + partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel) ) await wait_all_tasks_blocked() nursery.cancel_scope.cancel() diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index e4de476993..229dea301c 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -111,6 +111,7 @@ async def test_CapacityLimiter(): async def test_CapacityLimiter_inf(): from math import inf + c = CapacityLimiter(inf) repr(c) # smoke test assert c.total_tokens == inf @@ -240,9 +241,7 @@ async def test_Semaphore_bounded(): assert bs.value == 1 -@pytest.mark.parametrize( - "lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__ -) +@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__) async def test_Lock_and_StrictFIFOLock(lockcls): l = lockcls() # noqa assert not l.locked() @@ -548,7 +547,7 @@ async def loopy(name, lock_like): # The first three could be in any order due to scheduling randomness, # but after that they should repeat in the same order for i in range(LOOPS): - assert record[3 * i:3 * (i + 1)] == initial_order + assert record[3 * i : 3 * (i + 1)] == initial_order @generic_lock_test diff --git a/trio/tests/test_testing.py b/trio/tests/test_testing.py index e73624a67c..460200b28c 100644 --- a/trio/tests/test_testing.py +++ b/trio/tests/test_testing.py @@ -146,7 +146,8 @@ async def test_assert_checkpoints(recwarn): # if you have a schedule point but not a cancel point, or vice-versa, then # that's not a checkpoint. for partial_yield in [ - _core.checkpoint_if_cancelled, _core.cancel_shielded_checkpoint + _core.checkpoint_if_cancelled, + _core.cancel_shielded_checkpoint, ]: print(partial_yield) with pytest.raises(AssertionError): @@ -171,7 +172,8 @@ async def test_assert_no_checkpoints(recwarn): # if you have a schedule point but not a cancel point, or vice-versa, then # that doesn't make *either* version of assert_{no_,}yields happy. for partial_yield in [ - _core.checkpoint_if_cancelled, _core.cancel_shielded_checkpoint + _core.checkpoint_if_cancelled, + _core.cancel_shielded_checkpoint, ]: print(partial_yield) with pytest.raises(AssertionError): @@ -215,9 +217,7 @@ async def f2(seq): nursery.start_soon(f2, seq) async with seq(5): await wait_all_tasks_blocked() - assert record == [ - ("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4) - ] + assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)] seq = Sequencer() # Catches us if we try to re-use a sequence point: @@ -321,11 +321,7 @@ async def test_mock_clock_autojump(mock_clock): virtual_start = _core.current_time() real_duration = time.perf_counter() - real_start - print( - "Slept {} seconds in {} seconds".format( - 10 * sum(range(10)), real_duration - ) - ) + print("Slept {} seconds in {} seconds".format(10 * sum(range(10)), real_duration)) assert real_duration < 1 mock_clock.autojump_threshold = 0.02 @@ -571,9 +567,7 @@ def close_hook(): record.append("close_hook") mss2 = MemorySendStream( - send_all_hook, - wait_send_all_might_not_block_hook, - close_hook, + send_all_hook, wait_send_all_might_not_block_hook, close_hook, ) assert mss2.send_all_hook is send_all_hook diff --git a/trio/tests/test_threads.py b/trio/tests/test_threads.py index 6f5d2b6229..b4acae8b58 100644 --- a/trio/tests/test_threads.py +++ b/trio/tests/test_threads.py @@ -8,8 +8,11 @@ from .. import Event, CapacityLimiter, sleep from ..testing import wait_all_tasks_blocked from .._threads import ( - to_thread_run_sync, current_default_thread_limiter, from_thread_run, - from_thread_run_sync, BlockingTrioPortal + to_thread_run_sync, + current_default_thread_limiter, + from_thread_run, + from_thread_run_sync, + BlockingTrioPortal, ) from .._core.tests.test_ki import ki_self @@ -35,9 +38,7 @@ def threadfn(): while child_thread.is_alive(): print("yawn") await sleep(0.01) - assert record == [ - ("start", child_thread), ("f", trio_thread), expected - ] + assert record == [("start", child_thread), ("f", trio_thread), expected] token = _core.current_trio_token() @@ -53,9 +54,7 @@ def f(record): record.append(("f", threading.current_thread())) raise ValueError - await check_case( - from_thread_run_sync, f, ("error", ValueError), trio_token=token - ) + await check_case(from_thread_run_sync, f, ("error", ValueError), trio_token=token) async def f(record): assert not _core.currently_ki_protected() @@ -101,6 +100,7 @@ def trio_thread_fn(): ki_self() finally: import sys + print("finally", sys.exc_info()) async def trio_thread_afn(): @@ -332,15 +332,9 @@ def thread_fn(cancel_scope): async def run_thread(event): with _core.CancelScope() as cancel_scope: await to_thread_run_sync( - thread_fn, - cancel_scope, - limiter=limiter_arg, - cancellable=cancel + thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel, ) - print( - "run_thread finished, cancelled:", - cancel_scope.cancelled_caught - ) + print("run_thread finished, cancelled:", cancel_scope.cancelled_caught) event.set() async with _core.open_nursery() as nursery: @@ -521,9 +515,7 @@ async def test_trio_from_thread_token_kwarg(): # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can # use an explicitly defined token def thread_fn(token): - callee_token = from_thread_run_sync( - _core.current_trio_token, trio_token=token - ) + callee_token = from_thread_run_sync(_core.current_trio_token, trio_token=token) return callee_token caller_token = _core.current_trio_token() @@ -541,9 +533,7 @@ async def test_from_thread_no_token(): def test_run_fn_as_system_task_catched_badly_typed_token(): with pytest.raises(RuntimeError): - from_thread_run_sync( - _core.current_time, trio_token="Not TrioTokentype" - ) + from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype") async def test_do_in_trio_thread_from_trio_thread_legacy(): @@ -580,5 +570,6 @@ def worker_thread(token): def test_BlockingTrioPortal_deprecated_export(recwarn): import trio + btp = trio.BlockingTrioPortal assert btp is BlockingTrioPortal diff --git a/trio/tests/test_unix_pipes.py b/trio/tests/test_unix_pipes.py index 5b85716b75..55dd4e3734 100644 --- a/trio/tests/test_unix_pipes.py +++ b/trio/tests/test_unix_pipes.py @@ -74,7 +74,7 @@ async def test_receive_pipe(): async def test_pipes_combined(): write, read = await make_pipe() - count = 2**20 + count = 2 ** 20 async def sender(): big = bytearray(count) @@ -195,9 +195,7 @@ async def patched_wait_readable(*args, **kwargs): await orig_wait_readable(*args, **kwargs) await r.aclose() - monkeypatch.setattr( - _core._run.TheIOManager, "wait_readable", patched_wait_readable - ) + monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable) s, r = await make_pipe() async with s, r: async with _core.open_nursery() as nursery: @@ -225,9 +223,7 @@ async def patched_wait_writable(*args, **kwargs): await orig_wait_writable(*args, **kwargs) await s.aclose() - monkeypatch.setattr( - _core._run.TheIOManager, "wait_writable", patched_wait_writable - ) + monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable) s, r = await make_clogged_pipe() async with s, r: async with _core.open_nursery() as nursery: @@ -243,7 +239,7 @@ async def patched_wait_writable(*args, **kwargs): # other platforms is probably good enough. @pytest.mark.skipif( sys.platform.startswith("freebsd"), - reason="no way to make read() return a bizarro error on FreeBSD" + reason="no way to make read() return a bizarro error on FreeBSD", ) async def test_bizarro_OSError_from_receive(): # Make sure that if the read syscall returns some bizarro error, then we diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py index 009f9fa8f7..b08676e622 100644 --- a/trio/tests/test_util.py +++ b/trio/tests/test_util.py @@ -5,9 +5,14 @@ from .. import _core from .._core.tests.tutil import ignore_coroutine_never_awaited_warnings from .._util import ( - signal_raise, ConflictDetector, is_main_thread, coroutine_or_error, - generic_function, Final, NoPublicConstructor, - SubclassingDeprecatedIn_v0_15_0 + signal_raise, + ConflictDetector, + is_main_thread, + coroutine_or_error, + generic_function, + Final, + NoPublicConstructor, + SubclassingDeprecatedIn_v0_15_0, ) from ..testing import wait_all_tasks_blocked @@ -53,11 +58,12 @@ async def wait_with_ul1(): def test_module_metadata_is_fixed_up(): import trio + import trio.testing + assert trio.Cancelled.__module__ == "trio" assert trio.open_nursery.__module__ == "trio" assert trio.abc.Stream.__module__ == "trio.abc" assert trio.lowlevel.wait_task_rescheduled.__module__ == "trio.lowlevel" - import trio.testing assert trio.testing.trio_test.__module__ == "trio.testing" # Also check methods diff --git a/trio/tests/test_wait_for_object.py b/trio/tests/test_wait_for_object.py index 3c3830ea39..ac507a26f9 100644 --- a/trio/tests/test_wait_for_object.py +++ b/trio/tests/test_wait_for_object.py @@ -2,7 +2,7 @@ import pytest -on_windows = (os.name == "nt") +on_windows = os.name == "nt" # Mark all the tests in this file as being windows-only pytestmark = pytest.mark.skipif(not on_windows, reason="windows only") @@ -10,9 +10,13 @@ import trio from .. import _core from .. import _timeouts + if on_windows: from .._core._windows_cffi import ffi, kernel32 - from .._wait_for_object import WaitForSingleObject, WaitForMultipleObjects_sync + from .._wait_for_object import ( + WaitForSingleObject, + WaitForMultipleObjects_sync, + ) async def test_WaitForMultipleObjects_sync(): @@ -29,7 +33,7 @@ async def test_WaitForMultipleObjects_sync(): kernel32.SetEvent(handle1) WaitForMultipleObjects_sync(handle1) kernel32.CloseHandle(handle1) - print('test_WaitForMultipleObjects_sync one OK') + print("test_WaitForMultipleObjects_sync one OK") # Two handles, signal first handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -38,7 +42,7 @@ async def test_WaitForMultipleObjects_sync(): WaitForMultipleObjects_sync(handle1, handle2) kernel32.CloseHandle(handle1) kernel32.CloseHandle(handle2) - print('test_WaitForMultipleObjects_sync set first OK') + print("test_WaitForMultipleObjects_sync set first OK") # Two handles, signal second handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -47,7 +51,7 @@ async def test_WaitForMultipleObjects_sync(): WaitForMultipleObjects_sync(handle1, handle2) kernel32.CloseHandle(handle1) kernel32.CloseHandle(handle2) - print('test_WaitForMultipleObjects_sync set second OK') + print("test_WaitForMultipleObjects_sync set second OK") # Two handles, close first handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -56,7 +60,7 @@ async def test_WaitForMultipleObjects_sync(): with pytest.raises(OSError): WaitForMultipleObjects_sync(handle1, handle2) kernel32.CloseHandle(handle2) - print('test_WaitForMultipleObjects_sync close first OK') + print("test_WaitForMultipleObjects_sync close first OK") # Two handles, close second handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -65,7 +69,7 @@ async def test_WaitForMultipleObjects_sync(): with pytest.raises(OSError): WaitForMultipleObjects_sync(handle1, handle2) kernel32.CloseHandle(handle1) - print('test_WaitForMultipleObjects_sync close second OK') + print("test_WaitForMultipleObjects_sync close second OK") @slow @@ -89,7 +93,7 @@ async def test_WaitForMultipleObjects_sync_slow(): t1 = _core.current_time() assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT kernel32.CloseHandle(handle1) - print('test_WaitForMultipleObjects_sync_slow one OK') + print("test_WaitForMultipleObjects_sync_slow one OK") # Two handles, signal first handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -97,8 +101,7 @@ async def test_WaitForMultipleObjects_sync_slow(): t0 = _core.current_time() async with _core.open_nursery() as nursery: nursery.start_soon( - trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, - handle2 + trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2, ) await _timeouts.sleep(TIMEOUT) kernel32.SetEvent(handle1) @@ -106,7 +109,7 @@ async def test_WaitForMultipleObjects_sync_slow(): assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT kernel32.CloseHandle(handle1) kernel32.CloseHandle(handle2) - print('test_WaitForMultipleObjects_sync_slow thread-set first OK') + print("test_WaitForMultipleObjects_sync_slow thread-set first OK") # Two handles, signal second handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -114,8 +117,7 @@ async def test_WaitForMultipleObjects_sync_slow(): t0 = _core.current_time() async with _core.open_nursery() as nursery: nursery.start_soon( - trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, - handle2 + trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2, ) await _timeouts.sleep(TIMEOUT) kernel32.SetEvent(handle2) @@ -123,7 +125,7 @@ async def test_WaitForMultipleObjects_sync_slow(): assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT kernel32.CloseHandle(handle1) kernel32.CloseHandle(handle2) - print('test_WaitForMultipleObjects_sync_slow thread-set second OK') + print("test_WaitForMultipleObjects_sync_slow thread-set second OK") async def test_WaitForSingleObject(): @@ -135,7 +137,7 @@ async def test_WaitForSingleObject(): kernel32.SetEvent(handle) await WaitForSingleObject(handle) # should return at once kernel32.CloseHandle(handle) - print('test_WaitForSingleObject already set OK') + print("test_WaitForSingleObject already set OK") # Test already set, as int handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -143,21 +145,21 @@ async def test_WaitForSingleObject(): kernel32.SetEvent(handle) await WaitForSingleObject(handle_int) # should return at once kernel32.CloseHandle(handle) - print('test_WaitForSingleObject already set OK') + print("test_WaitForSingleObject already set OK") # Test already closed handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) kernel32.CloseHandle(handle) with pytest.raises(OSError): await WaitForSingleObject(handle) # should return at once - print('test_WaitForSingleObject already closed OK') + print("test_WaitForSingleObject already closed OK") # Not a handle with pytest.raises(TypeError): await WaitForSingleObject("not a handle") # Wrong type # with pytest.raises(OSError): # await WaitForSingleObject(99) # If you're unlucky, it actually IS a handle :( - print('test_WaitForSingleObject not a handle OK') + print("test_WaitForSingleObject not a handle OK") @slow @@ -185,7 +187,7 @@ async def signal_soon_async(handle): kernel32.CloseHandle(handle) t1 = _core.current_time() assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT - print('test_WaitForSingleObject_slow set from task OK') + print("test_WaitForSingleObject_slow set from task OK") # Test handle is SET after TIMEOUT in separate coroutine, as int @@ -200,7 +202,7 @@ async def signal_soon_async(handle): kernel32.CloseHandle(handle) t1 = _core.current_time() assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT - print('test_WaitForSingleObject_slow set from task as int OK') + print("test_WaitForSingleObject_slow set from task as int OK") # Test handle is CLOSED after 1 sec - NOPE see comment above @@ -215,4 +217,4 @@ async def signal_soon_async(handle): kernel32.CloseHandle(handle) t1 = _core.current_time() assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT - print('test_WaitForSingleObject_slow cancellation OK') + print("test_WaitForSingleObject_slow cancellation OK") diff --git a/trio/tests/test_windows_pipes.py b/trio/tests/test_windows_pipes.py index 864aaf768e..8fb29b632f 100644 --- a/trio/tests/test_windows_pipes.py +++ b/trio/tests/test_windows_pipes.py @@ -47,7 +47,7 @@ async def test_pipe_error_on_close(): async def test_pipes_combined(): write, read = await make_pipe() - count = 2**20 + count = 2 ** 20 replicas = 3 async def sender(): diff --git a/trio/tests/tools/test_gen_exports.py b/trio/tests/tools/test_gen_exports.py index 6c1fb0d668..e4e388c226 100644 --- a/trio/tests/tools/test_gen_exports.py +++ b/trio/tests/tools/test_gen_exports.py @@ -6,7 +6,9 @@ from shutil import copyfile from trio._tools.gen_exports import ( - get_public_methods, create_passthrough_args, process + get_public_methods, + create_passthrough_args, + process, ) SOURCE = '''from _run import _public @@ -43,7 +45,7 @@ def test_create_pass_through_args(): ("def f(one, *args)", "(one, *args)"), ( "def f(one, *args, kw1, kw2=None, **kwargs)", - "(one, *args, kw1=kw1, kw2=kw2, **kwargs)" + "(one, *args, kw1=kw1, kw2=kw2, **kwargs)", ), ]