Skip to content

Commit

Permalink
mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Dec 22, 2024
1 parent 6479a06 commit 4b908ba
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
23 changes: 12 additions & 11 deletions logpyle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ class _DependencyData:
@dataclass
class _WatchInfo:
parsed: ExpressionNode
expr: ExpressionNode
expr: str
dep_data: list[_DependencyData]
compiled: CompiledExpression
unit: str | None
Expand Down Expand Up @@ -885,7 +885,7 @@ def add_watches(self, watches: list[str | tuple[str, str]]) -> None:
any(dd.nonlocal_agg for dd in dep_data)

from pymbolic import compile
compiled = compile(parsed, [dd.varname for dd in dep_data])
compiled = compile(parsed, [dd.varname for dd in dep_data]) # type: ignore[no-untyped-call]

watch_info = _WatchInfo(parsed=parsed, expr=expr, dep_data=dep_data,
compiled=compiled, unit=unit, format=fmt)
Expand Down Expand Up @@ -1092,14 +1092,14 @@ def add_internal(name: str, unit: str | None, description: str | None,

self.save()

def get_expr_dataset(self, expression: ExpressionNode,
def get_expr_dataset(self, expression: str,
description: str | None = None,
unit: str | None = None) \
-> tuple[str | Any, str | Any,
list[tuple[int, Any]]]:
"""Prepare a time-series dataset for a given expression.
:arg expression: A :mod:`pymbolic` expression that may involve
:arg expression: A :mod:`pymbolic`-like expression that may involve
the time-series variables and the constants in this :class:`LogManager`.
If there is data from multiple ranks for a quantity occurring in
this expression, an aggregator may have to be specified.
Expand Down Expand Up @@ -1127,7 +1127,7 @@ def get_expr_dataset(self, expression: ExpressionNode,
if unit is None:
from pymbolic import parse, substitute

unit_dict = {dd.varname: dd.qdat.unit for dd in dep_data}
unit_dict: dict[str, Any] = {dd.varname: dd.qdat.unit for dd in dep_data}
from pytools import all
if all(v is not None for v in unit_dict.values()):
unit_dict = {k: parse(v) for k, v in unit_dict.items()}
Expand All @@ -1140,7 +1140,7 @@ def get_expr_dataset(self, expression: ExpressionNode,

# compile and evaluate
from pymbolic import compile
compiled = compile(parsed, [dd.varname for dd in dep_data])
compiled = compile(parsed, [dd.varname for dd in dep_data]) # type: ignore[no-untyped-call]

data = []

Expand All @@ -1153,7 +1153,8 @@ def get_expr_dataset(self, expression: ExpressionNode,

return (description, unit, data)

def get_joint_dataset(self, expressions: Sequence[ExpressionNode]) -> list[Any]:
def get_joint_dataset(self, expressions: Sequence[str | tuple[str, str, str]]) \
-> list[Any]:
"""Return a joint data set for a list of expressions.
:arg expressions: a list of either strings representing
Expand Down Expand Up @@ -1186,7 +1187,7 @@ def get_joint_dataset(self, expressions: Sequence[ExpressionNode]) -> list[Any]:

return zipped_dubs

def get_plot_data(self, expr_x: ExpressionNode, expr_y: ExpressionNode,
def get_plot_data(self, expr_x: str, expr_y: str,
min_step: int | None = None,
max_step: int | None = None) \
-> tuple[tuple[Any, str, str], tuple[Any, str, str]]:
Expand All @@ -1212,8 +1213,8 @@ def get_plot_data(self, expr_x: ExpressionNode, expr_y: ExpressionNode,
return (data_x, descr_x, unit_x), \
(data_y, descr_y, unit_y)

def write_datafile(self, filename: str, expr_x: ExpressionNode,
expr_y: ExpressionNode) -> None:
def write_datafile(self, filename: str, expr_x: str,
expr_y: str) -> None:
(data_x, label_x, _), (data_y, label_y, _) = self.get_plot_data(
expr_x, expr_y)

Expand All @@ -1223,7 +1224,7 @@ def write_datafile(self, filename: str, expr_x: ExpressionNode,
outf.write(f"{dx!r}\t{dy!r}\n")
outf.close()

def plot_matplotlib(self, expr_x: ExpressionNode, expr_y: ExpressionNode) -> None:
def plot_matplotlib(self, expr_x: str, expr_y: str) -> None:
from matplotlib.pyplot import plot, xlabel, ylabel

(data_x, descr_x, unit_x), (data_y, descr_y, unit_y) = \
Expand Down
1 change: 1 addition & 0 deletions logpyle/runalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def plot_cursor(self, cursor: Cursor, labels: list[str] | None = None, # noqa:

x, y = list(zip(*list(cursor), strict=False))
p = plot(x, y, *args, **kwargs)
assert p[0].axes

if isinstance(labels, list) and len(labels) == 2:
p[0].axes.set_xlabel(labels[0])
Expand Down

0 comments on commit 4b908ba

Please sign in to comment.