Skip to content

Commit

Permalink
Merge pull request #17 from simonsobs/type
Browse files Browse the repository at this point in the history
Add type hints to all defs
  • Loading branch information
TaiSakuma authored Oct 20, 2023
2 parents a8e16a3 + 57f654a commit a3da610
Show file tree
Hide file tree
Showing 33 changed files with 130 additions and 106 deletions.
14 changes: 7 additions & 7 deletions nextline/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Set
from typing import TYPE_CHECKING, AsyncIterator, Set

from nextline.utils.pubsub import PubSubItem

Expand All @@ -31,15 +31,15 @@ async def close(self) -> None:
await asyncio.gather(*self._tasks)
await self._pubsub_enabled.close()

async def __aenter__(self):
async def __aenter__(self) -> 'Continuous':
await self.start()
return self

async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(self, exc_type, exc_value, traceback): # type: ignore
del exc_type, exc_value, traceback
await self.close()

async def _monitor_state(self):
async def _monitor_state(self) -> None:
async for state in self._nextline.subscribe_state():
if state == 'initialized' and self._tasks:
_, pending = await asyncio.wait(
Expand All @@ -54,10 +54,10 @@ async def run_and_continue(self) -> None:
self._tasks.add(task)
await started.wait()

async def run_continue_and_wait(self, started: asyncio.Event):
async def run_continue_and_wait(self, started: asyncio.Event) -> None:
await self._run_and_continue(started)

async def _run_and_continue(self, started: asyncio.Event):
async def _run_and_continue(self, started: asyncio.Event) -> None:
await self._pubsub_enabled.publish(True)
try:
async with self._nextline.run_session():
Expand All @@ -77,5 +77,5 @@ async def _run_and_continue(self, started: asyncio.Event):
def enabled(self) -> bool:
return self._pubsub_enabled.latest()

def subscribe_enabled(self):
def subscribe_enabled(self) -> AsyncIterator[bool]:
return self._pubsub_enabled.subscribe()
10 changes: 5 additions & 5 deletions nextline/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@
_T = TypeVar("_T", bound=int)


def RunNoCounter(start=1) -> Callable[[], RunNo]:
def RunNoCounter(start: int = 1) -> Callable[[], RunNo]:
return CastedCounter(count(start).__next__, RunNo)


def TraceNoCounter(start=1) -> Callable[[], TraceNo]:
def TraceNoCounter(start: int = 1) -> Callable[[], TraceNo]:
return CastedCounter(count(start).__next__, TraceNo)


def ThreadNoCounter(start=1) -> Callable[[], ThreadNo]:
def ThreadNoCounter(start: int = 1) -> Callable[[], ThreadNo]:
return CastedCounter(count(start).__next__, ThreadNo)


def TaskNoCounter(start=1) -> Callable[[], TaskNo]:
def TaskNoCounter(start: int = 1) -> Callable[[], TaskNo]:
return CastedCounter(count(start).__next__, TaskNo)


def PromptNoCounter(start=1) -> Callable[[], PromptNo]:
def PromptNoCounter(start: int = 1) -> Callable[[], PromptNo]:
return CastedCounter(count(start).__next__, PromptNo)


Expand Down
3 changes: 2 additions & 1 deletion nextline/disable.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import contextlib
import sys
from collections.abc import Iterator


@contextlib.contextmanager
def disable_trace():
def disable_trace() -> Iterator[None]:
'''Remove the system trace function temporarily.
Example:
Expand Down
2 changes: 1 addition & 1 deletion nextline/fsm/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
}


def build_state_machine(model=None, graph=False, asyncio=True, markup=False) -> Machine:
def build_state_machine(model=None, graph=False, asyncio=True, markup=False) -> Machine: # type: ignore
MachineClass: Type[Machine]
if markup:
MachineClass = MarkupMachine
Expand Down
12 changes: 6 additions & 6 deletions nextline/fsm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def __init__(self, run_no_start_from: int, statement: Statement):

assert self.state # type: ignore

def __repr__(self):
def __repr__(self) -> str:
# e.g., "<Machine 'running'>"
return f'<{self.__class__.__name__} {self.state!r}>'
return f'<{self.__class__.__name__} {self.state!r}>' # type: ignore

async def after_state_change(self, event: EventData) -> None:
if not (event.transition and event.transition.dest):
Expand Down Expand Up @@ -93,10 +93,10 @@ async def on_reset(self, event: EventData) -> None:
# TODO: Check the arguments
await self._hook.ahook.reset(*event.args, **event.kwargs)

async def __aenter__(self):
await self.initialize()
async def __aenter__(self) -> 'Machine':
await self.initialize() # type: ignore
return self

async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(self, exc_type, exc_value, traceback) -> None: # type: ignore
del exc_type, exc_value, traceback
await self.close()
await self.close() # type: ignore
16 changes: 8 additions & 8 deletions nextline/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self._started = False
self._closed = False

def __repr__(self):
def __repr__(self) -> str:
# e.g., "<Nextline 'running'>"
return f'<{self.__class__.__name__} {self.state!r}>'

Expand All @@ -84,11 +84,11 @@ async def close(self) -> None:
await self._machine.close() # type: ignore
await self._continuous.close()

async def __aenter__(self):
async def __aenter__(self) -> 'Nextline':
await self.start()
return self

async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(self, exc_type, exc_value, traceback) -> None: # type: ignore
del exc_type, exc_value, traceback
await asyncio.wait_for(self.close(), timeout=self._timeout_on_exit)

Expand Down Expand Up @@ -181,7 +181,7 @@ def subscribe_state(self) -> AsyncIterator[str]:
return self.subscribe("state_name")

@property
def run_no(self):
def run_no(self) -> int:
"""The current run number"""
return self.get("run_no")

Expand All @@ -198,12 +198,12 @@ def trace_ids(self) -> tuple[int, ...]:
def subscribe_trace_ids(self) -> AsyncIterator[tuple[int, ...]]:
return self.subscribe("trace_nos")

def get_source(self, file_name=None):
def get_source(self, file_name: Optional[str] = None) -> list[str]:
if not file_name or file_name == self._registry.latest("script_file_name"):
return self.get("statement").split("\n")
return [e.rstrip() for e in linecache.getlines(file_name)]

def get_source_line(self, line_no, file_name=None):
def get_source_line(self, line_no: int, file_name: Optional[str] = None) -> str:
"""
based on linecache.getline()
https://github.com/python/cpython/blob/v3.9.5/Lib/linecache.py#L26
Expand Down Expand Up @@ -250,10 +250,10 @@ async def _subscribe_prompt_info_for(
assert isinstance(info, PromptInfo)
yield info

def get(self, key) -> Any:
def get(self, key: Any) -> Any:
return self._registry.latest(key)

def subscribe(self, key, last: Optional[bool] = True) -> AsyncIterator[Any]:
def subscribe(self, key: Any, last: Optional[bool] = True) -> AsyncIterator[Any]:
return self._registry.subscribe(key, last=last)

def subscribe_stdout(self) -> AsyncIterator[StdoutInfo]:
Expand Down
9 changes: 4 additions & 5 deletions nextline/plugin/plugins/session/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from apluggy import PluginManager

from nextline import spawned

from ....spawned import QueueOut
from nextline.spawned import QueueOut

# from rich import print

Expand All @@ -17,7 +16,7 @@ def __init__(self, hook: PluginManager, queue: QueueOut):
self._queue = queue
self._logger = getLogger(__name__)

async def open(self):
async def open(self) -> None:
self._task = asyncio.create_task(self._monitor())

async def close(self) -> None:
Expand All @@ -28,11 +27,11 @@ async def close(self) -> None:
await asyncio.to_thread(self._queue.put, None) # type: ignore
await self._task

async def __aenter__(self):
async def __aenter__(self) -> 'Monitor':
await self.open()
return self

async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(self, exc_type, exc_value, traceback): # type: ignore
del exc_type, exc_value, traceback
await self.close()

Expand Down
2 changes: 1 addition & 1 deletion nextline/plugin/plugins/session/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
pickling_support.install()


def _call_all(*funcs) -> None:
def _call_all(*funcs: Callable) -> None:
'''Execute callables and ignore return values.
Used to call multiple initializers in ProcessPoolExecutor.
Expand Down
4 changes: 2 additions & 2 deletions nextline/plugin/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def start() -> None:


@hookspec
async def close(exc_type=None, exc_value=None, traceback=None) -> None:
async def close(exc_type=None, exc_value=None, traceback=None) -> None: # type: ignore
pass


Expand Down Expand Up @@ -63,7 +63,7 @@ async def on_initialize_run(run_arg: spawned.RunArg) -> None:

@hookspec
@apluggy.asynccontextmanager
async def run():
async def run(): # type: ignore
yield


Expand Down
5 changes: 4 additions & 1 deletion nextline/spawned/call.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import sys
import threading
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Optional

from .types import TraceFunction


@contextmanager
def sys_trace(trace_func: TraceFunction, thread: Optional[bool] = True):
def sys_trace(
trace_func: TraceFunction, thread: Optional[bool] = True
) -> Iterator[None]:
'''Trace callables in the context and all threads created during the context.
Notes
Expand Down
7 changes: 4 additions & 3 deletions nextline/spawned/plugin/plugins/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def context(self) -> Iterator[None]:
yield
finally:
self._callback.close()
self._to_end and self._on_end(self._to_end)
if self._to_end:
self._on_end(self._to_end)

@hookimpl
def filtered(self) -> None:
Expand All @@ -57,7 +58,7 @@ def _on_start(self, current: Task | Thread) -> None:
self._counter() # increment the counter
self._hook.hook.on_start_task_or_thread()

def _on_end(self, ending: Task | Thread):
def _on_end(self, ending: Task | Thread) -> None:
# The "ending" is not the "current" unless it is the main thread.
self._logger.info(f'{self.__class__.__name__}._on_end: {ending}')
self._hook.hook.on_end_task_or_thread(task_or_thread=ending)
Expand Down Expand Up @@ -89,7 +90,7 @@ def on_start_task_or_thread(self) -> None:
self._hook.hook.on_start_trace(trace_no=trace_no)

@hookimpl
def on_end_task_or_thread(self, task_or_thread: Task | Thread):
def on_end_task_or_thread(self, task_or_thread: Task | Thread) -> None:
trace_no = self._map[task_or_thread]
self._hook.hook.on_end_trace(trace_no=trace_no)
self._logger.info(f'{self.__class__.__name__} end: trace_no={trace_no}')
Expand Down
4 changes: 2 additions & 2 deletions nextline/spawned/plugin/plugins/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class FilterByModuleName:
'''Skip Python modules with names that match any of the patterns.'''

@hookimpl
def init(self, modules_to_skip: Iterable[str]):
def init(self, modules_to_skip: Iterable[str]) -> None:
self._patterns = frozenset(modules_to_skip)

# NOTE: Use lru_cache() as match_any() is slow
Expand Down Expand Up @@ -91,7 +91,7 @@ def on_cmdloop(self) -> Generator[None, str, None]:
self._add(trace_args)
yield

def _add(self, trace_args: TraceArgs):
def _add(self, trace_args: TraceArgs) -> None:
frame, _, _ = trace_args
module_name = frame.f_globals.get('__name__')
if module_name is None:
Expand Down
11 changes: 8 additions & 3 deletions nextline/spawned/plugin/plugins/global_.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from logging import getLogger
from typing import Optional
from types import FrameType
from typing import Any, Optional

from apluggy import PluginManager

Expand All @@ -15,7 +16,9 @@ def init(self, hook: PluginManager) -> None:

@hookimpl
def create_trace_func(self) -> TraceFunction:
def _trace_func(frame, event, arg) -> Optional[TraceFunction]:
def _trace_func(
frame: FrameType, event: str, arg: Any
) -> Optional[TraceFunction]:
try:
return self._hook.hook.global_trace_func(
frame=frame, event=event, arg=arg
Expand All @@ -33,7 +36,9 @@ def init(self, hook: PluginManager) -> None:
self._hook = hook

@hookimpl
def global_trace_func(self, frame, event, arg) -> Optional[TraceFunction]:
def global_trace_func(
self, frame: FrameType, event: str, arg: Any
) -> Optional[TraceFunction]:
if self._hook.hook.filter(trace_args=(frame, event, arg)):
return None
self._hook.hook.filtered(trace_args=(frame, event, arg))
Expand Down
15 changes: 9 additions & 6 deletions nextline/spawned/plugin/plugins/local_.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections import defaultdict
from collections.abc import Iterator
from types import FrameType
from typing import Callable, Optional
from typing import Any, Callable, Optional

from apluggy import PluginManager, contextmanager
from exceptiongroup import BaseExceptionGroup, catch
from exceptiongroup import catch

from nextline.spawned.plugin.spec import hookimpl
from nextline.spawned.types import RunResult, TraceArgs, TraceFunction
Expand All @@ -28,7 +29,9 @@ def init(self, hook: PluginManager) -> None:
self._map = defaultdict[TraceNo, TraceFunction](factory)

@hookimpl
def local_trace_func(self, frame: FrameType, event, arg) -> Optional[TraceFunction]:
def local_trace_func(
self, frame: FrameType, event: str, arg: Any
) -> Optional[TraceFunction]:
trace_no = self._hook.hook.current_trace_no()
local_trace_func = self._map[trace_no]
return local_trace_func(frame, event, arg)
Expand All @@ -53,12 +56,12 @@ def _factory() -> TraceFunction:
trace = hook.hook.create_local_trace_func()

@contextmanager
def _context(frame, event, arg):
def _context(frame: FrameType, event: str, arg: Any) -> Iterator[None]:
'''A "with" block in which "trace" is called.'''

keyboard_interrupt_raised = False

def _keyboard_interrupt(exc: BaseExceptionGroup) -> None:
def _keyboard_interrupt(exc: BaseException) -> None:
nonlocal keyboard_interrupt_raised
keyboard_interrupt_raised = True

Expand Down Expand Up @@ -97,7 +100,7 @@ def init(self, hook: PluginManager) -> None:

@hookimpl
@contextmanager
def on_trace_call(self, trace_args: TraceArgs):
def on_trace_call(self, trace_args: TraceArgs) -> Iterator[None]:
trace_no = self._hook.hook.current_trace_no()
self._traces_on_call.add(trace_no)
self._trace_args_map[trace_no] = trace_args
Expand Down
Loading

0 comments on commit a3da610

Please sign in to comment.