Skip to content

Commit

Permalink
Merge pull request #58 from flaxandteal/fix/mypy-hook-errors-for-new-…
Browse files Browse the repository at this point in the history
…devs

Mypy Errors Triggered in Local Development
  • Loading branch information
philtweir authored Jan 13, 2025
2 parents 229102b + b0635bf commit a56d3d9
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ ENV/
env.bak/
venv.bak/
.pixi/
.vscode/

# Spyder project settings
.spyderproject
Expand Down
13 changes: 9 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@ repos:
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
args: [--fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
rev: v1.8.0
hooks:
- id: mypy
args: [--strict, --install-types, --non-interactive]
additional_dependencies: [sympy, attrs, pytest, click, dask]
args: [
--strict,
--install-types,
--allow-subclassing-any,
--non-interactive,
]
additional_dependencies: [sympy, attrs, pytest, click, dask]
2 changes: 1 addition & 1 deletion src/dewret/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def render(
if output == "-":

@contextmanager
def _opener(key: str, _: str) -> Generator[IO[Any], None, None]:
def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]:
print(" ------ ", key, " ------ ")
yield sys.stdout
print()
Expand Down
3 changes: 2 additions & 1 deletion src/dewret/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
BasicType = str | float | bool | bytes | int | None
RawType = BasicType | list["RawType"] | dict[str, "RawType"]
FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...]
ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore
# Basic is from Sympy, which does not have type annotations, so ExprType cannot pass mypy
ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore # fmt: skip

U = TypeVar("U")
T = TypeVar("T")
Expand Down
24 changes: 14 additions & 10 deletions src/dewret/renderers/cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,15 @@ def from_step(cls, step: BaseStep) -> "StepDefinition":
Args:
step: step to convert.
"""
out: list[str] | dict[str, "CommandInputSchema"]
if attrs_has(step.return_type) or (is_dataclass(step.return_type) and isclass(step.return_type)):
out = to_output_schema("out", step.return_type)["fields"]
else:
out = ["out"]
return cls(
name=step.name,
run=step.task.name,
out=(to_output_schema("out", step.return_type)["fields"])
if attrs_has(step.return_type) or is_dataclass(step.return_type)
else ["out"],
out=out,
in_={
key: (
ReferenceDefinition.from_reference(param)
Expand Down Expand Up @@ -463,13 +466,14 @@ def to_output_schema(
for field in attrs_fields(typ)
}
elif is_dataclass(typ):
fields = {
str(field.name): cast(
CommandInputSchema, to_output_schema(field.name, field.type)
)
for field in dataclass_fields(typ)
}

fields = {}
for field in dataclass_fields(typ):
if isinstance(field.type, type) and issubclass(field.type, RawType | AttrsInstance | DataclassProtocol):
fields[str(field.name)] = cast(
CommandInputSchema, to_output_schema(field.name, field.type)
)
else:
raise TypeError("Types of fields in results must also be valid result-types themselves (string-defined types not currently allowed)")
if fields:
output = CommandOutputSchema(
type="record",
Expand Down
6 changes: 5 additions & 1 deletion src/dewret/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,11 @@ def find_field(self: FieldableProtocol, field: str | int, fallback_type: type |
type_hints = get_type_hints(parent_type, localns={parent_type.__name__: parent_type}, include_extras=True)
field_type = type_hints.get(field)
if field_type is None:
field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type
dataclass_field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type
if isinstance(dataclass_field_type, str):
# TODO: we could ask Python to resolve the str expression for us
raise TypeError("Dataclass fields must be provided as types directly, not str")
field_type = dataclass_field_type
except StopIteration:
raise AttributeError(f"Dataclass {parent_type} does not have field {field}") from None
elif attr_has(parent_type):
Expand Down

0 comments on commit a56d3d9

Please sign in to comment.