From 2e7e6e7eb8280f031bed8948f3b7da47bf769a43 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 20 Jul 2024 19:49:05 +0100 Subject: [PATCH 001/108] wip(renderers): can pick renderers dynamically --- src/dewret/__main__.py | 23 ++++-- src/dewret/render.py | 49 +++++++++++++ src/dewret/renderers/snakemake.py | 2 +- tests/_lib/frender.py | 117 ++++++++++++++++++++++++++++++ tests/test_render_module.py | 36 +++++++++ 5 files changed, 219 insertions(+), 8 deletions(-) create mode 100644 src/dewret/render.py create mode 100644 tests/_lib/frender.py create mode 100644 tests/test_render_module.py diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 059ec406..09e5e29a 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -21,12 +21,13 @@ import importlib from pathlib import Path +import re import yaml import sys import click import json -from .renderers.cwl import render as cwl_render +from .render import get_render_method, RawRenderModule, StructuredRenderModule from .tasks import Backend, construct @@ -45,11 +46,15 @@ default=Backend.DASK.name, help="Backend to use for workflow evaluation.", ) +@click.option( + "--renderer", + default="@cwl" +) @click.argument("workflow_py") @click.argument("task") @click.argument("arguments", nargs=-1) def render( - workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend + workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend, renderer: str ) -> None: """Render a workflow. @@ -70,18 +75,22 @@ def render( key, val = arg.split(":", 1) kwargs[key] = json.loads(val) + render_module: Path | RawRenderModule | StructuredRenderModule + if (mtch := re.match(r"@([a-z_0-9-.]+)", renderer)): + render_module = importlib.import_module(f"dewret.renderers.{mtch.group(1)}") + else: + render_module = Path(renderer) + + render = get_render_method(render_module, pretty=pretty) try: - cwl = cwl_render(construct(task_fn(**kwargs), simplify_ids=True)) + rendered = render(construct(task_fn(**kwargs), simplify_ids=True)) except Exception as exc: import traceback print(exc, exc.__cause__, exc.__context__) traceback.print_exc() else: - if pretty: - yaml.dump(cwl, sys.stdout, indent=2) - else: - print(cwl) + print(rendered) render() diff --git a/src/dewret/render.py b/src/dewret/render.py new file mode 100644 index 00000000..6a07dd7c --- /dev/null +++ b/src/dewret/render.py @@ -0,0 +1,49 @@ +import sys +import importlib +from pathlib import Path +from typing import Protocol, TypeVar, Any, Unpack, TypedDict +import yaml + +from .workflow import Workflow +from .utils import RawType +from .workflow import Workflow + +RenderConfiguration = TypeVar("RenderConfiguration", bound=dict[str, Any]) + +class RawRenderModule(Protocol): + def render_raw(self, workflow: Workflow, **kwargs: RenderConfiguration) -> str | tuple[str, dict[str, str]]: + ... + +class StructuredRenderModule(Protocol): + def render(self, workflow: Workflow, **kwargs: RenderConfiguration) -> RawType | tuple[str, dict[str, RawType]]: + ... + +def structured_to_raw(rendered: RawType, pretty: bool=False) -> str: + if pretty: + output = yaml.dumps(rendered, indent=2) + else: + output = str(rendered) + return output + +def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, pretty: bool=False): + render_module: RawRenderModule | StructuredRenderModule + if isinstance(renderer, Path): + if (render_dir := str(renderer.parent)) not in sys.path: + sys.path.append(render_dir) + loader = importlib.machinery.SourceFileLoader("renderer", str(renderer)) + render_module = loader.load_module() + else: + render_module = renderer + if hasattr(render_module, "render_raw"): + return render_module.render_raw + + def _render(workflow: Workflow, pretty=False, **kwargs: RenderConfiguration) -> str | tuple[str, dict[str, str]]: + rendered = render_module.render(workflow, **kwargs) + if isinstance(rendered, tuple) and len(rendered) == 2: + return structured_to_raw({ + "__root__": rendered[0], + **rendered[1] + }, pretty=pretty) + return structured_to_raw(rendered, pretty=pretty) + + return _render diff --git a/src/dewret/renderers/snakemake.py b/src/dewret/renderers/snakemake.py index c2f809dd..f60be740 100644 --- a/src/dewret/renderers/snakemake.py +++ b/src/dewret/renderers/snakemake.py @@ -240,7 +240,7 @@ def from_step(cls, step: BaseStep) -> "OutputDefinition": """ # TODO: Error handling # TODO: Better way to handling input/output files - output_file = step.arguments["output_file"] + output_file = step.arguments.get("output_file", Raw("OUTPUT_FILE")) if isinstance(output_file, Raw): args = to_snakemake_type(output_file) diff --git a/tests/_lib/frender.py b/tests/_lib/frender.py new file mode 100644 index 00000000..ac6d09a3 --- /dev/null +++ b/tests/_lib/frender.py @@ -0,0 +1,117 @@ +"""Testing example renderer. + +'Friendly render', outputting human-readable descriptions. +""" + +from textwrap import indent +from typing import Unpack, TypedDict, Any +from dataclasses import dataclass +from contextvars import ContextVar + +from dewret.utils import RawType +from dewret.workflow import Workflow, Step, NestedStep + +from extra import JUMP + +class FrenderRendererConfiguration(TypedDict): + allow_complex_types: bool + +CONFIGURATION: ContextVar[FrenderRendererConfiguration] = ContextVar("configuration") +CONFIGURATION.set({ + "allow_complex_types": True +}) + +@dataclass +class NestedStepDefinition: + name: str + subworkflow_name: str + + @classmethod + def from_nested_step(cls, nested_step: NestedStep): + return cls( + name=nested_step.name, + subworkflow_name=nested_step.subworkflow.name + ) + + def render(self): + return \ +f""" +A portal called {self.name} to another workflow, +whose name is {self.subworkflow_name} +""" + +@dataclass +class StepDefinition: + name: str + + @classmethod + def from_step(cls, step: Step): + return cls( + name=step.name + ) + + def render(self): + return \ +f""" +Something called {self.name} +""" + + +@dataclass +class WorkflowDefinition: + name: str + steps: list[StepDefinition | NestedStepDefinition] + + @classmethod + def from_workflow(cls, workflow: Workflow): + steps = [] + for step in workflow.steps: + if isinstance(step, Step): + steps.append(StepDefinition.from_step(step)) + elif isinstance(step, NestedStep): + steps.append(NestedStepDefinition.from_nested_step(step)) + else: + raise RuntimeError(f"Unrecognised step type: {type(step)}") + + try: + name = workflow.name + except NameError: + name = "Work Doe" + return cls(name=name, steps=steps) + + def render(self): + return \ +f""" +I found a workflow called {self.name}. +It has {len(self.steps)} steps! +They are: +{"\n".join("* " + indent(step.render(), " ")[3:] for step in self.steps)} +It probably got made with JUMP={JUMP} +""" + +def render_raw( + workflow: Workflow, **kwargs: Unpack[FrenderRendererConfiguration] +) -> str | tuple[str, dict[str, str]]: + """Render to a dict-like structure. + + Args: + workflow: workflow to evaluate result. + **kwargs: additional configuration arguments - these should match CWLRendererConfiguration. + + Returns: + Reduced form as a native Python dict structure for + serialization. + """ + CONFIGURATION.get().update(kwargs) + primary_workflow = WorkflowDefinition.from_workflow(workflow).render() + subworkflows = {} + for step in workflow.steps: + if isinstance(step, NestedStep): + subworkflows[step.name] = WorkflowDefinition.from_workflow( + step.subworkflow + ).render() + + if subworkflows: + return primary_workflow, subworkflows + + return primary_workflow diff --git a/tests/test_render_module.py b/tests/test_render_module.py new file mode 100644 index 00000000..8cac3d13 --- /dev/null +++ b/tests/test_render_module.py @@ -0,0 +1,36 @@ +from pathlib import Path +from dewret.tasks import construct, task, factory +from dewret.render import get_render_method + +from ._lib.extra import increment, double, mod10, sum, triple_and_one + +def test_can_load_render_module(): + result = triple_and_one(num=increment(num=3)) + workflow = construct(result, simplify_ids=True) + workflow._name = "Fred" + + frender_py = Path(__file__).parent / "_lib/frender.py" + render = get_render_method(frender_py) + + assert render(workflow) == (""" +I found a workflow called Fred. +It has 2 steps! +They are: +* Something called increment-1 + +* A portal called triple_and_one-1 to another workflow, + whose name is triple_and_one + +It probably got made with JUMP=1.0 +""", {"triple_and_one-1": """ +I found a workflow called triple_and_one. +It has 3 steps! +They are: +* Something called double-1-1 + +* Something called sum-1-1 + +* Something called sum-1-2 + +It probably got made with JUMP=1.0 +"""}) From aad87a59afd5ebe7ef7f6d0899507128d547b8ce Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 20 Jul 2024 20:43:12 +0100 Subject: [PATCH 002/108] fix(renderers): make the renderer semantics less janky --- src/dewret/__main__.py | 5 +++- src/dewret/render.py | 49 +++++++++++++++++++++++-------- src/dewret/renderers/cwl.py | 44 ++++++++++----------------- src/dewret/renderers/snakemake.py | 12 +++++--- tests/_lib/frender.py | 19 ++++-------- tests/test_cwl.py | 22 ++++++++------ tests/test_modularity.py | 2 +- tests/test_multiresult_steps.py | 10 +++---- tests/test_parameters.py | 5 ++-- tests/test_render_module.py | 6 ++-- tests/test_subworkflows.py | 15 ++++++---- 11 files changed, 103 insertions(+), 86 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 09e5e29a..dd4cad06 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -90,7 +90,10 @@ def render( print(exc, exc.__cause__, exc.__context__) traceback.print_exc() else: - print(rendered) + if len(rendered) == 1: + print(rendered["__root__"]) + else: + print(rendered) render() diff --git a/src/dewret/render.py b/src/dewret/render.py index 6a07dd7c..0a17aa00 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -1,21 +1,22 @@ import sys import importlib from pathlib import Path -from typing import Protocol, TypeVar, Any, Unpack, TypedDict +from functools import partial +from typing import Protocol, TypeVar, Any, Unpack, TypedDict, Callable import yaml -from .workflow import Workflow +from .workflow import Workflow, NestedStep from .utils import RawType from .workflow import Workflow RenderConfiguration = TypeVar("RenderConfiguration", bound=dict[str, Any]) class RawRenderModule(Protocol): - def render_raw(self, workflow: Workflow, **kwargs: RenderConfiguration) -> str | tuple[str, dict[str, str]]: + def render_raw(self, workflow: Workflow, **kwargs: RenderConfiguration) -> dict[str, str]: ... class StructuredRenderModule(Protocol): - def render(self, workflow: Workflow, **kwargs: RenderConfiguration) -> RawType | tuple[str, dict[str, RawType]]: + def render(self, workflow: Workflow, **kwargs: RenderConfiguration) -> dict[str, dict[str, RawType]]: ... def structured_to_raw(rendered: RawType, pretty: bool=False) -> str: @@ -37,13 +38,35 @@ def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, if hasattr(render_module, "render_raw"): return render_module.render_raw - def _render(workflow: Workflow, pretty=False, **kwargs: RenderConfiguration) -> str | tuple[str, dict[str, str]]: + def _render(workflow: Workflow, render_module: StructuredRenderModule, pretty=False, **kwargs: RenderConfiguration) -> dict[str, str]: rendered = render_module.render(workflow, **kwargs) - if isinstance(rendered, tuple) and len(rendered) == 2: - return structured_to_raw({ - "__root__": rendered[0], - **rendered[1] - }, pretty=pretty) - return structured_to_raw(rendered, pretty=pretty) - - return _render + return { + key: structured_to_raw(value, pretty=pretty) + for key, value in rendered.items() + } + + return partial(_render, render_module=render_module) + +T = TypeVar("T") +def base_render( + workflow: Workflow, build_cb: Callable[[Workflow], T] +) -> dict[str, T]: + """Render to a dict-like structure. + + Args: + workflow: workflow to evaluate result. + **kwargs: additional configuration arguments - these should match CWLRendererConfiguration. + + Returns: + Reduced form as a native Python dict structure for + serialization. + """ + primary_workflow = build_cb(workflow) + subworkflows = {} + for step in workflow.steps: + if isinstance(step, NestedStep): + nested_subworkflows = base_render(step.subworkflow, build_cb) + subworkflows.update(nested_subworkflows) + subworkflows[step.name] = nested_subworkflows["__root__"] + subworkflows["__root__"] = primary_workflow + return subworkflows diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index b51c3e58..9832272a 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -37,6 +37,7 @@ Unset, ) from dewret.utils import RawType, flatten, DataclassProtocol +from dewret.render import base_render InputSchemaType = Union[ str, "CommandInputSchema", list[str], list["InputSchemaType"], dict[str, str] @@ -57,21 +58,12 @@ class CWLRendererConfiguration(TypedDict): CONFIGURATION: ContextVar[CWLRendererConfiguration] = ContextVar("cwl-configuration") - - -def set_configuration(configuration: CWLRendererConfiguration) -> None: - """Set configuration for this rendering. - - Args: - configuration: overridden settings as dict. - """ - CONFIGURATION.set( - CWLRendererConfiguration( - allow_complex_types=False, - factories_as_params=False, - ) - ) - CONFIGURATION.get().update(configuration) +DEFAULT_CONFIGURATION: CWLRendererConfiguration = { + "allow_complex_types": False, + "factories_as_params": False, +} +CONFIGURATION.set({}) +CONFIGURATION.get().update(DEFAULT_CONFIGURATION) def configuration(key: str) -> Any: @@ -544,7 +536,7 @@ def render(self) -> dict[str, RawType]: def render( workflow: Workflow, **kwargs: Unpack[CWLRendererConfiguration] -) -> dict[str, RawType] | tuple[dict[str, RawType], dict[str, dict[str, RawType]]]: +) -> dict[str, dict[str, RawType]]: """Render to a dict-like structure. Args: @@ -555,16 +547,10 @@ def render( Reduced form as a native Python dict structure for serialization. """ - set_configuration(kwargs) - primary_workflow = WorkflowDefinition.from_workflow(workflow).render() - subworkflows = {} - for step in workflow.steps: - if isinstance(step, NestedStep): - subworkflows[step.name] = WorkflowDefinition.from_workflow( - step.subworkflow - ).render() - - if subworkflows: - return primary_workflow, subworkflows - - return primary_workflow + CONFIGURATION.get().update(kwargs) + rendered = base_render( + workflow, + lambda workflow: WorkflowDefinition.from_workflow(workflow).render() + ) + CONFIGURATION.get().update(DEFAULT_CONFIGURATION) + return rendered diff --git a/src/dewret/renderers/snakemake.py b/src/dewret/renderers/snakemake.py index f60be740..f721ceb7 100644 --- a/src/dewret/renderers/snakemake.py +++ b/src/dewret/renderers/snakemake.py @@ -35,6 +35,7 @@ Lazy, BaseStep, ) +from dewret.render import base_render MainTypes = typing.Union[ BasicType, list[str], list["MainTypes"], dict[str, "MainTypes"] @@ -448,7 +449,7 @@ def raw_render(workflow: Workflow) -> dict[str, MainTypes]: return WorkflowDefinition.from_workflow(workflow).render() -def render(workflow: Workflow) -> str: +def render(workflow: Workflow) -> dict[str, typing.Any]: """Render the workflow as a Snakemake (SMK) string. This function converts a Workflow object into a Snakemake-compatible yaml. @@ -468,6 +469,9 @@ def render(workflow: Workflow) -> str: } ) - return yaml.dump( - WorkflowDefinition.from_workflow(workflow).render(), indent=4 - ).translate(trans_table) + return base_render( + workflow, + lambda workflow: yaml.dump( + WorkflowDefinition.from_workflow(workflow).render(), indent=4 + ).translate(trans_table) + ) diff --git a/tests/_lib/frender.py b/tests/_lib/frender.py index ac6d09a3..33430690 100644 --- a/tests/_lib/frender.py +++ b/tests/_lib/frender.py @@ -10,6 +10,7 @@ from dewret.utils import RawType from dewret.workflow import Workflow, Step, NestedStep +from dewret.render import base_render from extra import JUMP @@ -91,7 +92,7 @@ def render(self): def render_raw( workflow: Workflow, **kwargs: Unpack[FrenderRendererConfiguration] -) -> str | tuple[str, dict[str, str]]: +) -> dict[str, str]: """Render to a dict-like structure. Args: @@ -103,15 +104,7 @@ def render_raw( serialization. """ CONFIGURATION.get().update(kwargs) - primary_workflow = WorkflowDefinition.from_workflow(workflow).render() - subworkflows = {} - for step in workflow.steps: - if isinstance(step, NestedStep): - subworkflows[step.name] = WorkflowDefinition.from_workflow( - step.subworkflow - ).render() - - if subworkflows: - return primary_workflow, subworkflows - - return primary_workflow + return base_render( + workflow, + lambda workflow: WorkflowDefinition.from_workflow(workflow).render() + ) diff --git a/tests/test_cwl.py b/tests/test_cwl.py index bdf5a5f9..b29b2f47 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -50,7 +50,7 @@ def test_basic_cwl() -> None: """ result = pi() workflow = construct(result) - rendered = render(workflow) + rendered = render(workflow)["__root__"] hsh = hasher(("pi",)) assert rendered == yaml.safe_load(f""" @@ -83,7 +83,7 @@ def get_now() -> datetime: now = factory(get_now)() result = days_in_future(now=now, num=3) workflow = construct(result, simplify_ids=True) - rendered = render(workflow, allow_complex_types=True, factories_as_params=True) + rendered = render(workflow, allow_complex_types=True, factories_as_params=True)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -112,7 +112,7 @@ def get_now() -> datetime: out: [out] """) - rendered = render(workflow, allow_complex_types=True) + rendered = render(workflow, allow_complex_types=True)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -151,7 +151,7 @@ def test_cwl_with_parameter() -> None: """ result = increment(num=3) workflow = construct(result) - rendered = render(workflow) + rendered = render(workflow)["__root__"] num_param = list(workflow.find_parameters())[0] hsh = hasher(("increment", ("num", f"int|:param:{num_param.unique_name}"))) @@ -187,7 +187,7 @@ def test_cwl_without_default() -> None: result = increment(num=my_param) workflow = construct(result) - rendered = render(workflow) + rendered = render(workflow)["__root__"] hsh = hasher(("increment", ("num", "int|:param:my_param"))) assert rendered == yaml.safe_load(f""" @@ -217,7 +217,9 @@ def test_cwl_with_subworkflow() -> None: my_param = param("num", typ=int) result = increment(num=floor(num=triple_and_one(num=increment(num=my_param)))) workflow = construct(result, simplify_ids=True) - rendered, subworkflows = render(workflow) + subworkflows = render(workflow) + rendered = subworkflows["__root__"] + del subworkflows["__root__"] assert len(subworkflows) == 1 assert isinstance(subworkflows, dict) @@ -316,7 +318,7 @@ def test_cwl_references() -> None: """ result = double(num=increment(num=3)) workflow = construct(result) - rendered = render(workflow) + rendered = render(workflow)["__root__"] num_param = list(workflow.find_parameters())[0] hsh_increment = hasher( ("increment", ("num", f"int|:param:{num_param.unique_name}")) @@ -361,7 +363,7 @@ def test_complex_cwl_references() -> None: """ result = sum(left=double(num=increment(num=23)), right=mod10(num=increment(num=23))) workflow = construct(result, simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -423,8 +425,10 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: my_param = param("num", typ=int) result = increment(num=floor(num=triple_and_one(num=sum(left=my_param, right=3)))) workflow = construct(result, simplify_ids=True) - rendered, subworkflows = render(workflow) + subworkflows = render(workflow) + rendered = subworkflows["__root__"] + del subworkflows["__root__"] assert len(subworkflows) == 1 assert isinstance(subworkflows, dict) name, subworkflow = list(subworkflows.items())[0] diff --git a/tests/test_modularity.py b/tests/test_modularity.py index 908a3adb..f19a067a 100644 --- a/tests/test_modularity.py +++ b/tests/test_modularity.py @@ -23,7 +23,7 @@ def test_nested_task() -> None: Produces CWL that has references between multiple steps. """ workflow = construct(algorithm(), simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 diff --git a/tests/test_multiresult_steps.py b/tests/test_multiresult_steps.py index 7457f6c6..caefa8f8 100644 --- a/tests/test_multiresult_steps.py +++ b/tests/test_multiresult_steps.py @@ -82,7 +82,7 @@ def test_nested_task() -> None: Produces CWL that has references between multiple steps. """ workflow = construct(split(), simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -117,7 +117,7 @@ def test_nested_task() -> None: def test_field_of_nested_task() -> None: """Tests whether a directly-output nested task can have fields.""" workflow = construct(split().first, simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -145,7 +145,7 @@ def test_field_of_nested_task() -> None: def test_field_of_nested_task_into_dataclasses() -> None: """Tests whether a directly-output nested task can have fields.""" workflow = construct(split_into_dataclass().first, simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -173,7 +173,7 @@ def test_field_of_nested_task_into_dataclasses() -> None: def test_complex_field_of_nested_task() -> None: """Tests whether a task can sum complex structures.""" workflow = construct(algorithm(), simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -210,7 +210,7 @@ def test_complex_field_of_nested_task_with_dataclasses() -> None: """Tests whether a task can insert result fields into other steps.""" result = algorithm_with_dataclasses() workflow = construct(result, simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 95717b87..c52b55b7 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -23,7 +23,7 @@ def test_cwl_parameters() -> None: """ result = rotate(num=3) workflow = construct(result, simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -62,7 +62,8 @@ def test_complex_parameters() -> None: num = param("numx", 23) result = sum(left=double(num=rotate(num=num)), right=rotate(num=rotate(num=23))) workflow = construct(result, simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] + assert rendered == yaml.safe_load(""" cwlVersion: 1.2 class: Workflow diff --git a/tests/test_render_module.py b/tests/test_render_module.py index 8cac3d13..1cc36225 100644 --- a/tests/test_render_module.py +++ b/tests/test_render_module.py @@ -12,7 +12,7 @@ def test_can_load_render_module(): frender_py = Path(__file__).parent / "_lib/frender.py" render = get_render_method(frender_py) - assert render(workflow) == (""" + assert render(workflow) == {"__root__": """ I found a workflow called Fred. It has 2 steps! They are: @@ -22,7 +22,7 @@ def test_can_load_render_module(): whose name is triple_and_one It probably got made with JUMP=1.0 -""", {"triple_and_one-1": """ +""", "triple_and_one-1": """ I found a workflow called triple_and_one. It has 3 steps! They are: @@ -33,4 +33,4 @@ def test_can_load_render_module(): * Something called sum-1-2 It probably got made with JUMP=1.0 -"""}) +"""} diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 41b48fc2..a0219013 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -59,9 +59,10 @@ def test_subworkflows_can_use_globals() -> None: my_param = param("num", typ=int) result = increment(num=add_constant(num=increment(num=my_param))) workflow = construct(result, simplify_ids=True) - rendered, subworkflows = render(workflow) + subworkflows = render(workflow) + rendered = subworkflows["__root__"] - assert len(subworkflows) == 1 + assert len(subworkflows) == 2 assert isinstance(subworkflows, dict) assert rendered == yaml.safe_load(""" @@ -109,9 +110,10 @@ def test_subworkflows_can_use_factories() -> None: my_param = param("num", typ=int) result = pop(queue=make_queue(num=increment(num=my_param))) workflow = construct(result, simplify_ids=True) - rendered, subworkflows = render(workflow, allow_complex_types=True) + subworkflows = render(workflow, allow_complex_types=True) + rendered = subworkflows["__root__"] - assert len(subworkflows) == 1 + assert len(subworkflows) == 2 assert isinstance(subworkflows, dict) assert rendered == yaml.safe_load(""" @@ -153,9 +155,10 @@ def test_subworkflows_can_use_global_factories() -> None: my_param = param("num", typ=int) result = pop(queue=get_global_queue(num=increment(num=my_param))) workflow = construct(result, simplify_ids=True) - rendered, subworkflows = render(workflow, allow_complex_types=True) + subworkflows = render(workflow, allow_complex_types=True) + rendered = subworkflows["__root__"] - assert len(subworkflows) == 1 + assert len(subworkflows) == 2 assert isinstance(subworkflows, dict) assert rendered == yaml.safe_load(""" From c58c17e79b19a83188cb6e1c18676b30e03167cc Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 20 Jul 2024 20:55:13 +0100 Subject: [PATCH 003/108] feat(renderers): add the ability to pass renderer arguments via CLI/YAML --- src/dewret/__main__.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index dd4cad06..a959f1fc 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -23,6 +23,7 @@ from pathlib import Path import re import yaml +from typing import Any import sys import click import json @@ -48,13 +49,17 @@ ) @click.option( "--renderer", - default="@cwl" + default="cwl" +) +@click.option( + "--renderer-args", + default="simplify_ids:true" ) @click.argument("workflow_py") @click.argument("task") @click.argument("arguments", nargs=-1) def render( - workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend, renderer: str + workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend, renderer: str, renderer_args: str ) -> None: """Render a workflow. @@ -76,14 +81,25 @@ def render( kwargs[key] = json.loads(val) render_module: Path | RawRenderModule | StructuredRenderModule - if (mtch := re.match(r"@([a-z_0-9-.]+)", renderer)): + if (mtch := re.match(r"^([a-z_0-9-.]+)$", renderer)): render_module = importlib.import_module(f"dewret.renderers.{mtch.group(1)}") + elif renderer.startswith("@"): + render_module = Path(renderer[1:]) + else: + raise RuntimeError("Renderer argument should be a known dewret renderer, or '@FILENAME' where FILENAME is a renderer") + + renderer_kwargs: dict[str, Any] + if renderer_args.startswith("@"): + with Path(renderer_args[1:]).open() as renderer_args_f: + renderer_kwargs = yaml.load(renderer_args_f) + elif not renderer_args: + renderer_kwargs = {} else: - render_module = Path(renderer) + renderer_kwargs = dict(pair.split(":") for pair in renderer_args.split(",")) render = get_render_method(render_module, pretty=pretty) try: - rendered = render(construct(task_fn(**kwargs), simplify_ids=True)) + rendered = render(construct(task_fn(**kwargs), **renderer_kwargs)) except Exception as exc: import traceback From d792f762600e6c23c230ffe07742716d96873e61 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 20 Jul 2024 21:25:51 +0100 Subject: [PATCH 004/108] feat(renderers): add way to output subworkflows to files --- src/dewret/__main__.py | 41 +++++++++++++++++++++++++++++++++++++---- src/dewret/render.py | 4 ++-- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index a959f1fc..77949e24 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -21,6 +21,8 @@ import importlib from pathlib import Path +from contextlib import contextmanager +import sys import re import yaml from typing import Any @@ -55,11 +57,15 @@ "--renderer-args", default="simplify_ids:true" ) +@click.option( + "--output", + default="-" +) @click.argument("workflow_py") @click.argument("task") @click.argument("arguments", nargs=-1) def render( - workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend, renderer: str, renderer_args: str + workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend, renderer: str, renderer_args: str, output: str ) -> None: """Render a workflow. @@ -97,6 +103,21 @@ def render( else: renderer_kwargs = dict(pair.split(":") for pair in renderer_args.split(",")) + if output == "-": + @contextmanager + def _opener(key, _): + print(" ------ ", key, " ------ ") + yield sys.stdout + print() + opener = _opener + else: + @contextmanager + def _opener(key, mode): + output_file = output.replace("%", key) + with Path(output_file).open(mode) as output_f: + yield output_f + opener = _opener + render = get_render_method(render_module, pretty=pretty) try: rendered = render(construct(task_fn(**kwargs), **renderer_kwargs)) @@ -107,9 +128,21 @@ def render( traceback.print_exc() else: if len(rendered) == 1: - print(rendered["__root__"]) + with opener("", "w") as output_f: + output_f.write(rendered["__root__"]) + elif "%" in output: + for key, value in rendered.items(): + if key == "__root__": + key = "ROOT" + with opener(key, "w") as output_f: + output_f.write(value) else: - print(rendered) - + with opener("ROOT", "w") as output_f: + output_f.write(rendered["__root__"]) + del rendered["__root__"] + for key, value in rendered.items(): + with opener(key, "a") as output_f: + output_f.write("\n---\n") + output_f.write(value) render() diff --git a/src/dewret/render.py b/src/dewret/render.py index 0a17aa00..70a34917 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -21,7 +21,7 @@ def render(self, workflow: Workflow, **kwargs: RenderConfiguration) -> dict[str, def structured_to_raw(rendered: RawType, pretty: bool=False) -> str: if pretty: - output = yaml.dumps(rendered, indent=2) + output = yaml.safe_dump(rendered, indent=2) else: output = str(rendered) return output @@ -45,7 +45,7 @@ def _render(workflow: Workflow, render_module: StructuredRenderModule, pretty=Fa for key, value in rendered.items() } - return partial(_render, render_module=render_module) + return partial(_render, render_module=render_module, pretty=pretty) T = TypeVar("T") def base_render( From b2c10e286f80482a4202cb152cda13205180c741 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Wed, 24 Jul 2024 15:32:24 +0100 Subject: [PATCH 005/108] feat: support list/tuple return values --- src/dewret/backends/_base.py | 2 +- src/dewret/backends/backend_dask.py | 23 ++-- src/dewret/renderers/cwl.py | 26 +++- src/dewret/tasks.py | 26 +++- src/dewret/workflow.py | 66 +++++++--- tests/test_subworkflows.py | 189 ++++++++++++++++++++++++++++ 6 files changed, 298 insertions(+), 34 deletions(-) diff --git a/src/dewret/backends/_base.py b/src/dewret/backends/_base.py index 3ad11bec..d2d2510c 100644 --- a/src/dewret/backends/_base.py +++ b/src/dewret/backends/_base.py @@ -32,7 +32,7 @@ class BackendModule(Protocol): """ lazy: LazyFactory - def run(self, workflow: Workflow, task: Lazy) -> StepReference[Any]: + def run(self, workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy]) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: """Execute a lazy task for this `Workflow`. Args: diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 017fb502..84818a0a 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -87,7 +87,7 @@ def is_lazy(task: Any) -> bool: lazy = delayed -def run(workflow: Workflow | None, task: Lazy) -> StepReference[Any]: +def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy]) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: """Execute a task as the output of a workflow. Runs a task with dask. @@ -96,10 +96,19 @@ def run(workflow: Workflow | None, task: Lazy) -> StepReference[Any]: workflow: `Workflow` in which to record the execution. task: `dask.delayed` function, wrapped by dewret, that we wish to compute. """ - # We need isinstance to reassure type-checker. - if not isinstance(task, Delayed) or not is_lazy(task): - raise RuntimeError( - f"{task} is not a dask delayed, perhaps you tried to mix backends?" - ) - result = task.compute(__workflow__=workflow) + + def _check_delayed(task: Lazy | list[Lazy] | tuple[Lazy]) -> Delayed: + # We need isinstance to reassure type-checker. + if isinstance(task, list) or isinstance(task, tuple): + lst: list[Delayed] | tuple[Delayed, ...] = [_check_delayed(elt) for elt in task] + if isinstance(task, tuple): + lst = tuple(lst) + return delayed(lst) + elif not isinstance(task, Delayed) or not is_lazy(task): + raise RuntimeError( + f"{task} is not a dask delayed, perhaps you tried to mix backends?" + ) + return task + computable = _check_delayed(task) + result = computable.compute(__workflow__=workflow) return result diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 9832272a..a111b3d5 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -435,11 +435,11 @@ class OutputsDefinition: outputs: sequence of results from a workflow. """ - outputs: dict[str, "CommandOutputSchema"] + outputs: dict[str, "CommandOutputSchema"] | list["CommandOutputSchema"] @classmethod def from_results( - cls, results: dict[str, StepReference[Any]] + cls, results: dict[str, StepReference[Any]] | list[StepReference[Any]] | tuple[StepReference[Any]] ) -> "OutputsDefinition": """Takes a mapping of results into a CWL structure. @@ -449,7 +449,12 @@ def from_results( CWL-like structure representing all workflow outputs. """ return cls( - outputs={ + outputs=[ + to_output_schema( + result.field, result.return_type, output_source=result.name + ) for result in results + ] + if isinstance(results, list | tuple) else { key: to_output_schema( result.field, result.return_type, output_source=result.name ) @@ -457,14 +462,19 @@ def from_results( } ) - def render(self) -> dict[str, RawType]: + def render(self) -> dict[str, RawType] | list[RawType]: """Render to a dict-like structure. Returns: Reduced form as a native Python dict structure for serialization. """ - return {key: flatten(output) for key, output in self.outputs.items()} + return [ + flatten(output) for output in self.outputs + ] if isinstance(self.outputs, list) else { + key: flatten(output) + for key, output in self.outputs.items() + } @define @@ -513,7 +523,11 @@ def from_workflow( ], inputs=InputsDefinition.from_parameters(parameters), outputs=OutputsDefinition.from_results( - {workflow.result.field: workflow.result} if workflow.result else {} + workflow.result + if isinstance(workflow.result, list | tuple) else + {workflow.result.field: workflow.result} + if workflow.result else + {} ), name=name, ) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index a0e6c804..bdf02ed3 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -132,7 +132,7 @@ def make_lazy(self) -> LazyFactory: """ return self.backend.lazy - def evaluate(self, task: Lazy, __workflow__: Workflow, **kwargs: Any) -> Any: + def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy], __workflow__: Workflow, **kwargs: Any) -> Any: """Evaluate a single task for a known workflow. Args: @@ -141,12 +141,26 @@ def evaluate(self, task: Lazy, __workflow__: Workflow, **kwargs: Any) -> Any: **kwargs: any arguments to pass to the task. """ result = self.backend.run(__workflow__, task, **kwargs) - result.__workflow__.set_result(result) - if __workflow__ is not None and result.__workflow__ != __workflow__: - workflow = Workflow.assimilate(__workflow__, result.__workflow__) + to_check: list[StepReference] | tuple[StepReference] + if isinstance(result, list | tuple): + to_check = result else: - workflow = result.__workflow__ - return workflow.result + to_check = [result] + + # Build a unified workflow + collected_workflow = __workflow__ or to_check[0].__workflow__ + for step_result in to_check: + new_workflow = step_result.__workflow__ + if collected_workflow != new_workflow and collected_workflow and new_workflow: + collected_workflow = Workflow.assimilate(collected_workflow, new_workflow) + + # Make sure all the results share it + for step_result in to_check: + step_result.__workflow__ = collected_workflow + + # Then we set the result to be the whole thing + collected_workflow.set_result(result) + return collected_workflow.result def unwrap(self, task: Lazy) -> Target: """Unwraps a lazy-evaluated function to get the function. diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 6e9e28d3..bb6cb1f1 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -330,7 +330,7 @@ class Workflow: steps: list["BaseStep"] tasks: MutableMapping[str, "Task"] - result: StepReference[Any] | None + result: StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]] | None _remapping: dict[str, str] | None _name: str | None @@ -467,15 +467,32 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": for step in new.steps: step.set_workflow(new, with_arguments=True) - # TODO: should we combine as a result array? - result = left.result or right.result + if left.result == right.result: + result = left.result + elif not left.result: + result = right.result + elif not right.result: + result = left.result + else: + if not isinstance(left.result, tuple | list): + left.result = [left.result] + if not isinstance(right.result, tuple | list): + right.result = [right.result] + result = list(left.result) + list(right.result) if result: - new.set_result( - StepReference( - new, result.step, typ=result.return_type, field=result.field + if isinstance(result, list | tuple): + new.set_result([ + StepReference( + new, entry.step, typ=entry.return_type, field=entry.field + ) for entry in result + ]) + else: + new.set_result( + StepReference( + new, result.step, typ=result.return_type, field=result.field + ) ) - ) return new @@ -592,19 +609,26 @@ def add_step( @staticmethod def from_result( - result: StepReference[Any], simplify_ids: bool = False, nested: bool = True + result: StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]], simplify_ids: bool = False, nested: bool = True ) -> Workflow: """Create from a desired result. Starts from a result, and builds a workflow to output it. """ - workflow = result.__workflow__ + if isinstance(result, list | tuple): + workflow = result[0].__workflow__ + # Ensure that we have exactly one workflow, even if multiple results. + for entry in result[1:]: + if entry.__workflow__ != workflow: + raise RuntimeError("If multiple results, they must share a single workflow") + else: + workflow = result.__workflow__ workflow.set_result(result) if simplify_ids: workflow.simplify_ids() return workflow - def set_result(self, result: StepReference[Any]) -> None: + def set_result(self, result: StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]) -> None: """Choose the result step. Sets a step as being the result for the entire workflow. @@ -616,10 +640,24 @@ def set_result(self, result: StepReference[Any]) -> None: Args: result: reference to the chosen step. """ - if result.step.__workflow__ != self: - raise RuntimeError("Output must be from a step in this workflow.") + if isinstance(result, list | tuple): + to_check = result + else: + to_check = [result] + for entry in to_check: + if entry.step.__workflow__ != self: + raise RuntimeError("Output must be from a step in this workflow.") self.result = result + @property + def result_type(self): + if self.result is None: + return type(None) + if isinstance(self.result, tuple | list): + # TODO: get individual types! + return type(self.result) + return self.result.return_type + class WorkflowComponent: """Base class for anything directly tied to an individual `Workflow`. @@ -774,7 +812,7 @@ def return_type(self) -> Any: """ if isinstance(self.task, Workflow): if self.task.result: - return self.task.result.return_type + return self.task.result_type else: raise AttributeError( "Cannot determine return type of a workflow with an unspecified result" @@ -866,7 +904,7 @@ def return_type(self) -> Any: """ if not self.__subworkflow__.result: raise RuntimeError("Can only use a subworkflow if the reference exists.") - return self.__subworkflow__.result.return_type + return self.__subworkflow__.result_type class Step(BaseStep): diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index a0219013..6448d6bb 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -47,6 +47,14 @@ def get_global_queue(num: int | float) -> "Queue[int]": """Add a number to a global queue.""" return add_and_queue(num=to_int(num=num), queue=GLOBAL_QUEUE) +@subworkflow() +def get_global_queues(num: int | float) -> list["Queue[int] | int"]: + """Add a number to a global queue.""" + return [ + add_and_queue(num=to_int(num=num), queue=GLOBAL_QUEUE), + add_constant(num=num) + ] + @subworkflow() def add_constant(num: int | float) -> int: @@ -193,3 +201,184 @@ def test_subworkflows_can_use_global_factories() -> None: out: [out] run: pop """) + + +def test_subworkflows_can_return_lists() -> None: + """Check whether we can produce a subworkflow that returns a list.""" + my_param = param("num", typ=int) + result = get_global_queues(num=increment(num=my_param)) + workflow = construct(result, simplify_ids=True) + subworkflows = render(workflow, allow_complex_types=True) + rendered = subworkflows["__root__"] + del subworkflows["__root__"] + + assert len(subworkflows) == 2 + assert isinstance(subworkflows, dict) + osubworkflows = sorted(list(subworkflows.items())) + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + num: + label: num + type: int + outputs: + out: + label: out + outputSource: get_global_queues-1/out + type: array + steps: + increment-1: + in: + num: + source: num + out: [out] + run: increment + get_global_queues-1: + in: + num: + source: increment-1/out + out: [out] + run: get_global_queues + """) + + assert osubworkflows[0] == ("add_constant-1-1", yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + num: + label: num + type: int + sum-1-1-1-right: + default: 3 + label: sum-1-1-1-right + type: int + outputs: + out: + label: out + outputSource: to_int-1-1-1/out + type: int + steps: + sum-1-1-1: + in: + left: + source: num + right: + source: sum-1-1-1-right + out: + - out + run: sum + to_int-1-1-1: + in: + num: + source: sum-1-1-1/out + out: + - out + run: to_int + """)) + + assert osubworkflows[1] == ("get_global_queues-1", yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + CONSTANT: + default: 3 + label: CONSTANT + type: int + num: + label: num + type: int + outputs: + - label: out + outputSource: add_and_queue-1-1/out + type: Queue[int] + - label: out + outputSource: add_constant-1-1/out + type: int + steps: + Queue-1-1: + in: {} + out: + - out + run: Queue + add_and_queue-1-1: + in: + num: + source: to_int-1-1/out + queue: + source: Queue-1-1/out + out: + - out + run: add_and_queue + add_constant-1-1: + in: + CONSTANT: + source: CONSTANT + num: + source: num + out: + - out + run: add_constant + to_int-1-1: + in: + num: + source: num + out: + - out + run: to_int + """)) + +def test_can_merge_workflows() -> None: + """Check whether we can merge workflows.""" + my_param = param("num", typ=int) + value = to_int(num=increment(num=my_param)) + result = sum(left=value, right=increment(num=value)) + workflow = construct(result, simplify_ids=True) + subworkflows = render(workflow, allow_complex_types=True) + rendered = subworkflows["__root__"] + del subworkflows["__root__"] + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + num: + label: num + type: int + outputs: + out: + label: out + outputSource: sum-1/out + type: [ + int, + double + ] + steps: + increment-1: + in: + num: + source: num + out: [out] + run: increment + increment-2: + in: + num: + source: to_int-1/out + out: [out] + run: increment + sum-1: + in: + left: + source: to_int-1/out + right: + source: increment-2/out + out: [out] + run: sum + to_int-1: + in: + num: + source: increment-1/out + out: [out] + run: to_int + """) From f9e6c9bae7c9ff3ab5af5d142547053e5b91a6d4 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 3 Aug 2024 20:24:49 +0100 Subject: [PATCH 006/108] feat(annotations): add ability to annotate variables --- src/dewret/annotations.py | 52 ++++++++++++++ src/dewret/renderers/cwl.py | 2 +- src/dewret/tasks.py | 25 +++++-- src/dewret/workflow.py | 33 ++++++++- tests/test_annotations.py | 131 ++++++++++++++++++++++++++++++++++++ 5 files changed, 234 insertions(+), 9 deletions(-) create mode 100644 src/dewret/annotations.py create mode 100644 tests/test_annotations.py diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py new file mode 100644 index 00000000..6e6643bf --- /dev/null +++ b/src/dewret/annotations.py @@ -0,0 +1,52 @@ +import inspect +from functools import lru_cache +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args + +T = TypeVar("T") +AtConstruct = Annotated[T, "AtConstruct"] + +class FunctionAnalyser: + _fn: Callable[..., Any] + _annotations: dict[str, Any] + + def __init__(self, fn: Callable[..., Any]): + self.fn = ( + fn.__init__ + if inspect.isclass(fn) else + fn.__func__ + if inspect.ismethod(fn) else + fn + ) + + @property + @lru_cache + def all_annotations(self): + try: + self._annotations = self.fn.__globals__["__annotations__"] + except KeyError: + self._annotations = {} + + self._annotations.update(self.fn.__annotations__) + + return self._annotations + + @staticmethod + def _typ_has(typ: type, annotation: type) -> bool: + if not hasattr(annotation, "__metadata__"): + return False + if (origin := get_origin(typ)): + if origin is Annotated and hasattr(typ, "__metadata__") and typ.__metadata__ == annotation.__metadata__: + return True + if any(FunctionAnalyser._typ_has(arg, annotation) for arg in get_args(typ)): + return True + return False + + def argument_has(self, arg: str, annotation: type) -> bool: + if arg in self.all_annotations: + typ = self.all_annotations[arg] + if self._typ_has(typ, annotation): + return True + return False + + def is_at_construct_arg(self, arg: str) -> bool: + return self.argument_has(arg, AtConstruct) diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index a111b3d5..609ffa34 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -526,7 +526,7 @@ def from_workflow( workflow.result if isinstance(workflow.result, list | tuple) else {workflow.result.field: workflow.result} - if workflow.result else + if workflow.has_result else {} ), name=name, diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index bdf02ed3..6501124f 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -45,6 +45,7 @@ from .utils import is_raw, make_traceback from .workflow import ( + Reference, StepReference, ParameterReference, Workflow, @@ -59,6 +60,7 @@ is_task, ) from .backends._base import BackendModule +from .annotations import FunctionAnalyser Param = ParamSpec("Param") RetType = TypeVar("RetType") @@ -446,9 +448,14 @@ def add_numbers(left: int, right: int): workflow = merge_workflows(*workflows) else: workflow = Workflow() + + analyser = FunctionAnalyser(fn) + if not is_in_nested_task(): for var, value in kwargs.items(): - if is_raw(value): + if analyser.is_at_construct_arg(var): + kwargs[var] = value + elif is_raw(value): # We leave this reference dangling for a consumer to pick up ("tethered"), unless # we are in a nested task, that does not have any existence of its own. kwargs[var] = ParameterReference( @@ -478,7 +485,9 @@ def add_numbers(left: int, right: int): # raise TypeError( # "Captured parameter {var} (global variable in task) shadows an argument" # ) - if isinstance(value, Parameter): + if analyser.is_at_construct_arg(var): + kwargs[var] = value + elif isinstance(value, Parameter): kwargs[var] = ParameterReference(workflow, value) elif is_raw(value): kwargs[var] = ParameterReference( @@ -525,10 +534,14 @@ def {fn.__name__}(...) -> ...: var: ParameterReference( nested_workflow, param( - var, typ=value.__type__, tethered=nested_workflow + var, + typ=( + value.__type__ + ), + tethered=nested_workflow ), - ) - for var, value in original_kwargs.items() + ) if isinstance(var, Reference) else value + for var, value in kwargs.items() } with in_nested_task(): output = fn(**nested_kwargs) @@ -536,7 +549,7 @@ def {fn.__name__}(...) -> ...: step_reference = workflow.add_nested_step( fn.__name__, nested_workflow, kwargs ) - if isinstance(step_reference, StepReference): + if isinstance(step_reference, StepReference): # RMV: What if it's a list? return cast(RetType, step_reference) raise TypeError( f"Nested tasks must return a step reference, not {type(step_reference)} to ensure graph makes sense." diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index bb6cb1f1..3c90b2e0 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -24,7 +24,7 @@ from attrs import define, has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter -from typing import Protocol, Any, TypeVar, Generic, cast, Literal +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated from uuid import uuid4 import logging @@ -37,6 +37,10 @@ RetType = TypeVar("RetType") +class UnevaluatableError(Exception): + ... + + @define class Raw: """Value object for any raw types. @@ -368,6 +372,10 @@ def __eq__(self, other: object) -> bool: and self._name == other._name ) + @property + def has_result(self) -> bool: + return not(self.result is None or self.result is []) + @property def name(self) -> str: """Get the name of the workflow. @@ -698,6 +706,27 @@ def __workflow__(self) -> Workflow: class Reference: """Superclass for all symbolic references to values.""" + def _raise_unevaluatable_error(self): + raise UnevaluatableError(f"This reference, {self.name}, cannot be evaluated during construction.") + + def __eq__(self, other) -> bool: + if not isinstance(other, Reference): + print(self, other) + self._raise_unevaluatable_error() + return super().__eq__(other) + + def __float__(self) -> bool: + self._raise_unevaluatable_error() + return False + + def __int__(self) -> bool: + self._raise_unevaluatable_error() + return False + + def __bool__(self) -> bool: + self._raise_unevaluatable_error() + return False + @property def name(self) -> str: """Referral name for this reference.""" @@ -902,7 +931,7 @@ def return_type(self) -> Any: Returns: Expected type of the return value. """ - if not self.__subworkflow__.result: + if self.__subworkflow__.result is None or self.__subworkflow__.result is []: raise RuntimeError("Can only use a subworkflow if the reference exists.") return self.__subworkflow__.result_type diff --git a/tests/test_annotations.py b/tests/test_annotations.py new file mode 100644 index 00000000..cdf2a112 --- /dev/null +++ b/tests/test_annotations.py @@ -0,0 +1,131 @@ +import pytest +import yaml + +from dewret.tasks import task, construct, subworkflow, TaskException +from dewret.renderers.cwl import render +from dewret.annotations import AtConstruct, FunctionAnalyser + +from ._lib.extra import increment, sum + +ARG1: AtConstruct[bool] = True +ARG2: bool = False + +class MyClass: + def method(self, arg1: bool, arg2: AtConstruct[int]) -> float: + arg3: float = 7.0 + arg4: AtConstruct[float] = 8.0 + return arg1 + arg2 + arg3 + arg4 + int(ARG1) + int(ARG2) + +def fn(arg5: int, arg6: AtConstruct[int]) -> float: + arg7: float = 7.0 + arg8: AtConstruct[float] = 8.0 + return arg5 + arg6 + arg7 + arg8 + int(ARG1) + int(ARG2) + + +@subworkflow() +def to_int_bad(num: int, should_double: bool) -> int | float: + """Cast to an int.""" + return increment(num=num) if should_double else sum(left=num, right=num) + +@subworkflow() +def to_int(num: int, should_double: AtConstruct[bool]) -> int | float: + """Cast to an int.""" + return increment(num=num) if should_double else sum(left=num, right=num) + +def test_can_analyze_annotations(): + my_obj = MyClass() + + analyser = FunctionAnalyser(my_obj.method) + assert analyser.argument_has("arg1", AtConstruct) is False + assert analyser.argument_has("arg3", AtConstruct) is False + assert analyser.argument_has("ARG2", AtConstruct) is False + assert analyser.argument_has("arg2", AtConstruct) is True + assert analyser.argument_has("arg4", AtConstruct) is False # Not a global/argument + assert analyser.argument_has("ARG1", AtConstruct) is True + + analyser = FunctionAnalyser(fn) + assert analyser.argument_has("arg5", AtConstruct) is False + assert analyser.argument_has("arg7", AtConstruct) is False + assert analyser.argument_has("ARG2", AtConstruct) is False + assert analyser.argument_has("arg2", AtConstruct) is True + assert analyser.argument_has("arg8", AtConstruct) is False # Not a global/argument + assert analyser.argument_has("ARG1", AtConstruct) is True + +def test_at_construct() -> None: + result = to_int_bad(num=increment(num=3), should_double=True) + with pytest.raises(TaskException) as _: + workflow = construct(result, simplify_ids=True) + + result = to_int(num=increment(num=3), should_double=True) + workflow = construct(result, simplify_ids=True) + subworkflows = render(workflow, allow_complex_types=True) + rendered = subworkflows["__root__"] + assert rendered == yaml.safe_load(""" + cwlVersion: 1.2 + class: Workflow + inputs: + increment-1-num: + default: 3 + label: increment-1-num + type: int + outputs: + out: + label: out + outputSource: to_int-1/out + type: int + steps: + increment-1: + in: + num: + source: increment-1-num + out: + - out + run: increment + to_int-1: + in: + num: + source: increment-1/out + should_double: + default: True + out: + - out + run: to_int + """) + + result = to_int(num=increment(num=3), should_double=False) + workflow = construct(result, simplify_ids=True) + subworkflows = render(workflow, allow_complex_types=True) + rendered = subworkflows["__root__"] + assert rendered == yaml.safe_load(""" + cwlVersion: 1.2 + class: Workflow + inputs: + increment-1-num: + default: 3 + label: increment-1-num + type: int + outputs: + out: + label: out + outputSource: to_int-1/out + type: + - int + - double + steps: + increment-1: + in: + num: + source: increment-1-num + out: + - out + run: increment + to_int-1: + in: + num: + source: increment-1/out + should_double: + default: False + out: + - out + run: to_int + """) From d5b8c65718bb56ab09906bab3a51903bc671d408 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 3 Aug 2024 20:41:03 +0100 Subject: [PATCH 007/108] fix: correct tests --- src/dewret/tasks.py | 3 ++- src/dewret/workflow.py | 9 +++++++-- tests/test_errors.py | 5 +++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 6501124f..fd94bdbb 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -540,8 +540,9 @@ def {fn.__name__}(...) -> ...: ), tethered=nested_workflow ), - ) if isinstance(var, Reference) else value + ) if isinstance(value, Reference) else value for var, value in kwargs.items() + if var in original_kwargs } with in_nested_task(): output = fn(**nested_kwargs) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 3c90b2e0..4fc7d643 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -488,7 +488,7 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": right.result = [right.result] result = list(left.result) + list(right.result) - if result: + if result is not None and result != []: if isinstance(result, list | tuple): new.set_result([ StepReference( @@ -706,12 +706,17 @@ def __workflow__(self) -> Workflow: class Reference: """Superclass for all symbolic references to values.""" + @property + def __type__(self): + raise NotImplementedError() + def _raise_unevaluatable_error(self): raise UnevaluatableError(f"This reference, {self.name}, cannot be evaluated during construction.") def __eq__(self, other) -> bool: + if isinstance(other, list) or other is None: + return False if not isinstance(other, Reference): - print(self, other) self._raise_unevaluatable_error() return super().__eq__(other) diff --git a/tests/test_errors.py b/tests/test_errors.py index 45496c61..ae49e07e 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -3,6 +3,7 @@ import pytest from dewret.workflow import Task, Lazy from dewret.tasks import construct, task, nested_task, TaskException +from dewret.annotations import AtConstruct from ._lib.extra import increment # noqa: F401 @@ -12,7 +13,7 @@ def add_task(left: int, right: int) -> int: return left + right -ADD_TASK_LINE_NO = 9 +ADD_TASK_LINE_NO = 10 @nested_task() @@ -95,7 +96,7 @@ def unacceptable_object_usage() -> int: @nested_task() -def unacceptable_nested_return(int_not_global: bool) -> int | Lazy: +def unacceptable_nested_return(int_not_global: AtConstruct[bool]) -> int | Lazy: """Bad nested_task that fails to return a task.""" add_task(left=3, right=4) return 7 if int_not_global else ADD_TASK_LINE_NO From 470de2094c7cd91e009f909607e94c9469aebe42 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 3 Aug 2024 18:02:23 +0100 Subject: [PATCH 008/108] wip: add configuration --- src/dewret/backends/backend_dask.py | 20 ++++++++++-- src/dewret/renderers/cwl.py | 1 - src/dewret/tasks.py | 31 +++++++++++++++++- tests/test_configuration.py | 49 +++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 4 deletions(-) create mode 100644 tests/test_configuration.py diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 84818a0a..98300a6c 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -18,8 +18,13 @@ """ from dask.delayed import delayed, DelayedLeaf -from dewret.workflow import Workflow, Lazy, StepReference, Target +from dask.config import config +import contextvars +from functools import partial from typing import Protocol, runtime_checkable, Any, cast +from concurrent.futures import Executor, ThreadPoolExecutor +from dewret.workflow import Workflow, Lazy, StepReference, Target +from dewret.tasks import CONSTRUCT_CONFIGURATION @runtime_checkable @@ -86,8 +91,18 @@ def is_lazy(task: Any) -> bool: lazy = delayed +CONTEXT = [] +def _initializer(): + for var, value in CONTEXT: + var.set(value) + CONSTRUCT_CONFIGURATION.get() + +CONNECTION_POOL = ThreadPoolExecutor(initializer=_initializer) +config["pool"] = CONNECTION_POOL + def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy]) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: + global CONTEXT """Execute a task as the output of a workflow. Runs a task with dask. @@ -110,5 +125,6 @@ def _check_delayed(task: Lazy | list[Lazy] | tuple[Lazy]) -> Delayed: ) return task computable = _check_delayed(task) - result = computable.compute(__workflow__=workflow) + CONTEXT = contextvars.copy_context().items() + result = computable.compute(__workflow__=workflow, initializer=_initializer) return result diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 609ffa34..0dfc29d8 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -44,7 +44,6 @@ ] -@dataclass class CWLRendererConfiguration(TypedDict): """Configuration for the renderer. diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index fd94bdbb..839d064e 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -32,13 +32,14 @@ import inspect import importlib import sys +from typing import TypedDict, NotRequired, Unpack from enum import Enum from functools import cached_property from collections.abc import Callable from typing import Any, ParamSpec, TypeVar, cast, Generator from types import TracebackType from attrs import has as attrs_has -from dataclasses import is_dataclass +from dataclasses import dataclass, is_dataclass import traceback from contextvars import ContextVar from contextlib import contextmanager @@ -62,6 +63,34 @@ from .backends._base import BackendModule from .annotations import FunctionAnalyser +class ConstructConfiguration(TypedDict): + flatten_all_nested: NotRequired[bool] + allow_positional_args: NotRequired[bool] + +CONSTRUCT_CONFIGURATION: ContextVar[ConstructConfiguration] = ContextVar("construct-configuration") + +@contextmanager +def set_configuration(**kwargs: Unpack[ConstructConfiguration]): + try: + previous = CONSTRUCT_CONFIGURATION.get() + except LookupError: + previous = ConstructConfiguration( + flatten_all_nested=False, + allow_positional_args=False + ) + + try: + CONSTRUCT_CONFIGURATION.set({}) + CONSTRUCT_CONFIGURATION.get().update(previous) + CONSTRUCT_CONFIGURATION.get().update(kwargs) + + yield CONSTRUCT_CONFIGURATION + finally: + CONSTRUCT_CONFIGURATION.set(previous) + +def get_configuration(key: str): + return CONSTRUCT_CONFIGURATION.get()[key] + Param = ParamSpec("Param") RetType = TypeVar("RetType") diff --git a/tests/test_configuration.py b/tests/test_configuration.py new file mode 100644 index 00000000..b827dc3f --- /dev/null +++ b/tests/test_configuration.py @@ -0,0 +1,49 @@ +import yaml +import pytest +from dewret.tasks import construct, task, factory, subworkflow +from dewret.renderers.cwl import render +from dewret.utils import hasher +from dewret.tasks import set_configuration +from ._lib.extra import increment, double, mod10, sum, triple_and_one + +@pytest.fixture +def configuration(): + with set_configuration() as configuration: + yield configuration.get() + +@subworkflow() +def floor(num: int | float, expected: bool) -> int: + """Converts int/float to int.""" + from dewret.tasks import get_configuration + if get_configuration("flatten_all_nested") != expected: + raise AssertionError(f"Not expected configuration: {get_configuration('flatten_all_nested')} != {expected}") + return int(num) + +def test_cwl_with_parameter(configuration) -> None: + result = increment(num=floor(num=3.1, expected=True)) + workflow = construct(result) + rendered = render(workflow)["__root__"] + num_param = list(workflow.find_parameters())[0] + hsh = hasher(("increment", ("num", f"int|:param:{num_param.unique_name}"))) + + assert rendered == yaml.safe_load(f""" + cwlVersion: 1.2 + class: Workflow + inputs: + increment-{hsh}-num: + label: increment-{hsh}-num + type: int + default: 3 + outputs: + out: + label: out + outputSource: increment-{hsh}/out + type: int + steps: + increment-{hsh}: + run: increment + in: + num: + source: increment-{hsh}-num + out: [out] + """) From be9b86216746cdd2b5c81773b11320f37791e376 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 3 Aug 2024 23:27:01 +0100 Subject: [PATCH 009/108] fix: scope thread pools to a single render call and add context copying --- src/dewret/backends/_base.py | 4 ++- src/dewret/backends/backend_dask.py | 13 +++------- src/dewret/tasks.py | 18 +++++++++---- tests/test_configuration.py | 40 +++++++++++++++++++++-------- 4 files changed, 48 insertions(+), 27 deletions(-) diff --git a/src/dewret/backends/_base.py b/src/dewret/backends/_base.py index d2d2510c..fd1a830e 100644 --- a/src/dewret/backends/_base.py +++ b/src/dewret/backends/_base.py @@ -18,6 +18,7 @@ """ from typing import Protocol, Any +from concurrent.futures import ThreadPoolExecutor from dewret.workflow import LazyFactory, Lazy, Workflow, StepReference, Target class BackendModule(Protocol): @@ -32,12 +33,13 @@ class BackendModule(Protocol): """ lazy: LazyFactory - def run(self, workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy]) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: + def run(self, workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread_pool: ThreadPoolExecutor | None=None) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: """Execute a lazy task for this `Workflow`. Args: workflow: `Workflow` that is being executed. task: task that forms the output. + thread_pool: the thread pool that should be used for this execution. Returns: Reference to the final output step. diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 98300a6c..521bd591 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -92,17 +92,10 @@ def is_lazy(task: Any) -> bool: lazy = delayed CONTEXT = [] -def _initializer(): - for var, value in CONTEXT: - var.set(value) - CONSTRUCT_CONFIGURATION.get() -CONNECTION_POOL = ThreadPoolExecutor(initializer=_initializer) -config["pool"] = CONNECTION_POOL -def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy]) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: - global CONTEXT +def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread_pool: ThreadPoolExecutor | None=None) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: """Execute a task as the output of a workflow. Runs a task with dask. @@ -125,6 +118,6 @@ def _check_delayed(task: Lazy | list[Lazy] | tuple[Lazy]) -> Delayed: ) return task computable = _check_delayed(task) - CONTEXT = contextvars.copy_context().items() - result = computable.compute(__workflow__=workflow, initializer=_initializer) + config["pool"] = thread_pool + result = computable.compute(__workflow__=workflow) return result diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 839d064e..4cc6f9db 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -41,7 +41,8 @@ from attrs import has as attrs_has from dataclasses import dataclass, is_dataclass import traceback -from contextvars import ContextVar +from concurrent.futures import ThreadPoolExecutor +from contextvars import ContextVar, copy_context from contextlib import contextmanager from .utils import is_raw, make_traceback @@ -78,9 +79,9 @@ def set_configuration(**kwargs: Unpack[ConstructConfiguration]): flatten_all_nested=False, allow_positional_args=False ) + CONSTRUCT_CONFIGURATION.set({}) try: - CONSTRUCT_CONFIGURATION.set({}) CONSTRUCT_CONFIGURATION.get().update(previous) CONSTRUCT_CONFIGURATION.get().update(kwargs) @@ -163,7 +164,7 @@ def make_lazy(self) -> LazyFactory: """ return self.backend.lazy - def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy], __workflow__: Workflow, **kwargs: Any) -> Any: + def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy], __workflow__: Workflow, thread_pool=None, **kwargs: Any) -> Any: """Evaluate a single task for a known workflow. Args: @@ -171,7 +172,7 @@ def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy], __workflow__: Workflow __workflow__: workflow within which this exists. **kwargs: any arguments to pass to the task. """ - result = self.backend.run(__workflow__, task, **kwargs) + result = self.backend.run(__workflow__, task, thread_pool=thread_pool, **kwargs) to_check: list[StepReference] | tuple[StepReference] if isinstance(result, list | tuple): to_check = result @@ -246,7 +247,14 @@ def __call__( A reusable reference to this individual step. """ workflow = __workflow__ or Workflow() - result = self.evaluate(task, workflow, **kwargs) + + context = copy_context().items() + def _initializer(): + for var, value in context: + var.set(value) + thread_pool = ThreadPoolExecutor(initializer=_initializer) + + result = self.evaluate(task, workflow, thread_pool=thread_pool, **kwargs) return Workflow.from_result(result, simplify_ids=simplify_ids) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index b827dc3f..12d2bd72 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -1,9 +1,10 @@ import yaml import pytest -from dewret.tasks import construct, task, factory, subworkflow +from dewret.tasks import construct, task, factory, subworkflow, TaskException from dewret.renderers.cwl import render from dewret.utils import hasher from dewret.tasks import set_configuration +from dewret.annotations import AtConstruct from ._lib.extra import increment, double, mod10, sum, triple_and_one @pytest.fixture @@ -12,38 +13,55 @@ def configuration(): yield configuration.get() @subworkflow() -def floor(num: int | float, expected: bool) -> int: +def floor(num: int, expected: AtConstruct[bool]) -> int: """Converts int/float to int.""" from dewret.tasks import get_configuration if get_configuration("flatten_all_nested") != expected: raise AssertionError(f"Not expected configuration: {get_configuration('flatten_all_nested')} != {expected}") - return int(num) + return increment(num=num) def test_cwl_with_parameter(configuration) -> None: - result = increment(num=floor(num=3.1, expected=True)) - workflow = construct(result) + result = increment(num=floor(num=3, expected=True)) + + with set_configuration(flatten_all_nested=True): + workflow = construct(result, simplify_ids=True) + + with pytest.raises(TaskException) as exc, set_configuration(flatten_all_nested=False): + workflow = construct(result, simplify_ids=True) + assert "AssertionError" in str(exc.getrepr()) + + with set_configuration(flatten_all_nested=True): + result = increment(num=floor(num=3, expected=True)) + workflow = construct(result, simplify_ids=True) rendered = render(workflow)["__root__"] num_param = list(workflow.find_parameters())[0] - hsh = hasher(("increment", ("num", f"int|:param:{num_param.unique_name}"))) assert rendered == yaml.safe_load(f""" cwlVersion: 1.2 class: Workflow inputs: - increment-{hsh}-num: - label: increment-{hsh}-num + floor-1-num: + label: floor-1-num type: int default: 3 outputs: out: label: out - outputSource: increment-{hsh}/out + outputSource: increment-1/out type: int steps: - increment-{hsh}: + floor-1: + run: floor + in: + expected: + default: true + num: + source: floor-1-num + out: [out] + increment-1: run: increment in: num: - source: increment-{hsh}-num + source: floor-1/out out: [out] """) From 7c34dbed8f40b54036f11e580c9058ce0c049ecb Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 4 Aug 2024 01:19:17 +0100 Subject: [PATCH 010/108] fix: make sure the globals are updated when a subworkflow is running --- src/dewret/annotations.py | 26 ++++++- src/dewret/backends/backend_dask.py | 4 -- src/dewret/tasks.py | 12 ++-- tests/test_subworkflows.py | 103 ++++++++++++++++++++++++++++ 4 files changed, 132 insertions(+), 13 deletions(-) diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index 6e6643bf..bbdd7bc5 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -1,6 +1,7 @@ import inspect from functools import lru_cache -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args +from types import FunctionType +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args, Mapping T = TypeVar("T") AtConstruct = Annotated[T, "AtConstruct"] @@ -50,3 +51,26 @@ def argument_has(self, arg: str, annotation: type) -> bool: def is_at_construct_arg(self, arg: str) -> bool: return self.argument_has(arg, AtConstruct) + + @property + @lru_cache + def globals(self) -> Mapping[str, Any]: + try: + fn_globals = inspect.getclosurevars(self.fn).globals + # This covers the case of wrapping, rather than decorating. + except TypeError: + fn_globals = {} + return fn_globals + + def with_new_globals(self, new_globals: dict[str, Any]) -> Callable[..., Any]: + code = self.fn.__code__ + fn_name = self.fn.__name__ + all_globals = dict(self.globals) + all_globals.update(new_globals) + return FunctionType( + code, + all_globals, + name=fn_name, + closure=self.fn.__closure__, + argdefs=self.fn.__defaults__, + ) diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 521bd591..d5a82194 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -91,10 +91,6 @@ def is_lazy(task: Any) -> bool: lazy = delayed -CONTEXT = [] - - - def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread_pool: ThreadPoolExecutor | None=None) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: """Execute a task as the output of a workflow. diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 4cc6f9db..1962edf9 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -509,11 +509,7 @@ def add_numbers(left: int, right: int): elif isinstance(value, Parameter): kwargs[var] = ParameterReference(workflow, value) original_kwargs = dict(kwargs) - try: - fn_globals = dict(inspect.getclosurevars(fn).globals) - # This covers the case of wrapping, rather than decorating. - except TypeError: - fn_globals = {} + fn_globals = analyser.globals for var, value in fn_globals.items(): # This error is redundant as it triggers a SyntaxError in Python. @@ -567,7 +563,7 @@ def {fn.__name__}(...) -> ...: step_reference = evaluate(lazy_fn, __workflow__=workflow) else: nested_workflow = Workflow(name=fn.__name__) - nested_kwargs: Param.kwargs = { + nested_globals: Param.kwargs = { var: ParameterReference( nested_workflow, param( @@ -579,10 +575,10 @@ def {fn.__name__}(...) -> ...: ), ) if isinstance(value, Reference) else value for var, value in kwargs.items() - if var in original_kwargs } + nested_kwargs = {key: value for key, value in nested_globals.items() if key in original_kwargs} with in_nested_task(): - output = fn(**nested_kwargs) + output = analyser.with_new_globals(nested_globals)(**nested_kwargs) nested_workflow = _manager(output, __workflow__=nested_workflow) step_reference = workflow.add_nested_step( fn.__name__, nested_workflow, kwargs diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 6448d6bb..c57b4f5c 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -59,8 +59,15 @@ def get_global_queues(num: int | float) -> list["Queue[int] | int"]: @subworkflow() def add_constant(num: int | float) -> int: """Add a global constant to a number.""" + print(CONSTANT, type(CONSTANT)) return to_int(num=sum(left=num, right=CONSTANT)) +@subworkflow() +def add_constants(num: int | float) -> int: + """Add a global constant to a number.""" + print(CONSTANT, type(CONSTANT)) + return to_int(num=sum(left=sum(left=num, right=CONSTANT), right=CONSTANT)) + def test_subworkflows_can_use_globals() -> None: """Produce a subworkflow that uses a global.""" @@ -382,3 +389,99 @@ def test_can_merge_workflows() -> None: out: [out] run: to_int """) + + +def test_subworkflows_can_use_globals_in_right_scope() -> None: + """Produce a subworkflow that uses a global.""" + my_param = param("num", typ=int) + result = increment(num=add_constants(num=increment(num=my_param))) + workflow = construct(result, simplify_ids=True) + subworkflows = render(workflow) + rendered = subworkflows["__root__"] + del subworkflows["__root__"] + + assert len(subworkflows) == 1 + assert isinstance(subworkflows, dict) + osubworkflows = sorted(list(subworkflows.items())) + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + CONSTANT: + label: CONSTANT + default: 3 + type: int + num: + label: num + type: int + outputs: + out: + label: out + outputSource: increment-2/out + type: int + steps: + increment-1: + in: + num: + source: num + out: [out] + run: increment + increment-2: + in: + num: + source: add_constants-1/out + out: [out] + run: increment + add_constants-1: + in: + CONSTANT: + source: CONSTANT + num: + source: increment-1/out + out: [out] + run: add_constants + """) + + assert osubworkflows[0] == ("add_constants-1", yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + num: + label: num + type: int + CONSTANT: + label: CONSTANT + type: int + outputs: + out: + label: out + outputSource: to_int-1-1/out + type: int + steps: + sum-1-1: + in: + left: + source: num + right: + source: CONSTANT + out: + - out + run: sum + sum-1-2: + in: + left: + source: sum-1-1/out + right: + source: CONSTANT + out: + - out + run: sum + to_int-1-1: + in: + num: + source: sum-1-2/out + out: + - out + run: to_int + """)) From c6dc355b564a7e4a2ad06899148997ce618200cf Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 6 Aug 2024 01:06:03 +0100 Subject: [PATCH 011/108] fix: tests --- src/dewret/renderers/cwl.py | 22 ++- src/dewret/tasks.py | 20 +- src/dewret/workflow.py | 354 +++++++++++++++++++++--------------- tests/test_cwl.py | 6 +- tests/test_fieldable.py | 64 +++++++ tests/test_subworkflows.py | 9 +- 6 files changed, 304 insertions(+), 171 deletions(-) create mode 100644 tests/test_fieldable.py diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 0dfc29d8..8cf6e15c 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -28,6 +28,7 @@ from dewret.workflow import ( FactoryCall, Reference, + FieldableMixin, Raw, Workflow, BaseStep, @@ -91,7 +92,7 @@ def from_reference(cls, ref: Reference) -> "ReferenceDefinition": Args: ref: reference to convert. """ - return cls(source=ref.name) + return cls(source=to_name(ref)) def render(self) -> dict[str, RawType]: """Render to a dict-like structure. @@ -392,11 +393,11 @@ def from_parameters( """ return cls( inputs={ - input.name: cls.CommandInputParameter( - label=input.name, - default=input.default, + input.__name__: cls.CommandInputParameter( + label=input.__name__, + default=input.__default__, type=raw_to_command_input_schema( - label=input.name, value=input.default + label=input.__name__, value=input.__default__ ), ) for input in parameters @@ -424,6 +425,11 @@ def render(self) -> dict[str, RawType]: return result +def to_name(result: FieldableMixin): + if not result.__field__ and isinstance(result, StepReference): + return f"{result.__name__}/out" + return result.__name__ + @define class OutputsDefinition: """CWL-renderable set of workflow outputs. @@ -450,12 +456,12 @@ def from_results( return cls( outputs=[ to_output_schema( - result.field, result.return_type, output_source=result.name + "/".join(result.__field__) or "out", result.__type__, output_source=to_name(result) ) for result in results ] if isinstance(results, list | tuple) else { key: to_output_schema( - result.field, result.return_type, output_source=result.name + "/".join(result.__field__) or "out", result.__type__, output_source=to_name(result) ) for key, result in results.items() } @@ -524,7 +530,7 @@ def from_workflow( outputs=OutputsDefinition.from_results( workflow.result if isinstance(workflow.result, list | tuple) else - {workflow.result.field: workflow.result} + {"/".join(workflow.result.__field__) or "out": workflow.result} if workflow.has_result else {} ), diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 1962edf9..be124093 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -496,8 +496,8 @@ def add_numbers(left: int, right: int): # We leave this reference dangling for a consumer to pick up ("tethered"), unless # we are in a nested task, that does not have any existence of its own. kwargs[var] = ParameterReference( - workflow, - param( + workflow=workflow, + parameter=param( var, value, tethered=( @@ -507,7 +507,7 @@ def add_numbers(left: int, right: int): ), ) elif isinstance(value, Parameter): - kwargs[var] = ParameterReference(workflow, value) + kwargs[var] = ParameterReference(workflow=workflow, parameter=value) original_kwargs = dict(kwargs) fn_globals = analyser.globals @@ -518,13 +518,13 @@ def add_numbers(left: int, right: int): # raise TypeError( # "Captured parameter {var} (global variable in task) shadows an argument" # ) - if analyser.is_at_construct_arg(var): + if analyser.is_at_construct_arg(var) or isinstance(value, Reference): kwargs[var] = value elif isinstance(value, Parameter): - kwargs[var] = ParameterReference(workflow, value) - elif is_raw(value): + kwargs[var] = ParameterReference(workflow=workflow, parameter=value) + elif is_raw(value) or ((attrs_has(value) or is_dataclass(value)) and not inspect.isclass(value)): kwargs[var] = ParameterReference( - workflow, param(var, value, tethered=False) + workflow=workflow, parameter=param(var, value, tethered=False) ) elif is_task(value) or ensure_lazy(value) is not None: if not nested and _workaround_check_value_is_task( @@ -546,8 +546,6 @@ def {fn.__name__}(...) -> ...: ... """ ) - elif attrs_has(value) or is_dataclass(value): - ... elif nested: raise NotImplementedError( f"Nested tasks must now only refer to global parameters, raw or tasks, not objects: {var}" @@ -565,8 +563,8 @@ def {fn.__name__}(...) -> ...: nested_workflow = Workflow(name=fn.__name__) nested_globals: Param.kwargs = { var: ParameterReference( - nested_workflow, - param( + workflow=nested_workflow, + parameter=param( var, typ=( value.__type__ diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 4fc7d643..9846e462 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -34,6 +34,7 @@ from .utils import hasher, RawType, is_raw, make_traceback, is_raw_type T = TypeVar("T") +U = TypeVar("U") RetType = TypeVar("RetType") @@ -135,6 +136,7 @@ def __init__(self, raw_type: type[T]): UNSET = Unset() + class Parameter(Generic[T]): """Global parameter. @@ -191,15 +193,20 @@ def __init__( else: raw_type = type(default) self.__type__: type[T] = raw_type + if self.__type__ == type: + asdffdsa if tethered and isinstance(tethered, BaseStep): self.register_caller(tethered) + def __eq__(self, other): + return hash(self) == hash(other) + def __hash__(self) -> int: """Get a unique hash for this parameter.""" if self.__tethered__ is None: raise RuntimeError( - "Parameter {self.full_name} was never tethered but should have been" + f"Parameter {self.name} was never tethered but should have been" ) return hash(self.__name__) @@ -209,7 +216,7 @@ def default(self) -> T | UnsetType[T]: return self.__default__ @property - def full_name(self) -> str: + def name(self) -> str: """Extended name, suitable for rendering. This attempts to create a unique name by tying the parameter to a step @@ -232,15 +239,6 @@ def register_caller(self, caller: BaseStep) -> None: self.__tethered__ = caller self.__callers__.append(caller) - @property - def name(self) -> str: - """Name for this step. - - May be remapped by the workflow to something nicer - than the ID. - """ - return self.full_name - def param( name: str, @@ -490,17 +488,11 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": if result is not None and result != []: if isinstance(result, list | tuple): - new.set_result([ - StepReference( - new, entry.step, typ=entry.return_type, field=entry.field - ) for entry in result - ]) + for entry in result: + entry.__workflow__ = new else: - new.set_result( - StepReference( - new, result.step, typ=result.return_type, field=result.field - ) - ) + result.__workflow__ = new + new.set_result(result) return new @@ -533,7 +525,7 @@ def simplify_ids(self, infix: list[str] | None = None) -> None: param_counter = Counter[str]() name_to_original: dict[str, str] = {} for name, param in { - pr.parameter.__name__: pr.parameter + pr._.parameter.__name__: pr._.parameter for pr in self.find_parameters() if isinstance(pr, ParameterReference) }.items(): @@ -582,7 +574,7 @@ def add_nested_step( return_type = step.return_type if return_type is inspect._empty: raise TypeError("All tasks should have a type annotation.") - return StepReference(self, step, return_type) + return StepReference(step, return_type) def add_step( self, @@ -613,7 +605,7 @@ def add_step( and not inspect.isclass(fn) ): raise TypeError("All tasks should have a type annotation.") - return StepReference(self, step, return_type) + return StepReference(step, return_type) @staticmethod def from_result( @@ -653,7 +645,7 @@ def set_result(self, result: StepReference[Any] | list[StepReference[Any]] | tup else: to_check = [result] for entry in to_check: - if entry.step.__workflow__ != self: + if entry._.step.__workflow__ != self: raise RuntimeError("Output must be from a step in this workflow.") self.result = result @@ -664,7 +656,7 @@ def result_type(self): if isinstance(self.result, tuple | list): # TODO: get individual types! return type(self.result) - return self.result.return_type + return self.result.__type__ class WorkflowComponent: @@ -676,7 +668,7 @@ class WorkflowComponent: __workflow__: Workflow - def __init__(self, workflow: Workflow): + def __init__(self, *args, workflow: Workflow, **kwargs): """Tie to a `Workflow`. All subclasses must call this. @@ -685,6 +677,7 @@ def __init__(self, workflow: Workflow): workflow: the `Workflow` to tie to. """ self.__workflow__ = workflow + super().__init__(*args, **kwargs) class WorkflowLinkedComponent(Protocol): @@ -703,15 +696,32 @@ def __workflow__(self) -> Workflow: ... -class Reference: +class Reference(Generic[U]): """Superclass for all symbolic references to values.""" + _type: type[U] | None = None + __workflow__: Workflow + + def __init__(self, *args, typ: type[U] | None = None, **kwargs): + self._type = typ + if typ == type: + asdf + super().__init__() + + @property + def __root_name__(self) -> str: + raise NotImplementedError( + "Reference must have a '__root_name__' property or override '__name__'" + ) + @property def __type__(self): + if self._type is not None: + return self._type raise NotImplementedError() def _raise_unevaluatable_error(self): - raise UnevaluatableError(f"This reference, {self.name}, cannot be evaluated during construction.") + raise UnevaluatableError(f"This reference, {self.__name__}, cannot be evaluated during construction.") def __eq__(self, other) -> bool: if isinstance(other, list) or other is None: @@ -733,10 +743,81 @@ def __bool__(self) -> bool: return False @property - def name(self) -> str: + def __name__(self) -> str: """Referral name for this reference.""" - raise NotImplementedError("Reference must provide a name") + workflow = self.__workflow__ + name = self.__root_name__ + return workflow.remap(name) + + def __str__(self) -> str: + """Global description of the reference.""" + return self.__name__ + + +class FieldableProtocol(Protocol): + __field__: tuple[str, ...] + + def __init__(self, *args, field: str | None = None, **kwargs): + super().__init__(*args, **kwargs) + @property + def __type__(self): + ... + + @property + def name(self): + return "name" + +# Subclass Reference so that we know Reference methods/attrs are available. +class FieldableMixin: + def __init__(self: FieldableProtocol, *args, field: str | None = None, **kwargs): + self.__field__: tuple[str, ...] = tuple(field.split("/")) if field else () + super().__init__(*args, **kwargs) + + @property + def __name__(self: FieldableProtocol) -> str: + """Name for this step. + + May be remapped by the workflow to something nicer + than the ID. + """ + return "/".join([super().__name__] + list(self.__field__)) + + def find_field(self: FieldableProtocol, field, fallback_type: type | None = None, **init_kwargs: Any) -> Reference: + """Field within the reference, if possible. + + Returns: + A field-specific version of this reference. + """ + + # Get new type, for the specific field. + parent_type = self.__type__ + field_type = fallback_type + + if is_dataclass(parent_type): + try: + field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type + except StopIteration: + raise AttributeError(f"Dataclass {parent_type} does not have field {field}") + elif attr_has(parent_type): + resolve_types(parent_type) + try: + field_type = getattr(attrs_fields(parent_type), field).type + except AttributeError: + raise AttributeError(f"attrs-class {parent_type} does not have field {field}") + + if field_type: + if not issubclass(self.__class__, Reference): + raise TypeError("Only references can have a fieldable mixin") + + if self.__field__: + field = "/".join(self.__field__) + "/" + field + + return self.__class__(typ=field_type, field=field, **init_kwargs) + + raise AttributeError( + f"Could not determine the type for field {field} in type {parent_type}" + ) class BaseStep(WorkflowComponent): """Lazy-evaluated function call. @@ -752,6 +833,7 @@ class BaseStep(WorkflowComponent): _id: str | None = None task: Task | Workflow arguments: Mapping[str, Reference | Raw] + workflow: Workflow def __init__( self, @@ -768,7 +850,7 @@ def __init__( arguments: key-value pairs to pass to the function. raw_as_parameter: whether to turn any raw-type arguments into workflow parameters (or just keep them as default argument values). """ - super().__init__(workflow) + super().__init__(workflow=workflow) self.task = task self.arguments = {} for key, value in arguments.items(): @@ -787,12 +869,12 @@ def __init__( ): if raw_as_parameter: value = ParameterReference( - workflow, param(key, value, tethered=None) + workflow=workflow, parameter=param(key, value, tethered=None) ) else: value = Raw(value) if isinstance(value, ParameterReference): - parameter = value.parameter + parameter = value._.parameter parameter.register_caller(self) self.arguments[key] = value else: @@ -972,15 +1054,19 @@ def __init__( raise RuntimeError( f"Factories must be constructed with raw types {arg} {type(arg)}" ) - super().__init__(workflow, task, arguments, raw_as_parameter=raw_as_parameter) + super().__init__(workflow=workflow, task=task, arguments=arguments, raw_as_parameter=raw_as_parameter) + + @property + def __name__(self): + return self.name @property - def default(self) -> Unset: + def __default__(self) -> Unset: """Dummy default property for use as property.""" return UnsetType(self.return_type) -class ParameterReference(Reference): +class ParameterReference(WorkflowComponent, FieldableMixin, Reference[U]): """Reference to an individual `Parameter`. Allows us to refer to the outputs of a `Parameter` in subsequent `Parameter` @@ -988,39 +1074,67 @@ class ParameterReference(Reference): Attributes: parameter: `Parameter` referred to. - __workflow__: Related workflow. In this case, as Parameters are generic + workflow: Related workflow. In this case, as Parameters are generic but ParameterReferences are specific, this carries the actual workflow reference. Returns: Workflow that the referee is related to. """ - parameter: Parameter[RawType] - __workflow__: Workflow + class ParameterReferenceMetadata(Generic[T]): + parameter: Parameter[T] - def __init__(self, __workflow__: Workflow, parameter: Parameter[RawType]): - """Initialize the reference. + def __init__(self, parameter: Parameter[T], *args, typ: type[U] | None=None, **kwargs): + """Initialize the reference. - Args: - workflow: `Workflow` that this is tied to. - parameter: `Parameter` that this refers to. - """ - self.parameter = parameter - self.__workflow__ = __workflow__ + Args: + workflow: `Workflow` that this is tied to. + parameter: `Parameter` that this refers to. + """ + self.parameter = parameter + + @property + def unique_name(self) -> str: + """Unique, machine-generated name. + + Normally this will become invisible in output, but it avoids circularity + as a step that uses this parameter will ask for this when constructing + its own hash, but we will normally want to use the step's name as part of + the parameter name to distinguish from other parameters of the same name. + """ + return self.parameter.__name__ @property - def default(self) -> RawType | Unset: + def __default__(self) -> T | Unset: """Default value of the parameter.""" - return self.parameter.default + return self._.parameter.default @property - def __type__(self) -> type: - """Type represented by wrapped parameter.""" - return self.parameter.__type__ + def __root_name__(self) -> str: + """Reference based on the named step. - def __str__(self) -> str: - """Global description of the reference.""" - return self.parameter.full_name + May be remapped by the workflow to something nicer + than the ID. + """ + return self._.parameter.name + + def __init__(self, parameter: Parameter[U], *args, typ: type[U] | None=None, **kwargs): + typ = typ or parameter.__type__ + self._ = self.ParameterReferenceMetadata(parameter, *args, typ, **kwargs) + super().__init__(*args, typ=typ, **kwargs) + + def __getattr__(self, attr: str) -> "ParameterReference": + try: + return self.find_field( + field=attr, + workflow=self.__workflow__, + parameter=self._.parameter + ) + except AttributeError as _: + return super().__getattribute__(attr) + + def __getitem__(self, attr: str) -> "ParameterReference": + return getattr(self, attr) def __repr__(self) -> str: """Hashable reference to the step (and field).""" @@ -1028,27 +1142,8 @@ def __repr__(self) -> str: typ = self.__type__.__name__ except AttributeError: typ = str(self.__type__) - return f"{typ}|:param:{self.unique_name}" - - @property - def unique_name(self) -> str: - """Unique, machine-generated name. - - Normally this will become invisible in output, but it avoids circularity - as a step that uses this parameter will ask for this when constructing - its own hash, but we will normally want to use the step's name as part of - the parameter name to distinguish from other parameters of the same name. - """ - return self.parameter.__name__ - - @property - def name(self) -> str: - """Reference based on the named step. - - May be remapped by the workflow to something nicer - than the ID. - """ - return self.__workflow__.remap(self.parameter.name) + name = "/".join([self._.unique_name] + list(self.__field__)) + return f"{typ}|:param:{name}" def __hash__(self) -> int: """Hash to parameter. @@ -1056,7 +1151,7 @@ def __hash__(self) -> int: Returns: Unique hash corresponding to the parameter. """ - return hash(self.parameter) + return hash((self._.parameter, self.__field__)) def __eq__(self, other: object) -> bool: """Compare two references. @@ -1068,14 +1163,11 @@ def __eq__(self, other: object) -> bool: True if the other parameter reference is materially the same, otherwise False. """ return ( - isinstance(other, ParameterReference) and self.parameter == other.parameter + isinstance(other, ParameterReference) and self._.parameter == other._.parameter and self.__field__ == other.__field__ ) -U = TypeVar("U") - - -class StepReference(Generic[U], Reference): +class StepReference(FieldableMixin, Reference[U]): """Reference to an individual `Step`. Allows us to refer to the outputs of a `Step` in subsequent `Step` @@ -1087,22 +1179,30 @@ class StepReference(Generic[U], Reference): step: BaseStep _tethered_workflow: Workflow | None - _field: str | None - typ: type[U] - @property - def field(self) -> str: - """Field within the result. + class StepReferenceMetadata: + def __init__( + self, step: BaseStep, typ: type[U] | None = None + ): + """Initialize the reference. - Explicitly set field (within an attrs-class) or `out`. + Args: + workflow: `Workflow` that this is tied to. + step: `Step` that this refers to. + typ: the type that the step will output. + field: if provided, a specific field to pull out of an attrs result class. + """ + self.step = step + self._typ = typ - Returns: - Field name. - """ - return self._field or "out" + @property + def return_type(self): + return self._typ or self.step.return_type + + _: StepReferenceMetadata def __init__( - self, workflow: Workflow, step: BaseStep, typ: type[U], field: str | None = None + self, step: BaseStep, *args, typ: type[U] | None = None, **kwargs ): """Initialize the reference. @@ -1112,18 +1212,18 @@ def __init__( typ: the type that the step will output. field: if provided, a specific field to pull out of an attrs result class. """ - self.step = step - self._field = field - self.typ = typ self._tethered_workflow = None + typ = typ or step.return_type + self._ = self.StepReferenceMetadata(step, typ=typ) + super().__init__(*args, typ=typ, **kwargs) def __str__(self) -> str: """Global description of the reference.""" - return f"{self.step.id}/{self.field}" + return "/".join([self._.step.id] + list(self.__field__)) def __repr__(self) -> str: """Hashable reference to the step (and field).""" - return f"{self.step.id}/{self.field}" + return "/".join([self._.step.id] + list(self.__field__)) def __getattr__(self, attr: str) -> "StepReference[Any]": """Reference to a field within this result, if possible. @@ -1141,53 +1241,25 @@ def __getattr__(self, attr: str) -> "StepReference[Any]": AttributeError: if this field is not present in the dataclass. RuntimeError: if this field is not available, or we do not have a structured result. """ - if self._field is None: - typ: type | None - if attr_has(self.typ): - resolve_types(self.typ) - typ = getattr(attrs_fields(self.typ), attr).type - elif is_dataclass(self.typ): - matched = [ - field for field in dataclass_fields(self.typ) if field.name == attr - ] - if not matched: - raise AttributeError(f"Field {attr} not present in dataclass") - typ = matched[0].type - elif isinstance(self.step, FactoryCall): - typ = self.step.return_type - else: - typ = None - - if typ: - return self.__class__( - workflow=self.__workflow__, step=self.step, typ=typ, field=attr - ) - raise AttributeError( - "Can only get attribute of a StepReference representing an attrs-class or dataclass" - ) + try: + return self.find_field( + workflow=self.__workflow__, step=self._.step, field=attr + ) + except AttributeError as _: + return super().__getattribute__(attr) @property - def return_type(self) -> type[U]: - """Type that this step reference will resolve to. - - Returns: - Python type indicating the final result type. - """ - return self.typ + def __type__(self) -> type: + return self._.return_type @property - def name(self) -> str: + def __root_name__(self) -> str: """Reference based on the named step. May be remapped by the workflow to something nicer than the ID. """ - return f"{self.step.name}/{self.field}" - - @property - def __type__(self) -> Any: - """Type of the step's referenced value.""" - return self.step.return_type + return self._.step.name @property def __workflow__(self) -> Workflow: @@ -1212,12 +1284,8 @@ def __workflow__(self, workflow: Workflow) -> None: """ self._tethered_workflow = workflow if self._tethered_workflow: - if self.step not in self._tethered_workflow.steps: - self.step = self._tethered_workflow._indexed_steps[self.step.id] - - @__workflow__.setter - def __workflow__(self, workflow: Workflow) -> None: - self.step.set_workflow(workflow) + if self._.step not in self._tethered_workflow.steps: + self._.step = self._tethered_workflow._indexed_steps[self.step.id] def merge_workflows(*workflows: Workflow) -> Workflow: diff --git a/tests/test_cwl.py b/tests/test_cwl.py index b29b2f47..bc89cd1e 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -153,7 +153,7 @@ def test_cwl_with_parameter() -> None: workflow = construct(result) rendered = render(workflow)["__root__"] num_param = list(workflow.find_parameters())[0] - hsh = hasher(("increment", ("num", f"int|:param:{num_param.unique_name}"))) + hsh = hasher(("increment", ("num", f"int|:param:{num_param._.unique_name}"))) assert rendered == yaml.safe_load(f""" cwlVersion: 1.2 @@ -321,9 +321,9 @@ def test_cwl_references() -> None: rendered = render(workflow)["__root__"] num_param = list(workflow.find_parameters())[0] hsh_increment = hasher( - ("increment", ("num", f"int|:param:{num_param.unique_name}")) + ("increment", ("num", f"int|:param:{num_param._.unique_name}")) ) - hsh_double = hasher(("double", ("num", f"increment-{hsh_increment}/out"))) + hsh_double = hasher(("double", ("num", f"increment-{hsh_increment}"))) assert rendered == yaml.safe_load(f""" cwlVersion: 1.2 diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py new file mode 100644 index 00000000..51725ace --- /dev/null +++ b/tests/test_fieldable.py @@ -0,0 +1,64 @@ +import yaml +from dataclasses import dataclass +from dewret.tasks import task, construct, subworkflow +from dewret.workflow import param +from dewret.renderers.cwl import render + +from ._lib.extra import double, mod10, sum + +@dataclass +class Sides: + left: int + right: int + +SIDES: Sides = Sides(3, 6) + +@subworkflow() +def sum_sides(): + return sum(left=SIDES.left, right=SIDES.right) + +def test_fields_of_parameters_usable() -> None: + result = sum_sides() + workflow = construct(result, simplify_ids=True) + rendered = render(workflow, allow_complex_types=True)["sum_sides-1"] + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + SIDES/left: + label: SIDES/left + type: Sides + SIDES/right: + label: SIDES/right + type: Sides + outputs: + out: + label: out + outputSource: sum-1-1/out + type: + - int + - double + steps: + sum-1-1: + in: + left: + source: SIDES/left + right: + source: SIDES/right + out: + - out + run: sum + """) + +def test_can_get_field_reference_iff_parent_type_has_field(): + @dataclass + class MyDataclass: + left: int + my_param = param("my_param", typ=MyDataclass) + result = sum(left=my_param, right=my_param) + workflow = construct(result, simplify_ids=True) + param_reference = list(workflow.find_parameters())[0] + + assert str(param_reference.left) == "my_param/left" + assert param_reference.left.__type__ == int diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index c57b4f5c..f8da0030 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -59,13 +59,11 @@ def get_global_queues(num: int | float) -> list["Queue[int] | int"]: @subworkflow() def add_constant(num: int | float) -> int: """Add a global constant to a number.""" - print(CONSTANT, type(CONSTANT)) return to_int(num=sum(left=num, right=CONSTANT)) @subworkflow() def add_constants(num: int | float) -> int: """Add a global constant to a number.""" - print(CONSTANT, type(CONSTANT)) return to_int(num=sum(left=sum(left=num, right=CONSTANT), right=CONSTANT)) @@ -257,9 +255,8 @@ def test_subworkflows_can_return_lists() -> None: num: label: num type: int - sum-1-1-1-right: - default: 3 - label: sum-1-1-1-right + CONSTANT: + label: CONSTANT type: int outputs: out: @@ -272,7 +269,7 @@ def test_subworkflows_can_return_lists() -> None: left: source: num right: - source: sum-1-1-1-right + source: CONSTANT out: - out run: sum From a7c1fe94c01b932194191b3e3b52a153e224df7d Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 6 Aug 2024 21:46:34 +0100 Subject: [PATCH 012/108] fix: integrate flatten_all_nested configuration --- src/dewret/tasks.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index be124093..0f0c518b 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -248,13 +248,14 @@ def __call__( """ workflow = __workflow__ or Workflow() - context = copy_context().items() - def _initializer(): - for var, value in context: - var.set(value) - thread_pool = ThreadPoolExecutor(initializer=_initializer) - - result = self.evaluate(task, workflow, thread_pool=thread_pool, **kwargs) + with set_configuration(): + context = copy_context().items() + def _initializer(): + for var, value in context: + var.set(value) + thread_pool = ThreadPoolExecutor(initializer=_initializer) + + result = self.evaluate(task, workflow, thread_pool=thread_pool, **kwargs) return Workflow.from_result(result, simplify_ids=simplify_ids) @@ -455,7 +456,7 @@ def _fn( try: # Ensure that all arguments are passed as keyword args and prevent positional args. # passed at all. - if args: + if args and not get_configuration("allow_positional_args"): raise TypeError( f""" Calling {fn.__name__}: Arguments must _always_ be named, @@ -501,7 +502,9 @@ def add_numbers(left: int, right: int): var, value, tethered=( - False if nested and flatten_nested else None + False if nested and ( + flatten_nested or get_configuration("flatten_all_nested") + ) else None ), autoname=True, ), @@ -551,7 +554,7 @@ def {fn.__name__}(...) -> ...: f"Nested tasks must now only refer to global parameters, raw or tasks, not objects: {var}" ) if nested: - if flatten_nested: + if flatten_nested or get_configuration("flatten_all_nested"): output = fn(**original_kwargs) lazy_fn = ensure_lazy(output) if lazy_fn is None: From dee681509027f9d4617ed9c290f46435b2301de0 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 6 Aug 2024 22:32:47 +0100 Subject: [PATCH 013/108] fix: correctly handle flattened nested tasks --- src/dewret/tasks.py | 5 +++-- tests/test_configuration.py | 18 ++++++++---------- tests/test_errors.py | 2 +- tests/test_modularity.py | 6 +++--- tests/test_subworkflows.py | 5 +++++ 5 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 0f0c518b..43938451 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -73,7 +73,7 @@ class ConstructConfiguration(TypedDict): @contextmanager def set_configuration(**kwargs: Unpack[ConstructConfiguration]): try: - previous = CONSTRUCT_CONFIGURATION.get() + previous = ConstructConfiguration(**CONSTRUCT_CONFIGURATION.get()) except LookupError: previous = ConstructConfiguration( flatten_all_nested=False, @@ -555,7 +555,8 @@ def {fn.__name__}(...) -> ...: ) if nested: if flatten_nested or get_configuration("flatten_all_nested"): - output = fn(**original_kwargs) + with in_nested_task(): + output = analyser.with_new_globals(kwargs)(**original_kwargs) lazy_fn = ensure_lazy(output) if lazy_fn is None: raise TypeError( diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 12d2bd72..f656c386 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -40,28 +40,26 @@ def test_cwl_with_parameter(configuration) -> None: cwlVersion: 1.2 class: Workflow inputs: - floor-1-num: - label: floor-1-num + num: + label: num type: int default: 3 outputs: out: label: out - outputSource: increment-1/out + outputSource: increment-2/out type: int steps: - floor-1: - run: floor + increment-1: + run: increment in: - expected: - default: true num: - source: floor-1-num + source: num out: [out] - increment-1: + increment-2: run: increment in: num: - source: floor-1/out + source: increment-1/out out: [out] """) diff --git a/tests/test_errors.py b/tests/test_errors.py index ae49e07e..7208739f 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -245,5 +245,5 @@ def test_nested_tasks_must_return_a_task() -> None: construct(result) assert ( str(exc.value) - == "Task unacceptable_nested_return returned output of type , which is not a lazy function for this backend." + == "Task unacceptable_nested_return returned output of type , which is not a lazy function for this backend." ) diff --git a/tests/test_modularity.py b/tests/test_modularity.py index f19a067a..03119087 100644 --- a/tests/test_modularity.py +++ b/tests/test_modularity.py @@ -33,9 +33,9 @@ def test_nested_task() -> None: label: JUMP type: float default: 1.0 - increase-3-num: + STARTING_NUMBER: default: 23 - label: increase-3-num + label: STARTING_NUMBER type: int increase-1-num: default: 17 @@ -69,7 +69,7 @@ def test_nested_task() -> None: JUMP: source: JUMP num: - source: increase-3-num + source: STARTING_NUMBER out: [out] double-1: run: double diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index f8da0030..0167ad54 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -66,6 +66,11 @@ def add_constants(num: int | float) -> int: """Add a global constant to a number.""" return to_int(num=sum(left=sum(left=num, right=CONSTANT), right=CONSTANT)) +@subworkflow() +def get_values(num: int | float) -> tuple[int | float, int]: + """Add a global constant to a number.""" + return (sum(left=num, right=CONSTANT), add_constant(CONSTANT)) + def test_subworkflows_can_use_globals() -> None: """Produce a subworkflow that uses a global.""" From 363be8167990633a22218f144b6584f730236763 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 6 Aug 2024 22:46:54 +0100 Subject: [PATCH 014/108] fix: confirm tuples can be returned --- src/dewret/renderers/cwl.py | 2 ++ src/dewret/workflow.py | 2 +- tests/_lib/extra.py | 8 ++++++++ tests/test_cwl.py | 10 +--------- tests/test_subworkflows.py | 30 +++++++++++++++++++++++++++++- 5 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 8cf6e15c..1698a417 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -225,6 +225,8 @@ def to_cwl_type(typ: type) -> str | dict[str, Any] | list[str]: raise TypeError( f"Cannot render complex type ({typ}) to CWL, have you enabled allow_complex_types configuration?" ) from err + elif typ == tuple: + return "record" else: raise TypeError(f"Cannot render complex type ({typ}) to CWL") diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 9846e462..567cc467 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -891,7 +891,7 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, BaseStep): return False return ( - self.__workflow__ == other.__workflow__ + self.__workflow__ is other.__workflow__ and self.task == other.task and self.arguments == other.arguments ) diff --git a/tests/_lib/extra.py b/tests/_lib/extra.py index 955270de..c883afef 100644 --- a/tests/_lib/extra.py +++ b/tests/_lib/extra.py @@ -33,6 +33,14 @@ def sum(left: int | float, right: int | float) -> int | float: return left + right +@task() +def pi() -> float: + """Returns pi.""" + import math + + return math.pi + + @subworkflow() def triple_and_one(num: int | float) -> int | float: """Triple a number by doubling and adding again, then add 1.""" diff --git a/tests/test_cwl.py b/tests/test_cwl.py index bc89cd1e..cf795598 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -8,6 +8,7 @@ from dewret.workflow import param from ._lib.extra import ( + pi, increment, double, mod10, @@ -16,15 +17,6 @@ tuple_float_return, ) - -@task() -def pi() -> float: - """Returns pi.""" - import math - - return math.pi - - @task() def floor(num: int | float) -> int: """Converts int/float to int.""" diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 0167ad54..ffee8e76 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -7,7 +7,7 @@ from dewret.renderers.cwl import render from dewret.workflow import param -from ._lib.extra import increment, sum +from ._lib.extra import increment, sum, pi CONSTANT = 3 @@ -72,6 +72,34 @@ def get_values(num: int | float) -> tuple[int | float, int]: return (sum(left=num, right=CONSTANT), add_constant(CONSTANT)) +def test_cwl_for_pairs() -> None: + """Check whether we can produce CWL of pairs.""" + + @subworkflow() + def pair_pi(): + return (pi(), pi()) + + result = pair_pi() + workflow = construct(result, simplify_ids=True) + rendered = render(workflow)["__root__"] + + assert rendered == yaml.safe_load(f""" + cwlVersion: 1.2 + class: Workflow + inputs: {{}} + outputs: + out: + label: out + outputSource: pair_pi-1/out + type: record + steps: + pair_pi-1: + run: pair_pi + in: {{}} + out: [out] + """) + + def test_subworkflows_can_use_globals() -> None: """Produce a subworkflow that uses a global.""" my_param = param("num", typ=int) From 7aeccd9fe58f64eb65ff500279f8ab1c87d60f01 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 6 Aug 2024 23:07:52 +0100 Subject: [PATCH 015/108] fix: step reference works as tuple --- src/dewret/backends/backend_dask.py | 5 ++++- src/dewret/tasks.py | 5 ++++- tests/test_subworkflows.py | 32 ++++++++++++++++++----------- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index d5a82194..e5df1957 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -86,7 +86,10 @@ def is_lazy(task: Any) -> bool: Returns: True if so, False otherwise. """ - return isinstance(task, Delayed) + return isinstance(task, Delayed) or ( + isinstance(task, tuple | list) and + all(is_lazy(elt) for elt in task) + ) lazy = delayed diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 43938451..94ed2f7f 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -585,7 +585,10 @@ def {fn.__name__}(...) -> ...: step_reference = workflow.add_nested_step( fn.__name__, nested_workflow, kwargs ) - if isinstance(step_reference, StepReference): # RMV: What if it's a list? + if isinstance(step_reference, StepReference) or ( + isinstance(step_reference, tuple | list) and + all(isinstance(elt, StepReference) for elt in step_reference) + ): return cast(RetType, step_reference) raise TypeError( f"Nested tasks must return a step reference, not {type(step_reference)} to ensure graph makes sense." diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index ffee8e76..53eb23f3 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -3,7 +3,7 @@ from typing import Callable from queue import Queue import yaml -from dewret.tasks import construct, subworkflow, task, factory +from dewret.tasks import construct, subworkflow, task, factory, set_configuration from dewret.renderers.cwl import render from dewret.workflow import param @@ -76,25 +76,33 @@ def test_cwl_for_pairs() -> None: """Check whether we can produce CWL of pairs.""" @subworkflow() - def pair_pi(): - return (pi(), pi()) + def pair_pi() -> tuple[float, float]: + return pi(), pi() - result = pair_pi() - workflow = construct(result, simplify_ids=True) + with set_configuration(flatten_all_nested=True): + result = pair_pi() + workflow = construct(result, simplify_ids=True) rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(f""" cwlVersion: 1.2 class: Workflow inputs: {{}} - outputs: - out: - label: out - outputSource: pair_pi-1/out - type: record + outputs: [ + {{ + label: out, + outputSource: pi-1/out, + type: double + }}, + {{ + label: out, + outputSource: pi-1/out, + type: double + }} + ] steps: - pair_pi-1: - run: pair_pi + pi-1: + run: pi in: {{}} out: [out] """) From 03c8c8b0f6ae31bdcb5785cf28f1fe66f22c7567 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 6 Aug 2024 23:23:24 +0100 Subject: [PATCH 016/108] fix: tidy up type parsing --- src/dewret/tasks.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 94ed2f7f..cba3d85d 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -47,6 +47,7 @@ from .utils import is_raw, make_traceback from .workflow import ( + UNSET, Reference, StepReference, ParameterReference, @@ -507,6 +508,7 @@ def add_numbers(left: int, right: int): ) else None ), autoname=True, + typ=analyser.all_annotations.get(var, UNSET) ), ) elif isinstance(value, Parameter): @@ -521,13 +523,26 @@ def add_numbers(left: int, right: int): # raise TypeError( # "Captured parameter {var} (global variable in task) shadows an argument" # ) - if analyser.is_at_construct_arg(var) or isinstance(value, Reference): + if ( + analyser.is_at_construct_arg(var) or + isinstance(value, Reference) or + value is evaluate or value is construct # Allow manual building. + ): kwargs[var] = value elif isinstance(value, Parameter): kwargs[var] = ParameterReference(workflow=workflow, parameter=value) - elif is_raw(value) or ((attrs_has(value) or is_dataclass(value)) and not inspect.isclass(value)): + elif is_raw(value) or ( + (attrs_has(value) or is_dataclass(value)) and + not inspect.isclass(value) + ): kwargs[var] = ParameterReference( - workflow=workflow, parameter=param(var, value, tethered=False) + workflow=workflow, + parameter=param( + var, + value, + tethered=False, + typ=analyser.all_annotations.get(var, UNSET) + ) ) elif is_task(value) or ensure_lazy(value) is not None: if not nested and _workaround_check_value_is_task( From b50b1502b4e1f0444e13585428d4c43060af73b9 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Wed, 7 Aug 2024 00:26:31 +0100 Subject: [PATCH 017/108] fix: nested raw is parseable --- src/dewret/renderers/cwl.py | 148 +++++++++++++++++--------------- src/dewret/utils.py | 17 ++++ src/dewret/workflow.py | 20 +---- tests/test_multiresult_steps.py | 18 ++-- tests/test_nested.py | 39 +++++++++ tests/test_subworkflows.py | 24 +++--- 6 files changed, 159 insertions(+), 107 deletions(-) create mode 100644 tests/test_nested.py diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 1698a417..3917f2ed 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -19,29 +19,48 @@ """ from attrs import define, has as attrs_has, fields as attrs_fields, AttrsInstance -from dataclasses import dataclass, is_dataclass, fields as dataclass_fields +from dataclasses import is_dataclass, fields as dataclass_fields from collections.abc import Mapping from contextvars import ContextVar -from typing import TypedDict, NotRequired, get_args, Union, cast, Any, Iterable, Unpack +from typing import TypedDict, NotRequired, get_origin, get_args, Union, cast, Any, Iterable, Unpack from types import UnionType +from inspect import isclass from dewret.workflow import ( FactoryCall, Reference, - FieldableMixin, Raw, Workflow, BaseStep, - NestedStep, StepReference, ParameterReference, Unset, ) -from dewret.utils import RawType, flatten, DataclassProtocol +from dewret.utils import RawType, flatten, DataclassProtocol, flatten_if_set from dewret.render import base_render +class CommandInputSchema(TypedDict): + """Structure for referring to a raw type in CWL. + + Encompasses several CWL types. In future, it may be best to + use _cwltool_ or another library for these basic structures. + + Attributes: + type: CWL type of this input. + label: name to show for this input. + fields: (for `record`) individual fields in a dict-like structure. + items: (for `array`) type that each field will have. + """ + + type: "InputSchemaType" + label: str + fields: NotRequired[dict[str, "CommandInputSchema"]] + items: NotRequired["InputSchemaType"] + default: NotRequired[RawType] + + InputSchemaType = Union[ - str, "CommandInputSchema", list[str], list["InputSchemaType"], dict[str, str] + str, CommandInputSchema, list[str], list["InputSchemaType"], dict[str, "str | InputSchemaType"] ] @@ -62,8 +81,6 @@ class CWLRendererConfiguration(TypedDict): "allow_complex_types": False, "factories_as_params": False, } -CONFIGURATION.set({}) -CONFIGURATION.get().update(DEFAULT_CONFIGURATION) def configuration(key: str) -> Any: @@ -168,7 +185,7 @@ def render(self) -> dict[str, RawType]: } -def cwl_type_from_value(val: RawType | Unset) -> str | list[str] | dict[str, Any]: +def cwl_type_from_value(label: str, val: RawType | Unset) -> InputSchemaType: """Find a CWL type for a given (possibly Unset) value. Args: @@ -182,10 +199,10 @@ def cwl_type_from_value(val: RawType | Unset) -> str | list[str] | dict[str, Any else: raw_type = type(val) - return to_cwl_type(raw_type) + return to_cwl_type(label, raw_type)["type"] -def to_cwl_type(typ: type) -> str | dict[str, Any] | list[str]: +def to_cwl_type(label: str, typ: type) -> CommandInputSchema: """Map Python types to CWL types. Args: @@ -195,60 +212,54 @@ def to_cwl_type(typ: type) -> str | dict[str, Any] | list[str]: CWL specification type name, or a list if a union. """ - if typ == int: - return "int" - elif typ == bool: - return "boolean" - elif typ == dict or attrs_has(typ): - return "record" - elif typ == float: - return "float" - elif typ == str: - return "string" - elif typ == bytes: - return "bytes" - elif configuration("allow_complex_types"): - return typ if isinstance(typ, str) else typ.__name__ + typ_dict: CommandInputSchema = { + "label": label, + "type": "" + } + base: Any | None = typ + args = get_args(typ) + if args: + base = get_origin(typ) + + if base == type(None): + typ_dict["type"] = "null" + elif base == int: + typ_dict["type"] = "int" + elif base == bool: + typ_dict["type"] = "boolean" + elif base == dict or (isinstance(base, type) and attrs_has(base)): + typ_dict["type"] = "record" + elif base == float: + typ_dict["type"] = "float" + elif base == str: + typ_dict["type"] = "string" + elif base == bytes: + typ_dict["type"] = "bytes" elif isinstance(typ, UnionType): - return [to_cwl_type(item) for item in get_args(typ)] - elif isinstance(typ, Iterable): + typ_dict.update({"type": [to_cwl_type(label, item)["type"] for item in args]}) + elif isclass(base) and issubclass(base, Iterable): try: - basic_types = get_args(typ) - if len(basic_types) > 1: - return { + if len(args) > 1: + typ_dict.update({ + "type": "array", + "items": [to_cwl_type(label, t)["type"] for t in args], + }) + elif len(args) == 1: + typ_dict.update({ "type": "array", - "items": [{"type": to_cwl_type(t)} for t in basic_types], - } + "items": to_cwl_type(label, args[0])["type"] + }) else: - return {"type": "array", "items": to_cwl_type(basic_types[0])} + typ_dict["type"] = "array" except IndexError as err: raise TypeError( f"Cannot render complex type ({typ}) to CWL, have you enabled allow_complex_types configuration?" ) from err - elif typ == tuple: - return "record" + elif configuration("allow_complex_types"): + typ_dict["type"] = typ if isinstance(typ, str) else typ.__name__ else: - raise TypeError(f"Cannot render complex type ({typ}) to CWL") - - -class CommandInputSchema(TypedDict): - """Structure for referring to a raw type in CWL. - - Encompasses several CWL types. In future, it may be best to - use _cwltool_ or another library for these basic structures. - - Attributes: - type: CWL type of this input. - label: name to show for this input. - fields: (for `record`) individual fields in a dict-like structure. - items: (for `array`) type that each field will have. - """ - - type: InputSchemaType - label: str - fields: NotRequired[dict[str, "CommandInputSchema"]] - items: NotRequired[InputSchemaType] - default: NotRequired[RawType] + raise TypeError(f"Cannot render type ({typ}) to CWL") + return typ_dict class CommandOutputSchema(CommandInputSchema): @@ -278,9 +289,9 @@ def raw_to_command_input_schema(label: str, value: RawType | Unset) -> InputSche Structure used to define (possibly compound) basic types for input. """ if isinstance(value, dict) or isinstance(value, list): - return _raw_to_command_input_schema_internal(label, value) + return {"type": _raw_to_command_input_schema_internal(label, value)} else: - return cwl_type_from_value(value) + return cwl_type_from_value(label, value) def to_output_schema( @@ -324,8 +335,7 @@ def to_output_schema( ) else: output = CommandOutputSchema( - type=to_cwl_type(typ), - label=label, + **to_cwl_type(label, typ) ) if output_source is not None: output["outputSource"] = output_source @@ -335,7 +345,7 @@ def to_output_schema( def _raw_to_command_input_schema_internal( label: str, value: RawType | Unset ) -> CommandInputSchema: - typ = cwl_type_from_value(value) + typ = cwl_type_from_value(label, value) structure: CommandInputSchema = {"type": typ, "label": label} if isinstance(value, dict): structure["fields"] = { @@ -351,7 +361,7 @@ def _raw_to_command_input_schema_internal( "For CWL, an input array must have a consistent type, " "and we need at least one element to infer it, or an explicit typehint." ) - structure["items"] = to_cwl_type(typeset.pop()) + structure["items"] = to_cwl_type(label, typeset.pop())["type"] elif not isinstance(value, Unset): structure["default"] = value return structure @@ -397,9 +407,9 @@ def from_parameters( inputs={ input.__name__: cls.CommandInputParameter( label=input.__name__, - default=input.__default__, + default=(default := flatten_if_set(input.__default__)), type=raw_to_command_input_schema( - label=input.__name__, value=input.__default__ + label=input.__name__, value=default ), ) for input in parameters @@ -427,8 +437,8 @@ def render(self) -> dict[str, RawType]: return result -def to_name(result: FieldableMixin): - if not result.__field__ and isinstance(result, StepReference): +def to_name(result: Reference[Any]): + if hasattr(result, "__field__") and not result.__field__ and isinstance(result, StepReference): return f"{result.__name__}/out" return result.__name__ @@ -568,10 +578,12 @@ def render( Reduced form as a native Python dict structure for serialization. """ - CONFIGURATION.get().update(kwargs) + config = CWLRendererConfiguration(**DEFAULT_CONFIGURATION) + config.update(kwargs) + CONFIGURATION.set(config) rendered = base_render( workflow, lambda workflow: WorkflowDefinition.from_workflow(workflow).render() ) - CONFIGURATION.get().update(DEFAULT_CONFIGURATION) + CONFIGURATION.set(DEFAULT_CONFIGURATION) return rendered diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 05b3b8e8..bcad8871 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -29,6 +29,10 @@ FirmType = BasicType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] +class Unset: + """Unset variable, with no default value.""" + + class DataclassProtocol(Protocol): """Format of a dataclass. @@ -56,6 +60,19 @@ def make_traceback(skip: int = 2) -> TracebackType | None: return tb +def flatten_if_set(value: Any) -> RawType | Unset: + """Takes a Raw-like structure and makes it RawType or Unset. + + Flattens if the value is set, but otherwise returns the unset + sentinel value as-is. + + Args: + value: value to squash + """ + if isinstance(value, Unset): + return value + return flatten(value) + def flatten(value: Any) -> RawType: """Takes a Raw-like structure and makes it RawType. diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 567cc467..aea06ef6 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) -from .utils import hasher, RawType, is_raw, make_traceback, is_raw_type +from .utils import hasher, RawType, is_raw, make_traceback, is_raw_type, Unset T = TypeVar("T") U = TypeVar("U") @@ -111,10 +111,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> RetType: LazyFactory = Callable[[Target], Lazy] -class Unset: - """Unset variable, with no default value.""" - - class UnsetType(Unset, Generic[T]): """Unset variable with a specific type. @@ -1178,7 +1174,6 @@ class StepReference(FieldableMixin, Reference[U]): """ step: BaseStep - _tethered_workflow: Workflow | None class StepReferenceMetadata: def __init__( @@ -1212,7 +1207,6 @@ def __init__( typ: the type that the step will output. field: if provided, a specific field to pull out of an attrs result class. """ - self._tethered_workflow = None typ = typ or step.return_type self._ = self.StepReferenceMetadata(step, typ=typ) super().__init__(*args, typ=typ, **kwargs) @@ -1268,24 +1262,16 @@ def __workflow__(self) -> Workflow: Returns: Workflow that the referee is related to. """ - return self._tethered_workflow or self.step.__workflow__ + return self._.step.__workflow__ @__workflow__.setter def __workflow__(self, workflow: Workflow) -> None: """Sets related workflow. - We update the tethered workflow. If the step is missing from - this workflow then, by construction, it should have at least - been through an indexing process once, so we should be able - to get it back by name. - Args: workflow: workflow to update the step """ - self._tethered_workflow = workflow - if self._tethered_workflow: - if self._.step not in self._tethered_workflow.steps: - self._.step = self._tethered_workflow._indexed_steps[self.step.id] + self._.step.set_workflow(workflow) def merge_workflows(*workflows: Workflow) -> Workflow: diff --git a/tests/test_multiresult_steps.py b/tests/test_multiresult_steps.py index caefa8f8..2270ffab 100644 --- a/tests/test_multiresult_steps.py +++ b/tests/test_multiresult_steps.py @@ -246,7 +246,7 @@ def test_complex_field_of_nested_task_with_dataclasses() -> None: def test_pair_can_be_returned_from_step() -> None: """Tests whether a task can insert result fields into other steps.""" workflow = construct(algorithm_with_pair(), simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -256,11 +256,10 @@ def test_pair_can_be_returned_from_step() -> None: out: label: out outputSource: pair-1/out - type: - items: - - type: int - - type: float - type: array + items: + - int + - float + type: array steps: pair-1: in: @@ -286,7 +285,7 @@ def test_pair_can_be_returned_from_step() -> None: def test_list_can_be_returned_from_step() -> None: """Tests whether a task can insert result fields into other steps.""" workflow = construct(list_cast(iterable=algorithm_with_pair()), simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -296,9 +295,8 @@ def test_list_can_be_returned_from_step() -> None: out: label: out outputSource: list_cast-1/out - type: - items: float - type: array + items: float + type: array steps: list_cast-1: in: diff --git a/tests/test_nested.py b/tests/test_nested.py new file mode 100644 index 00000000..7cda5474 --- /dev/null +++ b/tests/test_nested.py @@ -0,0 +1,39 @@ +import yaml +from dewret.tasks import construct, task, factory +from dewret.renderers.cwl import render + +@task() +def reverse_list(to_sort: list[int]) -> list[int]: + return to_sort[::-1] + +def test_can_supply_nested_raw(): + result = reverse_list(to_sort=[1, 3, 5]) + workflow = construct(result, simplify_ids=True) + rendered = render(workflow)["__root__"] + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + reverse_list-1-to_sort: + default: [1, 3, 5] + label: reverse_list-1-to_sort + type: + type: + items: int + label: reverse_list-1-to_sort + type: array + outputs: + out: + items: int + label: out + outputSource: reverse_list-1/out + type: array + steps: + reverse_list-1: + in: + to_sort: + source: reverse_list-1-to_sort + out: + - out + run: reverse_list + """) diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 53eb23f3..b83f5b0b 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -80,30 +80,30 @@ def pair_pi() -> tuple[float, float]: return pi(), pi() with set_configuration(flatten_all_nested=True): - result = pair_pi() - workflow = construct(result, simplify_ids=True) + result = pair_pi() + workflow = construct(result, simplify_ids=True) rendered = render(workflow)["__root__"] - assert rendered == yaml.safe_load(f""" + assert rendered == yaml.safe_load(""" cwlVersion: 1.2 class: Workflow - inputs: {{}} + inputs: {} outputs: [ - {{ + { label: out, outputSource: pi-1/out, - type: double - }}, - {{ + type: float + }, + { label: out, outputSource: pi-1/out, - type: double - }} + type: float + } ] steps: pi-1: run: pi - in: {{}} + in: {} out: [out] """) @@ -397,7 +397,7 @@ def test_can_merge_workflows() -> None: outputSource: sum-1/out type: [ int, - double + float ] steps: increment-1: From 109dbf9379a8b84997bafe322f8d6b5f9b19b077 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Wed, 7 Aug 2024 12:56:27 +0100 Subject: [PATCH 018/108] fix(annotations): correct name to AtRender --- src/dewret/annotations.py | 4 ++-- tests/test_annotations.py | 38 ++++++++++++++++++------------------- tests/test_configuration.py | 4 ++-- tests/test_errors.py | 4 ++-- tests/test_fieldable.py | 18 ++++++++++-------- tests/test_nested.py | 4 +++- 6 files changed, 38 insertions(+), 34 deletions(-) diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index bbdd7bc5..3211b859 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -4,7 +4,7 @@ from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args, Mapping T = TypeVar("T") -AtConstruct = Annotated[T, "AtConstruct"] +AtRender = Annotated[T, "AtRender"] class FunctionAnalyser: _fn: Callable[..., Any] @@ -50,7 +50,7 @@ def argument_has(self, arg: str, annotation: type) -> bool: return False def is_at_construct_arg(self, arg: str) -> bool: - return self.argument_has(arg, AtConstruct) + return self.argument_has(arg, AtRender) @property @lru_cache diff --git a/tests/test_annotations.py b/tests/test_annotations.py index cdf2a112..e9b209ac 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -3,22 +3,22 @@ from dewret.tasks import task, construct, subworkflow, TaskException from dewret.renderers.cwl import render -from dewret.annotations import AtConstruct, FunctionAnalyser +from dewret.annotations import AtRender, FunctionAnalyser from ._lib.extra import increment, sum -ARG1: AtConstruct[bool] = True +ARG1: AtRender[bool] = True ARG2: bool = False class MyClass: - def method(self, arg1: bool, arg2: AtConstruct[int]) -> float: + def method(self, arg1: bool, arg2: AtRender[int]) -> float: arg3: float = 7.0 - arg4: AtConstruct[float] = 8.0 + arg4: AtRender[float] = 8.0 return arg1 + arg2 + arg3 + arg4 + int(ARG1) + int(ARG2) -def fn(arg5: int, arg6: AtConstruct[int]) -> float: +def fn(arg5: int, arg6: AtRender[int]) -> float: arg7: float = 7.0 - arg8: AtConstruct[float] = 8.0 + arg8: AtRender[float] = 8.0 return arg5 + arg6 + arg7 + arg8 + int(ARG1) + int(ARG2) @@ -28,7 +28,7 @@ def to_int_bad(num: int, should_double: bool) -> int | float: return increment(num=num) if should_double else sum(left=num, right=num) @subworkflow() -def to_int(num: int, should_double: AtConstruct[bool]) -> int | float: +def to_int(num: int, should_double: AtRender[bool]) -> int | float: """Cast to an int.""" return increment(num=num) if should_double else sum(left=num, right=num) @@ -36,20 +36,20 @@ def test_can_analyze_annotations(): my_obj = MyClass() analyser = FunctionAnalyser(my_obj.method) - assert analyser.argument_has("arg1", AtConstruct) is False - assert analyser.argument_has("arg3", AtConstruct) is False - assert analyser.argument_has("ARG2", AtConstruct) is False - assert analyser.argument_has("arg2", AtConstruct) is True - assert analyser.argument_has("arg4", AtConstruct) is False # Not a global/argument - assert analyser.argument_has("ARG1", AtConstruct) is True + assert analyser.argument_has("arg1", AtRender) is False + assert analyser.argument_has("arg3", AtRender) is False + assert analyser.argument_has("ARG2", AtRender) is False + assert analyser.argument_has("arg2", AtRender) is True + assert analyser.argument_has("arg4", AtRender) is False # Not a global/argument + assert analyser.argument_has("ARG1", AtRender) is True analyser = FunctionAnalyser(fn) - assert analyser.argument_has("arg5", AtConstruct) is False - assert analyser.argument_has("arg7", AtConstruct) is False - assert analyser.argument_has("ARG2", AtConstruct) is False - assert analyser.argument_has("arg2", AtConstruct) is True - assert analyser.argument_has("arg8", AtConstruct) is False # Not a global/argument - assert analyser.argument_has("ARG1", AtConstruct) is True + assert analyser.argument_has("arg5", AtRender) is False + assert analyser.argument_has("arg7", AtRender) is False + assert analyser.argument_has("ARG2", AtRender) is False + assert analyser.argument_has("arg2", AtRender) is True + assert analyser.argument_has("arg8", AtRender) is False # Not a global/argument + assert analyser.argument_has("ARG1", AtRender) is True def test_at_construct() -> None: result = to_int_bad(num=increment(num=3), should_double=True) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index f656c386..b6b8cf06 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -4,7 +4,7 @@ from dewret.renderers.cwl import render from dewret.utils import hasher from dewret.tasks import set_configuration -from dewret.annotations import AtConstruct +from dewret.annotations import AtRender from ._lib.extra import increment, double, mod10, sum, triple_and_one @pytest.fixture @@ -13,7 +13,7 @@ def configuration(): yield configuration.get() @subworkflow() -def floor(num: int, expected: AtConstruct[bool]) -> int: +def floor(num: int, expected: AtRender[bool]) -> int: """Converts int/float to int.""" from dewret.tasks import get_configuration if get_configuration("flatten_all_nested") != expected: diff --git a/tests/test_errors.py b/tests/test_errors.py index 7208739f..aa0a2be5 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -3,7 +3,7 @@ import pytest from dewret.workflow import Task, Lazy from dewret.tasks import construct, task, nested_task, TaskException -from dewret.annotations import AtConstruct +from dewret.annotations import AtRender from ._lib.extra import increment # noqa: F401 @@ -96,7 +96,7 @@ def unacceptable_object_usage() -> int: @nested_task() -def unacceptable_nested_return(int_not_global: AtConstruct[bool]) -> int | Lazy: +def unacceptable_nested_return(int_not_global: AtRender[bool]) -> int | Lazy: """Bad nested_task that fails to return a task.""" add_task(left=3, right=4) return 7 if int_not_global else ADD_TASK_LINE_NO diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index 51725ace..4b6ae194 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -26,12 +26,12 @@ def test_fields_of_parameters_usable() -> None: class: Workflow cwlVersion: 1.2 inputs: - SIDES/left: - label: SIDES/left - type: Sides - SIDES/right: - label: SIDES/right - type: Sides + SIDES: + label: SIDES + type: record + items: + left: int + right: int outputs: out: label: out @@ -43,9 +43,11 @@ def test_fields_of_parameters_usable() -> None: sum-1-1: in: left: - source: SIDES/left + source: SIDES + valueFrom: $(self.left) right: - source: SIDES/right + source: SIDES + valueFrom: $(self.right) out: - out run: sum diff --git a/tests/test_nested.py b/tests/test_nested.py index 7cda5474..e402d4bc 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -2,12 +2,14 @@ from dewret.tasks import construct, task, factory from dewret.renderers.cwl import render +from ._lib.extra import pi + @task() def reverse_list(to_sort: list[int]) -> list[int]: return to_sort[::-1] def test_can_supply_nested_raw(): - result = reverse_list(to_sort=[1, 3, 5]) + result = reverse_list(to_sort=[1, 3, pi()]) workflow = construct(result, simplify_ids=True) rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" From 6cb2ca1d9411da9ec34fb1054ed1bf29b1c9e218 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 10 Aug 2024 19:53:06 +0100 Subject: [PATCH 019/108] feat(sympy): handle arithmetic/structure manipulation of references --- src/dewret/annotations.py | 2 - src/dewret/core.py | 73 +++++++++ src/dewret/renderers/cwl.py | 17 ++- src/dewret/tasks.py | 31 +--- src/dewret/utils.py | 42 +++++- src/dewret/workflow.py | 252 ++++++++++++++++---------------- tests/_lib/extra.py | 8 + tests/test_configuration.py | 8 +- tests/test_cwl.py | 92 ++++++------ tests/test_errors.py | 5 +- tests/test_fieldable.py | 4 + tests/test_modularity.py | 14 +- tests/test_multiresult_steps.py | 2 +- tests/test_nested.py | 49 ++++--- tests/test_subworkflows.py | 8 +- 15 files changed, 358 insertions(+), 249 deletions(-) create mode 100644 src/dewret/core.py diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index 3211b859..69996668 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -20,7 +20,6 @@ def __init__(self, fn: Callable[..., Any]): ) @property - @lru_cache def all_annotations(self): try: self._annotations = self.fn.__globals__["__annotations__"] @@ -53,7 +52,6 @@ def is_at_construct_arg(self, arg: str) -> bool: return self.argument_has(arg, AtRender) @property - @lru_cache def globals(self) -> Mapping[str, Any]: try: fn_globals = inspect.getclosurevars(self.fn).globals diff --git a/src/dewret/core.py b/src/dewret/core.py new file mode 100644 index 00000000..dfa80ec2 --- /dev/null +++ b/src/dewret/core.py @@ -0,0 +1,73 @@ +from typing import Generic, TypeVar, Protocol +from sympy import Expr, Symbol + +U = TypeVar("U") + +class WorkflowProtocol(Protocol): + ... + +class UnevaluatableError(Exception): + ... + + +class Reference(Generic[U], Symbol): + """Superclass for all symbolic references to values.""" + + _type: type[U] | None = None + __workflow__: WorkflowProtocol + + def __init__(self, *args, typ: type[U] | None = None, **kwargs): + self._type = typ + super().__init__() + self.name = self.__root_name__ + + + def __new__(cls, *args, **kwargs): + instance = Expr.__new__(cls) + instance._assumptions0 = {} + return instance + + @property + def __root_name__(self) -> str: + raise NotImplementedError( + "Reference must have a '__root_name__' property or override '__name__'" + ) + + @property + def __type__(self): + if self._type is not None: + return self._type + raise NotImplementedError() + + def _raise_unevaluatable_error(self): + raise UnevaluatableError(f"This reference, {self.__name__}, cannot be evaluated during construction.") + + def __eq__(self, other) -> bool: + if isinstance(other, list) or other is None: + return False + if not isinstance(other, Reference): + self._raise_unevaluatable_error() + return super().__eq__(other) + + def __float__(self) -> bool: + self._raise_unevaluatable_error() + return False + + def __int__(self) -> bool: + self._raise_unevaluatable_error() + return False + + def __bool__(self) -> bool: + self._raise_unevaluatable_error() + return False + + @property + def __name__(self) -> str: + """Referral name for this reference.""" + workflow = self.__workflow__ + name = self.__root_name__ + return workflow.remap(name) + + def __str__(self) -> str: + """Global description of the reference.""" + return self.__name__ diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 3917f2ed..fe89cf2c 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -25,6 +25,7 @@ from typing import TypedDict, NotRequired, get_origin, get_args, Union, cast, Any, Iterable, Unpack from types import UnionType from inspect import isclass +from sympy import Expr, Basic, Tuple from dewret.workflow import ( FactoryCall, @@ -36,7 +37,7 @@ ParameterReference, Unset, ) -from dewret.utils import RawType, flatten, DataclassProtocol, flatten_if_set +from dewret.utils import RawType, flatten, DataclassProtocol, firm_to_raw, FirmType, flatten_if_set from dewret.render import base_render class CommandInputSchema(TypedDict): @@ -176,7 +177,9 @@ def render(self) -> dict[str, RawType]: "in": { key: ( ref.render() - if isinstance(ref, ReferenceDefinition) + if isinstance(ref, ReferenceDefinition) else + {"expression": f"$({ref})"} + if isinstance(ref, Basic) else {"default": ref.value} ) for key, ref in self.in_.items() @@ -236,7 +239,7 @@ def to_cwl_type(label: str, typ: type) -> CommandInputSchema: elif base == bytes: typ_dict["type"] = "bytes" elif isinstance(typ, UnionType): - typ_dict.update({"type": [to_cwl_type(label, item)["type"] for item in args]}) + typ_dict.update({"type": tuple(to_cwl_type(label, item)["type"] for item in args)}) elif isclass(base) and issubclass(base, Iterable): try: if len(args) > 1: @@ -355,7 +358,7 @@ def _raw_to_command_input_schema_internal( elif isinstance(value, list): typeset = set(get_args(value)) if not typeset: - typeset = {type(item) for item in value} + typeset = {item.__type__ if item is not None and hasattr(item, "__type__") else type(item) for item in value} if len(typeset) != 1: raise RuntimeError( "For CWL, an input array must have a consistent type, " @@ -428,7 +431,7 @@ def render(self) -> dict[str, RawType]: item = { # Would rather not cast, but CommandInputSchema is dict[RawType] # by construction, where type is seen as a TypedDict subclass. - "type": cast(RawType, input.type), + "type": firm_to_raw(cast(FirmType, input.type)), "label": input.label, } if not isinstance(input.default, Unset): @@ -471,7 +474,7 @@ def from_results( "/".join(result.__field__) or "out", result.__type__, output_source=to_name(result) ) for result in results ] - if isinstance(results, list | tuple) else { + if isinstance(results, list | tuple | Tuple) else { key: to_output_schema( "/".join(result.__field__) or "out", result.__type__, output_source=to_name(result) ) @@ -541,7 +544,7 @@ def from_workflow( inputs=InputsDefinition.from_parameters(parameters), outputs=OutputsDefinition.from_results( workflow.result - if isinstance(workflow.result, list | tuple) else + if isinstance(workflow.result, list | tuple | Tuple) else {"/".join(workflow.result.__field__) or "out": workflow.result} if workflow.has_result else {} diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index cba3d85d..64ef2377 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -45,8 +45,10 @@ from contextvars import ContextVar, copy_context from contextlib import contextmanager -from .utils import is_raw, make_traceback +from .utils import is_raw, make_traceback, is_expr from .workflow import ( + expr_to_references, + unify_workflows, UNSET, Reference, StepReference, @@ -174,22 +176,7 @@ def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy], __workflow__: Workflow **kwargs: any arguments to pass to the task. """ result = self.backend.run(__workflow__, task, thread_pool=thread_pool, **kwargs) - to_check: list[StepReference] | tuple[StepReference] - if isinstance(result, list | tuple): - to_check = result - else: - to_check = [result] - - # Build a unified workflow - collected_workflow = __workflow__ or to_check[0].__workflow__ - for step_result in to_check: - new_workflow = step_result.__workflow__ - if collected_workflow != new_workflow and collected_workflow and new_workflow: - collected_workflow = Workflow.assimilate(collected_workflow, new_workflow) - - # Make sure all the results share it - for step_result in to_check: - step_result.__workflow__ = collected_workflow + result, collected_workflow = unify_workflows(result, __workflow__) # Then we set the result to be the whole thing collected_workflow.set_result(result) @@ -475,9 +462,10 @@ def add_numbers(left: int, right: int): sig = inspect.signature(fn) sig.bind(*args, **kwargs) + _, refs = expr_to_references(kwargs.values(), include_parameters=True) workflows = [ reference.__workflow__ - for reference in kwargs.values() + for reference in refs if hasattr(reference, "__workflow__") and reference.__workflow__ is not None ] @@ -511,8 +499,6 @@ def add_numbers(left: int, right: int): typ=analyser.all_annotations.get(var, UNSET) ), ) - elif isinstance(value, Parameter): - kwargs[var] = ParameterReference(workflow=workflow, parameter=value) original_kwargs = dict(kwargs) fn_globals = analyser.globals @@ -600,10 +586,7 @@ def {fn.__name__}(...) -> ...: step_reference = workflow.add_nested_step( fn.__name__, nested_workflow, kwargs ) - if isinstance(step_reference, StepReference) or ( - isinstance(step_reference, tuple | list) and - all(isinstance(elt, StepReference) for elt in step_reference) - ): + if is_expr(step_reference): return cast(RetType, step_reference) raise TypeError( f"Nested tasks must return a step reference, not {type(step_reference)} to ensure graph makes sense." diff --git a/src/dewret/utils.py b/src/dewret/utils.py index bcad8871..0979ac7b 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -21,8 +21,11 @@ import json import sys from types import FrameType, TracebackType -from typing import Any, cast, Union, Protocol, ClassVar +from typing import Any, cast, Union, Protocol, ClassVar, Callable, Iterable from collections.abc import Sequence, Mapping +from sympy import Basic, Integer, Float, Rational + +from .core import Reference BasicType = str | float | bool | bytes | int | None RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] @@ -73,7 +76,7 @@ def flatten_if_set(value: Any) -> RawType | Unset: return value return flatten(value) -def flatten(value: Any) -> RawType: +def crawl_raw(value: Any, action: Callable[[Any], Any]) -> RawType: """Takes a Raw-like structure and makes it RawType. Particularly useful for squashing any TypedDicts. @@ -81,6 +84,9 @@ def flatten(value: Any) -> RawType: Args: value: value to squash """ + + value = action(value) + if value is None: return value if isinstance(value, str) or isinstance(value, bytes): @@ -93,13 +99,21 @@ def flatten(value: Any) -> RawType: return raw raise RuntimeError(f"Could not flatten: {value}") +def firm_to_raw(value: FirmType) -> RawType: + return crawl_raw(value, lambda entry: list(entry) if isinstance(entry, tuple) else entry) + +def flatten(value: Any) -> RawType: + return crawl_raw(value, lambda entry: entry) + +def is_expr(value: Any) -> bool: + return is_raw(value, lambda x: isinstance(x, Basic)) def is_raw_type(typ: type) -> bool: """Check if a type counts as "raw".""" return issubclass(typ, str | float | bool | bytes | int | None | list | dict) -def is_raw(value: Any) -> bool: +def is_raw(value: Any, check: Callable[[Any], bool] | None = None) -> bool: """Check if a variable counts as "raw". This works around a checking issue that isinstance of a union of types @@ -109,10 +123,26 @@ def is_raw(value: Any) -> bool: # Ideally this would be: # isinstance(value, RawType | list[RawType] | dict[str, RawType]) # but recursive types are problematic. - return isinstance(value, str | float | bool | bytes | int | None | list | dict) + if isinstance(value, str | float | bool | bytes | int | None | Integer | Float | Rational): + return True + + if isinstance(value, Mapping): + return ( + (isinstance(value, dict) or (check is not None and check(value))) and + all(is_raw(key, check) for key in value.keys()) and + all(is_raw(val, check) for val in value.values()) + ) + + if isinstance(value, Iterable): + return ( + (isinstance(value, list) or (check is not None and check(value))) and + all(is_raw(key, check) for key in value) + ) + + return check is not None and check(value) -def ensure_raw(value: Any) -> RawType | None: +def ensure_raw(value: Any, cast_tuple: bool = False) -> RawType | None: """Check if a variable counts as "raw". This works around a checking issue that isinstance of a union of types @@ -138,7 +168,7 @@ def hasher(construct: FirmType) -> str: have not yet been explicitly calculated. """ if isinstance(construct, Sequence) and not isinstance(construct, bytes | str): - if isinstance(construct, dict): + if isinstance(construct, Mapping): construct = list([k, hasher(v)] for k, v in sorted(construct.items())) else: # Cast to workaround recursive type diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index aea06ef6..48c8b4d3 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -23,24 +23,47 @@ import base64 from attrs import define, has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields -from collections import Counter -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated +from collections import Counter, OrderedDict +from types import GeneratorType +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable from uuid import uuid4 +from sympy import Symbol, Expr, Basic, sympify, Tuple import logging logger = logging.getLogger(__name__) -from .utils import hasher, RawType, is_raw, make_traceback, is_raw_type, Unset +from .core import Reference +from .utils import hasher, RawType, is_raw, make_traceback, is_raw_type, is_expr, Unset T = TypeVar("T") U = TypeVar("U") RetType = TypeVar("RetType") -class UnevaluatableError(Exception): - ... +def all_references_from(value: Any): + all_references: set = set() + + # If Raw, we examine the internal value. + # In theory, this should not contain a reference, + # but this makes all_references_from useful for error-checking. + if isinstance(value, Raw): + value = value.value + + if isinstance(value, Reference): + all_references.add(value) + elif isinstance(value, Basic): + symbols = value.free_symbols + if not all(isinstance(sym, Reference) for sym in symbols): + raise RuntimeError("Can only use symbols that are references to e.g. step or parameter.") + all_references |= symbols + elif isinstance(value, Mapping): + all_references |= all_references_from(value.keys()) + all_references |= all_references_from(value.values()) + elif isinstance(value, Iterable) and not isinstance(value, str | bytes): + all_references |= set().union(*(all_references_from(entry) for entry in value)) + return all_references @define class Raw: @@ -133,7 +156,7 @@ def __init__(self, raw_type: type[T]): -class Parameter(Generic[T]): +class Parameter(Generic[T], Symbol): """Global parameter. Independent parameter that will be used when a task is spotted @@ -171,8 +194,8 @@ def __init__( self.__original_name__ = name # TODO: is using this in a step hash a risk of ambiguity? (full name is circular) - if autoname: - name = f"{name}-{uuid4()}" + #if autoname: + # name = f"{name}-{uuid4()}" self.autoname = autoname self.__name__ = name @@ -189,21 +212,25 @@ def __init__( else: raw_type = type(default) self.__type__: type[T] = raw_type - if self.__type__ == type: - asdffdsa - if tethered and isinstance(tethered, BaseStep): self.register_caller(tethered) def __eq__(self, other): + if isinstance(other, ParameterReference) and other._.parameter == self and not other.__field__: + return True return hash(self) == hash(other) + def __new__(cls, *args, **kwargs): + instance = Expr.__new__(cls) + instance._assumptions0 = {} + return instance + def __hash__(self) -> int: """Get a unique hash for this parameter.""" - if self.__tethered__ is None: - raise RuntimeError( - f"Parameter {self.name} was never tethered but should have been" - ) + # if self.__tethered__ is None: + # raise RuntimeError( + # f"Parameter {self.name} was never tethered but should have been" + # ) return hash(self.__name__) @property @@ -346,9 +373,12 @@ def __str__(self) -> str: return super().__str__() return self.name + def __repr__(self) -> str: + return self.name + def __hash__(self) -> int: """Hashes for finding.""" - return hash(self.name) + return hash(repr(self)) def __eq__(self, other: object) -> bool: """Is this the same workflow? @@ -396,19 +426,10 @@ def find_parameters( Returns: Set of all references to parameters across the steps. """ - return set().union( - *( - { - arg - for arg in step.arguments.values() - if ( - isinstance(arg, ParameterReference) - and (include_factory_calls or not isinstance(step, FactoryCall)) - ) - } - for step in self.steps - ) + references = all_references_from( + step.arguments for step in self.steps if (include_factory_calls or not isinstance(step, FactoryCall)) ) + return {ref for ref in references if isinstance(ref, ParameterReference)} @property def _indexed_steps(self) -> dict[str, BaseStep]: @@ -421,7 +442,7 @@ def _indexed_steps(self) -> dict[str, BaseStep]: Returns: Mapping of steps by ID. """ - return {step.id: step for step in self.steps} + return OrderedDict(sorted((step.id, step) for step in self.steps)) @classmethod def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": @@ -446,9 +467,8 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": for step in list(left_steps.values()) + list(right_steps.values()): step.set_workflow(new) - for arg in step.arguments: - if hasattr(arg, "__workflow__"): - arg.__workflow__ = new + for arg in step.arguments.values(): + unify_workflows(arg, new, set_only=True) for step_id in left_steps.keys() & right_steps.keys(): if left_steps[step_id] != right_steps[step_id]: @@ -483,11 +503,7 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": result = list(left.result) + list(right.result) if result is not None and result != []: - if isinstance(result, list | tuple): - for entry in result: - entry.__workflow__ = new - else: - result.__workflow__ = new + unify_workflows(result, new, set_only=True) new.set_result(result) return new @@ -611,14 +627,13 @@ def from_result( Starts from a result, and builds a workflow to output it. """ - if isinstance(result, list | tuple): - workflow = result[0].__workflow__ - # Ensure that we have exactly one workflow, even if multiple results. - for entry in result[1:]: - if entry.__workflow__ != workflow: - raise RuntimeError("If multiple results, they must share a single workflow") - else: - workflow = result.__workflow__ + result, refs = expr_to_references(result) + refs = list(refs) + workflow = refs[0].__workflow__ + # Ensure that we have exactly one workflow, even if multiple results. + for entry in refs[1:]: + if entry.__workflow__ != workflow: + raise RuntimeError("If multiple results, they must share a single workflow") workflow.set_result(result) if simplify_ids: workflow.simplify_ids() @@ -636,11 +651,8 @@ def set_result(self, result: StepReference[Any] | list[StepReference[Any]] | tup Args: result: reference to the chosen step. """ - if isinstance(result, list | tuple): - to_check = result - else: - to_check = [result] - for entry in to_check: + _, refs = expr_to_references(result) + for entry in refs: if entry._.step.__workflow__ != self: raise RuntimeError("Output must be from a step in this workflow.") self.result = result @@ -649,10 +661,10 @@ def set_result(self, result: StepReference[Any] | list[StepReference[Any]] | tup def result_type(self): if self.result is None: return type(None) - if isinstance(self.result, tuple | list): - # TODO: get individual types! - return type(self.result) - return self.result.__type__ + if hasattr(self.result, "__type__"): + return self.result.__type__ + # TODO: get individual types! + return type(self.result) class WorkflowComponent: @@ -692,64 +704,6 @@ def __workflow__(self) -> Workflow: ... -class Reference(Generic[U]): - """Superclass for all symbolic references to values.""" - - _type: type[U] | None = None - __workflow__: Workflow - - def __init__(self, *args, typ: type[U] | None = None, **kwargs): - self._type = typ - if typ == type: - asdf - super().__init__() - - @property - def __root_name__(self) -> str: - raise NotImplementedError( - "Reference must have a '__root_name__' property or override '__name__'" - ) - - @property - def __type__(self): - if self._type is not None: - return self._type - raise NotImplementedError() - - def _raise_unevaluatable_error(self): - raise UnevaluatableError(f"This reference, {self.__name__}, cannot be evaluated during construction.") - - def __eq__(self, other) -> bool: - if isinstance(other, list) or other is None: - return False - if not isinstance(other, Reference): - self._raise_unevaluatable_error() - return super().__eq__(other) - - def __float__(self) -> bool: - self._raise_unevaluatable_error() - return False - - def __int__(self) -> bool: - self._raise_unevaluatable_error() - return False - - def __bool__(self) -> bool: - self._raise_unevaluatable_error() - return False - - @property - def __name__(self) -> str: - """Referral name for this reference.""" - workflow = self.__workflow__ - name = self.__root_name__ - return workflow.remap(name) - - def __str__(self) -> str: - """Global description of the reference.""" - return self.__name__ - - class FieldableProtocol(Protocol): __field__: tuple[str, ...] @@ -855,6 +809,7 @@ def __init__( or isinstance(value, Reference) or isinstance(value, Raw) or is_raw(value) + or is_expr(value) ): # Avoid recursive type issues if ( @@ -869,9 +824,21 @@ def __init__( ) else: value = Raw(value) - if isinstance(value, ParameterReference): - parameter = value._.parameter - parameter.register_caller(self) + + expr, refs = expr_to_references(value, include_parameters=True) + if expr is not None: + for ref in set(refs): + if isinstance(ref, Parameter): + new_ref = ParameterReference(workflow=workflow, parameter=ref) + expr = expr.subs(ref, new_ref) + refs.remove(ref) + refs.append(new_ref) + value = expr + + for ref in refs: + if isinstance(ref, ParameterReference): + parameter = ref._.parameter + parameter.register_caller(self) self.arguments[key] = value else: raise RuntimeError( @@ -906,11 +873,7 @@ def set_workflow(self, workflow: Workflow, with_arguments: bool = True) -> None: self.__workflow__ = workflow if with_arguments: for argument in self.arguments.values(): - if hasattr(argument, "__workflow__"): - try: - argument.__workflow__ = workflow - except AttributeError: - ... + unify_workflows(argument, workflow, set_only=True) @property def return_type(self) -> Any: @@ -1126,7 +1089,9 @@ def __getattr__(self, attr: str) -> "ParameterReference": workflow=self.__workflow__, parameter=self._.parameter ) - except AttributeError as _: + except AttributeError as exc: + if not "dask_graph" in str(exc): + raise return super().__getattribute__(attr) def __getitem__(self, attr: str) -> "ParameterReference": @@ -1158,8 +1123,10 @@ def __eq__(self, other: object) -> bool: Returns: True if the other parameter reference is materially the same, otherwise False. """ + # We are equal to a parameter if we are a direct, fieldless, reference to it. return ( - isinstance(other, ParameterReference) and self._.parameter == other._.parameter and self.__field__ == other.__field__ + (isinstance(other, Parameter) and self._.parameter == other and not self.__field__) or + (isinstance(other, ParameterReference) and self._.parameter == other._.parameter and self.__field__ == other.__field__) ) @@ -1219,6 +1186,9 @@ def __repr__(self) -> str: """Hashable reference to the step (and field).""" return "/".join([self._.step.id] + list(self.__field__)) + def __hash__(self) -> int: + return hash((repr(self), id(self.__workflow__))) + def __getattr__(self, attr: str) -> "StepReference[Any]": """Reference to a field within this result, if possible. @@ -1305,3 +1275,41 @@ def is_task(task: Lazy) -> bool: True if `task` is indeed a task. """ return isinstance(task, LazyEvaluation) + +def expr_to_references(expression: Any, include_parameters: bool=False) -> tuple[Basic | None, set[Reference | Parameter]]: + if isinstance(expression, Raw) or is_raw(expression): + return expression, set() + + def _to_expr(value): + if not isinstance(value, str | bytes) and isinstance(value, Iterable): + return Tuple(*(_to_expr(entry) for entry in value)) + return sympify(value) + + if not isinstance(expression, Basic): + expression = _to_expr(expression) + + symbols = list(expression.free_symbols) + to_check = [sym for sym in symbols if isinstance(sym, Reference) or (include_parameters and isinstance(sym, Parameter))] + if {sym for sym in symbols if not is_raw(sym)} != set(to_check): + raise RuntimeError("The only symbols allowed are references (to e.g. step or parameter)") + return expression, to_check + +def unify_workflows(expression: Any, base_workflow: Workflow | None, set_only: bool = False) -> Workflow | None: + expression, to_check = expr_to_references(expression) + if not to_check: + return expression, base_workflow + + # Build a unified workflow + collected_workflow = base_workflow or next(iter(to_check)).__workflow__ + if not set_only: + for step_result in to_check: + new_workflow = step_result.__workflow__ + if collected_workflow != new_workflow and collected_workflow and new_workflow: + collected_workflow = Workflow.assimilate(collected_workflow, new_workflow) + + # Make sure all the results share it + for step_result in to_check: + step_result.__workflow__ = collected_workflow + expression = expression.subs(step_result, step_result) + + return expression, collected_workflow diff --git a/tests/_lib/extra.py b/tests/_lib/extra.py index c883afef..ac18c8a0 100644 --- a/tests/_lib/extra.py +++ b/tests/_lib/extra.py @@ -51,3 +51,11 @@ def triple_and_one(num: int | float) -> int | float: def tuple_float_return() -> tuple[float, float]: """Return a tuple of floats.""" return 48.856667, 2.351667 + +@task() +def reverse_list(to_sort: list[int | float]) -> list[int | float]: + return to_sort[::-1] + +@task() +def max_list(lst: list[int | float]) -> int | float: + return max(lst) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index b6b8cf06..2d99d5ba 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -47,19 +47,19 @@ def test_cwl_with_parameter(configuration) -> None: outputs: out: label: out - outputSource: increment-2/out + outputSource: increment-1/out type: int steps: - increment-1: + increment-2: run: increment in: num: source: num out: [out] - increment-2: + increment-1: run: increment in: num: - source: increment-1/out + source: increment-2/out out: [out] """) diff --git a/tests/test_cwl.py b/tests/test_cwl.py index cf795598..3c69f39a 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -139,7 +139,7 @@ def test_cwl_with_parameter() -> None: """Check whether we can move raw input to parameters. Produces CWL for a call with a changeable raw value, that is converted - to a parameter, if and only if we are calling from outside a nested task. + to a parameter, if and only if we are calling from outside a subworkflow. """ result = increment(num=3) workflow = construct(result) @@ -170,38 +170,38 @@ def test_cwl_with_parameter() -> None: """) -def test_cwl_without_default() -> None: - """Check whether we can produce CWL without a default value. - - Uses a manually created parameter to avoid a default. - """ - my_param = param("my_param", typ=int) - - result = increment(num=my_param) - workflow = construct(result) - rendered = render(workflow)["__root__"] - hsh = hasher(("increment", ("num", "int|:param:my_param"))) - - assert rendered == yaml.safe_load(f""" - cwlVersion: 1.2 - class: Workflow - inputs: - my_param: - label: my_param - type: int - outputs: - out: - label: out - outputSource: increment-{hsh}/out - type: int - steps: - increment-{hsh}: - run: increment - in: - num: - source: my_param - out: [out] - """) +#def test_cwl_without_default() -> None: +# """Check whether we can produce CWL without a default value. +# +# Uses a manually created parameter to avoid a default. +# """ +# my_param = param("my_param", typ=int) +# +# result = increment(num=my_param) +# workflow = construct(result) +# rendered = render(workflow)["__root__"] +# hsh = hasher(("increment", ("num", "int|:param:my_param"))) +# +# assert rendered == yaml.safe_load(f""" +# cwlVersion: 1.2 +# class: Workflow +# inputs: +# my_param: +# label: my_param +# type: int +# outputs: +# out: +# label: out +# outputSource: increment-{hsh}/out +# type: int +# steps: +# increment-{hsh}: +# run: increment +# in: +# num: +# source: my_param +# out: [out] +# """) def test_cwl_with_subworkflow() -> None: @@ -365,10 +365,6 @@ def test_complex_cwl_references() -> None: label: increment-1-num type: int default: 23 - increment-2-num: - label: increment-2-num - type: int - default: 23 outputs: out: label: out @@ -383,17 +379,11 @@ def test_complex_cwl_references() -> None: num: source: increment-1-num out: [out] - increment-2: - run: increment - in: - num: - source: increment-2-num - out: [out] double-1: run: double in: num: - source: increment-2/out + source: increment-1/out out: [out] mod10-1: run: mod10 @@ -480,14 +470,14 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: int, float ] - sum-1-2-right: + sum-1-1-right: default: 1 - label: sum-1-2-right + label: sum-1-1-right type: int outputs: out: label: out - outputSource: sum-1-2/out + outputSource: sum-1-1/out type: - int - float @@ -499,7 +489,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: out: - out run: double - sum-1-1: + sum-1-2: in: left: source: double-1-1/out @@ -508,12 +498,12 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: out: - out run: sum - sum-1-2: + sum-1-1: in: left: - source: sum-1-1/out + source: sum-1-2/out right: - source: sum-1-2-right + source: sum-1-1-right out: - out run: sum diff --git a/tests/test_errors.py b/tests/test_errors.py index aa0a2be5..8336867f 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -4,7 +4,8 @@ from dewret.workflow import Task, Lazy from dewret.tasks import construct, task, nested_task, TaskException from dewret.annotations import AtRender -from ._lib.extra import increment # noqa: F401 +from dewret.renderers.cwl import render +from ._lib.extra import increment, pi, reverse_list # noqa: F401 @task() # This is expected to be the line number shown below. @@ -13,7 +14,7 @@ def add_task(left: int, right: int) -> int: return left + right -ADD_TASK_LINE_NO = 10 +ADD_TASK_LINE_NO = 11 @nested_task() diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index 4b6ae194..a86e04ee 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -1,5 +1,8 @@ import yaml from dataclasses import dataclass + +import pytest + from dewret.tasks import task, construct, subworkflow from dewret.workflow import param from dewret.renderers.cwl import render @@ -17,6 +20,7 @@ class Sides: def sum_sides(): return sum(left=SIDES.left, right=SIDES.right) +@pytest.mark.skip(reason="Need expression support") def test_fields_of_parameters_usable() -> None: result = sum_sides() workflow = construct(result, simplify_ids=True) diff --git a/tests/test_modularity.py b/tests/test_modularity.py index 03119087..330c9b78 100644 --- a/tests/test_modularity.py +++ b/tests/test_modularity.py @@ -37,9 +37,9 @@ def test_nested_task() -> None: default: 23 label: STARTING_NUMBER type: int - increase-1-num: + increase-3-num: default: 17 - label: increase-1-num + label: increase-3-num type: int outputs: out: @@ -47,13 +47,13 @@ def test_nested_task() -> None: outputSource: sum-1/out type: [int, float] steps: - increase-1: + increase-3: run: increase in: JUMP: source: JUMP num: - source: increase-1-num + source: increase-3-num out: [out] increase-2: run: increase @@ -61,9 +61,9 @@ def test_nested_task() -> None: JUMP: source: JUMP num: - source: increase-1/out + source: increase-3/out out: [out] - increase-3: + increase-1: run: increase in: JUMP: @@ -75,7 +75,7 @@ def test_nested_task() -> None: run: double in: num: - source: increase-3/out + source: increase-1/out out: [out] sum-1: run: sum diff --git a/tests/test_multiresult_steps.py b/tests/test_multiresult_steps.py index 2270ffab..b8b6efb6 100644 --- a/tests/test_multiresult_steps.py +++ b/tests/test_multiresult_steps.py @@ -256,7 +256,7 @@ def test_pair_can_be_returned_from_step() -> None: out: label: out outputSource: pair-1/out - items: + items: - int - float type: array diff --git a/tests/test_nested.py b/tests/test_nested.py index e402d4bc..fb2d3c8e 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -1,40 +1,51 @@ import yaml +import pytest +import math +from dewret.workflow import param from dewret.tasks import construct, task, factory from dewret.renderers.cwl import render -from ._lib.extra import pi - -@task() -def reverse_list(to_sort: list[int]) -> list[int]: - return to_sort[::-1] +from ._lib.extra import reverse_list, max_list def test_can_supply_nested_raw(): - result = reverse_list(to_sort=[1, 3, pi()]) - workflow = construct(result, simplify_ids=True) + pi = param("pi", math.pi) + result = reverse_list(to_sort=[1., 3., pi]) + workflow = construct(max_list(lst=result + result), simplify_ids=True) + #assert workflow.find_parameters() == { + # pi + #} + + # NB: This is not currently usefully renderable in CWL. + # However, the structures are important for future CWL rendering. + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow cwlVersion: 1.2 inputs: - reverse_list-1-to_sort: - default: [1, 3, 5] - label: reverse_list-1-to_sort - type: - type: - items: int - label: reverse_list-1-to_sort - type: array + pi: + default: 3.141592653589793 + label: pi + type: double outputs: out: - items: int label: out - outputSource: reverse_list-1/out - type: array + outputSource: max_list-1/out + type: + - int + - double steps: + max_list-1: + in: + lst: + expression: $(2*reverse_list-94ebd058f53d6a235643d33f3ab4c313) + out: + - out + run: max_list reverse_list-1: in: to_sort: - source: reverse_list-1-to_sort + expression: $((1.0, 3.0, pi)) out: - out run: reverse_list diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index b83f5b0b..a88e3fbd 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -459,16 +459,16 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: outputSource: increment-2/out type: int steps: - increment-1: + increment-2: in: num: - source: num + source: add_constants-1/out out: [out] run: increment - increment-2: + increment-1: in: num: - source: add_constants-1/out + source: num out: [out] run: increment add_constants-1: From 0f846e7c056b8d33523750674cf60f1b4d656d9d Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 10 Aug 2024 23:10:07 +0100 Subject: [PATCH 020/108] feat(sympy): more flexible parameter support --- src/dewret/annotations.py | 4 ++ src/dewret/backends/backend_dask.py | 35 ++++++++----- src/dewret/renderers/cwl.py | 23 +++++++-- src/dewret/tasks.py | 7 ++- src/dewret/workflow.py | 80 ++++++++++++++++++++++------- tests/test_annotations.py | 4 +- tests/test_fieldable.py | 40 ++++++++++++++- tests/test_modularity.py | 14 ++--- tests/test_subworkflows.py | 17 +++--- 9 files changed, 169 insertions(+), 55 deletions(-) diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index 69996668..d0c69e4f 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -30,6 +30,10 @@ def all_annotations(self): return self._annotations + @property + def return_type(self): + return inspect.signature(inspect.unwrap(self.fn)).return_annotation + @staticmethod def _typ_has(typ: type, annotation: type) -> bool: if not hasattr(annotation, "__metadata__"): diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index e5df1957..74e8a8b3 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -104,19 +104,28 @@ def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread task: `dask.delayed` function, wrapped by dewret, that we wish to compute. """ - def _check_delayed(task: Lazy | list[Lazy] | tuple[Lazy]) -> Delayed: - # We need isinstance to reassure type-checker. - if isinstance(task, list) or isinstance(task, tuple): - lst: list[Delayed] | tuple[Delayed, ...] = [_check_delayed(elt) for elt in task] - if isinstance(task, tuple): - lst = tuple(lst) - return delayed(lst) - elif not isinstance(task, Delayed) or not is_lazy(task): - raise RuntimeError( - f"{task} is not a dask delayed, perhaps you tried to mix backends?" - ) - return task - computable = _check_delayed(task) + # def _check_delayed(task: Lazy | list[Lazy] | tuple[Lazy]) -> Delayed: + # # We need isinstance to reassure type-checker. + # if isinstance(task, list) or isinstance(task, tuple): + # lst: list[Delayed] | tuple[Delayed, ...] = [_check_delayed(elt) for elt in task] + # if isinstance(task, tuple): + # lst = tuple(lst) + # return delayed(lst) + # elif not isinstance(task, Delayed) or not is_lazy(task): + # raise RuntimeError( + # f"{task} is not a dask delayed, perhaps you tried to mix backends?" + # ) + # return task + # computable = _check_delayed(task) + # if not is_lazy(task): + # raise RuntimeError( + # f"{task} is not a dask delayed, perhaps you tried to mix backends?" + # ) + + if isinstance(task, Delayed): + computable = task + else: + computable = delayed(task) config["pool"] = thread_pool result = computable.compute(__workflow__=workflow) return result diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index fe89cf2c..0b44bb04 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -92,6 +92,17 @@ def configuration(key: str) -> Any: return current_configuration.get(key) +def with_type(result: Any) -> type: + if hasattr(result, "__type__"): + return result.__type__ + return type(result) + +def with_field(result: Any) -> str: + if hasattr(result, "__field__"): + return "/".join(result.__field__) or "out" + else: + return "out" + @define class ReferenceDefinition: """CWL-renderable internal reference. @@ -179,8 +190,10 @@ def render(self) -> dict[str, RawType]: ref.render() if isinstance(ref, ReferenceDefinition) else {"expression": f"$({ref})"} - if isinstance(ref, Basic) - else {"default": ref.value} + if isinstance(ref, Basic) else + {"default": ref.value} + if hasattr(ref, "value") + else ref ) for key, ref in self.in_.items() }, @@ -471,12 +484,12 @@ def from_results( return cls( outputs=[ to_output_schema( - "/".join(result.__field__) or "out", result.__type__, output_source=to_name(result) + with_field(result), result.__type__, output_source=to_name(result) ) for result in results ] if isinstance(results, list | tuple | Tuple) else { key: to_output_schema( - "/".join(result.__field__) or "out", result.__type__, output_source=to_name(result) + with_field(result), result.__type__, output_source=to_name(result) ) for key, result in results.items() } @@ -545,7 +558,7 @@ def from_workflow( outputs=OutputsDefinition.from_results( workflow.result if isinstance(workflow.result, list | tuple | Tuple) else - {"/".join(workflow.result.__field__) or "out": workflow.result} + {with_field(workflow.result): workflow.result} if workflow.has_result else {} ), diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 64ef2377..de40afc6 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -462,7 +462,10 @@ def add_numbers(left: int, right: int): sig = inspect.signature(fn) sig.bind(*args, **kwargs) - _, refs = expr_to_references(kwargs.values(), include_parameters=True) + refs = [] + for key, val in kwargs.items(): + _, kw_refs = expr_to_references(val, include_parameters=True) + refs += kw_refs workflows = [ reference.__workflow__ for reference in refs @@ -584,7 +587,7 @@ def {fn.__name__}(...) -> ...: output = analyser.with_new_globals(nested_globals)(**nested_kwargs) nested_workflow = _manager(output, __workflow__=nested_workflow) step_reference = workflow.add_nested_step( - fn.__name__, nested_workflow, kwargs + fn.__name__, nested_workflow, analyser.return_type, kwargs ) if is_expr(step_reference): return cast(RetType, step_reference) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 48c8b4d3..969f2cdc 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -27,7 +27,7 @@ from types import GeneratorType from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable from uuid import uuid4 -from sympy import Symbol, Expr, Basic, sympify, Tuple +from sympy import Symbol, Expr, Basic, Tuple, Dict import logging @@ -173,7 +173,7 @@ class Parameter(Generic[T], Symbol): __name__: str __default__: T | UnsetType[T] __tethered__: Literal[False] | None | BaseStep | Workflow - + __fixed_type__: type[T] | Unset autoname: bool = False def __init__( @@ -182,6 +182,7 @@ def __init__( default: T | UnsetType[T], tethered: Literal[False] | None | Step | Workflow = None, autoname: bool = False, + typ: type[T] | Unset = UNSET ): """Construct a parameter. @@ -202,7 +203,17 @@ def __init__( self.__default__ = default self.__tethered__ = tethered self.__callers__: list[BaseStep] = [] + self.__fixed_type__ = typ + + if tethered and isinstance(tethered, BaseStep): + self.register_caller(tethered) + @property + def __type__(self): + if self.__fixed_type__ is not UNSET: + return self.__fixed_type__ + + default = self.__default__ if ( default is not None and hasattr(default, "__type__") @@ -211,9 +222,7 @@ def __init__( raw_type = default.__type__ else: raw_type = type(default) - self.__type__: type[T] = raw_type - if tethered and isinstance(tethered, BaseStep): - self.register_caller(tethered) + return raw_type def __eq__(self, other): if isinstance(other, ParameterReference) and other._.parameter == self and not other.__field__: @@ -282,7 +291,7 @@ def param( raise ValueError("Must provide a default or a type") default = UnsetType[T](typ) return cast( - T, Parameter(name, default=default, tethered=tethered, autoname=autoname) + T, Parameter(name, default=default, tethered=tethered, autoname=autoname, typ=typ) ) @@ -570,7 +579,7 @@ def register_task(self, fn: Lazy) -> Task: return task def add_nested_step( - self, name: str, subworkflow: Workflow, kwargs: dict[str, Any] + self, name: str, subworkflow: Workflow, return_type: type | None, kwargs: dict[str, Any] ) -> StepReference[Any]: """Append a nested step. @@ -583,10 +592,10 @@ def add_nested_step( """ step = NestedStep(self, name, subworkflow, kwargs) self.steps.append(step) - return_type = step.return_type + return_type = return_type or step.return_type if return_type is inspect._empty: raise TypeError("All tasks should have a type annotation.") - return StepReference(step, return_type) + return StepReference(step, typ=return_type) def add_step( self, @@ -653,7 +662,7 @@ def set_result(self, result: StepReference[Any] | list[StepReference[Any]] | tup """ _, refs = expr_to_references(result) for entry in refs: - if entry._.step.__workflow__ != self: + if entry.__workflow__ != self: raise RuntimeError("Output must be from a step in this workflow.") self.result = result @@ -755,6 +764,12 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None field_type = getattr(attrs_fields(parent_type), field).type except AttributeError: raise AttributeError(f"attrs-class {parent_type} does not have field {field}") + # TypedDict + elif inspect.isclass(parent_type) and issubclass(parent_type, dict) and hasattr(parent_type, "__annotations__"): + try: + field_type = parent_type.__annotations__[field] + except KeyError: + raise AttributeError(f"TypedDict {parent_type} does not have field {field}") if field_type: if not issubclass(self.__class__, Reference): @@ -810,6 +825,8 @@ def __init__( or isinstance(value, Raw) or is_raw(value) or is_expr(value) + or is_dataclass(value) + or attr_has(value) ): # Avoid recursive type issues if ( @@ -886,7 +903,7 @@ def return_type(self) -> Any: Expected type of the return value. """ if isinstance(self.task, Workflow): - if self.task.result: + if self.task.result is not None: return self.task.result_type else: raise AttributeError( @@ -977,6 +994,7 @@ def return_type(self) -> Any: Returns: Expected type of the return value. """ + return super().return_type if self.__subworkflow__.result is None or self.__subworkflow__.result is []: raise RuntimeError("Can only use a subworkflow if the reference exists.") return self.__subworkflow__.result_type @@ -1180,7 +1198,7 @@ def __init__( def __str__(self) -> str: """Global description of the reference.""" - return "/".join([self._.step.id] + list(self.__field__)) + return self.__name__ def __repr__(self) -> str: """Hashable reference to the step (and field).""" @@ -1209,8 +1227,14 @@ def __getattr__(self, attr: str) -> "StepReference[Any]": return self.find_field( workflow=self.__workflow__, step=self._.step, field=attr ) - except AttributeError as _: - return super().__getattribute__(attr) + except AttributeError as exc: + try: + return super().__getattribute__(attr) + except AttributeError as inner_exc: + raise inner_exc from exc + + def __getitem__(self, attr: str) -> "StepReference": + return getattr(self, attr) @property def __type__(self) -> type: @@ -1280,18 +1304,37 @@ def expr_to_references(expression: Any, include_parameters: bool=False) -> tuple if isinstance(expression, Raw) or is_raw(expression): return expression, set() + if isinstance(expression, Reference): + return expression, {expression} + + if is_dataclass(expression) or attr_has(expression): + refs = set() + fields = dataclass_fields(expression) if is_dataclass(expression) else {field.name for field in attrs_fields(expression)} + for field in fields: + if hasattr(expression, field.name) and isinstance((val := getattr(expression, field.name)), Reference): + _, field_refs = expr_to_references(val, include_parameters=include_parameters) + refs |= field_refs + return expression, refs + def _to_expr(value): - if not isinstance(value, str | bytes) and isinstance(value, Iterable): + if hasattr(value, "__type__"): + return value + + if isinstance(value, Mapping): + dct = Dict({key: _to_expr(val) for key, val in value.items()}) + return dct + elif not isinstance(value, str | bytes) and isinstance(value, Iterable): return Tuple(*(_to_expr(entry) for entry in value)) - return sympify(value) + return value if not isinstance(expression, Basic): expression = _to_expr(expression) symbols = list(expression.free_symbols) to_check = [sym for sym in symbols if isinstance(sym, Reference) or (include_parameters and isinstance(sym, Parameter))] - if {sym for sym in symbols if not is_raw(sym)} != set(to_check): - raise RuntimeError("The only symbols allowed are references (to e.g. step or parameter)") + #if {sym for sym in symbols if not is_raw(sym)} != set(to_check): + # print(symbols, to_check, [type(r) for r in symbols]) + # raise RuntimeError("The only symbols allowed are references (to e.g. step or parameter)") return expression, to_check def unify_workflows(expression: Any, base_workflow: Workflow | None, set_only: bool = False) -> Workflow | None: @@ -1310,6 +1353,5 @@ def unify_workflows(expression: Any, base_workflow: Workflow | None, set_only: b # Make sure all the results share it for step_result in to_check: step_result.__workflow__ = collected_workflow - expression = expression.subs(step_result, step_result) return expression, collected_workflow diff --git a/tests/test_annotations.py b/tests/test_annotations.py index e9b209ac..b5f2694b 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -72,7 +72,9 @@ def test_at_construct() -> None: out: label: out outputSource: to_int-1/out - type: int + type: + - int + - double steps: increment-1: in: diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index a86e04ee..7353fd25 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -2,12 +2,13 @@ from dataclasses import dataclass import pytest +from typing import Unpack, TypedDict from dewret.tasks import task, construct, subworkflow from dewret.workflow import param from dewret.renderers.cwl import render -from ._lib.extra import double, mod10, sum +from ._lib.extra import double, mod10, sum, pi @dataclass class Sides: @@ -68,3 +69,40 @@ class MyDataclass: assert str(param_reference.left) == "my_param/left" assert param_reference.left.__type__ == int + +def test_can_get_field_references_from_dataclass(): + @dataclass + class MyDataclass: + left: int + right: float + + @subworkflow() + def test_dataclass(my_dataclass: MyDataclass) -> MyDataclass: + result: MyDataclass = MyDataclass(left=mod10(num=my_dataclass.left), right=pi()) + return result + + @subworkflow() + def get_left(my_dataclass: MyDataclass) -> int: + return my_dataclass.left + + result = get_left(my_dataclass=test_dataclass(my_dataclass=MyDataclass(left=3, right=4.))) + workflow = construct(result, simplify_ids=True) + + assert str(workflow.result) == "get_left-1" + assert workflow.result.__type__ == int + +def test_can_get_field_references_from_typed_dict(): + class MyDict(TypedDict): + left: int + right: float + + @subworkflow() + def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: + result: MyDict = {"left": mod10(num=my_dict["left"]), "right": pi()} + return result + + result = test_dict(left=3, right=4.) + workflow = construct(result, simplify_ids=True) + + assert str(workflow.result["left"]) == "test_dict-1/left" + assert workflow.result["left"].__type__ == int diff --git a/tests/test_modularity.py b/tests/test_modularity.py index 330c9b78..deedfbc2 100644 --- a/tests/test_modularity.py +++ b/tests/test_modularity.py @@ -37,9 +37,9 @@ def test_nested_task() -> None: default: 23 label: STARTING_NUMBER type: int - increase-3-num: + increase-2-num: default: 17 - label: increase-3-num + label: increase-2-num type: int outputs: out: @@ -47,21 +47,21 @@ def test_nested_task() -> None: outputSource: sum-1/out type: [int, float] steps: - increase-3: + increase-2: run: increase in: JUMP: source: JUMP num: - source: increase-3-num + source: increase-2-num out: [out] - increase-2: + increase-3: run: increase in: JUMP: source: JUMP num: - source: increase-3/out + source: increase-2/out out: [out] increase-1: run: increase @@ -83,6 +83,6 @@ def test_nested_task() -> None: left: source: double-1/out right: - source: increase-2/out + source: increase-3/out out: [out] """) diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index a88e3fbd..ec41b416 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -11,13 +11,13 @@ CONSTANT = 3 -QueueFactory: Callable[..., "Queue[int]"] = factory(Queue) +QueueFactory: Callable[..., Queue[int]] = factory(Queue) GLOBAL_QUEUE = QueueFactory() @task() -def pop(queue: "Queue[int]") -> int: +def pop(queue: Queue[int]) -> int: """Remove element of a queue.""" return queue.get() @@ -29,26 +29,26 @@ def to_int(num: int | float) -> int: @task() -def add_and_queue(num: int, queue: "Queue[int]") -> "Queue[int]": +def add_and_queue(num: int, queue: Queue[int]) -> Queue[int]: """Add a global constant to a number.""" queue.put(num) return queue @subworkflow() -def make_queue(num: int | float) -> "Queue[int]": +def make_queue(num: int | float) -> Queue[int]: """Add a number to a queue.""" queue = QueueFactory() return add_and_queue(num=to_int(num=num), queue=queue) @subworkflow() -def get_global_queue(num: int | float) -> "Queue[int]": +def get_global_queue(num: int | float) -> Queue[int]: """Add a number to a global queue.""" return add_and_queue(num=to_int(num=num), queue=GLOBAL_QUEUE) @subworkflow() -def get_global_queues(num: int | float) -> list["Queue[int] | int"]: +def get_global_queues(num: int | float) -> list[Queue[int] | int]: """Add a number to a global queue.""" return [ add_and_queue(num=to_int(num=num), queue=GLOBAL_QUEUE), @@ -272,6 +272,9 @@ def test_subworkflows_can_return_lists() -> None: outputs: out: label: out + items: + - Queue + - int outputSource: get_global_queues-1/out type: array steps: @@ -337,7 +340,7 @@ def test_subworkflows_can_return_lists() -> None: outputs: - label: out outputSource: add_and_queue-1-1/out - type: Queue[int] + type: Queue - label: out outputSource: add_constant-1-1/out type: int From ff81ad41542faa077381cfd4634252275572d8a3 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 10 Aug 2024 23:51:20 +0100 Subject: [PATCH 021/108] fix: fix render load order --- src/dewret/__main__.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 77949e24..e1ec87e3 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -74,9 +74,6 @@ def render( ARGUMENTS is zero or more pairs representing constant arguments to pass to the task, in the format `key:val` where val is a JSON basic type. """ sys.path.append(str(Path(workflow_py).parent)) - loader = importlib.machinery.SourceFileLoader("workflow", workflow_py) - workflow = loader.load_module() - task_fn = getattr(workflow, task) kwargs = {} for arg in arguments: if ":" not in arg: @@ -97,7 +94,7 @@ def render( renderer_kwargs: dict[str, Any] if renderer_args.startswith("@"): with Path(renderer_args[1:]).open() as renderer_args_f: - renderer_kwargs = yaml.load(renderer_args_f) + renderer_kwargs = yaml.safe_load(renderer_args_f) elif not renderer_args: renderer_kwargs = {} else: @@ -119,6 +116,10 @@ def _opener(key, mode): opener = _opener render = get_render_method(render_module, pretty=pretty) + loader = importlib.machinery.SourceFileLoader("workflow", workflow_py) + workflow = loader.load_module() + task_fn = getattr(workflow, task) + try: rendered = render(construct(task_fn(**kwargs), **renderer_kwargs)) except Exception as exc: From e3b9b2888f8e415bd8f3b1d82355460957b75d96 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 11 Aug 2024 01:56:51 +0100 Subject: [PATCH 022/108] feat(iterated): allow spreading --- src/dewret/core.py | 40 ++++++++++++++++- src/dewret/tasks.py | 77 +++++++++++++++------------------ src/dewret/utils.py | 7 ++- src/dewret/workflow.py | 19 ++++++-- tests/test_annotations.py | 2 +- tests/test_configuration.py | 3 +- tests/test_cwl.py | 28 +++++------- tests/test_errors.py | 55 ++++++++++------------- tests/test_fieldable.py | 51 +++++++++++++++++++++- tests/test_modularity.py | 29 ++++++------- tests/test_multiresult_steps.py | 30 +++++++------ tests/test_parameters.py | 10 ++--- tests/test_subworkflows.py | 44 +++++++++++-------- 13 files changed, 239 insertions(+), 156 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index dfa80ec2..fa657105 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,4 +1,4 @@ -from typing import Generic, TypeVar, Protocol +from typing import Generic, TypeVar, Protocol, Iterator from sympy import Expr, Symbol U = TypeVar("U") @@ -53,6 +53,13 @@ def __float__(self) -> bool: self._raise_unevaluatable_error() return False + def __iter__(self) -> Iterator["Reference"]: + count = -1 + yield ( + Iterated(to_wrap=self, iteration=(count := iteration)) + for iteration in iter(lambda: count + 1, -1) + ) + def __int__(self) -> bool: self._raise_unevaluatable_error() return False @@ -71,3 +78,34 @@ def __name__(self) -> str: def __str__(self) -> str: """Global description of the reference.""" return self.__name__ + +class Iterated(Reference[U]): + __wrapped__: Reference[U] + __iteration__: int + + def __init__(self, to_wrap: Reference[U], iteration: int, *args, **kwargs): + self.__wrapped__ = to_wrap + self.__iteration__ = iteration + super().__init__(*args, **kwargs) + + @property + def __root_name__(self) -> str: + return f"{self.__wrapped__.__root_name__}[{self.__iteration__}]" + + @property + def __type__(self) -> type: + return Iterator[self.__wrapped__.__type__] + + def __hash__(self) -> int: + return hash(self.__root_name__) + + def __field__(self) -> str: + return str(self.__iteration__) + + @property + def __workflow__(self) -> WorkflowProtocol: + return self.__wrapped__.__workflow__ + + @__workflow__.setter + def __workflow__(self, workflow: WorkflowProtocol) -> None: + self.__wrapped__.__workflow__ = workflow diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index de40afc6..ed654419 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -37,7 +37,7 @@ from functools import cached_property from collections.abc import Callable from typing import Any, ParamSpec, TypeVar, cast, Generator -from types import TracebackType +from types import TracebackType, GeneratorType from attrs import has as attrs_has from dataclasses import dataclass, is_dataclass import traceback @@ -337,34 +337,6 @@ def factory(fn: Callable[..., RetType]) -> Callable[..., RetType]: return task(is_factory=True)(fn) -def nested_task() -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: - """Shortcut for marking a task as nested. - - A nested task is one which calls other tasks and does not - do anything else important. It will _not_ actually get called - at runtime, but should map entirely into the graph. As such, - arithmetic operations on results, etc. will cause errors at - render-time. Combining tasks is acceptable, and intended. The - effect of the nested task will be considered equivalent to whatever - reaching whatever step reference is returned at the end. - - ```python - >>> @task() - ... def increment(num: int) -> int: - ... return num + 1 - - >>> @nested_task() - ... def double_increment(num: int) -> int: - ... return increment(increment(num=num)) - - ``` - - Returns: - Task that runs at render, not execution, time. - """ - return task(nested=True) - - def subworkflow() -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: """Shortcut for marking a task as nested. @@ -381,7 +353,7 @@ def subworkflow() -> Callable[[Callable[Param, RetType]], Callable[Param, RetTyp ... def increment(num: int) -> int: ... return num + 1 - >>> @nested_task() + >>> @subworkflow() ... def double_increment(num: int) -> int: ... return increment(increment(num=num)) @@ -441,6 +413,14 @@ def _fn( __traceback__: TracebackType | None = None, **kwargs: Param.kwargs, ) -> RetType: + configuration = None + try: + allow_positional_args = get_configuration("allow_positional_args") + except LookupError: + configuration = set_configuration() + configuration.__enter__() + allow_positional_args = get_configuration("allow_positional_args") + try: # Ensure that all arguments are passed as keyword args and prevent positional args. # passed at all. @@ -460,7 +440,12 @@ def add_numbers(left: int, right: int): # Ensure that the passed arguments are, at least, a Python-match for the signature. sig = inspect.signature(fn) - sig.bind(*args, **kwargs) + positional_args = {key: False for key in kwargs} + if args and isinstance(args[0], GeneratorType): + for arg, (key, _) in zip(args[0], sig.parameters.items()): + kwargs[key] = arg + positional_args[key] = True + sig.bind(**kwargs) refs = [] for key, val in kwargs.items(): @@ -539,14 +524,14 @@ def add_numbers(left: int, right: int): ): raise TypeError( f""" - You referenced a task {var} inside another task {fn.__name__}, but it is not a nested_task + You referenced a task {var} inside another task {fn.__name__}, but it is not a workflow - this will not be found! - @task + @task() def {var}(...) -> ...: ... - @nested_task <<<--- likely what you want + @subworkflow() <<<--- likely what you want def {fn.__name__}(...) -> ...: ... {var}(...) @@ -562,11 +547,13 @@ def {fn.__name__}(...) -> ...: with in_nested_task(): output = analyser.with_new_globals(kwargs)(**original_kwargs) lazy_fn = ensure_lazy(output) - if lazy_fn is None: - raise TypeError( - f"Task {fn.__name__} returned output of type {type(output)}, which is not a lazy function for this backend." - ) - step_reference = evaluate(lazy_fn, __workflow__=workflow) + if lazy_fn is not None: + with in_nested_task(): + output = evaluate(lazy_fn, __workflow__=workflow) + #raise TypeError( + # f"Task {fn.__name__} returned output of type {type(output)}, which is not a lazy function for this backend." + #) + step_reference = output else: nested_workflow = Workflow(name=fn.__name__) nested_globals: Param.kwargs = { @@ -585,9 +572,9 @@ def {fn.__name__}(...) -> ...: nested_kwargs = {key: value for key, value in nested_globals.items() if key in original_kwargs} with in_nested_task(): output = analyser.with_new_globals(nested_globals)(**nested_kwargs) - nested_workflow = _manager(output, __workflow__=nested_workflow) + nested_workflow = _manager(output, __workflow__=nested_workflow) step_reference = workflow.add_nested_step( - fn.__name__, nested_workflow, analyser.return_type, kwargs + fn.__name__, nested_workflow, analyser.return_type, original_kwargs, positional_args ) if is_expr(step_reference): return cast(RetType, step_reference) @@ -599,8 +586,9 @@ def {fn.__name__}(...) -> ...: workflow.add_step( fn, kwargs, - raw_as_parameter=is_in_nested_task(), + raw_as_parameter=not is_in_nested_task(), is_factory=is_factory, + positional_args=positional_args ), ) return step @@ -613,9 +601,12 @@ def {fn.__name__}(...) -> ...: __traceback__, exc.args[0] if exc.args else "Could not call task {fn.__name__}", ) from exc + finally: + if configuration: + configuration.__exit__(None, None, None) _fn.__step_expression__ = True # type: ignore - return LazyEvaluation(lazy()(_fn)) + return LazyEvaluation(_fn) return _task diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 0979ac7b..983234cd 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -106,7 +106,7 @@ def flatten(value: Any) -> RawType: return crawl_raw(value, lambda entry: entry) def is_expr(value: Any) -> bool: - return is_raw(value, lambda x: isinstance(x, Basic)) + return is_raw(value, lambda x: isinstance(x, Basic) or isinstance(x, tuple) or isinstance(x, Reference)) def is_raw_type(typ: type) -> bool: """Check if a type counts as "raw".""" @@ -126,6 +126,9 @@ def is_raw(value: Any, check: Callable[[Any], bool] | None = None) -> bool: if isinstance(value, str | float | bool | bytes | int | None | Integer | Float | Rational): return True + if check is not None and check(value): + return True + if isinstance(value, Mapping): return ( (isinstance(value, dict) or (check is not None and check(value))) and @@ -139,7 +142,7 @@ def is_raw(value: Any, check: Callable[[Any], bool] | None = None) -> bool: all(is_raw(key, check) for key in value) ) - return check is not None and check(value) + return False def ensure_raw(value: Any, cast_tuple: bool = False) -> RawType | None: diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 969f2cdc..889ed527 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -24,7 +24,6 @@ from attrs import define, has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter, OrderedDict -from types import GeneratorType from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable from uuid import uuid4 from sympy import Symbol, Expr, Basic, Tuple, Dict @@ -579,7 +578,7 @@ def register_task(self, fn: Lazy) -> Task: return task def add_nested_step( - self, name: str, subworkflow: Workflow, return_type: type | None, kwargs: dict[str, Any] + self, name: str, subworkflow: Workflow, return_type: type | None, kwargs: dict[str, Any], positional_args: dict[str, bool] | None = None ) -> StepReference[Any]: """Append a nested step. @@ -591,6 +590,8 @@ def add_nested_step( kwargs: any key-value arguments to pass in the call. """ step = NestedStep(self, name, subworkflow, kwargs) + if positional_args is not None: + step.positional_args = positional_args self.steps.append(step) return_type = return_type or step.return_type if return_type is inspect._empty: @@ -603,6 +604,7 @@ def add_step( kwargs: dict[str, Raw | Reference], raw_as_parameter: bool = False, is_factory: bool = False, + positional_args: dict[str, bool] | None = None ) -> StepReference[Any]: """Append a step. @@ -618,6 +620,8 @@ def add_step( task = self.register_task(fn) step_maker = FactoryCall if is_factory else Step step = step_maker(self, task, kwargs, raw_as_parameter=raw_as_parameter) + if positional_args is not None: + step.positional_args = positional_args self.steps.append(step) return_type = step.return_type if ( @@ -637,6 +641,10 @@ def from_result( Starts from a result, and builds a workflow to output it. """ result, refs = expr_to_references(result) + if not refs: + raise RuntimeError( + "Attempted to build a workflow from a return-value/result/expression with no references." + ) refs = list(refs) workflow = refs[0].__workflow__ # Ensure that we have exactly one workflow, even if multiple results. @@ -799,6 +807,7 @@ class BaseStep(WorkflowComponent): task: Task | Workflow arguments: Mapping[str, Reference | Raw] workflow: Workflow + positional_args: dict[str, bool] | None = None def __init__( self, @@ -943,7 +952,7 @@ def _generate_id(self) -> str: for key, param in self.arguments.items(): components.append((key, repr(param))) - comp_tup: tuple[str | tuple[str, str], ...] = tuple(components) + comp_tup: tuple[str | tuple[str, str], ...] = tuple(sorted(components, key=lambda pair: pair[0])) return f"{self.task}-{hasher(comp_tup)}" @@ -972,10 +981,12 @@ def __init__( raw_as_parameter: whether raw-type arguments should be made (outer) workflow parameters. """ self.__subworkflow__ = subworkflow + base_arguments = {p.name: p for p in subworkflow.find_parameters()} + base_arguments.update(arguments) super().__init__( workflow=workflow, task=subworkflow, - arguments=arguments, + arguments=base_arguments, raw_as_parameter=raw_as_parameter, ) diff --git a/tests/test_annotations.py b/tests/test_annotations.py index b5f2694b..9d8f323f 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -52,8 +52,8 @@ def test_can_analyze_annotations(): assert analyser.argument_has("ARG1", AtRender) is True def test_at_construct() -> None: - result = to_int_bad(num=increment(num=3), should_double=True) with pytest.raises(TaskException) as _: + result = to_int_bad(num=increment(num=3), should_double=True) workflow = construct(result, simplify_ids=True) result = to_int(num=increment(num=3), should_double=True) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 2d99d5ba..72763d8d 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -21,12 +21,13 @@ def floor(num: int, expected: AtRender[bool]) -> int: return increment(num=num) def test_cwl_with_parameter(configuration) -> None: - result = increment(num=floor(num=3, expected=True)) with set_configuration(flatten_all_nested=True): + result = increment(num=floor(num=3, expected=True)) workflow = construct(result, simplify_ids=True) with pytest.raises(TaskException) as exc, set_configuration(flatten_all_nested=False): + result = increment(num=floor(num=3, expected=True)) workflow = construct(result, simplify_ids=True) assert "AssertionError" in str(exc.getrepr()) diff --git a/tests/test_cwl.py b/tests/test_cwl.py index 3c69f39a..3833a559 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -227,7 +227,7 @@ def test_cwl_with_subworkflow() -> None: outputs: out: label: out - outputSource: increment-2/out + outputSource: increment-1/out type: int steps: floor-1: @@ -236,13 +236,13 @@ def test_cwl_with_subworkflow() -> None: source: triple_and_one-1/out out: [out] run: floor - increment-1: + increment-2: in: num: source: num out: [out] run: increment - increment-2: + increment-1: in: num: source: floor-1/out @@ -251,7 +251,7 @@ def test_cwl_with_subworkflow() -> None: triple_and_one-1: in: num: - source: increment-1/out + source: increment-2/out out: [out] run: triple_and_one """) @@ -263,14 +263,10 @@ def test_cwl_with_subworkflow() -> None: num: label: num type: int - sum-1-2-right: - default: 1 - label: sum-1-2-right - type: int outputs: out: label: out - outputSource: sum-1-2/out + outputSource: sum-1-1/out type: - int - float @@ -282,7 +278,7 @@ def test_cwl_with_subworkflow() -> None: out: - out run: double - sum-1-1: + sum-1-2: in: left: source: double-1-1/out @@ -291,12 +287,12 @@ def test_cwl_with_subworkflow() -> None: out: - out run: sum - sum-1-2: + sum-1-1: in: left: - source: sum-1-1/out + source: sum-1-2/out right: - source: sum-1-2-right + default: 1 out: - out run: sum @@ -470,10 +466,6 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: int, float ] - sum-1-1-right: - default: 1 - label: sum-1-1-right - type: int outputs: out: label: out @@ -503,7 +495,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: left: source: sum-1-2/out right: - source: sum-1-1-right + default: 1 out: - out run: sum diff --git a/tests/test_errors.py b/tests/test_errors.py index 8336867f..96ecf108 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -2,7 +2,7 @@ import pytest from dewret.workflow import Task, Lazy -from dewret.tasks import construct, task, nested_task, TaskException +from dewret.tasks import construct, task, subworkflow, TaskException from dewret.annotations import AtRender from dewret.renderers.cwl import render from ._lib.extra import increment, pi, reverse_list # noqa: F401 @@ -17,7 +17,7 @@ def add_task(left: int, right: int) -> int: ADD_TASK_LINE_NO = 11 -@nested_task() +@subworkflow() def badly_add_task(left: int, right: int) -> int: """Badly attempts to add two numbers.""" return add_task(left=left) # type: ignore @@ -90,15 +90,15 @@ def pi_with_invisible_module_task() -> float: return extra.double(3.14 / 2) -@nested_task() +@subworkflow() def unacceptable_object_usage() -> int: """Invalid use of custom object within nested task.""" return MyStrangeClass(add_task(left=3, right=4)) # type: ignore -@nested_task() +@subworkflow() def unacceptable_nested_return(int_not_global: AtRender[bool]) -> int | Lazy: - """Bad nested_task that fails to return a task.""" + """Bad subworkflow that fails to return a task.""" add_task(left=3, right=4) return 7 if int_not_global else ADD_TASK_LINE_NO @@ -112,17 +112,16 @@ def test_missing_arguments_throw_error() -> None: WARNING: in keeping with Python principles, this does not error if types mismatch, but `mypy` should. You **must** type-check your code to catch these. """ - result = add_task(left=3) # type: ignore with pytest.raises(TaskException) as exc: - construct(result) + add_task(left=3) # type: ignore end_section = str(exc.getrepr())[-500:] assert str(exc.value) == "missing a required argument: 'right'" assert "Task add_task declared in at " in end_section assert f"test_errors.py:{ADD_TASK_LINE_NO}" in end_section -def test_missing_arguments_throw_error_in_nested_task() -> None: - """Check whether omitting a required argument within a nested_task will give an error. +def test_missing_arguments_throw_error_in_subworkflow() -> None: + """Check whether omitting a required argument within a subworkflow will give an error. Since we do not run the original function, it is up to dewret to check that the signature is, at least, acceptable to Python. @@ -130,9 +129,8 @@ def test_missing_arguments_throw_error_in_nested_task() -> None: WARNING: in keeping with Python principles, this does not error if types mismatch, but `mypy` should. You **must** type-check your code to catch these. """ - result = badly_add_task(left=3, right=4) with pytest.raises(TaskException) as exc: - construct(result) + badly_add_task(left=3, right=4) end_section = str(exc.getrepr())[-500:] assert str(exc.value) == "missing a required argument: 'right'" assert "def badly_add_task" in end_section @@ -146,9 +144,8 @@ def test_positional_arguments_throw_error() -> None: We can use default and non-default arguments, but we expect them to _always_ be named. """ - result = add_task(3, right=4) with pytest.raises(TaskException) as exc: - construct(result) + add_task(3, right=4) assert ( str(exc.value) .strip() @@ -156,21 +153,20 @@ def test_positional_arguments_throw_error() -> None: ) -def test_nesting_non_nested_tasks_throws_error() -> None: - """Ensure nesting is only allow in nested_tasks. +def test_nesting_non_subworkflows_throws_error() -> None: + """Ensure nesting is only allow in subworkflows. Nested tasks must be evaluated at construction time, and there is no concept of task calls that are not resolved during construction, so a task should not be called inside a non-nested task. """ - result = badly_wrap_task() with pytest.raises(TaskException) as exc: - construct(result) + badly_wrap_task() assert ( str(exc.value) .strip() .startswith( - "You referenced a task add_task inside another task badly_wrap_task, but it is not a nested_task" + "You referenced a task add_task inside another task badly_wrap_task, but it is not a workflow" ) ) @@ -199,17 +195,16 @@ def test_nesting_does_not_identify_imports_as_nesting() -> None: ] bad = [test_recursive, pi_with_visible_module_task] for tsk in bad: - result = tsk() with pytest.raises(TaskException) as exc: - construct(result) + tsk() assert str(exc.value).strip().startswith("You referenced a task") for tsk in good: result = tsk() construct(result) -def test_normal_objects_cannot_be_used_in_nested_tasks() -> None: - """Most entities cannot appear in a nested_task, ensure we catch them. +def test_normal_objects_cannot_be_used_in_subworkflows() -> None: + """Most entities cannot appear in a subworkflow, ensure we catch them. Since the logic in nested tasks has to be embedded explicitly in the workflow, complex types are not necessarily representable, and in most cases, we would not @@ -217,34 +212,28 @@ def test_normal_objects_cannot_be_used_in_nested_tasks() -> None: Note: this may be mitigated with sympy support, to some extent. """ - result = unacceptable_object_usage() with pytest.raises(TaskException) as exc: - construct(result) + unacceptable_object_usage() assert ( str(exc.value) == "Nested tasks must now only refer to global parameters, raw or tasks, not objects: MyStrangeClass" ) -def test_nested_tasks_must_return_a_task() -> None: +def test_subworkflows_must_return_a_task() -> None: """Ensure nested tasks are lazy-evaluatable. A graph only makes sense if the edges connect, and nested tasks must therefore chain. As such, a nested task must represent a real subgraph, and return a node to pull it into the main graph. """ - result = unacceptable_nested_return(int_not_global=True) with pytest.raises(TaskException) as exc: + result = unacceptable_nested_return(int_not_global=True) construct(result) assert ( str(exc.value) - == "Task unacceptable_nested_return returned output of type , which is not a lazy function for this backend." + == "Attempted to build a workflow from a return-value/result/expression with no references." ) result = unacceptable_nested_return(int_not_global=False) - with pytest.raises(TaskException) as exc: - construct(result) - assert ( - str(exc.value) - == "Task unacceptable_nested_return returned output of type , which is not a lazy function for this backend." - ) + construct(result) diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index 7353fd25..eeba908c 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -4,7 +4,7 @@ import pytest from typing import Unpack, TypedDict -from dewret.tasks import task, construct, subworkflow +from dewret.tasks import task, construct, subworkflow, set_configuration from dewret.workflow import param from dewret.renderers.cwl import render @@ -106,3 +106,52 @@ def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: assert str(workflow.result["left"]) == "test_dict-1/left" assert workflow.result["left"].__type__ == int + +def test_can_iterate(): + @task() + def test_task(alpha: int, beta: float, charlie: bool) -> int: + return int(alpha + beta) + + @task() + def test_list() -> list: + return [1, 2.] + + @subworkflow() + def test_iterated() -> int: + return test_task(*test_list()) + + with set_configuration(allow_positional_args=True, flatten_all_nested=True): + result = test_iterated() + workflow = construct(result, simplify_ids=True) + + rendered = render(workflow, allow_complex_types=True)["__root__"] + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: {} + outputs: + out: + label: out + outputSource: test_task-1/out + type: int + steps: + test_list-1: + in: {} + out: + - out + run: test_list + test_task-1: + in: + alpha: + source: test_list-1[0] + beta: + source: test_list-1[1] + charlie: + source: test_list-1[2] + out: + - out + run: test_task + """) + + assert workflow.result._.step.positional_args == {"alpha": True, "beta": True, "charlie": True} diff --git a/tests/test_modularity.py b/tests/test_modularity.py index deedfbc2..b2abb8b8 100644 --- a/tests/test_modularity.py +++ b/tests/test_modularity.py @@ -1,14 +1,14 @@ """Verify CWL can be made with split up and nested calls.""" import yaml -from dewret.tasks import nested_task, construct +from dewret.tasks import subworkflow, construct, set_configuration from dewret.renderers.cwl import render from ._lib.extra import double, sum, increase STARTING_NUMBER: int = 23 -@nested_task() +@subworkflow() def algorithm() -> int | float: """Creates a graph of task calls.""" left = double(num=increase(num=STARTING_NUMBER)) @@ -17,13 +17,14 @@ def algorithm() -> int | float: return sum(left=left, right=right) -def test_nested_task() -> None: +def test_subworkflow() -> None: """Check whether we can link between multiple steps and have parameters. Produces CWL that has references between multiple steps. """ - workflow = construct(algorithm(), simplify_ids=True) - rendered = render(workflow)["__root__"] + with set_configuration(flatten_all_nested=True): + workflow = construct(algorithm(), simplify_ids=True) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -37,39 +38,35 @@ def test_nested_task() -> None: default: 23 label: STARTING_NUMBER type: int - increase-2-num: - default: 17 - label: increase-2-num - type: int outputs: out: label: out outputSource: sum-1/out type: [int, float] steps: - increase-2: + increase-1: run: increase in: JUMP: source: JUMP num: - source: increase-2-num + source: STARTING_NUMBER out: [out] - increase-3: + increase-2: run: increase in: JUMP: source: JUMP num: - source: increase-2/out + source: increase-3/out out: [out] - increase-1: + increase-3: run: increase in: JUMP: source: JUMP num: - source: STARTING_NUMBER + default: 17 out: [out] double-1: run: double @@ -83,6 +80,6 @@ def test_nested_task() -> None: left: source: double-1/out right: - source: increase-3/out + source: increase-2/out out: [out] """) diff --git a/tests/test_multiresult_steps.py b/tests/test_multiresult_steps.py index b8b6efb6..23806658 100644 --- a/tests/test_multiresult_steps.py +++ b/tests/test_multiresult_steps.py @@ -4,7 +4,7 @@ from attr import define from dataclasses import dataclass from typing import Iterable -from dewret.tasks import task, construct, nested_task +from dewret.tasks import task, construct, subworkflow, set_configuration from dewret.renderers.cwl import render STARTING_NUMBER: int = 23 @@ -44,19 +44,19 @@ def pair(left: int, right: float) -> tuple[int, float]: return (left, right) -@nested_task() +@subworkflow() def algorithm() -> float: """Sum two split values.""" return combine(left=split().first, right=split().second) -@nested_task() +@subworkflow() def algorithm_with_pair() -> tuple[int, float]: """Pairs two split dataclass values.""" return pair(left=split_into_dataclass().first, right=split_into_dataclass().second) -@nested_task() +@subworkflow() def algorithm_with_dataclasses() -> float: """Sums two split dataclass values.""" return combine( @@ -76,7 +76,7 @@ def split_into_dataclass() -> SplitResultDataclass: return SplitResultDataclass(first=1, second=2) -def test_nested_task() -> None: +def test_subworkflow() -> None: """Check whether we can link between multiple steps and have parameters. Produces CWL that has references between multiple steps. @@ -114,7 +114,7 @@ def test_nested_task() -> None: """) -def test_field_of_nested_task() -> None: +def test_field_of_subworkflow() -> None: """Tests whether a directly-output nested task can have fields.""" workflow = construct(split().first, simplify_ids=True) rendered = render(workflow)["__root__"] @@ -142,7 +142,7 @@ def test_field_of_nested_task() -> None: """) -def test_field_of_nested_task_into_dataclasses() -> None: +def test_field_of_subworkflow_into_dataclasses() -> None: """Tests whether a directly-output nested task can have fields.""" workflow = construct(split_into_dataclass().first, simplify_ids=True) rendered = render(workflow)["__root__"] @@ -170,10 +170,11 @@ def test_field_of_nested_task_into_dataclasses() -> None: """) -def test_complex_field_of_nested_task() -> None: +def test_complex_field_of_subworkflow() -> None: """Tests whether a task can sum complex structures.""" - workflow = construct(algorithm(), simplify_ids=True) - rendered = render(workflow)["__root__"] + with set_configuration(flatten_all_nested=True): + workflow = construct(algorithm(), simplify_ids=True) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -206,11 +207,12 @@ def test_complex_field_of_nested_task() -> None: """) -def test_complex_field_of_nested_task_with_dataclasses() -> None: +def test_complex_field_of_subworkflow_with_dataclasses() -> None: """Tests whether a task can insert result fields into other steps.""" - result = algorithm_with_dataclasses() - workflow = construct(result, simplify_ids=True) - rendered = render(workflow)["__root__"] + with set_configuration(flatten_all_nested=True): + result = algorithm_with_dataclasses() + workflow = construct(result, simplify_ids=True) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow diff --git a/tests/test_parameters.py b/tests/test_parameters.py index c52b55b7..e19f0046 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -76,8 +76,8 @@ def test_complex_parameters() -> None: label: numx type: int default: 23 - rotate-1-num: - label: rotate-1-num + rotate-2-num: + label: rotate-2-num type: int default: 23 outputs: @@ -92,7 +92,7 @@ def test_complex_parameters() -> None: INPUT_NUM: source: INPUT_NUM num: - source: rotate-1-num + source: rotate-2/out out: [out] double-1: run: double @@ -106,7 +106,7 @@ def test_complex_parameters() -> None: INPUT_NUM: source: INPUT_NUM num: - source: rotate-1/out + source: rotate-2-num out: [out] rotate-3: run: rotate @@ -122,6 +122,6 @@ def test_complex_parameters() -> None: left: source: double-1/out right: - source: rotate-2/out + source: rotate-1/out out: [out] """) diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index ec41b416..9b3ecbb8 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -125,7 +125,6 @@ def test_subworkflows_can_use_globals() -> None: inputs: CONSTANT: label: CONSTANT - default: 3 type: int num: label: num @@ -133,16 +132,16 @@ def test_subworkflows_can_use_globals() -> None: outputs: out: label: out - outputSource: increment-2/out + outputSource: increment-1/out type: int steps: - increment-1: + increment-2: in: num: source: num out: [out] run: increment - increment-2: + increment-1: in: num: source: add_constant-1/out @@ -153,7 +152,7 @@ def test_subworkflows_can_use_globals() -> None: CONSTANT: source: CONSTANT num: - source: increment-1/out + source: increment-2/out out: [out] run: add_constant """) @@ -222,6 +221,9 @@ def test_subworkflows_can_use_global_factories() -> None: num: label: num type: int + GLOBAL_QUEUE: + label: GLOBAL_QUEUE + type: Queue outputs: out: label: out @@ -238,6 +240,8 @@ def test_subworkflows_can_use_global_factories() -> None: in: num: source: increment-1/out + GLOBAL_QUEUE: + source: GLOBAL_QUEUE out: [out] run: get_global_queue pop-1: @@ -269,6 +273,12 @@ def test_subworkflows_can_return_lists() -> None: num: label: num type: int + CONSTANT: + label: CONSTANT + type: int + GLOBAL_QUEUE: + label: GLOBAL_QUEUE + type: Queue outputs: out: label: out @@ -288,6 +298,10 @@ def test_subworkflows_can_return_lists() -> None: in: num: source: increment-1/out + CONSTANT: + source: CONSTANT + GLOBAL_QUEUE: + source: GLOBAL_QUEUE out: [out] run: get_global_queues """) @@ -331,12 +345,14 @@ def test_subworkflows_can_return_lists() -> None: cwlVersion: 1.2 inputs: CONSTANT: - default: 3 label: CONSTANT type: int num: label: num type: int + GLOBAL_QUEUE: + label: GLOBAL_QUEUE + type: Queue outputs: - label: out outputSource: add_and_queue-1-1/out @@ -345,17 +361,12 @@ def test_subworkflows_can_return_lists() -> None: outputSource: add_constant-1-1/out type: int steps: - Queue-1-1: - in: {} - out: - - out - run: Queue add_and_queue-1-1: in: num: source: to_int-1-1/out queue: - source: Queue-1-1/out + source: GLOBAL_QUEUE out: - out run: add_and_queue @@ -451,7 +462,6 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: inputs: CONSTANT: label: CONSTANT - default: 3 type: int num: label: num @@ -459,16 +469,16 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: outputs: out: label: out - outputSource: increment-2/out + outputSource: increment-1/out type: int steps: - increment-2: + increment-1: in: num: source: add_constants-1/out out: [out] run: increment - increment-1: + increment-2: in: num: source: num @@ -479,7 +489,7 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: CONSTANT: source: CONSTANT num: - source: increment-1/out + source: increment-2/out out: [out] run: add_constants """) From 38825fae494dd4de367a87a00173d11f00eccc70 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 11 Aug 2024 02:19:30 +0100 Subject: [PATCH 023/108] feat(fields): make it possible to use a plain dict if configured for fields --- src/dewret/core.py | 34 +++++++++++++++++++++++++++++++++- src/dewret/tasks.py | 29 +---------------------------- src/dewret/workflow.py | 10 ++++++++-- tests/test_configuration.py | 1 - tests/test_fieldable.py | 12 ++++++++++++ 5 files changed, 54 insertions(+), 32 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index fa657105..0c221c12 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,4 +1,6 @@ -from typing import Generic, TypeVar, Protocol, Iterator +from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired +from contextlib import contextmanager +from contextvars import ContextVar from sympy import Expr, Symbol U = TypeVar("U") @@ -10,6 +12,36 @@ class UnevaluatableError(Exception): ... +class ConstructConfiguration(TypedDict): + flatten_all_nested: NotRequired[bool] + allow_positional_args: NotRequired[bool] + allow_plain_dict_fields: NotRequired[bool] + +CONSTRUCT_CONFIGURATION: ContextVar[ConstructConfiguration] = ContextVar("construct-configuration") + +@contextmanager +def set_configuration(**kwargs: Unpack[ConstructConfiguration]): + try: + previous = ConstructConfiguration(**CONSTRUCT_CONFIGURATION.get()) + except LookupError: + previous = ConstructConfiguration( + flatten_all_nested=False, + allow_positional_args=False, + allow_plain_dict_fields=False, + ) + CONSTRUCT_CONFIGURATION.set({}) + + try: + CONSTRUCT_CONFIGURATION.get().update(previous) + CONSTRUCT_CONFIGURATION.get().update(kwargs) + + yield CONSTRUCT_CONFIGURATION + finally: + CONSTRUCT_CONFIGURATION.set(previous) + +def get_configuration(key: str): + return CONSTRUCT_CONFIGURATION.get()[key] + class Reference(Generic[U], Symbol): """Superclass for all symbolic references to values.""" diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index ed654419..d8f3d9f2 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -66,34 +66,7 @@ ) from .backends._base import BackendModule from .annotations import FunctionAnalyser - -class ConstructConfiguration(TypedDict): - flatten_all_nested: NotRequired[bool] - allow_positional_args: NotRequired[bool] - -CONSTRUCT_CONFIGURATION: ContextVar[ConstructConfiguration] = ContextVar("construct-configuration") - -@contextmanager -def set_configuration(**kwargs: Unpack[ConstructConfiguration]): - try: - previous = ConstructConfiguration(**CONSTRUCT_CONFIGURATION.get()) - except LookupError: - previous = ConstructConfiguration( - flatten_all_nested=False, - allow_positional_args=False - ) - CONSTRUCT_CONFIGURATION.set({}) - - try: - CONSTRUCT_CONFIGURATION.get().update(previous) - CONSTRUCT_CONFIGURATION.get().update(kwargs) - - yield CONSTRUCT_CONFIGURATION - finally: - CONSTRUCT_CONFIGURATION.set(previous) - -def get_configuration(key: str): - return CONSTRUCT_CONFIGURATION.get()[key] +from .core import get_configuration, set_configuration, CONSTRUCT_CONFIGURATION Param = ParamSpec("Param") RetType = TypeVar("RetType") diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 889ed527..6d9494cc 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -24,7 +24,7 @@ from attrs import define, has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter, OrderedDict -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable, get_origin, get_args from uuid import uuid4 from sympy import Symbol, Expr, Basic, Tuple, Dict @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) -from .core import Reference +from .core import Reference, get_configuration from .utils import hasher, RawType, is_raw, make_traceback, is_raw_type, is_expr, Unset T = TypeVar("T") @@ -778,6 +778,12 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None field_type = parent_type.__annotations__[field] except KeyError: raise AttributeError(f"TypedDict {parent_type} does not have field {field}") + if not field_type and get_configuration("allow_plain_dict_fields") and get_origin(parent_type) is dict: + args = get_args(parent_type) + if len(args) == 2 and args[0] is str: + field_type = args[1] + else: + raise AttributeError(f"Can only get fields for plain dicts if annotated dict[str, TYPE]") if field_type: if not issubclass(self.__class__, Reference): diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 72763d8d..53868ac2 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -21,7 +21,6 @@ def floor(num: int, expected: AtRender[bool]) -> int: return increment(num=num) def test_cwl_with_parameter(configuration) -> None: - with set_configuration(flatten_all_nested=True): result = increment(num=floor(num=3, expected=True)) workflow = construct(result, simplify_ids=True) diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index eeba908c..ff4fa65a 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -155,3 +155,15 @@ def test_iterated() -> int: """) assert workflow.result._.step.positional_args == {"alpha": True, "beta": True, "charlie": True} + +def test_can_use_plain_dict_fields(): + @subworkflow() + def test_dict(left: int, right: float) -> dict[str, float | int]: + result: dict[str, float | int] = {"left": mod10(num=left), "right": pi()} + return result + + with set_configuration(allow_plain_dict_fields=True): + result = test_dict(left=3, right=4.) + workflow = construct(result, simplify_ids=True) + assert str(workflow.result["left"]) == "test_dict-1/left" + assert workflow.result["left"].__type__ == int | float From 0f047fbadad3f34b929f3c32eb1e6e630e5bc23f Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 11 Aug 2024 02:33:44 +0100 Subject: [PATCH 024/108] fix: is_raw_type should handle union types --- src/dewret/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 983234cd..e3382a63 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -20,8 +20,8 @@ import hashlib import json import sys -from types import FrameType, TracebackType -from typing import Any, cast, Union, Protocol, ClassVar, Callable, Iterable +from types import FrameType, TracebackType, UnionType +from typing import Any, cast, Union, Protocol, ClassVar, Callable, Iterable, get_args from collections.abc import Sequence, Mapping from sympy import Basic, Integer, Float, Rational @@ -110,6 +110,8 @@ def is_expr(value: Any) -> bool: def is_raw_type(typ: type) -> bool: """Check if a type counts as "raw".""" + if isinstance(typ, UnionType): + return all(is_raw_type(t) for t in get_args(typ)) return issubclass(typ, str | float | bool | bytes | int | None | list | dict) From 946859581ea2ae801389c1b2f9cb12cf416ca661 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 11 Aug 2024 02:59:22 +0100 Subject: [PATCH 025/108] feat(positional): allow positional arguments if configuration override supplied --- src/dewret/core.py | 20 ++++++++++++++------ src/dewret/tasks.py | 16 +++++++++++----- tests/test_cwl.py | 41 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index 0c221c12..62933790 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,4 +1,4 @@ -from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired +from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator from contextlib import contextmanager from contextvars import ContextVar from sympy import Expr, Symbol @@ -86,11 +86,7 @@ def __float__(self) -> bool: return False def __iter__(self) -> Iterator["Reference"]: - count = -1 - yield ( - Iterated(to_wrap=self, iteration=(count := iteration)) - for iteration in iter(lambda: count + 1, -1) - ) + yield IteratedGenerator(self) def __int__(self) -> bool: self._raise_unevaluatable_error() @@ -111,6 +107,18 @@ def __str__(self) -> str: """Global description of the reference.""" return self.__name__ +class IteratedGenerator(Generic[U]): + __wrapped__: Reference[U] + + def __init__(self, to_wrap: Reference[U]): + self.__wrapped__ = to_wrap + + def __iter__(self): + count = -1 + while True: + yield Iterated(to_wrap=self.__wrapped__, iteration=(count := count + 1)) + + class Iterated(Reference[U]): __wrapped__: Reference[U] __iteration__: int diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index d8f3d9f2..0920036e 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -37,7 +37,7 @@ from functools import cached_property from collections.abc import Callable from typing import Any, ParamSpec, TypeVar, cast, Generator -from types import TracebackType, GeneratorType +from types import TracebackType from attrs import has as attrs_has from dataclasses import dataclass, is_dataclass import traceback @@ -66,7 +66,7 @@ ) from .backends._base import BackendModule from .annotations import FunctionAnalyser -from .core import get_configuration, set_configuration, CONSTRUCT_CONFIGURATION +from .core import get_configuration, set_configuration, CONSTRUCT_CONFIGURATION, IteratedGenerator Param = ParamSpec("Param") RetType = TypeVar("RetType") @@ -414,8 +414,14 @@ def add_numbers(left: int, right: int): # Ensure that the passed arguments are, at least, a Python-match for the signature. sig = inspect.signature(fn) positional_args = {key: False for key in kwargs} - if args and isinstance(args[0], GeneratorType): - for arg, (key, _) in zip(args[0], sig.parameters.items()): + for arg, (key, _) in zip(args, sig.parameters.items()): + if isinstance(arg, IteratedGenerator): + for inner_arg, (key, _) in zip(arg, sig.parameters.items()): + if key in positional_args: + continue + kwargs[key] = inner_arg + positional_args[key] = True + else: kwargs[key] = arg positional_args[key] = True sig.bind(**kwargs) @@ -572,7 +578,7 @@ def {fn.__name__}(...) -> ...: fn, declaration_tb, __traceback__, - exc.args[0] if exc.args else "Could not call task {fn.__name__}", + exc.args[0] if exc.args else f"Could not call task {fn.__name__}", ) from exc finally: if configuration: diff --git a/tests/test_cwl.py b/tests/test_cwl.py index 3833a559..48078d87 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -1,8 +1,10 @@ """Verify CWL output is OK.""" import yaml +import pytest from datetime import datetime, timedelta -from dewret.tasks import construct, task, factory +from dewret.core import set_configuration +from dewret.tasks import construct, task, factory, TaskException from dewret.renderers.cwl import render from dewret.utils import hasher from dewret.workflow import param @@ -169,6 +171,43 @@ def test_cwl_with_parameter() -> None: out: [out] """) +def test_cwl_with_parameter() -> None: + """Check whether we can move raw input to parameters. + + Produces CWL for a call with a changeable raw value, that is converted + to a parameter, if and only if we are calling from outside a subworkflow. + """ + with pytest.raises(TaskException) as exc: + result = increment(3) + with set_configuration(allow_positional_args=True): + result = increment(3) + workflow = construct(result) + rendered = render(workflow)["__root__"] + num_param = list(workflow.find_parameters())[0] + hsh = hasher(("increment", ("num", f"int|:param:{num_param._.unique_name}"))) + + assert rendered == yaml.safe_load(f""" + cwlVersion: 1.2 + class: Workflow + inputs: + increment-{hsh}-num: + label: increment-{hsh}-num + type: int + default: 3 + outputs: + out: + label: out + outputSource: increment-{hsh}/out + type: int + steps: + increment-{hsh}: + run: increment + in: + num: + source: increment-{hsh}-num + out: [out] + """) + #def test_cwl_without_default() -> None: # """Check whether we can produce CWL without a default value. From 333967f37e591b44cbf9fe4f2aaaa3034228e4cf Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 11 Aug 2024 03:23:45 +0100 Subject: [PATCH 026/108] fix(sympy): support none in dicts --- docs/renderer_tutorial.md | 6 ++--- src/dewret/core.py | 34 ++++++++++++++++++++++++- src/dewret/render.py | 2 +- src/dewret/renderers/cwl.py | 8 ++++-- src/dewret/renderers/snakemake.py | 3 +-- src/dewret/utils.py | 6 +---- src/dewret/workflow.py | 41 +++++++------------------------ tests/_lib/frender.py | 2 +- tests/test_multiresult_steps.py | 5 ++-- 9 files changed, 58 insertions(+), 49 deletions(-) diff --git a/docs/renderer_tutorial.md b/docs/renderer_tutorial.md index 43cec19e..27be6726 100644 --- a/docs/renderer_tutorial.md +++ b/docs/renderer_tutorial.md @@ -401,9 +401,9 @@ import inspect import typing from attrs import define +from dewret.utils import Raw, BasicType from dewret.workflow import Lazy -from dewret.workflow import Reference, Raw, Workflow, Step, Task -from dewret.utils import BasicType +from dewret.workflow import Reference, Workflow, Step, Task RawType = typing.Union[BasicType, list[str], list["RawType"], dict[str, "RawType"]] ``` @@ -418,4 +418,4 @@ python snakemake_tasks.py ``` ### Q: Should I add a brief description of dewret in step 1? Should link dewret types/docs etc here? -### A: Get details on how that happens and probably yes. \ No newline at end of file +### A: Get details on how that happens and probably yes. diff --git a/src/dewret/core.py b/src/dewret/core.py index 62933790..962f79ca 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,10 +1,16 @@ -from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator +from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union +from dataclasses import dataclass +import base64 from contextlib import contextmanager from contextvars import ContextVar from sympy import Expr, Symbol U = TypeVar("U") +BasicType = str | float | bool | bytes | int | None +RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] +FirmType = BasicType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] + class WorkflowProtocol(Protocol): ... @@ -107,6 +113,32 @@ def __str__(self) -> str: """Global description of the reference.""" return self.__name__ +@dataclass +class Raw: + """Value object for any raw types. + + This is able to hash raw types consistently and provides + a single type for validating type-consistency. + + Attributes: + value: the real value, e.g. a `str`, `int`, ... + """ + + value: RawType + + def __hash__(self) -> int: + """Provide a hash that is unique to the `value` member.""" + return hash(repr(self)) + + def __repr__(self) -> str: + """Convert to a consistent, string representation.""" + value: str + if isinstance(self.value, bytes): + value = base64.b64encode(self.value).decode("ascii") + else: + value = str(self.value) + return f"{type(self.value).__name__}|{value}" + class IteratedGenerator(Generic[U]): __wrapped__: Reference[U] diff --git a/src/dewret/render.py b/src/dewret/render.py index 70a34917..2fb69ffe 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -6,7 +6,7 @@ import yaml from .workflow import Workflow, NestedStep -from .utils import RawType +from .core import RawType from .workflow import Workflow RenderConfiguration = TypeVar("RenderConfiguration", bound=dict[str, Any]) diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 0b44bb04..e03296ca 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -27,17 +27,21 @@ from inspect import isclass from sympy import Expr, Basic, Tuple +from dewret.core import ( + Raw, + RawType, + FirmType, +) from dewret.workflow import ( FactoryCall, Reference, - Raw, Workflow, BaseStep, StepReference, ParameterReference, Unset, ) -from dewret.utils import RawType, flatten, DataclassProtocol, firm_to_raw, FirmType, flatten_if_set +from dewret.utils import flatten, DataclassProtocol, firm_to_raw, flatten_if_set from dewret.render import base_render class CommandInputSchema(TypedDict): diff --git a/src/dewret/renderers/snakemake.py b/src/dewret/renderers/snakemake.py index f721ceb7..f88bce75 100644 --- a/src/dewret/renderers/snakemake.py +++ b/src/dewret/renderers/snakemake.py @@ -25,11 +25,10 @@ import typing from attrs import define -from dewret.utils import BasicType +from dewret.core import Raw, BasicType from dewret.workflow import ( Reference, - Raw, Workflow, Task, Lazy, diff --git a/src/dewret/utils.py b/src/dewret/utils.py index e3382a63..8f246f50 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -25,11 +25,7 @@ from collections.abc import Sequence, Mapping from sympy import Basic, Integer, Float, Rational -from .core import Reference - -BasicType = str | float | bool | bytes | int | None -RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] -FirmType = BasicType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] +from .core import Reference, BasicType, RawType, FirmType class Unset: diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 6d9494cc..28265908 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -21,19 +21,19 @@ import inspect from collections.abc import Mapping, MutableMapping, Callable import base64 -from attrs import define, has as attr_has, resolve_types, fields as attrs_fields +from attrs import has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter, OrderedDict from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable, get_origin, get_args from uuid import uuid4 -from sympy import Symbol, Expr, Basic, Tuple, Dict +from sympy import Symbol, Expr, Basic, Tuple, Dict, nan import logging logger = logging.getLogger(__name__) -from .core import Reference, get_configuration -from .utils import hasher, RawType, is_raw, make_traceback, is_raw_type, is_expr, Unset +from .core import Reference, get_configuration, RawType, Raw +from .utils import hasher, is_raw, make_traceback, is_raw_type, is_expr, Unset T = TypeVar("T") U = TypeVar("U") @@ -64,33 +64,6 @@ def all_references_from(value: Any): return all_references -@define -class Raw: - """Value object for any raw types. - - This is able to hash raw types consistently and provides - a single type for validating type-consistency. - - Attributes: - value: the real value, e.g. a `str`, `int`, ... - """ - - value: RawType - - def __hash__(self) -> int: - """Provide a hash that is unique to the `value` member.""" - return hash(repr(self)) - - def __repr__(self) -> str: - """Convert to a consistent, string representation.""" - value: str - if isinstance(self.value, bytes): - value = base64.b64encode(self.value).decode("ascii") - else: - value = str(self.value) - return f"{type(self.value).__name__}|{value}" - - class Lazy(Protocol): """Requirements for a lazy-evaluatable function.""" @@ -1334,8 +1307,12 @@ def expr_to_references(expression: Any, include_parameters: bool=False) -> tuple return expression, refs def _to_expr(value): - if hasattr(value, "__type__"): + if value is None: + return nan + elif hasattr(value, "__type__"): return value + elif isinstance(value, Raw): + return value.value if isinstance(value, Mapping): dct = Dict({key: _to_expr(val) for key, val in value.items()}) diff --git a/tests/_lib/frender.py b/tests/_lib/frender.py index 33430690..6f81fea3 100644 --- a/tests/_lib/frender.py +++ b/tests/_lib/frender.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from contextvars import ContextVar -from dewret.utils import RawType +from dewret.core import RawType from dewret.workflow import Workflow, Step, NestedStep from dewret.render import base_render diff --git a/tests/test_multiresult_steps.py b/tests/test_multiresult_steps.py index 23806658..0a57dc70 100644 --- a/tests/test_multiresult_steps.py +++ b/tests/test_multiresult_steps.py @@ -247,8 +247,9 @@ def test_complex_field_of_subworkflow_with_dataclasses() -> None: def test_pair_can_be_returned_from_step() -> None: """Tests whether a task can insert result fields into other steps.""" - workflow = construct(algorithm_with_pair(), simplify_ids=True) - rendered = render(workflow)["__root__"] + with set_configuration(flatten_all_nested=True): + workflow = construct(algorithm_with_pair(), simplify_ids=True) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow From 079c00bfb3c338f3d9fc451482edf73cd46ecbea Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 11 Aug 2024 03:24:24 +0100 Subject: [PATCH 027/108] fix(sympy): support none in dicts --- src/dewret/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 28265908..5de0d759 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) -from .core import Reference, get_configuration, RawType, Raw +from .core import Reference, get_configuration, RawType from .utils import hasher, is_raw, make_traceback, is_raw_type, is_expr, Unset T = TypeVar("T") From 213c51437a1c058de299d49a64195c4ec8b1eaee Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 11 Aug 2024 03:34:26 +0100 Subject: [PATCH 028/108] fix: expressions should consider Raw as a raw type --- src/dewret/core.py | 57 +++++++++++++++++++++++------------------- src/dewret/utils.py | 4 +-- src/dewret/workflow.py | 2 +- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index 962f79ca..fdb6f3f2 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -5,6 +5,10 @@ from contextvars import ContextVar from sympy import Expr, Symbol +BasicType = str | float | bool | bytes | int | None +RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] +FirmType = BasicType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] + U = TypeVar("U") BasicType = str | float | bool | bytes | int | None @@ -113,32 +117,6 @@ def __str__(self) -> str: """Global description of the reference.""" return self.__name__ -@dataclass -class Raw: - """Value object for any raw types. - - This is able to hash raw types consistently and provides - a single type for validating type-consistency. - - Attributes: - value: the real value, e.g. a `str`, `int`, ... - """ - - value: RawType - - def __hash__(self) -> int: - """Provide a hash that is unique to the `value` member.""" - return hash(repr(self)) - - def __repr__(self) -> str: - """Convert to a consistent, string representation.""" - value: str - if isinstance(self.value, bytes): - value = base64.b64encode(self.value).decode("ascii") - else: - value = str(self.value) - return f"{type(self.value).__name__}|{value}" - class IteratedGenerator(Generic[U]): __wrapped__: Reference[U] @@ -181,3 +159,30 @@ def __workflow__(self) -> WorkflowProtocol: @__workflow__.setter def __workflow__(self, workflow: WorkflowProtocol) -> None: self.__wrapped__.__workflow__ = workflow + + +@dataclass +class Raw: + """Value object for any raw types. + + This is able to hash raw types consistently and provides + a single type for validating type-consistency. + + Attributes: + value: the real value, e.g. a `str`, `int`, ... + """ + + value: RawType + + def __hash__(self) -> int: + """Provide a hash that is unique to the `value` member.""" + return hash(repr(self)) + + def __repr__(self) -> str: + """Convert to a consistent, string representation.""" + value: str + if isinstance(self.value, bytes): + value = base64.b64encode(self.value).decode("ascii") + else: + value = str(self.value) + return f"{type(self.value).__name__}|{value}" diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 8f246f50..bf9fc98d 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -25,7 +25,7 @@ from collections.abc import Sequence, Mapping from sympy import Basic, Integer, Float, Rational -from .core import Reference, BasicType, RawType, FirmType +from .core import Reference, BasicType, RawType, FirmType, Raw class Unset: @@ -102,7 +102,7 @@ def flatten(value: Any) -> RawType: return crawl_raw(value, lambda entry: entry) def is_expr(value: Any) -> bool: - return is_raw(value, lambda x: isinstance(x, Basic) or isinstance(x, tuple) or isinstance(x, Reference)) + return is_raw(value, lambda x: isinstance(x, Basic) or isinstance(x, tuple) or isinstance(x, Reference) or isinstance(x, Raw)) def is_raw_type(typ: type) -> bool: """Check if a type counts as "raw".""" diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 5de0d759..241040fd 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) -from .core import Reference, get_configuration, RawType +from .core import Reference, get_configuration, Raw, RawType from .utils import hasher, is_raw, make_traceback, is_raw_type, is_expr, Unset T = TypeVar("T") From 08db69f5493bf46b0b97770643630245f7071069 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 11 Aug 2024 03:37:05 +0100 Subject: [PATCH 029/108] fix: sorted should not break for two identical steps --- src/dewret/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 241040fd..50e75bdc 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -423,7 +423,7 @@ def _indexed_steps(self) -> dict[str, BaseStep]: Returns: Mapping of steps by ID. """ - return OrderedDict(sorted((step.id, step) for step in self.steps)) + return OrderedDict(sorted(((step.id, step) for step in self.steps), key=lambda x: x[0])) @classmethod def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": From 3dd35d17f07f027979ced2a9fd42003cce011e28 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 13 Aug 2024 00:52:52 +0100 Subject: [PATCH 030/108] fix: efficiency improvements --- src/dewret/core.py | 4 +- src/dewret/renderers/cwl.py | 17 ++++-- src/dewret/tasks.py | 6 +- src/dewret/workflow.py | 117 ++++++++++++++++-------------------- tests/test_cwl.py | 16 ++--- tests/test_nested.py | 2 +- 6 files changed, 81 insertions(+), 81 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index fdb6f3f2..1ed57746 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -61,8 +61,10 @@ class Reference(Generic[U], Symbol): def __init__(self, *args, typ: type[U] | None = None, **kwargs): self._type = typ super().__init__() - self.name = self.__root_name__ + @property + def name(self): + return self.__root_name__ def __new__(cls, *args, **kwargs): instance = Expr.__new__(cls) diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index e03296ca..35147d33 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -22,10 +22,10 @@ from dataclasses import is_dataclass, fields as dataclass_fields from collections.abc import Mapping from contextvars import ContextVar -from typing import TypedDict, NotRequired, get_origin, get_args, Union, cast, Any, Iterable, Unpack +from typing import TypedDict, NotRequired, get_origin, get_args, Union, cast, Any, Unpack, Iterable from types import UnionType from inspect import isclass -from sympy import Expr, Basic, Tuple +from sympy import Expr, Basic, Tuple, sympify, Dict, jscode from dewret.core import ( Raw, @@ -68,6 +68,15 @@ class CommandInputSchema(TypedDict): str, CommandInputSchema, list[str], list["InputSchemaType"], dict[str, "str | InputSchemaType"] ] +def render_expression(ref: Any) -> str: + def _render(ref): + if not isinstance(ref, Basic): + if isinstance(ref, Mapping): + ref = Dict({key: _render(val) for key, val in ref.items()}) + elif not isinstance(ref, str | bytes) and isinstance(ref, Iterable): + ref = Tuple(*(_render(val) for val in ref)) + return ref + return f"$({jscode(_render(ref))})" class CWLRendererConfiguration(TypedDict): """Configuration for the renderer. @@ -193,11 +202,11 @@ def render(self) -> dict[str, RawType]: key: ( ref.render() if isinstance(ref, ReferenceDefinition) else - {"expression": f"$({ref})"} + {"expression": render_expression(ref)} if isinstance(ref, Basic) else {"default": ref.value} if hasattr(ref, "value") - else ref + else {"expression": render_expression(ref)} ) for key, ref in self.in_.items() }, diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 0920036e..fe89a146 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -426,9 +426,13 @@ def add_numbers(left: int, right: int): positional_args[key] = True sig.bind(**kwargs) + def _to_param_ref(value): + if isinstance(value, Parameter): + return ParameterReference(workflow=__workflow__, parameter=value) + refs = [] for key, val in kwargs.items(): - _, kw_refs = expr_to_references(val, include_parameters=True) + _, kw_refs = expr_to_references(val, remap=_to_param_ref) refs += kw_refs workflows = [ reference.__workflow__ diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 50e75bdc..0ae846ad 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -40,30 +40,6 @@ RetType = TypeVar("RetType") -def all_references_from(value: Any): - all_references: set = set() - - # If Raw, we examine the internal value. - # In theory, this should not contain a reference, - # but this makes all_references_from useful for error-checking. - if isinstance(value, Raw): - value = value.value - - if isinstance(value, Reference): - all_references.add(value) - elif isinstance(value, Basic): - symbols = value.free_symbols - if not all(isinstance(sym, Reference) for sym in symbols): - raise RuntimeError("Can only use symbols that are references to e.g. step or parameter.") - all_references |= symbols - elif isinstance(value, Mapping): - all_references |= all_references_from(value.keys()) - all_references |= all_references_from(value.values()) - elif isinstance(value, Iterable) and not isinstance(value, str | bytes): - all_references |= set().union(*(all_references_from(entry) for entry in value)) - - return all_references - class Lazy(Protocol): """Requirements for a lazy-evaluatable function.""" @@ -197,8 +173,6 @@ def __type__(self): return raw_type def __eq__(self, other): - if isinstance(other, ParameterReference) and other._.parameter == self and not other.__field__: - return True return hash(self) == hash(other) def __new__(cls, *args, **kwargs): @@ -407,7 +381,7 @@ def find_parameters( Returns: Set of all references to parameters across the steps. """ - references = all_references_from( + _, references = expr_to_references( step.arguments for step in self.steps if (include_factory_calls or not isinstance(step, FactoryCall)) ) return {ref for ref in references if isinstance(ref, ParameterReference)} @@ -439,6 +413,9 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": left: workflow to use as base right: workflow to combine on top """ + if left == right: + return left + new = cls() new._name = left._name or right._name @@ -448,8 +425,6 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": for step in list(left_steps.values()) + list(right_steps.values()): step.set_workflow(new) - for arg in step.arguments.values(): - unify_workflows(arg, new, set_only=True) for step_id in left_steps.keys() & right_steps.keys(): if left_steps[step_id] != right_steps[step_id]: @@ -830,15 +805,10 @@ def __init__( else: value = Raw(value) - expr, refs = expr_to_references(value, include_parameters=True) - if expr is not None: - for ref in set(refs): - if isinstance(ref, Parameter): - new_ref = ParameterReference(workflow=workflow, parameter=ref) - expr = expr.subs(ref, new_ref) - refs.remove(ref) - refs.append(new_ref) - value = expr + def _to_param_ref(value): + if isinstance(value, Parameter): + return ParameterReference(workflow=workflow, parameter=value) + value, refs = expr_to_references(value, remap=_to_param_ref) for ref in refs: if isinstance(ref, ParameterReference): @@ -1133,7 +1103,6 @@ def __eq__(self, other: object) -> bool: """ # We are equal to a parameter if we are a direct, fieldless, reference to it. return ( - (isinstance(other, Parameter) and self._.parameter == other and not self.__field__) or (isinstance(other, ParameterReference) and self._.parameter == other._.parameter and self.__field__ == other.__field__) ) @@ -1290,42 +1259,58 @@ def is_task(task: Lazy) -> bool: """ return isinstance(task, LazyEvaluation) -def expr_to_references(expression: Any, include_parameters: bool=False) -> tuple[Basic | None, set[Reference | Parameter]]: - if isinstance(expression, Raw) or is_raw(expression): - return expression, set() - - if isinstance(expression, Reference): - return expression, {expression} +def expr_to_references(expression: Any, remap: Callable[[Any], Any] | None = None) -> tuple[Basic | None, set[Reference | Parameter]]: + to_check = [] + def _to_expr(value): + if remap and (res := remap(value)) is not None: + return _to_expr(res) - if is_dataclass(expression) or attr_has(expression): - refs = set() - fields = dataclass_fields(expression) if is_dataclass(expression) else {field.name for field in attrs_fields(expression)} - for field in fields: - if hasattr(expression, field.name) and isinstance((val := getattr(expression, field.name)), Reference): - _, field_refs = expr_to_references(val, include_parameters=include_parameters) - refs |= field_refs - return expression, refs + if isinstance(value, Reference): + to_check.append(value) + return value - def _to_expr(value): if value is None: - return nan - elif hasattr(value, "__type__"): + return None + + if isinstance(value, Symbol): + return value + elif isinstance(value, Basic): + for sym in value.free_symbols: + new_sym = _to_expr(sym) + if new_sym != sym: + value = value.subs(sym, new_sym) + return value + + if is_dataclass(value) or attr_has(value): + if is_dataclass(value): + fields = dataclass_fields(value) + else: + fields = {field for field in attrs_fields(value.__class__)} + for field in fields: + if hasattr(value, field.name) and isinstance((val := getattr(value, field.name)), Reference): + setattr(value, field.name, _to_expr(val)) return value - elif isinstance(value, Raw): - return value.value + + # We need to look inside a Raw, but we do not want to lose it if + # we do not need to. + retval = value + if isinstance(value, Raw): + value = value.value if isinstance(value, Mapping): - dct = Dict({key: _to_expr(val) for key, val in value.items()}) - return dct + dct = {key: _to_expr(val) for key, val in value.items()} + if dct == value: + return retval + return value.__class__(dct) elif not isinstance(value, str | bytes) and isinstance(value, Iterable): - return Tuple(*(_to_expr(entry) for entry in value)) - return value + lst = (tuple if isinstance(value, tuple) else list)(_to_expr(v) for v in value) + if lst == value: + return retval + return lst + return retval - if not isinstance(expression, Basic): - expression = _to_expr(expression) + expression = _to_expr(expression) - symbols = list(expression.free_symbols) - to_check = [sym for sym in symbols if isinstance(sym, Reference) or (include_parameters and isinstance(sym, Parameter))] #if {sym for sym in symbols if not is_raw(sym)} != set(to_check): # print(symbols, to_check, [type(r) for r in symbols]) # raise RuntimeError("The only symbols allowed are references (to e.g. step or parameter)") diff --git a/tests/test_cwl.py b/tests/test_cwl.py index 48078d87..043ce53d 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -305,7 +305,7 @@ def test_cwl_with_subworkflow() -> None: outputs: out: label: out - outputSource: sum-1-1/out + outputSource: sum-1-2/out type: - int - float @@ -317,7 +317,7 @@ def test_cwl_with_subworkflow() -> None: out: - out run: double - sum-1-2: + sum-1-1: in: left: source: double-1-1/out @@ -326,10 +326,10 @@ def test_cwl_with_subworkflow() -> None: out: - out run: sum - sum-1-1: + sum-1-2: in: left: - source: sum-1-2/out + source: sum-1-1/out right: default: 1 out: @@ -508,7 +508,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: outputs: out: label: out - outputSource: sum-1-1/out + outputSource: sum-1-2/out type: - int - float @@ -520,7 +520,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: out: - out run: double - sum-1-2: + sum-1-1: in: left: source: double-1-1/out @@ -529,10 +529,10 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: out: - out run: sum - sum-1-1: + sum-1-2: in: left: - source: sum-1-2/out + source: sum-1-1/out right: default: 1 out: diff --git a/tests/test_nested.py b/tests/test_nested.py index fb2d3c8e..fc556f8f 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -38,7 +38,7 @@ def test_can_supply_nested_raw(): max_list-1: in: lst: - expression: $(2*reverse_list-94ebd058f53d6a235643d33f3ab4c313) + expression: $(2*reverse_list-1) out: - out run: max_list From 7e8ab1ae5af9b97e33061d2b68a47396247ef96c Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 13 Aug 2024 00:54:27 +0100 Subject: [PATCH 031/108] fix: efficiency improvements --- src/dewret/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index 1ed57746..57397a9b 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -64,7 +64,7 @@ def __init__(self, *args, typ: type[U] | None = None, **kwargs): @property def name(self): - return self.__root_name__ + return self.__name__ def __new__(cls, *args, **kwargs): instance = Expr.__new__(cls) From 2aba6e6a26f7fb41c6e5f5e157f3b1bb4183a928 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 13 Aug 2024 00:57:23 +0100 Subject: [PATCH 032/108] fix: efficiency improvements --- src/dewret/backends/backend_dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 74e8a8b3..4b859e59 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -94,7 +94,7 @@ def is_lazy(task: Any) -> bool: lazy = delayed -def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread_pool: ThreadPoolExecutor | None=None) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: +def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread_pool: ThreadPoolExecutor | None=None, **kwargs: Any) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: """Execute a task as the output of a workflow. Runs a task with dask. From 36c74f1c0a6a101c0bed389de78251de5ae0cee6 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 13 Aug 2024 01:09:35 +0100 Subject: [PATCH 033/108] feat: add construct args to CLI --- src/dewret/__main__.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index e1ec87e3..4193517c 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -49,13 +49,17 @@ default=Backend.DASK.name, help="Backend to use for workflow evaluation.", ) +@click.option( + "--construct-args", + default="simplify_ids:true" +) @click.option( "--renderer", default="cwl" ) @click.option( "--renderer-args", - default="simplify_ids:true" + default="" ) @click.option( "--output", @@ -65,7 +69,7 @@ @click.argument("task") @click.argument("arguments", nargs=-1) def render( - workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend, renderer: str, renderer_args: str, output: str + workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend, construct_args: str, renderer: str, renderer_args: str, output: str ) -> None: """Render a workflow. @@ -91,6 +95,14 @@ def render( else: raise RuntimeError("Renderer argument should be a known dewret renderer, or '@FILENAME' where FILENAME is a renderer") + if construct_args.startswith("@"): + with Path(construct_args[1:]).open() as construct_args_f: + construct_kwargs = yaml.safe_load(construct_args_f) + elif not construct_args: + construct_kwargs = {} + else: + construct_kwargs = dict(pair.split(":") for pair in construct_args.split(",")) + renderer_kwargs: dict[str, Any] if renderer_args.startswith("@"): with Path(renderer_args[1:]).open() as renderer_args_f: @@ -121,7 +133,7 @@ def _opener(key, mode): task_fn = getattr(workflow, task) try: - rendered = render(construct(task_fn(**kwargs), **renderer_kwargs)) + rendered = render(construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs) except Exception as exc: import traceback From 0517594cc3d68c320f599089243cd91cd4ea5d7a Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Wed, 14 Aug 2024 23:45:22 +0100 Subject: [PATCH 034/108] fix: make field separator configurable --- src/dewret/core.py | 2 ++ src/dewret/tasks.py | 3 ++- src/dewret/workflow.py | 14 +++++++++----- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index 57397a9b..b11a8a99 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -26,6 +26,7 @@ class ConstructConfiguration(TypedDict): flatten_all_nested: NotRequired[bool] allow_positional_args: NotRequired[bool] allow_plain_dict_fields: NotRequired[bool] + field_separator: NotRequired[str] CONSTRUCT_CONFIGURATION: ContextVar[ConstructConfiguration] = ContextVar("construct-configuration") @@ -38,6 +39,7 @@ def set_configuration(**kwargs: Unpack[ConstructConfiguration]): flatten_all_nested=False, allow_positional_args=False, allow_plain_dict_fields=False, + field_separator="/" ) CONSTRUCT_CONFIGURATION.set({}) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index fe89a146..ccc0ab82 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -432,8 +432,9 @@ def _to_param_ref(value): refs = [] for key, val in kwargs.items(): - _, kw_refs = expr_to_references(val, remap=_to_param_ref) + val, kw_refs = expr_to_references(val, remap=_to_param_ref) refs += kw_refs + kwargs[key] = val workflows = [ reference.__workflow__ for reference in refs diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 0ae846ad..b472265e 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -686,9 +686,13 @@ def name(self): # Subclass Reference so that we know Reference methods/attrs are available. class FieldableMixin: def __init__(self: FieldableProtocol, *args, field: str | None = None, **kwargs): - self.__field__: tuple[str, ...] = tuple(field.split("/")) if field else () + self.__field__: tuple[str, ...] = tuple(field.split(self.__field_sep__)) if field else () super().__init__(*args, **kwargs) + @property + def __field_sep__(self) -> str: + return get_configuration("field_separator") + @property def __name__(self: FieldableProtocol) -> str: """Name for this step. @@ -696,7 +700,7 @@ def __name__(self: FieldableProtocol) -> str: May be remapped by the workflow to something nicer than the ID. """ - return "/".join([super().__name__] + list(self.__field__)) + return self.__field_sep__.join([super().__name__] + list(self.__field__)) def find_field(self: FieldableProtocol, field, fallback_type: type | None = None, **init_kwargs: Any) -> Reference: """Field within the reference, if possible. @@ -738,7 +742,7 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None raise TypeError("Only references can have a fieldable mixin") if self.__field__: - field = "/".join(self.__field__) + "/" + field + field = self.__field_sep__.join(self.__field__) + self.__field_sep__ + field return self.__class__(typ=field_type, field=field, **init_kwargs) @@ -1081,7 +1085,7 @@ def __repr__(self) -> str: typ = self.__type__.__name__ except AttributeError: typ = str(self.__type__) - name = "/".join([self._.unique_name] + list(self.__field__)) + name = self.__field_sep__.join([self._.unique_name] + list(self.__field__)) return f"{typ}|:param:{name}" def __hash__(self) -> int: @@ -1161,7 +1165,7 @@ def __str__(self) -> str: def __repr__(self) -> str: """Hashable reference to the step (and field).""" - return "/".join([self._.step.id] + list(self.__field__)) + return self.__field_sep__.join([self._.step.id] + list(self.__field__)) def __hash__(self) -> int: return hash((repr(self), id(self.__workflow__))) From 9b5d03a633e9b2be2ad5010ae298f4e7a37d6c43 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Wed, 14 Aug 2024 23:49:04 +0100 Subject: [PATCH 035/108] fix: construct kwargs used to set_configuration --- src/dewret/tasks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index ccc0ab82..2aa7d343 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -66,7 +66,7 @@ ) from .backends._base import BackendModule from .annotations import FunctionAnalyser -from .core import get_configuration, set_configuration, CONSTRUCT_CONFIGURATION, IteratedGenerator +from .core import get_configuration, set_configuration, CONSTRUCT_CONFIGURATION, IteratedGenerator, ConstructConfiguration Param = ParamSpec("Param") RetType = TypeVar("RetType") @@ -195,7 +195,7 @@ def __call__( task: Any, simplify_ids: bool = False, __workflow__: Workflow | None = None, - **kwargs: Any, + **kwargs: ConstructConfiguration, ) -> Workflow: """Execute the lazy evalution. @@ -209,7 +209,7 @@ def __call__( """ workflow = __workflow__ or Workflow() - with set_configuration(): + with set_configuration(**kwargs): context = copy_context().items() def _initializer(): for var, value in context: From 77c46bf84dcab9ae9ea6e5c848e5110d45ed36c8 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Thu, 15 Aug 2024 00:00:16 +0100 Subject: [PATCH 036/108] fix: construct kwargs used to set_configuration --- src/dewret/__main__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 4193517c..7f9d5f6a 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -30,6 +30,7 @@ import click import json +from .core import set_configuration from .render import get_render_method, RawRenderModule, StructuredRenderModule from .tasks import Backend, construct @@ -133,7 +134,8 @@ def _opener(key, mode): task_fn = getattr(workflow, task) try: - rendered = render(construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs) + with set_configuration(**construct_kwargs): + rendered = render(construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs) except Exception as exc: import traceback From 19f1a384b43ca42438ebe6c88fbf8dbcb798dd89 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Thu, 15 Aug 2024 00:02:50 +0100 Subject: [PATCH 037/108] fix: ignore annotations when looking at fieldability --- src/dewret/workflow.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index b472265e..e6123625 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -711,6 +711,9 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None # Get new type, for the specific field. parent_type = self.__type__ + # Strip out any annotations. + while get_origin(parent_type) is Annotated: + parent_type = get_args(parent_type)[0] field_type = fallback_type if is_dataclass(parent_type): From ca47bb81c25b07bbafc911e333d5446c1be6c98a Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 17 Aug 2024 17:35:34 +0100 Subject: [PATCH 038/108] feat(imports): retrieve annotations from imported modules --- src/dewret/annotations.py | 61 ++++++++++++++++++++++++++------------ src/dewret/render.py | 13 ++++++-- src/dewret/tasks.py | 48 +++++++++++++++++------------- tests/_lib/extra.py | 13 ++++++++ tests/_lib/frender.py | 2 +- tests/_lib/other.py | 3 ++ tests/test_annotations.py | 40 +++++++++++++++---------- tests/test_errors.py | 2 +- tests/test_subworkflows.py | 4 +-- 9 files changed, 125 insertions(+), 61 deletions(-) create mode 100644 tests/_lib/other.py diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index d0c69e4f..0c85665b 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -1,4 +1,7 @@ import inspect +import ast +import sys +import importlib from functools import lru_cache from types import FunctionType from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args, Mapping @@ -19,17 +22,6 @@ def __init__(self, fn: Callable[..., Any]): fn ) - @property - def all_annotations(self): - try: - self._annotations = self.fn.__globals__["__annotations__"] - except KeyError: - self._annotations = {} - - self._annotations.update(self.fn.__annotations__) - - return self._annotations - @property def return_type(self): return inspect.signature(inspect.unwrap(self.fn)).return_annotation @@ -45,15 +37,46 @@ def _typ_has(typ: type, annotation: type) -> bool: return True return False - def argument_has(self, arg: str, annotation: type) -> bool: - if arg in self.all_annotations: - typ = self.all_annotations[arg] - if self._typ_has(typ, annotation): - return True - return False + def get_all_module_names(self): + return sys.modules[self.fn.__module__].__annotations__ + + def get_all_imported_names(self): + return self._get_all_imported_names(sys.modules[self.fn.__module__]) + + @staticmethod + @lru_cache + def _get_all_imported_names(mod): + ast_tree = ast.parse(inspect.getsource(mod)) + imported_names = {} + for node in ast.walk(ast_tree): + if isinstance(node, ast.ImportFrom): + for name in node.names: + imported_names[name.asname or name.name] = ( + importlib.import_module("".join(["."] * node.level) + node.module, package=mod.__package__), + name.name + ) + return imported_names + + def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> type | None: + all_annotations: dict[str, type] = {} + typ: type | None = None + if (typ := self.fn.__annotations__.get(arg)): + ... + elif exhaustive: + if "__annotations__" in self.fn.__globals__: + if (typ := self.fn.__globals__["__annotations__"].get(arg)): + ... + elif (orig_pair := self.get_all_imported_names().get(arg)): + orig_module, orig_name = orig_pair + typ = orig_module.__annotations__.get(orig_name) + return typ + + def argument_has(self, arg: str, annotation: type, exhaustive: bool=False) -> bool: + typ = self.get_argument_annotation(arg, exhaustive) + return bool(typ and self._typ_has(typ, annotation)) - def is_at_construct_arg(self, arg: str) -> bool: - return self.argument_has(arg, AtRender) + def is_at_construct_arg(self, arg: str, exhaustive: bool=False) -> bool: + return self.argument_has(arg, AtRender, exhaustive) @property def globals(self) -> Mapping[str, Any]: diff --git a/src/dewret/render.py b/src/dewret/render.py index 2fb69ffe..e4f6f53d 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -31,8 +31,17 @@ def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, if isinstance(renderer, Path): if (render_dir := str(renderer.parent)) not in sys.path: sys.path.append(render_dir) - loader = importlib.machinery.SourceFileLoader("renderer", str(renderer)) - render_module = loader.load_module() + package_init = renderer.parent + + # Attempt to load renderer as package, falling back to a single module otherwise. + # This enables relative imports in renderers and therefore the ability to modularize. + try: + loader = importlib.machinery.SourceFileLoader("renderer", str(package_init / "__init__.py")) + sys.modules["renderer"] = loader.load_module(f"renderer") + render_module = importlib.import_module(f"renderer.{renderer.stem}", "renderer") + except ImportError: + loader = importlib.machinery.SourceFileLoader("renderer", str(renderer)) + render_module = loader.load_module() else: render_module = renderer if hasattr(render_module, "render_raw"): diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 2aa7d343..a5739d63 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -67,6 +67,7 @@ from .backends._base import BackendModule from .annotations import FunctionAnalyser from .core import get_configuration, set_configuration, CONSTRUCT_CONFIGURATION, IteratedGenerator, ConstructConfiguration +import ast Param = ParamSpec("Param") RetType = TypeVar("RetType") @@ -468,7 +469,7 @@ def _to_param_ref(value): ) else None ), autoname=True, - typ=analyser.all_annotations.get(var, UNSET) + typ=analyser.get_argument_annotation(var) or UNSET ), ) original_kwargs = dict(kwargs) @@ -481,27 +482,8 @@ def _to_param_ref(value): # raise TypeError( # "Captured parameter {var} (global variable in task) shadows an argument" # ) - if ( - analyser.is_at_construct_arg(var) or - isinstance(value, Reference) or - value is evaluate or value is construct # Allow manual building. - ): - kwargs[var] = value - elif isinstance(value, Parameter): + if isinstance(value, Parameter): kwargs[var] = ParameterReference(workflow=workflow, parameter=value) - elif is_raw(value) or ( - (attrs_has(value) or is_dataclass(value)) and - not inspect.isclass(value) - ): - kwargs[var] = ParameterReference( - workflow=workflow, - parameter=param( - var, - value, - tethered=False, - typ=analyser.all_annotations.get(var, UNSET) - ) - ) elif is_task(value) or ensure_lazy(value) is not None: if not nested and _workaround_check_value_is_task( fn, var, value @@ -522,6 +504,29 @@ def {fn.__name__}(...) -> ...: ... """ ) + # If nested, we will execute the insides, and it is reasonable and important + # to have a full set of annotations for any encountered variables. + elif nested and not analyser.get_argument_annotation(var, exhaustive=True) and not inspect.isclass(value) or inspect.isfunction(value): + raise RuntimeError(f"Could not find a type annotation for {var} for {fn.__name__}") + elif ( + analyser.is_at_construct_arg(var, exhaustive=True) or + isinstance(value, Reference) or + value is evaluate or value is construct # Allow manual building. + ): + kwargs[var] = value + elif is_raw(value) or ( + (attrs_has(value) or is_dataclass(value)) and + not inspect.isclass(value) + ): + kwargs[var] = ParameterReference( + workflow=workflow, + parameter=param( + var, + value, + tethered=False, + typ=analyser.get_argument_annotation(var, exhaustive=True) or UNSET + ) + ) elif nested: raise NotImplementedError( f"Nested tasks must now only refer to global parameters, raw or tasks, not objects: {var}" @@ -590,6 +595,7 @@ def {fn.__name__}(...) -> ...: configuration.__exit__(None, None, None) _fn.__step_expression__ = True # type: ignore + _fn.__original__ = fn return LazyEvaluation(_fn) return _task diff --git a/tests/_lib/extra.py b/tests/_lib/extra.py index ac18c8a0..bb663835 100644 --- a/tests/_lib/extra.py +++ b/tests/_lib/extra.py @@ -1,7 +1,20 @@ from dewret.tasks import task, subworkflow +from dewret.annotations import AtRender + +from .other import nothing JUMP: float = 1.0 +test: float = nothing + +from inspect import get_annotations + +@subworkflow() +def try_nothing() -> int: + """Check that we can see AtRender in another module.""" + if nothing: + return increment(num=1) + return increment(num=0) @task() def increase(num: int | float) -> float: diff --git a/tests/_lib/frender.py b/tests/_lib/frender.py index 6f81fea3..6e92573f 100644 --- a/tests/_lib/frender.py +++ b/tests/_lib/frender.py @@ -12,7 +12,7 @@ from dewret.workflow import Workflow, Step, NestedStep from dewret.render import base_render -from extra import JUMP +from .extra import JUMP class FrenderRendererConfiguration(TypedDict): allow_complex_types: bool diff --git a/tests/_lib/other.py b/tests/_lib/other.py new file mode 100644 index 00000000..891fbe56 --- /dev/null +++ b/tests/_lib/other.py @@ -0,0 +1,3 @@ +from dewret.annotations import AtRender + +nothing: AtRender[bool] = True diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 9d8f323f..8de4a451 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -5,7 +5,7 @@ from dewret.renderers.cwl import render from dewret.annotations import AtRender, FunctionAnalyser -from ._lib.extra import increment, sum +from ._lib.extra import increment, sum, try_nothing ARG1: AtRender[bool] = True ARG2: bool = False @@ -36,22 +36,24 @@ def test_can_analyze_annotations(): my_obj = MyClass() analyser = FunctionAnalyser(my_obj.method) - assert analyser.argument_has("arg1", AtRender) is False - assert analyser.argument_has("arg3", AtRender) is False - assert analyser.argument_has("ARG2", AtRender) is False - assert analyser.argument_has("arg2", AtRender) is True - assert analyser.argument_has("arg4", AtRender) is False # Not a global/argument - assert analyser.argument_has("ARG1", AtRender) is True + assert analyser.argument_has("arg1", AtRender, exhaustive=True) is False + assert analyser.argument_has("arg3", AtRender, exhaustive=True) is False + assert analyser.argument_has("ARG2", AtRender, exhaustive=True) is False + assert analyser.argument_has("arg2", AtRender, exhaustive=True) is True + assert analyser.argument_has("arg4", AtRender, exhaustive=True) is False # Not a global/argument + assert analyser.argument_has("ARG1", AtRender, exhaustive=True) is True + assert analyser.argument_has("ARG1", AtRender) is False analyser = FunctionAnalyser(fn) - assert analyser.argument_has("arg5", AtRender) is False - assert analyser.argument_has("arg7", AtRender) is False - assert analyser.argument_has("ARG2", AtRender) is False - assert analyser.argument_has("arg2", AtRender) is True - assert analyser.argument_has("arg8", AtRender) is False # Not a global/argument - assert analyser.argument_has("ARG1", AtRender) is True - -def test_at_construct() -> None: + assert analyser.argument_has("arg5", AtRender, exhaustive=True) is False + assert analyser.argument_has("arg7", AtRender, exhaustive=True) is False + assert analyser.argument_has("ARG2", AtRender, exhaustive=True) is False + assert analyser.argument_has("arg6", AtRender, exhaustive=True) is True + assert analyser.argument_has("arg8", AtRender, exhaustive=True) is False # Not a global/argument + assert analyser.argument_has("ARG1", AtRender, exhaustive=True) is True + assert analyser.argument_has("ARG1", AtRender) is False + +def test_at_render() -> None: with pytest.raises(TaskException) as _: result = to_int_bad(num=increment(num=3), should_double=True) workflow = construct(result, simplify_ids=True) @@ -131,3 +133,11 @@ def test_at_construct() -> None: - out run: to_int """) + + +def test_at_render_between_modules() -> None: + nothing = False + result = try_nothing() + workflow = construct(result, simplify_ids=True) + subworkflows = render(workflow, allow_complex_types=True) + rendered = subworkflows["__root__"] diff --git a/tests/test_errors.py b/tests/test_errors.py index 96ecf108..7defe10b 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -14,7 +14,7 @@ def add_task(left: int, right: int) -> int: return left + right -ADD_TASK_LINE_NO = 11 +ADD_TASK_LINE_NO: int = 11 @subworkflow() diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 9b3ecbb8..163a8e84 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -9,11 +9,11 @@ from ._lib.extra import increment, sum, pi -CONSTANT = 3 +CONSTANT: int = 3 QueueFactory: Callable[..., Queue[int]] = factory(Queue) -GLOBAL_QUEUE = QueueFactory() +GLOBAL_QUEUE: Queue = QueueFactory() @task() From 4bdcd8be2b96db97a978a6d822045b8feae176f4 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 17 Aug 2024 23:41:51 +0100 Subject: [PATCH 039/108] feat(fixed): add loopable parameters --- src/dewret/annotations.py | 17 +++++++- src/dewret/core.py | 31 ++++++++++---- src/dewret/renderers/cwl.py | 34 +++++++++------ src/dewret/tasks.py | 67 +++++++++++++++--------------- src/dewret/utils.py | 10 ++++- src/dewret/workflow.py | 82 ++++++++++++++++++++++++++----------- tests/test_annotations.py | 32 ++++++++++++++- tests/test_errors.py | 44 ++++++++++++++++++-- tests/test_fieldable.py | 2 +- tests/test_subworkflows.py | 6 +++ 10 files changed, 237 insertions(+), 88 deletions(-) diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index 0c85665b..01c375a8 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -4,10 +4,12 @@ import importlib from functools import lru_cache from types import FunctionType -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args, Mapping +from dataclasses import dataclass +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args, Mapping, TypeAliasType T = TypeVar("T") AtRender = Annotated[T, "AtRender"] +Fixed = Annotated[T, "Fixed"] class FunctionAnalyser: _fn: Callable[..., Any] @@ -57,6 +59,12 @@ def _get_all_imported_names(mod): ) return imported_names + @property + def free_vars(self): + if self.fn.__code__ and self.fn.__closure__: + return dict(zip(self.fn.__code__.co_freevars, (c.cell_contents for c in self.fn.__closure__))) + return {} + def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> type | None: all_annotations: dict[str, type] = {} typ: type | None = None @@ -69,6 +77,9 @@ def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> type | No elif (orig_pair := self.get_all_imported_names().get(arg)): orig_module, orig_name = orig_pair typ = orig_module.__annotations__.get(orig_name) + elif (value := self.free_vars.get(arg)): + if not inspect.isclass(value) or inspect.isfunction(value): + raise RuntimeError(f"Cannot use free variables - please put {arg} at the global scope") return typ def argument_has(self, arg: str, annotation: type, exhaustive: bool=False) -> bool: @@ -81,7 +92,9 @@ def is_at_construct_arg(self, arg: str, exhaustive: bool=False) -> bool: @property def globals(self) -> Mapping[str, Any]: try: - fn_globals = inspect.getclosurevars(self.fn).globals + fn_tuple = inspect.getclosurevars(self.fn) + fn_globals = dict(fn_tuple.globals) + fn_globals.update(fn_tuple.nonlocals) # This covers the case of wrapping, rather than decorating. except TypeError: fn_globals = {} diff --git a/src/dewret/core.py b/src/dewret/core.py index b11a8a99..f8dde4fd 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,4 +1,4 @@ -from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union +from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any from dataclasses import dataclass import base64 from contextlib import contextmanager @@ -99,9 +99,6 @@ def __float__(self) -> bool: self._raise_unevaluatable_error() return False - def __iter__(self) -> Iterator["Reference"]: - yield IteratedGenerator(self) - def __int__(self) -> bool: self._raise_unevaluatable_error() return False @@ -121,6 +118,21 @@ def __str__(self) -> str: """Global description of the reference.""" return self.__name__ +class IterableMixin(Reference[U]): + def __iter__(self): + count = -1 + for _ in self.__inner_iter__(): + yield Iterated(to_wrap=self, iteration=(count := count + 1)) + + def __inner_iter__(self) -> Generator[Any, None, None]: + while True: + yield None + + def __getitem__(self, attr: str | int) -> Reference[U]: + if isinstance(attr, int): + return Iterated(to_wrap=self, iteration=attr) + return super().__getitem__(attr) + class IteratedGenerator(Generic[U]): __wrapped__: Reference[U] @@ -129,7 +141,7 @@ def __init__(self, to_wrap: Reference[U]): def __iter__(self): count = -1 - while True: + for _ in self.__wrapped__.__inner_iter__(): yield Iterated(to_wrap=self.__wrapped__, iteration=(count := count + 1)) @@ -142,6 +154,10 @@ def __init__(self, to_wrap: Reference[U], iteration: int, *args, **kwargs): self.__iteration__ = iteration super().__init__(*args, **kwargs) + @property + def _(self): + return self.__wrapped__._ + @property def __root_name__(self) -> str: return f"{self.__wrapped__.__root_name__}[{self.__iteration__}]" @@ -153,8 +169,9 @@ def __type__(self) -> type: def __hash__(self) -> int: return hash(self.__root_name__) - def __field__(self) -> str: - return str(self.__iteration__) + @property + def __field__(self) -> tuple[str]: + return tuple(list(self.__wrapped__.__field__) + [str(self.__iteration__)]) @property def __workflow__(self) -> WorkflowProtocol: diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 35147d33..7904ff0f 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -40,6 +40,7 @@ StepReference, ParameterReference, Unset, + expr_to_references ) from dewret.utils import flatten, DataclassProtocol, firm_to_raw, flatten_if_set from dewret.render import base_render @@ -494,19 +495,28 @@ def from_results( Returns: CWL-like structure representing all workflow outputs. """ - return cls( - outputs=[ - to_output_schema( - with_field(result), result.__type__, output_source=to_name(result) - ) for result in results - ] - if isinstance(results, list | tuple | Tuple) else { - key: to_output_schema( - with_field(result), result.__type__, output_source=to_name(result) + def _build_results(result): + if isinstance(result, Reference): + return to_output_schema( + with_field(result), with_type(result), output_source=to_name(result) ) - for key, result in results.items() - } - ) + results = result + return ( + [ + _build_results(result) for result in results + ] if isinstance(results, list | tuple | Tuple) else { + key: _build_results(result) for key, result in results.items() + } + ) + try: + return cls(outputs=_build_results(results)) + except AttributeError: + expr, references = expr_to_references(results) + references = sorted({str(ref._.parameter) for ref in references}) + return cls(outputs={ + "expression": str(expr), + "source": references + }) def render(self) -> dict[str, RawType] | list[RawType]: """Render to a dict-like structure. diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index a5739d63..57e04270 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -45,14 +45,13 @@ from contextvars import ContextVar, copy_context from contextlib import contextmanager -from .utils import is_raw, make_traceback, is_expr +from .utils import is_raw, make_traceback, is_expr, is_raw_type from .workflow import ( expr_to_references, unify_workflows, UNSET, Reference, StepReference, - ParameterReference, Workflow, Lazy, LazyEvaluation, @@ -429,7 +428,7 @@ def add_numbers(left: int, right: int): def _to_param_ref(value): if isinstance(value, Parameter): - return ParameterReference(workflow=__workflow__, parameter=value) + return value.make_reference(workflow=__workflow__) refs = [] for key, val in kwargs.items(): @@ -458,20 +457,17 @@ def _to_param_ref(value): elif is_raw(value): # We leave this reference dangling for a consumer to pick up ("tethered"), unless # we are in a nested task, that does not have any existence of its own. - kwargs[var] = ParameterReference( - workflow=workflow, - parameter=param( - var, - value, - tethered=( - False if nested and ( - flatten_nested or get_configuration("flatten_all_nested") - ) else None - ), - autoname=True, - typ=analyser.get_argument_annotation(var) or UNSET + kwargs[var] = param( + var, + value, + tethered=( + False if nested and ( + flatten_nested or get_configuration("flatten_all_nested") + ) else None ), - ) + autoname=True, + typ=analyser.get_argument_annotation(var) or UNSET + ).make_reference(workflow=workflow) original_kwargs = dict(kwargs) fn_globals = analyser.globals @@ -483,7 +479,7 @@ def _to_param_ref(value): # "Captured parameter {var} (global variable in task) shadows an argument" # ) if isinstance(value, Parameter): - kwargs[var] = ParameterReference(workflow=workflow, parameter=value) + kwargs[var] = value.make_reference(workflow=workflow) elif is_task(value) or ensure_lazy(value) is not None: if not nested and _workaround_check_value_is_task( fn, var, value @@ -514,19 +510,22 @@ def {fn.__name__}(...) -> ...: value is evaluate or value is construct # Allow manual building. ): kwargs[var] = value + elif ( + inspect.isclass(value) or + inspect.isfunction(value) + ): + # We assume these are loaded at runtime. + ... elif is_raw(value) or ( (attrs_has(value) or is_dataclass(value)) and not inspect.isclass(value) ): - kwargs[var] = ParameterReference( - workflow=workflow, - parameter=param( - var, - value, - tethered=False, - typ=analyser.get_argument_annotation(var, exhaustive=True) or UNSET - ) - ) + kwargs[var] = param( + var, + default=value, + tethered=False, + typ=analyser.get_argument_annotation(var, exhaustive=True) or UNSET + ).make_reference(workflow=workflow) elif nested: raise NotImplementedError( f"Nested tasks must now only refer to global parameters, raw or tasks, not objects: {var}" @@ -546,16 +545,14 @@ def {fn.__name__}(...) -> ...: else: nested_workflow = Workflow(name=fn.__name__) nested_globals: Param.kwargs = { - var: ParameterReference( - workflow=nested_workflow, - parameter=param( - var, - typ=( - value.__type__ - ), - tethered=nested_workflow + var: param( + var, + default=value.__default__ if hasattr(value, "__default__") else UNSET, + typ=( + value.__type__ ), - ) if isinstance(value, Reference) else value + tethered=nested_workflow + ).make_reference(workflow=nested_workflow) if isinstance(value, Reference) else value for var, value in kwargs.items() } nested_kwargs = {key: value for key, value in nested_globals.items() if key in original_kwargs} diff --git a/src/dewret/utils.py b/src/dewret/utils.py index bf9fc98d..4dc921f0 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -21,7 +21,7 @@ import json import sys from types import FrameType, TracebackType, UnionType -from typing import Any, cast, Union, Protocol, ClassVar, Callable, Iterable, get_args +from typing import Any, cast, Union, Protocol, ClassVar, Callable, Iterable, get_args, get_origin, Annotated from collections.abc import Sequence, Mapping from sympy import Basic, Integer, Float, Rational @@ -104,6 +104,14 @@ def flatten(value: Any) -> RawType: def is_expr(value: Any) -> bool: return is_raw(value, lambda x: isinstance(x, Basic) or isinstance(x, tuple) or isinstance(x, Reference) or isinstance(x, Raw)) +def strip_annotations(parent_type: type) -> tuple[type, tuple]: + # Strip out any annotations. This should be auto-flattened, so in theory only one iteration could occur. + metadata = [] + while get_origin(parent_type) is Annotated: + parent_type, *parent_metadata = get_args(parent_type) + metadata += list(parent_metadata) + return parent_type, tuple(metadata) + def is_raw_type(typ: type) -> bool: """Check if a type counts as "raw".""" if isinstance(typ, UnionType): diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index e6123625..7ac8a229 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -24,7 +24,7 @@ from attrs import has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter, OrderedDict -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable, get_origin, get_args +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable, get_origin, get_args, Generator, Sized from uuid import uuid4 from sympy import Symbol, Expr, Basic, Tuple, Dict, nan @@ -32,8 +32,8 @@ logger = logging.getLogger(__name__) -from .core import Reference, get_configuration, Raw, RawType -from .utils import hasher, is_raw, make_traceback, is_raw_type, is_expr, Unset +from .core import RawType, IterableMixin, Reference, get_configuration, Raw, IteratedGenerator +from .utils import hasher, is_raw, make_traceback, is_raw_type, is_expr, Unset, strip_annotations T = TypeVar("T") U = TypeVar("U") @@ -156,6 +156,11 @@ def __init__( if tethered and isinstance(tethered, BaseStep): self.register_caller(tethered) + @property + def is_loopable(self): + base = get_origin(strip_annotations(self.__type__)[0]) + return inspect.isclass(base) and issubclass(base, Iterable) and not issubclass(base, str | bytes) + @property def __type__(self): if self.__fixed_type__ is not UNSET: @@ -188,6 +193,11 @@ def __hash__(self) -> int: # ) return hash(self.__name__) + def make_reference(self, workflow: Workflow) -> "ParameterReference": + if self.is_loopable: + return IterableParameterReference(workflow=workflow, parameter=self) + return ParameterReference(workflow=workflow, parameter=self) + @property def default(self) -> T | UnsetType[T]: """Retrieve default value for this parameter, or an unset token.""" @@ -544,7 +554,7 @@ def add_nested_step( return_type = return_type or step.return_type if return_type is inspect._empty: raise TypeError("All tasks should have a type annotation.") - return StepReference(step, typ=return_type) + return step.make_reference(return_type=return_type) def add_step( self, @@ -578,7 +588,7 @@ def add_step( and not inspect.isclass(fn) ): raise TypeError("All tasks should have a type annotation.") - return StepReference(step, return_type) + return step.make_reference(return_type) @staticmethod def from_result( @@ -710,10 +720,7 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None """ # Get new type, for the specific field. - parent_type = self.__type__ - # Strip out any annotations. - while get_origin(parent_type) is Annotated: - parent_type = get_args(parent_type)[0] + parent_type, _ = strip_annotations(self.__type__) field_type = fallback_type if is_dataclass(parent_type): @@ -733,7 +740,7 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None field_type = parent_type.__annotations__[field] except KeyError: raise AttributeError(f"TypedDict {parent_type} does not have field {field}") - if not field_type and get_configuration("allow_plain_dict_fields") and get_origin(parent_type) is dict: + if not field_type and get_configuration("allow_plain_dict_fields") and strip_annotations(get_origin(parent_type))[0] is dict: args = get_args(parent_type) if len(args) == 2 and args[0] is str: field_type = args[1] @@ -806,15 +813,13 @@ def __init__( and is_raw(value) ): if raw_as_parameter: - value = ParameterReference( - workflow=workflow, parameter=param(key, value, tethered=None) - ) + value = param(key, value, tethered=None).make_reference(workflow=workflow) else: value = Raw(value) def _to_param_ref(value): if isinstance(value, Parameter): - return ParameterReference(workflow=workflow, parameter=value) + return value.make_parameter(workflow=workflow) value, refs = expr_to_references(value, remap=_to_param_ref) for ref in refs: @@ -841,8 +846,14 @@ def __eq__(self, other: object) -> bool: and self.arguments == other.arguments ) + def make_reference(self, return_type: type) -> "StepReference": + base = get_origin(strip_annotations(return_type)[0]) + if inspect.isclass(base) and issubclass(base, Iterable) and not issubclass(base, str | bytes): + return IterableStepReference(step=self, typ=return_type) + return StepReference(step=self, typ=return_type) + def set_workflow(self, workflow: Workflow, with_arguments: bool = True) -> None: - """Move the step reference to a different workflow. + """Move the step reference to another workflow. This method is primarily intended to be called by a step, allowing it to switch to a new workflow. It also updates the workflow reference for any @@ -1067,7 +1078,7 @@ def __init__(self, parameter: Parameter[U], *args, typ: type[U] | None=None, **k self._ = self.ParameterReferenceMetadata(parameter, *args, typ, **kwargs) super().__init__(*args, typ=typ, **kwargs) - def __getattr__(self, attr: str) -> "ParameterReference": + def __getitem__(self, attr: str) -> "ParameterReference": try: return self.find_field( field=attr, @@ -1075,12 +1086,13 @@ def __getattr__(self, attr: str) -> "ParameterReference": parameter=self._.parameter ) except AttributeError as exc: - if not "dask_graph" in str(exc): - raise - return super().__getattribute__(attr) + raise KeyError(attr) from exc - def __getitem__(self, attr: str) -> "ParameterReference": - return getattr(self, attr) + def __getattr__(self, attr: str) -> "ParameterReference": + try: + return self[attr] + except KeyError as exc: + return super().__getattribute__(attr) def __repr__(self) -> str: """Hashable reference to the step (and field).""" @@ -1113,6 +1125,15 @@ def __eq__(self, other: object) -> bool: (isinstance(other, ParameterReference) and self._.parameter == other._.parameter and self.__field__ == other.__field__) ) +class IterableParameterReference(IterableMixin, ParameterReference[U]): + def __inner_iter__(self) -> Generator[Any, None, None]: + inner, metadata = strip_annotations(self.__type__) + if metadata and metadata[0] == "Fixed" and isinstance(self.__default__, Sized): + yield from range(len(self.__default__)) + else: + while True: + yield None + class StepReference(FieldableMixin, Reference[U]): """Reference to an individual `Step`. @@ -1173,7 +1194,7 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash((repr(self), id(self.__workflow__))) - def __getattr__(self, attr: str) -> "StepReference[Any]": + def __getitem__(self, attr: str) -> "StepReference[Any]": """Reference to a field within this result, if possible. If the result is an attrs-class or dataclass, this will pull out an individual @@ -1194,14 +1215,17 @@ def __getattr__(self, attr: str) -> "StepReference[Any]": workflow=self.__workflow__, step=self._.step, field=attr ) except AttributeError as exc: + raise KeyError(attr) from exc + + def __getattr__(self, attr: str) -> "StepReference": + try: + return self[attr] + except KeyError as exc: try: return super().__getattribute__(attr) except AttributeError as inner_exc: raise inner_exc from exc - def __getitem__(self, attr: str) -> "StepReference": - return getattr(self, attr) - @property def __type__(self) -> type: return self._.return_type @@ -1233,6 +1257,14 @@ def __workflow__(self, workflow: Workflow) -> None: """ self._.step.set_workflow(workflow) +class IterableStepReference(IterableMixin, StepReference[U]): + def __getitem__(self, attr: str | int) -> "StepReference"[Any]: + if isinstance(attr, int): + return Iterated(to_wrap=self, iteration=attr) + return super().__getitem__(attr) + + def __iter__(self): + yield IteratedGenerator(self) def merge_workflows(*workflows: Workflow) -> Workflow: """Combine several workflows into one. diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 8de4a451..70211dbf 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -1,9 +1,11 @@ import pytest import yaml +from typing import Literal from dewret.tasks import task, construct, subworkflow, TaskException from dewret.renderers.cwl import render -from dewret.annotations import AtRender, FunctionAnalyser +from dewret.annotations import AtRender, FunctionAnalyser, Fixed +from dewret.core import set_configuration from ._lib.extra import increment, sum, try_nothing @@ -141,3 +143,31 @@ def test_at_render_between_modules() -> None: workflow = construct(result, simplify_ids=True) subworkflows = render(workflow, allow_complex_types=True) rendered = subworkflows["__root__"] + +list_2: Fixed[list[int]] = [0, 1, 2, 3] + +def test_can_loop_over_fixed_length() -> None: + @subworkflow() + def loop_over_lists(list_1: list[int]) -> list[int]: + result = [] + for a, b in zip(list_1, list_2): + result.append(a + b) + return result + + with set_configuration(flatten_all_nested=True): + result = loop_over_lists(list_1=[5, 6, 7, 8]) + workflow = construct(result, simplify_ids=True) + subworkflows = render(workflow, allow_complex_types=True) + rendered = subworkflows["__root__"] + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: {} + outputs: + expression: '[list_1[0] + list_2[0], list_1[1] + list_2[1], list_1[2] + list_2[2], + list_1[3] + list_2[3]]' + source: + - list_1 + - list_2 + steps: {} + """) diff --git a/tests/test_errors.py b/tests/test_errors.py index 7defe10b..6ffbf593 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -54,9 +54,9 @@ def pi_exported_from_math() -> float: @task() -def test_recursive() -> float: +def try_recursive() -> float: """Get pi from math package by name.""" - return test_recursive() + return try_recursive() @task() @@ -193,7 +193,7 @@ def test_nesting_does_not_identify_imports_as_nesting() -> None: pi_hidden_by_math, pi_hidden_by_math_2, ] - bad = [test_recursive, pi_with_visible_module_task] + bad = [try_recursive, pi_with_visible_module_task] for tsk in bad: with pytest.raises(TaskException) as exc: tsk() @@ -216,7 +216,7 @@ def test_normal_objects_cannot_be_used_in_subworkflows() -> None: unacceptable_object_usage() assert ( str(exc.value) - == "Nested tasks must now only refer to global parameters, raw or tasks, not objects: MyStrangeClass" + == "Attempted to build a workflow from a return-value/result/expression with no references." ) @@ -237,3 +237,39 @@ def test_subworkflows_must_return_a_task() -> None: result = unacceptable_nested_return(int_not_global=False) construct(result) + +bad_num = 3 +good_num: int = 4 + +def test_must_annotate_global() -> None: + worse_num = 3 + + @subworkflow() + def check_annotation() -> int | float: + return increment(num=bad_num) + + with pytest.raises(TaskException) as exc: + result = check_annotation() + + assert ( + str(exc.value) + == "Could not find a type annotation for bad_num for check_annotation" + ) + + @subworkflow() + def check_annotation_2() -> int | float: + return increment(num=worse_num) + + with pytest.raises(TaskException) as exc: + result = check_annotation_2() + + assert ( + str(exc.value) + == "Cannot use free variables - please put worse_num at the global scope" + ) + + @subworkflow() + def check_annotation_3() -> int | float: + return increment(num=good_num) + + check_annotation_3() diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index ff4fa65a..edd2c829 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -113,7 +113,7 @@ def test_task(alpha: int, beta: float, charlie: bool) -> int: return int(alpha + beta) @task() - def test_list() -> list: + def test_list() -> list[int | float]: return [1, 2.] @subworkflow() diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 163a8e84..cea34085 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -125,6 +125,7 @@ def test_subworkflows_can_use_globals() -> None: inputs: CONSTANT: label: CONSTANT + default: 3 type: int num: label: num @@ -275,6 +276,7 @@ def test_subworkflows_can_return_lists() -> None: type: int CONSTANT: label: CONSTANT + default: 3 type: int GLOBAL_QUEUE: label: GLOBAL_QUEUE @@ -314,6 +316,7 @@ def test_subworkflows_can_return_lists() -> None: label: num type: int CONSTANT: + default: 3 label: CONSTANT type: int outputs: @@ -345,6 +348,7 @@ def test_subworkflows_can_return_lists() -> None: cwlVersion: 1.2 inputs: CONSTANT: + default: 3 label: CONSTANT type: int num: @@ -461,6 +465,7 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: cwlVersion: 1.2 inputs: CONSTANT: + default: 3 label: CONSTANT type: int num: @@ -502,6 +507,7 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: label: num type: int CONSTANT: + default: 3 label: CONSTANT type: int outputs: From 5773fae3b5ae2102102f261e3a8fa19494735fc5 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 18 Aug 2024 00:20:17 +0100 Subject: [PATCH 040/108] feat(fixed): add loopable parameters --- src/dewret/workflow.py | 47 +++++++++++++++++++++++++++-------------- tests/test_fieldable.py | 16 ++++++++++++++ 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 7ac8a229..e9ae6b66 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -156,9 +156,8 @@ def __init__( if tethered and isinstance(tethered, BaseStep): self.register_caller(tethered) - @property - def is_loopable(self): - base = get_origin(strip_annotations(self.__type__)[0]) + def is_loopable(self, typ: type): + base = get_origin(strip_annotations(typ)[0]) return inspect.isclass(base) and issubclass(base, Iterable) and not issubclass(base, str | bytes) @property @@ -193,10 +192,13 @@ def __hash__(self) -> int: # ) return hash(self.__name__) - def make_reference(self, workflow: Workflow) -> "ParameterReference": - if self.is_loopable: - return IterableParameterReference(workflow=workflow, parameter=self) - return ParameterReference(workflow=workflow, parameter=self) + def make_reference(self, **kwargs) -> "ParameterReference": + kwargs["parameter"] = self + kwargs.setdefault("typ", self.__type__) + typ = kwargs["typ"] + if self.is_loopable(typ): + return IterableParameterReference(**kwargs) + return ParameterReference(**kwargs) @property def default(self) -> T | UnsetType[T]: @@ -554,7 +556,7 @@ def add_nested_step( return_type = return_type or step.return_type if return_type is inspect._empty: raise TypeError("All tasks should have a type annotation.") - return step.make_reference(return_type=return_type) + return step.make_reference(typ=return_type) def add_step( self, @@ -588,7 +590,7 @@ def add_step( and not inspect.isclass(fn) ): raise TypeError("All tasks should have a type annotation.") - return step.make_reference(return_type) + return step.make_reference(typ=return_type) @staticmethod def from_result( @@ -681,6 +683,7 @@ def __workflow__(self) -> Workflow: class FieldableProtocol(Protocol): __field__: tuple[str, ...] + __field_sep__: str def __init__(self, *args, field: str | None = None, **kwargs): super().__init__(*args, **kwargs) @@ -693,6 +696,9 @@ def __type__(self): def name(self): return "name" + def __make_reference__(self, *args, **kwargs) -> "FieldableProtocol": + ... + # Subclass Reference so that we know Reference methods/attrs are available. class FieldableMixin: def __init__(self: FieldableProtocol, *args, field: str | None = None, **kwargs): @@ -754,7 +760,7 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None if self.__field__: field = self.__field_sep__.join(self.__field__) + self.__field_sep__ + field - return self.__class__(typ=field_type, field=field, **init_kwargs) + return self.__make_reference__(typ=field_type, field=field, **init_kwargs) raise AttributeError( f"Could not determine the type for field {field} in type {parent_type}" @@ -846,11 +852,14 @@ def __eq__(self, other: object) -> bool: and self.arguments == other.arguments ) - def make_reference(self, return_type: type) -> "StepReference": - base = get_origin(strip_annotations(return_type)[0]) + def make_reference(self, **kwargs) -> "StepReference": + kwargs["step"] = self + kwargs.setdefault("typ", self.return_type) + typ = kwargs["typ"] + base = get_origin(strip_annotations(typ)[0]) if inspect.isclass(base) and issubclass(base, Iterable) and not issubclass(base, str | bytes): - return IterableStepReference(step=self, typ=return_type) - return StepReference(step=self, typ=return_type) + return IterableStepReference(**kwargs) + return StepReference(**kwargs) def set_workflow(self, workflow: Workflow, with_arguments: bool = True) -> None: """Move the step reference to another workflow. @@ -1086,7 +1095,7 @@ def __getitem__(self, attr: str) -> "ParameterReference": parameter=self._.parameter ) except AttributeError as exc: - raise KeyError(attr) from exc + raise KeyError(f"Key not found in {self.__root_name__} ({type(self)}:{self.__type__}): {attr}") from exc def __getattr__(self, attr: str) -> "ParameterReference": try: @@ -1125,6 +1134,9 @@ def __eq__(self, other: object) -> bool: (isinstance(other, ParameterReference) and self._.parameter == other._.parameter and self.__field__ == other.__field__) ) + def __make_reference__(self, **kwargs) -> "StepReference": + return self._.parameter.make_reference(**kwargs) + class IterableParameterReference(IterableMixin, ParameterReference[U]): def __inner_iter__(self) -> Generator[Any, None, None]: inner, metadata = strip_annotations(self.__type__) @@ -1215,7 +1227,7 @@ def __getitem__(self, attr: str) -> "StepReference[Any]": workflow=self.__workflow__, step=self._.step, field=attr ) except AttributeError as exc: - raise KeyError(attr) from exc + raise KeyError(f"Key not found in {self.__root_name__} ({type(self)}:{self.__type__}): {attr}") from exc def __getattr__(self, attr: str) -> "StepReference": try: @@ -1257,6 +1269,9 @@ def __workflow__(self, workflow: Workflow) -> None: """ self._.step.set_workflow(workflow) + def __make_reference__(self, **kwargs) -> "StepReference": + return self._.step.make_reference(**kwargs) + class IterableStepReference(IterableMixin, StepReference[U]): def __getitem__(self, attr: str | int) -> "StepReference"[Any]: if isinstance(attr, int): diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index edd2c829..f1085520 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -156,6 +156,22 @@ def test_iterated() -> int: assert workflow.result._.step.positional_args == {"alpha": True, "beta": True, "charlie": True} + @dataclass + class MyListWrapper: + my_list: list[int] + + @task() + def test_list_2() -> MyListWrapper: + return MyListWrapper(my_list=[1, 2]) + + @subworkflow() + def test_iterated_2(my_wrapper: MyListWrapper) -> int: + return test_task(*my_wrapper.my_list) + + with set_configuration(allow_positional_args=True, flatten_all_nested=True): + result = test_iterated_2(my_wrapper=test_list_2()) + workflow = construct(result, simplify_ids=True) + def test_can_use_plain_dict_fields(): @subworkflow() def test_dict(left: int, right: float) -> dict[str, float | int]: From 0c2cbdc4819a62b26c52a26e9f507be54caf0556 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 18 Aug 2024 00:32:17 +0100 Subject: [PATCH 041/108] fix: better error message for iterating a non-iterable --- src/dewret/workflow.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index e9ae6b66..ea1fb878 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -1095,7 +1095,11 @@ def __getitem__(self, attr: str) -> "ParameterReference": parameter=self._.parameter ) except AttributeError as exc: - raise KeyError(f"Key not found in {self.__root_name__} ({type(self)}:{self.__type__}): {attr}") from exc + raise KeyError( + f"Key not found in {self.__root_name__} ({type(self)}:{self.__type__}): {attr}" + + ". This could be because you are trying to iterate/index a reference whose type is not definitely iterable - double check your typehints." + if isinstance(attr, int) else "" + ) from exc def __getattr__(self, attr: str) -> "ParameterReference": try: @@ -1227,7 +1231,11 @@ def __getitem__(self, attr: str) -> "StepReference[Any]": workflow=self.__workflow__, step=self._.step, field=attr ) except AttributeError as exc: - raise KeyError(f"Key not found in {self.__root_name__} ({type(self)}:{self.__type__}): {attr}") from exc + raise KeyError( + f"Key not found in {self.__root_name__} ({type(self)}:{self.__type__}): {attr}" + + ". This could be because you are trying to iterate/index a reference whose type is not definitely iterable - double check your typehints." + if isinstance(attr, int) else "" + ) from exc def __getattr__(self, attr: str) -> "StepReference": try: From 00036d0ba97933cfcdb956cd0006bdd6dec599df Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 18 Aug 2024 00:41:51 +0100 Subject: [PATCH 042/108] fix: add is_firm --- src/dewret/tasks.py | 6 +++--- src/dewret/utils.py | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 57e04270..ec6a7d93 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -45,7 +45,7 @@ from contextvars import ContextVar, copy_context from contextlib import contextmanager -from .utils import is_raw, make_traceback, is_expr, is_raw_type +from .utils import is_firm, make_traceback, is_expr, is_raw_type from .workflow import ( expr_to_references, unify_workflows, @@ -454,7 +454,7 @@ def _to_param_ref(value): for var, value in kwargs.items(): if analyser.is_at_construct_arg(var): kwargs[var] = value - elif is_raw(value): + elif is_firm(value): # We leave this reference dangling for a consumer to pick up ("tethered"), unless # we are in a nested task, that does not have any existence of its own. kwargs[var] = param( @@ -516,7 +516,7 @@ def {fn.__name__}(...) -> ...: ): # We assume these are loaded at runtime. ... - elif is_raw(value) or ( + elif is_firm(value) or ( (attrs_has(value) or is_dataclass(value)) and not inspect.isclass(value) ): diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 4dc921f0..1609b686 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -119,6 +119,9 @@ def is_raw_type(typ: type) -> bool: return issubclass(typ, str | float | bool | bytes | int | None | list | dict) +def is_firm(value: Any, check: Callable[[Any], bool] | None = None) -> bool: + return is_raw(value, lambda x: isinstance(x, tuple)) + def is_raw(value: Any, check: Callable[[Any], bool] | None = None) -> bool: """Check if a variable counts as "raw". From 5b83266cc221475d73b04dab845f29c6bb90f809 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 18 Aug 2024 01:08:53 +0100 Subject: [PATCH 043/108] fix: move indexing to find_field --- src/dewret/core.py | 7 +--- src/dewret/workflow.py | 93 +++++++++++++++++++++++++----------------- 2 files changed, 58 insertions(+), 42 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index f8dde4fd..7dfde082 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -120,17 +120,14 @@ def __str__(self) -> str: class IterableMixin(Reference[U]): def __iter__(self): - count = -1 - for _ in self.__inner_iter__(): - yield Iterated(to_wrap=self, iteration=(count := count + 1)) + for count, _ in enumerate(self.__inner_iter__()): + yield super().__getitem__(count) def __inner_iter__(self) -> Generator[Any, None, None]: while True: yield None def __getitem__(self, attr: str | int) -> Reference[U]: - if isinstance(attr, int): - return Iterated(to_wrap=self, iteration=attr) return super().__getitem__(attr) class IteratedGenerator(Generic[U]): diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index ea1fb878..155d90e4 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -24,7 +24,7 @@ from attrs import has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter, OrderedDict -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable, get_origin, get_args, Generator, Sized +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable, get_origin, get_args, Generator, Sized, Sequence from uuid import uuid4 from sympy import Symbol, Expr, Basic, Tuple, Dict, nan @@ -701,8 +701,8 @@ def __make_reference__(self, *args, **kwargs) -> "FieldableProtocol": # Subclass Reference so that we know Reference methods/attrs are available. class FieldableMixin: - def __init__(self: FieldableProtocol, *args, field: str | None = None, **kwargs): - self.__field__: tuple[str, ...] = tuple(field.split(self.__field_sep__)) if field else () + def __init__(self: FieldableProtocol, *args, field: str | int | tuple | None = None, **kwargs): + self.__field__: tuple[str, ...] = (field if isinstance(field, tuple) else (field,)) if field is not None else () super().__init__(*args, **kwargs) @property @@ -716,7 +716,17 @@ def __name__(self: FieldableProtocol) -> str: May be remapped by the workflow to something nicer than the ID. """ - return self.__field_sep__.join([super().__name__] + list(self.__field__)) + return super().__name__ + self.__field_suffix__ + + @property + def __field_suffix__(self) -> str: + result = "" + for cmpt in self.__field__: + if isinstance(cmpt, int): + result += f"[{cmpt}]" + else: + result += f"{self.__field_sep__}{cmpt}" + return result def find_field(self: FieldableProtocol, field, fallback_type: type | None = None, **init_kwargs: Any) -> Reference: """Field within the reference, if possible. @@ -729,36 +739,46 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None parent_type, _ = strip_annotations(self.__type__) field_type = fallback_type - if is_dataclass(parent_type): - try: - field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type - except StopIteration: - raise AttributeError(f"Dataclass {parent_type} does not have field {field}") - elif attr_has(parent_type): - resolve_types(parent_type) - try: - field_type = getattr(attrs_fields(parent_type), field).type - except AttributeError: - raise AttributeError(f"attrs-class {parent_type} does not have field {field}") - # TypedDict - elif inspect.isclass(parent_type) and issubclass(parent_type, dict) and hasattr(parent_type, "__annotations__"): - try: - field_type = parent_type.__annotations__[field] - except KeyError: - raise AttributeError(f"TypedDict {parent_type} does not have field {field}") - if not field_type and get_configuration("allow_plain_dict_fields") and strip_annotations(get_origin(parent_type))[0] is dict: - args = get_args(parent_type) - if len(args) == 2 and args[0] is str: - field_type = args[1] - else: - raise AttributeError(f"Can only get fields for plain dicts if annotated dict[str, TYPE]") + if isinstance(field, int): + base = get_origin(parent_type) + if not inspect.isclass(base) or not issubclass(base, Sequence): + raise AttributeError(f"Tried to index int {field} into a non-sequence type {parent_type} (base: {base})") + if not (field_type := get_args(parent_type)[0]): + raise AttributeError( + f"Tried to index int {field} into type {parent_type} but can only do so if the first type argument " + f"is the element type (args: {get_args(parent_type)}" + ) + else: + if is_dataclass(parent_type): + try: + field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type + except StopIteration: + raise AttributeError(f"Dataclass {parent_type} does not have field {field}") + elif attr_has(parent_type): + resolve_types(parent_type) + try: + field_type = getattr(attrs_fields(parent_type), field).type + except AttributeError: + raise AttributeError(f"attrs-class {parent_type} does not have field {field}") + # TypedDict + elif inspect.isclass(parent_type) and issubclass(parent_type, dict) and hasattr(parent_type, "__annotations__"): + try: + field_type = parent_type.__annotations__[field] + except KeyError: + raise AttributeError(f"TypedDict {parent_type} does not have field {field}") + if not field_type and get_configuration("allow_plain_dict_fields") and strip_annotations(get_origin(parent_type))[0] is dict: + args = get_args(parent_type) + if len(args) == 2 and args[0] is str: + field_type = args[1] + else: + raise AttributeError(f"Can only get fields for plain dicts if annotated dict[str, TYPE]") if field_type: if not issubclass(self.__class__, Reference): raise TypeError("Only references can have a fieldable mixin") if self.__field__: - field = self.__field_sep__.join(self.__field__) + self.__field_sep__ + field + field = tuple(list(self.__field__) + [field]) return self.__make_reference__(typ=field_type, field=field, **init_kwargs) @@ -1097,8 +1117,10 @@ def __getitem__(self, attr: str) -> "ParameterReference": except AttributeError as exc: raise KeyError( f"Key not found in {self.__root_name__} ({type(self)}:{self.__type__}): {attr}" + - ". This could be because you are trying to iterate/index a reference whose type is not definitely iterable - double check your typehints." - if isinstance(attr, int) else "" + ( + ". This could be because you are trying to iterate/index a reference whose type is not definitely iterable - double check your typehints." + if isinstance(attr, int) else "" + ) ) from exc def __getattr__(self, attr: str) -> "ParameterReference": @@ -1233,8 +1255,10 @@ def __getitem__(self, attr: str) -> "StepReference[Any]": except AttributeError as exc: raise KeyError( f"Key not found in {self.__root_name__} ({type(self)}:{self.__type__}): {attr}" + - ". This could be because you are trying to iterate/index a reference whose type is not definitely iterable - double check your typehints." - if isinstance(attr, int) else "" + ( + ". This could be because you are trying to iterate/index a reference whose type is not definitely iterable - double check your typehints." + if isinstance(attr, int) else "" + ) ) from exc def __getattr__(self, attr: str) -> "StepReference": @@ -1281,11 +1305,6 @@ def __make_reference__(self, **kwargs) -> "StepReference": return self._.step.make_reference(**kwargs) class IterableStepReference(IterableMixin, StepReference[U]): - def __getitem__(self, attr: str | int) -> "StepReference"[Any]: - if isinstance(attr, int): - return Iterated(to_wrap=self, iteration=attr) - return super().__getitem__(attr) - def __iter__(self): yield IteratedGenerator(self) From fc5309e9ca2b3aa3b51151d67cce0c4540da445f Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 18 Aug 2024 02:22:10 +0100 Subject: [PATCH 044/108] fix: resolve the parameter deduplication --- src/dewret/core.py | 32 +++++++++++++++---- src/dewret/renderers/cwl.py | 19 ++++++++---- src/dewret/tasks.py | 13 ++++---- src/dewret/utils.py | 8 ----- src/dewret/workflow.py | 15 +++++---- tests/test_fieldable.py | 62 +++++++++++++++++++++++++++++++++++++ 6 files changed, 117 insertions(+), 32 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index 7dfde082..4d5028b1 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,6 +1,6 @@ -from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any from dataclasses import dataclass import base64 +from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated from contextlib import contextmanager from contextvars import ContextVar from sympy import Expr, Symbol @@ -11,9 +11,13 @@ U = TypeVar("U") -BasicType = str | float | bool | bytes | int | None -RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] -FirmType = BasicType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] +def strip_annotations(parent_type: type) -> tuple[type, tuple]: + # Strip out any annotations. This should be auto-flattened, so in theory only one iteration could occur. + metadata = [] + while get_origin(parent_type) is Annotated: + parent_type, *parent_metadata = get_args(parent_type) + metadata += list(parent_metadata) + return parent_type, tuple(metadata) class WorkflowProtocol(Protocol): ... @@ -119,13 +123,29 @@ def __str__(self) -> str: return self.__name__ class IterableMixin(Reference[U]): + __fixed_len__: int | None = None + + def __init__(self, typ: type[U] | None=None, **kwargs): + base = strip_annotations(typ)[0] + super().__init__(typ=typ, **kwargs) + if get_origin(base) == tuple and (args := get_args(base)): + # In the special case of an explicitly-typed tuple, we can state a length. + self.__fixed_len__ = len(args) + + def __len__(self): + return self.__fixed_len__ + def __iter__(self): for count, _ in enumerate(self.__inner_iter__()): yield super().__getitem__(count) def __inner_iter__(self) -> Generator[Any, None, None]: - while True: - yield None + if self.__fixed_len__ is not None: + for i in range(self.__fixed_len__): + yield i + else: + while True: + yield None def __getitem__(self, attr: str | int) -> Reference[U]: return super().__getitem__(attr) diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 7904ff0f..dc8f41f7 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -205,7 +205,7 @@ def render(self) -> dict[str, RawType]: if isinstance(ref, ReferenceDefinition) else {"expression": render_expression(ref)} if isinstance(ref, Basic) else - {"default": ref.value} + {"default": firm_to_raw(ref.value)} if hasattr(ref, "value") else {"expression": render_expression(ref)} ) @@ -393,7 +393,7 @@ def _raw_to_command_input_schema_internal( ) structure["items"] = to_cwl_type(label, typeset.pop())["type"] elif not isinstance(value, Unset): - structure["default"] = value + structure["default"] = firm_to_raw(value) return structure @@ -433,13 +433,15 @@ def from_parameters( Returns: CWL-like structure representing all workflow outputs. """ + parameters_dedup = {p._.parameter for p in parameters if isinstance(p, ParameterReference)} + parameters = list(parameters_dedup) + [p for p in parameters if not isinstance(p, ParameterReference)] return cls( inputs={ - input.__name__: cls.CommandInputParameter( + input.name: cls.CommandInputParameter( label=input.__name__, default=(default := flatten_if_set(input.__default__)), type=raw_to_command_input_schema( - label=input.__name__, value=default + label=input.name, value=default ), ) for input in parameters @@ -462,7 +464,7 @@ def render(self) -> dict[str, RawType]: "label": input.label, } if not isinstance(input.default, Unset): - item["default"] = input.default + item["default"] = firm_to_raw(input.default) result[key] = item return result @@ -512,7 +514,12 @@ def _build_results(result): return cls(outputs=_build_results(results)) except AttributeError: expr, references = expr_to_references(results) - references = sorted({str(ref._.parameter) for ref in references}) + references = sorted( + { + str(ref._.parameter) if isinstance(ref, ParameterReference) else str(ref._.step) + for ref in references + } + ) return cls(outputs={ "expression": str(expr), "source": references diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index ec6a7d93..9ba96aa3 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -457,15 +457,16 @@ def _to_param_ref(value): elif is_firm(value): # We leave this reference dangling for a consumer to pick up ("tethered"), unless # we are in a nested task, that does not have any existence of its own. + tethered = ( + False if nested and ( + flatten_nested or get_configuration("flatten_all_nested") + ) else None + ) kwargs[var] = param( var, value, - tethered=( - False if nested and ( - flatten_nested or get_configuration("flatten_all_nested") - ) else None - ), - autoname=True, + tethered=tethered, + autoname=tethered is not False, typ=analyser.get_argument_annotation(var) or UNSET ).make_reference(workflow=workflow) original_kwargs = dict(kwargs) diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 1609b686..894e5a3f 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -104,14 +104,6 @@ def flatten(value: Any) -> RawType: def is_expr(value: Any) -> bool: return is_raw(value, lambda x: isinstance(x, Basic) or isinstance(x, tuple) or isinstance(x, Reference) or isinstance(x, Raw)) -def strip_annotations(parent_type: type) -> tuple[type, tuple]: - # Strip out any annotations. This should be auto-flattened, so in theory only one iteration could occur. - metadata = [] - while get_origin(parent_type) is Annotated: - parent_type, *parent_metadata = get_args(parent_type) - metadata += list(parent_metadata) - return parent_type, tuple(metadata) - def is_raw_type(typ: type) -> bool: """Check if a type counts as "raw".""" if isinstance(typ, UnionType): diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 155d90e4..ab4da0b9 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -32,8 +32,8 @@ logger = logging.getLogger(__name__) -from .core import RawType, IterableMixin, Reference, get_configuration, Raw, IteratedGenerator -from .utils import hasher, is_raw, make_traceback, is_raw_type, is_expr, Unset, strip_annotations +from .core import RawType, IterableMixin, Reference, get_configuration, Raw, IteratedGenerator, strip_annotations +from .utils import hasher, is_raw, make_traceback, is_raw_type, is_expr, Unset T = TypeVar("T") U = TypeVar("U") @@ -119,6 +119,7 @@ class Parameter(Generic[T], Symbol): """ __name__: str + __name_suffix__: str = "" __default__: T | UnsetType[T] __tethered__: Literal[False] | None | BaseStep | Workflow __fixed_type__: type[T] | Unset @@ -143,8 +144,8 @@ def __init__( self.__original_name__ = name # TODO: is using this in a step hash a risk of ambiguity? (full name is circular) - #if autoname: - # name = f"{name}-{uuid4()}" + if autoname: + self.__name_suffix__ = f"-{uuid4()}" self.autoname = autoname self.__name__ = name @@ -215,7 +216,7 @@ def name(self) -> str: """ tethered = self.__tethered__ if tethered is False or tethered is None or self.autoname is False: - return self.__name__ + return self.__name__ + self.__name_suffix__ else: return f"{tethered.name}-{self.__original_name__}" @@ -1166,7 +1167,9 @@ def __make_reference__(self, **kwargs) -> "StepReference": class IterableParameterReference(IterableMixin, ParameterReference[U]): def __inner_iter__(self) -> Generator[Any, None, None]: inner, metadata = strip_annotations(self.__type__) - if metadata and metadata[0] == "Fixed" and isinstance(self.__default__, Sized): + if self.__fixed_len__ is not None: + yield from range(self.__fixed_len__) + elif metadata and metadata[0] == "Fixed" and isinstance(self.__default__, Sized): yield from range(len(self.__default__)) else: while True: diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index f1085520..db024f66 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -7,6 +7,7 @@ from dewret.tasks import task, construct, subworkflow, set_configuration from dewret.workflow import param from dewret.renderers.cwl import render +from dewret.annotations import Fixed from ._lib.extra import double, mod10, sum, pi @@ -172,6 +173,67 @@ def test_iterated_2(my_wrapper: MyListWrapper) -> int: result = test_iterated_2(my_wrapper=test_list_2()) workflow = construct(result, simplify_ids=True) + @task() + def test_list_3() -> Fixed[list[tuple[int, int]]]: + return [(0, 1), (2, 3)] + + @subworkflow() + def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: + retval = mod10(*test_list_3()[0]) + for pair in param: + a, b = pair + retval += a + b + return mod10(retval) + + with set_configuration(allow_positional_args=True, flatten_all_nested=True): + result = test_iterated_3(param=[(0, 1), (2, 3)]) + workflow = construct(result, simplify_ids=True) + + rendered = render(workflow, allow_complex_types=True)["__root__"] + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + param: + default: + - - 0 + - 1 + - - 2 + - 3 + label: param + type: + type: + items: array + label: param + type: array + outputs: + out: + label: out + outputSource: mod10-1/out + type: int + steps: + mod10-1: + in: + num: + expression: $(param[0][0] + param[0][1] + param[1][0] + param[1][1] + mod10-2) + out: + - out + run: mod10 + mod10-2: + in: + num: + source: test_list_3-1[0] + out: + - out + run: mod10 + test_list_3-1: + in: {} + out: + - out + run: test_list_3 + """) + def test_can_use_plain_dict_fields(): @subworkflow() def test_dict(left: int, right: float) -> dict[str, float | int]: From 36d15c67898b121c7406cc940c9463d9c5c6eecc Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 18 Aug 2024 02:23:52 +0100 Subject: [PATCH 045/108] fix: resolve the parameter deduplication --- src/dewret/workflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index ab4da0b9..23985b71 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -1136,7 +1136,7 @@ def __repr__(self) -> str: typ = self.__type__.__name__ except AttributeError: typ = str(self.__type__) - name = self.__field_sep__.join([self._.unique_name] + list(self.__field__)) + name = self._.unique_name + self.__field_suffix__ return f"{typ}|:param:{name}" def __hash__(self) -> int: @@ -1230,7 +1230,7 @@ def __str__(self) -> str: def __repr__(self) -> str: """Hashable reference to the step (and field).""" - return self.__field_sep__.join([self._.step.id] + list(self.__field__)) + return self._.step.id + self.__field_suffix__ def __hash__(self) -> int: return hash((repr(self), id(self.__workflow__))) From f6de37fdd52a97a9eafbc1dceee56fc8b95c822d Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 18 Aug 2024 03:09:34 +0100 Subject: [PATCH 046/108] fix: fixups in iteration --- src/dewret/workflow.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 23985b71..2bcb76b2 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -1092,7 +1092,16 @@ def unique_name(self) -> str: @property def __default__(self) -> T | Unset: """Default value of the parameter.""" - return self._.parameter.default + default = self._.parameter.default + if isinstance(default, Unset): + return default + + for field in self.__field__: + if isinstance(default, dict) or isinstance(field, int): + default = default[field] + else: + default = getattr(default, field) + return default @property def __root_name__(self) -> str: @@ -1165,11 +1174,18 @@ def __make_reference__(self, **kwargs) -> "StepReference": return self._.parameter.make_reference(**kwargs) class IterableParameterReference(IterableMixin, ParameterReference[U]): + def __iter__(self): + inner, metadata = strip_annotations(self.__type__) + if metadata and "AtRender" in metadata and isinstance(self.__default__, Iterable): + yield from self.__default__ + else: + yield from super().__iter__() + def __inner_iter__(self) -> Generator[Any, None, None]: inner, metadata = strip_annotations(self.__type__) if self.__fixed_len__ is not None: yield from range(self.__fixed_len__) - elif metadata and metadata[0] == "Fixed" and isinstance(self.__default__, Sized): + elif metadata and "Fixed" in metadata and isinstance(self.__default__, Sized): yield from range(len(self.__default__)) else: while True: From 3dbe777e588270a00f74ca453cf8b7e8737685da Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 19 Aug 2024 10:48:57 +0100 Subject: [PATCH 047/108] fix(typing): make types consistent in core --- src/dewret/core.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index 4d5028b1..2b10a97f 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import base64 -from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated +from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Never from contextlib import contextmanager from contextvars import ContextVar from sympy import Expr, Symbol @@ -20,7 +20,8 @@ def strip_annotations(parent_type: type) -> tuple[type, tuple]: return parent_type, tuple(metadata) class WorkflowProtocol(Protocol): - ... + def remap(self, name: str) -> str: + ... class UnevaluatableError(Exception): ... @@ -35,7 +36,7 @@ class ConstructConfiguration(TypedDict): CONSTRUCT_CONFIGURATION: ContextVar[ConstructConfiguration] = ContextVar("construct-configuration") @contextmanager -def set_configuration(**kwargs: Unpack[ConstructConfiguration]): +def set_configuration(**kwargs: Unpack[ConstructConfiguration]) -> Iterator[ContextVar[ConstructConfiguration]]: try: previous = ConstructConfiguration(**CONSTRUCT_CONFIGURATION.get()) except LookupError: @@ -55,8 +56,8 @@ def set_configuration(**kwargs: Unpack[ConstructConfiguration]): finally: CONSTRUCT_CONFIGURATION.set(previous) -def get_configuration(key: str): - return CONSTRUCT_CONFIGURATION.get()[key] +def get_configuration(key: str) -> RawType: + return CONSTRUCT_CONFIGURATION.get().get(key) # type: ignore class Reference(Generic[U], Symbol): """Superclass for all symbolic references to values.""" @@ -126,8 +127,8 @@ class IterableMixin(Reference[U]): __fixed_len__: int | None = None def __init__(self, typ: type[U] | None=None, **kwargs): - base = strip_annotations(typ)[0] super().__init__(typ=typ, **kwargs) + base = strip_annotations(self.__type__)[0] if get_origin(base) == tuple and (args := get_args(base)): # In the special case of an explicitly-typed tuple, we can state a length. self.__fixed_len__ = len(args) @@ -181,7 +182,7 @@ def __root_name__(self) -> str: @property def __type__(self) -> type: - return Iterator[self.__wrapped__.__type__] + return self.__wrapped__.__type__ def __hash__(self) -> int: return hash(self.__root_name__) From bd7ee4a708a889386f5cb070ca0393db17bd03cc Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 19 Aug 2024 11:04:22 +0100 Subject: [PATCH 048/108] fix(typing): move @subworkflow to @workflow --- example/workflow_complex.py | 4 +-- src/dewret/tasks.py | 6 ++-- tests/_lib/extra.py | 6 ++-- tests/test_annotations.py | 26 +++++++-------- tests/test_configuration.py | 4 +-- tests/test_errors.py | 14 ++++----- tests/test_fieldable.py | 56 ++++++++++++++++----------------- tests/test_modularity.py | 4 +-- tests/test_multiresult_steps.py | 8 ++--- tests/test_subworkflows.py | 44 +++++++++++++------------- 10 files changed, 86 insertions(+), 86 deletions(-) diff --git a/example/workflow_complex.py b/example/workflow_complex.py index a0a2dd28..d503c272 100644 --- a/example/workflow_complex.py +++ b/example/workflow_complex.py @@ -7,13 +7,13 @@ ``` """ -from dewret.tasks import subworkflow +from dewret.tasks import workflow from workflow_tasks import sum, double, increase STARTING_NUMBER: int = 23 -@subworkflow() +@workflow() def nested_workflow() -> int | float: """Creates a complex workflow with a nested task. diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 9ba96aa3..80742901 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -310,7 +310,7 @@ def factory(fn: Callable[..., RetType]) -> Callable[..., RetType]: return task(is_factory=True)(fn) -def subworkflow() -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: +def workflow() -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: """Shortcut for marking a task as nested. A nested task is one which calls other tasks and does not @@ -326,7 +326,7 @@ def subworkflow() -> Callable[[Callable[Param, RetType]], Callable[Param, RetTyp ... def increment(num: int) -> int: ... return num + 1 - >>> @subworkflow() + >>> @workflow() ... def double_increment(num: int) -> int: ... return increment(increment(num=num)) @@ -494,7 +494,7 @@ def _to_param_ref(value): def {var}(...) -> ...: ... - @subworkflow() <<<--- likely what you want + @workflow() <<<--- likely what you want def {fn.__name__}(...) -> ...: ... {var}(...) diff --git a/tests/_lib/extra.py b/tests/_lib/extra.py index bb663835..921a93c0 100644 --- a/tests/_lib/extra.py +++ b/tests/_lib/extra.py @@ -1,4 +1,4 @@ -from dewret.tasks import task, subworkflow +from dewret.tasks import task, workflow from dewret.annotations import AtRender from .other import nothing @@ -8,7 +8,7 @@ from inspect import get_annotations -@subworkflow() +@workflow() def try_nothing() -> int: """Check that we can see AtRender in another module.""" @@ -54,7 +54,7 @@ def pi() -> float: return math.pi -@subworkflow() +@workflow() def triple_and_one(num: int | float) -> int | float: """Triple a number by doubling and adding again, then add 1.""" return sum(left=sum(left=double(num=num), right=num), right=1) diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 70211dbf..765df53c 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -2,7 +2,7 @@ import yaml from typing import Literal -from dewret.tasks import task, construct, subworkflow, TaskException +from dewret.tasks import task, construct, workflow, TaskException from dewret.renderers.cwl import render from dewret.annotations import AtRender, FunctionAnalyser, Fixed from dewret.core import set_configuration @@ -24,12 +24,12 @@ def fn(arg5: int, arg6: AtRender[int]) -> float: return arg5 + arg6 + arg7 + arg8 + int(ARG1) + int(ARG2) -@subworkflow() +@workflow() def to_int_bad(num: int, should_double: bool) -> int | float: """Cast to an int.""" return increment(num=num) if should_double else sum(left=num, right=num) -@subworkflow() +@workflow() def to_int(num: int, should_double: AtRender[bool]) -> int | float: """Cast to an int.""" return increment(num=num) if should_double else sum(left=num, right=num) @@ -58,11 +58,11 @@ def test_can_analyze_annotations(): def test_at_render() -> None: with pytest.raises(TaskException) as _: result = to_int_bad(num=increment(num=3), should_double=True) - workflow = construct(result, simplify_ids=True) + wkflw = construct(result, simplify_ids=True) result = to_int(num=increment(num=3), should_double=True) - workflow = construct(result, simplify_ids=True) - subworkflows = render(workflow, allow_complex_types=True) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) rendered = subworkflows["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -99,8 +99,8 @@ def test_at_render() -> None: """) result = to_int(num=increment(num=3), should_double=False) - workflow = construct(result, simplify_ids=True) - subworkflows = render(workflow, allow_complex_types=True) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) rendered = subworkflows["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -140,14 +140,14 @@ def test_at_render() -> None: def test_at_render_between_modules() -> None: nothing = False result = try_nothing() - workflow = construct(result, simplify_ids=True) - subworkflows = render(workflow, allow_complex_types=True) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) rendered = subworkflows["__root__"] list_2: Fixed[list[int]] = [0, 1, 2, 3] def test_can_loop_over_fixed_length() -> None: - @subworkflow() + @workflow() def loop_over_lists(list_1: list[int]) -> list[int]: result = [] for a, b in zip(list_1, list_2): @@ -156,8 +156,8 @@ def loop_over_lists(list_1: list[int]) -> list[int]: with set_configuration(flatten_all_nested=True): result = loop_over_lists(list_1=[5, 6, 7, 8]) - workflow = construct(result, simplify_ids=True) - subworkflows = render(workflow, allow_complex_types=True) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) rendered = subworkflows["__root__"] assert rendered == yaml.safe_load(""" class: Workflow diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 53868ac2..01b91206 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -1,6 +1,6 @@ import yaml import pytest -from dewret.tasks import construct, task, factory, subworkflow, TaskException +from dewret.tasks import construct, task, factory, workflow, TaskException from dewret.renderers.cwl import render from dewret.utils import hasher from dewret.tasks import set_configuration @@ -12,7 +12,7 @@ def configuration(): with set_configuration() as configuration: yield configuration.get() -@subworkflow() +@workflow() def floor(num: int, expected: AtRender[bool]) -> int: """Converts int/float to int.""" from dewret.tasks import get_configuration diff --git a/tests/test_errors.py b/tests/test_errors.py index 6ffbf593..96607bce 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -2,7 +2,7 @@ import pytest from dewret.workflow import Task, Lazy -from dewret.tasks import construct, task, subworkflow, TaskException +from dewret.tasks import construct, task, workflow, TaskException from dewret.annotations import AtRender from dewret.renderers.cwl import render from ._lib.extra import increment, pi, reverse_list # noqa: F401 @@ -17,7 +17,7 @@ def add_task(left: int, right: int) -> int: ADD_TASK_LINE_NO: int = 11 -@subworkflow() +@workflow() def badly_add_task(left: int, right: int) -> int: """Badly attempts to add two numbers.""" return add_task(left=left) # type: ignore @@ -90,13 +90,13 @@ def pi_with_invisible_module_task() -> float: return extra.double(3.14 / 2) -@subworkflow() +@workflow() def unacceptable_object_usage() -> int: """Invalid use of custom object within nested task.""" return MyStrangeClass(add_task(left=3, right=4)) # type: ignore -@subworkflow() +@workflow() def unacceptable_nested_return(int_not_global: AtRender[bool]) -> int | Lazy: """Bad subworkflow that fails to return a task.""" add_task(left=3, right=4) @@ -244,7 +244,7 @@ def test_subworkflows_must_return_a_task() -> None: def test_must_annotate_global() -> None: worse_num = 3 - @subworkflow() + @workflow() def check_annotation() -> int | float: return increment(num=bad_num) @@ -256,7 +256,7 @@ def check_annotation() -> int | float: == "Could not find a type annotation for bad_num for check_annotation" ) - @subworkflow() + @workflow() def check_annotation_2() -> int | float: return increment(num=worse_num) @@ -268,7 +268,7 @@ def check_annotation_2() -> int | float: == "Cannot use free variables - please put worse_num at the global scope" ) - @subworkflow() + @workflow() def check_annotation_3() -> int | float: return increment(num=good_num) diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index db024f66..1a7f0b4c 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -4,7 +4,7 @@ import pytest from typing import Unpack, TypedDict -from dewret.tasks import task, construct, subworkflow, set_configuration +from dewret.tasks import task, construct, workflow, set_configuration from dewret.workflow import param from dewret.renderers.cwl import render from dewret.annotations import Fixed @@ -18,15 +18,15 @@ class Sides: SIDES: Sides = Sides(3, 6) -@subworkflow() +@workflow() def sum_sides(): return sum(left=SIDES.left, right=SIDES.right) @pytest.mark.skip(reason="Need expression support") def test_fields_of_parameters_usable() -> None: result = sum_sides() - workflow = construct(result, simplify_ids=True) - rendered = render(workflow, allow_complex_types=True)["sum_sides-1"] + wkflw = construct(result, simplify_ids=True) + rendered = render(wkflw, allow_complex_types=True)["sum_sides-1"] assert rendered == yaml.safe_load(""" class: Workflow @@ -65,8 +65,8 @@ class MyDataclass: left: int my_param = param("my_param", typ=MyDataclass) result = sum(left=my_param, right=my_param) - workflow = construct(result, simplify_ids=True) - param_reference = list(workflow.find_parameters())[0] + wkflw = construct(result, simplify_ids=True) + param_reference = list(wkflw.find_parameters())[0] assert str(param_reference.left) == "my_param/left" assert param_reference.left.__type__ == int @@ -77,36 +77,36 @@ class MyDataclass: left: int right: float - @subworkflow() + @workflow() def test_dataclass(my_dataclass: MyDataclass) -> MyDataclass: result: MyDataclass = MyDataclass(left=mod10(num=my_dataclass.left), right=pi()) return result - @subworkflow() + @workflow() def get_left(my_dataclass: MyDataclass) -> int: return my_dataclass.left result = get_left(my_dataclass=test_dataclass(my_dataclass=MyDataclass(left=3, right=4.))) - workflow = construct(result, simplify_ids=True) + wkflw = construct(result, simplify_ids=True) - assert str(workflow.result) == "get_left-1" - assert workflow.result.__type__ == int + assert str(wkflw.result) == "get_left-1" + assert wkflw.result.__type__ == int def test_can_get_field_references_from_typed_dict(): class MyDict(TypedDict): left: int right: float - @subworkflow() + @workflow() def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: result: MyDict = {"left": mod10(num=my_dict["left"]), "right": pi()} return result result = test_dict(left=3, right=4.) - workflow = construct(result, simplify_ids=True) + wkflw = construct(result, simplify_ids=True) - assert str(workflow.result["left"]) == "test_dict-1/left" - assert workflow.result["left"].__type__ == int + assert str(wkflw.result["left"]) == "test_dict-1/left" + assert wkflw.result["left"].__type__ == int def test_can_iterate(): @task() @@ -117,15 +117,15 @@ def test_task(alpha: int, beta: float, charlie: bool) -> int: def test_list() -> list[int | float]: return [1, 2.] - @subworkflow() + @workflow() def test_iterated() -> int: return test_task(*test_list()) with set_configuration(allow_positional_args=True, flatten_all_nested=True): result = test_iterated() - workflow = construct(result, simplify_ids=True) + wkflw = construct(result, simplify_ids=True) - rendered = render(workflow, allow_complex_types=True)["__root__"] + rendered = render(wkflw, allow_complex_types=True)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -155,7 +155,7 @@ def test_iterated() -> int: run: test_task """) - assert workflow.result._.step.positional_args == {"alpha": True, "beta": True, "charlie": True} + assert wkflw.result._.step.positional_args == {"alpha": True, "beta": True, "charlie": True} @dataclass class MyListWrapper: @@ -165,19 +165,19 @@ class MyListWrapper: def test_list_2() -> MyListWrapper: return MyListWrapper(my_list=[1, 2]) - @subworkflow() + @workflow() def test_iterated_2(my_wrapper: MyListWrapper) -> int: return test_task(*my_wrapper.my_list) with set_configuration(allow_positional_args=True, flatten_all_nested=True): result = test_iterated_2(my_wrapper=test_list_2()) - workflow = construct(result, simplify_ids=True) + wkflw = construct(result, simplify_ids=True) @task() def test_list_3() -> Fixed[list[tuple[int, int]]]: return [(0, 1), (2, 3)] - @subworkflow() + @workflow() def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: retval = mod10(*test_list_3()[0]) for pair in param: @@ -187,9 +187,9 @@ def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: with set_configuration(allow_positional_args=True, flatten_all_nested=True): result = test_iterated_3(param=[(0, 1), (2, 3)]) - workflow = construct(result, simplify_ids=True) + wkflw = construct(result, simplify_ids=True) - rendered = render(workflow, allow_complex_types=True)["__root__"] + rendered = render(wkflw, allow_complex_types=True)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -235,13 +235,13 @@ def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: """) def test_can_use_plain_dict_fields(): - @subworkflow() + @workflow() def test_dict(left: int, right: float) -> dict[str, float | int]: result: dict[str, float | int] = {"left": mod10(num=left), "right": pi()} return result with set_configuration(allow_plain_dict_fields=True): result = test_dict(left=3, right=4.) - workflow = construct(result, simplify_ids=True) - assert str(workflow.result["left"]) == "test_dict-1/left" - assert workflow.result["left"].__type__ == int | float + wkflw = construct(result, simplify_ids=True) + assert str(wkflw.result["left"]) == "test_dict-1/left" + assert wkflw.result["left"].__type__ == int | float diff --git a/tests/test_modularity.py b/tests/test_modularity.py index b2abb8b8..8b085457 100644 --- a/tests/test_modularity.py +++ b/tests/test_modularity.py @@ -1,14 +1,14 @@ """Verify CWL can be made with split up and nested calls.""" import yaml -from dewret.tasks import subworkflow, construct, set_configuration +from dewret.tasks import workflow, construct, set_configuration from dewret.renderers.cwl import render from ._lib.extra import double, sum, increase STARTING_NUMBER: int = 23 -@subworkflow() +@workflow() def algorithm() -> int | float: """Creates a graph of task calls.""" left = double(num=increase(num=STARTING_NUMBER)) diff --git a/tests/test_multiresult_steps.py b/tests/test_multiresult_steps.py index 0a57dc70..b278e84a 100644 --- a/tests/test_multiresult_steps.py +++ b/tests/test_multiresult_steps.py @@ -4,7 +4,7 @@ from attr import define from dataclasses import dataclass from typing import Iterable -from dewret.tasks import task, construct, subworkflow, set_configuration +from dewret.tasks import task, construct, workflow, set_configuration from dewret.renderers.cwl import render STARTING_NUMBER: int = 23 @@ -44,19 +44,19 @@ def pair(left: int, right: float) -> tuple[int, float]: return (left, right) -@subworkflow() +@workflow() def algorithm() -> float: """Sum two split values.""" return combine(left=split().first, right=split().second) -@subworkflow() +@workflow() def algorithm_with_pair() -> tuple[int, float]: """Pairs two split dataclass values.""" return pair(left=split_into_dataclass().first, right=split_into_dataclass().second) -@subworkflow() +@workflow() def algorithm_with_dataclasses() -> float: """Sums two split dataclass values.""" return combine( diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index cea34085..5cd86ff4 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -3,7 +3,7 @@ from typing import Callable from queue import Queue import yaml -from dewret.tasks import construct, subworkflow, task, factory, set_configuration +from dewret.tasks import construct, workflow, task, factory, set_configuration from dewret.renderers.cwl import render from dewret.workflow import param @@ -35,19 +35,19 @@ def add_and_queue(num: int, queue: Queue[int]) -> Queue[int]: return queue -@subworkflow() +@workflow() def make_queue(num: int | float) -> Queue[int]: """Add a number to a queue.""" queue = QueueFactory() return add_and_queue(num=to_int(num=num), queue=queue) -@subworkflow() +@workflow() def get_global_queue(num: int | float) -> Queue[int]: """Add a number to a global queue.""" return add_and_queue(num=to_int(num=num), queue=GLOBAL_QUEUE) -@subworkflow() +@workflow() def get_global_queues(num: int | float) -> list[Queue[int] | int]: """Add a number to a global queue.""" return [ @@ -56,17 +56,17 @@ def get_global_queues(num: int | float) -> list[Queue[int] | int]: ] -@subworkflow() +@workflow() def add_constant(num: int | float) -> int: """Add a global constant to a number.""" return to_int(num=sum(left=num, right=CONSTANT)) -@subworkflow() +@workflow() def add_constants(num: int | float) -> int: """Add a global constant to a number.""" return to_int(num=sum(left=sum(left=num, right=CONSTANT), right=CONSTANT)) -@subworkflow() +@workflow() def get_values(num: int | float) -> tuple[int | float, int]: """Add a global constant to a number.""" return (sum(left=num, right=CONSTANT), add_constant(CONSTANT)) @@ -75,14 +75,14 @@ def get_values(num: int | float) -> tuple[int | float, int]: def test_cwl_for_pairs() -> None: """Check whether we can produce CWL of pairs.""" - @subworkflow() + @workflow() def pair_pi() -> tuple[float, float]: return pi(), pi() with set_configuration(flatten_all_nested=True): result = pair_pi() - workflow = construct(result, simplify_ids=True) - rendered = render(workflow)["__root__"] + wkflw = construct(result, simplify_ids=True) + rendered = render(wkflw)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -112,8 +112,8 @@ def test_subworkflows_can_use_globals() -> None: """Produce a subworkflow that uses a global.""" my_param = param("num", typ=int) result = increment(num=add_constant(num=increment(num=my_param))) - workflow = construct(result, simplify_ids=True) - subworkflows = render(workflow) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw) rendered = subworkflows["__root__"] assert len(subworkflows) == 2 @@ -163,8 +163,8 @@ def test_subworkflows_can_use_factories() -> None: """Produce a subworkflow that uses a factory.""" my_param = param("num", typ=int) result = pop(queue=make_queue(num=increment(num=my_param))) - workflow = construct(result, simplify_ids=True) - subworkflows = render(workflow, allow_complex_types=True) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) rendered = subworkflows["__root__"] assert len(subworkflows) == 2 @@ -208,8 +208,8 @@ def test_subworkflows_can_use_global_factories() -> None: """Check whether we can produce a subworkflow that uses a global factory.""" my_param = param("num", typ=int) result = pop(queue=get_global_queue(num=increment(num=my_param))) - workflow = construct(result, simplify_ids=True) - subworkflows = render(workflow, allow_complex_types=True) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) rendered = subworkflows["__root__"] assert len(subworkflows) == 2 @@ -258,8 +258,8 @@ def test_subworkflows_can_return_lists() -> None: """Check whether we can produce a subworkflow that returns a list.""" my_param = param("num", typ=int) result = get_global_queues(num=increment(num=my_param)) - workflow = construct(result, simplify_ids=True) - subworkflows = render(workflow, allow_complex_types=True) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) rendered = subworkflows["__root__"] del subworkflows["__root__"] @@ -397,8 +397,8 @@ def test_can_merge_workflows() -> None: my_param = param("num", typ=int) value = to_int(num=increment(num=my_param)) result = sum(left=value, right=increment(num=value)) - workflow = construct(result, simplify_ids=True) - subworkflows = render(workflow, allow_complex_types=True) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) rendered = subworkflows["__root__"] del subworkflows["__root__"] @@ -451,8 +451,8 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: """Produce a subworkflow that uses a global.""" my_param = param("num", typ=int) result = increment(num=add_constants(num=increment(num=my_param))) - workflow = construct(result, simplify_ids=True) - subworkflows = render(workflow) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw) rendered = subworkflows["__root__"] del subworkflows["__root__"] From 71a35c3b98292eb4d1b20121b2da01a22838280b Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 19 Aug 2024 11:21:18 +0100 Subject: [PATCH 049/108] fix: allow workflows to come from packages --- src/dewret/__main__.py | 19 ++++++++++++++----- src/dewret/render.py | 9 ++++----- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 7f9d5f6a..3a0a235d 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -66,11 +66,11 @@ "--output", default="-" ) -@click.argument("workflow_py") +@click.argument("workflow_py", type=click.Path(exists=True)) @click.argument("task") @click.argument("arguments", nargs=-1) def render( - workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend, construct_args: str, renderer: str, renderer_args: str, output: str + workflow_py: Path, task: str, arguments: list[str], pretty: bool, backend: Backend, construct_args: str, renderer: str, renderer_args: str, output: str ) -> None: """Render a workflow. @@ -78,7 +78,7 @@ def render( TASK is the name of (decorated) task in workflow module. ARGUMENTS is zero or more pairs representing constant arguments to pass to the task, in the format `key:val` where val is a JSON basic type. """ - sys.path.append(str(Path(workflow_py).parent)) + sys.path.append(str(workflow_py.parent)) kwargs = {} for arg in arguments: if ":" not in arg: @@ -129,8 +129,17 @@ def _opener(key, mode): opener = _opener render = get_render_method(render_module, pretty=pretty) - loader = importlib.machinery.SourceFileLoader("workflow", workflow_py) - workflow = loader.load_module() + loader = importlib.machinery.SourceFileLoader("workflow", str(workflow_py)) + workflow_init = workflow_py.parent + + # Try to import the workflow as a package, if possible, to allow relative imports. + try: + loader = importlib.machinery.SourceFileLoader("__workflow__", str(workflow_py.parent / "__init__.py")) + sys.modules["workflow"] = loader.load_module(f"__workflow__") + workflow = importlib.import_module(f"__workflow__.{workflow_py.stem}", "__workflow__") + except ImportError: + loader = importlib.machinery.SourceFileLoader("__workflow__", str(workflow_py)) + workflow = loader.load_module() task_fn = getattr(workflow, task) try: diff --git a/src/dewret/render.py b/src/dewret/render.py index e4f6f53d..b19569e3 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -31,16 +31,15 @@ def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, if isinstance(renderer, Path): if (render_dir := str(renderer.parent)) not in sys.path: sys.path.append(render_dir) - package_init = renderer.parent # Attempt to load renderer as package, falling back to a single module otherwise. # This enables relative imports in renderers and therefore the ability to modularize. try: - loader = importlib.machinery.SourceFileLoader("renderer", str(package_init / "__init__.py")) - sys.modules["renderer"] = loader.load_module(f"renderer") - render_module = importlib.import_module(f"renderer.{renderer.stem}", "renderer") + loader = importlib.machinery.SourceFileLoader("__renderer__", str(renderer.parent / "__init__.py")) + sys.modules["__renderer__"] = loader.load_module(f"__renderer__") + render_module = importlib.import_module(f"__renderer__.{renderer.stem}", "__renderer__") except ImportError: - loader = importlib.machinery.SourceFileLoader("renderer", str(renderer)) + loader = importlib.machinery.SourceFileLoader("__renderer__", str(renderer)) render_module = loader.load_module() else: render_module = renderer From 03e52094dd7509179707cae273c557a1b53f6c08 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 19 Aug 2024 11:24:39 +0100 Subject: [PATCH 050/108] fix: allow workflows to come from packages --- src/dewret/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 3a0a235d..db64f8d8 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -66,7 +66,7 @@ "--output", default="-" ) -@click.argument("workflow_py", type=click.Path(exists=True)) +@click.argument("workflow_py", type=click.Path(exists=True), path_type=Path) @click.argument("task") @click.argument("arguments", nargs=-1) def render( From 7d880c06d6f70733acf149153f665d09291716fa Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 19 Aug 2024 11:28:06 +0100 Subject: [PATCH 051/108] fix: allow workflows to come from packages --- src/dewret/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index db64f8d8..74a26fa9 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -66,7 +66,7 @@ "--output", default="-" ) -@click.argument("workflow_py", type=click.Path(exists=True), path_type=Path) +@click.argument("workflow_py", type=click.Path(exists=True, path_type=Path)) @click.argument("task") @click.argument("arguments", nargs=-1) def render( From aee1ce109b893c1e4d7b50dd3f016db843a2007a Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 19 Aug 2024 13:04:13 +0100 Subject: [PATCH 052/108] fix: allow workflows to come from packages --- src/dewret/__main__.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 74a26fa9..762c5e84 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -20,6 +20,7 @@ """ import importlib +import importlib.util from pathlib import Path from contextlib import contextmanager import sys @@ -131,14 +132,19 @@ def _opener(key, mode): render = get_render_method(render_module, pretty=pretty) loader = importlib.machinery.SourceFileLoader("workflow", str(workflow_py)) workflow_init = workflow_py.parent + pkg = "__workflow__" # Try to import the workflow as a package, if possible, to allow relative imports. try: - loader = importlib.machinery.SourceFileLoader("__workflow__", str(workflow_py.parent / "__init__.py")) - sys.modules["workflow"] = loader.load_module(f"__workflow__") - workflow = importlib.import_module(f"__workflow__.{workflow_py.stem}", "__workflow__") + spec = importlib.util.spec_from_file_location(pkg, str(workflow_py.parent / "__init__.py")) + if spec is None or spec.loader is None: + raise ImportError(f"Could not open {pkg} package") + module = importlib.util.module_from_spec(spec) + sys.modules[pkg] = module + spec.loader.exec_module(module) + workflow = importlib.import_module(f"{pkg}.{workflow_py.stem}", pkg) except ImportError: - loader = importlib.machinery.SourceFileLoader("__workflow__", str(workflow_py)) + loader = importlib.machinery.SourceFileLoader(pkg, str(workflow_py)) workflow = loader.load_module() task_fn = getattr(workflow, task) From fc7495e6e3d5ef85a5951e88b8fe7ff153f01de0 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 19 Aug 2024 13:39:45 +0100 Subject: [PATCH 053/108] feat(iterable): provide len for Fixed --- src/dewret/workflow.py | 5 +++++ tests/test_annotations.py | 6 +++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 2bcb76b2..266e8b5c 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -1191,6 +1191,11 @@ def __inner_iter__(self) -> Generator[Any, None, None]: while True: yield None + def __len__(self): + inner, metadata = strip_annotations(self.__type__) + if metadata and "Fixed" in metadata and isinstance(self.__default__, Sized): + return len(self.__default__) + return super().__len__() class StepReference(FieldableMixin, Reference[U]): """Reference to an individual `Step`. diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 765df53c..9d70ac16 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -151,7 +151,7 @@ def test_can_loop_over_fixed_length() -> None: def loop_over_lists(list_1: list[int]) -> list[int]: result = [] for a, b in zip(list_1, list_2): - result.append(a + b) + result.append(a + b + len(list_2)) return result with set_configuration(flatten_all_nested=True): @@ -164,8 +164,8 @@ def loop_over_lists(list_1: list[int]) -> list[int]: cwlVersion: 1.2 inputs: {} outputs: - expression: '[list_1[0] + list_2[0], list_1[1] + list_2[1], list_1[2] + list_2[2], - list_1[3] + list_2[3]]' + expression: '[4 + list_1[0] + list_2[0], 4 + list_1[1] + list_2[1], 4 + list_1[2] + list_2[2], + 4 + list_1[3] + list_2[3]]' source: - list_1 - list_2 From 093b1d6ba3e44ac5399a58607a3308df5aa2dce5 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 19 Aug 2024 14:36:51 +0100 Subject: [PATCH 054/108] fix: allow fields of parameters --- src/dewret/workflow.py | 3 +++ tests/test_fieldable.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 266e8b5c..4b5d383d 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -230,6 +230,9 @@ def register_caller(self, caller: BaseStep) -> None: self.__tethered__ = caller self.__callers__.append(caller) + def __getattr__(self, attr: str) -> Reference[T]: + return getattr(self.make_reference(workflow=None), attr) + def param( name: str, diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index 1a7f0b4c..e122935f 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -59,6 +59,43 @@ def test_fields_of_parameters_usable() -> None: run: sum """) +def test_can_get_field_reference_from_parameter(): + @dataclass + class MyDataclass: + left: int + my_param = param("my_param", typ=MyDataclass) + result = sum(left=my_param.left, right=my_param) + wkflw = construct(result, simplify_ids=True) + param_references = {(str(p), p.__type__) for p in wkflw.find_parameters()} + + assert param_references == {("my_param/left", int), ("my_param", MyDataclass)} + rendered = render(wkflw, allow_complex_types=True)["__root__"] + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + my_param: + label: my_param + type: MyDataclass + outputs: + out: + label: out + outputSource: sum-1/out + type: + - int + - double + steps: + sum-1: + in: + left: + source: my_param/left + right: + source: my_param + out: + - out + run: sum + """) + def test_can_get_field_reference_iff_parent_type_has_field(): @dataclass class MyDataclass: From 362af32db8bf3285cb49e978c9e57f2ecf6f1f9c Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 19 Aug 2024 23:49:20 +0100 Subject: [PATCH 055/108] fix: reduce number of workflow merges --- src/dewret/workflow.py | 111 ++++++++++++++++++++++------------------- 1 file changed, 61 insertions(+), 50 deletions(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 4b5d383d..540e4286 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -39,6 +39,8 @@ U = TypeVar("U") RetType = TypeVar("RetType") +CHECK_IDS = False + class Lazy(Protocol): """Requirements for a lazy-evaluatable function.""" @@ -345,11 +347,18 @@ def __str__(self) -> str: return self.name def __repr__(self) -> str: - return self.name + if self._name: + return self.name + comp_tup = tuple(sorted(s.id for s in self.steps)) + + return f"workflow-{hasher(comp_tup)}" def __hash__(self) -> int: """Hashes for finding.""" - return hash(repr(self)) + return hash(( + self._name, + tuple(self.steps), + )) def __eq__(self, other: object) -> bool: """Is this the same workflow? @@ -416,7 +425,7 @@ def _indexed_steps(self) -> dict[str, BaseStep]: return OrderedDict(sorted(((step.id, step) for step in self.steps), key=lambda x: x[0])) @classmethod - def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": + def assimilate(cls, *workflows) -> "Workflow": """Combine two Workflows into one Workflow. Takes two workflows and unifies them by combining steps @@ -429,56 +438,56 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": left: workflow to use as base right: workflow to combine on top """ - if left == right: - return left + workflows = set(workflows) + base = next(iter(workflows)) - new = cls() + if len(workflows) == 1: + return base - new._name = left._name or right._name + names = {w._name for w in workflows if w._name} + base._name = base._name or (names and next(iter(names))) or None - left_steps = left._indexed_steps - right_steps = right._indexed_steps + #left_steps = left._indexed_steps + #right_steps = right._indexed_steps + all_steps = sum((list(w._indexed_steps.items()) for w in workflows), []) - for step in list(left_steps.values()) + list(right_steps.values()): - step.set_workflow(new) + for _, step in all_steps: + #for step in list(left_steps.values()) + list(right_steps.values()): + step.set_workflow(base) - for step_id in left_steps.keys() & right_steps.keys(): - if left_steps[step_id] != right_steps[step_id]: + indexed_steps = {} + for step_id, step in all_steps: + indexed_steps.setdefault(step_id, step) + if step != indexed_steps[step_id]: raise RuntimeError( f"Two steps have same ID but do not match: {step_id}" ) - for task_id in left.tasks.keys() & right.tasks.keys(): - if left.tasks[task_id] != right.tasks[task_id]: - raise RuntimeError("Two tasks have same name but do not match") - - indexed_steps = dict(left_steps) - indexed_steps.update(right_steps) - new.steps += list(indexed_steps.values()) - new.tasks.update(left.tasks) - new.tasks.update(right.tasks) - - for step in new.steps: - step.set_workflow(new, with_arguments=True) - - if left.result == right.result: - result = left.result - elif not left.result: - result = right.result - elif not right.result: - result = left.result + all_tasks = sum((list(w.tasks.items()) for w in workflows), []) + indexed_tasks = {} + for task_id, task in all_tasks: + indexed_tasks.setdefault(task_id, task) + if task != indexed_tasks[task_id]: + raise RuntimeError(f"Two tasks have same name {task_id} but do not match") + + base.steps = list(indexed_steps.values()) + base.tasks = indexed_tasks + + for step in base.steps: + step.set_workflow(base, with_arguments=True) + + results = set((w.result for w in workflows if w.result)) + if len(results) == 1: + result = next(iter(results)) else: - if not isinstance(left.result, tuple | list): - left.result = [left.result] - if not isinstance(right.result, tuple | list): - right.result = [right.result] - result = list(left.result) + list(right.result) + results = {r if isinstance(r, tuple | list) else (r,) for r in results} + result = sum(map(list, results), []) if result is not None and result != []: - unify_workflows(result, new, set_only=True) - new.set_result(result) + unify_workflows(result, base, set_only=True) + base.set_result(result) - return new + return base def remap(self, step_id: str) -> str: """Apply name simplification if requested. @@ -900,6 +909,7 @@ def set_workflow(self, workflow: Workflow, with_arguments: bool = True) -> None: if with_arguments: for argument in self.arguments.values(): unify_workflows(argument, workflow, set_only=True) + self._id = None @property def return_type(self) -> Any: @@ -922,6 +932,9 @@ def return_type(self) -> Any: return self.task.target return inspect.signature(inspect.unwrap(self.task.target)).return_annotation + def __hash__(self) -> int: + return hash(self.id) + @property def name(self) -> str: """Name for this step. @@ -938,12 +951,13 @@ def id(self) -> str: self._id = self._generate_id() return self._id - check_id = self._generate_id() - if check_id != self._id: - return self._id - raise RuntimeError( - f"Cannot change a step after requesting its ID: {self.task}" - ) + if CHECK_IDS: + check_id = self._generate_id() + if check_id != self._id: + return self._id + raise RuntimeError( + f"Cannot change a step after requesting its ID: {self.task}" + ) return self._id def _generate_id(self) -> str: @@ -1346,10 +1360,7 @@ def merge_workflows(*workflows: Workflow) -> Workflow: Returns: One workflow with all steps. """ - base = list(workflows).pop() - for workflow in workflows: - base = Workflow.assimilate(base, workflow) - return base + return Workflow.assimilate(*workflows) def is_task(task: Lazy) -> bool: From 828d58c6903ad4909004d8955a0f3a7c69653b7a Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 20 Aug 2024 01:06:16 +0100 Subject: [PATCH 056/108] fix: render configuration --- src/dewret/backends/backend_dask.py | 1 - src/dewret/core.py | 61 ++++++++++++++++++++++------- src/dewret/render.py | 1 + src/dewret/renderers/cwl.py | 40 +++++++------------ src/dewret/tasks.py | 2 +- 5 files changed, 63 insertions(+), 42 deletions(-) diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 4b859e59..36e7fb21 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -24,7 +24,6 @@ from typing import Protocol, runtime_checkable, Any, cast from concurrent.futures import Executor, ThreadPoolExecutor from dewret.workflow import Workflow, Lazy, StepReference, Target -from dewret.tasks import CONSTRUCT_CONFIGURATION @runtime_checkable diff --git a/src/dewret/core.py b/src/dewret/core.py index 2b10a97f..eaca524a 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,15 +1,18 @@ from dataclasses import dataclass import base64 +from functools import lru_cache from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Never from contextlib import contextmanager from contextvars import ContextVar from sympy import Expr, Symbol +import copy BasicType = str | float | bool | bytes | int | None RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] FirmType = BasicType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] U = TypeVar("U") +T = TypeVar("T", bound=TypedDict) def strip_annotations(parent_type: type) -> tuple[type, tuple]: # Strip out any annotations. This should be auto-flattened, so in theory only one iteration could occur. @@ -33,31 +36,59 @@ class ConstructConfiguration(TypedDict): allow_plain_dict_fields: NotRequired[bool] field_separator: NotRequired[str] -CONSTRUCT_CONFIGURATION: ContextVar[ConstructConfiguration] = ContextVar("construct-configuration") +class GlobalConfiguration(TypedDict): + construct: ConstructConfiguration + render: dict + +CONFIGURATION: ContextVar[GlobalConfiguration] = ContextVar("configuration") + +@contextmanager +def set_configuration(**kwargs: Unpack[ConstructConfiguration]) -> Iterator[ContextVar[GlobalConfiguration]]: + with _set_configuration("construct", kwargs) as var: + yield var @contextmanager -def set_configuration(**kwargs: Unpack[ConstructConfiguration]) -> Iterator[ContextVar[ConstructConfiguration]]: +def set_render_configuration(kwargs) -> Iterator[ContextVar[GlobalConfiguration]]: + with _set_configuration("render", kwargs) as var: + yield var + +@lru_cache +def default_renderer_config() -> dict: try: - previous = ConstructConfiguration(**CONSTRUCT_CONFIGURATION.get()) + from __renderer_mod__ import default_config + except ImportError: + return {} + return default_config() + +@lru_cache +def default_construct_config() -> ConstructConfiguration: + return ConstructConfiguration( + flatten_all_nested=False, + allow_positional_args=False, + allow_plain_dict_fields=False, + field_separator="/" + ) + +@contextmanager +def _set_configuration(config_group: str, kwargs: U) -> Iterator[ContextVar[GlobalConfiguration]]: + try: + previous = copy.deepcopy(GlobalConfiguration(**CONFIGURATION.get())) except LookupError: - previous = ConstructConfiguration( - flatten_all_nested=False, - allow_positional_args=False, - allow_plain_dict_fields=False, - field_separator="/" - ) - CONSTRUCT_CONFIGURATION.set({}) + previous = {"construct": default_construct_config(), "render": default_renderer_config()} + CONFIGURATION.set(previous) try: - CONSTRUCT_CONFIGURATION.get().update(previous) - CONSTRUCT_CONFIGURATION.get().update(kwargs) + CONFIGURATION.get()[config_group].update(kwargs) - yield CONSTRUCT_CONFIGURATION + yield CONFIGURATION finally: - CONSTRUCT_CONFIGURATION.set(previous) + CONFIGURATION.set(previous) def get_configuration(key: str) -> RawType: - return CONSTRUCT_CONFIGURATION.get().get(key) # type: ignore + return CONFIGURATION.get()["construct"].get(key) # type: ignore + +def get_render_configuration(key: str) -> RawType: + return CONFIGURATION.get()["render"].get(key) # type: ignore class Reference(Generic[U], Symbol): """Superclass for all symbolic references to values.""" diff --git a/src/dewret/render.py b/src/dewret/render.py index b19569e3..9162166f 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -41,6 +41,7 @@ def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, except ImportError: loader = importlib.machinery.SourceFileLoader("__renderer__", str(renderer)) render_module = loader.load_module() + sys.modules["__renderer_mod__"] = render_module else: render_module = renderer if hasattr(render_module, "render_raw"): diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index dc8f41f7..9a636072 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -44,6 +44,7 @@ ) from dewret.utils import flatten, DataclassProtocol, firm_to_raw, flatten_if_set from dewret.render import base_render +from dewret.core import get_render_configuration, set_render_configuration class CommandInputSchema(TypedDict): """Structure for referring to a raw type in CWL. @@ -91,19 +92,11 @@ class CWLRendererConfiguration(TypedDict): factories_as_params: NotRequired[bool] -CONFIGURATION: ContextVar[CWLRendererConfiguration] = ContextVar("cwl-configuration") -DEFAULT_CONFIGURATION: CWLRendererConfiguration = { - "allow_complex_types": False, - "factories_as_params": False, -} - - -def configuration(key: str) -> Any: - """Retrieve current configuration (thread/async-local).""" - current_configuration = CONFIGURATION.get() - if key not in current_configuration: - raise KeyError("Unknown configuration settings.") - return current_configuration.get(key) +def default_renderer_config() -> CWLRendererConfiguration: + return { + "allow_complex_types": False, + "factories_as_params": False, + } def with_type(result: Any) -> type: @@ -285,7 +278,7 @@ def to_cwl_type(label: str, typ: type) -> CommandInputSchema: raise TypeError( f"Cannot render complex type ({typ}) to CWL, have you enabled allow_complex_types configuration?" ) from err - elif configuration("allow_complex_types"): + elif get_render_configuration("allow_complex_types"): typ_dict["type"] = typ if isinstance(typ, str) else typ.__name__ else: raise TypeError(f"Cannot render type ({typ}) to CWL") @@ -570,10 +563,10 @@ def from_workflow( """ parameters: list[ParameterReference | FactoryCall] = list( workflow.find_parameters( - include_factory_calls=not configuration("factories_as_params") + include_factory_calls=not get_render_configuration("factories_as_params") ) ) - if configuration("factories_as_params"): + if get_render_configuration("factories_as_params"): parameters += list(workflow.find_factories().values()) return cls( steps=[ @@ -581,7 +574,7 @@ def from_workflow( for step in workflow.steps if not ( isinstance(step, FactoryCall) - and configuration("factories_as_params") + and get_render_configuration("factories_as_params") ) ], inputs=InputsDefinition.from_parameters(parameters), @@ -624,12 +617,9 @@ def render( Reduced form as a native Python dict structure for serialization. """ - config = CWLRendererConfiguration(**DEFAULT_CONFIGURATION) - config.update(kwargs) - CONFIGURATION.set(config) - rendered = base_render( - workflow, - lambda workflow: WorkflowDefinition.from_workflow(workflow).render() - ) - CONFIGURATION.set(DEFAULT_CONFIGURATION) + with set_render_configuration(kwargs): + rendered = base_render( + workflow, + lambda workflow: WorkflowDefinition.from_workflow(workflow).render() + ) return rendered diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 80742901..640a3387 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -65,7 +65,7 @@ ) from .backends._base import BackendModule from .annotations import FunctionAnalyser -from .core import get_configuration, set_configuration, CONSTRUCT_CONFIGURATION, IteratedGenerator, ConstructConfiguration +from .core import get_configuration, set_configuration, IteratedGenerator, ConstructConfiguration import ast Param = ParamSpec("Param") From 8988037a28718407f8331ecb06ed1603d36b5ba4 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 20 Aug 2024 01:13:41 +0100 Subject: [PATCH 057/108] fix: render configuration --- src/dewret/__main__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 762c5e84..11cf4779 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -31,7 +31,7 @@ import click import json -from .core import set_configuration +from .core import set_configuration, set_render_configuration from .render import get_render_method, RawRenderModule, StructuredRenderModule from .tasks import Backend, construct @@ -149,7 +149,7 @@ def _opener(key, mode): task_fn = getattr(workflow, task) try: - with set_configuration(**construct_kwargs): + with set_configuration(**construct_kwargs), set_render_configuration(**renderer_kwargs): rendered = render(construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs) except Exception as exc: import traceback From f03b58ca5d073ad67d0b68503808d69aec812c99 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 20 Aug 2024 23:46:18 +0100 Subject: [PATCH 058/108] fix: address field and hashing consistency --- src/dewret/__main__.py | 3 +- src/dewret/core.py | 6 ++- src/dewret/render.py | 2 +- src/dewret/renderers/cwl.py | 6 +-- src/dewret/tasks.py | 3 +- src/dewret/utils.py | 4 +- src/dewret/workflow.py | 103 ++++++++++++++++++++++++------------ tests/_lib/frender.py | 2 +- tests/test_cwl.py | 16 +++--- tests/test_fieldable.py | 27 ++++++++++ 10 files changed, 119 insertions(+), 53 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 11cf4779..e3c9093a 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -130,7 +130,6 @@ def _opener(key, mode): opener = _opener render = get_render_method(render_module, pretty=pretty) - loader = importlib.machinery.SourceFileLoader("workflow", str(workflow_py)) workflow_init = workflow_py.parent pkg = "__workflow__" @@ -149,7 +148,7 @@ def _opener(key, mode): task_fn = getattr(workflow, task) try: - with set_configuration(**construct_kwargs), set_render_configuration(**renderer_kwargs): + with set_configuration(**construct_kwargs), set_render_configuration(renderer_kwargs): rendered = render(construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs) except Exception as exc: import traceback diff --git a/src/dewret/core.py b/src/dewret/core.py index eaca524a..05436145 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -35,6 +35,7 @@ class ConstructConfiguration(TypedDict): allow_positional_args: NotRequired[bool] allow_plain_dict_fields: NotRequired[bool] field_separator: NotRequired[str] + field_index_types: NotRequired[str] class GlobalConfiguration(TypedDict): construct: ConstructConfiguration @@ -66,7 +67,8 @@ def default_construct_config() -> ConstructConfiguration: flatten_all_nested=False, allow_positional_args=False, allow_plain_dict_fields=False, - field_separator="/" + field_separator="/", + field_index_types="int", ) @contextmanager @@ -148,7 +150,7 @@ def __name__(self) -> str: """Referral name for this reference.""" workflow = self.__workflow__ name = self.__root_name__ - return workflow.remap(name) + return workflow.remap(name) or name def __str__(self) -> str: """Global description of the reference.""" diff --git a/src/dewret/render.py b/src/dewret/render.py index 9162166f..32b79f6f 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -72,7 +72,7 @@ def base_render( """ primary_workflow = build_cb(workflow) subworkflows = {} - for step in workflow.steps: + for step in workflow.indexed_steps.values(): if isinstance(step, NestedStep): nested_subworkflows = base_render(step.subworkflow, build_cb) subworkflows.update(nested_subworkflows) diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 9a636072..57bdcb64 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -105,8 +105,8 @@ def with_type(result: Any) -> type: return type(result) def with_field(result: Any) -> str: - if hasattr(result, "__field__"): - return "/".join(result.__field__) or "out" + if hasattr(result, "__field__") and result.__field__: + return result.__field_str__ else: return "out" @@ -571,7 +571,7 @@ def from_workflow( return cls( steps=[ StepDefinition.from_step(step) - for step in workflow.steps + for step in workflow.indexed_steps.values() if not ( isinstance(step, FactoryCall) and get_render_configuration("factories_as_params") diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 640a3387..e271e082 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -507,7 +507,6 @@ def {fn.__name__}(...) -> ...: raise RuntimeError(f"Could not find a type annotation for {var} for {fn.__name__}") elif ( analyser.is_at_construct_arg(var, exhaustive=True) or - isinstance(value, Reference) or value is evaluate or value is construct # Allow manual building. ): kwargs[var] = value @@ -527,6 +526,8 @@ def {fn.__name__}(...) -> ...: tethered=False, typ=analyser.get_argument_annotation(var, exhaustive=True) or UNSET ).make_reference(workflow=workflow) + elif is_expr(value) and expr_to_references(value)[1] is not []: + kwargs[var] = value elif nested: raise NotImplementedError( f"Nested tasks must now only refer to global parameters, raw or tasks, not objects: {var}" diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 894e5a3f..3b72cce9 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -101,8 +101,8 @@ def firm_to_raw(value: FirmType) -> RawType: def flatten(value: Any) -> RawType: return crawl_raw(value, lambda entry: entry) -def is_expr(value: Any) -> bool: - return is_raw(value, lambda x: isinstance(x, Basic) or isinstance(x, tuple) or isinstance(x, Reference) or isinstance(x, Raw)) +def is_expr(value: Any, permitted_references: type=Reference) -> bool: + return is_raw(value, lambda x: isinstance(x, Basic) or isinstance(x, tuple) or isinstance(x, permitted_references) or isinstance(x, Raw)) def is_raw_type(typ: type) -> bool: """Check if a type counts as "raw".""" diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 540e4286..3bd80c00 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -40,6 +40,10 @@ RetType = TypeVar("RetType") CHECK_IDS = False +AVAILABLE_TYPES = { + "int": int, + "str": str +} class Lazy(Protocol): @@ -160,7 +164,8 @@ def __init__( self.register_caller(tethered) def is_loopable(self, typ: type): - base = get_origin(strip_annotations(typ)[0]) + base = strip_annotations(typ)[0] + base = get_origin(base) or base return inspect.isclass(base) and issubclass(base, Iterable) and not issubclass(base, str | bytes) @property @@ -326,7 +331,7 @@ class Workflow: result: target reference to evaluate, if yet present. """ - steps: list["BaseStep"] + _steps: list["BaseStep"] tasks: MutableMapping[str, "Task"] result: StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]] | None _remapping: dict[str, str] | None @@ -334,12 +339,16 @@ class Workflow: def __init__(self, name: str | None = None) -> None: """Initialize a Workflow, by setting `steps` and `tasks` to empty containers.""" - self.steps = [] + self._steps = [] self.tasks = {} self.result: StepReference[Any] | None = None self._remapping = None self._name = name + @property + def steps(self) -> set[BaseStep]: + return set(self._steps) + def __str__(self) -> str: """Name of the workflow, if available.""" if self._name is None: @@ -349,16 +358,17 @@ def __str__(self) -> str: def __repr__(self) -> str: if self._name: return self.name - comp_tup = tuple(sorted(s.id for s in self.steps)) - - return f"workflow-{hasher(comp_tup)}" + return self.id def __hash__(self) -> int: """Hashes for finding.""" - return hash(( - self._name, - tuple(self.steps), - )) + return hash(self.id) + + @property + def id(self) -> str: + comp_tup = tuple(sorted(s.id for s in self.steps)) + return f"workflow-{hasher(comp_tup)}" + def __eq__(self, other: object) -> bool: """Is this the same workflow? @@ -369,7 +379,7 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, Workflow): return False return ( - self.steps == other.steps + self._steps == other._steps and self.tasks == other.tasks and self.result == other.result and self._remapping == other._remapping @@ -393,7 +403,7 @@ def name(self) -> str: def find_factories(self) -> dict[str, FactoryCall]: """Steps that are factory calls.""" - return {step.id: step for step in self.steps if isinstance(step, FactoryCall)} + return {step_id: step for step_id, step in self.indexed_steps.items() if isinstance(step, FactoryCall)} def find_parameters( self, include_factory_calls: bool = True @@ -412,7 +422,7 @@ def find_parameters( return {ref for ref in references if isinstance(ref, ParameterReference)} @property - def _indexed_steps(self) -> dict[str, BaseStep]: + def indexed_steps(self) -> dict[str, BaseStep]: """Steps mapped by ID. Forces generation of IDs. Note that this effectively @@ -425,7 +435,7 @@ def _indexed_steps(self) -> dict[str, BaseStep]: return OrderedDict(sorted(((step.id, step) for step in self.steps), key=lambda x: x[0])) @classmethod - def assimilate(cls, *workflows) -> "Workflow": + def assimilate(cls, *workflow_args: Workflow) -> "Workflow": """Combine two Workflows into one Workflow. Takes two workflows and unifies them by combining steps @@ -438,18 +448,18 @@ def assimilate(cls, *workflows) -> "Workflow": left: workflow to use as base right: workflow to combine on top """ - workflows = set(workflows) - base = next(iter(workflows)) + workflows = sorted((w for w in set(workflow_args)), key=lambda w: w.id) + base = workflows[0] if len(workflows) == 1: return base - names = {w._name for w in workflows if w._name} - base._name = base._name or (names and next(iter(names))) or None + names = sorted({w._name for w in workflows if w._name}) + base._name = base._name or (names and names[0]) or None #left_steps = left._indexed_steps #right_steps = right._indexed_steps - all_steps = sum((list(w._indexed_steps.items()) for w in workflows), []) + all_steps = sorted(sum((list(w.indexed_steps.items()) for w in workflows), []), key=lambda s: s[0]) for _, step in all_steps: #for step in list(left_steps.values()) + list(right_steps.values()): @@ -470,17 +480,17 @@ def assimilate(cls, *workflows) -> "Workflow": if task != indexed_tasks[task_id]: raise RuntimeError(f"Two tasks have same name {task_id} but do not match") - base.steps = list(indexed_steps.values()) + base._steps = sorted(indexed_steps.values(), key=lambda s: s.id) base.tasks = indexed_tasks for step in base.steps: step.set_workflow(base, with_arguments=True) - results = set((w.result for w in workflows if w.result)) + results = sorted(set((w.result for w in workflows if w.result))) if len(results) == 1: - result = next(iter(results)) + result = results[0] else: - results = {r if isinstance(r, tuple | list) else (r,) for r in results} + results = sorted({r if isinstance(r, tuple | list) else (r,) for r in results}) result = sum(map(list, results), []) if result is not None and result != []: @@ -508,7 +518,7 @@ def simplify_ids(self, infix: list[str] | None = None) -> None: counter = Counter[Task | Workflow]() self._remapping = {} infix_str = ("-".join(infix) + "-") if infix else "" - for step in self.steps: + for key, step in self.indexed_steps.items(): counter[step.task] += 1 self._remapping[step.id] = f"{step.task}-{infix_str}{counter[step.task]}" if isinstance(step, NestedStep): @@ -565,7 +575,7 @@ def add_nested_step( step = NestedStep(self, name, subworkflow, kwargs) if positional_args is not None: step.positional_args = positional_args - self.steps.append(step) + self._steps.append(step) return_type = return_type or step.return_type if return_type is inspect._empty: raise TypeError("All tasks should have a type annotation.") @@ -595,7 +605,7 @@ def add_step( step = step_maker(self, task, kwargs, raw_as_parameter=raw_as_parameter) if positional_args is not None: step.positional_args = positional_args - self.steps.append(step) + self._steps.append(step) return_type = step.return_type if ( return_type is inspect._empty @@ -697,6 +707,7 @@ def __workflow__(self) -> Workflow: class FieldableProtocol(Protocol): __field__: tuple[str, ...] __field_sep__: str + __field_index_types__: tuple[type, ...] def __init__(self, *args, field: str | None = None, **kwargs): super().__init__(*args, **kwargs) @@ -722,6 +733,20 @@ def __init__(self: FieldableProtocol, *args, field: str | int | tuple | None = N def __field_sep__(self) -> str: return get_configuration("field_separator") + @property + def __field_index_types__(self) -> tuple[type, ...]: + types = get_configuration("field_index_types") + if not isinstance(types, str): + raise TypeError("Field index types must be provided as a comma-separated names") + tup_all = tuple(AVAILABLE_TYPES.get(typ) for typ in types.split(",") if typ) + tup = tuple(t for t in tup_all if t is not None) + if tup_all != tup: + raise ValueError( + "Setting for fixed index types contains unavailable type: " + + f"{str(get_configuration("field_index_types"))} vs {tup}" + ) + return tup + @property def __name__(self: FieldableProtocol) -> str: """Name for this step. @@ -731,11 +756,15 @@ def __name__(self: FieldableProtocol) -> str: """ return super().__name__ + self.__field_suffix__ + @property + def __field_str__(self) -> str: + return self.__field_suffix__.lstrip(self.__field_sep__) + @property def __field_suffix__(self) -> str: result = "" for cmpt in self.__field__: - if isinstance(cmpt, int): + if any(isinstance(cmpt, typ) for typ in self.__field_index_types__): result += f"[{cmpt}]" else: result += f"{self.__field_sep__}{cmpt}" @@ -751,9 +780,9 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None # Get new type, for the specific field. parent_type, _ = strip_annotations(self.__type__) field_type = fallback_type + base = get_origin(parent_type) or parent_type if isinstance(field, int): - base = get_origin(parent_type) if not inspect.isclass(base) or not issubclass(base, Sequence): raise AttributeError(f"Tried to index int {field} into a non-sequence type {parent_type} (base: {base})") if not (field_type := get_args(parent_type)[0]): @@ -779,7 +808,7 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None field_type = parent_type.__annotations__[field] except KeyError: raise AttributeError(f"TypedDict {parent_type} does not have field {field}") - if not field_type and get_configuration("allow_plain_dict_fields") and strip_annotations(get_origin(parent_type))[0] is dict: + if not field_type and get_configuration("allow_plain_dict_fields") and inspect.isclass(base) and issubclass(base, dict): args = get_args(parent_type) if len(args) == 2 and args[0] is str: field_type = args[1] @@ -889,11 +918,15 @@ def make_reference(self, **kwargs) -> "StepReference": kwargs["step"] = self kwargs.setdefault("typ", self.return_type) typ = kwargs["typ"] - base = get_origin(strip_annotations(typ)[0]) + base = strip_annotations(typ)[0] + base = get_origin(base) or base if inspect.isclass(base) and issubclass(base, Iterable) and not issubclass(base, str | bytes): return IterableStepReference(**kwargs) return StepReference(**kwargs) + def __hash__(self) -> int: + return hash(self.id) + def set_workflow(self, workflow: Workflow, with_arguments: bool = True) -> None: """Move the step reference to another workflow. @@ -1049,12 +1082,16 @@ def __init__( arguments: key-value pairs to pass to the function - for a factory call, these _must_ be raw. raw_as_parameter: whether to turn any raw-type arguments into workflow parameters (or just keep them as default argument values). """ - for arg in list(arguments.values()): - if not is_raw(arg) and not ( + for key, arg in arguments.items(): + if not is_expr(arg) and not ( isinstance(arg, ParameterReference) and is_raw_type(arg.__type__) ): + try: + arg_name = str(arg) + except: + arg_name = "(unnamed)" raise RuntimeError( - f"Factories must be constructed with raw types {arg} {type(arg)}" + f"Factories must be constructed with raw types in {arg_name} {type(arg)}" ) super().__init__(workflow=workflow, task=task, arguments=arguments, raw_as_parameter=raw_as_parameter) diff --git a/tests/_lib/frender.py b/tests/_lib/frender.py index 6e92573f..257de0b6 100644 --- a/tests/_lib/frender.py +++ b/tests/_lib/frender.py @@ -66,7 +66,7 @@ class WorkflowDefinition: @classmethod def from_workflow(cls, workflow: Workflow): steps = [] - for step in workflow.steps: + for step in workflow.indexed_steps.values(): if isinstance(step, Step): steps.append(StepDefinition.from_step(step)) elif isinstance(step, NestedStep): diff --git a/tests/test_cwl.py b/tests/test_cwl.py index 043ce53d..48078d87 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -305,7 +305,7 @@ def test_cwl_with_subworkflow() -> None: outputs: out: label: out - outputSource: sum-1-2/out + outputSource: sum-1-1/out type: - int - float @@ -317,7 +317,7 @@ def test_cwl_with_subworkflow() -> None: out: - out run: double - sum-1-1: + sum-1-2: in: left: source: double-1-1/out @@ -326,10 +326,10 @@ def test_cwl_with_subworkflow() -> None: out: - out run: sum - sum-1-2: + sum-1-1: in: left: - source: sum-1-1/out + source: sum-1-2/out right: default: 1 out: @@ -508,7 +508,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: outputs: out: label: out - outputSource: sum-1-2/out + outputSource: sum-1-1/out type: - int - float @@ -520,7 +520,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: out: - out run: double - sum-1-1: + sum-1-2: in: left: source: double-1-1/out @@ -529,10 +529,10 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: out: - out run: sum - sum-1-2: + sum-1-1: in: left: - source: sum-1-1/out + source: sum-1-2/out right: default: 1 out: diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index e122935f..14ccbed1 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -282,3 +282,30 @@ def test_dict(left: int, right: float) -> dict[str, float | int]: wkflw = construct(result, simplify_ids=True) assert str(wkflw.result["left"]) == "test_dict-1/left" assert wkflw.result["left"].__type__ == int | float + +def test_can_configure_field_separator(): + @dataclass + class IndexTest: + left: Fixed[list[int]] + + @task() + def test_sep() -> IndexTest: + return IndexTest(left=[3]) + + with set_configuration(field_index_types="int"): + result = test_sep().left[0] + wkflw = construct(result, simplify_ids=True) + rendered = render(wkflw, allow_complex_types=True)["__root__"] + assert str(wkflw.result) == "test_sep-1/left[0]" + + #with set_configuration(field_index_types="int,str"): + # result = test_sep().left[0] + # wkflw = construct(result, simplify_ids=True) + # rendered = render(wkflw, allow_complex_types=True)["__root__"] + # assert str(wkflw.result) == "test_sep-1[left][0]" + + #with set_configuration(field_index_types=""): + # result = test_sep().left[0] + # wkflw = construct(result, simplify_ids=True) + # rendered = render(wkflw, allow_complex_types=True)["__root__"] + # assert str(wkflw.result) == "test_sep-1/left/0" From 8c0024791681e85e7de8c2d8d9335ced3d3c8d24 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Wed, 21 Aug 2024 00:19:37 +0100 Subject: [PATCH 059/108] fix: correct logic for remapping before a workflow has a name --- src/dewret/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index 05436145..b4dd8f1a 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -150,7 +150,7 @@ def __name__(self) -> str: """Referral name for this reference.""" workflow = self.__workflow__ name = self.__root_name__ - return workflow.remap(name) or name + return workflow.remap(name) if workflow is not None else name def __str__(self) -> str: """Global description of the reference.""" From 0be1fea202b18b8788e83252312e6843b19c3777 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Wed, 21 Aug 2024 00:22:06 +0100 Subject: [PATCH 060/108] fix: remove duplicate __hash__ method --- src/dewret/workflow.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 3bd80c00..82981fe5 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -965,9 +965,6 @@ def return_type(self) -> Any: return self.task.target return inspect.signature(inspect.unwrap(self.task.target)).return_annotation - def __hash__(self) -> int: - return hash(self.id) - @property def name(self) -> str: """Name for this step. From 0bae23c4d5161d20cb75ba8dccf66771c23dd8af Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 13:40:24 +0100 Subject: [PATCH 061/108] fix: tidy up 0.9.2 merge, mostly label corrections --- src/dewret/__main__.py | 16 ++-------------- src/dewret/render.py | 9 ++------- src/dewret/utils.py | 23 ++++++++++++++++++++++- tests/test_annotations.py | 8 ++++---- tests/test_cwl.py | 32 +++++++++++++++----------------- tests/test_fieldable.py | 4 ++-- tests/test_multiresult_steps.py | 5 +++-- tests/test_nested.py | 4 ++-- tests/test_parameters.py | 4 ++-- 9 files changed, 54 insertions(+), 51 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index e3c9093a..26b49461 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -32,6 +32,7 @@ import json from .core import set_configuration, set_render_configuration +from .utils import load_module_or_package from .render import get_render_method, RawRenderModule, StructuredRenderModule from .tasks import Backend, construct @@ -130,21 +131,8 @@ def _opener(key, mode): opener = _opener render = get_render_method(render_module, pretty=pretty) - workflow_init = workflow_py.parent pkg = "__workflow__" - - # Try to import the workflow as a package, if possible, to allow relative imports. - try: - spec = importlib.util.spec_from_file_location(pkg, str(workflow_py.parent / "__init__.py")) - if spec is None or spec.loader is None: - raise ImportError(f"Could not open {pkg} package") - module = importlib.util.module_from_spec(spec) - sys.modules[pkg] = module - spec.loader.exec_module(module) - workflow = importlib.import_module(f"{pkg}.{workflow_py.stem}", pkg) - except ImportError: - loader = importlib.machinery.SourceFileLoader(pkg, str(workflow_py)) - workflow = loader.load_module() + workflow = load_module_or_package(pkg, workflow_py) task_fn = getattr(workflow, task) try: diff --git a/src/dewret/render.py b/src/dewret/render.py index 32b79f6f..e484eb5a 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -8,6 +8,7 @@ from .workflow import Workflow, NestedStep from .core import RawType from .workflow import Workflow +from .utils import load_module_or_package RenderConfiguration = TypeVar("RenderConfiguration", bound=dict[str, Any]) @@ -34,13 +35,7 @@ def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, # Attempt to load renderer as package, falling back to a single module otherwise. # This enables relative imports in renderers and therefore the ability to modularize. - try: - loader = importlib.machinery.SourceFileLoader("__renderer__", str(renderer.parent / "__init__.py")) - sys.modules["__renderer__"] = loader.load_module(f"__renderer__") - render_module = importlib.import_module(f"__renderer__.{renderer.stem}", "__renderer__") - except ImportError: - loader = importlib.machinery.SourceFileLoader("__renderer__", str(renderer)) - render_module = loader.load_module() + render_module = load_module_or_package("__renderer__", renderer) sys.modules["__renderer_mod__"] = render_module else: render_module = renderer diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 3b72cce9..035cf21b 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -20,8 +20,11 @@ import hashlib import json import sys -from types import FrameType, TracebackType, UnionType +import importlib +import importlib.util +from types import FrameType, TracebackType, UnionType, ModuleType from typing import Any, cast, Union, Protocol, ClassVar, Callable, Iterable, get_args, get_origin, Annotated +from pathlib import Path from collections.abc import Sequence, Mapping from sympy import Basic, Integer, Float, Rational @@ -58,6 +61,24 @@ def make_traceback(skip: int = 2) -> TracebackType | None: frame = frame.f_back return tb +def load_module_or_package(target_name: str, path: Path) -> ModuleType: + # Try to import the workflow as a package, if possible, to allow relative imports. + try: + spec = importlib.util.spec_from_file_location(target_name, str(path.parent / "__init__.py")) + if spec is None or spec.loader is None: + raise ImportError(f"Could not open {path.parent} package") + module = importlib.util.module_from_spec(spec) + sys.modules[target_name] = module + spec.loader.exec_module(module) + workflow = importlib.import_module(f"{target_name}.{path.stem}", target_name) + except ImportError as exc: + spec = importlib.util.spec_from_file_location(target_name, str(path)) + if spec is None or spec.loader is None: + raise ImportError(f"Could not open {path} module") from exc + workflow = importlib.util.module_from_spec(spec) + spec.loader.exec_module(workflow) + + return workflow def flatten_if_set(value: Any) -> RawType | Unset: """Takes a Raw-like structure and makes it RawType or Unset. diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 9d70ac16..39bc36aa 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -70,7 +70,7 @@ def test_at_render() -> None: inputs: increment-1-num: default: 3 - label: increment-1-num + label: num type: int outputs: out: @@ -78,7 +78,7 @@ def test_at_render() -> None: outputSource: to_int-1/out type: - int - - double + - float steps: increment-1: in: @@ -108,7 +108,7 @@ def test_at_render() -> None: inputs: increment-1-num: default: 3 - label: increment-1-num + label: num type: int outputs: out: @@ -116,7 +116,7 @@ def test_at_render() -> None: outputSource: to_int-1/out type: - int - - double + - float steps: increment-1: in: diff --git a/tests/test_cwl.py b/tests/test_cwl.py index 48078d87..52f7793a 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -86,7 +86,7 @@ def get_now() -> datetime: days_in_future-1-num: default: 3 type: int - label: days_in_future-1-num + label: num get_now-1: label: get_now-1 type: datetime @@ -115,7 +115,7 @@ def get_now() -> datetime: days_in_future-1-num: default: 3 type: int - label: days_in_future-1-num + label: num outputs: out: label: out @@ -154,7 +154,7 @@ def test_cwl_with_parameter() -> None: class: Workflow inputs: increment-{hsh}-num: - label: increment-{hsh}-num + label: num type: int default: 3 outputs: @@ -171,7 +171,7 @@ def test_cwl_with_parameter() -> None: out: [out] """) -def test_cwl_with_parameter() -> None: +def test_cwl_with_positional_parameter() -> None: """Check whether we can move raw input to parameters. Produces CWL for a call with a changeable raw value, that is converted @@ -191,7 +191,7 @@ def test_cwl_with_parameter() -> None: class: Workflow inputs: increment-{hsh}-num: - label: increment-{hsh}-num + label: num type: int default: 3 outputs: @@ -357,14 +357,14 @@ def test_cwl_references() -> None: class: Workflow inputs: increment-{hsh_increment}-num: - label: increment-{hsh_increment}-num + label: num type: int default: 3 outputs: out: label: out outputSource: double-{hsh_double}/out - type: + type: - int - float steps: @@ -397,14 +397,14 @@ def test_complex_cwl_references() -> None: class: Workflow inputs: increment-1-num: - label: increment-1-num + label: num type: int default: 23 outputs: out: label: out outputSource: sum-1/out - type: + type: - int - float steps: @@ -459,7 +459,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: type: int sum-1-right: default: 3 - label: sum-1-right + label: right type: int outputs: out: @@ -548,8 +548,7 @@ def test_tuple_floats() -> None: """ result = tuple_float_return() workflow = construct(result, simplify_ids=True) - rendered = render(workflow) - print(yaml.dump(rendered)) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 class: Workflow @@ -558,11 +557,10 @@ def test_tuple_floats() -> None: out: label: out outputSource: tuple_float_return-1/out - type: - items: - - type: float - - type: float - type: array + items: + - float + - float + type: array steps: tuple_float_return-1: run: tuple_float_return diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index 14ccbed1..fa6dd95a 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -44,7 +44,7 @@ def test_fields_of_parameters_usable() -> None: outputSource: sum-1-1/out type: - int - - double + - float steps: sum-1-1: in: @@ -83,7 +83,7 @@ class MyDataclass: outputSource: sum-1/out type: - int - - double + - float steps: sum-1: in: diff --git a/tests/test_multiresult_steps.py b/tests/test_multiresult_steps.py index b278e84a..d7d7ca8a 100644 --- a/tests/test_multiresult_steps.py +++ b/tests/test_multiresult_steps.py @@ -287,8 +287,9 @@ def test_pair_can_be_returned_from_step() -> None: def test_list_can_be_returned_from_step() -> None: """Tests whether a task can insert result fields into other steps.""" - workflow = construct(list_cast(iterable=algorithm_with_pair()), simplify_ids=True) - rendered = render(workflow)["__root__"] + with set_configuration(flatten_all_nested=True): + workflow = construct(list_cast(iterable=algorithm_with_pair()), simplify_ids=True) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow diff --git a/tests/test_nested.py b/tests/test_nested.py index fc556f8f..e507194f 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -26,14 +26,14 @@ def test_can_supply_nested_raw(): pi: default: 3.141592653589793 label: pi - type: double + type: float outputs: out: label: out outputSource: max_list-1/out type: - int - - double + - float steps: max_list-1: in: diff --git a/tests/test_parameters.py b/tests/test_parameters.py index e19f0046..5d32315c 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -34,7 +34,7 @@ def test_cwl_parameters() -> None: type: int default: 3 rotate-1-num: - label: rotate-1-num + label: num type: int default: 3 outputs: @@ -77,7 +77,7 @@ def test_complex_parameters() -> None: type: int default: 23 rotate-2-num: - label: rotate-2-num + label: num type: int default: 23 outputs: From 8e031374bae7015f28031f86b48e5230c2d09002 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 18:07:29 +0100 Subject: [PATCH 062/108] fix: tests not working --- src/dewret/annotations.py | 11 +-- src/dewret/core.py | 55 +++---------- src/dewret/renderers/cwl.py | 149 +++++++++++++++++++++++++++--------- src/dewret/tasks.py | 7 +- src/dewret/utils.py | 3 + src/dewret/workflow.py | 22 ++++-- tests/test_annotations.py | 13 ++-- tests/test_fieldable.py | 85 ++++++++++++-------- tests/test_nested.py | 5 +- 9 files changed, 215 insertions(+), 135 deletions(-) diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index 01c375a8..c561c28b 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -5,7 +5,7 @@ from functools import lru_cache from types import FunctionType from dataclasses import dataclass -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args, Mapping, TypeAliasType +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args, Mapping, TypeAliasType, get_type_hints T = TypeVar("T") AtRender = Annotated[T, "AtRender"] @@ -26,7 +26,7 @@ def __init__(self, fn: Callable[..., Any]): @property def return_type(self): - return inspect.signature(inspect.unwrap(self.fn)).return_annotation + return get_type_hints(inspect.unwrap(self.fn), include_extras=True)["return"] @staticmethod def _typ_has(typ: type, annotation: type) -> bool: @@ -69,10 +69,11 @@ def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> type | No all_annotations: dict[str, type] = {} typ: type | None = None if (typ := self.fn.__annotations__.get(arg)): - ... + if isinstance(typ, str): + typ = get_type_hints(self.fn, include_extras=True).get(arg) elif exhaustive: - if "__annotations__" in self.fn.__globals__: - if (typ := self.fn.__globals__["__annotations__"].get(arg)): + if (anns := get_type_hints(sys.modules[self.fn.__module__], include_extras=True)): + if (typ := anns.get(arg)): ... elif (orig_pair := self.get_all_imported_names().get(arg)): orig_module, orig_name = orig_pair diff --git a/src/dewret/core.py b/src/dewret/core.py index b4dd8f1a..34629ea9 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -4,12 +4,12 @@ from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Never from contextlib import contextmanager from contextvars import ContextVar -from sympy import Expr, Symbol +from sympy import Expr, Symbol, Basic import copy BasicType = str | float | bool | bytes | int | None RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] -FirmType = BasicType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] +FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] U = TypeVar("U") T = TypeVar("T", bound=TypedDict) @@ -87,10 +87,16 @@ def _set_configuration(config_group: str, kwargs: U) -> Iterator[ContextVar[Glob CONFIGURATION.set(previous) def get_configuration(key: str) -> RawType: - return CONFIGURATION.get()["construct"].get(key) # type: ignore + try: + return CONFIGURATION.get()["construct"].get(key) # type: ignore + except LookupError: + return default_construct_config().get(key) def get_render_configuration(key: str) -> RawType: - return CONFIGURATION.get()["render"].get(key) # type: ignore + try: + return CONFIGURATION.get()["render"].get(key) # type: ignore + except LookupError: + return default_renderer_config().get(key) class Reference(Generic[U], Symbol): """Superclass for all symbolic references to values.""" @@ -129,7 +135,7 @@ def _raise_unevaluatable_error(self): def __eq__(self, other) -> bool: if isinstance(other, list) or other is None: return False - if not isinstance(other, Reference): + if not isinstance(other, Reference) and not isinstance(other, Basic): self._raise_unevaluatable_error() return super().__eq__(other) @@ -193,44 +199,7 @@ def __init__(self, to_wrap: Reference[U]): def __iter__(self): count = -1 for _ in self.__wrapped__.__inner_iter__(): - yield Iterated(to_wrap=self.__wrapped__, iteration=(count := count + 1)) - - -class Iterated(Reference[U]): - __wrapped__: Reference[U] - __iteration__: int - - def __init__(self, to_wrap: Reference[U], iteration: int, *args, **kwargs): - self.__wrapped__ = to_wrap - self.__iteration__ = iteration - super().__init__(*args, **kwargs) - - @property - def _(self): - return self.__wrapped__._ - - @property - def __root_name__(self) -> str: - return f"{self.__wrapped__.__root_name__}[{self.__iteration__}]" - - @property - def __type__(self) -> type: - return self.__wrapped__.__type__ - - def __hash__(self) -> int: - return hash(self.__root_name__) - - @property - def __field__(self) -> tuple[str]: - return tuple(list(self.__wrapped__.__field__) + [str(self.__iteration__)]) - - @property - def __workflow__(self) -> WorkflowProtocol: - return self.__wrapped__.__workflow__ - - @__workflow__.setter - def __workflow__(self, workflow: WorkflowProtocol) -> None: - self.__wrapped__.__workflow__ = workflow + yield self.__wrapped__.__make_reference__(field=(count := count + 1)) @dataclass diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 57bdcb64..1b4d08e5 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -21,11 +21,10 @@ from attrs import define, has as attrs_has, fields as attrs_fields, AttrsInstance from dataclasses import is_dataclass, fields as dataclass_fields from collections.abc import Mapping -from contextvars import ContextVar from typing import TypedDict, NotRequired, get_origin, get_args, Union, cast, Any, Unpack, Iterable from types import UnionType from inspect import isclass -from sympy import Expr, Basic, Tuple, sympify, Dict, jscode +from sympy import Basic, Tuple, Dict, jscode, Symbol from dewret.core import ( Raw, @@ -70,7 +69,14 @@ class CommandInputSchema(TypedDict): str, CommandInputSchema, list[str], list["InputSchemaType"], dict[str, "str | InputSchemaType"] ] -def render_expression(ref: Any) -> str: +def render_expression(ref: Any) -> "ReferenceDefinition": + """Turn a rich (sympy) expression into a CWL JS expression. + + Args: + ref: a structure whose elements are all string-renderable or sympy Basic. + + Returns: a ReferenceDefinition containing a string representation of the expression in the form `$(...)`. + """ def _render(ref): if not isinstance(ref, Basic): if isinstance(ref, Mapping): @@ -78,7 +84,39 @@ def _render(ref): elif not isinstance(ref, str | bytes) and isinstance(ref, Iterable): ref = Tuple(*(_render(val) for val in ref)) return ref - return f"$({jscode(_render(ref))})" + expr = _render(ref) + if isinstance(expr, Basic): + values = list(expr.free_symbols) + step_syms = [sym for sym in expr.free_symbols if isinstance(sym, StepReference)] + param_syms = [sym for sym in expr.free_symbols if isinstance(sym, ParameterReference)] + + if set(values) != set(step_syms) | set(param_syms): + raise NotImplementedError(f"Can only build expressions for step results and param results: {ref}") + + if len(step_syms) > 1: + raise NotImplementedError(f"Can only create expressions with 1 step reference: {ref}") + if not (step_syms or param_syms): + ... + if values == [ref]: + if isinstance(ref, StepReference): + return ReferenceDefinition(source=to_name(ref), value_from=None) + else: + return ReferenceDefinition(source=ref.name, value_from=None) + source = None + for ref in values: + if isinstance(ref, StepReference): + field = with_field(ref) + parts = field.split("/") + base = f"/{parts[0]}" if parts and parts[0] else "" + if len(parts) > 1: + expr = expr.subs(ref, f"self.{'.'.join(parts[1:])}") + else: + expr = expr.subs(ref, "self") + source = f"{ref.__root_name__}{base}" + else: + expr = expr.subs(ref, Symbol(f"inputs.{ref.name}")) + return ReferenceDefinition(source=source, value_from=f"$({jscode(_render(expr))})") + return ReferenceDefinition(source=str(expr), value_from=None) class CWLRendererConfiguration(TypedDict): """Configuration for the renderer. @@ -93,6 +131,14 @@ class CWLRendererConfiguration(TypedDict): def default_renderer_config() -> CWLRendererConfiguration: + """Default configuration for this renderer. + + This is a hook-like call to give a configuration dict that this renderer + will respect, and sets any necessary default values. + + Returns: a dict with (preferably) raw type structures to enable easy setting + from YAML/JSON. + """ return { "allow_complex_types": False, "factories_as_params": False, @@ -100,16 +146,45 @@ def default_renderer_config() -> CWLRendererConfiguration: def with_type(result: Any) -> type: + """Get a Python type from a value. + + Does so either by using its `__type__` field (for example, for References) + or if unavailable, using `type()`. + + Returns: a Python type. + """ if hasattr(result, "__type__"): return result.__type__ return type(result) def with_field(result: Any) -> str: + """Get a string representing any 'field' suffix of a value. + + This only makes sense in the context of a Reference, which can represent + a deep reference with a known variable (parameter or step result, say) using + its `__field__` attribute. Defaults to `"out"` as this produces compliant CWL + where every output has a "fieldname". + + Returns: a string representation of the field portion of the passed value or `"out"`. + """ if hasattr(result, "__field__") and result.__field__: return result.__field_str__ else: return "out" +def to_name(result: Reference[Any]): + """Take a reference and get a name representing it. + + The primary purpose of this method is to deal with the case where a reference is to the + whole result, as we always put this into an imagined `out` field for CWL consistency. + + Returns: the name of the reference, including any field portion, appending an `"out"` fieldname if none. + """ + if hasattr(result, "__field__") and not result.__field__ and isinstance(result, StepReference): + return f"{result.__name__}/out" + return result.__name__ + + @define class ReferenceDefinition: """CWL-renderable internal reference. @@ -117,7 +192,8 @@ class ReferenceDefinition: Normally points to a value or a step. """ - source: str + source: str | None + value_from: str | None @classmethod def from_reference(cls, ref: Reference) -> "ReferenceDefinition": @@ -128,7 +204,7 @@ def from_reference(cls, ref: Reference) -> "ReferenceDefinition": Args: ref: reference to convert. """ - return cls(source=to_name(ref)) + return render_expression(ref) def render(self) -> dict[str, RawType]: """Render to a dict-like structure. @@ -137,7 +213,12 @@ def render(self) -> dict[str, RawType]: Reduced form as a native Python dict structure for serialization. """ - return {"source": self.source} + representation = {} + if self.source is not None: + representation["source"] = self.source + if self.value_from is not None: + representation["valueFrom"] = self.value_from + return representation @define @@ -196,11 +277,11 @@ def render(self) -> dict[str, RawType]: key: ( ref.render() if isinstance(ref, ReferenceDefinition) else - {"expression": render_expression(ref)} + render_expression(ref).render() if isinstance(ref, Basic) else {"default": firm_to_raw(ref.value)} if hasattr(ref, "value") - else {"expression": render_expression(ref)} + else render_expression(ref).render() ) for key, ref in self.in_.items() }, @@ -208,32 +289,33 @@ def render(self) -> dict[str, RawType]: } -def cwl_type_from_value(label: str, val: RawType | Unset) -> InputSchemaType: +def cwl_type_from_value(label: str, val: RawType | Unset) -> CommandInputSchema: """Find a CWL type for a given (possibly Unset) value. Args: + label: the label for the variable being checked to prefill the input def and improve debugging info. val: a raw Python variable or an unset variable. Returns: - Type as a string or list of strings. + Input schema type. """ if val is not None and hasattr(val, "__type__"): raw_type = val.__type__ else: raw_type = type(val) - return to_cwl_type(label, raw_type)["type"] + return to_cwl_type(label, raw_type) def to_cwl_type(label: str, typ: type) -> CommandInputSchema: """Map Python types to CWL types. Args: + label: the label for the variable being checked to prefill the input def and improve debugging info. typ: a Python basic type. Returns: - CWL specification type name, or a list - if a union. + CWL specification type dict. """ typ_dict: CommandInputSchema = { "label": label, @@ -276,12 +358,12 @@ def to_cwl_type(label: str, typ: type) -> CommandInputSchema: typ_dict["type"] = "array" except IndexError as err: raise TypeError( - f"Cannot render complex type ({typ}) to CWL, have you enabled allow_complex_types configuration?" + f"Cannot render complex type ({typ}) to CWL for {label}, have you enabled allow_complex_types configuration?" ) from err elif get_render_configuration("allow_complex_types"): typ_dict["type"] = typ if isinstance(typ, str) else typ.__name__ else: - raise TypeError(f"Cannot render type ({typ}) to CWL") + raise TypeError(f"Cannot render type ({typ}) to CWL for {label}") return typ_dict @@ -296,6 +378,8 @@ class CommandOutputSchema(CommandInputSchema): """ outputSource: NotRequired[str] + expression: NotRequired[str] + source: NotRequired[list[str]] def raw_to_command_input_schema(label: str, value: RawType | Unset) -> InputSchemaType: @@ -312,7 +396,7 @@ def raw_to_command_input_schema(label: str, value: RawType | Unset) -> InputSche Structure used to define (possibly compound) basic types for input. """ if isinstance(value, dict) or isinstance(value, list): - return {"type": _raw_to_command_input_schema_internal(label, value)} + return _raw_to_command_input_schema_internal(label, value) else: return cwl_type_from_value(label, value) @@ -368,8 +452,7 @@ def to_output_schema( def _raw_to_command_input_schema_internal( label: str, value: RawType | Unset ) -> CommandInputSchema: - typ = cwl_type_from_value(label, value) - structure: CommandInputSchema = {"type": typ, "label": label} + structure: CommandInputSchema = cwl_type_from_value(label, value) if isinstance(value, dict): structure["fields"] = { key: _raw_to_command_input_schema_internal(key, val) @@ -434,7 +517,7 @@ def from_parameters( label=input.__name__, default=(default := flatten_if_set(input.__default__)), type=raw_to_command_input_schema( - label=input.name, value=default + label=input.__original_name__, value=default ), ) for input in parameters @@ -450,23 +533,15 @@ def render(self) -> dict[str, RawType]: """ result: dict[str, RawType] = {} for key, input in self.inputs.items(): - item = { - # Would rather not cast, but CommandInputSchema is dict[RawType] - # by construction, where type is seen as a TypedDict subclass. - "type": firm_to_raw(cast(FirmType, input.type)), - "label": input.label, - } + # Would rather not cast, but CommandInputSchema is dict[RawType] + # by construction, where type is seen as a TypedDict subclass. + item = firm_to_raw(cast(FirmType, input.type)) if not isinstance(input.default, Unset): item["default"] = firm_to_raw(input.default) result[key] = item return result -def to_name(result: Reference[Any]): - if hasattr(result, "__field__") and not result.__field__ and isinstance(result, StepReference): - return f"{result.__name__}/out" - return result.__name__ - @define class OutputsDefinition: """CWL-renderable set of workflow outputs. @@ -507,15 +582,19 @@ def _build_results(result): return cls(outputs=_build_results(results)) except AttributeError: expr, references = expr_to_references(results) - references = sorted( + reference_names = sorted( { str(ref._.parameter) if isinstance(ref, ParameterReference) else str(ref._.step) for ref in references } ) return cls(outputs={ - "expression": str(expr), - "source": references + "out": { + "type": "float", # WARNING: we assume any arithmetic expression returns a float. + "label": "out", + "expression": str(expr), + "source": reference_names + } }) def render(self) -> dict[str, RawType] | list[RawType]: @@ -582,7 +661,7 @@ def from_workflow( workflow.result if isinstance(workflow.result, list | tuple | Tuple) else {with_field(workflow.result): workflow.result} - if workflow.has_result else + if workflow.has_result and workflow.result is not None else {} ), name=name, diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index e271e082..c6bb30ff 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -387,12 +387,7 @@ def _fn( **kwargs: Param.kwargs, ) -> RetType: configuration = None - try: - allow_positional_args = get_configuration("allow_positional_args") - except LookupError: - configuration = set_configuration() - configuration.__enter__() - allow_positional_args = get_configuration("allow_positional_args") + allow_positional_args = get_configuration("allow_positional_args") try: # Ensure that all arguments are passed as keyword args and prevent positional args. diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 035cf21b..1cc58b93 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -26,6 +26,7 @@ from typing import Any, cast, Union, Protocol, ClassVar, Callable, Iterable, get_args, get_origin, Annotated from pathlib import Path from collections.abc import Sequence, Mapping +from dataclasses import asdict, is_dataclass from sympy import Basic, Integer, Float, Rational from .core import Reference, BasicType, RawType, FirmType, Raw @@ -110,6 +111,8 @@ def crawl_raw(value: Any, action: Callable[[Any], Any]) -> RawType: return value if isinstance(value, Mapping): return {key: flatten(item) for key, item in value.items()} + if is_dataclass(value) and not isinstance(value, type): + return crawl_raw(asdict(value), action) if isinstance(value, Sequence): return [flatten(item) for item in value] if (raw := ensure_raw(value)) is not None: diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 82981fe5..7f368c45 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -24,7 +24,7 @@ from attrs import has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter, OrderedDict -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable, get_origin, get_args, Generator, Sized, Sequence +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable, get_origin, get_args, Generator, Sized, Sequence, get_type_hints from uuid import uuid4 from sympy import Symbol, Expr, Basic, Tuple, Dict, nan @@ -793,7 +793,10 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None else: if is_dataclass(parent_type): try: - field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type + type_hints = get_type_hints(parent_type, localns={parent_type.__name__: parent_type}, include_extras=True) + field_type = type_hints.get(field) + if field_type is None: + field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type except StopIteration: raise AttributeError(f"Dataclass {parent_type} does not have field {field}") elif attr_has(parent_type): @@ -805,7 +808,7 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None # TypedDict elif inspect.isclass(parent_type) and issubclass(parent_type, dict) and hasattr(parent_type, "__annotations__"): try: - field_type = parent_type.__annotations__[field] + field_type = get_type_hints(parent_type, include_extras=True)[field] except KeyError: raise AttributeError(f"TypedDict {parent_type} does not have field {field}") if not field_type and get_configuration("allow_plain_dict_fields") and inspect.isclass(base) and issubclass(base, dict): @@ -825,7 +828,7 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None return self.__make_reference__(typ=field_type, field=field, **init_kwargs) raise AttributeError( - f"Could not determine the type for field {field} in type {parent_type}" + f"Could not determine the type for field {field} in type {parent_type} (type of parent type is {type(parent_type)})" ) class BaseStep(WorkflowComponent): @@ -963,7 +966,8 @@ def return_type(self) -> Any: ) if isinstance(self.task.target, type): return self.task.target - return inspect.signature(inspect.unwrap(self.task.target)).return_annotation + ann = get_type_hints(inspect.unwrap(self.task.target), include_extras=True)["return"] + return ann @property def name(self) -> str: @@ -1096,6 +1100,10 @@ def __init__( def __name__(self): return self.name + @property + def __original_name__(self) -> str: + return self.name + @property def __default__(self) -> Unset: """Dummy default property for use as property.""" @@ -1184,6 +1192,10 @@ def __getitem__(self, attr: str) -> "ParameterReference": ) ) from exc + @property + def __original_name__(self) -> str: + return self._.parameter.__original_name__ + def __getattr__(self, attr: str) -> "ParameterReference": try: return self[attr] diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 39bc36aa..a4cbb440 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -164,10 +164,13 @@ def loop_over_lists(list_1: list[int]) -> list[int]: cwlVersion: 1.2 inputs: {} outputs: - expression: '[4 + list_1[0] + list_2[0], 4 + list_1[1] + list_2[1], 4 + list_1[2] + list_2[2], - 4 + list_1[3] + list_2[3]]' - source: - - list_1 - - list_2 + out: + type: float + label: out + expression: '[4 + list_1[0] + list_2[0], 4 + list_1[1] + list_2[1], 4 + list_1[2] + list_2[2], + 4 + list_1[3] + list_2[3]]' + source: + - list_1 + - list_2 steps: {} """) diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index fa6dd95a..18495029 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -1,3 +1,4 @@ +from __future__ import annotations import yaml from dataclasses import dataclass @@ -19,10 +20,9 @@ class Sides: SIDES: Sides = Sides(3, 6) @workflow() -def sum_sides(): +def sum_sides() -> float: return sum(left=SIDES.left, right=SIDES.right) -@pytest.mark.skip(reason="Need expression support") def test_fields_of_parameters_usable() -> None: result = sum_sides() wkflw = construct(result, simplify_ids=True) @@ -34,10 +34,20 @@ def test_fields_of_parameters_usable() -> None: inputs: SIDES: label: SIDES + default: + left: 3 + right: 6 type: record - items: - left: int - right: int + fields: + left: + default: 3 + label: left + type: int + right: + default: 6 + label: right + type: int + label: SIDES outputs: out: label: out @@ -49,26 +59,26 @@ def test_fields_of_parameters_usable() -> None: sum-1-1: in: left: - source: SIDES - valueFrom: $(self.left) + source: SIDES/left right: - source: SIDES - valueFrom: $(self.right) + source: SIDES/right out: - out run: sum """) +@dataclass +class MyDataclass: + left: int + right: "MyDataclass" + def test_can_get_field_reference_from_parameter(): - @dataclass - class MyDataclass: - left: int my_param = param("my_param", typ=MyDataclass) - result = sum(left=my_param.left, right=my_param) + result = sum(left=my_param.left, right=sum(left=my_param.right.left, right=my_param)) wkflw = construct(result, simplify_ids=True) param_references = {(str(p), p.__type__) for p in wkflw.find_parameters()} - assert param_references == {("my_param/left", int), ("my_param", MyDataclass)} + assert param_references == {("my_param/left", int), ("my_param", MyDataclass), ("my_param/right/left", int)} rendered = render(wkflw, allow_complex_types=True)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -80,7 +90,7 @@ class MyDataclass: outputs: out: label: out - outputSource: sum-1/out + outputSource: sum-2/out type: - int - float @@ -88,12 +98,21 @@ class MyDataclass: sum-1: in: left: - source: my_param/left + source: my_param/right/left right: source: my_param out: - out run: sum + sum-2: + in: + left: + source: my_param/left + right: + source: sum-1/out + out: + - out + run: sum """) def test_can_get_field_reference_iff_parent_type_has_field(): @@ -129,11 +148,11 @@ def get_left(my_dataclass: MyDataclass) -> int: assert str(wkflw.result) == "get_left-1" assert wkflw.result.__type__ == int -def test_can_get_field_references_from_typed_dict(): - class MyDict(TypedDict): - left: int - right: float +class MyDict(TypedDict): + left: int + right: float +def test_can_get_field_references_from_typed_dict(): @workflow() def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: result: MyDict = {"left": mod10(num=my_dict["left"]), "right": pi()} @@ -145,6 +164,10 @@ def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: assert str(wkflw.result["left"]) == "test_dict-1/left" assert wkflw.result["left"].__type__ == int +@dataclass +class MyListWrapper: + my_list: list[int] + def test_can_iterate(): @task() def test_task(alpha: int, beta: float, charlie: bool) -> int: @@ -194,10 +217,6 @@ def test_iterated() -> int: assert wkflw.result._.step.positional_args == {"alpha": True, "beta": True, "charlie": True} - @dataclass - class MyListWrapper: - my_list: list[int] - @task() def test_list_2() -> MyListWrapper: return MyListWrapper(my_list=[1, 2]) @@ -238,12 +257,9 @@ def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: - 1 - - 2 - 3 + items: array label: param - type: - type: - items: array - label: param - type: array + type: array outputs: out: label: out @@ -253,7 +269,8 @@ def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: mod10-1: in: num: - expression: $(param[0][0] + param[0][1] + param[1][0] + param[1][1] + mod10-2) + valueFrom: $(inputs.param[0][0] + inputs.param[0][1] + inputs.param[1][0] + inputs.param[1][1] + self) + source: mod10-2/out out: - out run: mod10 @@ -283,11 +300,11 @@ def test_dict(left: int, right: float) -> dict[str, float | int]: assert str(wkflw.result["left"]) == "test_dict-1/left" assert wkflw.result["left"].__type__ == int | float -def test_can_configure_field_separator(): - @dataclass - class IndexTest: - left: Fixed[list[int]] +@dataclass +class IndexTest: + left: Fixed[list[int]] +def test_can_configure_field_separator(): @task() def test_sep() -> IndexTest: return IndexTest(left=[3]) diff --git a/tests/test_nested.py b/tests/test_nested.py index e507194f..99ccf1ca 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -38,14 +38,15 @@ def test_can_supply_nested_raw(): max_list-1: in: lst: - expression: $(2*reverse_list-1) + source: reverse_list-1/out + valueFrom: $(2*self) out: - out run: max_list reverse_list-1: in: to_sort: - expression: $((1.0, 3.0, pi)) + valueFrom: $((1.0, 3.0, inputs.pi)) out: - out run: reverse_list From 173b2c72e589d5890aff0d0d6dc22daf58994a62 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 18:16:44 +0100 Subject: [PATCH 063/108] fix: make sure we can identify an iterated ref --- src/dewret/core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index 34629ea9..b76107ff 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -103,6 +103,7 @@ class Reference(Generic[U], Symbol): _type: type[U] | None = None __workflow__: WorkflowProtocol + __iterated__: bool = False def __init__(self, *args, typ: type[U] | None = None, **kwargs): self._type = typ @@ -199,7 +200,9 @@ def __init__(self, to_wrap: Reference[U]): def __iter__(self): count = -1 for _ in self.__wrapped__.__inner_iter__(): - yield self.__wrapped__.__make_reference__(field=(count := count + 1)) + ref = self.__wrapped__.__make_reference__(field=(count := count + 1)) + ref.__iterated__ = True + yield ref @dataclass From 0d8b3d1b31596ce7a10e4ac216423793f2c93b20 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 22:09:59 +0100 Subject: [PATCH 064/108] docs: tidyup docstrings and structures --- src/dewret/core.py | 392 +++++++++++++++++++++++++++++++---- src/dewret/render.py | 48 +++-- src/dewret/renderers/cwl.py | 12 +- src/dewret/utils.py | 122 ++++++++--- src/dewret/workflow.py | 398 +++++++++++++++++++++++++++--------- tests/_lib/frender.py | 15 +- 6 files changed, 787 insertions(+), 200 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index b76107ff..8c7b642f 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,7 +1,10 @@ from dataclasses import dataclass +from abc import abstractmethod, abstractstaticmethod +import importlib import base64 +from attrs import define from functools import lru_cache -from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Never +from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Literal, Callable, cast, runtime_checkable from contextlib import contextmanager from contextvars import ContextVar from sympy import Expr, Symbol, Basic @@ -12,9 +15,17 @@ FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] U = TypeVar("U") -T = TypeVar("T", bound=TypedDict) +T = TypeVar("T") def strip_annotations(parent_type: type) -> tuple[type, tuple]: + """Discovers and removes annotations from a parent type. + + Args: + parent_type: a type, possibly Annotated. + + Returns: a parent type that is not Annotation, along with any stripped metadata, if present. + + """ # Strip out any annotations. This should be auto-flattened, so in theory only one iteration could occur. metadata = [] while get_origin(parent_type) is Annotated: @@ -22,47 +33,199 @@ def strip_annotations(parent_type: type) -> tuple[type, tuple]: metadata += list(parent_metadata) return parent_type, tuple(metadata) +RenderConfiguration = dict[str, RawType] + class WorkflowProtocol(Protocol): + """Expected structure for a workflow. + + We do not expect various workflow implementations, but this allows us to define the + interface expected by the core classes. + """ def remap(self, name: str) -> str: + """Perform any name-changing for steps, etc. in the workflow. + + This enables, for example, simplifying all the IDs to an integer sequence. + + Returns: remapped name. + """ + ... + + def set_result(self, result: Basic | list[Basic] | tuple[Basic]) -> None: + """Set the step that should produce a result for the overall workflow.""" + ... + + def simplify_ids(self, infix: list[str] | None = None) -> None: + """Drop the non-human-readable IDs if possible, in favour of integer sequences. + + Args: + infix: any inherited intermediary identifiers, to allow nesting, or None. + """ + ... + +class BaseRenderModule(Protocol): + """Common routines for all renderer modules.""" + @staticmethod + def default_config() -> dict[str, RawType]: + """Retrieve default settings. + + These will not change during execution, but can be overridden by `dewret.core.set_render_configuration`. + + Returns: a static, serializable dict. + """ + ... + +@runtime_checkable +class RawRenderModule(BaseRenderModule, Protocol): + """Render module that returns raw text.""" + @abstractmethod + def render_raw(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> dict[str, str]: + """Turn a workflow into flat strings. + + Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least. + """ + ... + +@runtime_checkable +class StructuredRenderModule(BaseRenderModule, Protocol): + """Render module that returns JSON/YAML-serializable structures.""" + @abstractmethod + def render(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> dict[str, dict[str, RawType]]: + """Turn a workflow into a serializable structure. + + Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least. + """ + ... + +class RenderCall(Protocol): + """Callable that will render out workflow(s).""" + def __call__(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> dict[str, str] | dict[str, RawType]: + """Take a workflow and turn it into a set of serializable (sub)workflows. + + Args: + workflow: root workflow. + kwargs: configuration for the renderer. + + Returns: a mapping of keys to serialized workflows, containing at least `__root__`. + """ ... class UnevaluatableError(Exception): + """Signposts that a user has tried to treat a reference as the real (runtime) value. + + For example, by comparing to a concrete integer or value, etc. + """ ... -class ConstructConfiguration(TypedDict): - flatten_all_nested: NotRequired[bool] - allow_positional_args: NotRequired[bool] - allow_plain_dict_fields: NotRequired[bool] - field_separator: NotRequired[str] - field_index_types: NotRequired[str] +@define +class ConstructConfiguration: + """Basic configuration of the construction process. + + Holds configuration that may be relevant to `construst(...)` calls or, realistically, + anything prior to rendering. It should hold generic configuration that is renderer-independent. + """ + flatten_all_nested: bool = False + allow_positional_args: bool = False + allow_plain_dict_fields: bool = False + field_separator: str = "/" + field_index_types: str = "int" + +class ConstructConfigurationTypedDict(TypedDict): + """Basic configuration of the construction process. + + Holds configuration that may be relevant to `construst(...)` calls or, realistically, + anything prior to rendering. It should hold generic configuration that is renderer-independent. + + **THIS MUST BE KEPT IDENTICAL TO ConstructConfiguration.** + """ + flatten_all_nested: bool + allow_positional_args: bool + allow_plain_dict_fields: bool + field_separator: str + field_index_types: str + +@define +class GlobalConfiguration: + """Overall configuration structure. -class GlobalConfiguration(TypedDict): + Having a single configuration dict allows us to manage only one ContextVar. + """ construct: ConstructConfiguration render: dict CONFIGURATION: ContextVar[GlobalConfiguration] = ContextVar("configuration") @contextmanager -def set_configuration(**kwargs: Unpack[ConstructConfiguration]) -> Iterator[ContextVar[GlobalConfiguration]]: - with _set_configuration("construct", kwargs) as var: - yield var +def set_configuration(**kwargs: Unpack[ConstructConfigurationTypedDict]) -> Iterator[ContextVar[GlobalConfiguration]]: + """Sets the construct-time configuration. + + This is a context manager, so that a setting can be temporarily overridden and automatically restored. + """ + with _set_configuration() as CONFIGURATION: + for key, value in kwargs.items(): + setattr(CONFIGURATION.get().construct, key, value) + yield CONFIGURATION @contextmanager def set_render_configuration(kwargs) -> Iterator[ContextVar[GlobalConfiguration]]: - with _set_configuration("render", kwargs) as var: - yield var + """Sets the render-time configuration. + + This is a context manager, so that a setting can be temporarily overridden and automatically restored. + + Returns: the yielded global configuration ContextVar. + """ + with _set_configuration() as CONFIGURATION: + CONFIGURATION.get().render.update(**kwargs) + yield CONFIGURATION + +@contextmanager +def _set_configuration() -> Iterator[ContextVar[GlobalConfiguration]]: + """Prepares and tidied up the configuration for applying settings. + + This is a context manager, so that a setting can be temporarily overridden and automatically restored. + """ + try: + previous = CONFIGURATION.get() + except LookupError: + previous = GlobalConfiguration(construct=ConstructConfiguration(), render=default_renderer_config()) + CONFIGURATION.set(previous) + previous = copy.deepcopy(previous) + + try: + yield CONFIGURATION + finally: + CONFIGURATION.set(previous) + @lru_cache -def default_renderer_config() -> dict: +def default_renderer_config() -> RenderConfiguration: + """Gets the default renderer configuration. + + This may be called frequently, but is cached so note that any changes to the + wrapped config function will _not_ be reflected during the process. + + It is a light wrapper for `default_config` in the supplier renderer module. + + Returns: the default configuration dict for the chosen renderer. + """ try: - from __renderer_mod__ import default_config + # We have to use a cast as we do not know if the renderer module is valid. + render_module = cast(BaseRenderModule, importlib.import_module("__renderer_mod__")) + default_config: Callable[[], RenderConfiguration] = render_module.default_config except ImportError: return {} return default_config() @lru_cache def default_construct_config() -> ConstructConfiguration: + """Gets the default construct-time configuration. + + This is the primary mechanism for configuring dewret internals, so these defaults + should be carefully chosen and, if they change, that likely has an impact on backwards compatibility + from a SemVer perspective. + + Returns: configuration dictionary with default construct values. + """ return ConstructConfiguration( flatten_all_nested=False, allow_positional_args=False, @@ -71,61 +234,115 @@ def default_construct_config() -> ConstructConfiguration: field_index_types="int", ) -@contextmanager -def _set_configuration(config_group: str, kwargs: U) -> Iterator[ContextVar[GlobalConfiguration]]: - try: - previous = copy.deepcopy(GlobalConfiguration(**CONFIGURATION.get())) - except LookupError: - previous = {"construct": default_construct_config(), "render": default_renderer_config()} - CONFIGURATION.set(previous) +def get_configuration(key: str) -> RawType: + """Retrieve the configuration or (silently) return the default. - try: - CONFIGURATION.get()[config_group].update(kwargs) + Helps avoid a proliferation of `set_configuration` calls by not erroring if it has not been called. + However, the cost is that the user may accidentally put configuration-affected logic outside a + set_configuration call and be surprised that the behaviour is inexplicibly not as expected. - yield CONFIGURATION - finally: - CONFIGURATION.set(previous) + Args: + key: configuration key to retrieve. -def get_configuration(key: str) -> RawType: + Returns: (preferably) a JSON/YAML-serializable construct. + """ try: - return CONFIGURATION.get()["construct"].get(key) # type: ignore + return getattr(CONFIGURATION.get().construct, key) # type: ignore except LookupError: - return default_construct_config().get(key) + return getattr(ConstructConfiguration(), key) def get_render_configuration(key: str) -> RawType: + """Retrieve configuration for the active renderer. + + Finds the current user-set configuration, defaulting back to the chosen renderer module's declared + defaults. + + Args: + key: configuration key to retrieve. + + Returns: (preferably) a JSON/YAML-serializable construct. + """ try: - return CONFIGURATION.get()["render"].get(key) # type: ignore + return CONFIGURATION.get().render.get(key) # type: ignore except LookupError: return default_renderer_config().get(key) -class Reference(Generic[U], Symbol): +class WorkflowComponent: + """Base class for anything directly tied to an individual `Workflow`. + + Attributes: + __workflow__: the `Workflow` that this is tied to. + """ + + __workflow_real__: WorkflowProtocol + + def __init__(self, *args, workflow: WorkflowProtocol, **kwargs): + """Tie to a `Workflow`. + + All subclasses must call this. + + Args: + workflow: the `Workflow` to tie to. + *args: remainder of arguments for other initializers. + **kwargs: remainder of arguments for other initializers. + """ + self.__workflow__ = workflow + super().__init__(*args, **kwargs) + + @property + def __workflow__(self) -> WorkflowProtocol: + """Workflow to which this reference applies.""" + return self.__workflow_real__ + + @__workflow__.setter + def __workflow__(self, workflow: WorkflowProtocol) -> None: + """Workflow to which this reference applies.""" + self.__workflow_real__ = workflow + + +class Reference(Generic[U], Symbol, WorkflowComponent): """Superclass for all symbolic references to values.""" _type: type[U] | None = None - __workflow__: WorkflowProtocol __iterated__: bool = False def __init__(self, *args, typ: type[U] | None = None, **kwargs): + """Extract any specified type. + + Args: + typ: type to override any inference, or None. + *args: any other arguments for other initializers (e.g. mixins). + **kwargs: any other arguments for other initializers (e.g. mixins). + """ self._type = typ - super().__init__() + super().__init__(*args, **kwargs) @property def name(self): + """Printable name of the reference.""" return self.__name__ def __new__(cls, *args, **kwargs): + """As all references are sympy Expressions, the real returned object must be made with Expr.""" instance = Expr.__new__(cls) instance._assumptions0 = {} return instance @property def __root_name__(self) -> str: + """Root name on which to suffix/prefix any derived names (with fields, etc.). + + For example, the base name of `add_thing-12345[3]` should be `add_thing`. + + Returns: basic name as a string. + """ raise NotImplementedError( "Reference must have a '__root_name__' property or override '__name__'" ) @property - def __type__(self): + def __type__(self) -> type: + """Type of the reference target, if known.""" if self._type is not None: return self._type raise NotImplementedError() @@ -134,6 +351,11 @@ def _raise_unevaluatable_error(self): raise UnevaluatableError(f"This reference, {self.__name__}, cannot be evaluated during construction.") def __eq__(self, other) -> bool: + """Test equality at construct-time, if sensible. + + Raises: + UnevaluatableError: if it seems the user is confusing this with a runtime check. + """ if isinstance(other, list) or other is None: return False if not isinstance(other, Reference) and not isinstance(other, Basic): @@ -141,46 +363,107 @@ def __eq__(self, other) -> bool: return super().__eq__(other) def __float__(self) -> bool: + """Catch accidental float casts. + + Raises: + UnevaluatableError: unconditionally, as it seems the user is confusing this with a runtime check. + """ self._raise_unevaluatable_error() return False def __int__(self) -> bool: + """Catch accidental int casts. + + Raises: + UnevaluatableError: unconditionally, as it seems the user is confusing this with a runtime check. + """ self._raise_unevaluatable_error() return False def __bool__(self) -> bool: + """Catch accidental bool casts. + + Note that this means ambiguous checks such as `if ref: ...` will error, and should be `if ref is None: ...` + or similar. + + Raises: + UnevaluatableError: unconditionally, as it seems the user is confusing this with a runtime check. + """ self._raise_unevaluatable_error() return False @property def __name__(self) -> str: - """Referral name for this reference.""" + """Referral name for this reference. + + Returns: an internal name to refer to the reference target. + """ workflow = self.__workflow__ name = self.__root_name__ return workflow.remap(name) if workflow is not None else name def __str__(self) -> str: - """Global description of the reference.""" + """Global description of the reference. + + Returns the _internal_ name. + """ return self.__name__ class IterableMixin(Reference[U]): + """Functionality for iterating over references to give new references.""" __fixed_len__: int | None = None def __init__(self, typ: type[U] | None=None, **kwargs): + """Extract length, if available from type. + + Captures types of the form (e.g.) `tuple[int, float]` and records the length + as being (e.g.) 2. + """ super().__init__(typ=typ, **kwargs) base = strip_annotations(self.__type__)[0] if get_origin(base) == tuple and (args := get_args(base)): # In the special case of an explicitly-typed tuple, we can state a length. self.__fixed_len__ = len(args) - def __len__(self): + def __len__(self) -> int: + """Length of this iterable, if available. + + The two cases that this is likely to be available are if the reference target + has been type-hinted as a `tuple` with a specific, fixed number of type arguments, + or if the target has been annotated with `Fixed(...)` indicating that the length + of the default value can be hard-coded into the output, and therefore that it can + be used for graph-building logic. The most useful application of this is likely to + be in for-loops and generators, as we can create variable references for each iteration + but can nonetheless execute the loop as we know how many iterations occur. + + Returns: length of the iterable, if available. + """ + if self.__fixed_len__ is None: + raise TypeError( + "This iterable reference does not have a known fixed length, " + "consider using `Fixed[...]` with a default, or typehint using `tuple[int, float]` (or similar) " + "to tell dewret how long it should be." + ) return self.__fixed_len__ - def __iter__(self): + def __iter__(self) -> Generator[Reference[U], None, None]: + """Execute the iteration. + + Note that this does _not_ return the yielded values of `__inner_iter__`, so that + it can be used with impunity to do actual iteration, and we will _always_ return + references here. + + Returns: a generator that will give a new reference for every iteration. + """ for count, _ in enumerate(self.__inner_iter__()): yield super().__getitem__(count) def __inner_iter__(self) -> Generator[Any, None, None]: + """Overrideable iterator for looping over the wrapped object. + + Returns: a generator that will yield ints if this reference is known to be fixed length, or will go + forever yielding None otherwise. + """ if self.__fixed_len__ is not None: for i in range(self.__fixed_len__): yield i @@ -189,18 +472,43 @@ def __inner_iter__(self) -> Generator[Any, None, None]: yield None def __getitem__(self, attr: str | int) -> Reference[U]: + """Get a reference to an individual item/field. + + Args: + attr: index or fieldname. + + Returns: a reference to the same target as this reference but a level deeper. + """ return super().__getitem__(attr) class IteratedGenerator(Generic[U]): + """Sentinel value for capturing that an iteration has occured without performing it. + + Allows us to lazily evaluate a loop, for instance, in the renderer. This may be relevant + if the renderer wishes to postpone iteration to runtime, and simply record it is required, + rather than evaluating the iterator. + """ __wrapped__: Reference[U] def __init__(self, to_wrap: Reference[U]): + """Capture wrapped reference. + + Args: + to_wrap: reference to wrap. + """ self.__wrapped__ = to_wrap - def __iter__(self): + def __iter__(self) -> Generator[Reference[U], None, None]: + """Loop through the wrapped reference. + + This will tag the references that are returned, so that the renderer can see this has + happened. + + Returns: a generator looping over the wrapped reference with a counter as the "field". + """ count = -1 for _ in self.__wrapped__.__inner_iter__(): - ref = self.__wrapped__.__make_reference__(field=(count := count + 1)) + ref = self.__wrapped__.__make_reference__(workflow=self.__wrapped__.__workflow__, field=(count := count + 1)) ref.__iterated__ = True yield ref diff --git a/src/dewret/render.py b/src/dewret/render.py index e484eb5a..7127fcce 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -1,45 +1,54 @@ import sys -import importlib from pathlib import Path from functools import partial -from typing import Protocol, TypeVar, Any, Unpack, TypedDict, Callable +from typing import TypeVar, Callable, cast import yaml from .workflow import Workflow, NestedStep -from .core import RawType -from .workflow import Workflow +from .core import RawType, RenderCall, BaseRenderModule, RawRenderModule, StructuredRenderModule, RenderConfiguration from .utils import load_module_or_package -RenderConfiguration = TypeVar("RenderConfiguration", bound=dict[str, Any]) +T = TypeVar("T") -class RawRenderModule(Protocol): - def render_raw(self, workflow: Workflow, **kwargs: RenderConfiguration) -> dict[str, str]: - ... +def structured_to_raw(rendered: RawType, pretty: bool=False) -> str: + """Serialize a serializable structure to a string. -class StructuredRenderModule(Protocol): - def render(self, workflow: Workflow, **kwargs: RenderConfiguration) -> dict[str, dict[str, RawType]]: - ... + Args: + rendered: a possibly-nested, static basic Python structure. + pretty: whether to attempt YAML dumping with an indent of 2. -def structured_to_raw(rendered: RawType, pretty: bool=False) -> str: + Returns: YAML/stringified version of the structure. + """ if pretty: output = yaml.safe_dump(rendered, indent=2) else: output = str(rendered) return output -def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, pretty: bool=False): - render_module: RawRenderModule | StructuredRenderModule +def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, pretty: bool=False) -> RenderCall: + """Create a ready-made callable to render the workflow that is appropriate for the renderer module. + + Args: + renderer: a module or path to a module. + pretty: whether the renderer should attempt to YAML-format the output (if relevant). + + Returns: a callable with a consistent interface, regardless of the renderer type. + """ + render_module: BaseRenderModule if isinstance(renderer, Path): if (render_dir := str(renderer.parent)) not in sys.path: sys.path.append(render_dir) # Attempt to load renderer as package, falling back to a single module otherwise. # This enables relative imports in renderers and therefore the ability to modularize. - render_module = load_module_or_package("__renderer__", renderer) - sys.modules["__renderer_mod__"] = render_module + module = load_module_or_package("__renderer__", renderer) + sys.modules["__renderer_mod__"] = module + render_module = cast(BaseRenderModule, module) else: render_module = renderer - if hasattr(render_module, "render_raw"): + if not isinstance(render_module, StructuredRenderModule): + if not isinstance(render_module, RawRenderModule): + raise NotImplementedError("This render module neither seems to be a structured nor a raw render module.") return render_module.render_raw def _render(workflow: Workflow, render_module: StructuredRenderModule, pretty=False, **kwargs: RenderConfiguration) -> dict[str, str]: @@ -49,9 +58,8 @@ def _render(workflow: Workflow, render_module: StructuredRenderModule, pretty=Fa for key, value in rendered.items() } - return partial(_render, render_module=render_module, pretty=pretty) + return cast(RenderCall, partial(_render, render_module=render_module, pretty=pretty)) -T = TypeVar("T") def base_render( workflow: Workflow, build_cb: Callable[[Workflow], T] ) -> dict[str, T]: @@ -59,7 +67,7 @@ def base_render( Args: workflow: workflow to evaluate result. - **kwargs: additional configuration arguments - these should match CWLRendererConfiguration. + build_cb: a callback to call for each workflow found. Returns: Reduced form as a native Python dict structure for diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 1b4d08e5..233f3a67 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -41,7 +41,7 @@ Unset, expr_to_references ) -from dewret.utils import flatten, DataclassProtocol, firm_to_raw, flatten_if_set +from dewret.utils import crawl_raw, DataclassProtocol, firm_to_raw, flatten_if_set from dewret.render import base_render from dewret.core import get_render_configuration, set_render_configuration @@ -213,7 +213,7 @@ def render(self) -> dict[str, RawType]: Reduced form as a native Python dict structure for serialization. """ - representation = {} + representation: dict[str, RawType] = {} if self.source is not None: representation["source"] = self.source if self.value_from is not None: @@ -285,7 +285,7 @@ def render(self) -> dict[str, RawType]: ) for key, ref in self.in_.items() }, - "out": flatten(self.out), + "out": crawl_raw(self.out), } @@ -536,7 +536,7 @@ def render(self) -> dict[str, RawType]: # Would rather not cast, but CommandInputSchema is dict[RawType] # by construction, where type is seen as a TypedDict subclass. item = firm_to_raw(cast(FirmType, input.type)) - if not isinstance(input.default, Unset): + if isinstance(item, dict) and not isinstance(input.default, Unset): item["default"] = firm_to_raw(input.default) result[key] = item return result @@ -605,9 +605,9 @@ def render(self) -> dict[str, RawType] | list[RawType]: serialization. """ return [ - flatten(output) for output in self.outputs + crawl_raw(output) for output in self.outputs ] if isinstance(self.outputs, list) else { - key: flatten(output) + key: crawl_raw(output) for key, output in self.outputs.items() } diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 1cc58b93..e4c7274e 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -23,13 +23,13 @@ import importlib import importlib.util from types import FrameType, TracebackType, UnionType, ModuleType -from typing import Any, cast, Union, Protocol, ClassVar, Callable, Iterable, get_args, get_origin, Annotated +from typing import Any, cast, Protocol, ClassVar, Callable, Iterable, get_args, Hashable from pathlib import Path from collections.abc import Sequence, Mapping from dataclasses import asdict, is_dataclass from sympy import Basic, Integer, Float, Rational -from .core import Reference, BasicType, RawType, FirmType, Raw +from .core import Reference, RawType, FirmType, Raw class Unset: @@ -63,23 +63,41 @@ def make_traceback(skip: int = 2) -> TracebackType | None: return tb def load_module_or_package(target_name: str, path: Path) -> ModuleType: - # Try to import the workflow as a package, if possible, to allow relative imports. - try: - spec = importlib.util.spec_from_file_location(target_name, str(path.parent / "__init__.py")) + """Convenience loader for modules. + + If an `__init__.py` is found in the same location as the target, it will try to load the renderer module + as if it is contained in a package and, if it cannot, will fall back to loading the single file. + + Args: + target_name: module name that should appear in `sys.modules`. + path: location of the module. + + Returns: the loaded module. + """ + module: None | ModuleType = None + exception: None | Exception = None + package_init = path.parent / "__init__.py" + # Try to import the module as a package, if possible, to allow relative imports. + if package_init.exists(): + try: + spec = importlib.util.spec_from_file_location(target_name, str(package_init)) + if spec is None or spec.loader is None: + raise ImportError(f"Could not open {path.parent} package") + module = importlib.util.module_from_spec(spec) + sys.modules[target_name] = module + spec.loader.exec_module(module) + module = importlib.import_module(f"{target_name}.{path.stem}", target_name) + except ImportError as exc: + exception = exc + + if module is None: + spec = importlib.util.spec_from_file_location(target_name, str(path)) if spec is None or spec.loader is None: - raise ImportError(f"Could not open {path.parent} package") + raise ImportError(f"Could not open {path} module") from exception module = importlib.util.module_from_spec(spec) - sys.modules[target_name] = module spec.loader.exec_module(module) - workflow = importlib.import_module(f"{target_name}.{path.stem}", target_name) - except ImportError as exc: - spec = importlib.util.spec_from_file_location(target_name, str(path)) - if spec is None or spec.loader is None: - raise ImportError(f"Could not open {path} module") from exc - workflow = importlib.util.module_from_spec(spec) - spec.loader.exec_module(workflow) - return workflow + return module def flatten_if_set(value: Any) -> RawType | Unset: """Takes a Raw-like structure and makes it RawType or Unset. @@ -92,40 +110,60 @@ def flatten_if_set(value: Any) -> RawType | Unset: """ if isinstance(value, Unset): return value - return flatten(value) + return crawl_raw(value) -def crawl_raw(value: Any, action: Callable[[Any], Any]) -> RawType: +def crawl_raw(value: Any, action: Callable[[Any], Any] | None = None) -> RawType: """Takes a Raw-like structure and makes it RawType. Particularly useful for squashing any TypedDicts. Args: value: value to squash - """ + action: an callback to apply to each found entry, or None. + + Returns: a structuure that is guaranteed to be raw. - value = action(value) + Raises: RuntimeError if it cannot convert the value to raw. + """ + if action is not None: + value = action(value) if value is None: return value if isinstance(value, str) or isinstance(value, bytes): return value if isinstance(value, Mapping): - return {key: flatten(item) for key, item in value.items()} + return {key: crawl_raw(item, action) for key, item in value.items()} if is_dataclass(value) and not isinstance(value, type): return crawl_raw(asdict(value), action) if isinstance(value, Sequence): - return [flatten(item) for item in value] + return [crawl_raw(item, action) for item in value] if (raw := ensure_raw(value)) is not None: return raw raise RuntimeError(f"Could not flatten: {value}") def firm_to_raw(value: FirmType) -> RawType: - return crawl_raw(value, lambda entry: list(entry) if isinstance(entry, tuple) else entry) + """Convenience wrapper for firm structures. + + Turns structures that would be raw, except for tuples, into raw structures + by mapping any tuples to lists. -def flatten(value: Any) -> RawType: - return crawl_raw(value, lambda entry: entry) + Args: + value: a firm structure (contains raw/tuple values). + + Returns: a raw structure. + """ + return crawl_raw(value, lambda entry: list(entry) if isinstance(entry, tuple) else entry) def is_expr(value: Any, permitted_references: type=Reference) -> bool: + """Confirms whether a structure has only raw or expression types. + + Args: + value: a crawlable structure. + permitted_references: a class representing the allowed types of References. + + Returns: True if valid, otherwise False. + """ return is_raw(value, lambda x: isinstance(x, Basic) or isinstance(x, tuple) or isinstance(x, permitted_references) or isinstance(x, Raw)) def is_raw_type(typ: type) -> bool: @@ -136,7 +174,17 @@ def is_raw_type(typ: type) -> bool: def is_firm(value: Any, check: Callable[[Any], bool] | None = None) -> bool: - return is_raw(value, lambda x: isinstance(x, tuple)) + """Confirms whether a function is firm. + + That is, whether its contents are raw or tuples. + + Args: + value: value to check. + check: any additional check to apply. + + Returns: True if is firm, else False. + """ + return is_raw(value, lambda x: isinstance(x, tuple) and (check is None or check(x))) def is_raw(value: Any, check: Callable[[Any], bool] | None = None) -> bool: """Check if a variable counts as "raw". @@ -195,13 +243,23 @@ def hasher(construct: FirmType) -> str: Hash string that should be unique to the construct. The limits of this uniqueness have not yet been explicitly calculated. """ - if isinstance(construct, Sequence) and not isinstance(construct, bytes | str): - if isinstance(construct, Mapping): - construct = list([k, hasher(v)] for k, v in sorted(construct.items())) - else: - # Cast to workaround recursive type - construct = cast(FirmType, list(construct)) - construct_as_string = json.dumps(construct) + def _make_hashable(construct: FirmType) -> Hashable: + hashed_construct: tuple[Hashable, ...] + if isinstance(construct, Sequence) and not isinstance(construct, bytes | str): + if isinstance(construct, Mapping): + hashed_construct = tuple((k, hasher(v)) for k, v in sorted(construct.items())) + else: + # Cast to workaround recursive type + hashed_construct = tuple(_make_hashable(v) for v in construct) + return hashed_construct + if not isinstance(construct, Hashable): + raise TypeError("Could not hash arguments") + return construct + if isinstance(construct, Hashable): + hashed_construct = construct + else: + hashed_construct = _make_hashable(construct) + construct_as_string = json.dumps(hashed_construct) hsh = hashlib.md5() hsh.update(construct_as_string.encode()) return hsh.hexdigest() diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 7f368c45..83d65624 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -20,19 +20,18 @@ from __future__ import annotations import inspect from collections.abc import Mapping, MutableMapping, Callable -import base64 from attrs import has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter, OrderedDict -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Iterable, get_origin, get_args, Generator, Sized, Sequence, get_type_hints +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, Iterable, get_origin, get_args, Generator, Sized, Sequence, get_type_hints, TYPE_CHECKING from uuid import uuid4 -from sympy import Symbol, Expr, Basic, Tuple, Dict, nan - import logging +from sympy import Symbol, Expr, Basic + logger = logging.getLogger(__name__) -from .core import RawType, IterableMixin, Reference, get_configuration, Raw, IteratedGenerator, strip_annotations +from .core import IterableMixin, Reference, get_configuration, Raw, IteratedGenerator, strip_annotations, WorkflowProtocol, WorkflowComponent from .utils import hasher, is_raw, make_traceback, is_raw_type, is_expr, Unset T = TypeVar("T") @@ -146,6 +145,7 @@ def __init__( default: value to infer type, etc. from. tethered: a workflow or step that demands this parameter; None if not yet present, False if not desired. autoname: whether we should customize this name for uniqueness (it is not user-set). + typ: a type to override what would be automatically inferred. """ self.__original_name__ = name @@ -163,13 +163,25 @@ def __init__( if tethered and isinstance(tethered, BaseStep): self.register_caller(tethered) - def is_loopable(self, typ: type): + @staticmethod + def is_loopable(typ: type): + """Checks if this type can be looped over. + + In particular, checks if this is an iterable that is NOT a str or bytes, possibly disguised + behind an Annotated. + + Args: + typ: type to check. + + Returns: True if loopable, otherwise False. + """ base = strip_annotations(typ)[0] base = get_origin(base) or base return inspect.isclass(base) and issubclass(base, Iterable) and not issubclass(base, str | bytes) @property def __type__(self): + """Type associated with this parameter.""" if self.__fixed_type__ is not UNSET: return self.__fixed_type__ @@ -185,9 +197,20 @@ def __type__(self): return raw_type def __eq__(self, other): + """Comparing two parameters. + + Currently, this uses the hashes. + + TODO: confirm this is an iff. + """ return hash(self) == hash(other) - def __new__(cls, *args, **kwargs): + def __new__(cls, *args, **kwargs) -> "Parameter": + """Creates a Parameter. + + Required, as Parameters are an instance of sympy Expression, so + we must instantiate via it. + """ instance = Expr.__new__(cls) instance._assumptions0 = {} return instance @@ -201,6 +224,17 @@ def __hash__(self) -> int: return hash(self.__name__) def make_reference(self, **kwargs) -> "ParameterReference": + """Creates a new reference for the parameter. + + The kwargs will be passed to the constructor, but the + + Args: + typ: type of the new reference's target. + **kwargs: arguments to pass to the constructor. + + Returns: checks if `typ` is loopable and gives an IterableParameterReference if so, + otherwise a normal ParameterReference. + """ kwargs["parameter"] = self kwargs.setdefault("typ", self.__type__) typ = kwargs["typ"] @@ -238,6 +272,13 @@ def register_caller(self, caller: BaseStep) -> None: self.__callers__.append(caller) def __getattr__(self, attr: str) -> Reference[T]: + """Retrieve a reference to a field within this Parameter. + + Arg: + attr: a field to find. + + Returns: a new reference with the `attr` appended to the field tuple. + """ return getattr(self.make_reference(workflow=None), attr) @@ -333,7 +374,7 @@ class Workflow: _steps: list["BaseStep"] tasks: MutableMapping[str, "Task"] - result: StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]] | None + result: StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any], ...] | None _remapping: dict[str, str] | None _name: str | None @@ -347,6 +388,10 @@ def __init__(self, name: str | None = None) -> None: @property def steps(self) -> set[BaseStep]: + """Get deduplicated steps. + + Returns: steps for looping over without duplicates. + """ return set(self._steps) def __str__(self) -> str: @@ -356,6 +401,12 @@ def __str__(self) -> str: return self.name def __repr__(self) -> str: + """Representation of the workflow. + + This will be the `id` if the workflow is anonymous. + + Returns: string identifier for this workflow. + """ if self._name: return self.name return self.id @@ -366,6 +417,7 @@ def __hash__(self) -> int: @property def id(self) -> str: + """Consistent ID based off the step IDs.""" comp_tup = tuple(sorted(s.id for s in self.steps)) return f"workflow-{hasher(comp_tup)}" @@ -388,7 +440,13 @@ def __eq__(self, other: object) -> bool: @property def has_result(self) -> bool: - return not(self.result is None or self.result is []) + """Confirms whether this workflow has a non-empty result. + + Either None or an empty list/tuple are considered empty for this purpose. + + Returns: True if the workflow has a result, False otherwise. + """ + return not(self.result is None or (isinstance(self.result, list | tuple) and not self.result)) @property def name(self) -> str: @@ -465,7 +523,7 @@ def assimilate(cls, *workflow_args: Workflow) -> "Workflow": #for step in list(left_steps.values()) + list(right_steps.values()): step.set_workflow(base) - indexed_steps = {} + indexed_steps: dict[str, BaseStep] = {} for step_id, step in all_steps: indexed_steps.setdefault(step_id, step) if step != indexed_steps[step_id]: @@ -474,7 +532,7 @@ def assimilate(cls, *workflow_args: Workflow) -> "Workflow": ) all_tasks = sum((list(w.tasks.items()) for w in workflows), []) - indexed_tasks = {} + indexed_tasks: dict[str, Task] = {} for task_id, task in all_tasks: indexed_tasks.setdefault(task_id, task) if task != indexed_tasks[task_id]: @@ -518,7 +576,7 @@ def simplify_ids(self, infix: list[str] | None = None) -> None: counter = Counter[Task | Workflow]() self._remapping = {} infix_str = ("-".join(infix) + "-") if infix else "" - for key, step in self.indexed_steps.items(): + for step in self.indexed_steps.values(): counter[step.task] += 1 self._remapping[step.id] = f"{step.task}-{infix_str}{counter[step.task]}" if isinstance(step, NestedStep): @@ -570,7 +628,11 @@ def add_nested_step( Args: name: name of the subworkflow. subworkflow: the subworkflow itself. + return_type: a forced type for the return, or None. kwargs: any key-value arguments to pass in the call. + positional_args: a mapping of arguments to bools, True if the argument is positional or otherwise False. + + Returns: a reference to the step that calls out to a new workflow. """ step = NestedStep(self, name, subworkflow, kwargs) if positional_args is not None: @@ -579,7 +641,7 @@ def add_nested_step( return_type = return_type or step.return_type if return_type is inspect._empty: raise TypeError("All tasks should have a type annotation.") - return step.make_reference(typ=return_type) + return step.make_reference(workflow=self, typ=return_type) def add_step( self, @@ -599,6 +661,7 @@ def add_step( kwargs: any key-value arguments to pass in the call. raw_as_parameter: whether to turn any discovered raw arguments into workflow parameters. is_factory: whether this step should be a Factory. + positional_args: a mapping of arguments to bools, True if the argument is positional or otherwise False. """ task = self.register_task(fn) step_maker = FactoryCall if is_factory else Step @@ -613,33 +676,38 @@ def add_step( and not inspect.isclass(fn) ): raise TypeError("All tasks should have a type annotation.") - return step.make_reference(typ=return_type) + return step.make_reference(workflow=self, typ=return_type) @staticmethod def from_result( - result: StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]], simplify_ids: bool = False, nested: bool = True + result: StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any], ...], simplify_ids: bool = False, nested: bool = True ) -> Workflow: """Create from a desired result. Starts from a result, and builds a workflow to output it. """ - result, refs = expr_to_references(result) - if not refs: + expr, _refs = expr_to_references(result) + if not _refs: raise RuntimeError( "Attempted to build a workflow from a return-value/result/expression with no references." ) - refs = list(refs) + refs = list(_refs) + if isinstance(refs[0], Parameter): + raise RuntimeError("Attempted to build a workflow from an input parameter.") workflow = refs[0].__workflow__ # Ensure that we have exactly one workflow, even if multiple results. for entry in refs[1:]: if entry.__workflow__ != workflow: raise RuntimeError("If multiple results, they must share a single workflow") - workflow.set_result(result) + workflow.set_result(expr) if simplify_ids: workflow.simplify_ids() + if not isinstance(workflow, Workflow): + # This may be better as a cast, since this seems an artificial case. + raise TypeError("Can only return a workflow of the same type") return workflow - def set_result(self, result: StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]) -> None: + def set_result(self, result: Basic | list[Basic] | tuple[Basic]) -> None: """Choose the result step. Sets a step as being the result for the entire workflow. @@ -658,7 +726,8 @@ def set_result(self, result: StepReference[Any] | list[StepReference[Any]] | tup self.result = result @property - def result_type(self): + def result_type(self) -> type: + """Overall return type of this workflow.""" if self.result is None: return type(None) if hasattr(self.result, "__type__"): @@ -667,27 +736,6 @@ def result_type(self): return type(self.result) -class WorkflowComponent: - """Base class for anything directly tied to an individual `Workflow`. - - Attributes: - __workflow__: the `Workflow` that this is tied to. - """ - - __workflow__: Workflow - - def __init__(self, *args, workflow: Workflow, **kwargs): - """Tie to a `Workflow`. - - All subclasses must call this. - - Args: - workflow: the `Workflow` to tie to. - """ - self.__workflow__ = workflow - super().__init__(*args, **kwargs) - - class WorkflowLinkedComponent(Protocol): """Protocol for objects dynamically tied to a `Workflow`.""" @@ -705,39 +753,90 @@ def __workflow__(self) -> Workflow: class FieldableProtocol(Protocol): - __field__: tuple[str, ...] - __field_sep__: str - __field_index_types__: tuple[type, ...] + """Expected interfaces for a type that can take fields. + + Attributes: + __field__: tuple representing the named fields, either strings or integers. + """ + __field__: tuple[str | int, ...] + + @property + def __field_sep__(self) -> str: + """The separator that should be used by default when rendering a name.""" + + @property + def __field_index_types__(self) -> tuple[type, ...]: + """Get the types that will be rendered as an index [x] rather than a field .x on stringifying. + + Will be taken from the `field_index_types` construct configuration. + """ def __init__(self, *args, field: str | None = None, **kwargs): + """Extract the field name from the initializer arguments, if provided.""" super().__init__(*args, **kwargs) @property - def __type__(self): + def __name__(self: FieldableProtocol) -> str: + """Name of the fieldable, which may not be `name` if this is needed as an internal reference.""" + return self.name + + @property + def __type__(self) -> type: + """Type of this field with the overall target.""" ... @property - def name(self): + def name(self) -> str: + """The name for the target, accounting for the field.""" return "name" - def __make_reference__(self, *args, **kwargs) -> "FieldableProtocol": + def __make_reference__(self, *args, **kwargs) -> Reference[Any]: + """Create a reference with constructor arguments, usually to a subfield.""" ... +# This required so that Mixins can trust that +# superclasss attributes will be available for mypy, +# but do not confuse the MRO. +if TYPE_CHECKING: + _Fieldable = FieldableProtocol +else: + _Fieldable = object + # Subclass Reference so that we know Reference methods/attrs are available. -class FieldableMixin: - def __init__(self: FieldableProtocol, *args, field: str | int | tuple | None = None, **kwargs): +class FieldableMixin(_Fieldable): + """Tooling for enhancing a type with referenceable fields.""" + + def __init__(self: FieldableProtocol, *args, field: str | int | tuple[str | int, ...] | None = None, **kwargs): + """Extract the requested field, if any, from the initializer arguments. + + Args: + field: the new subfield, either an index (int) or a fieldname (str) or a field tuple representing the whole path, or None. + *args: arguments to pass to the other initializers. + **kwargs: arguments to pass to the other initializers. + """ self.__field__: tuple[str, ...] = (field if isinstance(field, tuple) else (field,)) if field is not None else () super().__init__(*args, **kwargs) @property def __field_sep__(self) -> str: - return get_configuration("field_separator") + """Get the field separator. + + Will be taken from the configuration key, `field_separator`. + """ + sep = get_configuration("field_separator") + if not isinstance(sep, str): + raise TypeError(f"The `field_separator` configuration argument must be a string not {type(sep)}") + return sep @property def __field_index_types__(self) -> tuple[type, ...]: + """Get the types that will be rendered as an index [x] rather than a field .x on stringifying. + + Will be taken from the `field_index_types` construct configuration. + """ types = get_configuration("field_index_types") if not isinstance(types, str): - raise TypeError("Field index types must be provided as a comma-separated names") + raise TypeError(f"The `field_index_types` configuration argument should be a comma-separated string of types from: {", ".join(AVAILABLE_TYPES.keys())}") tup_all = tuple(AVAILABLE_TYPES.get(typ) for typ in types.split(",") if typ) tup = tuple(t for t in tup_all if t is not None) if tup_all != tup: @@ -747,21 +846,14 @@ def __field_index_types__(self) -> tuple[type, ...]: ) return tup - @property - def __name__(self: FieldableProtocol) -> str: - """Name for this step. - - May be remapped by the workflow to something nicer - than the ID. - """ - return super().__name__ + self.__field_suffix__ - @property def __field_str__(self) -> str: + """Stringified field, without the name.""" return self.__field_suffix__.lstrip(self.__field_sep__) @property def __field_suffix__(self) -> str: + """Stringified field, without the name but with an appropriate separator at the start.""" result = "" for cmpt in self.__field__: if any(isinstance(cmpt, typ) for typ in self.__field_index_types__): @@ -770,13 +862,26 @@ def __field_suffix__(self) -> str: result += f"{self.__field_sep__}{cmpt}" return result - def find_field(self: FieldableProtocol, field, fallback_type: type | None = None, **init_kwargs: Any) -> Reference: + @property + def __name__(self) -> str: + """Name for this step. + + May be remapped by the workflow to something nicer + than the ID. + """ + return super().__name__ + self.__field_suffix__ + + def find_field(self: FieldableProtocol, field: str | int, fallback_type: type | None = None, **init_kwargs: Any) -> Reference: """Field within the reference, if possible. + Args: + field: the field to search for. + fallback_type: the type to use if we do not know a more specific one. + **init_kwargs: arguments to use for constructing a new reference (via `__make_reference__`). + Returns: A field-specific version of this reference. """ - # Get new type, for the specific field. parent_type, _ = strip_annotations(self.__type__) field_type = fallback_type @@ -798,34 +903,37 @@ def find_field(self: FieldableProtocol, field, fallback_type: type | None = None if field_type is None: field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type except StopIteration: - raise AttributeError(f"Dataclass {parent_type} does not have field {field}") + raise AttributeError(f"Dataclass {parent_type} does not have field {field}") from None elif attr_has(parent_type): resolve_types(parent_type) try: field_type = getattr(attrs_fields(parent_type), field).type - except AttributeError: - raise AttributeError(f"attrs-class {parent_type} does not have field {field}") + except AttributeError as exc: + raise AttributeError(f"attrs-class {parent_type} does not have field {field}") from exc # TypedDict elif inspect.isclass(parent_type) and issubclass(parent_type, dict) and hasattr(parent_type, "__annotations__"): try: field_type = get_type_hints(parent_type, include_extras=True)[field] - except KeyError: - raise AttributeError(f"TypedDict {parent_type} does not have field {field}") + except KeyError as exc: + raise AttributeError(f"TypedDict {parent_type} does not have field {field}") from exc if not field_type and get_configuration("allow_plain_dict_fields") and inspect.isclass(base) and issubclass(base, dict): args = get_args(parent_type) if len(args) == 2 and args[0] is str: field_type = args[1] else: - raise AttributeError(f"Can only get fields for plain dicts if annotated dict[str, TYPE]") + raise AttributeError("Can only get fields for plain dicts if annotated dict[str, TYPE]") if field_type: if not issubclass(self.__class__, Reference): raise TypeError("Only references can have a fieldable mixin") + new_field: tuple[str | int, ...] | str | int if self.__field__: - field = tuple(list(self.__field__) + [field]) + new_field = tuple(list(self.__field__) + [field]) + else: + new_field = field - return self.__make_reference__(typ=field_type, field=field, **init_kwargs) + return self.__make_reference__(typ=field_type, field=new_field, **init_kwargs) raise AttributeError( f"Could not determine the type for field {field} in type {parent_type} (type of parent type is {type(parent_type)})" @@ -844,7 +952,7 @@ class BaseStep(WorkflowComponent): _id: str | None = None task: Task | Workflow - arguments: Mapping[str, Reference | Raw] + arguments: Mapping[str, Basic | Reference | Raw] workflow: Workflow positional_args: dict[str, bool] | None = None @@ -884,20 +992,21 @@ def __init__( and is_raw(value) ): if raw_as_parameter: - value = param(key, value, tethered=None).make_reference(workflow=workflow) + # We use param for convenience but note that it is a reference in disguise. + value = cast(Parameter, param(key, value, tethered=None)).make_reference(workflow=workflow) else: value = Raw(value) def _to_param_ref(value): if isinstance(value, Parameter): return value.make_parameter(workflow=workflow) - value, refs = expr_to_references(value, remap=_to_param_ref) + expression, refs = expr_to_references(value, remap=_to_param_ref) for ref in refs: if isinstance(ref, ParameterReference): parameter = ref._.parameter parameter.register_caller(self) - self.arguments[key] = value + self.arguments[key] = expression else: raise RuntimeError( f"Non-references must be a serializable type: {key}>{value} {type(value)}" @@ -918,6 +1027,15 @@ def __eq__(self, other: object) -> bool: ) def make_reference(self, **kwargs) -> "StepReference": + """Create a reference to this step. + + Builds a reference to the (result of) this step, which will be iterable if appropriate. + + Args: + **kwargs: arguments for reference constructor, which will be supplemented appropriately. + + Returns: a reference to the result of this step. + """ kwargs["step"] = self kwargs.setdefault("typ", self.return_type) typ = kwargs["typ"] @@ -928,6 +1046,7 @@ def make_reference(self, **kwargs) -> "StepReference": return StepReference(**kwargs) def __hash__(self) -> int: + """Searchable hash for this step.""" return hash(self.id) def set_workflow(self, workflow: Workflow, with_arguments: bool = True) -> None: @@ -1016,7 +1135,7 @@ def __init__( workflow: Workflow, name: str, subworkflow: Workflow, - arguments: Mapping[str, Reference | Raw], + arguments: Mapping[str, Basic | Reference | Raw], raw_as_parameter: bool = False, ): """Create a NestedStep. @@ -1029,7 +1148,7 @@ def __init__( raw_as_parameter: whether raw-type arguments should be made (outer) workflow parameters. """ self.__subworkflow__ = subworkflow - base_arguments = {p.name: p for p in subworkflow.find_parameters()} + base_arguments: dict[str, Basic | Reference | Raw] = {p.name: p for p in subworkflow.find_parameters()} base_arguments.update(arguments) super().__init__( workflow=workflow, @@ -1053,8 +1172,8 @@ def return_type(self) -> Any: Returns: Expected type of the return value. """ - return super().return_type - if self.__subworkflow__.result is None or self.__subworkflow__.result is []: + result = self.__subworkflow__.result + if result is None or (isinstance(result, list | tuple) and not result): raise RuntimeError("Can only use a subworkflow if the reference exists.") return self.__subworkflow__.result_type @@ -1098,10 +1217,12 @@ def __init__( @property def __name__(self): + """Get the name of this factory call.""" return self.name @property def __original_name__(self) -> str: + """Original name of a factory call is just its normal name.""" return self.name @property @@ -1110,7 +1231,7 @@ def __default__(self) -> Unset: return UnsetType(self.return_type) -class ParameterReference(WorkflowComponent, FieldableMixin, Reference[U]): +class ParameterReference(FieldableMixin, Reference[U], WorkflowComponent): """Reference to an individual `Parameter`. Allows us to refer to the outputs of a `Parameter` in subsequent `Parameter` @@ -1126,6 +1247,11 @@ class ParameterReference(WorkflowComponent, FieldableMixin, Reference[U]): """ class ParameterReferenceMetadata(Generic[T]): + """Holder for attributes of this reference that we do not wish to risk confusing with fieldnames. + + Attributes: + parameter: the parameter to which this reference refers. + """ parameter: Parameter[T] def __init__(self, parameter: Parameter[T], *args, typ: type[U] | None=None, **kwargs): @@ -1134,6 +1260,9 @@ def __init__(self, parameter: Parameter[T], *args, typ: type[U] | None=None, **k Args: workflow: `Workflow` that this is tied to. parameter: `Parameter` that this refers to. + typ: type that should override inferred type or None. + *args: any arguments for other initializers. + **kwargs: any arguments for other initializers. """ self.parameter = parameter @@ -1149,7 +1278,7 @@ def unique_name(self) -> str: return self.parameter.__name__ @property - def __default__(self) -> T | Unset: + def __default__(self) -> U | Unset: """Default value of the parameter.""" default = self._.parameter.default if isinstance(default, Unset): @@ -1172,11 +1301,27 @@ def __root_name__(self) -> str: return self._.parameter.name def __init__(self, parameter: Parameter[U], *args, typ: type[U] | None=None, **kwargs): + """Extract the parameter and type for setup. + + Args: + parameter: the parameter to reference. + typ: an overriding type for this reference, or None. + *args: arguments for other initializers. + **kwargs: arguments for other initializers. + """ typ = typ or parameter.__type__ self._ = self.ParameterReferenceMetadata(parameter, *args, typ, **kwargs) super().__init__(*args, typ=typ, **kwargs) def __getitem__(self, attr: str) -> "ParameterReference": + """Retrieve a field. + + Args: + attr: attribute to get. + + Returns: a reference to the field within this parameter, possibly nesting if we are already + referencing a field. + """ try: return self.find_field( field=attr, @@ -1194,12 +1339,21 @@ def __getitem__(self, attr: str) -> "ParameterReference": @property def __original_name__(self) -> str: + """The name of the original parameter, without any field, etc.""" return self._.parameter.__original_name__ def __getattr__(self, attr: str) -> "ParameterReference": + """Retrieve a field. + + Args: + attr: attribute to get. + + Returns: a reference to the field within this parameter, possibly nesting if we are already + referencing a field. + """ try: return self[attr] - except KeyError as exc: + except KeyError as _: return super().__getattribute__(attr) def __repr__(self) -> str: @@ -1233,11 +1387,18 @@ def __eq__(self, other: object) -> bool: (isinstance(other, ParameterReference) and self._.parameter == other._.parameter and self.__field__ == other.__field__) ) - def __make_reference__(self, **kwargs) -> "StepReference": + def __make_reference__(self, **kwargs) -> "ParameterReference": + """Get a reference for the same parameter.""" return self._.parameter.make_reference(**kwargs) class IterableParameterReference(IterableMixin, ParameterReference[U]): - def __iter__(self): + """Iterable form of parameter references.""" + def __iter__(self) -> Generator[Reference, None, None]: + """Iterate over this parameter. + + Returns: + Get values back for this parameter, normally references of the appropriate type. + """ inner, metadata = strip_annotations(self.__type__) if metadata and "AtRender" in metadata and isinstance(self.__default__, Iterable): yield from self.__default__ @@ -1245,6 +1406,13 @@ def __iter__(self): yield from super().__iter__() def __inner_iter__(self) -> Generator[Any, None, None]: + """Iterate over this parameter. + + This is intended for overriding as a convenience way of iterating over a parameter, + and the yielded values will be ignored, in favour of new references. + + Returns: a generator of any type. + """ inner, metadata = strip_annotations(self.__type__) if self.__fixed_len__ is not None: yield from range(self.__fixed_len__) @@ -1255,6 +1423,7 @@ def __inner_iter__(self) -> Generator[Any, None, None]: yield None def __len__(self): + """If it is possible to get a hard-codeable length from this iterable parameter, do so.""" inner, metadata = strip_annotations(self.__type__) if metadata and "Fixed" in metadata and isinstance(self.__default__, Sized): return len(self.__default__) @@ -1267,12 +1436,19 @@ class StepReference(FieldableMixin, Reference[U]): arguments. Attributes: - step: `Step` referred to. + _: metadata wrapping the `Step` referred to. """ step: BaseStep class StepReferenceMetadata: + """Wrapper for any metadata that we would not want to conflict with fieldnames. + + Attributes: + step: the step being wrapped. + _typ: the type to return, if overriding the step's own type, or None. + """ + def __init__( self, step: BaseStep, typ: type[U] | None = None ): @@ -1288,7 +1464,8 @@ def __init__( self._typ = typ @property - def return_type(self): + def return_type(self) -> type: + """Type of this reference, which may be overridden or may be the step's own type.""" return self._typ or self.step.return_type _: StepReferenceMetadata @@ -1303,6 +1480,8 @@ def __init__( step: `Step` that this refers to. typ: the type that the step will output. field: if provided, a specific field to pull out of an attrs result class. + *args: arguments for other initializers. + **kwargs: arguments for other initializers. """ typ = typ or step.return_type self._ = self.StepReferenceMetadata(step, typ=typ) @@ -1317,6 +1496,7 @@ def __repr__(self) -> str: return self._.step.id + self.__field_suffix__ def __hash__(self) -> int: + """Hashable value for this workflow.""" return hash((repr(self), id(self.__workflow__))) def __getitem__(self, attr: str) -> "StepReference[Any]": @@ -1349,6 +1529,7 @@ def __getitem__(self, attr: str) -> "StepReference[Any]": ) from exc def __getattr__(self, attr: str) -> "StepReference": + """Retrieve a field within this workflow.""" try: return self[attr] except KeyError as exc: @@ -1359,6 +1540,7 @@ def __getattr__(self, attr: str) -> "StepReference": @property def __type__(self) -> type: + """Get the type to which this step reference refers.""" return self._.return_type @property @@ -1371,7 +1553,7 @@ def __root_name__(self) -> str: return self._.step.name @property - def __workflow__(self) -> Workflow: + def __workflow__(self) -> WorkflowProtocol: """Related workflow. Returns: @@ -1389,11 +1571,21 @@ def __workflow__(self, workflow: Workflow) -> None: self._.step.set_workflow(workflow) def __make_reference__(self, **kwargs) -> "StepReference": + """Create a new reference for the same step.""" return self._.step.make_reference(**kwargs) class IterableStepReference(IterableMixin, StepReference[U]): - def __iter__(self): - yield IteratedGenerator(self) + """Iterable form of a step reference.""" + def __iter__(self) -> Generator[Reference, None, None]: + """Gets a sentinel value for iterating over this step's results. + + Bear in mind that this means an iterable step reference will iterate exactly once, + and return this Generator. The IteratedGenerator can be used as an iterator itself, + and will yield an infinite sequence of numbered field references, which renderers can use + for zipping with a fixed length iterator, or simply prepping fieldnames for serialization. + """ + # We cast this so that we can treat a step iterator as if it really loops over results. + yield cast(Reference, IteratedGenerator(self)) def merge_workflows(*workflows: Workflow) -> Workflow: """Combine several workflows into one. @@ -1424,8 +1616,16 @@ def is_task(task: Lazy) -> bool: """ return isinstance(task, LazyEvaluation) -def expr_to_references(expression: Any, remap: Callable[[Any], Any] | None = None) -> tuple[Basic | None, set[Reference | Parameter]]: - to_check = [] +def expr_to_references(expression: Any, remap: Callable[[Any], Any] | None = None) -> tuple[Basic | None, list[Reference | Parameter]]: + """Pull out any references, or other free symbols, from an expression. + + Args: + expression: normally a reference that can be immediately returned, but may be a sympy expression or a dict/tuple/list/etc. of such. + remap: a callable to project certain values down before extracting symbols, or None. + + Returns: a pair of the expression with any applied simplifications/standardizations, and the list of References/Parameters found. + """ + to_check: list[Reference | Parameter] = [] def _to_expr(value): if remap and (res := remap(value)) is not None: return _to_expr(res) @@ -1481,16 +1681,30 @@ def _to_expr(value): # raise RuntimeError("The only symbols allowed are references (to e.g. step or parameter)") return expression, to_check -def unify_workflows(expression: Any, base_workflow: Workflow | None, set_only: bool = False) -> Workflow | None: +def unify_workflows(expression: Any, base_workflow: Workflow | None, set_only: bool = False) -> tuple[Basic | None, Workflow | None]: + """Takes an expression and ensures all of its references exist in the same workflow. + + Args: + expression: any valid argument to `dewret.workflow.expr_to_references`. + base_workflow: the desired workflow to align on, or None. + set_only: whether to bother assimilating all the workflows (False), or to assume that has been done (False). + + Returns: a pair of the (standardized) expression and the workflow that now contains it. + """ expression, to_check = expr_to_references(expression) if not to_check: return expression, base_workflow # Build a unified workflow collected_workflow = base_workflow or next(iter(to_check)).__workflow__ + # This is a bit of an artifical check. We could rule out non-Workflow WorkflowProtocol implementers and simplify this. + if not isinstance(collected_workflow, Workflow): + raise NotImplementedError("Can only unify workflows on the same type: Workflow") if not set_only: for step_result in to_check: new_workflow = step_result.__workflow__ + if not isinstance(new_workflow, Workflow): + raise NotImplementedError("Can only unify workflows on the same type: Workflow") if collected_workflow != new_workflow and collected_workflow and new_workflow: collected_workflow = Workflow.assimilate(collected_workflow, new_workflow) diff --git a/tests/_lib/frender.py b/tests/_lib/frender.py index 257de0b6..83fc927c 100644 --- a/tests/_lib/frender.py +++ b/tests/_lib/frender.py @@ -4,11 +4,10 @@ """ from textwrap import indent -from typing import Unpack, TypedDict, Any +from typing import Unpack, TypedDict from dataclasses import dataclass -from contextvars import ContextVar -from dewret.core import RawType +from dewret.core import set_render_configuration from dewret.workflow import Workflow, Step, NestedStep from dewret.render import base_render @@ -17,10 +16,10 @@ class FrenderRendererConfiguration(TypedDict): allow_complex_types: bool -CONFIGURATION: ContextVar[FrenderRendererConfiguration] = ContextVar("configuration") -CONFIGURATION.set({ - "allow_complex_types": True -}) +def default_config() -> FrenderRendererConfiguration: + return FrenderRendererConfiguration({ + "allow_complex_types": True + }) @dataclass class NestedStepDefinition: @@ -103,7 +102,7 @@ def render_raw( Reduced form as a native Python dict structure for serialization. """ - CONFIGURATION.get().update(kwargs) + set_render_configuration(kwargs) return base_render( workflow, lambda workflow: WorkflowDefinition.from_workflow(workflow).render() From f1dd479c692ce865cf5296ed60c569453159029d Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 22:19:35 +0100 Subject: [PATCH 065/108] fix: throw import exceptions --- src/dewret/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dewret/utils.py b/src/dewret/utils.py index e4c7274e..895cd3cd 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -83,9 +83,9 @@ def load_module_or_package(target_name: str, path: Path) -> ModuleType: spec = importlib.util.spec_from_file_location(target_name, str(package_init)) if spec is None or spec.loader is None: raise ImportError(f"Could not open {path.parent} package") - module = importlib.util.module_from_spec(spec) - sys.modules[target_name] = module - spec.loader.exec_module(module) + package = importlib.util.module_from_spec(spec) + sys.modules[target_name] = package + spec.loader.exec_module(package) module = importlib.import_module(f"{target_name}.{path.stem}", target_name) except ImportError as exc: exception = exc From 5c2d7a29c7d5b3080632e9e8a7a1d1de164d8fe8 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 22:23:39 +0100 Subject: [PATCH 066/108] fix: add configuration for simplify_ids --- src/dewret/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dewret/core.py b/src/dewret/core.py index 8c7b642f..85f7ef59 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -129,6 +129,7 @@ class ConstructConfiguration: allow_plain_dict_fields: bool = False field_separator: str = "/" field_index_types: str = "int" + simplify_ids: bool = True class ConstructConfigurationTypedDict(TypedDict): """Basic configuration of the construction process. @@ -143,6 +144,7 @@ class ConstructConfigurationTypedDict(TypedDict): allow_plain_dict_fields: bool field_separator: str field_index_types: str + simplify_ids: bool @define class GlobalConfiguration: From 4aef9c494a1c68a6d7b60c1c660748653ee9ba97 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 22:56:53 +0100 Subject: [PATCH 067/108] docs: tidyup docstrings and structures --- src/dewret/__main__.py | 6 +- src/dewret/annotations.py | 70 ++++++++++++++++++---- src/dewret/backends/backend_dask.py | 5 +- src/dewret/core.py | 18 +++--- src/dewret/renderers/cwl.py | 2 +- src/dewret/tasks.py | 90 ++++++++++++++++------------- src/dewret/workflow.py | 21 ++----- tests/test_nested.py | 2 +- 8 files changed, 130 insertions(+), 84 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 26b49461..1816c30b 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -27,7 +27,7 @@ import re import yaml from typing import Any -import sys +from types import ModuleType import click import json @@ -90,9 +90,11 @@ def render( key, val = arg.split(":", 1) kwargs[key] = json.loads(val) - render_module: Path | RawRenderModule | StructuredRenderModule + render_module: Path | ModuleType if (mtch := re.match(r"^([a-z_0-9-.]+)$", renderer)): render_module = importlib.import_module(f"dewret.renderers.{mtch.group(1)}") + if not isinstance(render_module, RawRenderModule) and not isinstance(render_module, StructuredRenderModule): + raise NotImplementedError("The imported render module does not seem to match the `RawRenderModule` or `StructuredRenderModule` protocols.") elif renderer.startswith("@"): render_module = Path(renderer[1:]) else: diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index c561c28b..de46b2c6 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -3,19 +3,29 @@ import sys import importlib from functools import lru_cache -from types import FunctionType -from dataclasses import dataclass -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, TypeAliasType, Annotated, Callable, get_origin, get_args, Mapping, TypeAliasType, get_type_hints +from types import FunctionType, ModuleType +from typing import Any, TypeVar, Annotated, Callable, get_origin, get_args, Mapping, get_type_hints T = TypeVar("T") AtRender = Annotated[T, "AtRender"] Fixed = Annotated[T, "Fixed"] class FunctionAnalyser: + """Convenience class for analysing a function with reduced duplication of effort. + + Attributes: + _fn: the wrapped callable + _annotations: stored annotations for the function. + """ _fn: Callable[..., Any] _annotations: dict[str, Any] def __init__(self, fn: Callable[..., Any]): + """Set the function. + + If `fn` is a class, it takes the constructor, and if it is a method, it takes + the `__func__` attribute. + """ self.fn = ( fn.__init__ if inspect.isclass(fn) else @@ -25,11 +35,20 @@ def __init__(self, fn: Callable[..., Any]): ) @property - def return_type(self): + def return_type(self) -> type: + """Return type of the callable.""" return get_type_hints(inspect.unwrap(self.fn), include_extras=True)["return"] @staticmethod def _typ_has(typ: type, annotation: type) -> bool: + """Check if the type has an annotation. + + Args: + typ: type to check. + annotation: the Annotated to compare against. + + Returns: True if the type has the given annotation, otherwise False. + """ if not hasattr(annotation, "__metadata__"): return False if (origin := get_origin(typ)): @@ -40,33 +59,52 @@ def _typ_has(typ: type, annotation: type) -> bool: return False def get_all_module_names(self): - return sys.modules[self.fn.__module__].__annotations__ + """Find all of the annotations within this module.""" + return get_type_hints(sys.modules[self.fn.__module__], include_extras=True) def get_all_imported_names(self): + """Find all of the annotations that were imported into this module.""" return self._get_all_imported_names(sys.modules[self.fn.__module__]) @staticmethod @lru_cache - def _get_all_imported_names(mod): + def _get_all_imported_names(mod: ModuleType) -> dict[str, tuple[ModuleType, str]]: + """Get all of the names with this module, and their original locations. + + Args: + mod: a module in the `sys.modules`. + + Returns: + A dict whose keys are the known names in the current module, where the Callable lives, + and whose values are pairs of the module and the remote name. + """ ast_tree = ast.parse(inspect.getsource(mod)) imported_names = {} for node in ast.walk(ast_tree): if isinstance(node, ast.ImportFrom): for name in node.names: imported_names[name.asname or name.name] = ( - importlib.import_module("".join(["."] * node.level) + node.module, package=mod.__package__), + importlib.import_module("".join(["."] * node.level) + (node.module or ""), package=mod.__package__), name.name ) return imported_names @property - def free_vars(self): + def free_vars(self) -> dict[str, Any]: + """Get the free variables for this Callable.""" if self.fn.__code__ and self.fn.__closure__: - return dict(zip(self.fn.__code__.co_freevars, (c.cell_contents for c in self.fn.__closure__))) + return dict(zip(self.fn.__code__.co_freevars, (c.cell_contents for c in self.fn.__closure__), strict=False)) return {} def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> type | None: - all_annotations: dict[str, type] = {} + """Retrieve the annotations for this argument. + + Args: + arg: name of the argument. + exhaustive: True if we should search outside the function itself, into the module globals, and into imported modules. + + Returns: annotation if available, else None. + """ typ: type | None = None if (typ := self.fn.__annotations__.get(arg)): if isinstance(typ, str): @@ -84,14 +122,25 @@ def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> type | No return typ def argument_has(self, arg: str, annotation: type, exhaustive: bool=False) -> bool: + """Check if the named argument has the given annotation. + + Args: + arg: argument to retrieve. + annotation: Annotated to search for. + exhaustive: whether to check the globals and other modules. + + Returns: True if the Annotated is present in this argument's annotation. + """ typ = self.get_argument_annotation(arg, exhaustive) return bool(typ and self._typ_has(typ, annotation)) def is_at_construct_arg(self, arg: str, exhaustive: bool=False) -> bool: + """Convience function to check for `AtConstruct`, wrapping `FunctionAnalyser.argument_has`.""" return self.argument_has(arg, AtRender, exhaustive) @property def globals(self) -> Mapping[str, Any]: + """Get the globals for this Callable.""" try: fn_tuple = inspect.getclosurevars(self.fn) fn_globals = dict(fn_tuple.globals) @@ -102,6 +151,7 @@ def globals(self) -> Mapping[str, Any]: return fn_globals def with_new_globals(self, new_globals: dict[str, Any]) -> Callable[..., Any]: + """Create a Callable that will run the current Callable with new globals.""" code = self.fn.__code__ fn_name = self.fn.__name__ all_globals = dict(self.globals) diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 36e7fb21..918ba242 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -19,10 +19,8 @@ from dask.delayed import delayed, DelayedLeaf from dask.config import config -import contextvars -from functools import partial from typing import Protocol, runtime_checkable, Any, cast -from concurrent.futures import Executor, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor from dewret.workflow import Workflow, Lazy, StepReference, Target @@ -102,7 +100,6 @@ def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread workflow: `Workflow` in which to record the execution. task: `dask.delayed` function, wrapped by dewret, that we wish to compute. """ - # def _check_delayed(task: Lazy | list[Lazy] | tuple[Lazy]) -> Delayed: # # We need isinstance to reassure type-checker. # if isinstance(task, list) or isinstance(task, tuple): diff --git a/src/dewret/core.py b/src/dewret/core.py index 85f7ef59..266630f1 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from abc import abstractmethod, abstractstaticmethod +from abc import abstractmethod import importlib import base64 from attrs import define from functools import lru_cache -from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Literal, Callable, cast, runtime_checkable +from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Callable, cast, runtime_checkable from contextlib import contextmanager from contextvars import ContextVar from sympy import Expr, Symbol, Basic @@ -129,7 +129,7 @@ class ConstructConfiguration: allow_plain_dict_fields: bool = False field_separator: str = "/" field_index_types: str = "int" - simplify_ids: bool = True + simplify_ids: bool = False class ConstructConfigurationTypedDict(TypedDict): """Basic configuration of the construction process. @@ -139,12 +139,12 @@ class ConstructConfigurationTypedDict(TypedDict): **THIS MUST BE KEPT IDENTICAL TO ConstructConfiguration.** """ - flatten_all_nested: bool - allow_positional_args: bool - allow_plain_dict_fields: bool - field_separator: str - field_index_types: str - simplify_ids: bool + flatten_all_nested: NotRequired[bool] + allow_positional_args: NotRequired[bool] + allow_plain_dict_fields: NotRequired[bool] + field_separator: NotRequired[str] + field_index_types: NotRequired[str] + simplify_ids: NotRequired[bool] @define class GlobalConfiguration: diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 233f3a67..23b4214e 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -556,7 +556,7 @@ class OutputsDefinition: @classmethod def from_results( - cls, results: dict[str, StepReference[Any]] | list[StepReference[Any]] | tuple[StepReference[Any]] + cls, results: dict[str, StepReference[Any]] | list[StepReference[Any]] | tuple[StepReference[Any], ...] ) -> "OutputsDefinition": """Takes a mapping of results into a CWL structure. diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index c6bb30ff..31d55120 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -32,32 +32,29 @@ import inspect import importlib import sys -from typing import TypedDict, NotRequired, Unpack from enum import Enum from functools import cached_property from collections.abc import Callable -from typing import Any, ParamSpec, TypeVar, cast, Generator +from typing import Any, ParamSpec, TypeVar, cast, Generator, Unpack, Literal from types import TracebackType from attrs import has as attrs_has -from dataclasses import dataclass, is_dataclass +from dataclasses import is_dataclass import traceback from concurrent.futures import ThreadPoolExecutor from contextvars import ContextVar, copy_context from contextlib import contextmanager -from .utils import is_firm, make_traceback, is_expr, is_raw_type +from .utils import is_firm, make_traceback, is_expr from .workflow import ( expr_to_references, unify_workflows, UNSET, Reference, - StepReference, Workflow, Lazy, LazyEvaluation, Target, LazyFactory, - merge_workflows, Parameter, param, Task, @@ -65,8 +62,7 @@ ) from .backends._base import BackendModule from .annotations import FunctionAnalyser -from .core import get_configuration, set_configuration, IteratedGenerator, ConstructConfiguration -import ast +from .core import get_configuration, set_configuration, IteratedGenerator, ConstructConfigurationTypedDict Param = ParamSpec("Param") RetType = TypeVar("RetType") @@ -146,13 +142,17 @@ def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy], __workflow__: Workflow Args: task: the task to evaluate. __workflow__: workflow within which this exists. + thread_pool: existing pool of threads to run this in, or None. **kwargs: any arguments to pass to the task. """ result = self.backend.run(__workflow__, task, thread_pool=thread_pool, **kwargs) - result, collected_workflow = unify_workflows(result, __workflow__) + new_result, collected_workflow = unify_workflows(result, __workflow__) + + if collected_workflow is None: + raise RuntimeError("A new workflow could not be found") # Then we set the result to be the whole thing - collected_workflow.set_result(result) + collected_workflow.set_result(new_result) return collected_workflow.result def unwrap(self, task: Lazy) -> Target: @@ -193,9 +193,8 @@ def ensure_lazy(self, task: Any) -> Lazy | None: def __call__( self, task: Any, - simplify_ids: bool = False, __workflow__: Workflow | None = None, - **kwargs: ConstructConfiguration, + **kwargs: Unpack[ConstructConfigurationTypedDict], ) -> Workflow: """Execute the lazy evalution. @@ -217,6 +216,7 @@ def _initializer(): thread_pool = ThreadPoolExecutor(initializer=_initializer) result = self.evaluate(task, workflow, thread_pool=thread_pool, **kwargs) + simplify_ids = bool(get_configuration("simplify_ids")) return Workflow.from_result(result, simplify_ids=simplify_ids) @@ -387,12 +387,12 @@ def _fn( **kwargs: Param.kwargs, ) -> RetType: configuration = None - allow_positional_args = get_configuration("allow_positional_args") + allow_positional_args = bool(get_configuration("allow_positional_args")) try: # Ensure that all arguments are passed as keyword args and prevent positional args. # passed at all. - if args and not get_configuration("allow_positional_args"): + if args and not allow_positional_args: raise TypeError( f""" Calling {fn.__name__}: Arguments must _always_ be named, @@ -409,9 +409,9 @@ def add_numbers(left: int, right: int): # Ensure that the passed arguments are, at least, a Python-match for the signature. sig = inspect.signature(fn) positional_args = {key: False for key in kwargs} - for arg, (key, _) in zip(args, sig.parameters.items()): + for arg, (key, _) in zip(args, sig.parameters.items(), strict=False): if isinstance(arg, IteratedGenerator): - for inner_arg, (key, _) in zip(arg, sig.parameters.items()): + for inner_arg, (key, _) in zip(arg, sig.parameters.items(), strict=False): if key in positional_args: continue kwargs[key] = inner_arg @@ -430,8 +430,9 @@ def _to_param_ref(value): val, kw_refs = expr_to_references(val, remap=_to_param_ref) refs += kw_refs kwargs[key] = val - workflows = [ - reference.__workflow__ + # Not realistically going to be other than Workflow. + workflows: list[Workflow] = [ + cast(Workflow, reference.__workflow__) for reference in refs if hasattr(reference, "__workflow__") and reference.__workflow__ is not None @@ -439,7 +440,7 @@ def _to_param_ref(value): if __workflow__ is not None: workflows.insert(0, __workflow__) if workflows: - workflow = merge_workflows(*workflows) + workflow = Workflow.assimilate(*workflows) else: workflow = Workflow() @@ -452,17 +453,20 @@ def _to_param_ref(value): elif is_firm(value): # We leave this reference dangling for a consumer to pick up ("tethered"), unless # we are in a nested task, that does not have any existence of its own. - tethered = ( + tethered: Literal[False] | None = ( False if nested and ( flatten_nested or get_configuration("flatten_all_nested") ) else None ) - kwargs[var] = param( - var, - value, - tethered=tethered, - autoname=tethered is not False, - typ=analyser.get_argument_annotation(var) or UNSET + kwargs[var] = cast( + Parameter, + param( + var, + value, + tethered=tethered, + autoname=tethered is not False, + typ=analyser.get_argument_annotation(var) or UNSET + ) ).make_reference(workflow=workflow) original_kwargs = dict(kwargs) fn_globals = analyser.globals @@ -515,13 +519,16 @@ def {fn.__name__}(...) -> ...: (attrs_has(value) or is_dataclass(value)) and not inspect.isclass(value) ): - kwargs[var] = param( - var, - default=value, - tethered=False, - typ=analyser.get_argument_annotation(var, exhaustive=True) or UNSET + kwargs[var] = cast( + Parameter, + param( + var, + default=value, + tethered=False, + typ=analyser.get_argument_annotation(var, exhaustive=True) or UNSET + ) ).make_reference(workflow=workflow) - elif is_expr(value) and expr_to_references(value)[1] is not []: + elif is_expr(value) and (expr_refs := expr_to_references(value)) and len(expr_refs[1]) != 0: kwargs[var] = value elif nested: raise NotImplementedError( @@ -542,13 +549,16 @@ def {fn.__name__}(...) -> ...: else: nested_workflow = Workflow(name=fn.__name__) nested_globals: Param.kwargs = { - var: param( - var, - default=value.__default__ if hasattr(value, "__default__") else UNSET, - typ=( - value.__type__ - ), - tethered=nested_workflow + var: cast( + Parameter, + param( + var, + default=value.__default__ if hasattr(value, "__default__") else UNSET, + typ=( + value.__type__ + ), + tethered=nested_workflow + ) ).make_reference(workflow=nested_workflow) if isinstance(value, Reference) else value for var, value in kwargs.items() } @@ -589,7 +599,7 @@ def {fn.__name__}(...) -> ...: configuration.__exit__(None, None, None) _fn.__step_expression__ = True # type: ignore - _fn.__original__ = fn + _fn.__original__ = fn # type: ignore return LazyEvaluation(_fn) return _task diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 83d65624..eb97c9cd 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -502,9 +502,10 @@ def assimilate(cls, *workflow_args: Workflow) -> "Workflow": This could happen if the hashing function is flawed or some Python magic to do with Targets being passed. - Argument: - left: workflow to use as base - right: workflow to combine on top + Args: + workflow_args: workflows to use as base + + j """ workflows = sorted((w for w in set(workflow_args)), key=lambda w: w.id) base = workflows[0] @@ -1587,20 +1588,6 @@ def __iter__(self) -> Generator[Reference, None, None]: # We cast this so that we can treat a step iterator as if it really loops over results. yield cast(Reference, IteratedGenerator(self)) -def merge_workflows(*workflows: Workflow) -> Workflow: - """Combine several workflows into one. - - Merges a series of workflows by combining steps and tasks. - - Argument: - *workflows: series of workflows to combine. - - Returns: - One workflow with all steps. - """ - return Workflow.assimilate(*workflows) - - def is_task(task: Lazy) -> bool: """Decide whether this is a task. diff --git a/tests/test_nested.py b/tests/test_nested.py index 99ccf1ca..c05f85c2 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -2,7 +2,7 @@ import pytest import math from dewret.workflow import param -from dewret.tasks import construct, task, factory +from dewret.tasks import construct from dewret.renderers.cwl import render from ._lib.extra import reverse_list, max_list From 8c84ab2462596add385e340a96068c90868dc5c9 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 23:06:28 +0100 Subject: [PATCH 068/108] chore: fix up typehints and disable mypy sympy checking, as sympy does not have typehints --- pyproject.toml | 6 ++++++ src/dewret/annotations.py | 20 ++++++++++++++++++++ src/dewret/backends/backend_dask.py | 2 ++ src/dewret/core.py | 20 ++++++++++++++++++++ src/dewret/render.py | 20 ++++++++++++++++++++ src/dewret/workflow.py | 2 +- 6 files changed, 69 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2c44f274..b8992c2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,12 @@ select = ["D", "F", "B"] [tool.ruff.lint.pydocstyle] convention = "google" +[[tool.mypy.overrides]] +module = [ + "sympy", +] +ignore_missing_imports = true + [project] name = "dewret" description = "DEclarative Workflow REndering Tool" diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index de46b2c6..74db06f6 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -1,3 +1,23 @@ +# Copyright 2024- Flax & Teal Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tooling for managing annotations. + +Provides `FunctionAnalyser`, a toolkit that takes a `Callable` and can interrogate it +for annotations, with some intelligent searching beyond the obvious location. +""" + import inspect import ast import sys diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 918ba242..50806e54 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -99,6 +99,8 @@ def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread Args: workflow: `Workflow` in which to record the execution. task: `dask.delayed` function, wrapped by dewret, that we wish to compute. + thread_pool: thread pool for executing the workflows, to allow initialization of configuration contextvars. + **kwargs: any configuration arguments for this backend. """ # def _check_delayed(task: Lazy | list[Lazy] | tuple[Lazy]) -> Delayed: # # We need isinstance to reassure type-checker. diff --git a/src/dewret/core.py b/src/dewret/core.py index 266630f1..9fb7d630 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -1,3 +1,23 @@ +# Copyright 2024- Flax & Teal Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base classes that need to be available everywhere. + +Mainly tooling around configuration, protocols and superclasses for References +and WorkflowComponents, that are concretized elsewhere. +""" + from dataclasses import dataclass from abc import abstractmethod import importlib diff --git a/src/dewret/render.py b/src/dewret/render.py index 7127fcce..35e23b0d 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -1,3 +1,23 @@ +# Copyright 2024- Flax & Teal Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for building renderers. + +Provides the routines for calling varied renderers in a standard way, and for +renderers to reuse to build up their own functionality. +""" + import sys from pathlib import Path from functools import partial diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index eb97c9cd..351cd80e 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -1203,7 +1203,7 @@ def __init__( arguments: key-value pairs to pass to the function - for a factory call, these _must_ be raw. raw_as_parameter: whether to turn any raw-type arguments into workflow parameters (or just keep them as default argument values). """ - for key, arg in arguments.items(): + for _, arg in arguments.items(): if not is_expr(arg) and not ( isinstance(arg, ParameterReference) and is_raw_type(arg.__type__) ): From b4a4cbc0b663f78cd716f01e3f4ea41d258287f0 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 23:07:49 +0100 Subject: [PATCH 069/108] fix: unbreak the doublequotes --- src/dewret/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 351cd80e..71c380c3 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -837,7 +837,7 @@ def __field_index_types__(self) -> tuple[type, ...]: """ types = get_configuration("field_index_types") if not isinstance(types, str): - raise TypeError(f"The `field_index_types` configuration argument should be a comma-separated string of types from: {", ".join(AVAILABLE_TYPES.keys())}") + raise TypeError(f"The `field_index_types` configuration argument should be a comma-separated string of types from: {', '.join(AVAILABLE_TYPES.keys())}") tup_all = tuple(AVAILABLE_TYPES.get(typ) for typ in types.split(",") if typ) tup = tuple(t for t in tup_all if t is not None) if tup_all != tup: From 394e5a60d62fd53a7e0078d6baade006e3f74e51 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 23:09:27 +0100 Subject: [PATCH 070/108] fix: unbreak the doublequotes --- src/dewret/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 71c380c3..c59e3bfb 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -843,7 +843,7 @@ def __field_index_types__(self) -> tuple[type, ...]: if tup_all != tup: raise ValueError( "Setting for fixed index types contains unavailable type: " + - f"{str(get_configuration("field_index_types"))} vs {tup}" + f"{str(get_configuration('field_index_types'))} vs {tup}" ) return tup From 88472e4eb221ff683c2ed9c2381de955c6f35467 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 23:10:55 +0100 Subject: [PATCH 071/108] fix: unbreak the doublequotes --- tests/_lib/frender.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/_lib/frender.py b/tests/_lib/frender.py index 83fc927c..a7e8cfab 100644 --- a/tests/_lib/frender.py +++ b/tests/_lib/frender.py @@ -85,7 +85,7 @@ def render(self): I found a workflow called {self.name}. It has {len(self.steps)} steps! They are: -{"\n".join("* " + indent(step.render(), " ")[3:] for step in self.steps)} +{'\n'.join('* ' + indent(step.render(), ' ')[3:] for step in self.steps)} It probably got made with JUMP={JUMP} """ From 861420b2ae2c3cb8c4e31e5021a9687c6bf82c6d Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 23:13:13 +0100 Subject: [PATCH 072/108] fix: unbreak the doublequotes --- tests/_lib/frender.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/_lib/frender.py b/tests/_lib/frender.py index a7e8cfab..d04b0fe2 100644 --- a/tests/_lib/frender.py +++ b/tests/_lib/frender.py @@ -80,12 +80,13 @@ def from_workflow(cls, workflow: Workflow): return cls(name=name, steps=steps) def render(self): + steps = "\n".join('* ' + indent(step.render(), ' ')[3:] for step in self.steps) return \ f""" I found a workflow called {self.name}. It has {len(self.steps)} steps! They are: -{'\n'.join('* ' + indent(step.render(), ' ')[3:] for step in self.steps)} +{steps} It probably got made with JUMP={JUMP} """ From 99f26fdbf184130a5bebf4aa825946a853d0b073 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sat, 24 Aug 2024 23:27:23 +0100 Subject: [PATCH 073/108] wip: try 3.12 in test --- .github/workflows/python-test-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-test-ci.yml b/.github/workflows/python-test-ci.yml index 2cba3af2..3a6764fe 100644 --- a/.github/workflows/python-test-ci.yml +++ b/.github/workflows/python-test-ci.yml @@ -8,7 +8,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.11' + python-version: '3.12' - name: Install dependencies run: | python -m pip install --upgrade pip "hatchling < 1.22" From 3274ea9b253fea0311d5b39a1efbc46e871e4211 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 00:57:06 +0100 Subject: [PATCH 074/108] fix: sort examples - TODO: fix explanatory text --- docs/quickstart.md | 4 +- docs/workflows.md | 233 ++++++++++++------------------------- src/dewret/annotations.py | 13 ++- src/dewret/workflow.py | 14 +-- tests/test_fieldable.py | 4 +- tests/test_subworkflows.py | 70 +++++++++++ 6 files changed, 168 insertions(+), 170 deletions(-) diff --git a/docs/quickstart.md b/docs/quickstart.md index 34a5be6c..c5d04e13 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -97,14 +97,14 @@ and backends, as well as bespoke serialization or formatting. >>> >>> result = increment(num=3) >>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> cwl = render(workflow)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: increment-1-num: default: 3 - label: increment-1-num + label: num type: int outputs: out: diff --git a/docs/workflows.md b/docs/workflows.md index 94b05efb..84d78fe7 100644 --- a/docs/workflows.md +++ b/docs/workflows.md @@ -63,19 +63,15 @@ In code, this would be: ... left=double(num=increment(num=23)), ... right=mod10(num=increment(num=23)) ... ) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(result, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: increment-1-num: default: 23 - label: increment-1-num - type: int - increment-2-num: - default: 23 - label: increment-2-num + label: num type: int outputs: out: @@ -86,7 +82,7 @@ steps: double-1: in: num: - source: increment-2/out + source: increment-1/out out: - out run: double @@ -97,13 +93,6 @@ steps: out: - out run: increment - increment-2: - in: - num: - source: increment-2-num - out: - - out - run: increment mod10-1: in: num: @@ -157,8 +146,8 @@ This duplication can be avoided by explicitly indicating that the parameters are ... left=double(num=increment(num=num)), ... right=mod10(num=increment(num=num)) ... ) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(result, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -232,8 +221,8 @@ For example: ... return (num + INPUT_NUM) % INPUT_NUM >>> >>> result = rotate(num=5) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(result, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -244,7 +233,7 @@ inputs: type: int rotate-1-num: default: 5 - label: rotate-1-num + label: num type: int outputs: out: @@ -284,22 +273,24 @@ As code: ```python >>> import sys >>> import yaml ->>> from dewret.tasks import task, construct, nested_task +>>> from dewret.core import set_configuration +>>> from dewret.tasks import task, construct, workflow >>> from dewret.renderers.cwl import render >>> INPUT_NUM = 3 >>> @task() ... def rotate(num: int) -> int: -... """Rotate an integer.""" -... return (num + INPUT_NUM) % INPUT_NUM +... """Rotate an integer.""" +... return (num + INPUT_NUM) % INPUT_NUM >>> ->>> @nested_task() +>>> @workflow() ... def double_rotate(num: int) -> int: -... """Rotate an integer twice.""" -... return rotate(num=rotate(num=num)) +... """Rotate an integer twice.""" +... return rotate(num=rotate(num=num)) >>> ->>> result = double_rotate(num=3) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> with set_configuration(flatten_all_nested=True): +... result = double_rotate(num=3) +... wkflw = construct(result, simplify_ids=True) +... cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -315,7 +306,7 @@ inputs: outputs: out: label: out - outputSource: rotate-2/out + outputSource: rotate-1/out type: int steps: rotate-1: @@ -323,7 +314,7 @@ steps: INPUT_NUM: source: INPUT_NUM num: - source: num + source: rotate-2/out out: - out run: rotate @@ -332,7 +323,7 @@ steps: INPUT_NUM: source: INPUT_NUM num: - source: rotate-1/out + source: num out: - out run: rotate @@ -349,7 +340,7 @@ For example, the following code renders the same workflow as in the previous exa ```python -@nested_task() +@workflow() def double_rotate(num: int) -> int: """Rotate an integer twice.""" unused_var = increment(num=num) @@ -409,19 +400,15 @@ As code: ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> workflow = construct(red_total, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(red_total, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: shuffle-1-max_cards_per_suit: default: 13 - label: shuffle-1-max_cards_per_suit - type: int - shuffle-2-max_cards_per_suit: - default: 13 - label: shuffle-2-max_cards_per_suit + label: max_cards_per_suit type: int outputs: out: @@ -447,28 +434,10 @@ steps: label: spades type: int run: shuffle - shuffle-2: - in: - max_cards_per_suit: - source: shuffle-2-max_cards_per_suit - out: - clubs: - label: clubs - type: int - diamonds: - label: diamonds - type: int - hearts: - label: hearts - type: int - spades: - label: spades - type: int - run: shuffle sum-1: in: left: - source: shuffle-2/hearts + source: shuffle-1/hearts right: source: shuffle-1/diamonds out: @@ -510,19 +479,15 @@ Here, we show the same example with `dataclasses`. ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> workflow = construct(red_total, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(red_total, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: shuffle-1-max_cards_per_suit: default: 13 - label: shuffle-1-max_cards_per_suit - type: int - shuffle-2-max_cards_per_suit: - default: 13 - label: shuffle-2-max_cards_per_suit + label: max_cards_per_suit type: int outputs: out: @@ -548,28 +513,10 @@ steps: label: spades type: int run: shuffle - shuffle-2: - in: - max_cards_per_suit: - source: shuffle-2-max_cards_per_suit - out: - clubs: - label: clubs - type: int - diamonds: - label: diamonds - type: int - hearts: - label: hearts - type: int - spades: - label: spades - type: int - run: shuffle sum-1: in: left: - source: shuffle-2/hearts + source: shuffle-1/hearts right: source: shuffle-1/diamonds out: @@ -589,7 +536,7 @@ dewret will produce multiple output workflows that reference each other. >>> import yaml >>> from attrs import define >>> from numpy import random ->>> from dewret.tasks import task, construct, subworkflow +>>> from dewret.tasks import task, construct, workflow >>> from dewret.renderers.cwl import render >>> @define ... class PackResult: @@ -611,21 +558,21 @@ dewret will produce multiple output workflows that reference each other. ... spades=random.randint(max_cards_per_suit), ... diamonds=random.randint(max_cards_per_suit) ... ) ->>> @subworkflow() -... def red_total(): +>>> @workflow() +... def red_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> @subworkflow() -... def black_total(): +>>> @workflow() +... def black_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).spades, ... right=shuffle(max_cards_per_suit=13).clubs ... ) >>> total = sum(left=red_total(), right=black_total()) ->>> workflow = construct(total, simplify_ids=True) ->>> cwl, subworkflows = render(workflow) +>>> wkflw = construct(total, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -667,7 +614,7 @@ as a second term. >>> import yaml >>> from attrs import define >>> from numpy import random ->>> from dewret.tasks import task, construct, subworkflow +>>> from dewret.tasks import task, construct, workflow >>> from dewret.renderers.cwl import render >>> @define ... class PackResult: @@ -689,33 +636,25 @@ as a second term. ... def sum(left: int, right: int) -> int: ... return left + right >>> ->>> @subworkflow() -... def red_total(): +>>> @workflow() +... def red_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> @subworkflow() -... def black_total(): +>>> @workflow() +... def black_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).spades, ... right=shuffle(max_cards_per_suit=13).clubs ... ) >>> total = sum(left=red_total(), right=black_total()) ->>> workflow = construct(total, simplify_ids=True) ->>> cwl, subworkflows = render(workflow) ->>> yaml.dump(subworkflows["red_total-1"], sys.stdout, indent=2) +>>> wkflw = construct(total, simplify_ids=True) +>>> cwl = render(wkflw) +>>> yaml.dump(cwl["red_total-1"], sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 -inputs: - shuffle-1-1-max_cards_per_suit: - default: 13 - label: shuffle-1-1-max_cards_per_suit - type: int - shuffle-1-2-max_cards_per_suit: - default: 13 - label: shuffle-1-2-max_cards_per_suit - type: int +inputs: {} outputs: out: label: out @@ -725,25 +664,7 @@ steps: shuffle-1-1: in: max_cards_per_suit: - source: shuffle-1-1-max_cards_per_suit - out: - clubs: - label: clubs - type: int - diamonds: - label: diamonds - type: int - hearts: - label: hearts - type: int - spades: - label: spades - type: int - run: shuffle - shuffle-1-2: - in: - max_cards_per_suit: - source: shuffle-1-2-max_cards_per_suit + default: 13 out: clubs: label: clubs @@ -761,7 +682,7 @@ steps: sum-1-1: in: left: - source: shuffle-1-2/hearts + source: shuffle-1-1/hearts right: source: shuffle-1-1/diamonds out: @@ -783,7 +704,7 @@ Below is the default output, treating `Pack` as a task. ```python >>> import sys >>> import yaml ->>> from dewret.tasks import subworkflow, factory, nested_task, construct, task +>>> from dewret.tasks import workflow, factory, workflow, construct, task >>> from attrs import define >>> from dewret.renderers.cwl import render >>> @define @@ -799,39 +720,39 @@ Below is the default output, treating `Pack` as a task. ... def sum(left: int, right: int) -> int: ... return left + right >>> ->>> @nested_task() -... def black_total(pack: PackResult): +>>> @workflow() +... def black_total(pack: PackResult) -> int: ... return sum( ... left=pack.spades, ... right=pack.clubs ... ) >>> pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13) ->>> workflow = construct(black_total(pack=pack), simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(black_total(pack=pack), simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: PackResult-1-clubs: default: 13 - label: PackResult-1-clubs + label: clubs type: int PackResult-1-diamonds: default: 13 - label: PackResult-1-diamonds + label: diamonds type: int PackResult-1-hearts: default: 13 - label: PackResult-1-hearts + label: hearts type: int PackResult-1-spades: default: 13 - label: PackResult-1-spades + label: spades type: int outputs: out: label: out - outputSource: sum-1/out + outputSource: black_total-1/out type: int steps: PackResult-1: @@ -858,15 +779,13 @@ steps: label: spades type: int run: PackResult - sum-1: + black_total-1: in: - left: - source: PackResult-1/spades - right: - source: PackResult-1/clubs + pack: + source: PackResult-1/out out: - out - run: sum + run: black_total ``` @@ -876,7 +795,7 @@ types are allowed. ```python >>> import sys >>> import yaml ->>> from dewret.tasks import task, factory, nested_task, construct +>>> from dewret.tasks import task, factory, workflow, construct >>> from attrs import define >>> from dewret.renderers.cwl import render >>> @define @@ -891,36 +810,36 @@ types are allowed. ... def sum(left: int, right: int) -> int: ... return left + right >>> ->>> @nested_task() -... def black_total(pack: PackResult): +>>> @workflow() +... def black_total(pack: PackResult) -> int: ... return sum( ... left=pack.spades, ... right=pack.clubs ... ) >>> pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13) ->>> workflow = construct(black_total(pack=pack), simplify_ids=True) ->>> cwl = render(workflow, allow_complex_types=True, factories_as_params=True) +>>> wkflw = construct(black_total(pack=pack), simplify_ids=True) +>>> cwl = render(wkflw, allow_complex_types=True, factories_as_params=True)["black_total-1"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: - PackResult-1: - label: PackResult-1 + pack: + label: pack type: record outputs: out: label: out - outputSource: sum-1/out + outputSource: sum-1-1/out type: int steps: - sum-1: + sum-1-1: in: left: - source: PackResult-1/spades + source: pack/spades right: - source: PackResult-1/clubs + source: pack/clubs out: - out run: sum -``` \ No newline at end of file +``` diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index 74db06f6..44897289 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -56,8 +56,17 @@ def __init__(self, fn: Callable[..., Any]): @property def return_type(self) -> type: - """Return type of the callable.""" - return get_type_hints(inspect.unwrap(self.fn), include_extras=True)["return"] + """Return type of the callable. + + Returns: expected type of the return value. + + Raises: + ValueError: if the return value does not appear to be type-hinted. + """ + hints = get_type_hints(inspect.unwrap(self.fn), include_extras=True) + if "return" not in hints: + raise ValueError(f"Could not find type-hint for return value of {self.fn}") + return hints["return"] @staticmethod def _typ_has(typ: type, annotation: type) -> bool: diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index c59e3bfb..30268d91 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -465,7 +465,7 @@ def find_factories(self) -> dict[str, FactoryCall]: def find_parameters( self, include_factory_calls: bool = True - ) -> set[ParameterReference]: + ) -> set[Parameter]: """Crawl steps for parameter references. As the workflow does not hold its own list of parameters, this @@ -477,7 +477,7 @@ def find_parameters( _, references = expr_to_references( step.arguments for step in self.steps if (include_factory_calls or not isinstance(step, FactoryCall)) ) - return {ref for ref in references if isinstance(ref, ParameterReference)} + return {ref._.parameter for ref in references if isinstance(ref, ParameterReference)} @property def indexed_steps(self) -> dict[str, BaseStep]: @@ -545,7 +545,7 @@ def assimilate(cls, *workflow_args: Workflow) -> "Workflow": for step in base.steps: step.set_workflow(base, with_arguments=True) - results = sorted(set((w.result for w in workflows if w.result))) + results = sorted(set((w.result for w in workflows if w.has_result))) if len(results) == 1: result = results[0] else: @@ -587,9 +587,9 @@ def simplify_ids(self, infix: list[str] | None = None) -> None: param_counter = Counter[str]() name_to_original: dict[str, str] = {} for name, param in { - pr._.parameter.__name__: pr._.parameter - for pr in self.find_parameters() - if isinstance(pr, ParameterReference) + param.__name__: param + for param in self.find_parameters() + if isinstance(param, Parameter) }.items(): if param.__original_name__ != name: param_counter[param.__original_name__] += 1 @@ -1000,7 +1000,7 @@ def __init__( def _to_param_ref(value): if isinstance(value, Parameter): - return value.make_parameter(workflow=workflow) + return value.make_reference(workflow=workflow) expression, refs = expr_to_references(value, remap=_to_param_ref) for ref in refs: diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index 18495029..5ef7c93d 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -76,9 +76,9 @@ def test_can_get_field_reference_from_parameter(): my_param = param("my_param", typ=MyDataclass) result = sum(left=my_param.left, right=sum(left=my_param.right.left, right=my_param)) wkflw = construct(result, simplify_ids=True) - param_references = {(str(p), p.__type__) for p in wkflw.find_parameters()} + params = {(str(p), p.__type__) for p in wkflw.find_parameters()} - assert param_references == {("my_param/left", int), ("my_param", MyDataclass), ("my_param/right/left", int)} + assert params == {("my_param", MyDataclass)} rendered = render(wkflw, allow_complex_types=True)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 5cd86ff4..7339ce25 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -6,6 +6,7 @@ from dewret.tasks import construct, workflow, task, factory, set_configuration from dewret.renderers.cwl import render from dewret.workflow import param +from attrs import define from ._lib.extra import increment, sum, pi @@ -542,3 +543,72 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: - out run: to_int """)) + +@define +class PackResult: + hearts: int + clubs: int + spades: int + diamonds: int + +def test_combining_attrs_and_factories(): + Pack = factory(PackResult) + + @task() + def sum(left: int, right: int) -> int: + return left + right + + @workflow() + def black_total(pack: PackResult) -> int: + return sum( + left=pack.spades, + right=pack.clubs + ) + pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13) + wkflw = construct(black_total(pack=pack), simplify_ids=True) + cwl = render(wkflw, allow_complex_types=True, factories_as_params=True) + assert cwl["__root__"] == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + PackResult-1: + label: PackResult-1 + type: record + outputs: + out: + label: out + outputSource: black_total-1/out + type: int + steps: + black_total-1: + in: + pack: + source: PackResult-1/out + out: + - out + run: black_total + """) + + assert cwl["black_total-1"] == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + pack: + label: pack + type: record + outputs: + out: + label: out + outputSource: sum-1-1/out + type: int + steps: + sum-1-1: + in: + left: + source: pack/spades + right: + source: pack/clubs + out: + - out + run: sum + """) From 1dc5692c7c59104de47400bc6ba304b8528bb81a Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 01:06:16 +0100 Subject: [PATCH 075/108] chore: fix test linting - TODO: add docstrings --- tests/_lib/extra.py | 3 --- tests/test_annotations.py | 17 ++++++++++----- tests/test_configuration.py | 12 +++++++---- tests/test_cwl.py | 2 +- tests/test_errors.py | 16 ++++---------- tests/test_fieldable.py | 43 +++++++++++++++++++++++++------------ tests/test_nested.py | 4 +++- tests/test_render_module.py | 7 ++++-- tests/test_subworkflows.py | 2 ++ 9 files changed, 64 insertions(+), 42 deletions(-) diff --git a/tests/_lib/extra.py b/tests/_lib/extra.py index 921a93c0..3293044d 100644 --- a/tests/_lib/extra.py +++ b/tests/_lib/extra.py @@ -1,17 +1,14 @@ from dewret.tasks import task, workflow -from dewret.annotations import AtRender from .other import nothing JUMP: float = 1.0 test: float = nothing -from inspect import get_annotations @workflow() def try_nothing() -> int: """Check that we can see AtRender in another module.""" - if nothing: return increment(num=1) return increment(num=0) diff --git a/tests/test_annotations.py b/tests/test_annotations.py index a4cbb440..9a709003 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -1,8 +1,9 @@ +"""Verify we can interrogate annotations.""" + import pytest import yaml -from typing import Literal -from dewret.tasks import task, construct, workflow, TaskException +from dewret.tasks import construct, workflow, TaskException from dewret.renderers.cwl import render from dewret.annotations import AtRender, FunctionAnalyser, Fixed from dewret.core import set_configuration @@ -13,12 +14,15 @@ ARG2: bool = False class MyClass: + """TODO: Docstring.""" def method(self, arg1: bool, arg2: AtRender[int]) -> float: + """TODO: Docstring.""" arg3: float = 7.0 arg4: AtRender[float] = 8.0 return arg1 + arg2 + arg3 + arg4 + int(ARG1) + int(ARG2) def fn(arg5: int, arg6: AtRender[int]) -> float: + """TODO: Docstring.""" arg7: float = 7.0 arg8: AtRender[float] = 8.0 return arg5 + arg6 + arg7 + arg8 + int(ARG1) + int(ARG2) @@ -35,6 +39,7 @@ def to_int(num: int, should_double: AtRender[bool]) -> int | float: return increment(num=num) if should_double else sum(left=num, right=num) def test_can_analyze_annotations(): + """TODO: Docstring.""" my_obj = MyClass() analyser = FunctionAnalyser(my_obj.method) @@ -56,6 +61,7 @@ def test_can_analyze_annotations(): assert analyser.argument_has("ARG1", AtRender) is False def test_at_render() -> None: + """TODO: Docstring.""" with pytest.raises(TaskException) as _: result = to_int_bad(num=increment(num=3), should_double=True) wkflw = construct(result, simplify_ids=True) @@ -138,19 +144,20 @@ def test_at_render() -> None: def test_at_render_between_modules() -> None: - nothing = False + """TODO: Docstring.""" result = try_nothing() wkflw = construct(result, simplify_ids=True) subworkflows = render(wkflw, allow_complex_types=True) - rendered = subworkflows["__root__"] + subworkflows["__root__"] list_2: Fixed[list[int]] = [0, 1, 2, 3] def test_can_loop_over_fixed_length() -> None: + """TODO: Docstring.""" @workflow() def loop_over_lists(list_1: list[int]) -> list[int]: result = [] - for a, b in zip(list_1, list_2): + for a, b in zip(list_1, list_2, strict=False): result.append(a + b + len(list_2)) return result diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 01b91206..b94e57e4 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -1,14 +1,16 @@ +"""Check configuration is consistent and usable.""" + import yaml import pytest -from dewret.tasks import construct, task, factory, workflow, TaskException +from dewret.tasks import construct, workflow, TaskException from dewret.renderers.cwl import render -from dewret.utils import hasher from dewret.tasks import set_configuration from dewret.annotations import AtRender -from ._lib.extra import increment, double, mod10, sum, triple_and_one +from ._lib.extra import increment @pytest.fixture def configuration(): + """TODO: Docstring.""" with set_configuration() as configuration: yield configuration.get() @@ -21,6 +23,7 @@ def floor(num: int, expected: AtRender[bool]) -> int: return increment(num=num) def test_cwl_with_parameter(configuration) -> None: + """TODO: Docstring.""" with set_configuration(flatten_all_nested=True): result = increment(num=floor(num=3, expected=True)) workflow = construct(result, simplify_ids=True) @@ -35,8 +38,9 @@ def test_cwl_with_parameter(configuration) -> None: workflow = construct(result, simplify_ids=True) rendered = render(workflow)["__root__"] num_param = list(workflow.find_parameters())[0] + assert num_param - assert rendered == yaml.safe_load(f""" + assert rendered == yaml.safe_load(""" cwlVersion: 1.2 class: Workflow inputs: diff --git a/tests/test_cwl.py b/tests/test_cwl.py index 52f7793a..4188424f 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -177,7 +177,7 @@ def test_cwl_with_positional_parameter() -> None: Produces CWL for a call with a changeable raw value, that is converted to a parameter, if and only if we are calling from outside a subworkflow. """ - with pytest.raises(TaskException) as exc: + with pytest.raises(TaskException) as _: result = increment(3) with set_configuration(allow_positional_args=True): result = increment(3) diff --git a/tests/test_errors.py b/tests/test_errors.py index 96607bce..3b21c332 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -4,7 +4,6 @@ from dewret.workflow import Task, Lazy from dewret.tasks import construct, task, workflow, TaskException from dewret.annotations import AtRender -from dewret.renderers.cwl import render from ._lib.extra import increment, pi, reverse_list # noqa: F401 @@ -14,7 +13,7 @@ def add_task(left: int, right: int) -> int: return left + right -ADD_TASK_LINE_NO: int = 11 +ADD_TASK_LINE_NO: int = 10 @workflow() @@ -37,14 +36,6 @@ def __init__(self, task: Task): ... -@task() -def pi() -> float: - """Get pi from math package.""" - import math - - return math.pi - - @task() def pi_exported_from_math() -> float: """Get pi from math package by name.""" @@ -242,6 +233,7 @@ def test_subworkflows_must_return_a_task() -> None: good_num: int = 4 def test_must_annotate_global() -> None: + """TODO: Docstrings.""" worse_num = 3 @workflow() @@ -249,7 +241,7 @@ def check_annotation() -> int | float: return increment(num=bad_num) with pytest.raises(TaskException) as exc: - result = check_annotation() + check_annotation() assert ( str(exc.value) @@ -261,7 +253,7 @@ def check_annotation_2() -> int | float: return increment(num=worse_num) with pytest.raises(TaskException) as exc: - result = check_annotation_2() + check_annotation_2() assert ( str(exc.value) diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index 5ef7c93d..e669f63f 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -1,8 +1,9 @@ +"""Check field management works.""" + from __future__ import annotations import yaml from dataclasses import dataclass -import pytest from typing import Unpack, TypedDict from dewret.tasks import task, construct, workflow, set_configuration @@ -10,10 +11,11 @@ from dewret.renderers.cwl import render from dewret.annotations import Fixed -from ._lib.extra import double, mod10, sum, pi +from ._lib.extra import mod10, sum, pi @dataclass class Sides: + """TODO: Docstring.""" left: int right: int @@ -21,9 +23,11 @@ class Sides: @workflow() def sum_sides() -> float: + """TODO: Docstring.""" return sum(left=SIDES.left, right=SIDES.right) def test_fields_of_parameters_usable() -> None: + """TODO: Docstring.""" result = sum_sides() wkflw = construct(result, simplify_ids=True) rendered = render(wkflw, allow_complex_types=True)["sum_sides-1"] @@ -69,10 +73,12 @@ def test_fields_of_parameters_usable() -> None: @dataclass class MyDataclass: + """TODO: Docstring.""" left: int right: "MyDataclass" def test_can_get_field_reference_from_parameter(): + """TODO: Docstring.""" my_param = param("my_param", typ=MyDataclass) result = sum(left=my_param.left, right=sum(left=my_param.right.left, right=my_param)) wkflw = construct(result, simplify_ids=True) @@ -116,6 +122,7 @@ def test_can_get_field_reference_from_parameter(): """) def test_can_get_field_reference_iff_parent_type_has_field(): + """TODO: Docstring.""" @dataclass class MyDataclass: left: int @@ -128,6 +135,7 @@ class MyDataclass: assert param_reference.left.__type__ == int def test_can_get_field_references_from_dataclass(): + """TODO: Docstring.""" @dataclass class MyDataclass: left: int @@ -149,10 +157,12 @@ def get_left(my_dataclass: MyDataclass) -> int: assert wkflw.result.__type__ == int class MyDict(TypedDict): + """TODO: Docstring.""" left: int right: float def test_can_get_field_references_from_typed_dict(): + """TODO: Docstring.""" @workflow() def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: result: MyDict = {"left": mod10(num=my_dict["left"]), "right": pi()} @@ -166,9 +176,11 @@ def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: @dataclass class MyListWrapper: + """TODO: Docstring.""" my_list: list[int] def test_can_iterate(): + """TODO: Docstring.""" @task() def test_task(alpha: int, beta: float, charlie: bool) -> int: return int(alpha + beta) @@ -289,6 +301,7 @@ def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: """) def test_can_use_plain_dict_fields(): + """TODO: Docstring.""" @workflow() def test_dict(left: int, right: float) -> dict[str, float | int]: result: dict[str, float | int] = {"left": mod10(num=left), "right": pi()} @@ -302,9 +315,11 @@ def test_dict(left: int, right: float) -> dict[str, float | int]: @dataclass class IndexTest: + """TODO: Docstring.""" left: Fixed[list[int]] def test_can_configure_field_separator(): + """TODO: Docstring.""" @task() def test_sep() -> IndexTest: return IndexTest(left=[3]) @@ -312,17 +327,17 @@ def test_sep() -> IndexTest: with set_configuration(field_index_types="int"): result = test_sep().left[0] wkflw = construct(result, simplify_ids=True) - rendered = render(wkflw, allow_complex_types=True)["__root__"] + render(wkflw, allow_complex_types=True)["__root__"] assert str(wkflw.result) == "test_sep-1/left[0]" - #with set_configuration(field_index_types="int,str"): - # result = test_sep().left[0] - # wkflw = construct(result, simplify_ids=True) - # rendered = render(wkflw, allow_complex_types=True)["__root__"] - # assert str(wkflw.result) == "test_sep-1[left][0]" - - #with set_configuration(field_index_types=""): - # result = test_sep().left[0] - # wkflw = construct(result, simplify_ids=True) - # rendered = render(wkflw, allow_complex_types=True)["__root__"] - # assert str(wkflw.result) == "test_sep-1/left/0" + with set_configuration(field_index_types="int,str"): + result = test_sep().left[0] + wkflw = construct(result, simplify_ids=True) + render(wkflw, allow_complex_types=True)["__root__"] + assert str(wkflw.result) == "test_sep-1[left][0]" + + with set_configuration(field_index_types=""): + result = test_sep().left[0] + wkflw = construct(result, simplify_ids=True) + render(wkflw, allow_complex_types=True)["__root__"] + assert str(wkflw.result) == "test_sep-1/left/0" diff --git a/tests/test_nested.py b/tests/test_nested.py index c05f85c2..3d901d8c 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -1,5 +1,6 @@ +"""Check complex nested structures and expressions can mix.""" + import yaml -import pytest import math from dewret.workflow import param from dewret.tasks import construct @@ -8,6 +9,7 @@ from ._lib.extra import reverse_list, max_list def test_can_supply_nested_raw(): + """TODO: Docstrings.""" pi = param("pi", math.pi) result = reverse_list(to_sort=[1., 3., pi]) workflow = construct(max_list(lst=result + result), simplify_ids=True) diff --git a/tests/test_render_module.py b/tests/test_render_module.py index 1cc36225..c8760aab 100644 --- a/tests/test_render_module.py +++ b/tests/test_render_module.py @@ -1,10 +1,13 @@ +"""Check renderers can be imported live.""" + from pathlib import Path -from dewret.tasks import construct, task, factory +from dewret.tasks import construct from dewret.render import get_render_method -from ._lib.extra import increment, double, mod10, sum, triple_and_one +from ._lib.extra import increment, triple_and_one def test_can_load_render_module(): + """TODO: Docstrings.""" result = triple_and_one(num=increment(num=3)) workflow = construct(result, simplify_ids=True) workflow._name = "Fred" diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 7339ce25..0fdd43cd 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -546,12 +546,14 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: @define class PackResult: + """TODO: Docstrings.""" hearts: int clubs: int spades: int diamonds: int def test_combining_attrs_and_factories(): + """TODO: Docstrings.""" Pack = factory(PackResult) @task() From 3871ef4efa6e13aac87798ea8767bb1bdde88baa Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 01:31:11 +0100 Subject: [PATCH 076/108] fix: make result ordering consistent --- src/dewret/workflow.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 30268d91..38b0a6b1 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -23,11 +23,11 @@ from attrs import has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter, OrderedDict -from typing import Protocol, Any, TypeVar, Generic, cast, Literal, Iterable, get_origin, get_args, Generator, Sized, Sequence, get_type_hints, TYPE_CHECKING +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, Iterable, get_origin, get_args, Generator, Sized, Sequence, get_type_hints, TYPE_CHECKING, Hashable from uuid import uuid4 import logging -from sympy import Symbol, Expr, Basic +from sympy import Symbol, Expr, Basic, Tuple logger = logging.getLogger(__name__) @@ -545,12 +545,24 @@ def assimilate(cls, *workflow_args: Workflow) -> "Workflow": for step in base.steps: step.set_workflow(base, with_arguments=True) - results = sorted(set((w.result for w in workflows if w.has_result))) + hashable_workflows: list[Workflow] = [w for w in workflows if isinstance(w.result, Hashable)] + if len(hashable_workflows) != len(workflows): + raise NotImplementedError("Some results are not hashable.") + + def _get_order(result: None | StepReference | Iterable[StepReference]) -> str: + if result is None: + return "" + if isinstance(result, StepReference): + return result.id + return "|".join(r for r in result) + + + results = sorted(set({w.result for w in hashable_workflows if w.has_result}), key=lambda r: _get_order(r)) if len(results) == 1: result = results[0] else: - results = sorted({r if isinstance(r, tuple | list) else (r,) for r in results}) - result = sum(map(list, results), []) + list_results = [r if isinstance(r, tuple | list | Tuple) else (r,) for r in results] + result = sum(map(list, list_results), []) if result is not None and result != []: unify_workflows(result, base, set_only=True) From fc88fb4df0e45669f7b40bd432cbb0fbdb7d14c8 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 01:42:30 +0100 Subject: [PATCH 077/108] fix: ensure that fieldability doesn't mean dask matches references as a Delayed --- src/dewret/backends/backend_dask.py | 9 +++++++-- src/dewret/workflow.py | 2 ++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 50806e54..b85ae776 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -17,7 +17,7 @@ Lazy-evaluation via `dask.delayed`. """ -from dask.delayed import delayed, DelayedLeaf +from dask.delayed import delayed, DelayedLeaf, Graph from dask.config import config from typing import Protocol, runtime_checkable, Any, cast from concurrent.futures import ThreadPoolExecutor @@ -35,6 +35,11 @@ class Delayed(Protocol): More info: https://github.com/dask/dask/issues/7779 """ + @property + def __dask_graph__(self) -> Graph: + """Retrieve the dask graph.""" + ... + def compute(self, __workflow__: Workflow | None) -> StepReference[Any]: """Evaluate this `dask.delayed`. @@ -120,7 +125,7 @@ def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread # f"{task} is not a dask delayed, perhaps you tried to mix backends?" # ) - if isinstance(task, Delayed): + if isinstance(task, Delayed) and is_lazy(task): computable = task else: computable = delayed(task) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 38b0a6b1..94ce4d74 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -908,6 +908,8 @@ def find_field(self: FieldableProtocol, field: str | int, fallback_type: type | f"Tried to index int {field} into type {parent_type} but can only do so if the first type argument " f"is the element type (args: {get_args(parent_type)}" ) + elif field.startswith("__"): + raise AttributeError(f"We do not allow fields with dunder prefix, such as {field}, to reduce risk of clashes.") else: if is_dataclass(parent_type): try: From 780a7d6d037e0b2f8600f5f694af9a8506136379 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 01:43:03 +0100 Subject: [PATCH 078/108] fix: restore 3.11 in test --- .github/workflows/python-test-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-test-ci.yml b/.github/workflows/python-test-ci.yml index 3a6764fe..2cba3af2 100644 --- a/.github/workflows/python-test-ci.yml +++ b/.github/workflows/python-test-ci.yml @@ -8,7 +8,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.12' + python-version: '3.11' - name: Install dependencies run: | python -m pip install --upgrade pip "hatchling < 1.22" From 3e589ee78f14b428fec317a23ef59bc4e6916b9f Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 01:44:35 +0100 Subject: [PATCH 079/108] ci: add 3.12 to the build matrix --- .github/workflows/python-test-ci.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-test-ci.yml b/.github/workflows/python-test-ci.yml index 2cba3af2..859ea89d 100644 --- a/.github/workflows/python-test-ci.yml +++ b/.github/workflows/python-test-ci.yml @@ -3,12 +3,16 @@ on: [push, pull_request] jobs: unit-pip: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11", "3.12"] steps: - uses: actions/checkout@v4 - - name: Set up Python + + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: - python-version: '3.11' + python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip "hatchling < 1.22" From 0c45e51d69d4129e24a3a981fdfe9fa7f48232f3 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 01:48:20 +0100 Subject: [PATCH 080/108] fix: restore commented test --- tests/test_cwl.py | 64 +++++++++++++++++++++++------------------------ 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/test_cwl.py b/tests/test_cwl.py index 4188424f..17bc30ab 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -209,38 +209,38 @@ def test_cwl_with_positional_parameter() -> None: """) -#def test_cwl_without_default() -> None: -# """Check whether we can produce CWL without a default value. -# -# Uses a manually created parameter to avoid a default. -# """ -# my_param = param("my_param", typ=int) -# -# result = increment(num=my_param) -# workflow = construct(result) -# rendered = render(workflow)["__root__"] -# hsh = hasher(("increment", ("num", "int|:param:my_param"))) -# -# assert rendered == yaml.safe_load(f""" -# cwlVersion: 1.2 -# class: Workflow -# inputs: -# my_param: -# label: my_param -# type: int -# outputs: -# out: -# label: out -# outputSource: increment-{hsh}/out -# type: int -# steps: -# increment-{hsh}: -# run: increment -# in: -# num: -# source: my_param -# out: [out] -# """) +def test_cwl_without_default() -> None: + """Check whether we can produce CWL without a default value. + + Uses a manually created parameter to avoid a default. + """ + my_param = param("my_param", typ=int) + + result = increment(num=my_param) + workflow = construct(result) + rendered = render(workflow)["__root__"] + hsh = hasher(("increment", ("num", "int|:param:my_param"))) + + assert rendered == yaml.safe_load(f""" + cwlVersion: 1.2 + class: Workflow + inputs: + my_param: + label: my_param + type: int + outputs: + out: + label: out + outputSource: increment-{hsh}/out + type: int + steps: + increment-{hsh}: + run: increment + in: + num: + source: my_param + out: [out] + """) def test_cwl_with_subworkflow() -> None: From 269e591677eafe53e193af526bcc4b3f150d7641 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 01:50:19 +0100 Subject: [PATCH 081/108] fix: mypy issue in tests on 3.11 --- tests/test_configuration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index b94e57e4..0720e751 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -19,7 +19,7 @@ def floor(num: int, expected: AtRender[bool]) -> int: """Converts int/float to int.""" from dewret.tasks import get_configuration if get_configuration("flatten_all_nested") != expected: - raise AssertionError(f"Not expected configuration: {get_configuration('flatten_all_nested')} != {expected}") + raise AssertionError(f"Not expected configuration: {str(get_configuration('flatten_all_nested'))} != {expected}") return increment(num=num) def test_cwl_with_parameter(configuration) -> None: From 4dfa8bdc1c152d0c8c11e779b9750dc63a63f9cd Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 03:14:06 +0100 Subject: [PATCH 082/108] chore: fix tests and types for mypy and ruff on 3.11 and 3.12 --- mypy.ini | 7 +++ pyproject.toml | 6 -- src/dewret/__main__.py | 10 +-- src/dewret/annotations.py | 13 ++-- src/dewret/backends/backend_dask.py | 5 +- src/dewret/core.py | 31 +++++----- src/dewret/render.py | 2 +- src/dewret/renderers/cwl.py | 41 +++++++----- src/dewret/renderers/snakemake.py | 13 ++-- src/dewret/tasks.py | 23 ++++--- src/dewret/workflow.py | 96 +++++++++++++++-------------- tests/_lib/frender.py | 17 ++--- tests/test_annotations.py | 2 +- tests/test_configuration.py | 12 +--- tests/test_fieldable.py | 46 ++++++++------ tests/test_modularity.py | 3 +- tests/test_multiresult_steps.py | 3 +- tests/test_nested.py | 2 +- tests/test_render_module.py | 2 +- tests/test_subworkflows.py | 7 ++- 20 files changed, 188 insertions(+), 153 deletions(-) create mode 100644 mypy.ini diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..66dca830 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,7 @@ +[mypy] +strict = True +# Required to allow subclassing of sympy.Symbol +disallow_subclassing_any = False + +[mypy-sympy.*] +ignore_missing_imports = true diff --git a/pyproject.toml b/pyproject.toml index b8992c2c..2c44f274 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,12 +27,6 @@ select = ["D", "F", "B"] [tool.ruff.lint.pydocstyle] convention = "google" -[[tool.mypy.overrides]] -module = [ - "sympy", -] -ignore_missing_imports = true - [project] name = "dewret" description = "DEclarative Workflow REndering Tool" diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 1816c30b..ced66e51 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -26,14 +26,14 @@ import sys import re import yaml -from typing import Any +from typing import Any, IO, Generator from types import ModuleType import click import json -from .core import set_configuration, set_render_configuration +from .core import set_configuration, set_render_configuration, RawRenderModule, StructuredRenderModule from .utils import load_module_or_package -from .render import get_render_method, RawRenderModule, StructuredRenderModule +from .render import get_render_method from .tasks import Backend, construct @@ -119,14 +119,14 @@ def render( if output == "-": @contextmanager - def _opener(key, _): + def _opener(key: str, _: str) -> Generator[IO[Any], None, None]: print(" ------ ", key, " ------ ") yield sys.stdout print() opener = _opener else: @contextmanager - def _opener(key, mode): + def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]: output_file = output.replace("%", key) with Path(output_file).open(mode) as output_f: yield output_f diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index 44897289..62a9ee36 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -55,7 +55,7 @@ def __init__(self, fn: Callable[..., Any]): ) @property - def return_type(self) -> type: + def return_type(self) -> Any: """Return type of the callable. Returns: expected type of the return value. @@ -64,9 +64,10 @@ def return_type(self) -> type: ValueError: if the return value does not appear to be type-hinted. """ hints = get_type_hints(inspect.unwrap(self.fn), include_extras=True) - if "return" not in hints: + if "return" not in hints or hints["return"] is None: raise ValueError(f"Could not find type-hint for return value of {self.fn}") - return hints["return"] + typ = hints["return"] + return typ @staticmethod def _typ_has(typ: type, annotation: type) -> bool: @@ -87,11 +88,11 @@ def _typ_has(typ: type, annotation: type) -> bool: return True return False - def get_all_module_names(self): + def get_all_module_names(self) -> dict[str, Any]: """Find all of the annotations within this module.""" return get_type_hints(sys.modules[self.fn.__module__], include_extras=True) - def get_all_imported_names(self): + def get_all_imported_names(self) -> dict[str, tuple[ModuleType, str]]: """Find all of the annotations that were imported into this module.""" return self._get_all_imported_names(sys.modules[self.fn.__module__]) @@ -125,7 +126,7 @@ def free_vars(self) -> dict[str, Any]: return dict(zip(self.fn.__code__.co_freevars, (c.cell_contents for c in self.fn.__closure__), strict=False)) return {} - def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> type | None: + def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> Any: """Retrieve the annotations for this argument. Args: diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index b85ae776..cc22ae34 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -17,7 +17,8 @@ Lazy-evaluation via `dask.delayed`. """ -from dask.delayed import delayed, DelayedLeaf, Graph +from dask.delayed import delayed, DelayedLeaf +from dask.typing import Graph from dask.config import config from typing import Protocol, runtime_checkable, Any, cast from concurrent.futures import ThreadPoolExecutor @@ -96,7 +97,7 @@ def is_lazy(task: Any) -> bool: lazy = delayed -def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread_pool: ThreadPoolExecutor | None=None, **kwargs: Any) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: +def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread_pool: ThreadPoolExecutor | None=None, **kwargs: Any) -> Any: """Execute a task as the output of a workflow. Runs a task with dask. diff --git a/src/dewret/core.py b/src/dewret/core.py index 9fb7d630..ee782367 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -33,11 +33,12 @@ BasicType = str | float | bool | bytes | int | None RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] +ExprType = FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...] # type: ignore U = TypeVar("U") T = TypeVar("T") -def strip_annotations(parent_type: type) -> tuple[type, tuple]: +def strip_annotations(parent_type: type) -> tuple[type, tuple[str]]: """Discovers and removes annotations from a parent type. Args: @@ -173,7 +174,7 @@ class GlobalConfiguration: Having a single configuration dict allows us to manage only one ContextVar. """ construct: ConstructConfiguration - render: dict + render: dict[str, RawType] CONFIGURATION: ContextVar[GlobalConfiguration] = ContextVar("configuration") @@ -189,7 +190,7 @@ def set_configuration(**kwargs: Unpack[ConstructConfigurationTypedDict]) -> Iter yield CONFIGURATION @contextmanager -def set_render_configuration(kwargs) -> Iterator[ContextVar[GlobalConfiguration]]: +def set_render_configuration(kwargs: dict[str, RawType]) -> Iterator[ContextVar[GlobalConfiguration]]: """Sets the render-time configuration. This is a context manager, so that a setting can be temporarily overridden and automatically restored. @@ -271,7 +272,8 @@ def get_configuration(key: str) -> RawType: try: return getattr(CONFIGURATION.get().construct, key) # type: ignore except LookupError: - return getattr(ConstructConfiguration(), key) + # TODO: Not sure what the best way to typehint this is. + return getattr(ConstructConfiguration(), key) # type: ignore def get_render_configuration(key: str) -> RawType: """Retrieve configuration for the active renderer. @@ -285,7 +287,7 @@ def get_render_configuration(key: str) -> RawType: Returns: (preferably) a JSON/YAML-serializable construct. """ try: - return CONFIGURATION.get().render.get(key) # type: ignore + return CONFIGURATION.get().render.get(key) except LookupError: return default_renderer_config().get(key) @@ -298,7 +300,7 @@ class WorkflowComponent: __workflow_real__: WorkflowProtocol - def __init__(self, *args, workflow: WorkflowProtocol, **kwargs): + def __init__(self, *args: Any, workflow: WorkflowProtocol, **kwargs: Any): """Tie to a `Workflow`. All subclasses must call this. @@ -328,7 +330,7 @@ class Reference(Generic[U], Symbol, WorkflowComponent): _type: type[U] | None = None __iterated__: bool = False - def __init__(self, *args, typ: type[U] | None = None, **kwargs): + def __init__(self, *args: Any, typ: type[U] | None = None, **kwargs: Any): """Extract any specified type. Args: @@ -340,15 +342,15 @@ def __init__(self, *args, typ: type[U] | None = None, **kwargs): super().__init__(*args, **kwargs) @property - def name(self): + def name(self) -> str: """Printable name of the reference.""" return self.__name__ - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> "Reference[U]": """As all references are sympy Expressions, the real returned object must be made with Expr.""" instance = Expr.__new__(cls) instance._assumptions0 = {} - return instance + return cast(Reference[U], instance) @property def __root_name__(self) -> str: @@ -369,10 +371,11 @@ def __type__(self) -> type: return self._type raise NotImplementedError() - def _raise_unevaluatable_error(self): + def _raise_unevaluatable_error(self) -> None: + """Convenience method to consistently formulate an UnevaluatableError for this reference.""" raise UnevaluatableError(f"This reference, {self.__name__}, cannot be evaluated during construction.") - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> Any: """Test equality at construct-time, if sensible. Raises: @@ -435,7 +438,7 @@ class IterableMixin(Reference[U]): """Functionality for iterating over references to give new references.""" __fixed_len__: int | None = None - def __init__(self, typ: type[U] | None=None, **kwargs): + def __init__(self, typ: type[U] | None=None, **kwargs: Any): """Extract length, if available from type. Captures types of the form (e.g.) `tuple[int, float]` and records the length @@ -493,7 +496,7 @@ def __inner_iter__(self) -> Generator[Any, None, None]: while True: yield None - def __getitem__(self, attr: str | int) -> Reference[U]: + def __getitem__(self, attr: str | int) -> "Reference[U] | Any": """Get a reference to an individual item/field. Args: diff --git a/src/dewret/render.py b/src/dewret/render.py index 35e23b0d..31b9e3da 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -71,7 +71,7 @@ def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, raise NotImplementedError("This render module neither seems to be a structured nor a raw render module.") return render_module.render_raw - def _render(workflow: Workflow, render_module: StructuredRenderModule, pretty=False, **kwargs: RenderConfiguration) -> dict[str, str]: + def _render(workflow: Workflow, render_module: StructuredRenderModule, pretty: bool=False, **kwargs: RenderConfiguration) -> dict[str, str]: rendered = render_module.render(workflow, **kwargs) return { key: structured_to_raw(value, pretty=pretty) diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 23b4214e..ee882a68 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -33,17 +33,19 @@ ) from dewret.workflow import ( FactoryCall, - Reference, Workflow, BaseStep, StepReference, ParameterReference, - Unset, expr_to_references ) -from dewret.utils import crawl_raw, DataclassProtocol, firm_to_raw, flatten_if_set +from dewret.utils import crawl_raw, DataclassProtocol, firm_to_raw, flatten_if_set, Unset from dewret.render import base_render -from dewret.core import get_render_configuration, set_render_configuration +from dewret.core import ( + Reference, + get_render_configuration, + set_render_configuration +) class CommandInputSchema(TypedDict): """Structure for referring to a raw type in CWL. @@ -77,7 +79,7 @@ def render_expression(ref: Any) -> "ReferenceDefinition": Returns: a ReferenceDefinition containing a string representation of the expression in the form `$(...)`. """ - def _render(ref): + def _render(ref: Any) -> Basic | RawType: if not isinstance(ref, Basic): if isinstance(ref, Mapping): ref = Dict({key: _render(val) for key, val in ref.items()}) @@ -145,7 +147,7 @@ def default_renderer_config() -> CWLRendererConfiguration: } -def with_type(result: Any) -> type: +def with_type(result: Any) -> type | Any: """Get a Python type from a value. Does so either by using its `__type__` field (for example, for References) @@ -168,11 +170,11 @@ def with_field(result: Any) -> str: Returns: a string representation of the field portion of the passed value or `"out"`. """ if hasattr(result, "__field__") and result.__field__: - return result.__field_str__ + return str(result.__field_str__) else: return "out" -def to_name(result: Reference[Any]): +def to_name(result: Reference[Any]) -> str: """Take a reference and get a name representing it. The primary purpose of this method is to deal with the case where a reference is to the @@ -196,7 +198,7 @@ class ReferenceDefinition: value_from: str | None @classmethod - def from_reference(cls, ref: Reference) -> "ReferenceDefinition": + def from_reference(cls, ref: Reference[Any]) -> "ReferenceDefinition": """Build from a `Reference`. Converts a `dewret.workflow.Reference` into a CWL-rendering object. @@ -441,8 +443,10 @@ def to_output_schema( fields=fields, ) else: + # TODO: this complains because NotRequired keys are never present, + # but that does not seem like a problem here - likely a better solution. output = CommandOutputSchema( - **to_cwl_type(label, typ) + **to_cwl_type(label, typ) # type: ignore ) if output_source is not None: output["outputSource"] = output_source @@ -500,7 +504,7 @@ class CommandInputParameter: @classmethod def from_parameters( - cls, parameters: list[ParameterReference | FactoryCall] + cls, parameters: list[ParameterReference[Any] | FactoryCall] ) -> "InputsDefinition": """Takes a list of parameters into a CWL structure. @@ -552,7 +556,7 @@ class OutputsDefinition: outputs: sequence of results from a workflow. """ - outputs: dict[str, "CommandOutputSchema"] | list["CommandOutputSchema"] + outputs: dict[str, "CommandOutputSchema"] | list["CommandOutputSchema"] | CommandOutputSchema @classmethod def from_results( @@ -565,9 +569,10 @@ def from_results( Returns: CWL-like structure representing all workflow outputs. """ - def _build_results(result): + def _build_results(result: Any) -> RawType: if isinstance(result, Reference): - return to_output_schema( + # TODO: need to work out how to tell mypy that a TypedDict is also dict[str, RawType] + return to_output_schema( # type: ignore with_field(result), with_type(result), output_source=to_name(result) ) results = result @@ -579,7 +584,8 @@ def _build_results(result): } ) try: - return cls(outputs=_build_results(results)) + # TODO: sort out this nested type building. + return cls(outputs=_build_results(results)) # type: ignore except AttributeError: expr, references = expr_to_references(results) reference_names = sorted( @@ -640,7 +646,7 @@ def from_workflow( workflow: workflow to convert. name: name of this workflow, if it should have one. """ - parameters: list[ParameterReference | FactoryCall] = list( + parameters: list[ParameterReference[Any] | FactoryCall] = list( workflow.find_parameters( include_factory_calls=not get_render_configuration("factories_as_params") ) @@ -696,7 +702,8 @@ def render( Reduced form as a native Python dict structure for serialization. """ - with set_render_configuration(kwargs): + # TODO: Again, convincing mypy that a TypedDict has RawType values. + with set_render_configuration(kwargs): # type: ignore rendered = base_render( workflow, lambda workflow: WorkflowDefinition.from_workflow(workflow).render() diff --git a/src/dewret/renderers/snakemake.py b/src/dewret/renderers/snakemake.py index f88bce75..8bbbdca8 100644 --- a/src/dewret/renderers/snakemake.py +++ b/src/dewret/renderers/snakemake.py @@ -26,15 +26,20 @@ from attrs import define -from dewret.core import Raw, BasicType -from dewret.workflow import ( +from dewret.core import ( + Raw, + BasicType, Reference, +) +from dewret.workflow import ( Workflow, Task, Lazy, BaseStep, ) -from dewret.render import base_render +from dewret.render import ( + base_render, +) MainTypes = typing.Union[ BasicType, list[str], list["MainTypes"], dict[str, "MainTypes"] @@ -59,7 +64,7 @@ class ReferenceDefinition: source: str @classmethod - def from_reference(cls, ref: Reference) -> "ReferenceDefinition": + def from_reference(cls, ref: Reference[typing.Any]) -> "ReferenceDefinition": """Build from a `Reference`. Converts a `dewret.workflow.Reference` into a Snakemake-rendering object. diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 31d55120..5786027c 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -49,20 +49,26 @@ expr_to_references, unify_workflows, UNSET, - Reference, Workflow, Lazy, LazyEvaluation, Target, LazyFactory, Parameter, + ParameterReference, param, Task, is_task, ) from .backends._base import BackendModule from .annotations import FunctionAnalyser -from .core import get_configuration, set_configuration, IteratedGenerator, ConstructConfigurationTypedDict +from .core import ( + get_configuration, + set_configuration, + IteratedGenerator, + ConstructConfigurationTypedDict, + Reference +) Param = ParamSpec("Param") RetType = TypeVar("RetType") @@ -136,7 +142,7 @@ def make_lazy(self) -> LazyFactory: """ return self.backend.lazy - def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy], __workflow__: Workflow, thread_pool=None, **kwargs: Any) -> Any: + def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy], __workflow__: Workflow, thread_pool: ThreadPoolExecutor | None=None, **kwargs: Any) -> Any: """Evaluate a single task for a known workflow. Args: @@ -210,7 +216,7 @@ def __call__( with set_configuration(**kwargs): context = copy_context().items() - def _initializer(): + def _initializer() -> None: for var, value in context: var.set(value) thread_pool = ThreadPoolExecutor(initializer=_initializer) @@ -421,9 +427,10 @@ def add_numbers(left: int, right: int): positional_args[key] = True sig.bind(**kwargs) - def _to_param_ref(value): + def _to_param_ref(value: Any) -> ParameterReference[Any] | None: if isinstance(value, Parameter): return value.make_reference(workflow=__workflow__) + return None refs = [] for key, val in kwargs.items(): @@ -459,7 +466,7 @@ def _to_param_ref(value): ) else None ) kwargs[var] = cast( - Parameter, + Parameter[Any], param( var, value, @@ -520,7 +527,7 @@ def {fn.__name__}(...) -> ...: not inspect.isclass(value) ): kwargs[var] = cast( - Parameter, + Parameter[Any], param( var, default=value, @@ -550,7 +557,7 @@ def {fn.__name__}(...) -> ...: nested_workflow = Workflow(name=fn.__name__) nested_globals: Param.kwargs = { var: cast( - Parameter, + Parameter[Any], param( var, default=value.__default__ if hasattr(value, "__default__") else UNSET, diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 94ce4d74..7b2bb2b2 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) -from .core import IterableMixin, Reference, get_configuration, Raw, IteratedGenerator, strip_annotations, WorkflowProtocol, WorkflowComponent +from .core import IterableMixin, Reference, get_configuration, Raw, IteratedGenerator, strip_annotations, WorkflowProtocol, WorkflowComponent, ExprType from .utils import hasher, is_raw, make_traceback, is_raw_type, is_expr, Unset T = TypeVar("T") @@ -164,7 +164,7 @@ def __init__( self.register_caller(tethered) @staticmethod - def is_loopable(typ: type): + def is_loopable(typ: type) -> bool: """Checks if this type can be looped over. In particular, checks if this is an iterable that is NOT a str or bytes, possibly disguised @@ -180,7 +180,7 @@ def is_loopable(typ: type): return inspect.isclass(base) and issubclass(base, Iterable) and not issubclass(base, str | bytes) @property - def __type__(self): + def __type__(self) -> type | Unset: """Type associated with this parameter.""" if self.__fixed_type__ is not UNSET: return self.__fixed_type__ @@ -196,7 +196,7 @@ def __type__(self): raw_type = type(default) return raw_type - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """Comparing two parameters. Currently, this uses the hashes. @@ -205,7 +205,7 @@ def __eq__(self, other): """ return hash(self) == hash(other) - def __new__(cls, *args, **kwargs) -> "Parameter": + def __new__(cls, *args: Any, **kwargs: Any) -> "Parameter[T]": """Creates a Parameter. Required, as Parameters are an instance of sympy Expression, so @@ -213,7 +213,7 @@ def __new__(cls, *args, **kwargs) -> "Parameter": """ instance = Expr.__new__(cls) instance._assumptions0 = {} - return instance + return cast(Parameter[T], instance) def __hash__(self) -> int: """Get a unique hash for this parameter.""" @@ -223,7 +223,7 @@ def __hash__(self) -> int: # ) return hash(self.__name__) - def make_reference(self, **kwargs) -> "ParameterReference": + def make_reference(self, **kwargs: Any) -> "ParameterReference[T]": """Creates a new reference for the parameter. The kwargs will be passed to the constructor, but the @@ -271,7 +271,7 @@ def register_caller(self, caller: BaseStep) -> None: self.__tethered__ = caller self.__callers__.append(caller) - def __getattr__(self, attr: str) -> Reference[T]: + def __getattr__(self, attr: str) -> Reference[T] | Any: """Retrieve a reference to a field within this Parameter. Arg: @@ -465,7 +465,7 @@ def find_factories(self) -> dict[str, FactoryCall]: def find_parameters( self, include_factory_calls: bool = True - ) -> set[Parameter]: + ) -> set[Parameter[Any]]: """Crawl steps for parameter references. As the workflow does not hold its own list of parameters, this @@ -549,7 +549,7 @@ def assimilate(cls, *workflow_args: Workflow) -> "Workflow": if len(hashable_workflows) != len(workflows): raise NotImplementedError("Some results are not hashable.") - def _get_order(result: None | StepReference | Iterable[StepReference]) -> str: + def _get_order(result: None | StepReference[Any] | Iterable[StepReference[Any]]) -> str: if result is None: return "" if isinstance(result, StepReference): @@ -659,7 +659,7 @@ def add_nested_step( def add_step( self, fn: Lazy, - kwargs: dict[str, Raw | Reference], + kwargs: dict[str, Raw | Reference[Any]], raw_as_parameter: bool = False, is_factory: bool = False, positional_args: dict[str, bool] | None = None @@ -784,7 +784,7 @@ def __field_index_types__(self) -> tuple[type, ...]: Will be taken from the `field_index_types` construct configuration. """ - def __init__(self, *args, field: str | None = None, **kwargs): + def __init__(self, *args: Any, field: str | None = None, **kwargs: Any): """Extract the field name from the initializer arguments, if provided.""" super().__init__(*args, **kwargs) @@ -803,7 +803,7 @@ def name(self) -> str: """The name for the target, accounting for the field.""" return "name" - def __make_reference__(self, *args, **kwargs) -> Reference[Any]: + def __make_reference__(self, *args: Any, **kwargs: Any) -> Reference[Any]: """Create a reference with constructor arguments, usually to a subfield.""" ... @@ -819,7 +819,7 @@ def __make_reference__(self, *args, **kwargs) -> Reference[Any]: class FieldableMixin(_Fieldable): """Tooling for enhancing a type with referenceable fields.""" - def __init__(self: FieldableProtocol, *args, field: str | int | tuple[str | int, ...] | None = None, **kwargs): + def __init__(self: FieldableProtocol, *args: Any, field: str | int | tuple[str | int, ...] | None = None, **kwargs: Any): """Extract the requested field, if any, from the initializer arguments. Args: @@ -884,7 +884,7 @@ def __name__(self) -> str: """ return super().__name__ + self.__field_suffix__ - def find_field(self: FieldableProtocol, field: str | int, fallback_type: type | None = None, **init_kwargs: Any) -> Reference: + def find_field(self: FieldableProtocol, field: str | int, fallback_type: type | None = None, **init_kwargs: Any) -> Reference[Any]: """Field within the reference, if possible. Args: @@ -967,7 +967,7 @@ class BaseStep(WorkflowComponent): _id: str | None = None task: Task | Workflow - arguments: Mapping[str, Basic | Reference | Raw] + arguments: Mapping[str, Basic | Reference[Any] | Raw] workflow: Workflow positional_args: dict[str, bool] | None = None @@ -975,7 +975,7 @@ def __init__( self, workflow: Workflow, task: Task | Workflow, - arguments: Mapping[str, Reference | Raw], + arguments: Mapping[str, Reference[Any] | Raw], raw_as_parameter: bool = False, ): """Initialize a step. @@ -1008,13 +1008,14 @@ def __init__( ): if raw_as_parameter: # We use param for convenience but note that it is a reference in disguise. - value = cast(Parameter, param(key, value, tethered=None)).make_reference(workflow=workflow) + value = cast(Parameter[Any], param(key, value, tethered=None)).make_reference(workflow=workflow) else: value = Raw(value) - def _to_param_ref(value): + def _to_param_ref(value: Any) -> ParameterReference[Any] | None: if isinstance(value, Parameter): return value.make_reference(workflow=workflow) + return None expression, refs = expr_to_references(value, remap=_to_param_ref) for ref in refs: @@ -1041,7 +1042,7 @@ def __eq__(self, other: object) -> bool: and self.arguments == other.arguments ) - def make_reference(self, **kwargs) -> "StepReference": + def make_reference(self, **kwargs: Any) -> "StepReference[T]": """Create a reference to this step. Builds a reference to the (result of) this step, which will be iterable if appropriate. @@ -1150,7 +1151,7 @@ def __init__( workflow: Workflow, name: str, subworkflow: Workflow, - arguments: Mapping[str, Basic | Reference | Raw], + arguments: Mapping[str, Basic | Reference[Any] | Raw], raw_as_parameter: bool = False, ): """Create a NestedStep. @@ -1163,7 +1164,7 @@ def __init__( raw_as_parameter: whether raw-type arguments should be made (outer) workflow parameters. """ self.__subworkflow__ = subworkflow - base_arguments: dict[str, Basic | Reference | Raw] = {p.name: p for p in subworkflow.find_parameters()} + base_arguments: dict[str, Basic | Reference[Any] | Raw] = {p.name: p for p in subworkflow.find_parameters()} base_arguments.update(arguments) super().__init__( workflow=workflow, @@ -1206,7 +1207,7 @@ def __init__( self, workflow: Workflow, task: Task | Workflow, - arguments: Mapping[str, Reference | Raw], + arguments: Mapping[str, Reference[Any] | Raw], raw_as_parameter: bool = False, ): """Initialize a step. @@ -1231,7 +1232,7 @@ def __init__( super().__init__(workflow=workflow, task=task, arguments=arguments, raw_as_parameter=raw_as_parameter) @property - def __name__(self): + def __name__(self) -> str: """Get the name of this factory call.""" return self.name @@ -1269,7 +1270,7 @@ class ParameterReferenceMetadata(Generic[T]): """ parameter: Parameter[T] - def __init__(self, parameter: Parameter[T], *args, typ: type[U] | None=None, **kwargs): + def __init__(self, parameter: Parameter[T], typ: type[U] | Unset=UNSET): """Initialize the reference. Args: @@ -1315,7 +1316,7 @@ def __root_name__(self) -> str: """ return self._.parameter.name - def __init__(self, parameter: Parameter[U], *args, typ: type[U] | None=None, **kwargs): + def __init__(self, parameter: Parameter[U], *args: Any, typ: type[U] | None=None, **kwargs: Any): """Extract the parameter and type for setup. Args: @@ -1324,11 +1325,11 @@ def __init__(self, parameter: Parameter[U], *args, typ: type[U] | None=None, **k *args: arguments for other initializers. **kwargs: arguments for other initializers. """ - typ = typ or parameter.__type__ - self._ = self.ParameterReferenceMetadata(parameter, *args, typ, **kwargs) + chosen_type = typ or parameter.__type__ + self._ = self.ParameterReferenceMetadata(parameter, typ=chosen_type) super().__init__(*args, typ=typ, **kwargs) - def __getitem__(self, attr: str) -> "ParameterReference": + def __getitem__(self, attr: str) -> "ParameterReference[U]": """Retrieve a field. Args: @@ -1357,7 +1358,7 @@ def __original_name__(self) -> str: """The name of the original parameter, without any field, etc.""" return self._.parameter.__original_name__ - def __getattr__(self, attr: str) -> "ParameterReference": + def __getattr__(self, attr: str) -> "ParameterReference[U] | Any": """Retrieve a field. Args: @@ -1402,13 +1403,13 @@ def __eq__(self, other: object) -> bool: (isinstance(other, ParameterReference) and self._.parameter == other._.parameter and self.__field__ == other.__field__) ) - def __make_reference__(self, **kwargs) -> "ParameterReference": + def __make_reference__(self, **kwargs: Any) -> "ParameterReference[U]": """Get a reference for the same parameter.""" return self._.parameter.make_reference(**kwargs) -class IterableParameterReference(IterableMixin, ParameterReference[U]): +class IterableParameterReference(IterableMixin[U], ParameterReference[U]): """Iterable form of parameter references.""" - def __iter__(self) -> Generator[Reference, None, None]: + def __iter__(self) -> Generator[Reference[U], None, None]: """Iterate over this parameter. Returns: @@ -1437,7 +1438,7 @@ def __inner_iter__(self) -> Generator[Any, None, None]: while True: yield None - def __len__(self): + def __len__(self) -> int: """If it is possible to get a hard-codeable length from this iterable parameter, do so.""" inner, metadata = strip_annotations(self.__type__) if metadata and "Fixed" in metadata and isinstance(self.__default__, Sized): @@ -1486,7 +1487,7 @@ def return_type(self) -> type: _: StepReferenceMetadata def __init__( - self, step: BaseStep, *args, typ: type[U] | None = None, **kwargs + self, step: BaseStep, *args: Any, typ: type[U] | None = None, **kwargs: Any ): """Initialize the reference. @@ -1543,7 +1544,7 @@ def __getitem__(self, attr: str) -> "StepReference[Any]": ) ) from exc - def __getattr__(self, attr: str) -> "StepReference": + def __getattr__(self, attr: str) -> "StepReference[U] | Any": """Retrieve a field within this workflow.""" try: return self[attr] @@ -1585,13 +1586,13 @@ def __workflow__(self, workflow: Workflow) -> None: """ self._.step.set_workflow(workflow) - def __make_reference__(self, **kwargs) -> "StepReference": + def __make_reference__(self, **kwargs: Any) -> "StepReference[U]": """Create a new reference for the same step.""" return self._.step.make_reference(**kwargs) -class IterableStepReference(IterableMixin, StepReference[U]): +class IterableStepReference(IterableMixin[U], StepReference[U]): """Iterable form of a step reference.""" - def __iter__(self) -> Generator[Reference, None, None]: + def __iter__(self) -> Generator[Reference[U], None, None]: """Gets a sentinel value for iterating over this step's results. Bear in mind that this means an iterable step reference will iterate exactly once, @@ -1600,7 +1601,7 @@ def __iter__(self) -> Generator[Reference, None, None]: for zipping with a fixed length iterator, or simply prepping fieldnames for serialization. """ # We cast this so that we can treat a step iterator as if it really loops over results. - yield cast(Reference, IteratedGenerator(self)) + yield cast(Reference[U], IteratedGenerator(self)) def is_task(task: Lazy) -> bool: """Decide whether this is a task. @@ -1617,7 +1618,7 @@ def is_task(task: Lazy) -> bool: """ return isinstance(task, LazyEvaluation) -def expr_to_references(expression: Any, remap: Callable[[Any], Any] | None = None) -> tuple[Basic | None, list[Reference | Parameter]]: +def expr_to_references(expression: Any, remap: Callable[[Any], Any] | None = None) -> tuple[ExprType, list[Reference[Any] | Parameter[Any]]]: """Pull out any references, or other free symbols, from an expression. Args: @@ -1626,8 +1627,8 @@ def expr_to_references(expression: Any, remap: Callable[[Any], Any] | None = Non Returns: a pair of the expression with any applied simplifications/standardizations, and the list of References/Parameters found. """ - to_check: list[Reference | Parameter] = [] - def _to_expr(value): + to_check: list[Reference[Any] | Parameter[Any]] = [] + def _to_expr(value: Any) -> ExprType: if remap and (res := remap(value)) is not None: return _to_expr(res) @@ -1649,9 +1650,9 @@ def _to_expr(value): if is_dataclass(value) or attr_has(value): if is_dataclass(value): - fields = dataclass_fields(value) + fields = list(dataclass_fields(value)) else: - fields = {field for field in attrs_fields(value.__class__)} + fields = list({field for field in attrs_fields(value)}) for field in fields: if hasattr(value, field.name) and isinstance((val := getattr(value, field.name)), Reference): setattr(value, field.name, _to_expr(val)) @@ -1667,7 +1668,10 @@ def _to_expr(value): dct = {key: _to_expr(val) for key, val in value.items()} if dct == value: return retval - return value.__class__(dct) + # We try to reinstantiate this, but there will be an error otherwise. + # TODO: we could check this with a protocol, but some care would be needed to ensure all valid + # cases are covered. + return value.__class__(dct) # type: ignore elif not isinstance(value, str | bytes) and isinstance(value, Iterable): lst = (tuple if isinstance(value, tuple) else list)(_to_expr(v) for v in value) if lst == value: diff --git a/tests/_lib/frender.py b/tests/_lib/frender.py index d04b0fe2..c2c25b49 100644 --- a/tests/_lib/frender.py +++ b/tests/_lib/frender.py @@ -27,13 +27,13 @@ class NestedStepDefinition: subworkflow_name: str @classmethod - def from_nested_step(cls, nested_step: NestedStep): + def from_nested_step(cls, nested_step: NestedStep) -> "NestedStepDefinition": return cls( name=nested_step.name, subworkflow_name=nested_step.subworkflow.name ) - def render(self): + def render(self) -> str: return \ f""" A portal called {self.name} to another workflow, @@ -45,12 +45,12 @@ class StepDefinition: name: str @classmethod - def from_step(cls, step: Step): + def from_step(cls, step: Step) -> "StepDefinition": return cls( name=step.name ) - def render(self): + def render(self) -> str: return \ f""" Something called {self.name} @@ -63,8 +63,8 @@ class WorkflowDefinition: steps: list[StepDefinition | NestedStepDefinition] @classmethod - def from_workflow(cls, workflow: Workflow): - steps = [] + def from_workflow(cls, workflow: Workflow) -> "WorkflowDefinition": + steps: list[StepDefinition | NestedStepDefinition] = [] for step in workflow.indexed_steps.values(): if isinstance(step, Step): steps.append(StepDefinition.from_step(step)) @@ -79,7 +79,7 @@ def from_workflow(cls, workflow: Workflow): name = "Work Doe" return cls(name=name, steps=steps) - def render(self): + def render(self) -> str: steps = "\n".join('* ' + indent(step.render(), ' ')[3:] for step in self.steps) return \ f""" @@ -103,7 +103,8 @@ def render_raw( Reduced form as a native Python dict structure for serialization. """ - set_render_configuration(kwargs) + # TODO: work out how to handle these hints correctly. + set_render_configuration(kwargs) # type: ignore return base_render( workflow, lambda workflow: WorkflowDefinition.from_workflow(workflow).render() diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 9a709003..252b50a1 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -38,7 +38,7 @@ def to_int(num: int, should_double: AtRender[bool]) -> int | float: """Cast to an int.""" return increment(num=num) if should_double else sum(left=num, right=num) -def test_can_analyze_annotations(): +def test_can_analyze_annotations() -> None: """TODO: Docstring.""" my_obj = MyClass() diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 0720e751..a74cb66c 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -4,25 +4,19 @@ import pytest from dewret.tasks import construct, workflow, TaskException from dewret.renderers.cwl import render -from dewret.tasks import set_configuration +from dewret.core import set_configuration from dewret.annotations import AtRender from ._lib.extra import increment -@pytest.fixture -def configuration(): - """TODO: Docstring.""" - with set_configuration() as configuration: - yield configuration.get() - @workflow() def floor(num: int, expected: AtRender[bool]) -> int: """Converts int/float to int.""" - from dewret.tasks import get_configuration + from dewret.core import get_configuration if get_configuration("flatten_all_nested") != expected: raise AssertionError(f"Not expected configuration: {str(get_configuration('flatten_all_nested'))} != {expected}") return increment(num=num) -def test_cwl_with_parameter(configuration) -> None: +def test_cwl_with_parameter() -> None: """TODO: Docstring.""" with set_configuration(flatten_all_nested=True): result = increment(num=floor(num=3, expected=True)) diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index e669f63f..19c511f6 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -6,8 +6,9 @@ from typing import Unpack, TypedDict -from dewret.tasks import task, construct, workflow, set_configuration -from dewret.workflow import param +from dewret.tasks import task, construct, workflow +from dewret.core import set_configuration +from dewret.workflow import param, StepReference from dewret.renderers.cwl import render from dewret.annotations import Fixed @@ -77,10 +78,10 @@ class MyDataclass: left: int right: "MyDataclass" -def test_can_get_field_reference_from_parameter(): +def test_can_get_field_reference_from_parameter() -> None: """TODO: Docstring.""" my_param = param("my_param", typ=MyDataclass) - result = sum(left=my_param.left, right=sum(left=my_param.right.left, right=my_param)) + result = sum(left=my_param.left, right=sum(left=my_param.right.left, right=my_param.left)) wkflw = construct(result, simplify_ids=True) params = {(str(p), p.__type__) for p in wkflw.find_parameters()} @@ -96,7 +97,7 @@ def test_can_get_field_reference_from_parameter(): outputs: out: label: out - outputSource: sum-2/out + outputSource: sum-1/out type: - int - float @@ -104,37 +105,37 @@ def test_can_get_field_reference_from_parameter(): sum-1: in: left: - source: my_param/right/left + source: my_param/left right: - source: my_param + source: sum-2/out out: - out run: sum sum-2: in: left: - source: my_param/left + source: my_param/right/left right: - source: sum-1/out + source: my_param/left out: - out run: sum """) -def test_can_get_field_reference_iff_parent_type_has_field(): +def test_can_get_field_reference_iff_parent_type_has_field() -> None: """TODO: Docstring.""" @dataclass class MyDataclass: left: int my_param = param("my_param", typ=MyDataclass) - result = sum(left=my_param, right=my_param) + result = sum(left=my_param.left, right=my_param.left) wkflw = construct(result, simplify_ids=True) param_reference = list(wkflw.find_parameters())[0] assert str(param_reference.left) == "my_param/left" assert param_reference.left.__type__ == int -def test_can_get_field_references_from_dataclass(): +def test_can_get_field_references_from_dataclass() -> None: """TODO: Docstring.""" @dataclass class MyDataclass: @@ -153,6 +154,7 @@ def get_left(my_dataclass: MyDataclass) -> int: result = get_left(my_dataclass=test_dataclass(my_dataclass=MyDataclass(left=3, right=4.))) wkflw = construct(result, simplify_ids=True) + assert isinstance(wkflw.result, StepReference) assert str(wkflw.result) == "get_left-1" assert wkflw.result.__type__ == int @@ -161,7 +163,7 @@ class MyDict(TypedDict): left: int right: float -def test_can_get_field_references_from_typed_dict(): +def test_can_get_field_references_from_typed_dict() -> None: """TODO: Docstring.""" @workflow() def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: @@ -171,6 +173,7 @@ def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: result = test_dict(left=3, right=4.) wkflw = construct(result, simplify_ids=True) + assert isinstance(wkflw.result, StepReference) assert str(wkflw.result["left"]) == "test_dict-1/left" assert wkflw.result["left"].__type__ == int @@ -179,7 +182,7 @@ class MyListWrapper: """TODO: Docstring.""" my_list: list[int] -def test_can_iterate(): +def test_can_iterate() -> None: """TODO: Docstring.""" @task() def test_task(alpha: int, beta: float, charlie: bool) -> int: @@ -191,7 +194,8 @@ def test_list() -> list[int | float]: @workflow() def test_iterated() -> int: - return test_task(*test_list()) + # We ignore the type as mypy cannot confirm that the length and types match the args. + return test_task(*test_list()) # type: ignore with set_configuration(allow_positional_args=True, flatten_all_nested=True): result = test_iterated() @@ -227,6 +231,7 @@ def test_iterated() -> int: run: test_task """) + assert isinstance(wkflw.result, StepReference) assert wkflw.result._.step.positional_args == {"alpha": True, "beta": True, "charlie": True} @task() @@ -235,7 +240,8 @@ def test_list_2() -> MyListWrapper: @workflow() def test_iterated_2(my_wrapper: MyListWrapper) -> int: - return test_task(*my_wrapper.my_list) + # mypy cannot confirm argument types match. + return test_task(*my_wrapper.my_list) # type: ignore with set_configuration(allow_positional_args=True, flatten_all_nested=True): result = test_iterated_2(my_wrapper=test_list_2()) @@ -247,7 +253,8 @@ def test_list_3() -> Fixed[list[tuple[int, int]]]: @workflow() def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: - retval = mod10(*test_list_3()[0]) + # mypy cannot confirm argument types match. + retval = mod10(*test_list_3()[0]) # type: ignore for pair in param: a, b = pair retval += a + b @@ -300,7 +307,7 @@ def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: run: test_list_3 """) -def test_can_use_plain_dict_fields(): +def test_can_use_plain_dict_fields() -> None: """TODO: Docstring.""" @workflow() def test_dict(left: int, right: float) -> dict[str, float | int]: @@ -310,6 +317,7 @@ def test_dict(left: int, right: float) -> dict[str, float | int]: with set_configuration(allow_plain_dict_fields=True): result = test_dict(left=3, right=4.) wkflw = construct(result, simplify_ids=True) + assert isinstance(wkflw.result, StepReference) assert str(wkflw.result["left"]) == "test_dict-1/left" assert wkflw.result["left"].__type__ == int | float @@ -318,7 +326,7 @@ class IndexTest: """TODO: Docstring.""" left: Fixed[list[int]] -def test_can_configure_field_separator(): +def test_can_configure_field_separator() -> None: """TODO: Docstring.""" @task() def test_sep() -> IndexTest: diff --git a/tests/test_modularity.py b/tests/test_modularity.py index 8b085457..4fdc522a 100644 --- a/tests/test_modularity.py +++ b/tests/test_modularity.py @@ -1,7 +1,8 @@ """Verify CWL can be made with split up and nested calls.""" import yaml -from dewret.tasks import workflow, construct, set_configuration +from dewret.tasks import workflow, construct +from dewret.core import set_configuration from dewret.renderers.cwl import render from ._lib.extra import double, sum, increase diff --git a/tests/test_multiresult_steps.py b/tests/test_multiresult_steps.py index d7d7ca8a..f84f6ddf 100644 --- a/tests/test_multiresult_steps.py +++ b/tests/test_multiresult_steps.py @@ -4,7 +4,8 @@ from attr import define from dataclasses import dataclass from typing import Iterable -from dewret.tasks import task, construct, workflow, set_configuration +from dewret.tasks import task, construct, workflow +from dewret.core import set_configuration from dewret.renderers.cwl import render STARTING_NUMBER: int = 23 diff --git a/tests/test_nested.py b/tests/test_nested.py index 3d901d8c..3d126326 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -8,7 +8,7 @@ from ._lib.extra import reverse_list, max_list -def test_can_supply_nested_raw(): +def test_can_supply_nested_raw() -> None: """TODO: Docstrings.""" pi = param("pi", math.pi) result = reverse_list(to_sort=[1., 3., pi]) diff --git a/tests/test_render_module.py b/tests/test_render_module.py index c8760aab..2cb7b9e1 100644 --- a/tests/test_render_module.py +++ b/tests/test_render_module.py @@ -6,7 +6,7 @@ from ._lib.extra import increment, triple_and_one -def test_can_load_render_module(): +def test_can_load_render_module() -> None: """TODO: Docstrings.""" result = triple_and_one(num=increment(num=3)) workflow = construct(result, simplify_ids=True) diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 0fdd43cd..73f3c35d 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -3,7 +3,8 @@ from typing import Callable from queue import Queue import yaml -from dewret.tasks import construct, workflow, task, factory, set_configuration +from dewret.tasks import construct, workflow, task, factory +from dewret.core import set_configuration from dewret.renderers.cwl import render from dewret.workflow import param from attrs import define @@ -14,7 +15,7 @@ QueueFactory: Callable[..., Queue[int]] = factory(Queue) -GLOBAL_QUEUE: Queue = QueueFactory() +GLOBAL_QUEUE: Queue[int] = QueueFactory() @task() @@ -552,7 +553,7 @@ class PackResult: spades: int diamonds: int -def test_combining_attrs_and_factories(): +def test_combining_attrs_and_factories() -> None: """TODO: Docstrings.""" Pack = factory(PackResult) From 33874a33ca84759c440b28dc7ce55f1b1ff2b288 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 03:31:30 +0100 Subject: [PATCH 083/108] ci: allow subclassing any for mypy to enable subclassing Symbol --- .github/workflows/python-static-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-static-ci.yml b/.github/workflows/python-static-ci.yml index f1925412..14b41e89 100644 --- a/.github/workflows/python-static-ci.yml +++ b/.github/workflows/python-static-ci.yml @@ -14,7 +14,7 @@ jobs: - uses: python/mypy@v1.8.0 with: paths: "./src" - - run: pip install .[test] && python -m mypy --strict --install-types --non-interactive ./src ./tests ./example + - run: pip install .[test] && python -m mypy --strict --install-types --allow-subclassing-any --non-interactive ./src ./tests ./example audit: runs-on: ubuntu-latest From 3ec7e394cb5ce58c6dc28e3ca4b06e5784a03f4b Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 03:32:43 +0100 Subject: [PATCH 084/108] fix: make __make_reference__ pass workflow if none --- src/dewret/workflow.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 7b2bb2b2..81ba3914 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -1405,6 +1405,8 @@ def __eq__(self, other: object) -> bool: def __make_reference__(self, **kwargs: Any) -> "ParameterReference[U]": """Get a reference for the same parameter.""" + if "workflow" not in kwargs: + kwargs["workflow"] = self.__workflow__ return self._.parameter.make_reference(**kwargs) class IterableParameterReference(IterableMixin[U], ParameterReference[U]): @@ -1588,6 +1590,8 @@ def __workflow__(self, workflow: Workflow) -> None: def __make_reference__(self, **kwargs: Any) -> "StepReference[U]": """Create a new reference for the same step.""" + if "workflow" not in kwargs: + kwargs["workflow"] = self.__workflow__ return self._.step.make_reference(**kwargs) class IterableStepReference(IterableMixin[U], StepReference[U]): From d0e0d37408127dbee3560147b3676257bda20c79 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 03:35:52 +0100 Subject: [PATCH 085/108] fix: remove variable re-use --- .../snakemake_renderer/basic_example/snakemake_workflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/snakemake_renderer/basic_example/snakemake_workflow.py b/example/snakemake_renderer/basic_example/snakemake_workflow.py index 8967a892..1d4d261c 100644 --- a/example/snakemake_renderer/basic_example/snakemake_workflow.py +++ b/example/snakemake_renderer/basic_example/snakemake_workflow.py @@ -172,5 +172,5 @@ def generate_report(processed_data: str, multiple_arg: str, output_file: str) -> "]": "", } ) - smk_output = yaml.dump(smk_output, indent=4).translate(trans_table) - file.write(smk_output) + smk_text = yaml.dump(smk_output, indent=4).translate(trans_table) + file.write(smk_text) From e27db971c1c333e11e02c3cf8414e4f0e752caa5 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 03:44:41 +0100 Subject: [PATCH 086/108] feat: can go upwards in fields --- src/dewret/workflow.py | 12 ++++++++++++ tests/test_fieldable.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 81ba3914..a37e6d18 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -875,6 +875,18 @@ def __field_suffix__(self) -> str: result += f"{self.__field_sep__}{cmpt}" return result + def __field_up__(self) -> Reference[Any]: + """Get the parent field, if possible. + + Returns: reference to the field above this. + + Raises: + RuntimeError: if there is no field above this. + """ + if self.__field__: + return self.__make_reference__(field=tuple(self.__field__[:-1])) + raise RuntimeError("Cannot go upwards unless currently using a field.") + @property def __name__(self) -> str: """Name for this step. diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index 19c511f6..41f86ed4 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -135,6 +135,20 @@ class MyDataclass: assert str(param_reference.left) == "my_param/left" assert param_reference.left.__type__ == int +def test_can_get_go_upwards_from_a_field_reference() -> None: + """TODO: Docstring.""" + @dataclass + class MyDataclass: + left: int + right: "MyDataclass" + my_param = param("my_param", typ=MyDataclass) + result = sum(left=my_param.left, right=my_param.left) + construct(result, simplify_ids=True) + + back = my_param.right.left.__field_up__() # type: ignore + assert str(back) == "my_param/right" + assert back.__type__ == MyDataclass + def test_can_get_field_references_from_dataclass() -> None: """TODO: Docstring.""" @dataclass From d890bd37ebb1016fa17be5501d5ed62e4bee558b Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Sun, 25 Aug 2024 03:48:55 +0100 Subject: [PATCH 087/108] fix: do not lose the field structure when iterating --- src/dewret/core.py | 2 +- tests/test_fieldable.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index ee782367..09e14e33 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -533,7 +533,7 @@ def __iter__(self) -> Generator[Reference[U], None, None]: """ count = -1 for _ in self.__wrapped__.__inner_iter__(): - ref = self.__wrapped__.__make_reference__(workflow=self.__wrapped__.__workflow__, field=(count := count + 1)) + ref = self.__wrapped__[(count := count + 1)] ref.__iterated__ = True yield ref diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index 41f86ed4..d6f95ee7 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -296,21 +296,21 @@ def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: outputs: out: label: out - outputSource: mod10-1/out + outputSource: mod10-2/out type: int steps: mod10-1: in: num: - valueFrom: $(inputs.param[0][0] + inputs.param[0][1] + inputs.param[1][0] + inputs.param[1][1] + self) - source: mod10-2/out + source: test_list_3-1[0][0] out: - out run: mod10 mod10-2: in: num: - source: test_list_3-1[0] + valueFrom: $(inputs.param[0][0] + inputs.param[0][1] + inputs.param[1][0] + inputs.param[1][1] + self) + source: mod10-1/out out: - out run: mod10 From 351d344467ca6e262b43cba6329d8dc6647e523b Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Thu, 29 Aug 2024 09:02:49 +0100 Subject: [PATCH 088/108] fix: add a catch for a bug in the first import --- src/dewret/utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 895cd3cd..9f0381ab 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -91,11 +91,16 @@ def load_module_or_package(target_name: str, path: Path) -> ModuleType: exception = exc if module is None: - spec = importlib.util.spec_from_file_location(target_name, str(path)) - if spec is None or spec.loader is None: - raise ImportError(f"Could not open {path} module") from exception - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + try: + spec = importlib.util.spec_from_file_location(target_name, str(path)) + if spec is None or spec.loader is None: + raise ImportError(f"Could not open {path} module") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + except ImportError as exc: + if exception: + raise exc from exception + raise exc return module From 3bd55e41e10fef89f0d151599928db464a6de4f6 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Fri, 30 Aug 2024 11:39:25 +0100 Subject: [PATCH 089/108] fix: remove unnecessary abstractmethod annotation --- src/dewret/core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index 09e14e33..f70a8904 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -98,7 +98,6 @@ def default_config() -> dict[str, RawType]: @runtime_checkable class RawRenderModule(BaseRenderModule, Protocol): """Render module that returns raw text.""" - @abstractmethod def render_raw(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> dict[str, str]: """Turn a workflow into flat strings. @@ -109,7 +108,6 @@ def render_raw(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) @runtime_checkable class StructuredRenderModule(BaseRenderModule, Protocol): """Render module that returns JSON/YAML-serializable structures.""" - @abstractmethod def render(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> dict[str, dict[str, RawType]]: """Turn a workflow into a serializable structure. From a7c0d8d9e21673534482e42175f3e0e5e9dde4df Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Fri, 30 Aug 2024 12:04:28 +0100 Subject: [PATCH 090/108] fix: remove unnecessary abstractmethod annotation --- src/dewret/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dewret/core.py b/src/dewret/core.py index f70a8904..ba570cd8 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -19,7 +19,6 @@ """ from dataclasses import dataclass -from abc import abstractmethod import importlib import base64 from attrs import define From c32469e3e4d681a03964376473139468e4acdb39 Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Mon, 2 Sep 2024 10:39:50 +0300 Subject: [PATCH 091/108] chore: Replace all Unions with | for consistency --- docs/renderer_tutorial.md | 2 +- src/dewret/core.py | 95 ++++++++++++--- src/dewret/renderers/cwl.py | 189 +++++++++++++++++++----------- src/dewret/renderers/snakemake.py | 6 +- 4 files changed, 204 insertions(+), 88 deletions(-) diff --git a/docs/renderer_tutorial.md b/docs/renderer_tutorial.md index 27be6726..f046b6ea 100644 --- a/docs/renderer_tutorial.md +++ b/docs/renderer_tutorial.md @@ -405,7 +405,7 @@ from dewret.utils import Raw, BasicType from dewret.workflow import Lazy from dewret.workflow import Reference, Workflow, Step, Task -RawType = typing.Union[BasicType, list[str], list["RawType"], dict[str, "RawType"]] +RawType = BasicType | list[str] | list["RawType"] | dict[str, "RawType"] ``` ## To run this example: diff --git a/src/dewret/core.py b/src/dewret/core.py index ba570cd8..72be23ae 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -23,20 +23,37 @@ import base64 from attrs import define from functools import lru_cache -from typing import Generic, TypeVar, Protocol, Iterator, Unpack, TypedDict, NotRequired, Generator, Union, Any, get_args, get_origin, Annotated, Callable, cast, runtime_checkable +from typing import ( + Generic, + TypeVar, + Protocol, + Iterator, + Unpack, + TypedDict, + NotRequired, + Generator, + Any, + get_args, + get_origin, + Annotated, + Callable, + cast, + runtime_checkable, +) from contextlib import contextmanager from contextvars import ContextVar from sympy import Expr, Symbol, Basic import copy BasicType = str | float | bool | bytes | int | None -RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] +RawType = BasicType | list["RawType"] | dict[str, "RawType"] FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] -ExprType = FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...] # type: ignore +ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore U = TypeVar("U") T = TypeVar("T") + def strip_annotations(parent_type: type) -> tuple[type, tuple[str]]: """Discovers and removes annotations from a parent type. @@ -53,14 +70,17 @@ def strip_annotations(parent_type: type) -> tuple[type, tuple[str]]: metadata += list(parent_metadata) return parent_type, tuple(metadata) + RenderConfiguration = dict[str, RawType] + class WorkflowProtocol(Protocol): """Expected structure for a workflow. We do not expect various workflow implementations, but this allows us to define the interface expected by the core classes. """ + def remap(self, name: str) -> str: """Perform any name-changing for steps, etc. in the workflow. @@ -82,8 +102,10 @@ def simplify_ids(self, infix: list[str] | None = None) -> None: """ ... + class BaseRenderModule(Protocol): """Common routines for all renderer modules.""" + @staticmethod def default_config() -> dict[str, RawType]: """Retrieve default settings. @@ -94,29 +116,41 @@ def default_config() -> dict[str, RawType]: """ ... + @runtime_checkable class RawRenderModule(BaseRenderModule, Protocol): """Render module that returns raw text.""" - def render_raw(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> dict[str, str]: + + def render_raw( + self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration + ) -> dict[str, str]: """Turn a workflow into flat strings. Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least. """ ... + @runtime_checkable class StructuredRenderModule(BaseRenderModule, Protocol): """Render module that returns JSON/YAML-serializable structures.""" - def render(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> dict[str, dict[str, RawType]]: + + def render( + self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration + ) -> dict[str, dict[str, RawType]]: """Turn a workflow into a serializable structure. Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least. """ ... + class RenderCall(Protocol): """Callable that will render out workflow(s).""" - def __call__(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> dict[str, str] | dict[str, RawType]: + + def __call__( + self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration + ) -> dict[str, str] | dict[str, RawType]: """Take a workflow and turn it into a set of serializable (sub)workflows. Args: @@ -127,11 +161,13 @@ def __call__(self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration) -> """ ... + class UnevaluatableError(Exception): """Signposts that a user has tried to treat a reference as the real (runtime) value. For example, by comparing to a concrete integer or value, etc. """ + ... @@ -142,6 +178,7 @@ class ConstructConfiguration: Holds configuration that may be relevant to `construst(...)` calls or, realistically, anything prior to rendering. It should hold generic configuration that is renderer-independent. """ + flatten_all_nested: bool = False allow_positional_args: bool = False allow_plain_dict_fields: bool = False @@ -149,6 +186,7 @@ class ConstructConfiguration: field_index_types: str = "int" simplify_ids: bool = False + class ConstructConfigurationTypedDict(TypedDict): """Basic configuration of the construction process. @@ -157,6 +195,7 @@ class ConstructConfigurationTypedDict(TypedDict): **THIS MUST BE KEPT IDENTICAL TO ConstructConfiguration.** """ + flatten_all_nested: NotRequired[bool] allow_positional_args: NotRequired[bool] allow_plain_dict_fields: NotRequired[bool] @@ -164,19 +203,25 @@ class ConstructConfigurationTypedDict(TypedDict): field_index_types: NotRequired[str] simplify_ids: NotRequired[bool] + @define class GlobalConfiguration: """Overall configuration structure. Having a single configuration dict allows us to manage only one ContextVar. """ + construct: ConstructConfiguration render: dict[str, RawType] + CONFIGURATION: ContextVar[GlobalConfiguration] = ContextVar("configuration") + @contextmanager -def set_configuration(**kwargs: Unpack[ConstructConfigurationTypedDict]) -> Iterator[ContextVar[GlobalConfiguration]]: +def set_configuration( + **kwargs: Unpack[ConstructConfigurationTypedDict], +) -> Iterator[ContextVar[GlobalConfiguration]]: """Sets the construct-time configuration. This is a context manager, so that a setting can be temporarily overridden and automatically restored. @@ -186,8 +231,11 @@ def set_configuration(**kwargs: Unpack[ConstructConfigurationTypedDict]) -> Iter setattr(CONFIGURATION.get().construct, key, value) yield CONFIGURATION + @contextmanager -def set_render_configuration(kwargs: dict[str, RawType]) -> Iterator[ContextVar[GlobalConfiguration]]: +def set_render_configuration( + kwargs: dict[str, RawType], +) -> Iterator[ContextVar[GlobalConfiguration]]: """Sets the render-time configuration. This is a context manager, so that a setting can be temporarily overridden and automatically restored. @@ -198,6 +246,7 @@ def set_render_configuration(kwargs: dict[str, RawType]) -> Iterator[ContextVar[ CONFIGURATION.get().render.update(**kwargs) yield CONFIGURATION + @contextmanager def _set_configuration() -> Iterator[ContextVar[GlobalConfiguration]]: """Prepares and tidied up the configuration for applying settings. @@ -207,7 +256,9 @@ def _set_configuration() -> Iterator[ContextVar[GlobalConfiguration]]: try: previous = CONFIGURATION.get() except LookupError: - previous = GlobalConfiguration(construct=ConstructConfiguration(), render=default_renderer_config()) + previous = GlobalConfiguration( + construct=ConstructConfiguration(), render=default_renderer_config() + ) CONFIGURATION.set(previous) previous = copy.deepcopy(previous) @@ -230,12 +281,15 @@ def default_renderer_config() -> RenderConfiguration: """ try: # We have to use a cast as we do not know if the renderer module is valid. - render_module = cast(BaseRenderModule, importlib.import_module("__renderer_mod__")) + render_module = cast( + BaseRenderModule, importlib.import_module("__renderer_mod__") + ) default_config: Callable[[], RenderConfiguration] = render_module.default_config except ImportError: return {} return default_config() + @lru_cache def default_construct_config() -> ConstructConfiguration: """Gets the default construct-time configuration. @@ -254,6 +308,7 @@ def default_construct_config() -> ConstructConfiguration: field_index_types="int", ) + def get_configuration(key: str) -> RawType: """Retrieve the configuration or (silently) return the default. @@ -267,10 +322,11 @@ def get_configuration(key: str) -> RawType: Returns: (preferably) a JSON/YAML-serializable construct. """ try: - return getattr(CONFIGURATION.get().construct, key) # type: ignore + return getattr(CONFIGURATION.get().construct, key) # type: ignore except LookupError: # TODO: Not sure what the best way to typehint this is. - return getattr(ConstructConfiguration(), key) # type: ignore + return getattr(ConstructConfiguration(), key) # type: ignore + def get_render_configuration(key: str) -> RawType: """Retrieve configuration for the active renderer. @@ -288,6 +344,7 @@ def get_render_configuration(key: str) -> RawType: except LookupError: return default_renderer_config().get(key) + class WorkflowComponent: """Base class for anything directly tied to an individual `Workflow`. @@ -312,12 +369,12 @@ def __init__(self, *args: Any, workflow: WorkflowProtocol, **kwargs: Any): @property def __workflow__(self) -> WorkflowProtocol: - """Workflow to which this reference applies.""" + """Workflow to which this reference applies.""" return self.__workflow_real__ @__workflow__.setter def __workflow__(self, workflow: WorkflowProtocol) -> None: - """Workflow to which this reference applies.""" + """Workflow to which this reference applies.""" self.__workflow_real__ = workflow @@ -370,7 +427,9 @@ def __type__(self) -> type: def _raise_unevaluatable_error(self) -> None: """Convenience method to consistently formulate an UnevaluatableError for this reference.""" - raise UnevaluatableError(f"This reference, {self.__name__}, cannot be evaluated during construction.") + raise UnevaluatableError( + f"This reference, {self.__name__}, cannot be evaluated during construction." + ) def __eq__(self, other: object) -> Any: """Test equality at construct-time, if sensible. @@ -431,11 +490,13 @@ def __str__(self) -> str: """ return self.__name__ + class IterableMixin(Reference[U]): """Functionality for iterating over references to give new references.""" + __fixed_len__: int | None = None - def __init__(self, typ: type[U] | None=None, **kwargs: Any): + def __init__(self, typ: type[U] | None = None, **kwargs: Any): """Extract length, if available from type. Captures types of the form (e.g.) `tuple[int, float]` and records the length @@ -503,6 +564,7 @@ def __getitem__(self, attr: str | int) -> "Reference[U] | Any": """ return super().__getitem__(attr) + class IteratedGenerator(Generic[U]): """Sentinel value for capturing that an iteration has occured without performing it. @@ -510,6 +572,7 @@ class IteratedGenerator(Generic[U]): if the renderer wishes to postpone iteration to runtime, and simply record it is required, rather than evaluating the iterator. """ + __wrapped__: Reference[U] def __init__(self, to_wrap: Reference[U]): diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index ee882a68..8e229c8a 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -21,7 +21,16 @@ from attrs import define, has as attrs_has, fields as attrs_fields, AttrsInstance from dataclasses import is_dataclass, fields as dataclass_fields from collections.abc import Mapping -from typing import TypedDict, NotRequired, get_origin, get_args, Union, cast, Any, Unpack, Iterable +from typing import ( + TypedDict, + NotRequired, + get_origin, + get_args, + cast, + Any, + Unpack, + Iterable, +) from types import UnionType from inspect import isclass from sympy import Basic, Tuple, Dict, jscode, Symbol @@ -37,15 +46,18 @@ BaseStep, StepReference, ParameterReference, - expr_to_references + expr_to_references, ) -from dewret.utils import crawl_raw, DataclassProtocol, firm_to_raw, flatten_if_set, Unset -from dewret.render import base_render -from dewret.core import ( - Reference, - get_render_configuration, - set_render_configuration +from dewret.utils import ( + crawl_raw, + DataclassProtocol, + firm_to_raw, + flatten_if_set, + Unset, ) +from dewret.render import base_render +from dewret.core import Reference, get_render_configuration, set_render_configuration + class CommandInputSchema(TypedDict): """Structure for referring to a raw type in CWL. @@ -67,9 +79,14 @@ class CommandInputSchema(TypedDict): default: NotRequired[RawType] -InputSchemaType = Union[ - str, CommandInputSchema, list[str], list["InputSchemaType"], dict[str, "str | InputSchemaType"] -] +InputSchemaType = ( + str + | CommandInputSchema + | list[str] + | list["InputSchemaType"] + | dict[str, "str | InputSchemaType"] +) + def render_expression(ref: Any) -> "ReferenceDefinition": """Turn a rich (sympy) expression into a CWL JS expression. @@ -79,6 +96,7 @@ def render_expression(ref: Any) -> "ReferenceDefinition": Returns: a ReferenceDefinition containing a string representation of the expression in the form `$(...)`. """ + def _render(ref: Any) -> Basic | RawType: if not isinstance(ref, Basic): if isinstance(ref, Mapping): @@ -86,17 +104,24 @@ def _render(ref: Any) -> Basic | RawType: elif not isinstance(ref, str | bytes) and isinstance(ref, Iterable): ref = Tuple(*(_render(val) for val in ref)) return ref + expr = _render(ref) if isinstance(expr, Basic): values = list(expr.free_symbols) step_syms = [sym for sym in expr.free_symbols if isinstance(sym, StepReference)] - param_syms = [sym for sym in expr.free_symbols if isinstance(sym, ParameterReference)] + param_syms = [ + sym for sym in expr.free_symbols if isinstance(sym, ParameterReference) + ] if set(values) != set(step_syms) | set(param_syms): - raise NotImplementedError(f"Can only build expressions for step results and param results: {ref}") + raise NotImplementedError( + f"Can only build expressions for step results and param results: {ref}" + ) if len(step_syms) > 1: - raise NotImplementedError(f"Can only create expressions with 1 step reference: {ref}") + raise NotImplementedError( + f"Can only create expressions with 1 step reference: {ref}" + ) if not (step_syms or param_syms): ... if values == [ref]: @@ -117,9 +142,12 @@ def _render(ref: Any) -> Basic | RawType: source = f"{ref.__root_name__}{base}" else: expr = expr.subs(ref, Symbol(f"inputs.{ref.name}")) - return ReferenceDefinition(source=source, value_from=f"$({jscode(_render(expr))})") + return ReferenceDefinition( + source=source, value_from=f"$({jscode(_render(expr))})" + ) return ReferenceDefinition(source=str(expr), value_from=None) + class CWLRendererConfiguration(TypedDict): """Configuration for the renderer. @@ -159,6 +187,7 @@ def with_type(result: Any) -> type | Any: return result.__type__ return type(result) + def with_field(result: Any) -> str: """Get a string representing any 'field' suffix of a value. @@ -174,6 +203,7 @@ def with_field(result: Any) -> str: else: return "out" + def to_name(result: Reference[Any]) -> str: """Take a reference and get a name representing it. @@ -182,7 +212,11 @@ def to_name(result: Reference[Any]) -> str: Returns: the name of the reference, including any field portion, appending an `"out"` fieldname if none. """ - if hasattr(result, "__field__") and not result.__field__ and isinstance(result, StepReference): + if ( + hasattr(result, "__field__") + and not result.__field__ + and isinstance(result, StepReference) + ): return f"{result.__name__}/out" return result.__name__ @@ -278,10 +312,10 @@ def render(self) -> dict[str, RawType]: "in": { key: ( ref.render() - if isinstance(ref, ReferenceDefinition) else - render_expression(ref).render() - if isinstance(ref, Basic) else - {"default": firm_to_raw(ref.value)} + if isinstance(ref, ReferenceDefinition) + else render_expression(ref).render() + if isinstance(ref, Basic) + else {"default": firm_to_raw(ref.value)} if hasattr(ref, "value") else render_expression(ref).render() ) @@ -319,10 +353,7 @@ def to_cwl_type(label: str, typ: type) -> CommandInputSchema: Returns: CWL specification type dict. """ - typ_dict: CommandInputSchema = { - "label": label, - "type": "" - } + typ_dict: CommandInputSchema = {"label": label, "type": ""} base: Any | None = typ args = get_args(typ) if args: @@ -343,19 +374,22 @@ def to_cwl_type(label: str, typ: type) -> CommandInputSchema: elif base == bytes: typ_dict["type"] = "bytes" elif isinstance(typ, UnionType): - typ_dict.update({"type": tuple(to_cwl_type(label, item)["type"] for item in args)}) + typ_dict.update( + {"type": tuple(to_cwl_type(label, item)["type"] for item in args)} + ) elif isclass(base) and issubclass(base, Iterable): try: if len(args) > 1: - typ_dict.update({ - "type": "array", - "items": [to_cwl_type(label, t)["type"] for t in args], - }) + typ_dict.update( + { + "type": "array", + "items": [to_cwl_type(label, t)["type"] for t in args], + } + ) elif len(args) == 1: - typ_dict.update({ - "type": "array", - "items": to_cwl_type(label, args[0])["type"] - }) + typ_dict.update( + {"type": "array", "items": to_cwl_type(label, args[0])["type"]} + ) else: typ_dict["type"] = "array" except IndexError as err: @@ -446,7 +480,7 @@ def to_output_schema( # TODO: this complains because NotRequired keys are never present, # but that does not seem like a problem here - likely a better solution. output = CommandOutputSchema( - **to_cwl_type(label, typ) # type: ignore + **to_cwl_type(label, typ) # type: ignore ) if output_source is not None: output["outputSource"] = output_source @@ -465,7 +499,12 @@ def _raw_to_command_input_schema_internal( elif isinstance(value, list): typeset = set(get_args(value)) if not typeset: - typeset = {item.__type__ if item is not None and hasattr(item, "__type__") else type(item) for item in value} + typeset = { + item.__type__ + if item is not None and hasattr(item, "__type__") + else type(item) + for item in value + } if len(typeset) != 1: raise RuntimeError( "For CWL, an input array must have a consistent type, " @@ -513,8 +552,12 @@ def from_parameters( Returns: CWL-like structure representing all workflow outputs. """ - parameters_dedup = {p._.parameter for p in parameters if isinstance(p, ParameterReference)} - parameters = list(parameters_dedup) + [p for p in parameters if not isinstance(p, ParameterReference)] + parameters_dedup = { + p._.parameter for p in parameters if isinstance(p, ParameterReference) + } + parameters = list(parameters_dedup) + [ + p for p in parameters if not isinstance(p, ParameterReference) + ] return cls( inputs={ input.name: cls.CommandInputParameter( @@ -556,11 +599,18 @@ class OutputsDefinition: outputs: sequence of results from a workflow. """ - outputs: dict[str, "CommandOutputSchema"] | list["CommandOutputSchema"] | CommandOutputSchema + outputs: ( + dict[str, "CommandOutputSchema"] + | list["CommandOutputSchema"] + | CommandOutputSchema + ) @classmethod def from_results( - cls, results: dict[str, StepReference[Any]] | list[StepReference[Any]] | tuple[StepReference[Any], ...] + cls, + results: dict[str, StepReference[Any]] + | list[StepReference[Any]] + | tuple[StepReference[Any], ...], ) -> "OutputsDefinition": """Takes a mapping of results into a CWL structure. @@ -569,39 +619,43 @@ def from_results( Returns: CWL-like structure representing all workflow outputs. """ + def _build_results(result: Any) -> RawType: if isinstance(result, Reference): # TODO: need to work out how to tell mypy that a TypedDict is also dict[str, RawType] - return to_output_schema( # type: ignore + return to_output_schema( # type: ignore with_field(result), with_type(result), output_source=to_name(result) ) results = result return ( - [ - _build_results(result) for result in results - ] if isinstance(results, list | tuple | Tuple) else { - key: _build_results(result) for key, result in results.items() - } + [_build_results(result) for result in results] + if isinstance(results, list | tuple | Tuple) + else {key: _build_results(result) for key, result in results.items()} ) + try: # TODO: sort out this nested type building. - return cls(outputs=_build_results(results)) # type: ignore + return cls(outputs=_build_results(results)) # type: ignore except AttributeError: expr, references = expr_to_references(results) reference_names = sorted( { - str(ref._.parameter) if isinstance(ref, ParameterReference) else str(ref._.step) + str(ref._.parameter) + if isinstance(ref, ParameterReference) + else str(ref._.step) for ref in references } ) - return cls(outputs={ - "out": { - "type": "float", # WARNING: we assume any arithmetic expression returns a float. - "label": "out", - "expression": str(expr), - "source": reference_names + return cls( + outputs={ + "out": { + "type": "float", # WARNING: we assume any arithmetic expression returns a float. + "label": "out", + "expression": str(expr), + "source": reference_names, + } } - }) + ) def render(self) -> dict[str, RawType] | list[RawType]: """Render to a dict-like structure. @@ -610,12 +664,11 @@ def render(self) -> dict[str, RawType] | list[RawType]: Reduced form as a native Python dict structure for serialization. """ - return [ - crawl_raw(output) for output in self.outputs - ] if isinstance(self.outputs, list) else { - key: crawl_raw(output) - for key, output in self.outputs.items() - } + return ( + [crawl_raw(output) for output in self.outputs] + if isinstance(self.outputs, list) + else {key: crawl_raw(output) for key, output in self.outputs.items()} + ) @define @@ -648,7 +701,9 @@ def from_workflow( """ parameters: list[ParameterReference[Any] | FactoryCall] = list( workflow.find_parameters( - include_factory_calls=not get_render_configuration("factories_as_params") + include_factory_calls=not get_render_configuration( + "factories_as_params" + ) ) ) if get_render_configuration("factories_as_params"): @@ -665,10 +720,10 @@ def from_workflow( inputs=InputsDefinition.from_parameters(parameters), outputs=OutputsDefinition.from_results( workflow.result - if isinstance(workflow.result, list | tuple | Tuple) else - {with_field(workflow.result): workflow.result} - if workflow.has_result and workflow.result is not None else - {} + if isinstance(workflow.result, list | tuple | Tuple) + else {with_field(workflow.result): workflow.result} + if workflow.has_result and workflow.result is not None + else {} ), name=name, ) @@ -703,9 +758,9 @@ def render( serialization. """ # TODO: Again, convincing mypy that a TypedDict has RawType values. - with set_render_configuration(kwargs): # type: ignore + with set_render_configuration(kwargs): # type: ignore rendered = base_render( workflow, - lambda workflow: WorkflowDefinition.from_workflow(workflow).render() + lambda workflow: WorkflowDefinition.from_workflow(workflow).render(), ) return rendered diff --git a/src/dewret/renderers/snakemake.py b/src/dewret/renderers/snakemake.py index 8bbbdca8..d9c7cd5c 100644 --- a/src/dewret/renderers/snakemake.py +++ b/src/dewret/renderers/snakemake.py @@ -41,9 +41,7 @@ base_render, ) -MainTypes = typing.Union[ - BasicType, list[str], list["MainTypes"], dict[str, "MainTypes"] -] +MainTypes = BasicType | list[str] | list["MainTypes"] | dict[str, "MainTypes"] @define @@ -477,5 +475,5 @@ def render(workflow: Workflow) -> dict[str, typing.Any]: workflow, lambda workflow: yaml.dump( WorkflowDefinition.from_workflow(workflow).render(), indent=4 - ).translate(trans_table) + ).translate(trans_table), ) From f6f478bddbbe467efa41ea861695e184a2d26c0b Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Mon, 2 Sep 2024 10:55:10 +0300 Subject: [PATCH 092/108] chore: Refactored render.py:get_render_method for better readability, added explanitory comments --- docs/renderer_tutorial.md | 3 --- src/dewret/annotations.py | 13 ++++++------- src/dewret/backends/backend_dask.py | 20 ++------------------ src/dewret/core.py | 5 ++++- src/dewret/render.py | 21 +++++++++++---------- 5 files changed, 23 insertions(+), 39 deletions(-) diff --git a/docs/renderer_tutorial.md b/docs/renderer_tutorial.md index f046b6ea..f4d0f702 100644 --- a/docs/renderer_tutorial.md +++ b/docs/renderer_tutorial.md @@ -416,6 +416,3 @@ RawType = BasicType | list[str] | list["RawType"] | dict[str, "RawType"] ```shell python snakemake_tasks.py ``` - -### Q: Should I add a brief description of dewret in step 1? Should link dewret types/docs etc here? -### A: Get details on how that happens and probably yes. diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index 62a9ee36..a541ec0d 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -46,13 +46,12 @@ def __init__(self, fn: Callable[..., Any]): If `fn` is a class, it takes the constructor, and if it is a method, it takes the `__func__` attribute. """ - self.fn = ( - fn.__init__ - if inspect.isclass(fn) else - fn.__func__ - if inspect.ismethod(fn) else - fn - ) + if inspect.isclass(fn): + self.fn = fn.__init__ + elif inspect.ismethod(fn): + self.fn = fn.__func__ + else: + self.fn = fn @property def return_type(self) -> Any: diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index cc22ae34..06ae47f0 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -105,26 +105,10 @@ def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread Args: workflow: `Workflow` in which to record the execution. task: `dask.delayed` function, wrapped by dewret, that we wish to compute. - thread_pool: thread pool for executing the workflows, to allow initialization of configuration contextvars. + thread_pool: custom thread pool for executing workflows, copies in correct values for contextvars to each thread before they are accessed by a dask worker. **kwargs: any configuration arguments for this backend. """ - # def _check_delayed(task: Lazy | list[Lazy] | tuple[Lazy]) -> Delayed: - # # We need isinstance to reassure type-checker. - # if isinstance(task, list) or isinstance(task, tuple): - # lst: list[Delayed] | tuple[Delayed, ...] = [_check_delayed(elt) for elt in task] - # if isinstance(task, tuple): - # lst = tuple(lst) - # return delayed(lst) - # elif not isinstance(task, Delayed) or not is_lazy(task): - # raise RuntimeError( - # f"{task} is not a dask delayed, perhaps you tried to mix backends?" - # ) - # return task - # computable = _check_delayed(task) - # if not is_lazy(task): - # raise RuntimeError( - # f"{task} is not a dask delayed, perhaps you tried to mix backends?" - # ) + # def _check_delayed was here, but we decided to delegate this to dask if isinstance(task, Delayed) and is_lazy(task): computable = task diff --git a/src/dewret/core.py b/src/dewret/core.py index 72be23ae..fbc575f2 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -48,7 +48,9 @@ BasicType = str | float | bool | bytes | int | None RawType = BasicType | list["RawType"] | dict[str, "RawType"] FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] -ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore +ExprType = ( + FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...] +) # type: ignore U = TypeVar("U") T = TypeVar("T") @@ -71,6 +73,7 @@ def strip_annotations(parent_type: type) -> tuple[type, tuple[str]]: return parent_type, tuple(metadata) +# Generic type for configuration settings for the renderer RenderConfiguration = dict[str, RawType] diff --git a/src/dewret/render.py b/src/dewret/render.py index 31b9e3da..f2fc3e3d 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -66,19 +66,20 @@ def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, render_module = cast(BaseRenderModule, module) else: render_module = renderer - if not isinstance(render_module, StructuredRenderModule): - if not isinstance(render_module, RawRenderModule): - raise NotImplementedError("This render module neither seems to be a structured nor a raw render module.") + + if isinstance(render_module, RawRenderModule): return render_module.render_raw + elif isinstance(render_module, (StructuredRenderModule)): + def _render(workflow: Workflow, render_module: StructuredRenderModule, pretty: bool=False, **kwargs: RenderConfiguration) -> dict[str, str]: + rendered = render_module.render(workflow, **kwargs) + return { + key: structured_to_raw(value, pretty=pretty) + for key, value in rendered.items() + } - def _render(workflow: Workflow, render_module: StructuredRenderModule, pretty: bool=False, **kwargs: RenderConfiguration) -> dict[str, str]: - rendered = render_module.render(workflow, **kwargs) - return { - key: structured_to_raw(value, pretty=pretty) - for key, value in rendered.items() - } + return cast(RenderCall, partial(_render, render_module=render_module, pretty=pretty)) - return cast(RenderCall, partial(_render, render_module=render_module, pretty=pretty)) + raise NotImplementedError("This render module neither seems to be a structured nor a raw render module.") def base_render( workflow: Workflow, build_cb: Callable[[Workflow], T] From e76a0a9e05def35ac9db41cc4a49830f421939f2 Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Mon, 2 Sep 2024 10:59:33 +0300 Subject: [PATCH 093/108] chore: Formated methods --- src/dewret/annotations.py | 62 +++++++++++++++++++++-------- src/dewret/backends/backend_dask.py | 11 +++-- src/dewret/core.py | 4 +- src/dewret/render.py | 38 +++++++++++++----- 4 files changed, 84 insertions(+), 31 deletions(-) diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py index a541ec0d..c1449d43 100644 --- a/src/dewret/annotations.py +++ b/src/dewret/annotations.py @@ -24,12 +24,22 @@ import importlib from functools import lru_cache from types import FunctionType, ModuleType -from typing import Any, TypeVar, Annotated, Callable, get_origin, get_args, Mapping, get_type_hints +from typing import ( + Any, + TypeVar, + Annotated, + Callable, + get_origin, + get_args, + Mapping, + get_type_hints, +) T = TypeVar("T") AtRender = Annotated[T, "AtRender"] Fixed = Annotated[T, "Fixed"] + class FunctionAnalyser: """Convenience class for analysing a function with reduced duplication of effort. @@ -37,6 +47,7 @@ class FunctionAnalyser: _fn: the wrapped callable _annotations: stored annotations for the function. """ + _fn: Callable[..., Any] _annotations: dict[str, Any] @@ -79,9 +90,13 @@ def _typ_has(typ: type, annotation: type) -> bool: Returns: True if the type has the given annotation, otherwise False. """ if not hasattr(annotation, "__metadata__"): - return False - if (origin := get_origin(typ)): - if origin is Annotated and hasattr(typ, "__metadata__") and typ.__metadata__ == annotation.__metadata__: + return False + if origin := get_origin(typ): + if ( + origin is Annotated + and hasattr(typ, "__metadata__") + and typ.__metadata__ == annotation.__metadata__ + ): return True if any(FunctionAnalyser._typ_has(arg, annotation) for arg in get_args(typ)): return True @@ -113,8 +128,11 @@ def _get_all_imported_names(mod: ModuleType) -> dict[str, tuple[ModuleType, str] if isinstance(node, ast.ImportFrom): for name in node.names: imported_names[name.asname or name.name] = ( - importlib.import_module("".join(["."] * node.level) + (node.module or ""), package=mod.__package__), - name.name + importlib.import_module( + "".join(["."] * node.level) + (node.module or ""), + package=mod.__package__, + ), + name.name, ) return imported_names @@ -122,10 +140,16 @@ def _get_all_imported_names(mod: ModuleType) -> dict[str, tuple[ModuleType, str] def free_vars(self) -> dict[str, Any]: """Get the free variables for this Callable.""" if self.fn.__code__ and self.fn.__closure__: - return dict(zip(self.fn.__code__.co_freevars, (c.cell_contents for c in self.fn.__closure__), strict=False)) + return dict( + zip( + self.fn.__code__.co_freevars, + (c.cell_contents for c in self.fn.__closure__), + strict=False, + ) + ) return {} - def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> Any: + def get_argument_annotation(self, arg: str, exhaustive: bool = False) -> Any: """Retrieve the annotations for this argument. Args: @@ -135,22 +159,28 @@ def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> Any: Returns: annotation if available, else None. """ typ: type | None = None - if (typ := self.fn.__annotations__.get(arg)): + if typ := self.fn.__annotations__.get(arg): if isinstance(typ, str): typ = get_type_hints(self.fn, include_extras=True).get(arg) elif exhaustive: - if (anns := get_type_hints(sys.modules[self.fn.__module__], include_extras=True)): - if (typ := anns.get(arg)): + if anns := get_type_hints( + sys.modules[self.fn.__module__], include_extras=True + ): + if typ := anns.get(arg): ... - elif (orig_pair := self.get_all_imported_names().get(arg)): + elif orig_pair := self.get_all_imported_names().get(arg): orig_module, orig_name = orig_pair typ = orig_module.__annotations__.get(orig_name) - elif (value := self.free_vars.get(arg)): + elif value := self.free_vars.get(arg): if not inspect.isclass(value) or inspect.isfunction(value): - raise RuntimeError(f"Cannot use free variables - please put {arg} at the global scope") + raise RuntimeError( + f"Cannot use free variables - please put {arg} at the global scope" + ) return typ - def argument_has(self, arg: str, annotation: type, exhaustive: bool=False) -> bool: + def argument_has( + self, arg: str, annotation: type, exhaustive: bool = False + ) -> bool: """Check if the named argument has the given annotation. Args: @@ -163,7 +193,7 @@ def argument_has(self, arg: str, annotation: type, exhaustive: bool=False) -> bo typ = self.get_argument_annotation(arg, exhaustive) return bool(typ and self._typ_has(typ, annotation)) - def is_at_construct_arg(self, arg: str, exhaustive: bool=False) -> bool: + def is_at_construct_arg(self, arg: str, exhaustive: bool = False) -> bool: """Convience function to check for `AtConstruct`, wrapping `FunctionAnalyser.argument_has`.""" return self.argument_has(arg, AtRender, exhaustive) diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 06ae47f0..48d55859 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -90,14 +90,19 @@ def is_lazy(task: Any) -> bool: True if so, False otherwise. """ return isinstance(task, Delayed) or ( - isinstance(task, tuple | list) and - all(is_lazy(elt) for elt in task) + isinstance(task, tuple | list) and all(is_lazy(elt) for elt in task) ) lazy = delayed -def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread_pool: ThreadPoolExecutor | None=None, **kwargs: Any) -> Any: + +def run( + workflow: Workflow | None, + task: Lazy | list[Lazy] | tuple[Lazy], + thread_pool: ThreadPoolExecutor | None = None, + **kwargs: Any, +) -> Any: """Execute a task as the output of a workflow. Runs a task with dask. diff --git a/src/dewret/core.py b/src/dewret/core.py index fbc575f2..eecf3033 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -48,9 +48,7 @@ BasicType = str | float | bool | bytes | int | None RawType = BasicType | list["RawType"] | dict[str, "RawType"] FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] -ExprType = ( - FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...] -) # type: ignore +ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore U = TypeVar("U") T = TypeVar("T") diff --git a/src/dewret/render.py b/src/dewret/render.py index f2fc3e3d..a313613e 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -25,12 +25,20 @@ import yaml from .workflow import Workflow, NestedStep -from .core import RawType, RenderCall, BaseRenderModule, RawRenderModule, StructuredRenderModule, RenderConfiguration +from .core import ( + RawType, + RenderCall, + BaseRenderModule, + RawRenderModule, + StructuredRenderModule, + RenderConfiguration, +) from .utils import load_module_or_package T = TypeVar("T") -def structured_to_raw(rendered: RawType, pretty: bool=False) -> str: + +def structured_to_raw(rendered: RawType, pretty: bool = False) -> str: """Serialize a serializable structure to a string. Args: @@ -45,7 +53,10 @@ def structured_to_raw(rendered: RawType, pretty: bool=False) -> str: output = str(rendered) return output -def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, pretty: bool=False) -> RenderCall: + +def get_render_method( + renderer: Path | RawRenderModule | StructuredRenderModule, pretty: bool = False +) -> RenderCall: """Create a ready-made callable to render the workflow that is appropriate for the renderer module. Args: @@ -70,20 +81,29 @@ def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, if isinstance(render_module, RawRenderModule): return render_module.render_raw elif isinstance(render_module, (StructuredRenderModule)): - def _render(workflow: Workflow, render_module: StructuredRenderModule, pretty: bool=False, **kwargs: RenderConfiguration) -> dict[str, str]: + + def _render( + workflow: Workflow, + render_module: StructuredRenderModule, + pretty: bool = False, + **kwargs: RenderConfiguration, + ) -> dict[str, str]: rendered = render_module.render(workflow, **kwargs) return { key: structured_to_raw(value, pretty=pretty) for key, value in rendered.items() } - return cast(RenderCall, partial(_render, render_module=render_module, pretty=pretty)) + return cast( + RenderCall, partial(_render, render_module=render_module, pretty=pretty) + ) + + raise NotImplementedError( + "This render module neither seems to be a structured nor a raw render module." + ) - raise NotImplementedError("This render module neither seems to be a structured nor a raw render module.") -def base_render( - workflow: Workflow, build_cb: Callable[[Workflow], T] -) -> dict[str, T]: +def base_render(workflow: Workflow, build_cb: Callable[[Workflow], T]) -> dict[str, T]: """Render to a dict-like structure. Args: From a0655c0d3e444d491ef32df5976b9b0c64579294 Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Mon, 2 Sep 2024 13:48:51 +0300 Subject: [PATCH 094/108] docs: Added docstrings to annotation tests --- tests/test_annotations.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 252b50a1..b0741f54 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -14,15 +14,15 @@ ARG2: bool = False class MyClass: - """TODO: Docstring.""" + """A mock class to wrap values and mock processing of data.""" def method(self, arg1: bool, arg2: AtRender[int]) -> float: - """TODO: Docstring.""" + """A mock method to simulate the behavior of processing and rendering values.""" arg3: float = 7.0 arg4: AtRender[float] = 8.0 return arg1 + arg2 + arg3 + arg4 + int(ARG1) + int(ARG2) def fn(arg5: int, arg6: AtRender[int]) -> float: - """TODO: Docstring.""" + """A mock function to simulate processing of rendered values with integer inputs.""" arg7: float = 7.0 arg8: AtRender[float] = 8.0 return arg5 + arg6 + arg7 + arg8 + int(ARG1) + int(ARG2) @@ -30,16 +30,19 @@ def fn(arg5: int, arg6: AtRender[int]) -> float: @workflow() def to_int_bad(num: int, should_double: bool) -> int | float: - """Cast to an int.""" + """A mock workflow that casts to an int with a wrong type for handling doubles.""" return increment(num=num) if should_double else sum(left=num, right=num) @workflow() def to_int(num: int, should_double: AtRender[bool]) -> int | float: - """Cast to an int.""" + """A mock workflow that casts to an int with a right type for handling doubles.""" return increment(num=num) if should_double else sum(left=num, right=num) def test_can_analyze_annotations() -> None: - """TODO: Docstring.""" + """Test that annotations can be correctly analyzed within methods and functions. + Verifies that the `FunctionAnalyser` finds which arguments + and global variables are derived from `dewret.annotations.AtRender`. + """ my_obj = MyClass() analyser = FunctionAnalyser(my_obj.method) @@ -61,7 +64,7 @@ def test_can_analyze_annotations() -> None: assert analyser.argument_has("ARG1", AtRender) is False def test_at_render() -> None: - """TODO: Docstring.""" + """Test the rendering of workflows with `dewret.annotations.AtRender` and exceptions handling.""" with pytest.raises(TaskException) as _: result = to_int_bad(num=increment(num=3), should_double=True) wkflw = construct(result, simplify_ids=True) @@ -144,7 +147,7 @@ def test_at_render() -> None: def test_at_render_between_modules() -> None: - """TODO: Docstring.""" + """Test rendering of workflows across different modules using `dewret.annotations.AtRender`.""" result = try_nothing() wkflw = construct(result, simplify_ids=True) subworkflows = render(wkflw, allow_complex_types=True) @@ -153,7 +156,7 @@ def test_at_render_between_modules() -> None: list_2: Fixed[list[int]] = [0, 1, 2, 3] def test_can_loop_over_fixed_length() -> None: - """TODO: Docstring.""" + """Test looping over a fixed-length list using `dewret.annotations.Fixed`.""" @workflow() def loop_over_lists(list_1: list[int]) -> list[int]: result = [] From aa7d6e4d0166ed94c8b8288fa425ea0d81f9c945 Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Mon, 2 Sep 2024 15:36:34 +0300 Subject: [PATCH 095/108] docs: Added docstrings for tests --- tests/test_configuration.py | 2 +- tests/test_fieldable.py | 39 ++++++++++++++++++++++--------------- tests/test_nested.py | 2 +- tests/test_render_module.py | 2 +- tests/test_subworkflows.py | 4 ++-- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index a74cb66c..2d293519 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -17,7 +17,7 @@ def floor(num: int, expected: AtRender[bool]) -> int: return increment(num=num) def test_cwl_with_parameter() -> None: - """TODO: Docstring.""" + """Test workflows with configuration parameters.""" with set_configuration(flatten_all_nested=True): result = increment(num=floor(num=3, expected=True)) workflow = construct(result, simplify_ids=True) diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index d6f95ee7..d30d6edc 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -16,7 +16,7 @@ @dataclass class Sides: - """TODO: Docstring.""" + """A dataclass representing the sides with `left` and `right` integers.""" left: int right: int @@ -24,11 +24,11 @@ class Sides: @workflow() def sum_sides() -> float: - """TODO: Docstring.""" + """Workflow that returns the sum of the `left` and `right` sides.""" return sum(left=SIDES.left, right=SIDES.right) def test_fields_of_parameters_usable() -> None: - """TODO: Docstring.""" + """Test that fields of parameters can be used to construct and render a workflow correctly.""" result = sum_sides() wkflw = construct(result, simplify_ids=True) rendered = render(wkflw, allow_complex_types=True)["sum_sides-1"] @@ -74,12 +74,12 @@ def test_fields_of_parameters_usable() -> None: @dataclass class MyDataclass: - """TODO: Docstring.""" + """A dataclass with nested references to itself, containing `left` and `right` fields.""" left: int right: "MyDataclass" def test_can_get_field_reference_from_parameter() -> None: - """TODO: Docstring.""" + """Test that field references can be retrieved from a parameter when constructing a workflow.""" my_param = param("my_param", typ=MyDataclass) result = sum(left=my_param.left, right=sum(left=my_param.right.left, right=my_param.left)) wkflw = construct(result, simplify_ids=True) @@ -122,8 +122,8 @@ def test_can_get_field_reference_from_parameter() -> None: run: sum """) -def test_can_get_field_reference_iff_parent_type_has_field() -> None: - """TODO: Docstring.""" +def test_can_get_field_reference_if_parent_type_has_field() -> None: + """Test that a field reference is retrievable if the parent type has that field.""" @dataclass class MyDataclass: left: int @@ -136,7 +136,7 @@ class MyDataclass: assert param_reference.left.__type__ == int def test_can_get_go_upwards_from_a_field_reference() -> None: - """TODO: Docstring.""" + """Test that it's possible to move upwards in the hierarchy from a field reference.""" @dataclass class MyDataclass: left: int @@ -150,7 +150,7 @@ class MyDataclass: assert back.__type__ == MyDataclass def test_can_get_field_references_from_dataclass() -> None: - """TODO: Docstring.""" + """Test that field references can be extracted from a dataclass and used in workflows.""" @dataclass class MyDataclass: left: int @@ -173,12 +173,12 @@ def get_left(my_dataclass: MyDataclass) -> int: assert wkflw.result.__type__ == int class MyDict(TypedDict): - """TODO: Docstring.""" + """A typed dictionary with `left` as an integer and `right` as a float.""" left: int right: float def test_can_get_field_references_from_typed_dict() -> None: - """TODO: Docstring.""" + """Test that field references can be extracted from a custom typed dictionary and used in workflows.""" @workflow() def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: result: MyDict = {"left": mod10(num=my_dict["left"]), "right": pi()} @@ -193,21 +193,24 @@ def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: @dataclass class MyListWrapper: - """TODO: Docstring.""" + """A dataclass that wraps a list of integers.""" my_list: list[int] def test_can_iterate() -> None: - """TODO: Docstring.""" + """Test iteration over a list of tasks and validate positional argument handling.""" @task() def test_task(alpha: int, beta: float, charlie: bool) -> int: + """Task that adds `alpha` and `beta` and returns the integer result.""" return int(alpha + beta) @task() def test_list() -> list[int | float]: + """Task that returns a list containing an integer and a float.""" return [1, 2.] @workflow() def test_iterated() -> int: + """Workflow that tests task iteration over a list.""" # We ignore the type as mypy cannot confirm that the length and types match the args. return test_task(*test_list()) # type: ignore @@ -250,10 +253,12 @@ def test_iterated() -> int: @task() def test_list_2() -> MyListWrapper: + """Task that returns a `MyListWrapper` containing a list of integers.""" return MyListWrapper(my_list=[1, 2]) @workflow() def test_iterated_2(my_wrapper: MyListWrapper) -> int: + """Workflow that tests iteration over a list in a `MyListWrapper`.""" # mypy cannot confirm argument types match. return test_task(*my_wrapper.my_list) # type: ignore @@ -263,10 +268,12 @@ def test_iterated_2(my_wrapper: MyListWrapper) -> int: @task() def test_list_3() -> Fixed[list[tuple[int, int]]]: + """Task that returns a list of integer tuples wrapped in a `dewret.annotations.Fixed` type.""" return [(0, 1), (2, 3)] @workflow() def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: + """Workflow that iterates over a list of integer tuples and performs operations.""" # mypy cannot confirm argument types match. retval = mod10(*test_list_3()[0]) # type: ignore for pair in param: @@ -322,7 +329,7 @@ def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: """) def test_can_use_plain_dict_fields() -> None: - """TODO: Docstring.""" + """Test the use of plain dictionary fields in workflows.""" @workflow() def test_dict(left: int, right: float) -> dict[str, float | int]: result: dict[str, float | int] = {"left": mod10(num=left), "right": pi()} @@ -337,11 +344,11 @@ def test_dict(left: int, right: float) -> dict[str, float | int]: @dataclass class IndexTest: - """TODO: Docstring.""" + """A dataclass for testing indexed fields, containing a `left` field that is a list of integers.""" left: Fixed[list[int]] def test_can_configure_field_separator() -> None: - """TODO: Docstring.""" + """Test the ability to configure the field separator in workflows.""" @task() def test_sep() -> IndexTest: return IndexTest(left=[3]) diff --git a/tests/test_nested.py b/tests/test_nested.py index 3d126326..4a989a8e 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -9,7 +9,7 @@ from ._lib.extra import reverse_list, max_list def test_can_supply_nested_raw() -> None: - """TODO: Docstrings.""" + """TODO: The structures are important for future CWL rendering.""" pi = param("pi", math.pi) result = reverse_list(to_sort=[1., 3., pi]) workflow = construct(max_list(lst=result + result), simplify_ids=True) diff --git a/tests/test_render_module.py b/tests/test_render_module.py index 2cb7b9e1..70390a23 100644 --- a/tests/test_render_module.py +++ b/tests/test_render_module.py @@ -7,7 +7,7 @@ from ._lib.extra import increment, triple_and_one def test_can_load_render_module() -> None: - """TODO: Docstrings.""" + """Checks if we can load a render module""" result = triple_and_one(num=increment(num=3)) workflow = construct(result, simplify_ids=True) workflow._name = "Fred" diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 73f3c35d..a2ee8e7d 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -547,14 +547,14 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: @define class PackResult: - """TODO: Docstrings.""" + """A class representing the counts of card suits in a deck, including hearts, clubs, spades, and diamonds.""" hearts: int clubs: int spades: int diamonds: int def test_combining_attrs_and_factories() -> None: - """TODO: Docstrings.""" + """Check combining attributes from a dataclass with factory-produced instances.""" Pack = factory(PackResult) @task() From 3bff5e5179afbf5e304574e10dc1ac9029e117aa Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Mon, 2 Sep 2024 15:43:12 +0300 Subject: [PATCH 096/108] fix: Formating --- tests/test_annotations.py | 18 ++++++++++-- tests/test_configuration.py | 12 ++++++-- tests/test_fieldable.py | 58 ++++++++++++++++++++++++++++++------- tests/test_nested.py | 7 +++-- tests/test_render_module.py | 12 +++++--- tests/test_subworkflows.py | 38 ++++++++++++++++-------- 6 files changed, 112 insertions(+), 33 deletions(-) diff --git a/tests/test_annotations.py b/tests/test_annotations.py index b0741f54..f62479d8 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -13,14 +13,17 @@ ARG1: AtRender[bool] = True ARG2: bool = False + class MyClass: """A mock class to wrap values and mock processing of data.""" + def method(self, arg1: bool, arg2: AtRender[int]) -> float: """A mock method to simulate the behavior of processing and rendering values.""" arg3: float = 7.0 arg4: AtRender[float] = 8.0 return arg1 + arg2 + arg3 + arg4 + int(ARG1) + int(ARG2) + def fn(arg5: int, arg6: AtRender[int]) -> float: """A mock function to simulate processing of rendered values with integer inputs.""" arg7: float = 7.0 @@ -33,13 +36,16 @@ def to_int_bad(num: int, should_double: bool) -> int | float: """A mock workflow that casts to an int with a wrong type for handling doubles.""" return increment(num=num) if should_double else sum(left=num, right=num) + @workflow() def to_int(num: int, should_double: AtRender[bool]) -> int | float: """A mock workflow that casts to an int with a right type for handling doubles.""" return increment(num=num) if should_double else sum(left=num, right=num) + def test_can_analyze_annotations() -> None: """Test that annotations can be correctly analyzed within methods and functions. + Verifies that the `FunctionAnalyser` finds which arguments and global variables are derived from `dewret.annotations.AtRender`. """ @@ -50,7 +56,9 @@ def test_can_analyze_annotations() -> None: assert analyser.argument_has("arg3", AtRender, exhaustive=True) is False assert analyser.argument_has("ARG2", AtRender, exhaustive=True) is False assert analyser.argument_has("arg2", AtRender, exhaustive=True) is True - assert analyser.argument_has("arg4", AtRender, exhaustive=True) is False # Not a global/argument + assert ( + analyser.argument_has("arg4", AtRender, exhaustive=True) is False + ) # Not a global/argument assert analyser.argument_has("ARG1", AtRender, exhaustive=True) is True assert analyser.argument_has("ARG1", AtRender) is False @@ -59,10 +67,13 @@ def test_can_analyze_annotations() -> None: assert analyser.argument_has("arg7", AtRender, exhaustive=True) is False assert analyser.argument_has("ARG2", AtRender, exhaustive=True) is False assert analyser.argument_has("arg6", AtRender, exhaustive=True) is True - assert analyser.argument_has("arg8", AtRender, exhaustive=True) is False # Not a global/argument + assert ( + analyser.argument_has("arg8", AtRender, exhaustive=True) is False + ) # Not a global/argument assert analyser.argument_has("ARG1", AtRender, exhaustive=True) is True assert analyser.argument_has("ARG1", AtRender) is False + def test_at_render() -> None: """Test the rendering of workflows with `dewret.annotations.AtRender` and exceptions handling.""" with pytest.raises(TaskException) as _: @@ -153,10 +164,13 @@ def test_at_render_between_modules() -> None: subworkflows = render(wkflw, allow_complex_types=True) subworkflows["__root__"] + list_2: Fixed[list[int]] = [0, 1, 2, 3] + def test_can_loop_over_fixed_length() -> None: """Test looping over a fixed-length list using `dewret.annotations.Fixed`.""" + @workflow() def loop_over_lists(list_1: list[int]) -> list[int]: result = [] diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 2d293519..1d45b71f 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -8,21 +8,29 @@ from dewret.annotations import AtRender from ._lib.extra import increment + @workflow() def floor(num: int, expected: AtRender[bool]) -> int: """Converts int/float to int.""" from dewret.core import get_configuration + if get_configuration("flatten_all_nested") != expected: - raise AssertionError(f"Not expected configuration: {str(get_configuration('flatten_all_nested'))} != {expected}") + raise AssertionError( + f"Not expected configuration: {str(get_configuration('flatten_all_nested'))} != {expected}" + ) return increment(num=num) + def test_cwl_with_parameter() -> None: """Test workflows with configuration parameters.""" with set_configuration(flatten_all_nested=True): result = increment(num=floor(num=3, expected=True)) workflow = construct(result, simplify_ids=True) - with pytest.raises(TaskException) as exc, set_configuration(flatten_all_nested=False): + with ( + pytest.raises(TaskException) as exc, + set_configuration(flatten_all_nested=False), + ): result = increment(num=floor(num=3, expected=True)) workflow = construct(result, simplify_ids=True) assert "AssertionError" in str(exc.getrepr()) diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py index d30d6edc..8f7faaf3 100644 --- a/tests/test_fieldable.py +++ b/tests/test_fieldable.py @@ -14,19 +14,24 @@ from ._lib.extra import mod10, sum, pi + @dataclass class Sides: """A dataclass representing the sides with `left` and `right` integers.""" + left: int right: int + SIDES: Sides = Sides(3, 6) + @workflow() def sum_sides() -> float: """Workflow that returns the sum of the `left` and `right` sides.""" return sum(left=SIDES.left, right=SIDES.right) + def test_fields_of_parameters_usable() -> None: """Test that fields of parameters can be used to construct and render a workflow correctly.""" result = sum_sides() @@ -72,16 +77,21 @@ def test_fields_of_parameters_usable() -> None: run: sum """) + @dataclass class MyDataclass: """A dataclass with nested references to itself, containing `left` and `right` fields.""" + left: int right: "MyDataclass" + def test_can_get_field_reference_from_parameter() -> None: """Test that field references can be retrieved from a parameter when constructing a workflow.""" my_param = param("my_param", typ=MyDataclass) - result = sum(left=my_param.left, right=sum(left=my_param.right.left, right=my_param.left)) + result = sum( + left=my_param.left, right=sum(left=my_param.right.left, right=my_param.left) + ) wkflw = construct(result, simplify_ids=True) params = {(str(p), p.__type__) for p in wkflw.find_parameters()} @@ -122,11 +132,14 @@ def test_can_get_field_reference_from_parameter() -> None: run: sum """) + def test_can_get_field_reference_if_parent_type_has_field() -> None: """Test that a field reference is retrievable if the parent type has that field.""" + @dataclass class MyDataclass: left: int + my_param = param("my_param", typ=MyDataclass) result = sum(left=my_param.left, right=my_param.left) wkflw = construct(result, simplify_ids=True) @@ -135,22 +148,27 @@ class MyDataclass: assert str(param_reference.left) == "my_param/left" assert param_reference.left.__type__ == int + def test_can_get_go_upwards_from_a_field_reference() -> None: """Test that it's possible to move upwards in the hierarchy from a field reference.""" + @dataclass class MyDataclass: left: int right: "MyDataclass" + my_param = param("my_param", typ=MyDataclass) result = sum(left=my_param.left, right=my_param.left) construct(result, simplify_ids=True) - back = my_param.right.left.__field_up__() # type: ignore + back = my_param.right.left.__field_up__() # type: ignore assert str(back) == "my_param/right" assert back.__type__ == MyDataclass + def test_can_get_field_references_from_dataclass() -> None: """Test that field references can be extracted from a dataclass and used in workflows.""" + @dataclass class MyDataclass: left: int @@ -165,39 +183,49 @@ def test_dataclass(my_dataclass: MyDataclass) -> MyDataclass: def get_left(my_dataclass: MyDataclass) -> int: return my_dataclass.left - result = get_left(my_dataclass=test_dataclass(my_dataclass=MyDataclass(left=3, right=4.))) + result = get_left( + my_dataclass=test_dataclass(my_dataclass=MyDataclass(left=3, right=4.0)) + ) wkflw = construct(result, simplify_ids=True) assert isinstance(wkflw.result, StepReference) assert str(wkflw.result) == "get_left-1" assert wkflw.result.__type__ == int + class MyDict(TypedDict): """A typed dictionary with `left` as an integer and `right` as a float.""" + left: int right: float + def test_can_get_field_references_from_typed_dict() -> None: """Test that field references can be extracted from a custom typed dictionary and used in workflows.""" + @workflow() def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: result: MyDict = {"left": mod10(num=my_dict["left"]), "right": pi()} return result - result = test_dict(left=3, right=4.) + result = test_dict(left=3, right=4.0) wkflw = construct(result, simplify_ids=True) assert isinstance(wkflw.result, StepReference) assert str(wkflw.result["left"]) == "test_dict-1/left" assert wkflw.result["left"].__type__ == int + @dataclass class MyListWrapper: """A dataclass that wraps a list of integers.""" + my_list: list[int] + def test_can_iterate() -> None: """Test iteration over a list of tasks and validate positional argument handling.""" + @task() def test_task(alpha: int, beta: float, charlie: bool) -> int: """Task that adds `alpha` and `beta` and returns the integer result.""" @@ -206,13 +234,13 @@ def test_task(alpha: int, beta: float, charlie: bool) -> int: @task() def test_list() -> list[int | float]: """Task that returns a list containing an integer and a float.""" - return [1, 2.] + return [1, 2.0] @workflow() def test_iterated() -> int: """Workflow that tests task iteration over a list.""" # We ignore the type as mypy cannot confirm that the length and types match the args. - return test_task(*test_list()) # type: ignore + return test_task(*test_list()) # type: ignore with set_configuration(allow_positional_args=True, flatten_all_nested=True): result = test_iterated() @@ -249,7 +277,11 @@ def test_iterated() -> int: """) assert isinstance(wkflw.result, StepReference) - assert wkflw.result._.step.positional_args == {"alpha": True, "beta": True, "charlie": True} + assert wkflw.result._.step.positional_args == { + "alpha": True, + "beta": True, + "charlie": True, + } @task() def test_list_2() -> MyListWrapper: @@ -260,7 +292,7 @@ def test_list_2() -> MyListWrapper: def test_iterated_2(my_wrapper: MyListWrapper) -> int: """Workflow that tests iteration over a list in a `MyListWrapper`.""" # mypy cannot confirm argument types match. - return test_task(*my_wrapper.my_list) # type: ignore + return test_task(*my_wrapper.my_list) # type: ignore with set_configuration(allow_positional_args=True, flatten_all_nested=True): result = test_iterated_2(my_wrapper=test_list_2()) @@ -275,7 +307,7 @@ def test_list_3() -> Fixed[list[tuple[int, int]]]: def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: """Workflow that iterates over a list of integer tuples and performs operations.""" # mypy cannot confirm argument types match. - retval = mod10(*test_list_3()[0]) # type: ignore + retval = mod10(*test_list_3()[0]) # type: ignore for pair in param: a, b = pair retval += a + b @@ -328,27 +360,33 @@ def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: run: test_list_3 """) + def test_can_use_plain_dict_fields() -> None: """Test the use of plain dictionary fields in workflows.""" + @workflow() def test_dict(left: int, right: float) -> dict[str, float | int]: result: dict[str, float | int] = {"left": mod10(num=left), "right": pi()} return result with set_configuration(allow_plain_dict_fields=True): - result = test_dict(left=3, right=4.) + result = test_dict(left=3, right=4.0) wkflw = construct(result, simplify_ids=True) assert isinstance(wkflw.result, StepReference) assert str(wkflw.result["left"]) == "test_dict-1/left" assert wkflw.result["left"].__type__ == int | float + @dataclass class IndexTest: """A dataclass for testing indexed fields, containing a `left` field that is a list of integers.""" + left: Fixed[list[int]] + def test_can_configure_field_separator() -> None: """Test the ability to configure the field separator in workflows.""" + @task() def test_sep() -> IndexTest: return IndexTest(left=[3]) diff --git a/tests/test_nested.py b/tests/test_nested.py index 4a989a8e..39848112 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -8,14 +8,15 @@ from ._lib.extra import reverse_list, max_list + def test_can_supply_nested_raw() -> None: """TODO: The structures are important for future CWL rendering.""" pi = param("pi", math.pi) - result = reverse_list(to_sort=[1., 3., pi]) + result = reverse_list(to_sort=[1.0, 3.0, pi]) workflow = construct(max_list(lst=result + result), simplify_ids=True) - #assert workflow.find_parameters() == { + # assert workflow.find_parameters() == { # pi - #} + # } # NB: This is not currently usefully renderable in CWL. # However, the structures are important for future CWL rendering. diff --git a/tests/test_render_module.py b/tests/test_render_module.py index 70390a23..56caa17b 100644 --- a/tests/test_render_module.py +++ b/tests/test_render_module.py @@ -6,8 +6,9 @@ from ._lib.extra import increment, triple_and_one + def test_can_load_render_module() -> None: - """Checks if we can load a render module""" + """Checks if we can load a render module.""" result = triple_and_one(num=increment(num=3)) workflow = construct(result, simplify_ids=True) workflow._name = "Fred" @@ -15,7 +16,8 @@ def test_can_load_render_module() -> None: frender_py = Path(__file__).parent / "_lib/frender.py" render = get_render_method(frender_py) - assert render(workflow) == {"__root__": """ + assert render(workflow) == { + "__root__": """ I found a workflow called Fred. It has 2 steps! They are: @@ -25,7 +27,8 @@ def test_can_load_render_module() -> None: whose name is triple_and_one It probably got made with JUMP=1.0 -""", "triple_and_one-1": """ +""", + "triple_and_one-1": """ I found a workflow called triple_and_one. It has 3 steps! They are: @@ -36,4 +39,5 @@ def test_can_load_render_module() -> None: * Something called sum-1-2 It probably got made with JUMP=1.0 -"""} +""", + } diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index a2ee8e7d..4ddd30b8 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -49,12 +49,13 @@ def get_global_queue(num: int | float) -> Queue[int]: """Add a number to a global queue.""" return add_and_queue(num=to_int(num=num), queue=GLOBAL_QUEUE) + @workflow() def get_global_queues(num: int | float) -> list[Queue[int] | int]: """Add a number to a global queue.""" return [ add_and_queue(num=to_int(num=num), queue=GLOBAL_QUEUE), - add_constant(num=num) + add_constant(num=num), ] @@ -63,11 +64,13 @@ def add_constant(num: int | float) -> int: """Add a global constant to a number.""" return to_int(num=sum(left=num, right=CONSTANT)) + @workflow() def add_constants(num: int | float) -> int: """Add a global constant to a number.""" return to_int(num=sum(left=sum(left=num, right=CONSTANT), right=CONSTANT)) + @workflow() def get_values(num: int | float) -> tuple[int | float, int]: """Add a global constant to a number.""" @@ -310,7 +313,9 @@ def test_subworkflows_can_return_lists() -> None: run: get_global_queues """) - assert osubworkflows[0] == ("add_constant-1-1", yaml.safe_load(""" + assert osubworkflows[0] == ( + "add_constant-1-1", + yaml.safe_load(""" class: Workflow cwlVersion: 1.2 inputs: @@ -343,9 +348,12 @@ def test_subworkflows_can_return_lists() -> None: out: - out run: to_int - """)) + """), + ) - assert osubworkflows[1] == ("get_global_queues-1", yaml.safe_load(""" + assert osubworkflows[1] == ( + "get_global_queues-1", + yaml.safe_load(""" class: Workflow cwlVersion: 1.2 inputs: @@ -392,7 +400,9 @@ def test_subworkflows_can_return_lists() -> None: out: - out run: to_int - """)) + """), + ) + def test_can_merge_workflows() -> None: """Check whether we can merge workflows.""" @@ -501,7 +511,9 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: run: add_constants """) - assert osubworkflows[0] == ("add_constants-1", yaml.safe_load(""" + assert osubworkflows[0] == ( + "add_constants-1", + yaml.safe_load(""" class: Workflow cwlVersion: 1.2 inputs: @@ -543,30 +555,32 @@ def test_subworkflows_can_use_globals_in_right_scope() -> None: out: - out run: to_int - """)) + """), + ) + @define class PackResult: """A class representing the counts of card suits in a deck, including hearts, clubs, spades, and diamonds.""" + hearts: int clubs: int spades: int diamonds: int + def test_combining_attrs_and_factories() -> None: """Check combining attributes from a dataclass with factory-produced instances.""" Pack = factory(PackResult) @task() def sum(left: int, right: int) -> int: - return left + right + return left + right @workflow() def black_total(pack: PackResult) -> int: - return sum( - left=pack.spades, - right=pack.clubs - ) + return sum(left=pack.spades, right=pack.clubs) + pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13) wkflw = construct(black_total(pack=pack), simplify_ids=True) cwl = render(wkflw, allow_complex_types=True, factories_as_params=True) From 7ae5a7a335ab5c5da826f2cbbb1b278526f17f85 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 9 Sep 2024 21:59:53 +0100 Subject: [PATCH 097/108] fix: double-check importing correctly errors if render neither behaves as a package nor isolated module --- tests/_lib/unfrender.py | 29 +++++++++++++++++++++++++++++ tests/test_render_module.py | 14 ++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 tests/_lib/unfrender.py diff --git a/tests/_lib/unfrender.py b/tests/_lib/unfrender.py new file mode 100644 index 00000000..8bdb4bfe --- /dev/null +++ b/tests/_lib/unfrender.py @@ -0,0 +1,29 @@ +"""Testing example renderer. + +Correctly fails to import to show a broken module being handled. +""" + +from typing import Unpack, TypedDict +from dewret.workflow import Workflow + +# This lacking a relative import, while extra itself +# uses one is what breaks the module. It cannot be both +# a package and not-a-package. This is importable by +# adding a . before extra. If instead, you try to avoid +# relative imports altogether, and change .other -> other +# in extra, it will also break, as the directory of this +# file is not in the PATH. +from extra import JUMP # type: ignore + +class UnfrenderRendererConfiguration(TypedDict): + allow_complex_types: bool + +def default_config() -> UnfrenderRendererConfiguration: + return { + "allow_complex_types": True + } + +def render_raw( + workflow: Workflow, **kwargs: Unpack[UnfrenderRendererConfiguration] +) -> dict[str, str]: + return {"JUMP": str(JUMP)} diff --git a/tests/test_render_module.py b/tests/test_render_module.py index 56caa17b..eec1a278 100644 --- a/tests/test_render_module.py +++ b/tests/test_render_module.py @@ -1,5 +1,6 @@ """Check renderers can be imported live.""" +import pytest from pathlib import Path from dewret.tasks import construct from dewret.render import get_render_method @@ -41,3 +42,16 @@ def test_can_load_render_module() -> None: It probably got made with JUMP=1.0 """, } + +def test_get_correct_import_error_if_unable_to_load_render_module() -> None: + """TODO: Docstrings.""" + unfrender_py = Path(__file__).parent / "_lib/unfrender.py" + with pytest.raises(ImportError) as exc: + get_render_method(unfrender_py) + + entry = exc.traceback[-1] + assert Path(entry.path).resolve() == ( + Path(__file__).parent / "_lib" / "extra.py" + ).resolve() + assert entry.relline == 2 + assert "attempted relative import with no known parent package" in str(exc.value) From 2a2d7902b5c7fa4e85ec0362439edebbf851e8fb Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 9 Sep 2024 22:10:22 +0100 Subject: [PATCH 098/108] fix: check that importing a non-compliant render module throws an error --- tests/_lib/nonfrender.py | 22 ++++++++++++++++++++++ tests/test_render_module.py | 6 ++++++ 2 files changed, 28 insertions(+) create mode 100644 tests/_lib/nonfrender.py diff --git a/tests/_lib/nonfrender.py b/tests/_lib/nonfrender.py new file mode 100644 index 00000000..95079bda --- /dev/null +++ b/tests/_lib/nonfrender.py @@ -0,0 +1,22 @@ +"""Testing example renderer. + +Correctly fails to import to show a broken module being handled. +""" + +from typing import Unpack, TypedDict +from dewret.workflow import Workflow + +from .extra import JUMP + +class NonfrenderRendererConfiguration(TypedDict): + allow_complex_types: bool + +# This should fail to load as default_config is not present. However it would +# ignore the fact that the return type is not a (subtype of) dict[str, RawType] +# def default_config() -> int: +# return 3 + +def render_raw( + workflow: Workflow, **kwargs: Unpack[NonfrenderRendererConfiguration] +) -> dict[str, str]: + return {"JUMP": str(JUMP)} diff --git a/tests/test_render_module.py b/tests/test_render_module.py index eec1a278..aeac0424 100644 --- a/tests/test_render_module.py +++ b/tests/test_render_module.py @@ -55,3 +55,9 @@ def test_get_correct_import_error_if_unable_to_load_render_module() -> None: ).resolve() assert entry.relline == 2 assert "attempted relative import with no known parent package" in str(exc.value) + + nonfrender_py = Path(__file__).parent / "_lib/nonfrender.py" + with pytest.raises(NotImplementedError) as nexc: + get_render_method(nonfrender_py) + + assert "This render module neither seems to be a structured nor a raw render module" in str(nexc.value) From 241f3383eabad4d07ecef080dc09851ec45450c2 Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 9 Sep 2024 22:39:57 +0100 Subject: [PATCH 099/108] fix: double-check importing correctly errors if render neither behaves as a package nor isolated module --- src/dewret/render.py | 3 --- tests/_lib/unfrender.py | 5 +---- tests/test_render_module.py | 6 +++--- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/dewret/render.py b/src/dewret/render.py index a313613e..cf2bc95d 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -67,9 +67,6 @@ def get_render_method( """ render_module: BaseRenderModule if isinstance(renderer, Path): - if (render_dir := str(renderer.parent)) not in sys.path: - sys.path.append(render_dir) - # Attempt to load renderer as package, falling back to a single module otherwise. # This enables relative imports in renderers and therefore the ability to modularize. module = load_module_or_package("__renderer__", renderer) diff --git a/tests/_lib/unfrender.py b/tests/_lib/unfrender.py index 8bdb4bfe..5057c14c 100644 --- a/tests/_lib/unfrender.py +++ b/tests/_lib/unfrender.py @@ -9,10 +9,7 @@ # This lacking a relative import, while extra itself # uses one is what breaks the module. It cannot be both # a package and not-a-package. This is importable by -# adding a . before extra. If instead, you try to avoid -# relative imports altogether, and change .other -> other -# in extra, it will also break, as the directory of this -# file is not in the PATH. +# adding a . before extra. from extra import JUMP # type: ignore class UnfrenderRendererConfiguration(TypedDict): diff --git a/tests/test_render_module.py b/tests/test_render_module.py index aeac0424..ba9b5a68 100644 --- a/tests/test_render_module.py +++ b/tests/test_render_module.py @@ -51,10 +51,10 @@ def test_get_correct_import_error_if_unable_to_load_render_module() -> None: entry = exc.traceback[-1] assert Path(entry.path).resolve() == ( - Path(__file__).parent / "_lib" / "extra.py" + Path(__file__).parent / "_lib" / "unfrender.py" ).resolve() - assert entry.relline == 2 - assert "attempted relative import with no known parent package" in str(exc.value) + assert entry.relline == 12 + assert "No module named 'extra'" in str(exc.value) nonfrender_py = Path(__file__).parent / "_lib/nonfrender.py" with pytest.raises(NotImplementedError) as nexc: From fe5253c3f9da17dd0a23339a050767833e777d5c Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 9 Sep 2024 22:43:15 +0100 Subject: [PATCH 100/108] fix: double-check importing correctly errors if render neither behaves as a package nor isolated module --- tests/test_render_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_render_module.py b/tests/test_render_module.py index ba9b5a68..63fb0dce 100644 --- a/tests/test_render_module.py +++ b/tests/test_render_module.py @@ -46,7 +46,7 @@ def test_can_load_render_module() -> None: def test_get_correct_import_error_if_unable_to_load_render_module() -> None: """TODO: Docstrings.""" unfrender_py = Path(__file__).parent / "_lib/unfrender.py" - with pytest.raises(ImportError) as exc: + with pytest.raises(ModuleNotFoundError) as exc: get_render_method(unfrender_py) entry = exc.traceback[-1] From c28a31eb4895a9cf3d580a40acc3aa6611be459a Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 9 Sep 2024 22:51:14 +0100 Subject: [PATCH 101/108] fix: ensure pytest ignores tests._lib --- .github/workflows/python-test-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-test-ci.yml b/.github/workflows/python-test-ci.yml index 859ea89d..d5957b64 100644 --- a/.github/workflows/python-test-ci.yml +++ b/.github/workflows/python-test-ci.yml @@ -21,7 +21,7 @@ jobs: run: | pip install pytest pytest-cov pip install --no-build-isolation --no-deps --disable-pip-version-check -e . - python -m pytest --doctest-modules --ignore=example + python -m pytest --doctest-modules --ignore=example --ignore=tests/_lib python -m doctest -v docs/*.md # name: Test examples # run: | @@ -46,7 +46,7 @@ jobs: conda install -c /tmp/output/noarch/*.conda --update-deps --use-local dewret -y conda install pytest $CONDA/bin/pytest - python -m pytest --doctest-modules --ignore=example + python -m pytest --doctest-modules --ignore=example --ignore=tests/_lib python -m doctest -v docs/*.md # name: Test examples # run: | From 713882af358a70acfb318c2d89fdad628147877c Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 10 Sep 2024 10:03:50 +0100 Subject: [PATCH 102/108] fix: correct name of default_config --- src/dewret/renderers/cwl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 8e229c8a..18c086f4 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -160,7 +160,7 @@ class CWLRendererConfiguration(TypedDict): factories_as_params: NotRequired[bool] -def default_renderer_config() -> CWLRendererConfiguration: +def default_config() -> CWLRendererConfiguration: """Default configuration for this renderer. This is a hook-like call to give a configuration dict that this renderer From 67147868da1057611e8b317507f8c5979587f64b Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 10 Sep 2024 23:08:45 +0100 Subject: [PATCH 103/108] fix: remove Graph import from dask.typing in backend_dask as it seems to cause errors --- src/dewret/backends/backend_dask.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 48d55859..47228d20 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -18,7 +18,6 @@ """ from dask.delayed import delayed, DelayedLeaf -from dask.typing import Graph from dask.config import config from typing import Protocol, runtime_checkable, Any, cast from concurrent.futures import ThreadPoolExecutor @@ -37,7 +36,7 @@ class Delayed(Protocol): """ @property - def __dask_graph__(self) -> Graph: + def __dask_graph__(self): # type: ignore """Retrieve the dask graph.""" ... From dd278010435e5227400e12eea44c03054fa64b3a Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Tue, 17 Sep 2024 11:20:50 +0300 Subject: [PATCH 104/108] ops: enable example tests in CI --- .github/workflows/python-test-ci.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/python-test-ci.yml b/.github/workflows/python-test-ci.yml index d5957b64..b8975e70 100644 --- a/.github/workflows/python-test-ci.yml +++ b/.github/workflows/python-test-ci.yml @@ -23,9 +23,9 @@ jobs: pip install --no-build-isolation --no-deps --disable-pip-version-check -e . python -m pytest --doctest-modules --ignore=example --ignore=tests/_lib python -m doctest -v docs/*.md - # name: Test examples - # run: | - # (cd example; examples=$(grep "^\\$ " *.py | sed "s/.*\\$ //g"); while IFS= read -r line; do PYTHONPATH=.:$PYTHONPATH eval $line; done <<< "$examples") + name: Test examples + run: | + (cd example; examples=$(grep "^\\$ " *.py | sed "s/.*\\$ //g"); while IFS= read -r line; do PYTHONPATH=.:$PYTHONPATH eval $line; done <<< "$examples") unit-conda: runs-on: ubuntu-latest steps: @@ -48,6 +48,6 @@ jobs: $CONDA/bin/pytest python -m pytest --doctest-modules --ignore=example --ignore=tests/_lib python -m doctest -v docs/*.md - # name: Test examples - # run: | - # (cd example; examples=$(grep "^\\$ " *.py | sed "s/.*\\$ //g"); while IFS= read -r line; do PYTHONPATH=.:$PYTHONPATH eval $line; done <<< "$examples") + name: Test examples + run: | + (cd example; examples=$(grep "^\\$ " *.py | sed "s/.*\\$ //g"); while IFS= read -r line; do PYTHONPATH=.:$PYTHONPATH eval $line; done <<< "$examples") From b6677c02a612905e22f50b6bcc8f71d72b275535 Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Tue, 17 Sep 2024 14:18:27 +0300 Subject: [PATCH 105/108] feat: Added a test for default renderer --- tests/test_render_module.py | 39 ++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/test_render_module.py b/tests/test_render_module.py index 63fb0dce..648191e7 100644 --- a/tests/test_render_module.py +++ b/tests/test_render_module.py @@ -1,5 +1,6 @@ """Check renderers can be imported live.""" +import yaml import pytest from pathlib import Path from dewret.tasks import construct @@ -43,8 +44,44 @@ def test_can_load_render_module() -> None: """, } +def test_can_load_cwl_render_module() -> None: + """Checks if we can load a render module.""" + result = triple_and_one(num=increment(num=3)) + workflow = construct(result, simplify_ids=True) + workflow._name = "Fred" + + frender_py = Path(__file__).parent.parent / "src/dewret/renderers/cwl.py" + render = get_render_method(frender_py) + assert yaml.safe_load(render(workflow)["__root__"]) == yaml.safe_load(""" + cwlVersion: 1.2 + class: Workflow + inputs: + increment-1-num: + label: num + type: int + default: 3 + outputs: + out: + label: out + type: [int, float] + outputSource: triple_and_one-1/out + steps: + increment-1: + run: increment + in: + num: + source: increment-1-num + out: [out] + triple_and_one-1: + run: triple_and_one + in: + num: + source: increment-1/out + out: [out] + """) + def test_get_correct_import_error_if_unable_to_load_render_module() -> None: - """TODO: Docstrings.""" + """Check if the correct import error will be logged if unable to load render module.""" unfrender_py = Path(__file__).parent / "_lib/unfrender.py" with pytest.raises(ModuleNotFoundError) as exc: get_render_method(unfrender_py) From 3b9ccfeeeca4690989093527784894986f478e1d Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Tue, 17 Sep 2024 14:18:27 +0300 Subject: [PATCH 106/108] feat: Added a test for default renderer --- tests/test_render_module.py | 37 +++++++------------------------------ 1 file changed, 7 insertions(+), 30 deletions(-) diff --git a/tests/test_render_module.py b/tests/test_render_module.py index 648191e7..19e2a510 100644 --- a/tests/test_render_module.py +++ b/tests/test_render_module.py @@ -1,6 +1,5 @@ """Check renderers can be imported live.""" -import yaml import pytest from pathlib import Path from dewret.tasks import construct @@ -44,42 +43,20 @@ def test_can_load_render_module() -> None: """, } + def test_can_load_cwl_render_module() -> None: """Checks if we can load a render module.""" result = triple_and_one(num=increment(num=3)) workflow = construct(result, simplify_ids=True) - workflow._name = "Fred" frender_py = Path(__file__).parent.parent / "src/dewret/renderers/cwl.py" render = get_render_method(frender_py) - assert yaml.safe_load(render(workflow)["__root__"]) == yaml.safe_load(""" - cwlVersion: 1.2 - class: Workflow - inputs: - increment-1-num: - label: num - type: int - default: 3 - outputs: - out: - label: out - type: [int, float] - outputSource: triple_and_one-1/out - steps: - increment-1: - run: increment - in: - num: - source: increment-1-num - out: [out] - triple_and_one-1: - run: triple_and_one - in: - num: - source: increment-1/out - out: [out] - """) - + assert render(workflow) == { + "__root__": "{'cwlVersion': 1.2, 'class': 'Workflow', 'inputs': {'increment-1-num': {'label': 'num', 'type': 'int', 'default': 3}}, 'outputs': {'out': {'label': 'out', 'type': ['int', 'float'], 'outputSource': 'triple_and_one-1/out'}}, 'steps': {'increment-1': {'run': 'increment', 'in': {'num': {'source': 'increment-1-num'}}, 'out': ['out']}, 'triple_and_one-1': {'run': 'triple_and_one', 'in': {'num': {'source': 'increment-1/out'}}, 'out': ['out']}}}", + "triple_and_one-1": "{'cwlVersion': 1.2, 'class': 'Workflow', 'inputs': {'num': {'label': 'num', 'type': 'int'}}, 'outputs': {'out': {'label': 'out', 'type': ['int', 'float'], 'outputSource': 'sum-1-1/out'}}, 'steps': {'double-1-1': {'run': 'double', 'in': {'num': {'source': 'num'}}, 'out': ['out']}, 'sum-1-1': {'run': 'sum', 'in': {'left': {'source': 'sum-1-2/out'}, 'right': {'default': 1}}, 'out': ['out']}, 'sum-1-2': {'run': 'sum', 'in': {'left': {'source': 'double-1-1/out'}, 'right': {'source': 'num'}}, 'out': ['out']}}}" + } + + def test_get_correct_import_error_if_unable_to_load_render_module() -> None: """Check if the correct import error will be logged if unable to load render module.""" unfrender_py = Path(__file__).parent / "_lib/unfrender.py" From f519c47ae4b6bb6d78e5610da83f570f89bf838f Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Mon, 30 Sep 2024 17:49:19 +0300 Subject: [PATCH 107/108] feat: Abstracted write_rendered_output from the main dewret CLI --- src/dewret/__main__.py | 53 ++++++++++++++++++++++++------------------ src/dewret/render.py | 27 ++++++++++++++++++++- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index ced66e51..92d947d2 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -31,9 +31,14 @@ import click import json -from .core import set_configuration, set_render_configuration, RawRenderModule, StructuredRenderModule +from .core import ( + set_configuration, + set_render_configuration, + RawRenderModule, + StructuredRenderModule, +) from .utils import load_module_or_package -from .render import get_render_method +from .render import get_render_method, write_rendered_output from .tasks import Backend, construct @@ -72,7 +77,15 @@ @click.argument("task") @click.argument("arguments", nargs=-1) def render( - workflow_py: Path, task: str, arguments: list[str], pretty: bool, backend: Backend, construct_args: str, renderer: str, renderer_args: str, output: str + workflow_py: Path, + task: str, + arguments: list[str], + pretty: bool, + backend: Backend, + construct_args: str, + renderer: str, + renderer_args: str, + output: str, ) -> None: """Render a workflow. @@ -91,7 +104,7 @@ def render( kwargs[key] = json.loads(val) render_module: Path | ModuleType - if (mtch := re.match(r"^([a-z_0-9-.]+)$", renderer)): + if mtch := re.match(r"^([a-z_0-9-.]+)$", renderer): render_module = importlib.import_module(f"dewret.renderers.{mtch.group(1)}") if not isinstance(render_module, RawRenderModule) and not isinstance(render_module, StructuredRenderModule): raise NotImplementedError("The imported render module does not seem to match the `RawRenderModule` or `StructuredRenderModule` protocols.") @@ -118,18 +131,22 @@ def render( renderer_kwargs = dict(pair.split(":") for pair in renderer_args.split(",")) if output == "-": + @contextmanager def _opener(key: str, _: str) -> Generator[IO[Any], None, None]: print(" ------ ", key, " ------ ") yield sys.stdout print() + opener = _opener else: + @contextmanager def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]: output_file = output.replace("%", key) with Path(output_file).open(mode) as output_f: yield output_f + opener = _opener render = get_render_method(render_module, pretty=pretty) @@ -138,30 +155,20 @@ def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]: task_fn = getattr(workflow, task) try: - with set_configuration(**construct_kwargs), set_render_configuration(renderer_kwargs): - rendered = render(construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs) + with ( + set_configuration(**construct_kwargs), + set_render_configuration(renderer_kwargs), + ): + rendered = render( + construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs + ) except Exception as exc: import traceback print(exc, exc.__cause__, exc.__context__) traceback.print_exc() else: - if len(rendered) == 1: - with opener("", "w") as output_f: - output_f.write(rendered["__root__"]) - elif "%" in output: - for key, value in rendered.items(): - if key == "__root__": - key = "ROOT" - with opener(key, "w") as output_f: - output_f.write(value) - else: - with opener("ROOT", "w") as output_f: - output_f.write(rendered["__root__"]) - del rendered["__root__"] - for key, value in rendered.items(): - with opener(key, "a") as output_f: - output_f.write("\n---\n") - output_f.write(value) + write_rendered_output(rendered, output, opener) + render() diff --git a/src/dewret/render.py b/src/dewret/render.py index cf2bc95d..d702dcad 100644 --- a/src/dewret/render.py +++ b/src/dewret/render.py @@ -21,7 +21,7 @@ import sys from pathlib import Path from functools import partial -from typing import TypeVar, Callable, cast +from typing import TypeVar, Callable, ContextManager, IO, Any, cast import yaml from .workflow import Workflow, NestedStep @@ -100,6 +100,31 @@ def _render( ) +def write_rendered_output( + rendered: dict[str, str] | dict[str, RawType], + output: str, + opener: Callable[[str, str], ContextManager[IO[Any]]], +) -> None: + """Utility function to handle writing rendered output to file or stdout.""" + if len(rendered) == 1: + with opener("", "w") as output_f: + output_f.write(rendered["__root__"]) + elif "%" in output: + for key, value in rendered.items(): + if key == "__root__": + key = "ROOT" + with opener(key, "w") as output_f: + output_f.write(value) + else: + with opener("ROOT", "w") as output_f: + output_f.write(rendered["__root__"]) + del rendered["__root__"] + for key, value in rendered.items(): + with opener(key, "a") as output_f: + output_f.write("\n---\n") + output_f.write(value) + + def base_render(workflow: Workflow, build_cb: Callable[[Workflow], T]) -> dict[str, T]: """Render to a dict-like structure. From a25d3c2235506b3b27ac11b04260277b18ac81ed Mon Sep 17 00:00:00 2001 From: KamenDimitrov97 Date: Mon, 30 Sep 2024 18:06:24 +0300 Subject: [PATCH 108/108] docs: Updated renderer tutorial with the protocols implementation needed for a custom renderer to be considered a dewret renderer --- docs/renderer_tutorial.md | 62 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/docs/renderer_tutorial.md b/docs/renderer_tutorial.md index f4d0f702..2fd1d6cf 100644 --- a/docs/renderer_tutorial.md +++ b/docs/renderer_tutorial.md @@ -49,7 +49,63 @@ class WorkflowDefinition: } ``` -## 3. Create a StepDefinition. +## 3. Ensuring Our Module is Recognized as a Render Module +To have our custom renderer identified by Dewret as a valid renderer, we need to implement the `BaseRenderModule` along with one of the two protocols: `RawRenderModule` or `StructuredRenderModule`. + +#### Implementing BaseRenderModule +The `BaseRenderModule` defines the foundation for a custom renderer. To implement this protocol, we need to define the `default_config()` method, which provides default configurations for our renderer. + +```python +def default_config() -> CWLRendererConfiguration: + """Default configuration for this renderer. + + This is a hook-like call to give a configuration dict that this renderer + will respect, and sets any necessary default values. + + Returns: a dict with (preferably) raw type structures to enable easy setting + from YAML/JSON. + """ + return { + "allow_complex_types": False, + "factories_as_params": False, + } +``` + +After implementing `BaseRenderModule`, you need to implement either the `RawRenderModule` or `StructuredRenderModule` protocol, depending on how you want to handle the workflow rendering. + +#### Implementing either RawRenderModule or StructuredRenderModule +The `StructuredRenderModule` is designed for structured workflows that are directly ready to be output in the respective format (e.g., CWL, Snakemake, etc.). The key method to implement is `render`, which converts a workflow into a structured, serializable format. +```python +def render( + self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration + ) -> dict[str, dict[str, RawType]]: + """Turn a workflow into a serializable structure. + + Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least. + """ + ... +``` +In this method: +- You receive a workflow and potentially some optional configurations. +- You return a dictionary where the `__root__` key holds the primary workflow and any additional subworkflows are nested inside the returned structure. + +If you prefer more flexibility and want the structuring to be handled by the user, you can implement the `RawRenderModule` protocol. This requires defining the `render_raw` method, which converts a workflow into raw, flat strings. +```python + def render_raw( + self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration + ) -> dict[str, str]: + """Turn a workflow into flat strings. + + Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least. + """ + ... +``` +In this method: + +- The workflow is rendered as raw, unstructured strings. +- The user is responsible for handling the structuring of the rendered output. + +## 4. Create a StepDefinition. Create a StepsDefinition class create each of the code blocks needed for a rule(step) to be executable in Snakemake. When you have defined each block in your target workflow language task from [step 1](#1-understand-the-target-workflow-language), @@ -128,7 +184,7 @@ class StepDefinition: } ``` -## 4. Create the Separate block definitions. +## 5. Create the Separate block definitions. In this step, you'll define classes to handle the rendering of each code block required for a rule (step) to be executable in the target workflow language. Each of these classes will encapsulate the logic for converting parts of a workflow step into the target language format. @@ -335,7 +391,7 @@ class OutputDefinition: Integrate these block definitions into the StepDefinition class as demonstrated in [Step 3](#3-create-a-stepsdefinition). Each StepDefinition will use these block definitions to render the complete step in the target workflow language. -## 5. Helper methods. +## 6. Helper methods. In this step, you'll define helper methods that will assist you in converting workflow components into the target workflow language format. In our case these methods will handle type conversion, extracting method arguments, and computing relative paths.