diff --git a/.github/workflows/python-static-ci.yml b/.github/workflows/python-static-ci.yml index f1925412..14b41e89 100644 --- a/.github/workflows/python-static-ci.yml +++ b/.github/workflows/python-static-ci.yml @@ -14,7 +14,7 @@ jobs: - uses: python/mypy@v1.8.0 with: paths: "./src" - - run: pip install .[test] && python -m mypy --strict --install-types --non-interactive ./src ./tests ./example + - run: pip install .[test] && python -m mypy --strict --install-types --allow-subclassing-any --non-interactive ./src ./tests ./example audit: runs-on: ubuntu-latest diff --git a/.github/workflows/python-test-ci.yml b/.github/workflows/python-test-ci.yml index 2cba3af2..b8975e70 100644 --- a/.github/workflows/python-test-ci.yml +++ b/.github/workflows/python-test-ci.yml @@ -3,12 +3,16 @@ on: [push, pull_request] jobs: unit-pip: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11", "3.12"] steps: - uses: actions/checkout@v4 - - name: Set up Python + + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: - python-version: '3.11' + python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip "hatchling < 1.22" @@ -17,11 +21,11 @@ jobs: run: | pip install pytest pytest-cov pip install --no-build-isolation --no-deps --disable-pip-version-check -e . - python -m pytest --doctest-modules --ignore=example + python -m pytest --doctest-modules --ignore=example --ignore=tests/_lib python -m doctest -v docs/*.md - # name: Test examples - # run: | - # (cd example; examples=$(grep "^\\$ " *.py | sed "s/.*\\$ //g"); while IFS= read -r line; do PYTHONPATH=.:$PYTHONPATH eval $line; done <<< "$examples") + name: Test examples + run: | + (cd example; examples=$(grep "^\\$ " *.py | sed "s/.*\\$ //g"); while IFS= read -r line; do PYTHONPATH=.:$PYTHONPATH eval $line; done <<< "$examples") unit-conda: runs-on: ubuntu-latest steps: @@ -42,8 +46,8 @@ jobs: conda install -c /tmp/output/noarch/*.conda --update-deps --use-local dewret -y conda install pytest $CONDA/bin/pytest - python -m pytest --doctest-modules --ignore=example + python -m pytest --doctest-modules --ignore=example --ignore=tests/_lib python -m doctest -v docs/*.md - # name: Test examples - # run: | - # (cd example; examples=$(grep "^\\$ " *.py | sed "s/.*\\$ //g"); while IFS= read -r line; do PYTHONPATH=.:$PYTHONPATH eval $line; done <<< "$examples") + name: Test examples + run: | + (cd example; examples=$(grep "^\\$ " *.py | sed "s/.*\\$ //g"); while IFS= read -r line; do PYTHONPATH=.:$PYTHONPATH eval $line; done <<< "$examples") diff --git a/docs/quickstart.md b/docs/quickstart.md index 34a5be6c..c5d04e13 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -97,14 +97,14 @@ and backends, as well as bespoke serialization or formatting. >>> >>> result = increment(num=3) >>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> cwl = render(workflow)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: increment-1-num: default: 3 - label: increment-1-num + label: num type: int outputs: out: diff --git a/docs/renderer_tutorial.md b/docs/renderer_tutorial.md index 43cec19e..2fd1d6cf 100644 --- a/docs/renderer_tutorial.md +++ b/docs/renderer_tutorial.md @@ -49,7 +49,63 @@ class WorkflowDefinition: } ``` -## 3. Create a StepDefinition. +## 3. Ensuring Our Module is Recognized as a Render Module +To have our custom renderer identified by Dewret as a valid renderer, we need to implement the `BaseRenderModule` along with one of the two protocols: `RawRenderModule` or `StructuredRenderModule`. + +#### Implementing BaseRenderModule +The `BaseRenderModule` defines the foundation for a custom renderer. To implement this protocol, we need to define the `default_config()` method, which provides default configurations for our renderer. + +```python +def default_config() -> CWLRendererConfiguration: + """Default configuration for this renderer. + + This is a hook-like call to give a configuration dict that this renderer + will respect, and sets any necessary default values. + + Returns: a dict with (preferably) raw type structures to enable easy setting + from YAML/JSON. + """ + return { + "allow_complex_types": False, + "factories_as_params": False, + } +``` + +After implementing `BaseRenderModule`, you need to implement either the `RawRenderModule` or `StructuredRenderModule` protocol, depending on how you want to handle the workflow rendering. + +#### Implementing either RawRenderModule or StructuredRenderModule +The `StructuredRenderModule` is designed for structured workflows that are directly ready to be output in the respective format (e.g., CWL, Snakemake, etc.). The key method to implement is `render`, which converts a workflow into a structured, serializable format. +```python +def render( + self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration + ) -> dict[str, dict[str, RawType]]: + """Turn a workflow into a serializable structure. + + Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least. + """ + ... +``` +In this method: +- You receive a workflow and potentially some optional configurations. +- You return a dictionary where the `__root__` key holds the primary workflow and any additional subworkflows are nested inside the returned structure. + +If you prefer more flexibility and want the structuring to be handled by the user, you can implement the `RawRenderModule` protocol. This requires defining the `render_raw` method, which converts a workflow into raw, flat strings. +```python + def render_raw( + self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration + ) -> dict[str, str]: + """Turn a workflow into flat strings. + + Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least. + """ + ... +``` +In this method: + +- The workflow is rendered as raw, unstructured strings. +- The user is responsible for handling the structuring of the rendered output. + +## 4. Create a StepDefinition. Create a StepsDefinition class create each of the code blocks needed for a rule(step) to be executable in Snakemake. When you have defined each block in your target workflow language task from [step 1](#1-understand-the-target-workflow-language), @@ -128,7 +184,7 @@ class StepDefinition: } ``` -## 4. Create the Separate block definitions. +## 5. Create the Separate block definitions. In this step, you'll define classes to handle the rendering of each code block required for a rule (step) to be executable in the target workflow language. Each of these classes will encapsulate the logic for converting parts of a workflow step into the target language format. @@ -335,7 +391,7 @@ class OutputDefinition: Integrate these block definitions into the StepDefinition class as demonstrated in [Step 3](#3-create-a-stepsdefinition). Each StepDefinition will use these block definitions to render the complete step in the target workflow language. -## 5. Helper methods. +## 6. Helper methods. In this step, you'll define helper methods that will assist you in converting workflow components into the target workflow language format. In our case these methods will handle type conversion, extracting method arguments, and computing relative paths. @@ -401,11 +457,11 @@ import inspect import typing from attrs import define +from dewret.utils import Raw, BasicType from dewret.workflow import Lazy -from dewret.workflow import Reference, Raw, Workflow, Step, Task -from dewret.utils import BasicType +from dewret.workflow import Reference, Workflow, Step, Task -RawType = typing.Union[BasicType, list[str], list["RawType"], dict[str, "RawType"]] +RawType = BasicType | list[str] | list["RawType"] | dict[str, "RawType"] ``` ## To run this example: @@ -416,6 +472,3 @@ RawType = typing.Union[BasicType, list[str], list["RawType"], dict[str, "RawType ```shell python snakemake_tasks.py ``` - -### Q: Should I add a brief description of dewret in step 1? Should link dewret types/docs etc here? -### A: Get details on how that happens and probably yes. \ No newline at end of file diff --git a/docs/workflows.md b/docs/workflows.md index 94b05efb..84d78fe7 100644 --- a/docs/workflows.md +++ b/docs/workflows.md @@ -63,19 +63,15 @@ In code, this would be: ... left=double(num=increment(num=23)), ... right=mod10(num=increment(num=23)) ... ) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(result, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: increment-1-num: default: 23 - label: increment-1-num - type: int - increment-2-num: - default: 23 - label: increment-2-num + label: num type: int outputs: out: @@ -86,7 +82,7 @@ steps: double-1: in: num: - source: increment-2/out + source: increment-1/out out: - out run: double @@ -97,13 +93,6 @@ steps: out: - out run: increment - increment-2: - in: - num: - source: increment-2-num - out: - - out - run: increment mod10-1: in: num: @@ -157,8 +146,8 @@ This duplication can be avoided by explicitly indicating that the parameters are ... left=double(num=increment(num=num)), ... right=mod10(num=increment(num=num)) ... ) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(result, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -232,8 +221,8 @@ For example: ... return (num + INPUT_NUM) % INPUT_NUM >>> >>> result = rotate(num=5) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(result, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -244,7 +233,7 @@ inputs: type: int rotate-1-num: default: 5 - label: rotate-1-num + label: num type: int outputs: out: @@ -284,22 +273,24 @@ As code: ```python >>> import sys >>> import yaml ->>> from dewret.tasks import task, construct, nested_task +>>> from dewret.core import set_configuration +>>> from dewret.tasks import task, construct, workflow >>> from dewret.renderers.cwl import render >>> INPUT_NUM = 3 >>> @task() ... def rotate(num: int) -> int: -... """Rotate an integer.""" -... return (num + INPUT_NUM) % INPUT_NUM +... """Rotate an integer.""" +... return (num + INPUT_NUM) % INPUT_NUM >>> ->>> @nested_task() +>>> @workflow() ... def double_rotate(num: int) -> int: -... """Rotate an integer twice.""" -... return rotate(num=rotate(num=num)) +... """Rotate an integer twice.""" +... return rotate(num=rotate(num=num)) >>> ->>> result = double_rotate(num=3) ->>> workflow = construct(result, simplify_ids=True) ->>> cwl = render(workflow) +>>> with set_configuration(flatten_all_nested=True): +... result = double_rotate(num=3) +... wkflw = construct(result, simplify_ids=True) +... cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -315,7 +306,7 @@ inputs: outputs: out: label: out - outputSource: rotate-2/out + outputSource: rotate-1/out type: int steps: rotate-1: @@ -323,7 +314,7 @@ steps: INPUT_NUM: source: INPUT_NUM num: - source: num + source: rotate-2/out out: - out run: rotate @@ -332,7 +323,7 @@ steps: INPUT_NUM: source: INPUT_NUM num: - source: rotate-1/out + source: num out: - out run: rotate @@ -349,7 +340,7 @@ For example, the following code renders the same workflow as in the previous exa ```python -@nested_task() +@workflow() def double_rotate(num: int) -> int: """Rotate an integer twice.""" unused_var = increment(num=num) @@ -409,19 +400,15 @@ As code: ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> workflow = construct(red_total, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(red_total, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: shuffle-1-max_cards_per_suit: default: 13 - label: shuffle-1-max_cards_per_suit - type: int - shuffle-2-max_cards_per_suit: - default: 13 - label: shuffle-2-max_cards_per_suit + label: max_cards_per_suit type: int outputs: out: @@ -447,28 +434,10 @@ steps: label: spades type: int run: shuffle - shuffle-2: - in: - max_cards_per_suit: - source: shuffle-2-max_cards_per_suit - out: - clubs: - label: clubs - type: int - diamonds: - label: diamonds - type: int - hearts: - label: hearts - type: int - spades: - label: spades - type: int - run: shuffle sum-1: in: left: - source: shuffle-2/hearts + source: shuffle-1/hearts right: source: shuffle-1/diamonds out: @@ -510,19 +479,15 @@ Here, we show the same example with `dataclasses`. ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> workflow = construct(red_total, simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(red_total, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: shuffle-1-max_cards_per_suit: default: 13 - label: shuffle-1-max_cards_per_suit - type: int - shuffle-2-max_cards_per_suit: - default: 13 - label: shuffle-2-max_cards_per_suit + label: max_cards_per_suit type: int outputs: out: @@ -548,28 +513,10 @@ steps: label: spades type: int run: shuffle - shuffle-2: - in: - max_cards_per_suit: - source: shuffle-2-max_cards_per_suit - out: - clubs: - label: clubs - type: int - diamonds: - label: diamonds - type: int - hearts: - label: hearts - type: int - spades: - label: spades - type: int - run: shuffle sum-1: in: left: - source: shuffle-2/hearts + source: shuffle-1/hearts right: source: shuffle-1/diamonds out: @@ -589,7 +536,7 @@ dewret will produce multiple output workflows that reference each other. >>> import yaml >>> from attrs import define >>> from numpy import random ->>> from dewret.tasks import task, construct, subworkflow +>>> from dewret.tasks import task, construct, workflow >>> from dewret.renderers.cwl import render >>> @define ... class PackResult: @@ -611,21 +558,21 @@ dewret will produce multiple output workflows that reference each other. ... spades=random.randint(max_cards_per_suit), ... diamonds=random.randint(max_cards_per_suit) ... ) ->>> @subworkflow() -... def red_total(): +>>> @workflow() +... def red_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> @subworkflow() -... def black_total(): +>>> @workflow() +... def black_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).spades, ... right=shuffle(max_cards_per_suit=13).clubs ... ) >>> total = sum(left=red_total(), right=black_total()) ->>> workflow = construct(total, simplify_ids=True) ->>> cwl, subworkflows = render(workflow) +>>> wkflw = construct(total, simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 @@ -667,7 +614,7 @@ as a second term. >>> import yaml >>> from attrs import define >>> from numpy import random ->>> from dewret.tasks import task, construct, subworkflow +>>> from dewret.tasks import task, construct, workflow >>> from dewret.renderers.cwl import render >>> @define ... class PackResult: @@ -689,33 +636,25 @@ as a second term. ... def sum(left: int, right: int) -> int: ... return left + right >>> ->>> @subworkflow() -... def red_total(): +>>> @workflow() +... def red_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).hearts, ... right=shuffle(max_cards_per_suit=13).diamonds ... ) ->>> @subworkflow() -... def black_total(): +>>> @workflow() +... def black_total() -> int: ... return sum( ... left=shuffle(max_cards_per_suit=13).spades, ... right=shuffle(max_cards_per_suit=13).clubs ... ) >>> total = sum(left=red_total(), right=black_total()) ->>> workflow = construct(total, simplify_ids=True) ->>> cwl, subworkflows = render(workflow) ->>> yaml.dump(subworkflows["red_total-1"], sys.stdout, indent=2) +>>> wkflw = construct(total, simplify_ids=True) +>>> cwl = render(wkflw) +>>> yaml.dump(cwl["red_total-1"], sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 -inputs: - shuffle-1-1-max_cards_per_suit: - default: 13 - label: shuffle-1-1-max_cards_per_suit - type: int - shuffle-1-2-max_cards_per_suit: - default: 13 - label: shuffle-1-2-max_cards_per_suit - type: int +inputs: {} outputs: out: label: out @@ -725,25 +664,7 @@ steps: shuffle-1-1: in: max_cards_per_suit: - source: shuffle-1-1-max_cards_per_suit - out: - clubs: - label: clubs - type: int - diamonds: - label: diamonds - type: int - hearts: - label: hearts - type: int - spades: - label: spades - type: int - run: shuffle - shuffle-1-2: - in: - max_cards_per_suit: - source: shuffle-1-2-max_cards_per_suit + default: 13 out: clubs: label: clubs @@ -761,7 +682,7 @@ steps: sum-1-1: in: left: - source: shuffle-1-2/hearts + source: shuffle-1-1/hearts right: source: shuffle-1-1/diamonds out: @@ -783,7 +704,7 @@ Below is the default output, treating `Pack` as a task. ```python >>> import sys >>> import yaml ->>> from dewret.tasks import subworkflow, factory, nested_task, construct, task +>>> from dewret.tasks import workflow, factory, workflow, construct, task >>> from attrs import define >>> from dewret.renderers.cwl import render >>> @define @@ -799,39 +720,39 @@ Below is the default output, treating `Pack` as a task. ... def sum(left: int, right: int) -> int: ... return left + right >>> ->>> @nested_task() -... def black_total(pack: PackResult): +>>> @workflow() +... def black_total(pack: PackResult) -> int: ... return sum( ... left=pack.spades, ... right=pack.clubs ... ) >>> pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13) ->>> workflow = construct(black_total(pack=pack), simplify_ids=True) ->>> cwl = render(workflow) +>>> wkflw = construct(black_total(pack=pack), simplify_ids=True) +>>> cwl = render(wkflw)["__root__"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: PackResult-1-clubs: default: 13 - label: PackResult-1-clubs + label: clubs type: int PackResult-1-diamonds: default: 13 - label: PackResult-1-diamonds + label: diamonds type: int PackResult-1-hearts: default: 13 - label: PackResult-1-hearts + label: hearts type: int PackResult-1-spades: default: 13 - label: PackResult-1-spades + label: spades type: int outputs: out: label: out - outputSource: sum-1/out + outputSource: black_total-1/out type: int steps: PackResult-1: @@ -858,15 +779,13 @@ steps: label: spades type: int run: PackResult - sum-1: + black_total-1: in: - left: - source: PackResult-1/spades - right: - source: PackResult-1/clubs + pack: + source: PackResult-1/out out: - out - run: sum + run: black_total ``` @@ -876,7 +795,7 @@ types are allowed. ```python >>> import sys >>> import yaml ->>> from dewret.tasks import task, factory, nested_task, construct +>>> from dewret.tasks import task, factory, workflow, construct >>> from attrs import define >>> from dewret.renderers.cwl import render >>> @define @@ -891,36 +810,36 @@ types are allowed. ... def sum(left: int, right: int) -> int: ... return left + right >>> ->>> @nested_task() -... def black_total(pack: PackResult): +>>> @workflow() +... def black_total(pack: PackResult) -> int: ... return sum( ... left=pack.spades, ... right=pack.clubs ... ) >>> pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13) ->>> workflow = construct(black_total(pack=pack), simplify_ids=True) ->>> cwl = render(workflow, allow_complex_types=True, factories_as_params=True) +>>> wkflw = construct(black_total(pack=pack), simplify_ids=True) +>>> cwl = render(wkflw, allow_complex_types=True, factories_as_params=True)["black_total-1"] >>> yaml.dump(cwl, sys.stdout, indent=2) class: Workflow cwlVersion: 1.2 inputs: - PackResult-1: - label: PackResult-1 + pack: + label: pack type: record outputs: out: label: out - outputSource: sum-1/out + outputSource: sum-1-1/out type: int steps: - sum-1: + sum-1-1: in: left: - source: PackResult-1/spades + source: pack/spades right: - source: PackResult-1/clubs + source: pack/clubs out: - out run: sum -``` \ No newline at end of file +``` diff --git a/example/snakemake_renderer/basic_example/snakemake_workflow.py b/example/snakemake_renderer/basic_example/snakemake_workflow.py index 8967a892..1d4d261c 100644 --- a/example/snakemake_renderer/basic_example/snakemake_workflow.py +++ b/example/snakemake_renderer/basic_example/snakemake_workflow.py @@ -172,5 +172,5 @@ def generate_report(processed_data: str, multiple_arg: str, output_file: str) -> "]": "", } ) - smk_output = yaml.dump(smk_output, indent=4).translate(trans_table) - file.write(smk_output) + smk_text = yaml.dump(smk_output, indent=4).translate(trans_table) + file.write(smk_text) diff --git a/example/workflow_complex.py b/example/workflow_complex.py index a0a2dd28..d503c272 100644 --- a/example/workflow_complex.py +++ b/example/workflow_complex.py @@ -7,13 +7,13 @@ ``` """ -from dewret.tasks import subworkflow +from dewret.tasks import workflow from workflow_tasks import sum, double, increase STARTING_NUMBER: int = 23 -@subworkflow() +@workflow() def nested_workflow() -> int | float: """Creates a complex workflow with a nested task. diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..66dca830 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,7 @@ +[mypy] +strict = True +# Required to allow subclassing of sympy.Symbol +disallow_subclassing_any = False + +[mypy-sympy.*] +ignore_missing_imports = true diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 059ec406..92d947d2 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -20,13 +20,25 @@ """ import importlib +import importlib.util from pathlib import Path -import yaml +from contextlib import contextmanager import sys +import re +import yaml +from typing import Any, IO, Generator +from types import ModuleType import click import json -from .renderers.cwl import render as cwl_render +from .core import ( + set_configuration, + set_render_configuration, + RawRenderModule, + StructuredRenderModule, +) +from .utils import load_module_or_package +from .render import get_render_method, write_rendered_output from .tasks import Backend, construct @@ -45,11 +57,35 @@ default=Backend.DASK.name, help="Backend to use for workflow evaluation.", ) -@click.argument("workflow_py") +@click.option( + "--construct-args", + default="simplify_ids:true" +) +@click.option( + "--renderer", + default="cwl" +) +@click.option( + "--renderer-args", + default="" +) +@click.option( + "--output", + default="-" +) +@click.argument("workflow_py", type=click.Path(exists=True, path_type=Path)) @click.argument("task") @click.argument("arguments", nargs=-1) def render( - workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend + workflow_py: Path, + task: str, + arguments: list[str], + pretty: bool, + backend: Backend, + construct_args: str, + renderer: str, + renderer_args: str, + output: str, ) -> None: """Render a workflow. @@ -57,10 +93,7 @@ def render( TASK is the name of (decorated) task in workflow module. ARGUMENTS is zero or more pairs representing constant arguments to pass to the task, in the format `key:val` where val is a JSON basic type. """ - sys.path.append(str(Path(workflow_py).parent)) - loader = importlib.machinery.SourceFileLoader("workflow", workflow_py) - workflow = loader.load_module() - task_fn = getattr(workflow, task) + sys.path.append(str(workflow_py.parent)) kwargs = {} for arg in arguments: if ":" not in arg: @@ -70,18 +103,72 @@ def render( key, val = arg.split(":", 1) kwargs[key] = json.loads(val) + render_module: Path | ModuleType + if mtch := re.match(r"^([a-z_0-9-.]+)$", renderer): + render_module = importlib.import_module(f"dewret.renderers.{mtch.group(1)}") + if not isinstance(render_module, RawRenderModule) and not isinstance(render_module, StructuredRenderModule): + raise NotImplementedError("The imported render module does not seem to match the `RawRenderModule` or `StructuredRenderModule` protocols.") + elif renderer.startswith("@"): + render_module = Path(renderer[1:]) + else: + raise RuntimeError("Renderer argument should be a known dewret renderer, or '@FILENAME' where FILENAME is a renderer") + + if construct_args.startswith("@"): + with Path(construct_args[1:]).open() as construct_args_f: + construct_kwargs = yaml.safe_load(construct_args_f) + elif not construct_args: + construct_kwargs = {} + else: + construct_kwargs = dict(pair.split(":") for pair in construct_args.split(",")) + + renderer_kwargs: dict[str, Any] + if renderer_args.startswith("@"): + with Path(renderer_args[1:]).open() as renderer_args_f: + renderer_kwargs = yaml.safe_load(renderer_args_f) + elif not renderer_args: + renderer_kwargs = {} + else: + renderer_kwargs = dict(pair.split(":") for pair in renderer_args.split(",")) + + if output == "-": + + @contextmanager + def _opener(key: str, _: str) -> Generator[IO[Any], None, None]: + print(" ------ ", key, " ------ ") + yield sys.stdout + print() + + opener = _opener + else: + + @contextmanager + def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]: + output_file = output.replace("%", key) + with Path(output_file).open(mode) as output_f: + yield output_f + + opener = _opener + + render = get_render_method(render_module, pretty=pretty) + pkg = "__workflow__" + workflow = load_module_or_package(pkg, workflow_py) + task_fn = getattr(workflow, task) + try: - cwl = cwl_render(construct(task_fn(**kwargs), simplify_ids=True)) + with ( + set_configuration(**construct_kwargs), + set_render_configuration(renderer_kwargs), + ): + rendered = render( + construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs + ) except Exception as exc: import traceback print(exc, exc.__cause__, exc.__context__) traceback.print_exc() else: - if pretty: - yaml.dump(cwl, sys.stdout, indent=2) - else: - print(cwl) + write_rendered_output(rendered, output, opener) render() diff --git a/src/dewret/annotations.py b/src/dewret/annotations.py new file mode 100644 index 00000000..c1449d43 --- /dev/null +++ b/src/dewret/annotations.py @@ -0,0 +1,224 @@ +# Copyright 2024- Flax & Teal Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tooling for managing annotations. + +Provides `FunctionAnalyser`, a toolkit that takes a `Callable` and can interrogate it +for annotations, with some intelligent searching beyond the obvious location. +""" + +import inspect +import ast +import sys +import importlib +from functools import lru_cache +from types import FunctionType, ModuleType +from typing import ( + Any, + TypeVar, + Annotated, + Callable, + get_origin, + get_args, + Mapping, + get_type_hints, +) + +T = TypeVar("T") +AtRender = Annotated[T, "AtRender"] +Fixed = Annotated[T, "Fixed"] + + +class FunctionAnalyser: + """Convenience class for analysing a function with reduced duplication of effort. + + Attributes: + _fn: the wrapped callable + _annotations: stored annotations for the function. + """ + + _fn: Callable[..., Any] + _annotations: dict[str, Any] + + def __init__(self, fn: Callable[..., Any]): + """Set the function. + + If `fn` is a class, it takes the constructor, and if it is a method, it takes + the `__func__` attribute. + """ + if inspect.isclass(fn): + self.fn = fn.__init__ + elif inspect.ismethod(fn): + self.fn = fn.__func__ + else: + self.fn = fn + + @property + def return_type(self) -> Any: + """Return type of the callable. + + Returns: expected type of the return value. + + Raises: + ValueError: if the return value does not appear to be type-hinted. + """ + hints = get_type_hints(inspect.unwrap(self.fn), include_extras=True) + if "return" not in hints or hints["return"] is None: + raise ValueError(f"Could not find type-hint for return value of {self.fn}") + typ = hints["return"] + return typ + + @staticmethod + def _typ_has(typ: type, annotation: type) -> bool: + """Check if the type has an annotation. + + Args: + typ: type to check. + annotation: the Annotated to compare against. + + Returns: True if the type has the given annotation, otherwise False. + """ + if not hasattr(annotation, "__metadata__"): + return False + if origin := get_origin(typ): + if ( + origin is Annotated + and hasattr(typ, "__metadata__") + and typ.__metadata__ == annotation.__metadata__ + ): + return True + if any(FunctionAnalyser._typ_has(arg, annotation) for arg in get_args(typ)): + return True + return False + + def get_all_module_names(self) -> dict[str, Any]: + """Find all of the annotations within this module.""" + return get_type_hints(sys.modules[self.fn.__module__], include_extras=True) + + def get_all_imported_names(self) -> dict[str, tuple[ModuleType, str]]: + """Find all of the annotations that were imported into this module.""" + return self._get_all_imported_names(sys.modules[self.fn.__module__]) + + @staticmethod + @lru_cache + def _get_all_imported_names(mod: ModuleType) -> dict[str, tuple[ModuleType, str]]: + """Get all of the names with this module, and their original locations. + + Args: + mod: a module in the `sys.modules`. + + Returns: + A dict whose keys are the known names in the current module, where the Callable lives, + and whose values are pairs of the module and the remote name. + """ + ast_tree = ast.parse(inspect.getsource(mod)) + imported_names = {} + for node in ast.walk(ast_tree): + if isinstance(node, ast.ImportFrom): + for name in node.names: + imported_names[name.asname or name.name] = ( + importlib.import_module( + "".join(["."] * node.level) + (node.module or ""), + package=mod.__package__, + ), + name.name, + ) + return imported_names + + @property + def free_vars(self) -> dict[str, Any]: + """Get the free variables for this Callable.""" + if self.fn.__code__ and self.fn.__closure__: + return dict( + zip( + self.fn.__code__.co_freevars, + (c.cell_contents for c in self.fn.__closure__), + strict=False, + ) + ) + return {} + + def get_argument_annotation(self, arg: str, exhaustive: bool = False) -> Any: + """Retrieve the annotations for this argument. + + Args: + arg: name of the argument. + exhaustive: True if we should search outside the function itself, into the module globals, and into imported modules. + + Returns: annotation if available, else None. + """ + typ: type | None = None + if typ := self.fn.__annotations__.get(arg): + if isinstance(typ, str): + typ = get_type_hints(self.fn, include_extras=True).get(arg) + elif exhaustive: + if anns := get_type_hints( + sys.modules[self.fn.__module__], include_extras=True + ): + if typ := anns.get(arg): + ... + elif orig_pair := self.get_all_imported_names().get(arg): + orig_module, orig_name = orig_pair + typ = orig_module.__annotations__.get(orig_name) + elif value := self.free_vars.get(arg): + if not inspect.isclass(value) or inspect.isfunction(value): + raise RuntimeError( + f"Cannot use free variables - please put {arg} at the global scope" + ) + return typ + + def argument_has( + self, arg: str, annotation: type, exhaustive: bool = False + ) -> bool: + """Check if the named argument has the given annotation. + + Args: + arg: argument to retrieve. + annotation: Annotated to search for. + exhaustive: whether to check the globals and other modules. + + Returns: True if the Annotated is present in this argument's annotation. + """ + typ = self.get_argument_annotation(arg, exhaustive) + return bool(typ and self._typ_has(typ, annotation)) + + def is_at_construct_arg(self, arg: str, exhaustive: bool = False) -> bool: + """Convience function to check for `AtConstruct`, wrapping `FunctionAnalyser.argument_has`.""" + return self.argument_has(arg, AtRender, exhaustive) + + @property + def globals(self) -> Mapping[str, Any]: + """Get the globals for this Callable.""" + try: + fn_tuple = inspect.getclosurevars(self.fn) + fn_globals = dict(fn_tuple.globals) + fn_globals.update(fn_tuple.nonlocals) + # This covers the case of wrapping, rather than decorating. + except TypeError: + fn_globals = {} + return fn_globals + + def with_new_globals(self, new_globals: dict[str, Any]) -> Callable[..., Any]: + """Create a Callable that will run the current Callable with new globals.""" + code = self.fn.__code__ + fn_name = self.fn.__name__ + all_globals = dict(self.globals) + all_globals.update(new_globals) + return FunctionType( + code, + all_globals, + name=fn_name, + closure=self.fn.__closure__, + argdefs=self.fn.__defaults__, + ) diff --git a/src/dewret/backends/_base.py b/src/dewret/backends/_base.py index 3ad11bec..fd1a830e 100644 --- a/src/dewret/backends/_base.py +++ b/src/dewret/backends/_base.py @@ -18,6 +18,7 @@ """ from typing import Protocol, Any +from concurrent.futures import ThreadPoolExecutor from dewret.workflow import LazyFactory, Lazy, Workflow, StepReference, Target class BackendModule(Protocol): @@ -32,12 +33,13 @@ class BackendModule(Protocol): """ lazy: LazyFactory - def run(self, workflow: Workflow, task: Lazy) -> StepReference[Any]: + def run(self, workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread_pool: ThreadPoolExecutor | None=None) -> StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any]]: """Execute a lazy task for this `Workflow`. Args: workflow: `Workflow` that is being executed. task: task that forms the output. + thread_pool: the thread pool that should be used for this execution. Returns: Reference to the final output step. diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index 017fb502..47228d20 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -18,8 +18,10 @@ """ from dask.delayed import delayed, DelayedLeaf -from dewret.workflow import Workflow, Lazy, StepReference, Target +from dask.config import config from typing import Protocol, runtime_checkable, Any, cast +from concurrent.futures import ThreadPoolExecutor +from dewret.workflow import Workflow, Lazy, StepReference, Target @runtime_checkable @@ -33,6 +35,11 @@ class Delayed(Protocol): More info: https://github.com/dask/dask/issues/7779 """ + @property + def __dask_graph__(self): # type: ignore + """Retrieve the dask graph.""" + ... + def compute(self, __workflow__: Workflow | None) -> StepReference[Any]: """Evaluate this `dask.delayed`. @@ -81,13 +88,20 @@ def is_lazy(task: Any) -> bool: Returns: True if so, False otherwise. """ - return isinstance(task, Delayed) + return isinstance(task, Delayed) or ( + isinstance(task, tuple | list) and all(is_lazy(elt) for elt in task) + ) lazy = delayed -def run(workflow: Workflow | None, task: Lazy) -> StepReference[Any]: +def run( + workflow: Workflow | None, + task: Lazy | list[Lazy] | tuple[Lazy], + thread_pool: ThreadPoolExecutor | None = None, + **kwargs: Any, +) -> Any: """Execute a task as the output of a workflow. Runs a task with dask. @@ -95,11 +109,15 @@ def run(workflow: Workflow | None, task: Lazy) -> StepReference[Any]: Args: workflow: `Workflow` in which to record the execution. task: `dask.delayed` function, wrapped by dewret, that we wish to compute. + thread_pool: custom thread pool for executing workflows, copies in correct values for contextvars to each thread before they are accessed by a dask worker. + **kwargs: any configuration arguments for this backend. """ - # We need isinstance to reassure type-checker. - if not isinstance(task, Delayed) or not is_lazy(task): - raise RuntimeError( - f"{task} is not a dask delayed, perhaps you tried to mix backends?" - ) - result = task.compute(__workflow__=workflow) + # def _check_delayed was here, but we decided to delegate this to dask + + if isinstance(task, Delayed) and is_lazy(task): + computable = task + else: + computable = delayed(task) + config["pool"] = thread_pool + result = computable.compute(__workflow__=workflow) return result diff --git a/src/dewret/core.py b/src/dewret/core.py new file mode 100644 index 00000000..eecf3033 --- /dev/null +++ b/src/dewret/core.py @@ -0,0 +1,626 @@ +# Copyright 2024- Flax & Teal Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base classes that need to be available everywhere. + +Mainly tooling around configuration, protocols and superclasses for References +and WorkflowComponents, that are concretized elsewhere. +""" + +from dataclasses import dataclass +import importlib +import base64 +from attrs import define +from functools import lru_cache +from typing import ( + Generic, + TypeVar, + Protocol, + Iterator, + Unpack, + TypedDict, + NotRequired, + Generator, + Any, + get_args, + get_origin, + Annotated, + Callable, + cast, + runtime_checkable, +) +from contextlib import contextmanager +from contextvars import ContextVar +from sympy import Expr, Symbol, Basic +import copy + +BasicType = str | float | bool | bytes | int | None +RawType = BasicType | list["RawType"] | dict[str, "RawType"] +FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] +ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore + +U = TypeVar("U") +T = TypeVar("T") + + +def strip_annotations(parent_type: type) -> tuple[type, tuple[str]]: + """Discovers and removes annotations from a parent type. + + Args: + parent_type: a type, possibly Annotated. + + Returns: a parent type that is not Annotation, along with any stripped metadata, if present. + + """ + # Strip out any annotations. This should be auto-flattened, so in theory only one iteration could occur. + metadata = [] + while get_origin(parent_type) is Annotated: + parent_type, *parent_metadata = get_args(parent_type) + metadata += list(parent_metadata) + return parent_type, tuple(metadata) + + +# Generic type for configuration settings for the renderer +RenderConfiguration = dict[str, RawType] + + +class WorkflowProtocol(Protocol): + """Expected structure for a workflow. + + We do not expect various workflow implementations, but this allows us to define the + interface expected by the core classes. + """ + + def remap(self, name: str) -> str: + """Perform any name-changing for steps, etc. in the workflow. + + This enables, for example, simplifying all the IDs to an integer sequence. + + Returns: remapped name. + """ + ... + + def set_result(self, result: Basic | list[Basic] | tuple[Basic]) -> None: + """Set the step that should produce a result for the overall workflow.""" + ... + + def simplify_ids(self, infix: list[str] | None = None) -> None: + """Drop the non-human-readable IDs if possible, in favour of integer sequences. + + Args: + infix: any inherited intermediary identifiers, to allow nesting, or None. + """ + ... + + +class BaseRenderModule(Protocol): + """Common routines for all renderer modules.""" + + @staticmethod + def default_config() -> dict[str, RawType]: + """Retrieve default settings. + + These will not change during execution, but can be overridden by `dewret.core.set_render_configuration`. + + Returns: a static, serializable dict. + """ + ... + + +@runtime_checkable +class RawRenderModule(BaseRenderModule, Protocol): + """Render module that returns raw text.""" + + def render_raw( + self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration + ) -> dict[str, str]: + """Turn a workflow into flat strings. + + Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least. + """ + ... + + +@runtime_checkable +class StructuredRenderModule(BaseRenderModule, Protocol): + """Render module that returns JSON/YAML-serializable structures.""" + + def render( + self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration + ) -> dict[str, dict[str, RawType]]: + """Turn a workflow into a serializable structure. + + Returns: one or more subworkflows with a `__root__` key representing the outermost workflow, at least. + """ + ... + + +class RenderCall(Protocol): + """Callable that will render out workflow(s).""" + + def __call__( + self, workflow: WorkflowProtocol, **kwargs: RenderConfiguration + ) -> dict[str, str] | dict[str, RawType]: + """Take a workflow and turn it into a set of serializable (sub)workflows. + + Args: + workflow: root workflow. + kwargs: configuration for the renderer. + + Returns: a mapping of keys to serialized workflows, containing at least `__root__`. + """ + ... + + +class UnevaluatableError(Exception): + """Signposts that a user has tried to treat a reference as the real (runtime) value. + + For example, by comparing to a concrete integer or value, etc. + """ + + ... + + +@define +class ConstructConfiguration: + """Basic configuration of the construction process. + + Holds configuration that may be relevant to `construst(...)` calls or, realistically, + anything prior to rendering. It should hold generic configuration that is renderer-independent. + """ + + flatten_all_nested: bool = False + allow_positional_args: bool = False + allow_plain_dict_fields: bool = False + field_separator: str = "/" + field_index_types: str = "int" + simplify_ids: bool = False + + +class ConstructConfigurationTypedDict(TypedDict): + """Basic configuration of the construction process. + + Holds configuration that may be relevant to `construst(...)` calls or, realistically, + anything prior to rendering. It should hold generic configuration that is renderer-independent. + + **THIS MUST BE KEPT IDENTICAL TO ConstructConfiguration.** + """ + + flatten_all_nested: NotRequired[bool] + allow_positional_args: NotRequired[bool] + allow_plain_dict_fields: NotRequired[bool] + field_separator: NotRequired[str] + field_index_types: NotRequired[str] + simplify_ids: NotRequired[bool] + + +@define +class GlobalConfiguration: + """Overall configuration structure. + + Having a single configuration dict allows us to manage only one ContextVar. + """ + + construct: ConstructConfiguration + render: dict[str, RawType] + + +CONFIGURATION: ContextVar[GlobalConfiguration] = ContextVar("configuration") + + +@contextmanager +def set_configuration( + **kwargs: Unpack[ConstructConfigurationTypedDict], +) -> Iterator[ContextVar[GlobalConfiguration]]: + """Sets the construct-time configuration. + + This is a context manager, so that a setting can be temporarily overridden and automatically restored. + """ + with _set_configuration() as CONFIGURATION: + for key, value in kwargs.items(): + setattr(CONFIGURATION.get().construct, key, value) + yield CONFIGURATION + + +@contextmanager +def set_render_configuration( + kwargs: dict[str, RawType], +) -> Iterator[ContextVar[GlobalConfiguration]]: + """Sets the render-time configuration. + + This is a context manager, so that a setting can be temporarily overridden and automatically restored. + + Returns: the yielded global configuration ContextVar. + """ + with _set_configuration() as CONFIGURATION: + CONFIGURATION.get().render.update(**kwargs) + yield CONFIGURATION + + +@contextmanager +def _set_configuration() -> Iterator[ContextVar[GlobalConfiguration]]: + """Prepares and tidied up the configuration for applying settings. + + This is a context manager, so that a setting can be temporarily overridden and automatically restored. + """ + try: + previous = CONFIGURATION.get() + except LookupError: + previous = GlobalConfiguration( + construct=ConstructConfiguration(), render=default_renderer_config() + ) + CONFIGURATION.set(previous) + previous = copy.deepcopy(previous) + + try: + yield CONFIGURATION + finally: + CONFIGURATION.set(previous) + + +@lru_cache +def default_renderer_config() -> RenderConfiguration: + """Gets the default renderer configuration. + + This may be called frequently, but is cached so note that any changes to the + wrapped config function will _not_ be reflected during the process. + + It is a light wrapper for `default_config` in the supplier renderer module. + + Returns: the default configuration dict for the chosen renderer. + """ + try: + # We have to use a cast as we do not know if the renderer module is valid. + render_module = cast( + BaseRenderModule, importlib.import_module("__renderer_mod__") + ) + default_config: Callable[[], RenderConfiguration] = render_module.default_config + except ImportError: + return {} + return default_config() + + +@lru_cache +def default_construct_config() -> ConstructConfiguration: + """Gets the default construct-time configuration. + + This is the primary mechanism for configuring dewret internals, so these defaults + should be carefully chosen and, if they change, that likely has an impact on backwards compatibility + from a SemVer perspective. + + Returns: configuration dictionary with default construct values. + """ + return ConstructConfiguration( + flatten_all_nested=False, + allow_positional_args=False, + allow_plain_dict_fields=False, + field_separator="/", + field_index_types="int", + ) + + +def get_configuration(key: str) -> RawType: + """Retrieve the configuration or (silently) return the default. + + Helps avoid a proliferation of `set_configuration` calls by not erroring if it has not been called. + However, the cost is that the user may accidentally put configuration-affected logic outside a + set_configuration call and be surprised that the behaviour is inexplicibly not as expected. + + Args: + key: configuration key to retrieve. + + Returns: (preferably) a JSON/YAML-serializable construct. + """ + try: + return getattr(CONFIGURATION.get().construct, key) # type: ignore + except LookupError: + # TODO: Not sure what the best way to typehint this is. + return getattr(ConstructConfiguration(), key) # type: ignore + + +def get_render_configuration(key: str) -> RawType: + """Retrieve configuration for the active renderer. + + Finds the current user-set configuration, defaulting back to the chosen renderer module's declared + defaults. + + Args: + key: configuration key to retrieve. + + Returns: (preferably) a JSON/YAML-serializable construct. + """ + try: + return CONFIGURATION.get().render.get(key) + except LookupError: + return default_renderer_config().get(key) + + +class WorkflowComponent: + """Base class for anything directly tied to an individual `Workflow`. + + Attributes: + __workflow__: the `Workflow` that this is tied to. + """ + + __workflow_real__: WorkflowProtocol + + def __init__(self, *args: Any, workflow: WorkflowProtocol, **kwargs: Any): + """Tie to a `Workflow`. + + All subclasses must call this. + + Args: + workflow: the `Workflow` to tie to. + *args: remainder of arguments for other initializers. + **kwargs: remainder of arguments for other initializers. + """ + self.__workflow__ = workflow + super().__init__(*args, **kwargs) + + @property + def __workflow__(self) -> WorkflowProtocol: + """Workflow to which this reference applies.""" + return self.__workflow_real__ + + @__workflow__.setter + def __workflow__(self, workflow: WorkflowProtocol) -> None: + """Workflow to which this reference applies.""" + self.__workflow_real__ = workflow + + +class Reference(Generic[U], Symbol, WorkflowComponent): + """Superclass for all symbolic references to values.""" + + _type: type[U] | None = None + __iterated__: bool = False + + def __init__(self, *args: Any, typ: type[U] | None = None, **kwargs: Any): + """Extract any specified type. + + Args: + typ: type to override any inference, or None. + *args: any other arguments for other initializers (e.g. mixins). + **kwargs: any other arguments for other initializers (e.g. mixins). + """ + self._type = typ + super().__init__(*args, **kwargs) + + @property + def name(self) -> str: + """Printable name of the reference.""" + return self.__name__ + + def __new__(cls, *args: Any, **kwargs: Any) -> "Reference[U]": + """As all references are sympy Expressions, the real returned object must be made with Expr.""" + instance = Expr.__new__(cls) + instance._assumptions0 = {} + return cast(Reference[U], instance) + + @property + def __root_name__(self) -> str: + """Root name on which to suffix/prefix any derived names (with fields, etc.). + + For example, the base name of `add_thing-12345[3]` should be `add_thing`. + + Returns: basic name as a string. + """ + raise NotImplementedError( + "Reference must have a '__root_name__' property or override '__name__'" + ) + + @property + def __type__(self) -> type: + """Type of the reference target, if known.""" + if self._type is not None: + return self._type + raise NotImplementedError() + + def _raise_unevaluatable_error(self) -> None: + """Convenience method to consistently formulate an UnevaluatableError for this reference.""" + raise UnevaluatableError( + f"This reference, {self.__name__}, cannot be evaluated during construction." + ) + + def __eq__(self, other: object) -> Any: + """Test equality at construct-time, if sensible. + + Raises: + UnevaluatableError: if it seems the user is confusing this with a runtime check. + """ + if isinstance(other, list) or other is None: + return False + if not isinstance(other, Reference) and not isinstance(other, Basic): + self._raise_unevaluatable_error() + return super().__eq__(other) + + def __float__(self) -> bool: + """Catch accidental float casts. + + Raises: + UnevaluatableError: unconditionally, as it seems the user is confusing this with a runtime check. + """ + self._raise_unevaluatable_error() + return False + + def __int__(self) -> bool: + """Catch accidental int casts. + + Raises: + UnevaluatableError: unconditionally, as it seems the user is confusing this with a runtime check. + """ + self._raise_unevaluatable_error() + return False + + def __bool__(self) -> bool: + """Catch accidental bool casts. + + Note that this means ambiguous checks such as `if ref: ...` will error, and should be `if ref is None: ...` + or similar. + + Raises: + UnevaluatableError: unconditionally, as it seems the user is confusing this with a runtime check. + """ + self._raise_unevaluatable_error() + return False + + @property + def __name__(self) -> str: + """Referral name for this reference. + + Returns: an internal name to refer to the reference target. + """ + workflow = self.__workflow__ + name = self.__root_name__ + return workflow.remap(name) if workflow is not None else name + + def __str__(self) -> str: + """Global description of the reference. + + Returns the _internal_ name. + """ + return self.__name__ + + +class IterableMixin(Reference[U]): + """Functionality for iterating over references to give new references.""" + + __fixed_len__: int | None = None + + def __init__(self, typ: type[U] | None = None, **kwargs: Any): + """Extract length, if available from type. + + Captures types of the form (e.g.) `tuple[int, float]` and records the length + as being (e.g.) 2. + """ + super().__init__(typ=typ, **kwargs) + base = strip_annotations(self.__type__)[0] + if get_origin(base) == tuple and (args := get_args(base)): + # In the special case of an explicitly-typed tuple, we can state a length. + self.__fixed_len__ = len(args) + + def __len__(self) -> int: + """Length of this iterable, if available. + + The two cases that this is likely to be available are if the reference target + has been type-hinted as a `tuple` with a specific, fixed number of type arguments, + or if the target has been annotated with `Fixed(...)` indicating that the length + of the default value can be hard-coded into the output, and therefore that it can + be used for graph-building logic. The most useful application of this is likely to + be in for-loops and generators, as we can create variable references for each iteration + but can nonetheless execute the loop as we know how many iterations occur. + + Returns: length of the iterable, if available. + """ + if self.__fixed_len__ is None: + raise TypeError( + "This iterable reference does not have a known fixed length, " + "consider using `Fixed[...]` with a default, or typehint using `tuple[int, float]` (or similar) " + "to tell dewret how long it should be." + ) + return self.__fixed_len__ + + def __iter__(self) -> Generator[Reference[U], None, None]: + """Execute the iteration. + + Note that this does _not_ return the yielded values of `__inner_iter__`, so that + it can be used with impunity to do actual iteration, and we will _always_ return + references here. + + Returns: a generator that will give a new reference for every iteration. + """ + for count, _ in enumerate(self.__inner_iter__()): + yield super().__getitem__(count) + + def __inner_iter__(self) -> Generator[Any, None, None]: + """Overrideable iterator for looping over the wrapped object. + + Returns: a generator that will yield ints if this reference is known to be fixed length, or will go + forever yielding None otherwise. + """ + if self.__fixed_len__ is not None: + for i in range(self.__fixed_len__): + yield i + else: + while True: + yield None + + def __getitem__(self, attr: str | int) -> "Reference[U] | Any": + """Get a reference to an individual item/field. + + Args: + attr: index or fieldname. + + Returns: a reference to the same target as this reference but a level deeper. + """ + return super().__getitem__(attr) + + +class IteratedGenerator(Generic[U]): + """Sentinel value for capturing that an iteration has occured without performing it. + + Allows us to lazily evaluate a loop, for instance, in the renderer. This may be relevant + if the renderer wishes to postpone iteration to runtime, and simply record it is required, + rather than evaluating the iterator. + """ + + __wrapped__: Reference[U] + + def __init__(self, to_wrap: Reference[U]): + """Capture wrapped reference. + + Args: + to_wrap: reference to wrap. + """ + self.__wrapped__ = to_wrap + + def __iter__(self) -> Generator[Reference[U], None, None]: + """Loop through the wrapped reference. + + This will tag the references that are returned, so that the renderer can see this has + happened. + + Returns: a generator looping over the wrapped reference with a counter as the "field". + """ + count = -1 + for _ in self.__wrapped__.__inner_iter__(): + ref = self.__wrapped__[(count := count + 1)] + ref.__iterated__ = True + yield ref + + +@dataclass +class Raw: + """Value object for any raw types. + + This is able to hash raw types consistently and provides + a single type for validating type-consistency. + + Attributes: + value: the real value, e.g. a `str`, `int`, ... + """ + + value: RawType + + def __hash__(self) -> int: + """Provide a hash that is unique to the `value` member.""" + return hash(repr(self)) + + def __repr__(self) -> str: + """Convert to a consistent, string representation.""" + value: str + if isinstance(self.value, bytes): + value = base64.b64encode(self.value).decode("ascii") + else: + value = str(self.value) + return f"{type(self.value).__name__}|{value}" diff --git a/src/dewret/render.py b/src/dewret/render.py new file mode 100644 index 00000000..d702dcad --- /dev/null +++ b/src/dewret/render.py @@ -0,0 +1,147 @@ +# Copyright 2024- Flax & Teal Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for building renderers. + +Provides the routines for calling varied renderers in a standard way, and for +renderers to reuse to build up their own functionality. +""" + +import sys +from pathlib import Path +from functools import partial +from typing import TypeVar, Callable, ContextManager, IO, Any, cast +import yaml + +from .workflow import Workflow, NestedStep +from .core import ( + RawType, + RenderCall, + BaseRenderModule, + RawRenderModule, + StructuredRenderModule, + RenderConfiguration, +) +from .utils import load_module_or_package + +T = TypeVar("T") + + +def structured_to_raw(rendered: RawType, pretty: bool = False) -> str: + """Serialize a serializable structure to a string. + + Args: + rendered: a possibly-nested, static basic Python structure. + pretty: whether to attempt YAML dumping with an indent of 2. + + Returns: YAML/stringified version of the structure. + """ + if pretty: + output = yaml.safe_dump(rendered, indent=2) + else: + output = str(rendered) + return output + + +def get_render_method( + renderer: Path | RawRenderModule | StructuredRenderModule, pretty: bool = False +) -> RenderCall: + """Create a ready-made callable to render the workflow that is appropriate for the renderer module. + + Args: + renderer: a module or path to a module. + pretty: whether the renderer should attempt to YAML-format the output (if relevant). + + Returns: a callable with a consistent interface, regardless of the renderer type. + """ + render_module: BaseRenderModule + if isinstance(renderer, Path): + # Attempt to load renderer as package, falling back to a single module otherwise. + # This enables relative imports in renderers and therefore the ability to modularize. + module = load_module_or_package("__renderer__", renderer) + sys.modules["__renderer_mod__"] = module + render_module = cast(BaseRenderModule, module) + else: + render_module = renderer + + if isinstance(render_module, RawRenderModule): + return render_module.render_raw + elif isinstance(render_module, (StructuredRenderModule)): + + def _render( + workflow: Workflow, + render_module: StructuredRenderModule, + pretty: bool = False, + **kwargs: RenderConfiguration, + ) -> dict[str, str]: + rendered = render_module.render(workflow, **kwargs) + return { + key: structured_to_raw(value, pretty=pretty) + for key, value in rendered.items() + } + + return cast( + RenderCall, partial(_render, render_module=render_module, pretty=pretty) + ) + + raise NotImplementedError( + "This render module neither seems to be a structured nor a raw render module." + ) + + +def write_rendered_output( + rendered: dict[str, str] | dict[str, RawType], + output: str, + opener: Callable[[str, str], ContextManager[IO[Any]]], +) -> None: + """Utility function to handle writing rendered output to file or stdout.""" + if len(rendered) == 1: + with opener("", "w") as output_f: + output_f.write(rendered["__root__"]) + elif "%" in output: + for key, value in rendered.items(): + if key == "__root__": + key = "ROOT" + with opener(key, "w") as output_f: + output_f.write(value) + else: + with opener("ROOT", "w") as output_f: + output_f.write(rendered["__root__"]) + del rendered["__root__"] + for key, value in rendered.items(): + with opener(key, "a") as output_f: + output_f.write("\n---\n") + output_f.write(value) + + +def base_render(workflow: Workflow, build_cb: Callable[[Workflow], T]) -> dict[str, T]: + """Render to a dict-like structure. + + Args: + workflow: workflow to evaluate result. + build_cb: a callback to call for each workflow found. + + Returns: + Reduced form as a native Python dict structure for + serialization. + """ + primary_workflow = build_cb(workflow) + subworkflows = {} + for step in workflow.indexed_steps.values(): + if isinstance(step, NestedStep): + nested_subworkflows = base_render(step.subworkflow, build_cb) + subworkflows.update(nested_subworkflows) + subworkflows[step.name] = nested_subworkflows["__root__"] + subworkflows["__root__"] = primary_workflow + return subworkflows diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index b51c3e58..18c086f4 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -19,31 +19,135 @@ """ from attrs import define, has as attrs_has, fields as attrs_fields, AttrsInstance -from dataclasses import dataclass, is_dataclass, fields as dataclass_fields +from dataclasses import is_dataclass, fields as dataclass_fields from collections.abc import Mapping -from contextvars import ContextVar -from typing import TypedDict, NotRequired, get_args, Union, cast, Any, Iterable, Unpack +from typing import ( + TypedDict, + NotRequired, + get_origin, + get_args, + cast, + Any, + Unpack, + Iterable, +) from types import UnionType +from inspect import isclass +from sympy import Basic, Tuple, Dict, jscode, Symbol +from dewret.core import ( + Raw, + RawType, + FirmType, +) from dewret.workflow import ( FactoryCall, - Reference, - Raw, Workflow, BaseStep, - NestedStep, StepReference, ParameterReference, + expr_to_references, +) +from dewret.utils import ( + crawl_raw, + DataclassProtocol, + firm_to_raw, + flatten_if_set, Unset, ) -from dewret.utils import RawType, flatten, DataclassProtocol +from dewret.render import base_render +from dewret.core import Reference, get_render_configuration, set_render_configuration + + +class CommandInputSchema(TypedDict): + """Structure for referring to a raw type in CWL. + + Encompasses several CWL types. In future, it may be best to + use _cwltool_ or another library for these basic structures. + + Attributes: + type: CWL type of this input. + label: name to show for this input. + fields: (for `record`) individual fields in a dict-like structure. + items: (for `array`) type that each field will have. + """ + + type: "InputSchemaType" + label: str + fields: NotRequired[dict[str, "CommandInputSchema"]] + items: NotRequired["InputSchemaType"] + default: NotRequired[RawType] + + +InputSchemaType = ( + str + | CommandInputSchema + | list[str] + | list["InputSchemaType"] + | dict[str, "str | InputSchemaType"] +) + + +def render_expression(ref: Any) -> "ReferenceDefinition": + """Turn a rich (sympy) expression into a CWL JS expression. + + Args: + ref: a structure whose elements are all string-renderable or sympy Basic. -InputSchemaType = Union[ - str, "CommandInputSchema", list[str], list["InputSchemaType"], dict[str, str] -] + Returns: a ReferenceDefinition containing a string representation of the expression in the form `$(...)`. + """ + + def _render(ref: Any) -> Basic | RawType: + if not isinstance(ref, Basic): + if isinstance(ref, Mapping): + ref = Dict({key: _render(val) for key, val in ref.items()}) + elif not isinstance(ref, str | bytes) and isinstance(ref, Iterable): + ref = Tuple(*(_render(val) for val in ref)) + return ref + + expr = _render(ref) + if isinstance(expr, Basic): + values = list(expr.free_symbols) + step_syms = [sym for sym in expr.free_symbols if isinstance(sym, StepReference)] + param_syms = [ + sym for sym in expr.free_symbols if isinstance(sym, ParameterReference) + ] + + if set(values) != set(step_syms) | set(param_syms): + raise NotImplementedError( + f"Can only build expressions for step results and param results: {ref}" + ) + + if len(step_syms) > 1: + raise NotImplementedError( + f"Can only create expressions with 1 step reference: {ref}" + ) + if not (step_syms or param_syms): + ... + if values == [ref]: + if isinstance(ref, StepReference): + return ReferenceDefinition(source=to_name(ref), value_from=None) + else: + return ReferenceDefinition(source=ref.name, value_from=None) + source = None + for ref in values: + if isinstance(ref, StepReference): + field = with_field(ref) + parts = field.split("/") + base = f"/{parts[0]}" if parts and parts[0] else "" + if len(parts) > 1: + expr = expr.subs(ref, f"self.{'.'.join(parts[1:])}") + else: + expr = expr.subs(ref, "self") + source = f"{ref.__root_name__}{base}" + else: + expr = expr.subs(ref, Symbol(f"inputs.{ref.name}")) + return ReferenceDefinition( + source=source, value_from=f"$({jscode(_render(expr))})" + ) + return ReferenceDefinition(source=str(expr), value_from=None) -@dataclass class CWLRendererConfiguration(TypedDict): """Configuration for the renderer. @@ -56,30 +160,65 @@ class CWLRendererConfiguration(TypedDict): factories_as_params: NotRequired[bool] -CONFIGURATION: ContextVar[CWLRendererConfiguration] = ContextVar("cwl-configuration") +def default_config() -> CWLRendererConfiguration: + """Default configuration for this renderer. + + This is a hook-like call to give a configuration dict that this renderer + will respect, and sets any necessary default values. + Returns: a dict with (preferably) raw type structures to enable easy setting + from YAML/JSON. + """ + return { + "allow_complex_types": False, + "factories_as_params": False, + } -def set_configuration(configuration: CWLRendererConfiguration) -> None: - """Set configuration for this rendering. - Args: - configuration: overridden settings as dict. +def with_type(result: Any) -> type | Any: + """Get a Python type from a value. + + Does so either by using its `__type__` field (for example, for References) + or if unavailable, using `type()`. + + Returns: a Python type. """ - CONFIGURATION.set( - CWLRendererConfiguration( - allow_complex_types=False, - factories_as_params=False, - ) - ) - CONFIGURATION.get().update(configuration) + if hasattr(result, "__type__"): + return result.__type__ + return type(result) + + +def with_field(result: Any) -> str: + """Get a string representing any 'field' suffix of a value. + + This only makes sense in the context of a Reference, which can represent + a deep reference with a known variable (parameter or step result, say) using + its `__field__` attribute. Defaults to `"out"` as this produces compliant CWL + where every output has a "fieldname". + + Returns: a string representation of the field portion of the passed value or `"out"`. + """ + if hasattr(result, "__field__") and result.__field__: + return str(result.__field_str__) + else: + return "out" -def configuration(key: str) -> Any: - """Retrieve current configuration (thread/async-local).""" - current_configuration = CONFIGURATION.get() - if key not in current_configuration: - raise KeyError("Unknown configuration settings.") - return current_configuration.get(key) +def to_name(result: Reference[Any]) -> str: + """Take a reference and get a name representing it. + + The primary purpose of this method is to deal with the case where a reference is to the + whole result, as we always put this into an imagined `out` field for CWL consistency. + + Returns: the name of the reference, including any field portion, appending an `"out"` fieldname if none. + """ + if ( + hasattr(result, "__field__") + and not result.__field__ + and isinstance(result, StepReference) + ): + return f"{result.__name__}/out" + return result.__name__ @define @@ -89,10 +228,11 @@ class ReferenceDefinition: Normally points to a value or a step. """ - source: str + source: str | None + value_from: str | None @classmethod - def from_reference(cls, ref: Reference) -> "ReferenceDefinition": + def from_reference(cls, ref: Reference[Any]) -> "ReferenceDefinition": """Build from a `Reference`. Converts a `dewret.workflow.Reference` into a CWL-rendering object. @@ -100,7 +240,7 @@ def from_reference(cls, ref: Reference) -> "ReferenceDefinition": Args: ref: reference to convert. """ - return cls(source=ref.name) + return render_expression(ref) def render(self) -> dict[str, RawType]: """Render to a dict-like structure. @@ -109,7 +249,12 @@ def render(self) -> dict[str, RawType]: Reduced form as a native Python dict structure for serialization. """ - return {"source": self.source} + representation: dict[str, RawType] = {} + if self.source is not None: + representation["source"] = self.source + if self.value_from is not None: + representation["valueFrom"] = self.value_from + return representation @define @@ -168,93 +313,94 @@ def render(self) -> dict[str, RawType]: key: ( ref.render() if isinstance(ref, ReferenceDefinition) - else {"default": ref.value} + else render_expression(ref).render() + if isinstance(ref, Basic) + else {"default": firm_to_raw(ref.value)} + if hasattr(ref, "value") + else render_expression(ref).render() ) for key, ref in self.in_.items() }, - "out": flatten(self.out), + "out": crawl_raw(self.out), } -def cwl_type_from_value(val: RawType | Unset) -> str | list[str] | dict[str, Any]: +def cwl_type_from_value(label: str, val: RawType | Unset) -> CommandInputSchema: """Find a CWL type for a given (possibly Unset) value. Args: + label: the label for the variable being checked to prefill the input def and improve debugging info. val: a raw Python variable or an unset variable. Returns: - Type as a string or list of strings. + Input schema type. """ if val is not None and hasattr(val, "__type__"): raw_type = val.__type__ else: raw_type = type(val) - return to_cwl_type(raw_type) + return to_cwl_type(label, raw_type) -def to_cwl_type(typ: type) -> str | dict[str, Any] | list[str]: +def to_cwl_type(label: str, typ: type) -> CommandInputSchema: """Map Python types to CWL types. Args: + label: the label for the variable being checked to prefill the input def and improve debugging info. typ: a Python basic type. Returns: - CWL specification type name, or a list - if a union. + CWL specification type dict. """ - if typ == int: - return "int" - elif typ == bool: - return "boolean" - elif typ == dict or attrs_has(typ): - return "record" - elif typ == float: - return "float" - elif typ == str: - return "string" - elif typ == bytes: - return "bytes" - elif configuration("allow_complex_types"): - return typ if isinstance(typ, str) else typ.__name__ + typ_dict: CommandInputSchema = {"label": label, "type": ""} + base: Any | None = typ + args = get_args(typ) + if args: + base = get_origin(typ) + + if base == type(None): + typ_dict["type"] = "null" + elif base == int: + typ_dict["type"] = "int" + elif base == bool: + typ_dict["type"] = "boolean" + elif base == dict or (isinstance(base, type) and attrs_has(base)): + typ_dict["type"] = "record" + elif base == float: + typ_dict["type"] = "float" + elif base == str: + typ_dict["type"] = "string" + elif base == bytes: + typ_dict["type"] = "bytes" elif isinstance(typ, UnionType): - return [to_cwl_type(item) for item in get_args(typ)] - elif isinstance(typ, Iterable): + typ_dict.update( + {"type": tuple(to_cwl_type(label, item)["type"] for item in args)} + ) + elif isclass(base) and issubclass(base, Iterable): try: - basic_types = get_args(typ) - if len(basic_types) > 1: - return { - "type": "array", - "items": [{"type": to_cwl_type(t)} for t in basic_types], - } + if len(args) > 1: + typ_dict.update( + { + "type": "array", + "items": [to_cwl_type(label, t)["type"] for t in args], + } + ) + elif len(args) == 1: + typ_dict.update( + {"type": "array", "items": to_cwl_type(label, args[0])["type"]} + ) else: - return {"type": "array", "items": to_cwl_type(basic_types[0])} + typ_dict["type"] = "array" except IndexError as err: raise TypeError( - f"Cannot render complex type ({typ}) to CWL, have you enabled allow_complex_types configuration?" + f"Cannot render complex type ({typ}) to CWL for {label}, have you enabled allow_complex_types configuration?" ) from err + elif get_render_configuration("allow_complex_types"): + typ_dict["type"] = typ if isinstance(typ, str) else typ.__name__ else: - raise TypeError(f"Cannot render complex type ({typ}) to CWL") - - -class CommandInputSchema(TypedDict): - """Structure for referring to a raw type in CWL. - - Encompasses several CWL types. In future, it may be best to - use _cwltool_ or another library for these basic structures. - - Attributes: - type: CWL type of this input. - label: name to show for this input. - fields: (for `record`) individual fields in a dict-like structure. - items: (for `array`) type that each field will have. - """ - - type: InputSchemaType - label: str - fields: NotRequired[dict[str, "CommandInputSchema"]] - items: NotRequired[InputSchemaType] - default: NotRequired[RawType] + raise TypeError(f"Cannot render type ({typ}) to CWL for {label}") + return typ_dict class CommandOutputSchema(CommandInputSchema): @@ -268,6 +414,8 @@ class CommandOutputSchema(CommandInputSchema): """ outputSource: NotRequired[str] + expression: NotRequired[str] + source: NotRequired[list[str]] def raw_to_command_input_schema(label: str, value: RawType | Unset) -> InputSchemaType: @@ -286,7 +434,7 @@ def raw_to_command_input_schema(label: str, value: RawType | Unset) -> InputSche if isinstance(value, dict) or isinstance(value, list): return _raw_to_command_input_schema_internal(label, value) else: - return cwl_type_from_value(value) + return cwl_type_from_value(label, value) def to_output_schema( @@ -329,9 +477,10 @@ def to_output_schema( fields=fields, ) else: + # TODO: this complains because NotRequired keys are never present, + # but that does not seem like a problem here - likely a better solution. output = CommandOutputSchema( - type=to_cwl_type(typ), - label=label, + **to_cwl_type(label, typ) # type: ignore ) if output_source is not None: output["outputSource"] = output_source @@ -341,8 +490,7 @@ def to_output_schema( def _raw_to_command_input_schema_internal( label: str, value: RawType | Unset ) -> CommandInputSchema: - typ = cwl_type_from_value(value) - structure: CommandInputSchema = {"type": typ, "label": label} + structure: CommandInputSchema = cwl_type_from_value(label, value) if isinstance(value, dict): structure["fields"] = { key: _raw_to_command_input_schema_internal(key, val) @@ -351,15 +499,20 @@ def _raw_to_command_input_schema_internal( elif isinstance(value, list): typeset = set(get_args(value)) if not typeset: - typeset = {type(item) for item in value} + typeset = { + item.__type__ + if item is not None and hasattr(item, "__type__") + else type(item) + for item in value + } if len(typeset) != 1: raise RuntimeError( "For CWL, an input array must have a consistent type, " "and we need at least one element to infer it, or an explicit typehint." ) - structure["items"] = to_cwl_type(typeset.pop()) + structure["items"] = to_cwl_type(label, typeset.pop())["type"] elif not isinstance(value, Unset): - structure["default"] = value + structure["default"] = firm_to_raw(value) return structure @@ -390,7 +543,7 @@ class CommandInputParameter: @classmethod def from_parameters( - cls, parameters: list[ParameterReference | FactoryCall] + cls, parameters: list[ParameterReference[Any] | FactoryCall] ) -> "InputsDefinition": """Takes a list of parameters into a CWL structure. @@ -399,13 +552,19 @@ def from_parameters( Returns: CWL-like structure representing all workflow outputs. """ + parameters_dedup = { + p._.parameter for p in parameters if isinstance(p, ParameterReference) + } + parameters = list(parameters_dedup) + [ + p for p in parameters if not isinstance(p, ParameterReference) + ] return cls( inputs={ input.name: cls.CommandInputParameter( - label=input.name, - default=input.default, + label=input.__name__, + default=(default := flatten_if_set(input.__default__)), type=raw_to_command_input_schema( - label=input.name, value=input.default + label=input.__original_name__, value=default ), ) for input in parameters @@ -421,14 +580,11 @@ def render(self) -> dict[str, RawType]: """ result: dict[str, RawType] = {} for key, input in self.inputs.items(): - item = { - # Would rather not cast, but CommandInputSchema is dict[RawType] - # by construction, where type is seen as a TypedDict subclass. - "type": cast(RawType, input.type), - "label": input.label, - } - if not isinstance(input.default, Unset): - item["default"] = input.default + # Would rather not cast, but CommandInputSchema is dict[RawType] + # by construction, where type is seen as a TypedDict subclass. + item = firm_to_raw(cast(FirmType, input.type)) + if isinstance(item, dict) and not isinstance(input.default, Unset): + item["default"] = firm_to_raw(input.default) result[key] = item return result @@ -443,11 +599,18 @@ class OutputsDefinition: outputs: sequence of results from a workflow. """ - outputs: dict[str, "CommandOutputSchema"] + outputs: ( + dict[str, "CommandOutputSchema"] + | list["CommandOutputSchema"] + | CommandOutputSchema + ) @classmethod def from_results( - cls, results: dict[str, StepReference[Any]] + cls, + results: dict[str, StepReference[Any]] + | list[StepReference[Any]] + | tuple[StepReference[Any], ...], ) -> "OutputsDefinition": """Takes a mapping of results into a CWL structure. @@ -456,23 +619,56 @@ def from_results( Returns: CWL-like structure representing all workflow outputs. """ - return cls( - outputs={ - key: to_output_schema( - result.field, result.return_type, output_source=result.name + + def _build_results(result: Any) -> RawType: + if isinstance(result, Reference): + # TODO: need to work out how to tell mypy that a TypedDict is also dict[str, RawType] + return to_output_schema( # type: ignore + with_field(result), with_type(result), output_source=to_name(result) ) - for key, result in results.items() - } - ) + results = result + return ( + [_build_results(result) for result in results] + if isinstance(results, list | tuple | Tuple) + else {key: _build_results(result) for key, result in results.items()} + ) - def render(self) -> dict[str, RawType]: + try: + # TODO: sort out this nested type building. + return cls(outputs=_build_results(results)) # type: ignore + except AttributeError: + expr, references = expr_to_references(results) + reference_names = sorted( + { + str(ref._.parameter) + if isinstance(ref, ParameterReference) + else str(ref._.step) + for ref in references + } + ) + return cls( + outputs={ + "out": { + "type": "float", # WARNING: we assume any arithmetic expression returns a float. + "label": "out", + "expression": str(expr), + "source": reference_names, + } + } + ) + + def render(self) -> dict[str, RawType] | list[RawType]: """Render to a dict-like structure. Returns: Reduced form as a native Python dict structure for serialization. """ - return {key: flatten(output) for key, output in self.outputs.items()} + return ( + [crawl_raw(output) for output in self.outputs] + if isinstance(self.outputs, list) + else {key: crawl_raw(output) for key, output in self.outputs.items()} + ) @define @@ -503,25 +699,31 @@ def from_workflow( workflow: workflow to convert. name: name of this workflow, if it should have one. """ - parameters: list[ParameterReference | FactoryCall] = list( + parameters: list[ParameterReference[Any] | FactoryCall] = list( workflow.find_parameters( - include_factory_calls=not configuration("factories_as_params") + include_factory_calls=not get_render_configuration( + "factories_as_params" + ) ) ) - if configuration("factories_as_params"): + if get_render_configuration("factories_as_params"): parameters += list(workflow.find_factories().values()) return cls( steps=[ StepDefinition.from_step(step) - for step in workflow.steps + for step in workflow.indexed_steps.values() if not ( isinstance(step, FactoryCall) - and configuration("factories_as_params") + and get_render_configuration("factories_as_params") ) ], inputs=InputsDefinition.from_parameters(parameters), outputs=OutputsDefinition.from_results( - {workflow.result.field: workflow.result} if workflow.result else {} + workflow.result + if isinstance(workflow.result, list | tuple | Tuple) + else {with_field(workflow.result): workflow.result} + if workflow.has_result and workflow.result is not None + else {} ), name=name, ) @@ -544,7 +746,7 @@ def render(self) -> dict[str, RawType]: def render( workflow: Workflow, **kwargs: Unpack[CWLRendererConfiguration] -) -> dict[str, RawType] | tuple[dict[str, RawType], dict[str, dict[str, RawType]]]: +) -> dict[str, dict[str, RawType]]: """Render to a dict-like structure. Args: @@ -555,16 +757,10 @@ def render( Reduced form as a native Python dict structure for serialization. """ - set_configuration(kwargs) - primary_workflow = WorkflowDefinition.from_workflow(workflow).render() - subworkflows = {} - for step in workflow.steps: - if isinstance(step, NestedStep): - subworkflows[step.name] = WorkflowDefinition.from_workflow( - step.subworkflow - ).render() - - if subworkflows: - return primary_workflow, subworkflows - - return primary_workflow + # TODO: Again, convincing mypy that a TypedDict has RawType values. + with set_render_configuration(kwargs): # type: ignore + rendered = base_render( + workflow, + lambda workflow: WorkflowDefinition.from_workflow(workflow).render(), + ) + return rendered diff --git a/src/dewret/renderers/snakemake.py b/src/dewret/renderers/snakemake.py index c2f809dd..d9c7cd5c 100644 --- a/src/dewret/renderers/snakemake.py +++ b/src/dewret/renderers/snakemake.py @@ -25,20 +25,23 @@ import typing from attrs import define -from dewret.utils import BasicType -from dewret.workflow import ( - Reference, +from dewret.core import ( Raw, + BasicType, + Reference, +) +from dewret.workflow import ( Workflow, Task, Lazy, BaseStep, ) +from dewret.render import ( + base_render, +) -MainTypes = typing.Union[ - BasicType, list[str], list["MainTypes"], dict[str, "MainTypes"] -] +MainTypes = BasicType | list[str] | list["MainTypes"] | dict[str, "MainTypes"] @define @@ -59,7 +62,7 @@ class ReferenceDefinition: source: str @classmethod - def from_reference(cls, ref: Reference) -> "ReferenceDefinition": + def from_reference(cls, ref: Reference[typing.Any]) -> "ReferenceDefinition": """Build from a `Reference`. Converts a `dewret.workflow.Reference` into a Snakemake-rendering object. @@ -240,7 +243,7 @@ def from_step(cls, step: BaseStep) -> "OutputDefinition": """ # TODO: Error handling # TODO: Better way to handling input/output files - output_file = step.arguments["output_file"] + output_file = step.arguments.get("output_file", Raw("OUTPUT_FILE")) if isinstance(output_file, Raw): args = to_snakemake_type(output_file) @@ -448,7 +451,7 @@ def raw_render(workflow: Workflow) -> dict[str, MainTypes]: return WorkflowDefinition.from_workflow(workflow).render() -def render(workflow: Workflow) -> str: +def render(workflow: Workflow) -> dict[str, typing.Any]: """Render the workflow as a Snakemake (SMK) string. This function converts a Workflow object into a Snakemake-compatible yaml. @@ -468,6 +471,9 @@ def render(workflow: Workflow) -> str: } ) - return yaml.dump( - WorkflowDefinition.from_workflow(workflow).render(), indent=4 - ).translate(trans_table) + return base_render( + workflow, + lambda workflow: yaml.dump( + WorkflowDefinition.from_workflow(workflow).render(), indent=4 + ).translate(trans_table), + ) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index a0e6c804..5786027c 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -35,30 +35,40 @@ from enum import Enum from functools import cached_property from collections.abc import Callable -from typing import Any, ParamSpec, TypeVar, cast, Generator +from typing import Any, ParamSpec, TypeVar, cast, Generator, Unpack, Literal from types import TracebackType from attrs import has as attrs_has from dataclasses import is_dataclass import traceback -from contextvars import ContextVar +from concurrent.futures import ThreadPoolExecutor +from contextvars import ContextVar, copy_context from contextlib import contextmanager -from .utils import is_raw, make_traceback +from .utils import is_firm, make_traceback, is_expr from .workflow import ( - StepReference, - ParameterReference, + expr_to_references, + unify_workflows, + UNSET, Workflow, Lazy, LazyEvaluation, Target, LazyFactory, - merge_workflows, Parameter, + ParameterReference, param, Task, is_task, ) from .backends._base import BackendModule +from .annotations import FunctionAnalyser +from .core import ( + get_configuration, + set_configuration, + IteratedGenerator, + ConstructConfigurationTypedDict, + Reference +) Param = ParamSpec("Param") RetType = TypeVar("RetType") @@ -132,21 +142,24 @@ def make_lazy(self) -> LazyFactory: """ return self.backend.lazy - def evaluate(self, task: Lazy, __workflow__: Workflow, **kwargs: Any) -> Any: + def evaluate(self, task: Lazy | list[Lazy] | tuple[Lazy], __workflow__: Workflow, thread_pool: ThreadPoolExecutor | None=None, **kwargs: Any) -> Any: """Evaluate a single task for a known workflow. Args: task: the task to evaluate. __workflow__: workflow within which this exists. + thread_pool: existing pool of threads to run this in, or None. **kwargs: any arguments to pass to the task. """ - result = self.backend.run(__workflow__, task, **kwargs) - result.__workflow__.set_result(result) - if __workflow__ is not None and result.__workflow__ != __workflow__: - workflow = Workflow.assimilate(__workflow__, result.__workflow__) - else: - workflow = result.__workflow__ - return workflow.result + result = self.backend.run(__workflow__, task, thread_pool=thread_pool, **kwargs) + new_result, collected_workflow = unify_workflows(result, __workflow__) + + if collected_workflow is None: + raise RuntimeError("A new workflow could not be found") + + # Then we set the result to be the whole thing + collected_workflow.set_result(new_result) + return collected_workflow.result def unwrap(self, task: Lazy) -> Target: """Unwraps a lazy-evaluated function to get the function. @@ -186,9 +199,8 @@ def ensure_lazy(self, task: Any) -> Lazy | None: def __call__( self, task: Any, - simplify_ids: bool = False, __workflow__: Workflow | None = None, - **kwargs: Any, + **kwargs: Unpack[ConstructConfigurationTypedDict], ) -> Workflow: """Execute the lazy evalution. @@ -201,7 +213,16 @@ def __call__( A reusable reference to this individual step. """ workflow = __workflow__ or Workflow() - result = self.evaluate(task, workflow, **kwargs) + + with set_configuration(**kwargs): + context = copy_context().items() + def _initializer() -> None: + for var, value in context: + var.set(value) + thread_pool = ThreadPoolExecutor(initializer=_initializer) + + result = self.evaluate(task, workflow, thread_pool=thread_pool, **kwargs) + simplify_ids = bool(get_configuration("simplify_ids")) return Workflow.from_result(result, simplify_ids=simplify_ids) @@ -295,35 +316,7 @@ def factory(fn: Callable[..., RetType]) -> Callable[..., RetType]: return task(is_factory=True)(fn) -def nested_task() -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: - """Shortcut for marking a task as nested. - - A nested task is one which calls other tasks and does not - do anything else important. It will _not_ actually get called - at runtime, but should map entirely into the graph. As such, - arithmetic operations on results, etc. will cause errors at - render-time. Combining tasks is acceptable, and intended. The - effect of the nested task will be considered equivalent to whatever - reaching whatever step reference is returned at the end. - - ```python - >>> @task() - ... def increment(num: int) -> int: - ... return num + 1 - - >>> @nested_task() - ... def double_increment(num: int) -> int: - ... return increment(increment(num=num)) - - ``` - - Returns: - Task that runs at render, not execution, time. - """ - return task(nested=True) - - -def subworkflow() -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: +def workflow() -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: """Shortcut for marking a task as nested. A nested task is one which calls other tasks and does not @@ -339,7 +332,7 @@ def subworkflow() -> Callable[[Callable[Param, RetType]], Callable[Param, RetTyp ... def increment(num: int) -> int: ... return num + 1 - >>> @nested_task() + >>> @workflow() ... def double_increment(num: int) -> int: ... return increment(increment(num=num)) @@ -399,10 +392,13 @@ def _fn( __traceback__: TracebackType | None = None, **kwargs: Param.kwargs, ) -> RetType: + configuration = None + allow_positional_args = bool(get_configuration("allow_positional_args")) + try: # Ensure that all arguments are passed as keyword args and prevent positional args. # passed at all. - if args: + if args and not allow_positional_args: raise TypeError( f""" Calling {fn.__name__}: Arguments must _always_ be named, @@ -418,44 +414,69 @@ def add_numbers(left: int, right: int): # Ensure that the passed arguments are, at least, a Python-match for the signature. sig = inspect.signature(fn) - sig.bind(*args, **kwargs) + positional_args = {key: False for key in kwargs} + for arg, (key, _) in zip(args, sig.parameters.items(), strict=False): + if isinstance(arg, IteratedGenerator): + for inner_arg, (key, _) in zip(arg, sig.parameters.items(), strict=False): + if key in positional_args: + continue + kwargs[key] = inner_arg + positional_args[key] = True + else: + kwargs[key] = arg + positional_args[key] = True + sig.bind(**kwargs) - workflows = [ - reference.__workflow__ - for reference in kwargs.values() + def _to_param_ref(value: Any) -> ParameterReference[Any] | None: + if isinstance(value, Parameter): + return value.make_reference(workflow=__workflow__) + return None + + refs = [] + for key, val in kwargs.items(): + val, kw_refs = expr_to_references(val, remap=_to_param_ref) + refs += kw_refs + kwargs[key] = val + # Not realistically going to be other than Workflow. + workflows: list[Workflow] = [ + cast(Workflow, reference.__workflow__) + for reference in refs if hasattr(reference, "__workflow__") and reference.__workflow__ is not None ] if __workflow__ is not None: workflows.insert(0, __workflow__) if workflows: - workflow = merge_workflows(*workflows) + workflow = Workflow.assimilate(*workflows) else: workflow = Workflow() + + analyser = FunctionAnalyser(fn) + if not is_in_nested_task(): for var, value in kwargs.items(): - if is_raw(value): + if analyser.is_at_construct_arg(var): + kwargs[var] = value + elif is_firm(value): # We leave this reference dangling for a consumer to pick up ("tethered"), unless # we are in a nested task, that does not have any existence of its own. - kwargs[var] = ParameterReference( - workflow, + tethered: Literal[False] | None = ( + False if nested and ( + flatten_nested or get_configuration("flatten_all_nested") + ) else None + ) + kwargs[var] = cast( + Parameter[Any], param( var, value, - tethered=( - False if nested and flatten_nested else None - ), - autoname=True, - ), - ) - elif isinstance(value, Parameter): - kwargs[var] = ParameterReference(workflow, value) + tethered=tethered, + autoname=tethered is not False, + typ=analyser.get_argument_annotation(var) or UNSET + ) + ).make_reference(workflow=workflow) original_kwargs = dict(kwargs) - try: - fn_globals = dict(inspect.getclosurevars(fn).globals) - # This covers the case of wrapping, rather than decorating. - except TypeError: - fn_globals = {} + fn_globals = analyser.globals for var, value in fn_globals.items(): # This error is redundant as it triggers a SyntaxError in Python. @@ -465,64 +486,97 @@ def add_numbers(left: int, right: int): # "Captured parameter {var} (global variable in task) shadows an argument" # ) if isinstance(value, Parameter): - kwargs[var] = ParameterReference(workflow, value) - elif is_raw(value): - kwargs[var] = ParameterReference( - workflow, param(var, value, tethered=False) - ) + kwargs[var] = value.make_reference(workflow=workflow) elif is_task(value) or ensure_lazy(value) is not None: if not nested and _workaround_check_value_is_task( fn, var, value ): raise TypeError( f""" - You referenced a task {var} inside another task {fn.__name__}, but it is not a nested_task + You referenced a task {var} inside another task {fn.__name__}, but it is not a workflow - this will not be found! - @task + @task() def {var}(...) -> ...: ... - @nested_task <<<--- likely what you want + @workflow() <<<--- likely what you want def {fn.__name__}(...) -> ...: ... {var}(...) ... """ ) - elif attrs_has(value) or is_dataclass(value): + # If nested, we will execute the insides, and it is reasonable and important + # to have a full set of annotations for any encountered variables. + elif nested and not analyser.get_argument_annotation(var, exhaustive=True) and not inspect.isclass(value) or inspect.isfunction(value): + raise RuntimeError(f"Could not find a type annotation for {var} for {fn.__name__}") + elif ( + analyser.is_at_construct_arg(var, exhaustive=True) or + value is evaluate or value is construct # Allow manual building. + ): + kwargs[var] = value + elif ( + inspect.isclass(value) or + inspect.isfunction(value) + ): + # We assume these are loaded at runtime. ... + elif is_firm(value) or ( + (attrs_has(value) or is_dataclass(value)) and + not inspect.isclass(value) + ): + kwargs[var] = cast( + Parameter[Any], + param( + var, + default=value, + tethered=False, + typ=analyser.get_argument_annotation(var, exhaustive=True) or UNSET + ) + ).make_reference(workflow=workflow) + elif is_expr(value) and (expr_refs := expr_to_references(value)) and len(expr_refs[1]) != 0: + kwargs[var] = value elif nested: raise NotImplementedError( f"Nested tasks must now only refer to global parameters, raw or tasks, not objects: {var}" ) if nested: - if flatten_nested: - output = fn(**original_kwargs) + if flatten_nested or get_configuration("flatten_all_nested"): + with in_nested_task(): + output = analyser.with_new_globals(kwargs)(**original_kwargs) lazy_fn = ensure_lazy(output) - if lazy_fn is None: - raise TypeError( - f"Task {fn.__name__} returned output of type {type(output)}, which is not a lazy function for this backend." - ) - step_reference = evaluate(lazy_fn, __workflow__=workflow) + if lazy_fn is not None: + with in_nested_task(): + output = evaluate(lazy_fn, __workflow__=workflow) + #raise TypeError( + # f"Task {fn.__name__} returned output of type {type(output)}, which is not a lazy function for this backend." + #) + step_reference = output else: nested_workflow = Workflow(name=fn.__name__) - nested_kwargs: Param.kwargs = { - var: ParameterReference( - nested_workflow, - param( - var, typ=value.__type__, tethered=nested_workflow - ), - ) - for var, value in original_kwargs.items() + nested_globals: Param.kwargs = { + var: cast( + Parameter[Any], + param( + var, + default=value.__default__ if hasattr(value, "__default__") else UNSET, + typ=( + value.__type__ + ), + tethered=nested_workflow + ) + ).make_reference(workflow=nested_workflow) if isinstance(value, Reference) else value + for var, value in kwargs.items() } + nested_kwargs = {key: value for key, value in nested_globals.items() if key in original_kwargs} with in_nested_task(): - output = fn(**nested_kwargs) - nested_workflow = _manager(output, __workflow__=nested_workflow) + output = analyser.with_new_globals(nested_globals)(**nested_kwargs) + nested_workflow = _manager(output, __workflow__=nested_workflow) step_reference = workflow.add_nested_step( - fn.__name__, nested_workflow, kwargs + fn.__name__, nested_workflow, analyser.return_type, original_kwargs, positional_args ) - if isinstance(step_reference, StepReference): + if is_expr(step_reference): return cast(RetType, step_reference) raise TypeError( f"Nested tasks must return a step reference, not {type(step_reference)} to ensure graph makes sense." @@ -532,8 +586,9 @@ def {fn.__name__}(...) -> ...: workflow.add_step( fn, kwargs, - raw_as_parameter=is_in_nested_task(), + raw_as_parameter=not is_in_nested_task(), is_factory=is_factory, + positional_args=positional_args ), ) return step @@ -544,11 +599,15 @@ def {fn.__name__}(...) -> ...: fn, declaration_tb, __traceback__, - exc.args[0] if exc.args else "Could not call task {fn.__name__}", + exc.args[0] if exc.args else f"Could not call task {fn.__name__}", ) from exc + finally: + if configuration: + configuration.__exit__(None, None, None) _fn.__step_expression__ = True # type: ignore - return LazyEvaluation(lazy()(_fn)) + _fn.__original__ = fn # type: ignore + return LazyEvaluation(_fn) return _task diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 05b3b8e8..9f0381ab 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -20,13 +20,20 @@ import hashlib import json import sys -from types import FrameType, TracebackType -from typing import Any, cast, Union, Protocol, ClassVar +import importlib +import importlib.util +from types import FrameType, TracebackType, UnionType, ModuleType +from typing import Any, cast, Protocol, ClassVar, Callable, Iterable, get_args, Hashable +from pathlib import Path from collections.abc import Sequence, Mapping +from dataclasses import asdict, is_dataclass +from sympy import Basic, Integer, Float, Rational -BasicType = str | float | bool | bytes | int | None -RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] -FirmType = BasicType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] +from .core import Reference, RawType, FirmType, Raw + + +class Unset: + """Unset variable, with no default value.""" class DataclassProtocol(Protocol): @@ -55,34 +62,136 @@ def make_traceback(skip: int = 2) -> TracebackType | None: frame = frame.f_back return tb +def load_module_or_package(target_name: str, path: Path) -> ModuleType: + """Convenience loader for modules. + + If an `__init__.py` is found in the same location as the target, it will try to load the renderer module + as if it is contained in a package and, if it cannot, will fall back to loading the single file. + + Args: + target_name: module name that should appear in `sys.modules`. + path: location of the module. + + Returns: the loaded module. + """ + module: None | ModuleType = None + exception: None | Exception = None + package_init = path.parent / "__init__.py" + # Try to import the module as a package, if possible, to allow relative imports. + if package_init.exists(): + try: + spec = importlib.util.spec_from_file_location(target_name, str(package_init)) + if spec is None or spec.loader is None: + raise ImportError(f"Could not open {path.parent} package") + package = importlib.util.module_from_spec(spec) + sys.modules[target_name] = package + spec.loader.exec_module(package) + module = importlib.import_module(f"{target_name}.{path.stem}", target_name) + except ImportError as exc: + exception = exc + + if module is None: + try: + spec = importlib.util.spec_from_file_location(target_name, str(path)) + if spec is None or spec.loader is None: + raise ImportError(f"Could not open {path} module") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + except ImportError as exc: + if exception: + raise exc from exception + raise exc + + return module + +def flatten_if_set(value: Any) -> RawType | Unset: + """Takes a Raw-like structure and makes it RawType or Unset. -def flatten(value: Any) -> RawType: + Flattens if the value is set, but otherwise returns the unset + sentinel value as-is. + + Args: + value: value to squash + """ + if isinstance(value, Unset): + return value + return crawl_raw(value) + +def crawl_raw(value: Any, action: Callable[[Any], Any] | None = None) -> RawType: """Takes a Raw-like structure and makes it RawType. Particularly useful for squashing any TypedDicts. Args: value: value to squash + action: an callback to apply to each found entry, or None. + + Returns: a structuure that is guaranteed to be raw. + + Raises: RuntimeError if it cannot convert the value to raw. """ + if action is not None: + value = action(value) + if value is None: return value if isinstance(value, str) or isinstance(value, bytes): return value if isinstance(value, Mapping): - return {key: flatten(item) for key, item in value.items()} + return {key: crawl_raw(item, action) for key, item in value.items()} + if is_dataclass(value) and not isinstance(value, type): + return crawl_raw(asdict(value), action) if isinstance(value, Sequence): - return [flatten(item) for item in value] + return [crawl_raw(item, action) for item in value] if (raw := ensure_raw(value)) is not None: return raw raise RuntimeError(f"Could not flatten: {value}") +def firm_to_raw(value: FirmType) -> RawType: + """Convenience wrapper for firm structures. + + Turns structures that would be raw, except for tuples, into raw structures + by mapping any tuples to lists. + + Args: + value: a firm structure (contains raw/tuple values). + + Returns: a raw structure. + """ + return crawl_raw(value, lambda entry: list(entry) if isinstance(entry, tuple) else entry) + +def is_expr(value: Any, permitted_references: type=Reference) -> bool: + """Confirms whether a structure has only raw or expression types. + + Args: + value: a crawlable structure. + permitted_references: a class representing the allowed types of References. + + Returns: True if valid, otherwise False. + """ + return is_raw(value, lambda x: isinstance(x, Basic) or isinstance(x, tuple) or isinstance(x, permitted_references) or isinstance(x, Raw)) def is_raw_type(typ: type) -> bool: """Check if a type counts as "raw".""" + if isinstance(typ, UnionType): + return all(is_raw_type(t) for t in get_args(typ)) return issubclass(typ, str | float | bool | bytes | int | None | list | dict) -def is_raw(value: Any) -> bool: +def is_firm(value: Any, check: Callable[[Any], bool] | None = None) -> bool: + """Confirms whether a function is firm. + + That is, whether its contents are raw or tuples. + + Args: + value: value to check. + check: any additional check to apply. + + Returns: True if is firm, else False. + """ + return is_raw(value, lambda x: isinstance(x, tuple) and (check is None or check(x))) + +def is_raw(value: Any, check: Callable[[Any], bool] | None = None) -> bool: """Check if a variable counts as "raw". This works around a checking issue that isinstance of a union of types @@ -92,10 +201,29 @@ def is_raw(value: Any) -> bool: # Ideally this would be: # isinstance(value, RawType | list[RawType] | dict[str, RawType]) # but recursive types are problematic. - return isinstance(value, str | float | bool | bytes | int | None | list | dict) + if isinstance(value, str | float | bool | bytes | int | None | Integer | Float | Rational): + return True + + if check is not None and check(value): + return True + + if isinstance(value, Mapping): + return ( + (isinstance(value, dict) or (check is not None and check(value))) and + all(is_raw(key, check) for key in value.keys()) and + all(is_raw(val, check) for val in value.values()) + ) + + if isinstance(value, Iterable): + return ( + (isinstance(value, list) or (check is not None and check(value))) and + all(is_raw(key, check) for key in value) + ) + + return False -def ensure_raw(value: Any) -> RawType | None: +def ensure_raw(value: Any, cast_tuple: bool = False) -> RawType | None: """Check if a variable counts as "raw". This works around a checking issue that isinstance of a union of types @@ -120,13 +248,23 @@ def hasher(construct: FirmType) -> str: Hash string that should be unique to the construct. The limits of this uniqueness have not yet been explicitly calculated. """ - if isinstance(construct, Sequence) and not isinstance(construct, bytes | str): - if isinstance(construct, dict): - construct = list([k, hasher(v)] for k, v in sorted(construct.items())) - else: - # Cast to workaround recursive type - construct = cast(FirmType, list(construct)) - construct_as_string = json.dumps(construct) + def _make_hashable(construct: FirmType) -> Hashable: + hashed_construct: tuple[Hashable, ...] + if isinstance(construct, Sequence) and not isinstance(construct, bytes | str): + if isinstance(construct, Mapping): + hashed_construct = tuple((k, hasher(v)) for k, v in sorted(construct.items())) + else: + # Cast to workaround recursive type + hashed_construct = tuple(_make_hashable(v) for v in construct) + return hashed_construct + if not isinstance(construct, Hashable): + raise TypeError("Could not hash arguments") + return construct + if isinstance(construct, Hashable): + hashed_construct = construct + else: + hashed_construct = _make_hashable(construct) + construct_as_string = json.dumps(hashed_construct) hsh = hashlib.md5() hsh.update(construct_as_string.encode()) return hsh.hexdigest() diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 6e9e28d3..a37e6d18 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -20,48 +20,29 @@ from __future__ import annotations import inspect from collections.abc import Mapping, MutableMapping, Callable -import base64 -from attrs import define, has as attr_has, resolve_types, fields as attrs_fields +from attrs import has as attr_has, resolve_types, fields as attrs_fields from dataclasses import is_dataclass, fields as dataclass_fields -from collections import Counter -from typing import Protocol, Any, TypeVar, Generic, cast, Literal +from collections import Counter, OrderedDict +from typing import Protocol, Any, TypeVar, Generic, cast, Literal, Iterable, get_origin, get_args, Generator, Sized, Sequence, get_type_hints, TYPE_CHECKING, Hashable from uuid import uuid4 - import logging +from sympy import Symbol, Expr, Basic, Tuple + logger = logging.getLogger(__name__) -from .utils import hasher, RawType, is_raw, make_traceback, is_raw_type +from .core import IterableMixin, Reference, get_configuration, Raw, IteratedGenerator, strip_annotations, WorkflowProtocol, WorkflowComponent, ExprType +from .utils import hasher, is_raw, make_traceback, is_raw_type, is_expr, Unset T = TypeVar("T") +U = TypeVar("U") RetType = TypeVar("RetType") - -@define -class Raw: - """Value object for any raw types. - - This is able to hash raw types consistently and provides - a single type for validating type-consistency. - - Attributes: - value: the real value, e.g. a `str`, `int`, ... - """ - - value: RawType - - def __hash__(self) -> int: - """Provide a hash that is unique to the `value` member.""" - return hash(repr(self)) - - def __repr__(self) -> str: - """Convert to a consistent, string representation.""" - value: str - if isinstance(self.value, bytes): - value = base64.b64encode(self.value).decode("ascii") - else: - value = str(self.value) - return f"{type(self.value).__name__}|{value}" +CHECK_IDS = False +AVAILABLE_TYPES = { + "int": int, + "str": str +} class Lazy(Protocol): @@ -106,10 +87,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> RetType: LazyFactory = Callable[[Target], Lazy] -class Unset: - """Unset variable, with no default value.""" - - class UnsetType(Unset, Generic[T]): """Unset variable with a specific type. @@ -131,7 +108,8 @@ def __init__(self, raw_type: type[T]): UNSET = Unset() -class Parameter(Generic[T]): + +class Parameter(Generic[T], Symbol): """Global parameter. Independent parameter that will be used when a task is spotted @@ -146,9 +124,10 @@ class Parameter(Generic[T]): """ __name__: str + __name_suffix__: str = "" __default__: T | UnsetType[T] __tethered__: Literal[False] | None | BaseStep | Workflow - + __fixed_type__: type[T] | Unset autoname: bool = False def __init__( @@ -157,6 +136,7 @@ def __init__( default: T | UnsetType[T], tethered: Literal[False] | None | Step | Workflow = None, autoname: bool = False, + typ: type[T] | Unset = UNSET ): """Construct a parameter. @@ -165,19 +145,47 @@ def __init__( default: value to infer type, etc. from. tethered: a workflow or step that demands this parameter; None if not yet present, False if not desired. autoname: whether we should customize this name for uniqueness (it is not user-set). + typ: a type to override what would be automatically inferred. """ self.__original_name__ = name # TODO: is using this in a step hash a risk of ambiguity? (full name is circular) if autoname: - name = f"{name}-{uuid4()}" + self.__name_suffix__ = f"-{uuid4()}" self.autoname = autoname self.__name__ = name self.__default__ = default self.__tethered__ = tethered self.__callers__: list[BaseStep] = [] + self.__fixed_type__ = typ + if tethered and isinstance(tethered, BaseStep): + self.register_caller(tethered) + + @staticmethod + def is_loopable(typ: type) -> bool: + """Checks if this type can be looped over. + + In particular, checks if this is an iterable that is NOT a str or bytes, possibly disguised + behind an Annotated. + + Args: + typ: type to check. + + Returns: True if loopable, otherwise False. + """ + base = strip_annotations(typ)[0] + base = get_origin(base) or base + return inspect.isclass(base) and issubclass(base, Iterable) and not issubclass(base, str | bytes) + + @property + def __type__(self) -> type | Unset: + """Type associated with this parameter.""" + if self.__fixed_type__ is not UNSET: + return self.__fixed_type__ + + default = self.__default__ if ( default is not None and hasattr(default, "__type__") @@ -186,26 +194,61 @@ def __init__( raw_type = default.__type__ else: raw_type = type(default) - self.__type__: type[T] = raw_type + return raw_type - if tethered and isinstance(tethered, BaseStep): - self.register_caller(tethered) + def __eq__(self, other: object) -> bool: + """Comparing two parameters. + + Currently, this uses the hashes. + + TODO: confirm this is an iff. + """ + return hash(self) == hash(other) + + def __new__(cls, *args: Any, **kwargs: Any) -> "Parameter[T]": + """Creates a Parameter. + + Required, as Parameters are an instance of sympy Expression, so + we must instantiate via it. + """ + instance = Expr.__new__(cls) + instance._assumptions0 = {} + return cast(Parameter[T], instance) def __hash__(self) -> int: """Get a unique hash for this parameter.""" - if self.__tethered__ is None: - raise RuntimeError( - "Parameter {self.full_name} was never tethered but should have been" - ) + # if self.__tethered__ is None: + # raise RuntimeError( + # f"Parameter {self.name} was never tethered but should have been" + # ) return hash(self.__name__) + def make_reference(self, **kwargs: Any) -> "ParameterReference[T]": + """Creates a new reference for the parameter. + + The kwargs will be passed to the constructor, but the + + Args: + typ: type of the new reference's target. + **kwargs: arguments to pass to the constructor. + + Returns: checks if `typ` is loopable and gives an IterableParameterReference if so, + otherwise a normal ParameterReference. + """ + kwargs["parameter"] = self + kwargs.setdefault("typ", self.__type__) + typ = kwargs["typ"] + if self.is_loopable(typ): + return IterableParameterReference(**kwargs) + return ParameterReference(**kwargs) + @property def default(self) -> T | UnsetType[T]: """Retrieve default value for this parameter, or an unset token.""" return self.__default__ @property - def full_name(self) -> str: + def name(self) -> str: """Extended name, suitable for rendering. This attempts to create a unique name by tying the parameter to a step @@ -214,7 +257,7 @@ def full_name(self) -> str: """ tethered = self.__tethered__ if tethered is False or tethered is None or self.autoname is False: - return self.__name__ + return self.__name__ + self.__name_suffix__ else: return f"{tethered.name}-{self.__original_name__}" @@ -228,14 +271,15 @@ def register_caller(self, caller: BaseStep) -> None: self.__tethered__ = caller self.__callers__.append(caller) - @property - def name(self) -> str: - """Name for this step. + def __getattr__(self, attr: str) -> Reference[T] | Any: + """Retrieve a reference to a field within this Parameter. - May be remapped by the workflow to something nicer - than the ID. + Arg: + attr: a field to find. + + Returns: a new reference with the `attr` appended to the field tuple. """ - return self.full_name + return getattr(self.make_reference(workflow=None), attr) def param( @@ -257,7 +301,7 @@ def param( raise ValueError("Must provide a default or a type") default = UnsetType[T](typ) return cast( - T, Parameter(name, default=default, tethered=tethered, autoname=autoname) + T, Parameter(name, default=default, tethered=tethered, autoname=autoname, typ=typ) ) @@ -328,29 +372,55 @@ class Workflow: result: target reference to evaluate, if yet present. """ - steps: list["BaseStep"] + _steps: list["BaseStep"] tasks: MutableMapping[str, "Task"] - result: StepReference[Any] | None + result: StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any], ...] | None _remapping: dict[str, str] | None _name: str | None def __init__(self, name: str | None = None) -> None: """Initialize a Workflow, by setting `steps` and `tasks` to empty containers.""" - self.steps = [] + self._steps = [] self.tasks = {} self.result: StepReference[Any] | None = None self._remapping = None self._name = name + @property + def steps(self) -> set[BaseStep]: + """Get deduplicated steps. + + Returns: steps for looping over without duplicates. + """ + return set(self._steps) + def __str__(self) -> str: """Name of the workflow, if available.""" if self._name is None: return super().__str__() return self.name + def __repr__(self) -> str: + """Representation of the workflow. + + This will be the `id` if the workflow is anonymous. + + Returns: string identifier for this workflow. + """ + if self._name: + return self.name + return self.id + def __hash__(self) -> int: """Hashes for finding.""" - return hash(self.name) + return hash(self.id) + + @property + def id(self) -> str: + """Consistent ID based off the step IDs.""" + comp_tup = tuple(sorted(s.id for s in self.steps)) + return f"workflow-{hasher(comp_tup)}" + def __eq__(self, other: object) -> bool: """Is this the same workflow? @@ -361,13 +431,23 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, Workflow): return False return ( - self.steps == other.steps + self._steps == other._steps and self.tasks == other.tasks and self.result == other.result and self._remapping == other._remapping and self._name == other._name ) + @property + def has_result(self) -> bool: + """Confirms whether this workflow has a non-empty result. + + Either None or an empty list/tuple are considered empty for this purpose. + + Returns: True if the workflow has a result, False otherwise. + """ + return not(self.result is None or (isinstance(self.result, list | tuple) and not self.result)) + @property def name(self) -> str: """Get the name of the workflow. @@ -381,11 +461,11 @@ def name(self) -> str: def find_factories(self) -> dict[str, FactoryCall]: """Steps that are factory calls.""" - return {step.id: step for step in self.steps if isinstance(step, FactoryCall)} + return {step_id: step for step_id, step in self.indexed_steps.items() if isinstance(step, FactoryCall)} def find_parameters( self, include_factory_calls: bool = True - ) -> set[ParameterReference]: + ) -> set[Parameter[Any]]: """Crawl steps for parameter references. As the workflow does not hold its own list of parameters, this @@ -394,22 +474,13 @@ def find_parameters( Returns: Set of all references to parameters across the steps. """ - return set().union( - *( - { - arg - for arg in step.arguments.values() - if ( - isinstance(arg, ParameterReference) - and (include_factory_calls or not isinstance(step, FactoryCall)) - ) - } - for step in self.steps - ) + _, references = expr_to_references( + step.arguments for step in self.steps if (include_factory_calls or not isinstance(step, FactoryCall)) ) + return {ref._.parameter for ref in references if isinstance(ref, ParameterReference)} @property - def _indexed_steps(self) -> dict[str, BaseStep]: + def indexed_steps(self) -> dict[str, BaseStep]: """Steps mapped by ID. Forces generation of IDs. Note that this effectively @@ -419,10 +490,10 @@ def _indexed_steps(self) -> dict[str, BaseStep]: Returns: Mapping of steps by ID. """ - return {step.id: step for step in self.steps} + return OrderedDict(sorted(((step.id, step) for step in self.steps), key=lambda x: x[0])) @classmethod - def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": + def assimilate(cls, *workflow_args: Workflow) -> "Workflow": """Combine two Workflows into one Workflow. Takes two workflows and unifies them by combining steps @@ -431,53 +502,73 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": This could happen if the hashing function is flawed or some Python magic to do with Targets being passed. - Argument: - left: workflow to use as base - right: workflow to combine on top + Args: + workflow_args: workflows to use as base + + j """ - new = cls() + workflows = sorted((w for w in set(workflow_args)), key=lambda w: w.id) + base = workflows[0] + + if len(workflows) == 1: + return base - new._name = left._name or right._name + names = sorted({w._name for w in workflows if w._name}) + base._name = base._name or (names and names[0]) or None - left_steps = left._indexed_steps - right_steps = right._indexed_steps + #left_steps = left._indexed_steps + #right_steps = right._indexed_steps + all_steps = sorted(sum((list(w.indexed_steps.items()) for w in workflows), []), key=lambda s: s[0]) - for step in list(left_steps.values()) + list(right_steps.values()): - step.set_workflow(new) - for arg in step.arguments: - if hasattr(arg, "__workflow__"): - arg.__workflow__ = new + for _, step in all_steps: + #for step in list(left_steps.values()) + list(right_steps.values()): + step.set_workflow(base) - for step_id in left_steps.keys() & right_steps.keys(): - if left_steps[step_id] != right_steps[step_id]: + indexed_steps: dict[str, BaseStep] = {} + for step_id, step in all_steps: + indexed_steps.setdefault(step_id, step) + if step != indexed_steps[step_id]: raise RuntimeError( f"Two steps have same ID but do not match: {step_id}" ) - for task_id in left.tasks.keys() & right.tasks.keys(): - if left.tasks[task_id] != right.tasks[task_id]: - raise RuntimeError("Two tasks have same name but do not match") + all_tasks = sum((list(w.tasks.items()) for w in workflows), []) + indexed_tasks: dict[str, Task] = {} + for task_id, task in all_tasks: + indexed_tasks.setdefault(task_id, task) + if task != indexed_tasks[task_id]: + raise RuntimeError(f"Two tasks have same name {task_id} but do not match") - indexed_steps = dict(left_steps) - indexed_steps.update(right_steps) - new.steps += list(indexed_steps.values()) - new.tasks.update(left.tasks) - new.tasks.update(right.tasks) + base._steps = sorted(indexed_steps.values(), key=lambda s: s.id) + base.tasks = indexed_tasks - for step in new.steps: - step.set_workflow(new, with_arguments=True) + for step in base.steps: + step.set_workflow(base, with_arguments=True) - # TODO: should we combine as a result array? - result = left.result or right.result + hashable_workflows: list[Workflow] = [w for w in workflows if isinstance(w.result, Hashable)] + if len(hashable_workflows) != len(workflows): + raise NotImplementedError("Some results are not hashable.") - if result: - new.set_result( - StepReference( - new, result.step, typ=result.return_type, field=result.field - ) - ) + def _get_order(result: None | StepReference[Any] | Iterable[StepReference[Any]]) -> str: + if result is None: + return "" + if isinstance(result, StepReference): + return result.id + return "|".join(r for r in result) + + + results = sorted(set({w.result for w in hashable_workflows if w.has_result}), key=lambda r: _get_order(r)) + if len(results) == 1: + result = results[0] + else: + list_results = [r if isinstance(r, tuple | list | Tuple) else (r,) for r in results] + result = sum(map(list, list_results), []) - return new + if result is not None and result != []: + unify_workflows(result, base, set_only=True) + base.set_result(result) + + return base def remap(self, step_id: str) -> str: """Apply name simplification if requested. @@ -498,7 +589,7 @@ def simplify_ids(self, infix: list[str] | None = None) -> None: counter = Counter[Task | Workflow]() self._remapping = {} infix_str = ("-".join(infix) + "-") if infix else "" - for step in self.steps: + for step in self.indexed_steps.values(): counter[step.task] += 1 self._remapping[step.id] = f"{step.task}-{infix_str}{counter[step.task]}" if isinstance(step, NestedStep): @@ -508,9 +599,9 @@ def simplify_ids(self, infix: list[str] | None = None) -> None: param_counter = Counter[str]() name_to_original: dict[str, str] = {} for name, param in { - pr.parameter.__name__: pr.parameter - for pr in self.find_parameters() - if isinstance(pr, ParameterReference) + param.__name__: param + for param in self.find_parameters() + if isinstance(param, Parameter) }.items(): if param.__original_name__ != name: param_counter[param.__original_name__] += 1 @@ -541,7 +632,7 @@ def register_task(self, fn: Lazy) -> Task: return task def add_nested_step( - self, name: str, subworkflow: Workflow, kwargs: dict[str, Any] + self, name: str, subworkflow: Workflow, return_type: type | None, kwargs: dict[str, Any], positional_args: dict[str, bool] | None = None ) -> StepReference[Any]: """Append a nested step. @@ -550,21 +641,28 @@ def add_nested_step( Args: name: name of the subworkflow. subworkflow: the subworkflow itself. + return_type: a forced type for the return, or None. kwargs: any key-value arguments to pass in the call. + positional_args: a mapping of arguments to bools, True if the argument is positional or otherwise False. + + Returns: a reference to the step that calls out to a new workflow. """ step = NestedStep(self, name, subworkflow, kwargs) - self.steps.append(step) - return_type = step.return_type + if positional_args is not None: + step.positional_args = positional_args + self._steps.append(step) + return_type = return_type or step.return_type if return_type is inspect._empty: raise TypeError("All tasks should have a type annotation.") - return StepReference(self, step, return_type) + return step.make_reference(workflow=self, typ=return_type) def add_step( self, fn: Lazy, - kwargs: dict[str, Raw | Reference], + kwargs: dict[str, Raw | Reference[Any]], raw_as_parameter: bool = False, is_factory: bool = False, + positional_args: dict[str, bool] | None = None ) -> StepReference[Any]: """Append a step. @@ -576,11 +674,14 @@ def add_step( kwargs: any key-value arguments to pass in the call. raw_as_parameter: whether to turn any discovered raw arguments into workflow parameters. is_factory: whether this step should be a Factory. + positional_args: a mapping of arguments to bools, True if the argument is positional or otherwise False. """ task = self.register_task(fn) step_maker = FactoryCall if is_factory else Step step = step_maker(self, task, kwargs, raw_as_parameter=raw_as_parameter) - self.steps.append(step) + if positional_args is not None: + step.positional_args = positional_args + self._steps.append(step) return_type = step.return_type if ( return_type is inspect._empty @@ -588,23 +689,38 @@ def add_step( and not inspect.isclass(fn) ): raise TypeError("All tasks should have a type annotation.") - return StepReference(self, step, return_type) + return step.make_reference(workflow=self, typ=return_type) @staticmethod def from_result( - result: StepReference[Any], simplify_ids: bool = False, nested: bool = True + result: StepReference[Any] | list[StepReference[Any]] | tuple[StepReference[Any], ...], simplify_ids: bool = False, nested: bool = True ) -> Workflow: """Create from a desired result. Starts from a result, and builds a workflow to output it. """ - workflow = result.__workflow__ - workflow.set_result(result) + expr, _refs = expr_to_references(result) + if not _refs: + raise RuntimeError( + "Attempted to build a workflow from a return-value/result/expression with no references." + ) + refs = list(_refs) + if isinstance(refs[0], Parameter): + raise RuntimeError("Attempted to build a workflow from an input parameter.") + workflow = refs[0].__workflow__ + # Ensure that we have exactly one workflow, even if multiple results. + for entry in refs[1:]: + if entry.__workflow__ != workflow: + raise RuntimeError("If multiple results, they must share a single workflow") + workflow.set_result(expr) if simplify_ids: workflow.simplify_ids() + if not isinstance(workflow, Workflow): + # This may be better as a cast, since this seems an artificial case. + raise TypeError("Can only return a workflow of the same type") return workflow - def set_result(self, result: StepReference[Any]) -> None: + def set_result(self, result: Basic | list[Basic] | tuple[Basic]) -> None: """Choose the result step. Sets a step as being the result for the entire workflow. @@ -616,29 +732,21 @@ def set_result(self, result: StepReference[Any]) -> None: Args: result: reference to the chosen step. """ - if result.step.__workflow__ != self: - raise RuntimeError("Output must be from a step in this workflow.") + _, refs = expr_to_references(result) + for entry in refs: + if entry.__workflow__ != self: + raise RuntimeError("Output must be from a step in this workflow.") self.result = result - -class WorkflowComponent: - """Base class for anything directly tied to an individual `Workflow`. - - Attributes: - __workflow__: the `Workflow` that this is tied to. - """ - - __workflow__: Workflow - - def __init__(self, workflow: Workflow): - """Tie to a `Workflow`. - - All subclasses must call this. - - Args: - workflow: the `Workflow` to tie to. - """ - self.__workflow__ = workflow + @property + def result_type(self) -> type: + """Overall return type of this workflow.""" + if self.result is None: + return type(None) + if hasattr(self.result, "__type__"): + return self.result.__type__ + # TODO: get individual types! + return type(self.result) class WorkflowLinkedComponent(Protocol): @@ -657,14 +765,206 @@ def __workflow__(self) -> Workflow: ... -class Reference: - """Superclass for all symbolic references to values.""" +class FieldableProtocol(Protocol): + """Expected interfaces for a type that can take fields. + + Attributes: + __field__: tuple representing the named fields, either strings or integers. + """ + __field__: tuple[str | int, ...] + + @property + def __field_sep__(self) -> str: + """The separator that should be used by default when rendering a name.""" + + @property + def __field_index_types__(self) -> tuple[type, ...]: + """Get the types that will be rendered as an index [x] rather than a field .x on stringifying. + + Will be taken from the `field_index_types` construct configuration. + """ + + def __init__(self, *args: Any, field: str | None = None, **kwargs: Any): + """Extract the field name from the initializer arguments, if provided.""" + super().__init__(*args, **kwargs) + + @property + def __name__(self: FieldableProtocol) -> str: + """Name of the fieldable, which may not be `name` if this is needed as an internal reference.""" + return self.name + + @property + def __type__(self) -> type: + """Type of this field with the overall target.""" + ... @property def name(self) -> str: - """Referral name for this reference.""" - raise NotImplementedError("Reference must provide a name") + """The name for the target, accounting for the field.""" + return "name" + + def __make_reference__(self, *args: Any, **kwargs: Any) -> Reference[Any]: + """Create a reference with constructor arguments, usually to a subfield.""" + ... + +# This required so that Mixins can trust that +# superclasss attributes will be available for mypy, +# but do not confuse the MRO. +if TYPE_CHECKING: + _Fieldable = FieldableProtocol +else: + _Fieldable = object + +# Subclass Reference so that we know Reference methods/attrs are available. +class FieldableMixin(_Fieldable): + """Tooling for enhancing a type with referenceable fields.""" + + def __init__(self: FieldableProtocol, *args: Any, field: str | int | tuple[str | int, ...] | None = None, **kwargs: Any): + """Extract the requested field, if any, from the initializer arguments. + + Args: + field: the new subfield, either an index (int) or a fieldname (str) or a field tuple representing the whole path, or None. + *args: arguments to pass to the other initializers. + **kwargs: arguments to pass to the other initializers. + """ + self.__field__: tuple[str, ...] = (field if isinstance(field, tuple) else (field,)) if field is not None else () + super().__init__(*args, **kwargs) + + @property + def __field_sep__(self) -> str: + """Get the field separator. + + Will be taken from the configuration key, `field_separator`. + """ + sep = get_configuration("field_separator") + if not isinstance(sep, str): + raise TypeError(f"The `field_separator` configuration argument must be a string not {type(sep)}") + return sep + + @property + def __field_index_types__(self) -> tuple[type, ...]: + """Get the types that will be rendered as an index [x] rather than a field .x on stringifying. + + Will be taken from the `field_index_types` construct configuration. + """ + types = get_configuration("field_index_types") + if not isinstance(types, str): + raise TypeError(f"The `field_index_types` configuration argument should be a comma-separated string of types from: {', '.join(AVAILABLE_TYPES.keys())}") + tup_all = tuple(AVAILABLE_TYPES.get(typ) for typ in types.split(",") if typ) + tup = tuple(t for t in tup_all if t is not None) + if tup_all != tup: + raise ValueError( + "Setting for fixed index types contains unavailable type: " + + f"{str(get_configuration('field_index_types'))} vs {tup}" + ) + return tup + + @property + def __field_str__(self) -> str: + """Stringified field, without the name.""" + return self.__field_suffix__.lstrip(self.__field_sep__) + + @property + def __field_suffix__(self) -> str: + """Stringified field, without the name but with an appropriate separator at the start.""" + result = "" + for cmpt in self.__field__: + if any(isinstance(cmpt, typ) for typ in self.__field_index_types__): + result += f"[{cmpt}]" + else: + result += f"{self.__field_sep__}{cmpt}" + return result + + def __field_up__(self) -> Reference[Any]: + """Get the parent field, if possible. + + Returns: reference to the field above this. + + Raises: + RuntimeError: if there is no field above this. + """ + if self.__field__: + return self.__make_reference__(field=tuple(self.__field__[:-1])) + raise RuntimeError("Cannot go upwards unless currently using a field.") + @property + def __name__(self) -> str: + """Name for this step. + + May be remapped by the workflow to something nicer + than the ID. + """ + return super().__name__ + self.__field_suffix__ + + def find_field(self: FieldableProtocol, field: str | int, fallback_type: type | None = None, **init_kwargs: Any) -> Reference[Any]: + """Field within the reference, if possible. + + Args: + field: the field to search for. + fallback_type: the type to use if we do not know a more specific one. + **init_kwargs: arguments to use for constructing a new reference (via `__make_reference__`). + + Returns: + A field-specific version of this reference. + """ + # Get new type, for the specific field. + parent_type, _ = strip_annotations(self.__type__) + field_type = fallback_type + base = get_origin(parent_type) or parent_type + + if isinstance(field, int): + if not inspect.isclass(base) or not issubclass(base, Sequence): + raise AttributeError(f"Tried to index int {field} into a non-sequence type {parent_type} (base: {base})") + if not (field_type := get_args(parent_type)[0]): + raise AttributeError( + f"Tried to index int {field} into type {parent_type} but can only do so if the first type argument " + f"is the element type (args: {get_args(parent_type)}" + ) + elif field.startswith("__"): + raise AttributeError(f"We do not allow fields with dunder prefix, such as {field}, to reduce risk of clashes.") + else: + if is_dataclass(parent_type): + try: + type_hints = get_type_hints(parent_type, localns={parent_type.__name__: parent_type}, include_extras=True) + field_type = type_hints.get(field) + if field_type is None: + field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type + except StopIteration: + raise AttributeError(f"Dataclass {parent_type} does not have field {field}") from None + elif attr_has(parent_type): + resolve_types(parent_type) + try: + field_type = getattr(attrs_fields(parent_type), field).type + except AttributeError as exc: + raise AttributeError(f"attrs-class {parent_type} does not have field {field}") from exc + # TypedDict + elif inspect.isclass(parent_type) and issubclass(parent_type, dict) and hasattr(parent_type, "__annotations__"): + try: + field_type = get_type_hints(parent_type, include_extras=True)[field] + except KeyError as exc: + raise AttributeError(f"TypedDict {parent_type} does not have field {field}") from exc + if not field_type and get_configuration("allow_plain_dict_fields") and inspect.isclass(base) and issubclass(base, dict): + args = get_args(parent_type) + if len(args) == 2 and args[0] is str: + field_type = args[1] + else: + raise AttributeError("Can only get fields for plain dicts if annotated dict[str, TYPE]") + + if field_type: + if not issubclass(self.__class__, Reference): + raise TypeError("Only references can have a fieldable mixin") + + new_field: tuple[str | int, ...] | str | int + if self.__field__: + new_field = tuple(list(self.__field__) + [field]) + else: + new_field = field + + return self.__make_reference__(typ=field_type, field=new_field, **init_kwargs) + + raise AttributeError( + f"Could not determine the type for field {field} in type {parent_type} (type of parent type is {type(parent_type)})" + ) class BaseStep(WorkflowComponent): """Lazy-evaluated function call. @@ -679,13 +979,15 @@ class BaseStep(WorkflowComponent): _id: str | None = None task: Task | Workflow - arguments: Mapping[str, Reference | Raw] + arguments: Mapping[str, Basic | Reference[Any] | Raw] + workflow: Workflow + positional_args: dict[str, bool] | None = None def __init__( self, workflow: Workflow, task: Task | Workflow, - arguments: Mapping[str, Reference | Raw], + arguments: Mapping[str, Reference[Any] | Raw], raw_as_parameter: bool = False, ): """Initialize a step. @@ -696,7 +998,7 @@ def __init__( arguments: key-value pairs to pass to the function. raw_as_parameter: whether to turn any raw-type arguments into workflow parameters (or just keep them as default argument values). """ - super().__init__(workflow) + super().__init__(workflow=workflow) self.task = task self.arguments = {} for key, value in arguments.items(): @@ -705,6 +1007,9 @@ def __init__( or isinstance(value, Reference) or isinstance(value, Raw) or is_raw(value) + or is_expr(value) + or is_dataclass(value) + or attr_has(value) ): # Avoid recursive type issues if ( @@ -714,15 +1019,22 @@ def __init__( and is_raw(value) ): if raw_as_parameter: - value = ParameterReference( - workflow, param(key, value, tethered=None) - ) + # We use param for convenience but note that it is a reference in disguise. + value = cast(Parameter[Any], param(key, value, tethered=None)).make_reference(workflow=workflow) else: value = Raw(value) - if isinstance(value, ParameterReference): - parameter = value.parameter - parameter.register_caller(self) - self.arguments[key] = value + + def _to_param_ref(value: Any) -> ParameterReference[Any] | None: + if isinstance(value, Parameter): + return value.make_reference(workflow=workflow) + return None + expression, refs = expr_to_references(value, remap=_to_param_ref) + + for ref in refs: + if isinstance(ref, ParameterReference): + parameter = ref._.parameter + parameter.register_caller(self) + self.arguments[key] = expression else: raise RuntimeError( f"Non-references must be a serializable type: {key}>{value} {type(value)}" @@ -737,13 +1049,36 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, BaseStep): return False return ( - self.__workflow__ == other.__workflow__ + self.__workflow__ is other.__workflow__ and self.task == other.task and self.arguments == other.arguments ) + def make_reference(self, **kwargs: Any) -> "StepReference[T]": + """Create a reference to this step. + + Builds a reference to the (result of) this step, which will be iterable if appropriate. + + Args: + **kwargs: arguments for reference constructor, which will be supplemented appropriately. + + Returns: a reference to the result of this step. + """ + kwargs["step"] = self + kwargs.setdefault("typ", self.return_type) + typ = kwargs["typ"] + base = strip_annotations(typ)[0] + base = get_origin(base) or base + if inspect.isclass(base) and issubclass(base, Iterable) and not issubclass(base, str | bytes): + return IterableStepReference(**kwargs) + return StepReference(**kwargs) + + def __hash__(self) -> int: + """Searchable hash for this step.""" + return hash(self.id) + def set_workflow(self, workflow: Workflow, with_arguments: bool = True) -> None: - """Move the step reference to a different workflow. + """Move the step reference to another workflow. This method is primarily intended to be called by a step, allowing it to switch to a new workflow. It also updates the workflow reference for any @@ -756,11 +1091,8 @@ def set_workflow(self, workflow: Workflow, with_arguments: bool = True) -> None: self.__workflow__ = workflow if with_arguments: for argument in self.arguments.values(): - if hasattr(argument, "__workflow__"): - try: - argument.__workflow__ = workflow - except AttributeError: - ... + unify_workflows(argument, workflow, set_only=True) + self._id = None @property def return_type(self) -> Any: @@ -773,15 +1105,16 @@ def return_type(self) -> Any: Expected type of the return value. """ if isinstance(self.task, Workflow): - if self.task.result: - return self.task.result.return_type + if self.task.result is not None: + return self.task.result_type else: raise AttributeError( "Cannot determine return type of a workflow with an unspecified result" ) if isinstance(self.task.target, type): return self.task.target - return inspect.signature(inspect.unwrap(self.task.target)).return_annotation + ann = get_type_hints(inspect.unwrap(self.task.target), include_extras=True)["return"] + return ann @property def name(self) -> str: @@ -799,12 +1132,13 @@ def id(self) -> str: self._id = self._generate_id() return self._id - check_id = self._generate_id() - if check_id != self._id: - return self._id - raise RuntimeError( - f"Cannot change a step after requesting its ID: {self.task}" - ) + if CHECK_IDS: + check_id = self._generate_id() + if check_id != self._id: + return self._id + raise RuntimeError( + f"Cannot change a step after requesting its ID: {self.task}" + ) return self._id def _generate_id(self) -> str: @@ -813,7 +1147,7 @@ def _generate_id(self) -> str: for key, param in self.arguments.items(): components.append((key, repr(param))) - comp_tup: tuple[str | tuple[str, str], ...] = tuple(components) + comp_tup: tuple[str | tuple[str, str], ...] = tuple(sorted(components, key=lambda pair: pair[0])) return f"{self.task}-{hasher(comp_tup)}" @@ -829,7 +1163,7 @@ def __init__( workflow: Workflow, name: str, subworkflow: Workflow, - arguments: Mapping[str, Reference | Raw], + arguments: Mapping[str, Basic | Reference[Any] | Raw], raw_as_parameter: bool = False, ): """Create a NestedStep. @@ -842,10 +1176,12 @@ def __init__( raw_as_parameter: whether raw-type arguments should be made (outer) workflow parameters. """ self.__subworkflow__ = subworkflow + base_arguments: dict[str, Basic | Reference[Any] | Raw] = {p.name: p for p in subworkflow.find_parameters()} + base_arguments.update(arguments) super().__init__( workflow=workflow, task=subworkflow, - arguments=arguments, + arguments=base_arguments, raw_as_parameter=raw_as_parameter, ) @@ -864,9 +1200,10 @@ def return_type(self) -> Any: Returns: Expected type of the return value. """ - if not self.__subworkflow__.result: + result = self.__subworkflow__.result + if result is None or (isinstance(result, list | tuple) and not result): raise RuntimeError("Can only use a subworkflow if the reference exists.") - return self.__subworkflow__.result.return_type + return self.__subworkflow__.result_type class Step(BaseStep): @@ -882,7 +1219,7 @@ def __init__( self, workflow: Workflow, task: Task | Workflow, - arguments: Mapping[str, Reference | Raw], + arguments: Mapping[str, Reference[Any] | Raw], raw_as_parameter: bool = False, ): """Initialize a step. @@ -893,22 +1230,36 @@ def __init__( arguments: key-value pairs to pass to the function - for a factory call, these _must_ be raw. raw_as_parameter: whether to turn any raw-type arguments into workflow parameters (or just keep them as default argument values). """ - for arg in list(arguments.values()): - if not is_raw(arg) and not ( + for _, arg in arguments.items(): + if not is_expr(arg) and not ( isinstance(arg, ParameterReference) and is_raw_type(arg.__type__) ): + try: + arg_name = str(arg) + except: + arg_name = "(unnamed)" raise RuntimeError( - f"Factories must be constructed with raw types {arg} {type(arg)}" + f"Factories must be constructed with raw types in {arg_name} {type(arg)}" ) - super().__init__(workflow, task, arguments, raw_as_parameter=raw_as_parameter) + super().__init__(workflow=workflow, task=task, arguments=arguments, raw_as_parameter=raw_as_parameter) @property - def default(self) -> Unset: + def __name__(self) -> str: + """Get the name of this factory call.""" + return self.name + + @property + def __original_name__(self) -> str: + """Original name of a factory call is just its normal name.""" + return self.name + + @property + def __default__(self) -> Unset: """Dummy default property for use as property.""" return UnsetType(self.return_type) -class ParameterReference(Reference): +class ParameterReference(FieldableMixin, Reference[U], WorkflowComponent): """Reference to an individual `Parameter`. Allows us to refer to the outputs of a `Parameter` in subsequent `Parameter` @@ -916,67 +1267,131 @@ class ParameterReference(Reference): Attributes: parameter: `Parameter` referred to. - __workflow__: Related workflow. In this case, as Parameters are generic + workflow: Related workflow. In this case, as Parameters are generic but ParameterReferences are specific, this carries the actual workflow reference. Returns: Workflow that the referee is related to. """ - parameter: Parameter[RawType] - __workflow__: Workflow + class ParameterReferenceMetadata(Generic[T]): + """Holder for attributes of this reference that we do not wish to risk confusing with fieldnames. - def __init__(self, __workflow__: Workflow, parameter: Parameter[RawType]): - """Initialize the reference. - - Args: - workflow: `Workflow` that this is tied to. - parameter: `Parameter` that this refers to. + Attributes: + parameter: the parameter to which this reference refers. """ - self.parameter = parameter - self.__workflow__ = __workflow__ + parameter: Parameter[T] + + def __init__(self, parameter: Parameter[T], typ: type[U] | Unset=UNSET): + """Initialize the reference. + + Args: + workflow: `Workflow` that this is tied to. + parameter: `Parameter` that this refers to. + typ: type that should override inferred type or None. + *args: any arguments for other initializers. + **kwargs: any arguments for other initializers. + """ + self.parameter = parameter + + @property + def unique_name(self) -> str: + """Unique, machine-generated name. + + Normally this will become invisible in output, but it avoids circularity + as a step that uses this parameter will ask for this when constructing + its own hash, but we will normally want to use the step's name as part of + the parameter name to distinguish from other parameters of the same name. + """ + return self.parameter.__name__ @property - def default(self) -> RawType | Unset: + def __default__(self) -> U | Unset: """Default value of the parameter.""" - return self.parameter.default + default = self._.parameter.default + if isinstance(default, Unset): + return default + + for field in self.__field__: + if isinstance(default, dict) or isinstance(field, int): + default = default[field] + else: + default = getattr(default, field) + return default @property - def __type__(self) -> type: - """Type represented by wrapped parameter.""" - return self.parameter.__type__ + def __root_name__(self) -> str: + """Reference based on the named step. - def __str__(self) -> str: - """Global description of the reference.""" - return self.parameter.full_name + May be remapped by the workflow to something nicer + than the ID. + """ + return self._.parameter.name - def __repr__(self) -> str: - """Hashable reference to the step (and field).""" - try: - typ = self.__type__.__name__ - except AttributeError: - typ = str(self.__type__) - return f"{typ}|:param:{self.unique_name}" + def __init__(self, parameter: Parameter[U], *args: Any, typ: type[U] | None=None, **kwargs: Any): + """Extract the parameter and type for setup. - @property - def unique_name(self) -> str: - """Unique, machine-generated name. + Args: + parameter: the parameter to reference. + typ: an overriding type for this reference, or None. + *args: arguments for other initializers. + **kwargs: arguments for other initializers. + """ + chosen_type = typ or parameter.__type__ + self._ = self.ParameterReferenceMetadata(parameter, typ=chosen_type) + super().__init__(*args, typ=typ, **kwargs) + + def __getitem__(self, attr: str) -> "ParameterReference[U]": + """Retrieve a field. - Normally this will become invisible in output, but it avoids circularity - as a step that uses this parameter will ask for this when constructing - its own hash, but we will normally want to use the step's name as part of - the parameter name to distinguish from other parameters of the same name. + Args: + attr: attribute to get. + + Returns: a reference to the field within this parameter, possibly nesting if we are already + referencing a field. """ - return self.parameter.__name__ + try: + return self.find_field( + field=attr, + workflow=self.__workflow__, + parameter=self._.parameter + ) + except AttributeError as exc: + raise KeyError( + f"Key not found in {self.__root_name__} ({type(self)}:{self.__type__}): {attr}" + + ( + ". This could be because you are trying to iterate/index a reference whose type is not definitely iterable - double check your typehints." + if isinstance(attr, int) else "" + ) + ) from exc @property - def name(self) -> str: - """Reference based on the named step. + def __original_name__(self) -> str: + """The name of the original parameter, without any field, etc.""" + return self._.parameter.__original_name__ - May be remapped by the workflow to something nicer - than the ID. + def __getattr__(self, attr: str) -> "ParameterReference[U] | Any": + """Retrieve a field. + + Args: + attr: attribute to get. + + Returns: a reference to the field within this parameter, possibly nesting if we are already + referencing a field. """ - return self.__workflow__.remap(self.parameter.name) + try: + return self[attr] + except KeyError as _: + return super().__getattribute__(attr) + + def __repr__(self) -> str: + """Hashable reference to the step (and field).""" + try: + typ = self.__type__.__name__ + except AttributeError: + typ = str(self.__type__) + name = self._.unique_name + self.__field_suffix__ + return f"{typ}|:param:{name}" def __hash__(self) -> int: """Hash to parameter. @@ -984,7 +1399,7 @@ def __hash__(self) -> int: Returns: Unique hash corresponding to the parameter. """ - return hash(self.parameter) + return hash((self._.parameter, self.__field__)) def __eq__(self, other: object) -> bool: """Compare two references. @@ -995,42 +1410,98 @@ def __eq__(self, other: object) -> bool: Returns: True if the other parameter reference is materially the same, otherwise False. """ + # We are equal to a parameter if we are a direct, fieldless, reference to it. return ( - isinstance(other, ParameterReference) and self.parameter == other.parameter + (isinstance(other, ParameterReference) and self._.parameter == other._.parameter and self.__field__ == other.__field__) ) + def __make_reference__(self, **kwargs: Any) -> "ParameterReference[U]": + """Get a reference for the same parameter.""" + if "workflow" not in kwargs: + kwargs["workflow"] = self.__workflow__ + return self._.parameter.make_reference(**kwargs) -U = TypeVar("U") +class IterableParameterReference(IterableMixin[U], ParameterReference[U]): + """Iterable form of parameter references.""" + def __iter__(self) -> Generator[Reference[U], None, None]: + """Iterate over this parameter. + + Returns: + Get values back for this parameter, normally references of the appropriate type. + """ + inner, metadata = strip_annotations(self.__type__) + if metadata and "AtRender" in metadata and isinstance(self.__default__, Iterable): + yield from self.__default__ + else: + yield from super().__iter__() + + def __inner_iter__(self) -> Generator[Any, None, None]: + """Iterate over this parameter. + This is intended for overriding as a convenience way of iterating over a parameter, + and the yielded values will be ignored, in favour of new references. -class StepReference(Generic[U], Reference): + Returns: a generator of any type. + """ + inner, metadata = strip_annotations(self.__type__) + if self.__fixed_len__ is not None: + yield from range(self.__fixed_len__) + elif metadata and "Fixed" in metadata and isinstance(self.__default__, Sized): + yield from range(len(self.__default__)) + else: + while True: + yield None + + def __len__(self) -> int: + """If it is possible to get a hard-codeable length from this iterable parameter, do so.""" + inner, metadata = strip_annotations(self.__type__) + if metadata and "Fixed" in metadata and isinstance(self.__default__, Sized): + return len(self.__default__) + return super().__len__() + +class StepReference(FieldableMixin, Reference[U]): """Reference to an individual `Step`. Allows us to refer to the outputs of a `Step` in subsequent `Step` arguments. Attributes: - step: `Step` referred to. + _: metadata wrapping the `Step` referred to. """ step: BaseStep - _tethered_workflow: Workflow | None - _field: str | None - typ: type[U] - @property - def field(self) -> str: - """Field within the result. + class StepReferenceMetadata: + """Wrapper for any metadata that we would not want to conflict with fieldnames. - Explicitly set field (within an attrs-class) or `out`. - - Returns: - Field name. + Attributes: + step: the step being wrapped. + _typ: the type to return, if overriding the step's own type, or None. """ - return self._field or "out" + + def __init__( + self, step: BaseStep, typ: type[U] | None = None + ): + """Initialize the reference. + + Args: + workflow: `Workflow` that this is tied to. + step: `Step` that this refers to. + typ: the type that the step will output. + field: if provided, a specific field to pull out of an attrs result class. + """ + self.step = step + self._typ = typ + + @property + def return_type(self) -> type: + """Type of this reference, which may be overridden or may be the step's own type.""" + return self._typ or self.step.return_type + + _: StepReferenceMetadata def __init__( - self, workflow: Workflow, step: BaseStep, typ: type[U], field: str | None = None + self, step: BaseStep, *args: Any, typ: type[U] | None = None, **kwargs: Any ): """Initialize the reference. @@ -1039,21 +1510,26 @@ def __init__( step: `Step` that this refers to. typ: the type that the step will output. field: if provided, a specific field to pull out of an attrs result class. + *args: arguments for other initializers. + **kwargs: arguments for other initializers. """ - self.step = step - self._field = field - self.typ = typ - self._tethered_workflow = None + typ = typ or step.return_type + self._ = self.StepReferenceMetadata(step, typ=typ) + super().__init__(*args, typ=typ, **kwargs) def __str__(self) -> str: """Global description of the reference.""" - return f"{self.step.id}/{self.field}" + return self.__name__ def __repr__(self) -> str: """Hashable reference to the step (and field).""" - return f"{self.step.id}/{self.field}" + return self._.step.id + self.__field_suffix__ - def __getattr__(self, attr: str) -> "StepReference[Any]": + def __hash__(self) -> int: + """Hashable value for this workflow.""" + return hash((repr(self), id(self.__workflow__))) + + def __getitem__(self, attr: str) -> "StepReference[Any]": """Reference to a field within this result, if possible. If the result is an attrs-class or dataclass, this will pull out an individual @@ -1069,101 +1545,79 @@ def __getattr__(self, attr: str) -> "StepReference[Any]": AttributeError: if this field is not present in the dataclass. RuntimeError: if this field is not available, or we do not have a structured result. """ - if self._field is None: - typ: type | None - if attr_has(self.typ): - resolve_types(self.typ) - typ = getattr(attrs_fields(self.typ), attr).type - elif is_dataclass(self.typ): - matched = [ - field for field in dataclass_fields(self.typ) if field.name == attr - ] - if not matched: - raise AttributeError(f"Field {attr} not present in dataclass") - typ = matched[0].type - elif isinstance(self.step, FactoryCall): - typ = self.step.return_type - else: - typ = None - - if typ: - return self.__class__( - workflow=self.__workflow__, step=self.step, typ=typ, field=attr + try: + return self.find_field( + workflow=self.__workflow__, step=self._.step, field=attr + ) + except AttributeError as exc: + raise KeyError( + f"Key not found in {self.__root_name__} ({type(self)}:{self.__type__}): {attr}" + + ( + ". This could be because you are trying to iterate/index a reference whose type is not definitely iterable - double check your typehints." + if isinstance(attr, int) else "" ) - raise AttributeError( - "Can only get attribute of a StepReference representing an attrs-class or dataclass" - ) + ) from exc - @property - def return_type(self) -> type[U]: - """Type that this step reference will resolve to. + def __getattr__(self, attr: str) -> "StepReference[U] | Any": + """Retrieve a field within this workflow.""" + try: + return self[attr] + except KeyError as exc: + try: + return super().__getattribute__(attr) + except AttributeError as inner_exc: + raise inner_exc from exc - Returns: - Python type indicating the final result type. - """ - return self.typ + @property + def __type__(self) -> type: + """Get the type to which this step reference refers.""" + return self._.return_type @property - def name(self) -> str: + def __root_name__(self) -> str: """Reference based on the named step. May be remapped by the workflow to something nicer than the ID. """ - return f"{self.step.name}/{self.field}" - - @property - def __type__(self) -> Any: - """Type of the step's referenced value.""" - return self.step.return_type + return self._.step.name @property - def __workflow__(self) -> Workflow: + def __workflow__(self) -> WorkflowProtocol: """Related workflow. Returns: Workflow that the referee is related to. """ - return self._tethered_workflow or self.step.__workflow__ + return self._.step.__workflow__ @__workflow__.setter def __workflow__(self, workflow: Workflow) -> None: """Sets related workflow. - We update the tethered workflow. If the step is missing from - this workflow then, by construction, it should have at least - been through an indexing process once, so we should be able - to get it back by name. - Args: workflow: workflow to update the step """ - self._tethered_workflow = workflow - if self._tethered_workflow: - if self.step not in self._tethered_workflow.steps: - self.step = self._tethered_workflow._indexed_steps[self.step.id] - - @__workflow__.setter - def __workflow__(self, workflow: Workflow) -> None: - self.step.set_workflow(workflow) - - -def merge_workflows(*workflows: Workflow) -> Workflow: - """Combine several workflows into one. - - Merges a series of workflows by combining steps and tasks. - - Argument: - *workflows: series of workflows to combine. - - Returns: - One workflow with all steps. - """ - base = list(workflows).pop() - for workflow in workflows: - base = Workflow.assimilate(base, workflow) - return base - + self._.step.set_workflow(workflow) + + def __make_reference__(self, **kwargs: Any) -> "StepReference[U]": + """Create a new reference for the same step.""" + if "workflow" not in kwargs: + kwargs["workflow"] = self.__workflow__ + return self._.step.make_reference(**kwargs) + +class IterableStepReference(IterableMixin[U], StepReference[U]): + """Iterable form of a step reference.""" + def __iter__(self) -> Generator[Reference[U], None, None]: + """Gets a sentinel value for iterating over this step's results. + + Bear in mind that this means an iterable step reference will iterate exactly once, + and return this Generator. The IteratedGenerator can be used as an iterator itself, + and will yield an infinite sequence of numbered field references, which renderers can use + for zipping with a fixed length iterator, or simply prepping fieldnames for serialization. + """ + # We cast this so that we can treat a step iterator as if it really loops over results. + yield cast(Reference[U], IteratedGenerator(self)) def is_task(task: Lazy) -> bool: """Decide whether this is a task. @@ -1179,3 +1633,104 @@ def is_task(task: Lazy) -> bool: True if `task` is indeed a task. """ return isinstance(task, LazyEvaluation) + +def expr_to_references(expression: Any, remap: Callable[[Any], Any] | None = None) -> tuple[ExprType, list[Reference[Any] | Parameter[Any]]]: + """Pull out any references, or other free symbols, from an expression. + + Args: + expression: normally a reference that can be immediately returned, but may be a sympy expression or a dict/tuple/list/etc. of such. + remap: a callable to project certain values down before extracting symbols, or None. + + Returns: a pair of the expression with any applied simplifications/standardizations, and the list of References/Parameters found. + """ + to_check: list[Reference[Any] | Parameter[Any]] = [] + def _to_expr(value: Any) -> ExprType: + if remap and (res := remap(value)) is not None: + return _to_expr(res) + + if isinstance(value, Reference): + to_check.append(value) + return value + + if value is None: + return None + + if isinstance(value, Symbol): + return value + elif isinstance(value, Basic): + for sym in value.free_symbols: + new_sym = _to_expr(sym) + if new_sym != sym: + value = value.subs(sym, new_sym) + return value + + if is_dataclass(value) or attr_has(value): + if is_dataclass(value): + fields = list(dataclass_fields(value)) + else: + fields = list({field for field in attrs_fields(value)}) + for field in fields: + if hasattr(value, field.name) and isinstance((val := getattr(value, field.name)), Reference): + setattr(value, field.name, _to_expr(val)) + return value + + # We need to look inside a Raw, but we do not want to lose it if + # we do not need to. + retval = value + if isinstance(value, Raw): + value = value.value + + if isinstance(value, Mapping): + dct = {key: _to_expr(val) for key, val in value.items()} + if dct == value: + return retval + # We try to reinstantiate this, but there will be an error otherwise. + # TODO: we could check this with a protocol, but some care would be needed to ensure all valid + # cases are covered. + return value.__class__(dct) # type: ignore + elif not isinstance(value, str | bytes) and isinstance(value, Iterable): + lst = (tuple if isinstance(value, tuple) else list)(_to_expr(v) for v in value) + if lst == value: + return retval + return lst + return retval + + expression = _to_expr(expression) + + #if {sym for sym in symbols if not is_raw(sym)} != set(to_check): + # print(symbols, to_check, [type(r) for r in symbols]) + # raise RuntimeError("The only symbols allowed are references (to e.g. step or parameter)") + return expression, to_check + +def unify_workflows(expression: Any, base_workflow: Workflow | None, set_only: bool = False) -> tuple[Basic | None, Workflow | None]: + """Takes an expression and ensures all of its references exist in the same workflow. + + Args: + expression: any valid argument to `dewret.workflow.expr_to_references`. + base_workflow: the desired workflow to align on, or None. + set_only: whether to bother assimilating all the workflows (False), or to assume that has been done (False). + + Returns: a pair of the (standardized) expression and the workflow that now contains it. + """ + expression, to_check = expr_to_references(expression) + if not to_check: + return expression, base_workflow + + # Build a unified workflow + collected_workflow = base_workflow or next(iter(to_check)).__workflow__ + # This is a bit of an artifical check. We could rule out non-Workflow WorkflowProtocol implementers and simplify this. + if not isinstance(collected_workflow, Workflow): + raise NotImplementedError("Can only unify workflows on the same type: Workflow") + if not set_only: + for step_result in to_check: + new_workflow = step_result.__workflow__ + if not isinstance(new_workflow, Workflow): + raise NotImplementedError("Can only unify workflows on the same type: Workflow") + if collected_workflow != new_workflow and collected_workflow and new_workflow: + collected_workflow = Workflow.assimilate(collected_workflow, new_workflow) + + # Make sure all the results share it + for step_result in to_check: + step_result.__workflow__ = collected_workflow + + return expression, collected_workflow diff --git a/tests/_lib/extra.py b/tests/_lib/extra.py index 955270de..3293044d 100644 --- a/tests/_lib/extra.py +++ b/tests/_lib/extra.py @@ -1,7 +1,17 @@ -from dewret.tasks import task, subworkflow +from dewret.tasks import task, workflow + +from .other import nothing JUMP: float = 1.0 +test: float = nothing + +@workflow() +def try_nothing() -> int: + """Check that we can see AtRender in another module.""" + if nothing: + return increment(num=1) + return increment(num=0) @task() def increase(num: int | float) -> float: @@ -33,7 +43,15 @@ def sum(left: int | float, right: int | float) -> int | float: return left + right -@subworkflow() +@task() +def pi() -> float: + """Returns pi.""" + import math + + return math.pi + + +@workflow() def triple_and_one(num: int | float) -> int | float: """Triple a number by doubling and adding again, then add 1.""" return sum(left=sum(left=double(num=num), right=num), right=1) @@ -43,3 +61,11 @@ def triple_and_one(num: int | float) -> int | float: def tuple_float_return() -> tuple[float, float]: """Return a tuple of floats.""" return 48.856667, 2.351667 + +@task() +def reverse_list(to_sort: list[int | float]) -> list[int | float]: + return to_sort[::-1] + +@task() +def max_list(lst: list[int | float]) -> int | float: + return max(lst) diff --git a/tests/_lib/frender.py b/tests/_lib/frender.py new file mode 100644 index 00000000..c2c25b49 --- /dev/null +++ b/tests/_lib/frender.py @@ -0,0 +1,111 @@ +"""Testing example renderer. + +'Friendly render', outputting human-readable descriptions. +""" + +from textwrap import indent +from typing import Unpack, TypedDict +from dataclasses import dataclass + +from dewret.core import set_render_configuration +from dewret.workflow import Workflow, Step, NestedStep +from dewret.render import base_render + +from .extra import JUMP + +class FrenderRendererConfiguration(TypedDict): + allow_complex_types: bool + +def default_config() -> FrenderRendererConfiguration: + return FrenderRendererConfiguration({ + "allow_complex_types": True + }) + +@dataclass +class NestedStepDefinition: + name: str + subworkflow_name: str + + @classmethod + def from_nested_step(cls, nested_step: NestedStep) -> "NestedStepDefinition": + return cls( + name=nested_step.name, + subworkflow_name=nested_step.subworkflow.name + ) + + def render(self) -> str: + return \ +f""" +A portal called {self.name} to another workflow, +whose name is {self.subworkflow_name} +""" + +@dataclass +class StepDefinition: + name: str + + @classmethod + def from_step(cls, step: Step) -> "StepDefinition": + return cls( + name=step.name + ) + + def render(self) -> str: + return \ +f""" +Something called {self.name} +""" + + +@dataclass +class WorkflowDefinition: + name: str + steps: list[StepDefinition | NestedStepDefinition] + + @classmethod + def from_workflow(cls, workflow: Workflow) -> "WorkflowDefinition": + steps: list[StepDefinition | NestedStepDefinition] = [] + for step in workflow.indexed_steps.values(): + if isinstance(step, Step): + steps.append(StepDefinition.from_step(step)) + elif isinstance(step, NestedStep): + steps.append(NestedStepDefinition.from_nested_step(step)) + else: + raise RuntimeError(f"Unrecognised step type: {type(step)}") + + try: + name = workflow.name + except NameError: + name = "Work Doe" + return cls(name=name, steps=steps) + + def render(self) -> str: + steps = "\n".join('* ' + indent(step.render(), ' ')[3:] for step in self.steps) + return \ +f""" +I found a workflow called {self.name}. +It has {len(self.steps)} steps! +They are: +{steps} +It probably got made with JUMP={JUMP} +""" + +def render_raw( + workflow: Workflow, **kwargs: Unpack[FrenderRendererConfiguration] +) -> dict[str, str]: + """Render to a dict-like structure. + + Args: + workflow: workflow to evaluate result. + **kwargs: additional configuration arguments - these should match CWLRendererConfiguration. + + Returns: + Reduced form as a native Python dict structure for + serialization. + """ + # TODO: work out how to handle these hints correctly. + set_render_configuration(kwargs) # type: ignore + return base_render( + workflow, + lambda workflow: WorkflowDefinition.from_workflow(workflow).render() + ) diff --git a/tests/_lib/nonfrender.py b/tests/_lib/nonfrender.py new file mode 100644 index 00000000..95079bda --- /dev/null +++ b/tests/_lib/nonfrender.py @@ -0,0 +1,22 @@ +"""Testing example renderer. + +Correctly fails to import to show a broken module being handled. +""" + +from typing import Unpack, TypedDict +from dewret.workflow import Workflow + +from .extra import JUMP + +class NonfrenderRendererConfiguration(TypedDict): + allow_complex_types: bool + +# This should fail to load as default_config is not present. However it would +# ignore the fact that the return type is not a (subtype of) dict[str, RawType] +# def default_config() -> int: +# return 3 + +def render_raw( + workflow: Workflow, **kwargs: Unpack[NonfrenderRendererConfiguration] +) -> dict[str, str]: + return {"JUMP": str(JUMP)} diff --git a/tests/_lib/other.py b/tests/_lib/other.py new file mode 100644 index 00000000..891fbe56 --- /dev/null +++ b/tests/_lib/other.py @@ -0,0 +1,3 @@ +from dewret.annotations import AtRender + +nothing: AtRender[bool] = True diff --git a/tests/_lib/unfrender.py b/tests/_lib/unfrender.py new file mode 100644 index 00000000..5057c14c --- /dev/null +++ b/tests/_lib/unfrender.py @@ -0,0 +1,26 @@ +"""Testing example renderer. + +Correctly fails to import to show a broken module being handled. +""" + +from typing import Unpack, TypedDict +from dewret.workflow import Workflow + +# This lacking a relative import, while extra itself +# uses one is what breaks the module. It cannot be both +# a package and not-a-package. This is importable by +# adding a . before extra. +from extra import JUMP # type: ignore + +class UnfrenderRendererConfiguration(TypedDict): + allow_complex_types: bool + +def default_config() -> UnfrenderRendererConfiguration: + return { + "allow_complex_types": True + } + +def render_raw( + workflow: Workflow, **kwargs: Unpack[UnfrenderRendererConfiguration] +) -> dict[str, str]: + return {"JUMP": str(JUMP)} diff --git a/tests/test_annotations.py b/tests/test_annotations.py new file mode 100644 index 00000000..f62479d8 --- /dev/null +++ b/tests/test_annotations.py @@ -0,0 +1,200 @@ +"""Verify we can interrogate annotations.""" + +import pytest +import yaml + +from dewret.tasks import construct, workflow, TaskException +from dewret.renderers.cwl import render +from dewret.annotations import AtRender, FunctionAnalyser, Fixed +from dewret.core import set_configuration + +from ._lib.extra import increment, sum, try_nothing + +ARG1: AtRender[bool] = True +ARG2: bool = False + + +class MyClass: + """A mock class to wrap values and mock processing of data.""" + + def method(self, arg1: bool, arg2: AtRender[int]) -> float: + """A mock method to simulate the behavior of processing and rendering values.""" + arg3: float = 7.0 + arg4: AtRender[float] = 8.0 + return arg1 + arg2 + arg3 + arg4 + int(ARG1) + int(ARG2) + + +def fn(arg5: int, arg6: AtRender[int]) -> float: + """A mock function to simulate processing of rendered values with integer inputs.""" + arg7: float = 7.0 + arg8: AtRender[float] = 8.0 + return arg5 + arg6 + arg7 + arg8 + int(ARG1) + int(ARG2) + + +@workflow() +def to_int_bad(num: int, should_double: bool) -> int | float: + """A mock workflow that casts to an int with a wrong type for handling doubles.""" + return increment(num=num) if should_double else sum(left=num, right=num) + + +@workflow() +def to_int(num: int, should_double: AtRender[bool]) -> int | float: + """A mock workflow that casts to an int with a right type for handling doubles.""" + return increment(num=num) if should_double else sum(left=num, right=num) + + +def test_can_analyze_annotations() -> None: + """Test that annotations can be correctly analyzed within methods and functions. + + Verifies that the `FunctionAnalyser` finds which arguments + and global variables are derived from `dewret.annotations.AtRender`. + """ + my_obj = MyClass() + + analyser = FunctionAnalyser(my_obj.method) + assert analyser.argument_has("arg1", AtRender, exhaustive=True) is False + assert analyser.argument_has("arg3", AtRender, exhaustive=True) is False + assert analyser.argument_has("ARG2", AtRender, exhaustive=True) is False + assert analyser.argument_has("arg2", AtRender, exhaustive=True) is True + assert ( + analyser.argument_has("arg4", AtRender, exhaustive=True) is False + ) # Not a global/argument + assert analyser.argument_has("ARG1", AtRender, exhaustive=True) is True + assert analyser.argument_has("ARG1", AtRender) is False + + analyser = FunctionAnalyser(fn) + assert analyser.argument_has("arg5", AtRender, exhaustive=True) is False + assert analyser.argument_has("arg7", AtRender, exhaustive=True) is False + assert analyser.argument_has("ARG2", AtRender, exhaustive=True) is False + assert analyser.argument_has("arg6", AtRender, exhaustive=True) is True + assert ( + analyser.argument_has("arg8", AtRender, exhaustive=True) is False + ) # Not a global/argument + assert analyser.argument_has("ARG1", AtRender, exhaustive=True) is True + assert analyser.argument_has("ARG1", AtRender) is False + + +def test_at_render() -> None: + """Test the rendering of workflows with `dewret.annotations.AtRender` and exceptions handling.""" + with pytest.raises(TaskException) as _: + result = to_int_bad(num=increment(num=3), should_double=True) + wkflw = construct(result, simplify_ids=True) + + result = to_int(num=increment(num=3), should_double=True) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) + rendered = subworkflows["__root__"] + assert rendered == yaml.safe_load(""" + cwlVersion: 1.2 + class: Workflow + inputs: + increment-1-num: + default: 3 + label: num + type: int + outputs: + out: + label: out + outputSource: to_int-1/out + type: + - int + - float + steps: + increment-1: + in: + num: + source: increment-1-num + out: + - out + run: increment + to_int-1: + in: + num: + source: increment-1/out + should_double: + default: True + out: + - out + run: to_int + """) + + result = to_int(num=increment(num=3), should_double=False) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) + rendered = subworkflows["__root__"] + assert rendered == yaml.safe_load(""" + cwlVersion: 1.2 + class: Workflow + inputs: + increment-1-num: + default: 3 + label: num + type: int + outputs: + out: + label: out + outputSource: to_int-1/out + type: + - int + - float + steps: + increment-1: + in: + num: + source: increment-1-num + out: + - out + run: increment + to_int-1: + in: + num: + source: increment-1/out + should_double: + default: False + out: + - out + run: to_int + """) + + +def test_at_render_between_modules() -> None: + """Test rendering of workflows across different modules using `dewret.annotations.AtRender`.""" + result = try_nothing() + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) + subworkflows["__root__"] + + +list_2: Fixed[list[int]] = [0, 1, 2, 3] + + +def test_can_loop_over_fixed_length() -> None: + """Test looping over a fixed-length list using `dewret.annotations.Fixed`.""" + + @workflow() + def loop_over_lists(list_1: list[int]) -> list[int]: + result = [] + for a, b in zip(list_1, list_2, strict=False): + result.append(a + b + len(list_2)) + return result + + with set_configuration(flatten_all_nested=True): + result = loop_over_lists(list_1=[5, 6, 7, 8]) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) + rendered = subworkflows["__root__"] + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: {} + outputs: + out: + type: float + label: out + expression: '[4 + list_1[0] + list_2[0], 4 + list_1[1] + list_2[1], 4 + list_1[2] + list_2[2], + 4 + list_1[3] + list_2[3]]' + source: + - list_1 + - list_2 + steps: {} + """) diff --git a/tests/test_configuration.py b/tests/test_configuration.py new file mode 100644 index 00000000..1d45b71f --- /dev/null +++ b/tests/test_configuration.py @@ -0,0 +1,71 @@ +"""Check configuration is consistent and usable.""" + +import yaml +import pytest +from dewret.tasks import construct, workflow, TaskException +from dewret.renderers.cwl import render +from dewret.core import set_configuration +from dewret.annotations import AtRender +from ._lib.extra import increment + + +@workflow() +def floor(num: int, expected: AtRender[bool]) -> int: + """Converts int/float to int.""" + from dewret.core import get_configuration + + if get_configuration("flatten_all_nested") != expected: + raise AssertionError( + f"Not expected configuration: {str(get_configuration('flatten_all_nested'))} != {expected}" + ) + return increment(num=num) + + +def test_cwl_with_parameter() -> None: + """Test workflows with configuration parameters.""" + with set_configuration(flatten_all_nested=True): + result = increment(num=floor(num=3, expected=True)) + workflow = construct(result, simplify_ids=True) + + with ( + pytest.raises(TaskException) as exc, + set_configuration(flatten_all_nested=False), + ): + result = increment(num=floor(num=3, expected=True)) + workflow = construct(result, simplify_ids=True) + assert "AssertionError" in str(exc.getrepr()) + + with set_configuration(flatten_all_nested=True): + result = increment(num=floor(num=3, expected=True)) + workflow = construct(result, simplify_ids=True) + rendered = render(workflow)["__root__"] + num_param = list(workflow.find_parameters())[0] + assert num_param + + assert rendered == yaml.safe_load(""" + cwlVersion: 1.2 + class: Workflow + inputs: + num: + label: num + type: int + default: 3 + outputs: + out: + label: out + outputSource: increment-1/out + type: int + steps: + increment-2: + run: increment + in: + num: + source: num + out: [out] + increment-1: + run: increment + in: + num: + source: increment-2/out + out: [out] + """) diff --git a/tests/test_cwl.py b/tests/test_cwl.py index bdf5a5f9..17bc30ab 100644 --- a/tests/test_cwl.py +++ b/tests/test_cwl.py @@ -1,13 +1,16 @@ """Verify CWL output is OK.""" import yaml +import pytest from datetime import datetime, timedelta -from dewret.tasks import construct, task, factory +from dewret.core import set_configuration +from dewret.tasks import construct, task, factory, TaskException from dewret.renderers.cwl import render from dewret.utils import hasher from dewret.workflow import param from ._lib.extra import ( + pi, increment, double, mod10, @@ -16,15 +19,6 @@ tuple_float_return, ) - -@task() -def pi() -> float: - """Returns pi.""" - import math - - return math.pi - - @task() def floor(num: int | float) -> int: """Converts int/float to int.""" @@ -50,7 +44,7 @@ def test_basic_cwl() -> None: """ result = pi() workflow = construct(result) - rendered = render(workflow) + rendered = render(workflow)["__root__"] hsh = hasher(("pi",)) assert rendered == yaml.safe_load(f""" @@ -83,7 +77,7 @@ def get_now() -> datetime: now = factory(get_now)() result = days_in_future(now=now, num=3) workflow = construct(result, simplify_ids=True) - rendered = render(workflow, allow_complex_types=True, factories_as_params=True) + rendered = render(workflow, allow_complex_types=True, factories_as_params=True)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -92,7 +86,7 @@ def get_now() -> datetime: days_in_future-1-num: default: 3 type: int - label: days_in_future-1-num + label: num get_now-1: label: get_now-1 type: datetime @@ -112,7 +106,7 @@ def get_now() -> datetime: out: [out] """) - rendered = render(workflow, allow_complex_types=True) + rendered = render(workflow, allow_complex_types=True)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -121,7 +115,7 @@ def get_now() -> datetime: days_in_future-1-num: default: 3 type: int - label: days_in_future-1-num + label: num outputs: out: label: out @@ -147,20 +141,57 @@ def test_cwl_with_parameter() -> None: """Check whether we can move raw input to parameters. Produces CWL for a call with a changeable raw value, that is converted - to a parameter, if and only if we are calling from outside a nested task. + to a parameter, if and only if we are calling from outside a subworkflow. """ result = increment(num=3) workflow = construct(result) - rendered = render(workflow) + rendered = render(workflow)["__root__"] num_param = list(workflow.find_parameters())[0] - hsh = hasher(("increment", ("num", f"int|:param:{num_param.unique_name}"))) + hsh = hasher(("increment", ("num", f"int|:param:{num_param._.unique_name}"))) assert rendered == yaml.safe_load(f""" cwlVersion: 1.2 class: Workflow inputs: increment-{hsh}-num: - label: increment-{hsh}-num + label: num + type: int + default: 3 + outputs: + out: + label: out + outputSource: increment-{hsh}/out + type: int + steps: + increment-{hsh}: + run: increment + in: + num: + source: increment-{hsh}-num + out: [out] + """) + +def test_cwl_with_positional_parameter() -> None: + """Check whether we can move raw input to parameters. + + Produces CWL for a call with a changeable raw value, that is converted + to a parameter, if and only if we are calling from outside a subworkflow. + """ + with pytest.raises(TaskException) as _: + result = increment(3) + with set_configuration(allow_positional_args=True): + result = increment(3) + workflow = construct(result) + rendered = render(workflow)["__root__"] + num_param = list(workflow.find_parameters())[0] + hsh = hasher(("increment", ("num", f"int|:param:{num_param._.unique_name}"))) + + assert rendered == yaml.safe_load(f""" + cwlVersion: 1.2 + class: Workflow + inputs: + increment-{hsh}-num: + label: num type: int default: 3 outputs: @@ -187,7 +218,7 @@ def test_cwl_without_default() -> None: result = increment(num=my_param) workflow = construct(result) - rendered = render(workflow) + rendered = render(workflow)["__root__"] hsh = hasher(("increment", ("num", "int|:param:my_param"))) assert rendered == yaml.safe_load(f""" @@ -217,7 +248,9 @@ def test_cwl_with_subworkflow() -> None: my_param = param("num", typ=int) result = increment(num=floor(num=triple_and_one(num=increment(num=my_param)))) workflow = construct(result, simplify_ids=True) - rendered, subworkflows = render(workflow) + subworkflows = render(workflow) + rendered = subworkflows["__root__"] + del subworkflows["__root__"] assert len(subworkflows) == 1 assert isinstance(subworkflows, dict) @@ -233,7 +266,7 @@ def test_cwl_with_subworkflow() -> None: outputs: out: label: out - outputSource: increment-2/out + outputSource: increment-1/out type: int steps: floor-1: @@ -242,13 +275,13 @@ def test_cwl_with_subworkflow() -> None: source: triple_and_one-1/out out: [out] run: floor - increment-1: + increment-2: in: num: source: num out: [out] run: increment - increment-2: + increment-1: in: num: source: floor-1/out @@ -257,7 +290,7 @@ def test_cwl_with_subworkflow() -> None: triple_and_one-1: in: num: - source: increment-1/out + source: increment-2/out out: [out] run: triple_and_one """) @@ -269,14 +302,10 @@ def test_cwl_with_subworkflow() -> None: num: label: num type: int - sum-1-2-right: - default: 1 - label: sum-1-2-right - type: int outputs: out: label: out - outputSource: sum-1-2/out + outputSource: sum-1-1/out type: - int - float @@ -288,7 +317,7 @@ def test_cwl_with_subworkflow() -> None: out: - out run: double - sum-1-1: + sum-1-2: in: left: source: double-1-1/out @@ -297,12 +326,12 @@ def test_cwl_with_subworkflow() -> None: out: - out run: sum - sum-1-2: + sum-1-1: in: left: - source: sum-1-1/out + source: sum-1-2/out right: - source: sum-1-2-right + default: 1 out: - out run: sum @@ -316,26 +345,26 @@ def test_cwl_references() -> None: """ result = double(num=increment(num=3)) workflow = construct(result) - rendered = render(workflow) + rendered = render(workflow)["__root__"] num_param = list(workflow.find_parameters())[0] hsh_increment = hasher( - ("increment", ("num", f"int|:param:{num_param.unique_name}")) + ("increment", ("num", f"int|:param:{num_param._.unique_name}")) ) - hsh_double = hasher(("double", ("num", f"increment-{hsh_increment}/out"))) + hsh_double = hasher(("double", ("num", f"increment-{hsh_increment}"))) assert rendered == yaml.safe_load(f""" cwlVersion: 1.2 class: Workflow inputs: increment-{hsh_increment}-num: - label: increment-{hsh_increment}-num + label: num type: int default: 3 outputs: out: label: out outputSource: double-{hsh_double}/out - type: + type: - int - float steps: @@ -361,25 +390,21 @@ def test_complex_cwl_references() -> None: """ result = sum(left=double(num=increment(num=23)), right=mod10(num=increment(num=23))) workflow = construct(result, simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 class: Workflow inputs: increment-1-num: - label: increment-1-num - type: int - default: 23 - increment-2-num: - label: increment-2-num + label: num type: int default: 23 outputs: out: label: out outputSource: sum-1/out - type: + type: - int - float steps: @@ -389,17 +414,11 @@ def test_complex_cwl_references() -> None: num: source: increment-1-num out: [out] - increment-2: - run: increment - in: - num: - source: increment-2-num - out: [out] double-1: run: double in: num: - source: increment-2/out + source: increment-1/out out: [out] mod10-1: run: mod10 @@ -423,8 +442,10 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: my_param = param("num", typ=int) result = increment(num=floor(num=triple_and_one(num=sum(left=my_param, right=3)))) workflow = construct(result, simplify_ids=True) - rendered, subworkflows = render(workflow) + subworkflows = render(workflow) + rendered = subworkflows["__root__"] + del subworkflows["__root__"] assert len(subworkflows) == 1 assert isinstance(subworkflows, dict) name, subworkflow = list(subworkflows.items())[0] @@ -438,7 +459,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: type: int sum-1-right: default: 3 - label: sum-1-right + label: right type: int outputs: out: @@ -484,14 +505,10 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: int, float ] - sum-1-2-right: - default: 1 - label: sum-1-2-right - type: int outputs: out: label: out - outputSource: sum-1-2/out + outputSource: sum-1-1/out type: - int - float @@ -503,7 +520,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: out: - out run: double - sum-1-1: + sum-1-2: in: left: source: double-1-1/out @@ -512,12 +529,12 @@ def test_cwl_with_subworkflow_and_raw_params() -> None: out: - out run: sum - sum-1-2: + sum-1-1: in: left: - source: sum-1-1/out + source: sum-1-2/out right: - source: sum-1-2-right + default: 1 out: - out run: sum @@ -531,8 +548,7 @@ def test_tuple_floats() -> None: """ result = tuple_float_return() workflow = construct(result, simplify_ids=True) - rendered = render(workflow) - print(yaml.dump(rendered)) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 class: Workflow @@ -541,11 +557,10 @@ def test_tuple_floats() -> None: out: label: out outputSource: tuple_float_return-1/out - type: - items: - - type: float - - type: float - type: array + items: + - float + - float + type: array steps: tuple_float_return-1: run: tuple_float_return diff --git a/tests/test_errors.py b/tests/test_errors.py index 45496c61..3b21c332 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -2,8 +2,9 @@ import pytest from dewret.workflow import Task, Lazy -from dewret.tasks import construct, task, nested_task, TaskException -from ._lib.extra import increment # noqa: F401 +from dewret.tasks import construct, task, workflow, TaskException +from dewret.annotations import AtRender +from ._lib.extra import increment, pi, reverse_list # noqa: F401 @task() # This is expected to be the line number shown below. @@ -12,10 +13,10 @@ def add_task(left: int, right: int) -> int: return left + right -ADD_TASK_LINE_NO = 9 +ADD_TASK_LINE_NO: int = 10 -@nested_task() +@workflow() def badly_add_task(left: int, right: int) -> int: """Badly attempts to add two numbers.""" return add_task(left=left) # type: ignore @@ -35,14 +36,6 @@ def __init__(self, task: Task): ... -@task() -def pi() -> float: - """Get pi from math package.""" - import math - - return math.pi - - @task() def pi_exported_from_math() -> float: """Get pi from math package by name.""" @@ -52,9 +45,9 @@ def pi_exported_from_math() -> float: @task() -def test_recursive() -> float: +def try_recursive() -> float: """Get pi from math package by name.""" - return test_recursive() + return try_recursive() @task() @@ -88,15 +81,15 @@ def pi_with_invisible_module_task() -> float: return extra.double(3.14 / 2) -@nested_task() +@workflow() def unacceptable_object_usage() -> int: """Invalid use of custom object within nested task.""" return MyStrangeClass(add_task(left=3, right=4)) # type: ignore -@nested_task() -def unacceptable_nested_return(int_not_global: bool) -> int | Lazy: - """Bad nested_task that fails to return a task.""" +@workflow() +def unacceptable_nested_return(int_not_global: AtRender[bool]) -> int | Lazy: + """Bad subworkflow that fails to return a task.""" add_task(left=3, right=4) return 7 if int_not_global else ADD_TASK_LINE_NO @@ -110,17 +103,16 @@ def test_missing_arguments_throw_error() -> None: WARNING: in keeping with Python principles, this does not error if types mismatch, but `mypy` should. You **must** type-check your code to catch these. """ - result = add_task(left=3) # type: ignore with pytest.raises(TaskException) as exc: - construct(result) + add_task(left=3) # type: ignore end_section = str(exc.getrepr())[-500:] assert str(exc.value) == "missing a required argument: 'right'" assert "Task add_task declared in at " in end_section assert f"test_errors.py:{ADD_TASK_LINE_NO}" in end_section -def test_missing_arguments_throw_error_in_nested_task() -> None: - """Check whether omitting a required argument within a nested_task will give an error. +def test_missing_arguments_throw_error_in_subworkflow() -> None: + """Check whether omitting a required argument within a subworkflow will give an error. Since we do not run the original function, it is up to dewret to check that the signature is, at least, acceptable to Python. @@ -128,9 +120,8 @@ def test_missing_arguments_throw_error_in_nested_task() -> None: WARNING: in keeping with Python principles, this does not error if types mismatch, but `mypy` should. You **must** type-check your code to catch these. """ - result = badly_add_task(left=3, right=4) with pytest.raises(TaskException) as exc: - construct(result) + badly_add_task(left=3, right=4) end_section = str(exc.getrepr())[-500:] assert str(exc.value) == "missing a required argument: 'right'" assert "def badly_add_task" in end_section @@ -144,9 +135,8 @@ def test_positional_arguments_throw_error() -> None: We can use default and non-default arguments, but we expect them to _always_ be named. """ - result = add_task(3, right=4) with pytest.raises(TaskException) as exc: - construct(result) + add_task(3, right=4) assert ( str(exc.value) .strip() @@ -154,21 +144,20 @@ def test_positional_arguments_throw_error() -> None: ) -def test_nesting_non_nested_tasks_throws_error() -> None: - """Ensure nesting is only allow in nested_tasks. +def test_nesting_non_subworkflows_throws_error() -> None: + """Ensure nesting is only allow in subworkflows. Nested tasks must be evaluated at construction time, and there is no concept of task calls that are not resolved during construction, so a task should not be called inside a non-nested task. """ - result = badly_wrap_task() with pytest.raises(TaskException) as exc: - construct(result) + badly_wrap_task() assert ( str(exc.value) .strip() .startswith( - "You referenced a task add_task inside another task badly_wrap_task, but it is not a nested_task" + "You referenced a task add_task inside another task badly_wrap_task, but it is not a workflow" ) ) @@ -195,19 +184,18 @@ def test_nesting_does_not_identify_imports_as_nesting() -> None: pi_hidden_by_math, pi_hidden_by_math_2, ] - bad = [test_recursive, pi_with_visible_module_task] + bad = [try_recursive, pi_with_visible_module_task] for tsk in bad: - result = tsk() with pytest.raises(TaskException) as exc: - construct(result) + tsk() assert str(exc.value).strip().startswith("You referenced a task") for tsk in good: result = tsk() construct(result) -def test_normal_objects_cannot_be_used_in_nested_tasks() -> None: - """Most entities cannot appear in a nested_task, ensure we catch them. +def test_normal_objects_cannot_be_used_in_subworkflows() -> None: + """Most entities cannot appear in a subworkflow, ensure we catch them. Since the logic in nested tasks has to be embedded explicitly in the workflow, complex types are not necessarily representable, and in most cases, we would not @@ -215,34 +203,65 @@ def test_normal_objects_cannot_be_used_in_nested_tasks() -> None: Note: this may be mitigated with sympy support, to some extent. """ - result = unacceptable_object_usage() with pytest.raises(TaskException) as exc: - construct(result) + unacceptable_object_usage() assert ( str(exc.value) - == "Nested tasks must now only refer to global parameters, raw or tasks, not objects: MyStrangeClass" + == "Attempted to build a workflow from a return-value/result/expression with no references." ) -def test_nested_tasks_must_return_a_task() -> None: +def test_subworkflows_must_return_a_task() -> None: """Ensure nested tasks are lazy-evaluatable. A graph only makes sense if the edges connect, and nested tasks must therefore chain. As such, a nested task must represent a real subgraph, and return a node to pull it into the main graph. """ - result = unacceptable_nested_return(int_not_global=True) with pytest.raises(TaskException) as exc: + result = unacceptable_nested_return(int_not_global=True) construct(result) assert ( str(exc.value) - == "Task unacceptable_nested_return returned output of type , which is not a lazy function for this backend." + == "Attempted to build a workflow from a return-value/result/expression with no references." ) result = unacceptable_nested_return(int_not_global=False) + construct(result) + +bad_num = 3 +good_num: int = 4 + +def test_must_annotate_global() -> None: + """TODO: Docstrings.""" + worse_num = 3 + + @workflow() + def check_annotation() -> int | float: + return increment(num=bad_num) + with pytest.raises(TaskException) as exc: - construct(result) + check_annotation() + + assert ( + str(exc.value) + == "Could not find a type annotation for bad_num for check_annotation" + ) + + @workflow() + def check_annotation_2() -> int | float: + return increment(num=worse_num) + + with pytest.raises(TaskException) as exc: + check_annotation_2() + assert ( str(exc.value) - == "Task unacceptable_nested_return returned output of type , which is not a lazy function for this backend." + == "Cannot use free variables - please put worse_num at the global scope" ) + + @workflow() + def check_annotation_3() -> int | float: + return increment(num=good_num) + + check_annotation_3() diff --git a/tests/test_fieldable.py b/tests/test_fieldable.py new file mode 100644 index 00000000..8f7faaf3 --- /dev/null +++ b/tests/test_fieldable.py @@ -0,0 +1,410 @@ +"""Check field management works.""" + +from __future__ import annotations +import yaml +from dataclasses import dataclass + +from typing import Unpack, TypedDict + +from dewret.tasks import task, construct, workflow +from dewret.core import set_configuration +from dewret.workflow import param, StepReference +from dewret.renderers.cwl import render +from dewret.annotations import Fixed + +from ._lib.extra import mod10, sum, pi + + +@dataclass +class Sides: + """A dataclass representing the sides with `left` and `right` integers.""" + + left: int + right: int + + +SIDES: Sides = Sides(3, 6) + + +@workflow() +def sum_sides() -> float: + """Workflow that returns the sum of the `left` and `right` sides.""" + return sum(left=SIDES.left, right=SIDES.right) + + +def test_fields_of_parameters_usable() -> None: + """Test that fields of parameters can be used to construct and render a workflow correctly.""" + result = sum_sides() + wkflw = construct(result, simplify_ids=True) + rendered = render(wkflw, allow_complex_types=True)["sum_sides-1"] + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + SIDES: + label: SIDES + default: + left: 3 + right: 6 + type: record + fields: + left: + default: 3 + label: left + type: int + right: + default: 6 + label: right + type: int + label: SIDES + outputs: + out: + label: out + outputSource: sum-1-1/out + type: + - int + - float + steps: + sum-1-1: + in: + left: + source: SIDES/left + right: + source: SIDES/right + out: + - out + run: sum + """) + + +@dataclass +class MyDataclass: + """A dataclass with nested references to itself, containing `left` and `right` fields.""" + + left: int + right: "MyDataclass" + + +def test_can_get_field_reference_from_parameter() -> None: + """Test that field references can be retrieved from a parameter when constructing a workflow.""" + my_param = param("my_param", typ=MyDataclass) + result = sum( + left=my_param.left, right=sum(left=my_param.right.left, right=my_param.left) + ) + wkflw = construct(result, simplify_ids=True) + params = {(str(p), p.__type__) for p in wkflw.find_parameters()} + + assert params == {("my_param", MyDataclass)} + rendered = render(wkflw, allow_complex_types=True)["__root__"] + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + my_param: + label: my_param + type: MyDataclass + outputs: + out: + label: out + outputSource: sum-1/out + type: + - int + - float + steps: + sum-1: + in: + left: + source: my_param/left + right: + source: sum-2/out + out: + - out + run: sum + sum-2: + in: + left: + source: my_param/right/left + right: + source: my_param/left + out: + - out + run: sum + """) + + +def test_can_get_field_reference_if_parent_type_has_field() -> None: + """Test that a field reference is retrievable if the parent type has that field.""" + + @dataclass + class MyDataclass: + left: int + + my_param = param("my_param", typ=MyDataclass) + result = sum(left=my_param.left, right=my_param.left) + wkflw = construct(result, simplify_ids=True) + param_reference = list(wkflw.find_parameters())[0] + + assert str(param_reference.left) == "my_param/left" + assert param_reference.left.__type__ == int + + +def test_can_get_go_upwards_from_a_field_reference() -> None: + """Test that it's possible to move upwards in the hierarchy from a field reference.""" + + @dataclass + class MyDataclass: + left: int + right: "MyDataclass" + + my_param = param("my_param", typ=MyDataclass) + result = sum(left=my_param.left, right=my_param.left) + construct(result, simplify_ids=True) + + back = my_param.right.left.__field_up__() # type: ignore + assert str(back) == "my_param/right" + assert back.__type__ == MyDataclass + + +def test_can_get_field_references_from_dataclass() -> None: + """Test that field references can be extracted from a dataclass and used in workflows.""" + + @dataclass + class MyDataclass: + left: int + right: float + + @workflow() + def test_dataclass(my_dataclass: MyDataclass) -> MyDataclass: + result: MyDataclass = MyDataclass(left=mod10(num=my_dataclass.left), right=pi()) + return result + + @workflow() + def get_left(my_dataclass: MyDataclass) -> int: + return my_dataclass.left + + result = get_left( + my_dataclass=test_dataclass(my_dataclass=MyDataclass(left=3, right=4.0)) + ) + wkflw = construct(result, simplify_ids=True) + + assert isinstance(wkflw.result, StepReference) + assert str(wkflw.result) == "get_left-1" + assert wkflw.result.__type__ == int + + +class MyDict(TypedDict): + """A typed dictionary with `left` as an integer and `right` as a float.""" + + left: int + right: float + + +def test_can_get_field_references_from_typed_dict() -> None: + """Test that field references can be extracted from a custom typed dictionary and used in workflows.""" + + @workflow() + def test_dict(**my_dict: Unpack[MyDict]) -> MyDict: + result: MyDict = {"left": mod10(num=my_dict["left"]), "right": pi()} + return result + + result = test_dict(left=3, right=4.0) + wkflw = construct(result, simplify_ids=True) + + assert isinstance(wkflw.result, StepReference) + assert str(wkflw.result["left"]) == "test_dict-1/left" + assert wkflw.result["left"].__type__ == int + + +@dataclass +class MyListWrapper: + """A dataclass that wraps a list of integers.""" + + my_list: list[int] + + +def test_can_iterate() -> None: + """Test iteration over a list of tasks and validate positional argument handling.""" + + @task() + def test_task(alpha: int, beta: float, charlie: bool) -> int: + """Task that adds `alpha` and `beta` and returns the integer result.""" + return int(alpha + beta) + + @task() + def test_list() -> list[int | float]: + """Task that returns a list containing an integer and a float.""" + return [1, 2.0] + + @workflow() + def test_iterated() -> int: + """Workflow that tests task iteration over a list.""" + # We ignore the type as mypy cannot confirm that the length and types match the args. + return test_task(*test_list()) # type: ignore + + with set_configuration(allow_positional_args=True, flatten_all_nested=True): + result = test_iterated() + wkflw = construct(result, simplify_ids=True) + + rendered = render(wkflw, allow_complex_types=True)["__root__"] + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: {} + outputs: + out: + label: out + outputSource: test_task-1/out + type: int + steps: + test_list-1: + in: {} + out: + - out + run: test_list + test_task-1: + in: + alpha: + source: test_list-1[0] + beta: + source: test_list-1[1] + charlie: + source: test_list-1[2] + out: + - out + run: test_task + """) + + assert isinstance(wkflw.result, StepReference) + assert wkflw.result._.step.positional_args == { + "alpha": True, + "beta": True, + "charlie": True, + } + + @task() + def test_list_2() -> MyListWrapper: + """Task that returns a `MyListWrapper` containing a list of integers.""" + return MyListWrapper(my_list=[1, 2]) + + @workflow() + def test_iterated_2(my_wrapper: MyListWrapper) -> int: + """Workflow that tests iteration over a list in a `MyListWrapper`.""" + # mypy cannot confirm argument types match. + return test_task(*my_wrapper.my_list) # type: ignore + + with set_configuration(allow_positional_args=True, flatten_all_nested=True): + result = test_iterated_2(my_wrapper=test_list_2()) + wkflw = construct(result, simplify_ids=True) + + @task() + def test_list_3() -> Fixed[list[tuple[int, int]]]: + """Task that returns a list of integer tuples wrapped in a `dewret.annotations.Fixed` type.""" + return [(0, 1), (2, 3)] + + @workflow() + def test_iterated_3(param: Fixed[list[tuple[int, int]]]) -> int: + """Workflow that iterates over a list of integer tuples and performs operations.""" + # mypy cannot confirm argument types match. + retval = mod10(*test_list_3()[0]) # type: ignore + for pair in param: + a, b = pair + retval += a + b + return mod10(retval) + + with set_configuration(allow_positional_args=True, flatten_all_nested=True): + result = test_iterated_3(param=[(0, 1), (2, 3)]) + wkflw = construct(result, simplify_ids=True) + + rendered = render(wkflw, allow_complex_types=True)["__root__"] + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + param: + default: + - - 0 + - 1 + - - 2 + - 3 + items: array + label: param + type: array + outputs: + out: + label: out + outputSource: mod10-2/out + type: int + steps: + mod10-1: + in: + num: + source: test_list_3-1[0][0] + out: + - out + run: mod10 + mod10-2: + in: + num: + valueFrom: $(inputs.param[0][0] + inputs.param[0][1] + inputs.param[1][0] + inputs.param[1][1] + self) + source: mod10-1/out + out: + - out + run: mod10 + test_list_3-1: + in: {} + out: + - out + run: test_list_3 + """) + + +def test_can_use_plain_dict_fields() -> None: + """Test the use of plain dictionary fields in workflows.""" + + @workflow() + def test_dict(left: int, right: float) -> dict[str, float | int]: + result: dict[str, float | int] = {"left": mod10(num=left), "right": pi()} + return result + + with set_configuration(allow_plain_dict_fields=True): + result = test_dict(left=3, right=4.0) + wkflw = construct(result, simplify_ids=True) + assert isinstance(wkflw.result, StepReference) + assert str(wkflw.result["left"]) == "test_dict-1/left" + assert wkflw.result["left"].__type__ == int | float + + +@dataclass +class IndexTest: + """A dataclass for testing indexed fields, containing a `left` field that is a list of integers.""" + + left: Fixed[list[int]] + + +def test_can_configure_field_separator() -> None: + """Test the ability to configure the field separator in workflows.""" + + @task() + def test_sep() -> IndexTest: + return IndexTest(left=[3]) + + with set_configuration(field_index_types="int"): + result = test_sep().left[0] + wkflw = construct(result, simplify_ids=True) + render(wkflw, allow_complex_types=True)["__root__"] + assert str(wkflw.result) == "test_sep-1/left[0]" + + with set_configuration(field_index_types="int,str"): + result = test_sep().left[0] + wkflw = construct(result, simplify_ids=True) + render(wkflw, allow_complex_types=True)["__root__"] + assert str(wkflw.result) == "test_sep-1[left][0]" + + with set_configuration(field_index_types=""): + result = test_sep().left[0] + wkflw = construct(result, simplify_ids=True) + render(wkflw, allow_complex_types=True)["__root__"] + assert str(wkflw.result) == "test_sep-1/left/0" diff --git a/tests/test_modularity.py b/tests/test_modularity.py index 908a3adb..4fdc522a 100644 --- a/tests/test_modularity.py +++ b/tests/test_modularity.py @@ -1,14 +1,15 @@ """Verify CWL can be made with split up and nested calls.""" import yaml -from dewret.tasks import nested_task, construct +from dewret.tasks import workflow, construct +from dewret.core import set_configuration from dewret.renderers.cwl import render from ._lib.extra import double, sum, increase STARTING_NUMBER: int = 23 -@nested_task() +@workflow() def algorithm() -> int | float: """Creates a graph of task calls.""" left = double(num=increase(num=STARTING_NUMBER)) @@ -17,13 +18,14 @@ def algorithm() -> int | float: return sum(left=left, right=right) -def test_nested_task() -> None: +def test_subworkflow() -> None: """Check whether we can link between multiple steps and have parameters. Produces CWL that has references between multiple steps. """ - workflow = construct(algorithm(), simplify_ids=True) - rendered = render(workflow) + with set_configuration(flatten_all_nested=True): + workflow = construct(algorithm(), simplify_ids=True) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -33,13 +35,9 @@ def test_nested_task() -> None: label: JUMP type: float default: 1.0 - increase-3-num: + STARTING_NUMBER: default: 23 - label: increase-3-num - type: int - increase-1-num: - default: 17 - label: increase-1-num + label: STARTING_NUMBER type: int outputs: out: @@ -53,7 +51,7 @@ def test_nested_task() -> None: JUMP: source: JUMP num: - source: increase-1-num + source: STARTING_NUMBER out: [out] increase-2: run: increase @@ -61,7 +59,7 @@ def test_nested_task() -> None: JUMP: source: JUMP num: - source: increase-1/out + source: increase-3/out out: [out] increase-3: run: increase @@ -69,13 +67,13 @@ def test_nested_task() -> None: JUMP: source: JUMP num: - source: increase-3-num + default: 17 out: [out] double-1: run: double in: num: - source: increase-3/out + source: increase-1/out out: [out] sum-1: run: sum diff --git a/tests/test_multiresult_steps.py b/tests/test_multiresult_steps.py index 7457f6c6..f84f6ddf 100644 --- a/tests/test_multiresult_steps.py +++ b/tests/test_multiresult_steps.py @@ -4,7 +4,8 @@ from attr import define from dataclasses import dataclass from typing import Iterable -from dewret.tasks import task, construct, nested_task +from dewret.tasks import task, construct, workflow +from dewret.core import set_configuration from dewret.renderers.cwl import render STARTING_NUMBER: int = 23 @@ -44,19 +45,19 @@ def pair(left: int, right: float) -> tuple[int, float]: return (left, right) -@nested_task() +@workflow() def algorithm() -> float: """Sum two split values.""" return combine(left=split().first, right=split().second) -@nested_task() +@workflow() def algorithm_with_pair() -> tuple[int, float]: """Pairs two split dataclass values.""" return pair(left=split_into_dataclass().first, right=split_into_dataclass().second) -@nested_task() +@workflow() def algorithm_with_dataclasses() -> float: """Sums two split dataclass values.""" return combine( @@ -76,13 +77,13 @@ def split_into_dataclass() -> SplitResultDataclass: return SplitResultDataclass(first=1, second=2) -def test_nested_task() -> None: +def test_subworkflow() -> None: """Check whether we can link between multiple steps and have parameters. Produces CWL that has references between multiple steps. """ workflow = construct(split(), simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -114,10 +115,10 @@ def test_nested_task() -> None: """) -def test_field_of_nested_task() -> None: +def test_field_of_subworkflow() -> None: """Tests whether a directly-output nested task can have fields.""" workflow = construct(split().first, simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -142,10 +143,10 @@ def test_field_of_nested_task() -> None: """) -def test_field_of_nested_task_into_dataclasses() -> None: +def test_field_of_subworkflow_into_dataclasses() -> None: """Tests whether a directly-output nested task can have fields.""" workflow = construct(split_into_dataclass().first, simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -170,10 +171,11 @@ def test_field_of_nested_task_into_dataclasses() -> None: """) -def test_complex_field_of_nested_task() -> None: +def test_complex_field_of_subworkflow() -> None: """Tests whether a task can sum complex structures.""" - workflow = construct(algorithm(), simplify_ids=True) - rendered = render(workflow) + with set_configuration(flatten_all_nested=True): + workflow = construct(algorithm(), simplify_ids=True) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -206,11 +208,12 @@ def test_complex_field_of_nested_task() -> None: """) -def test_complex_field_of_nested_task_with_dataclasses() -> None: +def test_complex_field_of_subworkflow_with_dataclasses() -> None: """Tests whether a task can insert result fields into other steps.""" - result = algorithm_with_dataclasses() - workflow = construct(result, simplify_ids=True) - rendered = render(workflow) + with set_configuration(flatten_all_nested=True): + result = algorithm_with_dataclasses() + workflow = construct(result, simplify_ids=True) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -245,8 +248,9 @@ def test_complex_field_of_nested_task_with_dataclasses() -> None: def test_pair_can_be_returned_from_step() -> None: """Tests whether a task can insert result fields into other steps.""" - workflow = construct(algorithm_with_pair(), simplify_ids=True) - rendered = render(workflow) + with set_configuration(flatten_all_nested=True): + workflow = construct(algorithm_with_pair(), simplify_ids=True) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -256,11 +260,10 @@ def test_pair_can_be_returned_from_step() -> None: out: label: out outputSource: pair-1/out - type: - items: - - type: int - - type: float - type: array + items: + - int + - float + type: array steps: pair-1: in: @@ -285,8 +288,9 @@ def test_pair_can_be_returned_from_step() -> None: def test_list_can_be_returned_from_step() -> None: """Tests whether a task can insert result fields into other steps.""" - workflow = construct(list_cast(iterable=algorithm_with_pair()), simplify_ids=True) - rendered = render(workflow) + with set_configuration(flatten_all_nested=True): + workflow = construct(list_cast(iterable=algorithm_with_pair()), simplify_ids=True) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" class: Workflow @@ -296,9 +300,8 @@ def test_list_can_be_returned_from_step() -> None: out: label: out outputSource: list_cast-1/out - type: - items: float - type: array + items: float + type: array steps: list_cast-1: in: diff --git a/tests/test_nested.py b/tests/test_nested.py new file mode 100644 index 00000000..39848112 --- /dev/null +++ b/tests/test_nested.py @@ -0,0 +1,56 @@ +"""Check complex nested structures and expressions can mix.""" + +import yaml +import math +from dewret.workflow import param +from dewret.tasks import construct +from dewret.renderers.cwl import render + +from ._lib.extra import reverse_list, max_list + + +def test_can_supply_nested_raw() -> None: + """TODO: The structures are important for future CWL rendering.""" + pi = param("pi", math.pi) + result = reverse_list(to_sort=[1.0, 3.0, pi]) + workflow = construct(max_list(lst=result + result), simplify_ids=True) + # assert workflow.find_parameters() == { + # pi + # } + + # NB: This is not currently usefully renderable in CWL. + # However, the structures are important for future CWL rendering. + + rendered = render(workflow)["__root__"] + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + pi: + default: 3.141592653589793 + label: pi + type: float + outputs: + out: + label: out + outputSource: max_list-1/out + type: + - int + - float + steps: + max_list-1: + in: + lst: + source: reverse_list-1/out + valueFrom: $(2*self) + out: + - out + run: max_list + reverse_list-1: + in: + to_sort: + valueFrom: $((1.0, 3.0, inputs.pi)) + out: + - out + run: reverse_list + """) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 95717b87..5d32315c 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -23,7 +23,7 @@ def test_cwl_parameters() -> None: """ result = rotate(num=3) workflow = construct(result, simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] assert rendered == yaml.safe_load(""" cwlVersion: 1.2 @@ -34,7 +34,7 @@ def test_cwl_parameters() -> None: type: int default: 3 rotate-1-num: - label: rotate-1-num + label: num type: int default: 3 outputs: @@ -62,7 +62,8 @@ def test_complex_parameters() -> None: num = param("numx", 23) result = sum(left=double(num=rotate(num=num)), right=rotate(num=rotate(num=23))) workflow = construct(result, simplify_ids=True) - rendered = render(workflow) + rendered = render(workflow)["__root__"] + assert rendered == yaml.safe_load(""" cwlVersion: 1.2 class: Workflow @@ -75,8 +76,8 @@ def test_complex_parameters() -> None: label: numx type: int default: 23 - rotate-1-num: - label: rotate-1-num + rotate-2-num: + label: num type: int default: 23 outputs: @@ -91,7 +92,7 @@ def test_complex_parameters() -> None: INPUT_NUM: source: INPUT_NUM num: - source: rotate-1-num + source: rotate-2/out out: [out] double-1: run: double @@ -105,7 +106,7 @@ def test_complex_parameters() -> None: INPUT_NUM: source: INPUT_NUM num: - source: rotate-1/out + source: rotate-2-num out: [out] rotate-3: run: rotate @@ -121,6 +122,6 @@ def test_complex_parameters() -> None: left: source: double-1/out right: - source: rotate-2/out + source: rotate-1/out out: [out] """) diff --git a/tests/test_render_module.py b/tests/test_render_module.py new file mode 100644 index 00000000..19e2a510 --- /dev/null +++ b/tests/test_render_module.py @@ -0,0 +1,77 @@ +"""Check renderers can be imported live.""" + +import pytest +from pathlib import Path +from dewret.tasks import construct +from dewret.render import get_render_method + +from ._lib.extra import increment, triple_and_one + + +def test_can_load_render_module() -> None: + """Checks if we can load a render module.""" + result = triple_and_one(num=increment(num=3)) + workflow = construct(result, simplify_ids=True) + workflow._name = "Fred" + + frender_py = Path(__file__).parent / "_lib/frender.py" + render = get_render_method(frender_py) + + assert render(workflow) == { + "__root__": """ +I found a workflow called Fred. +It has 2 steps! +They are: +* Something called increment-1 + +* A portal called triple_and_one-1 to another workflow, + whose name is triple_and_one + +It probably got made with JUMP=1.0 +""", + "triple_and_one-1": """ +I found a workflow called triple_and_one. +It has 3 steps! +They are: +* Something called double-1-1 + +* Something called sum-1-1 + +* Something called sum-1-2 + +It probably got made with JUMP=1.0 +""", + } + + +def test_can_load_cwl_render_module() -> None: + """Checks if we can load a render module.""" + result = triple_and_one(num=increment(num=3)) + workflow = construct(result, simplify_ids=True) + + frender_py = Path(__file__).parent.parent / "src/dewret/renderers/cwl.py" + render = get_render_method(frender_py) + assert render(workflow) == { + "__root__": "{'cwlVersion': 1.2, 'class': 'Workflow', 'inputs': {'increment-1-num': {'label': 'num', 'type': 'int', 'default': 3}}, 'outputs': {'out': {'label': 'out', 'type': ['int', 'float'], 'outputSource': 'triple_and_one-1/out'}}, 'steps': {'increment-1': {'run': 'increment', 'in': {'num': {'source': 'increment-1-num'}}, 'out': ['out']}, 'triple_and_one-1': {'run': 'triple_and_one', 'in': {'num': {'source': 'increment-1/out'}}, 'out': ['out']}}}", + "triple_and_one-1": "{'cwlVersion': 1.2, 'class': 'Workflow', 'inputs': {'num': {'label': 'num', 'type': 'int'}}, 'outputs': {'out': {'label': 'out', 'type': ['int', 'float'], 'outputSource': 'sum-1-1/out'}}, 'steps': {'double-1-1': {'run': 'double', 'in': {'num': {'source': 'num'}}, 'out': ['out']}, 'sum-1-1': {'run': 'sum', 'in': {'left': {'source': 'sum-1-2/out'}, 'right': {'default': 1}}, 'out': ['out']}, 'sum-1-2': {'run': 'sum', 'in': {'left': {'source': 'double-1-1/out'}, 'right': {'source': 'num'}}, 'out': ['out']}}}" + } + + +def test_get_correct_import_error_if_unable_to_load_render_module() -> None: + """Check if the correct import error will be logged if unable to load render module.""" + unfrender_py = Path(__file__).parent / "_lib/unfrender.py" + with pytest.raises(ModuleNotFoundError) as exc: + get_render_method(unfrender_py) + + entry = exc.traceback[-1] + assert Path(entry.path).resolve() == ( + Path(__file__).parent / "_lib" / "unfrender.py" + ).resolve() + assert entry.relline == 12 + assert "No module named 'extra'" in str(exc.value) + + nonfrender_py = Path(__file__).parent / "_lib/nonfrender.py" + with pytest.raises(NotImplementedError) as nexc: + get_render_method(nonfrender_py) + + assert "This render module neither seems to be a structured nor a raw render module" in str(nexc.value) diff --git a/tests/test_subworkflows.py b/tests/test_subworkflows.py index 41b48fc2..4ddd30b8 100644 --- a/tests/test_subworkflows.py +++ b/tests/test_subworkflows.py @@ -3,21 +3,23 @@ from typing import Callable from queue import Queue import yaml -from dewret.tasks import construct, subworkflow, task, factory +from dewret.tasks import construct, workflow, task, factory +from dewret.core import set_configuration from dewret.renderers.cwl import render from dewret.workflow import param +from attrs import define -from ._lib.extra import increment, sum +from ._lib.extra import increment, sum, pi -CONSTANT = 3 +CONSTANT: int = 3 -QueueFactory: Callable[..., "Queue[int]"] = factory(Queue) +QueueFactory: Callable[..., Queue[int]] = factory(Queue) -GLOBAL_QUEUE = QueueFactory() +GLOBAL_QUEUE: Queue[int] = QueueFactory() @task() -def pop(queue: "Queue[int]") -> int: +def pop(queue: Queue[int]) -> int: """Remove element of a queue.""" return queue.get() @@ -29,39 +31,97 @@ def to_int(num: int | float) -> int: @task() -def add_and_queue(num: int, queue: "Queue[int]") -> "Queue[int]": +def add_and_queue(num: int, queue: Queue[int]) -> Queue[int]: """Add a global constant to a number.""" queue.put(num) return queue -@subworkflow() -def make_queue(num: int | float) -> "Queue[int]": +@workflow() +def make_queue(num: int | float) -> Queue[int]: """Add a number to a queue.""" queue = QueueFactory() return add_and_queue(num=to_int(num=num), queue=queue) -@subworkflow() -def get_global_queue(num: int | float) -> "Queue[int]": +@workflow() +def get_global_queue(num: int | float) -> Queue[int]: """Add a number to a global queue.""" return add_and_queue(num=to_int(num=num), queue=GLOBAL_QUEUE) -@subworkflow() +@workflow() +def get_global_queues(num: int | float) -> list[Queue[int] | int]: + """Add a number to a global queue.""" + return [ + add_and_queue(num=to_int(num=num), queue=GLOBAL_QUEUE), + add_constant(num=num), + ] + + +@workflow() def add_constant(num: int | float) -> int: """Add a global constant to a number.""" return to_int(num=sum(left=num, right=CONSTANT)) +@workflow() +def add_constants(num: int | float) -> int: + """Add a global constant to a number.""" + return to_int(num=sum(left=sum(left=num, right=CONSTANT), right=CONSTANT)) + + +@workflow() +def get_values(num: int | float) -> tuple[int | float, int]: + """Add a global constant to a number.""" + return (sum(left=num, right=CONSTANT), add_constant(CONSTANT)) + + +def test_cwl_for_pairs() -> None: + """Check whether we can produce CWL of pairs.""" + + @workflow() + def pair_pi() -> tuple[float, float]: + return pi(), pi() + + with set_configuration(flatten_all_nested=True): + result = pair_pi() + wkflw = construct(result, simplify_ids=True) + rendered = render(wkflw)["__root__"] + + assert rendered == yaml.safe_load(""" + cwlVersion: 1.2 + class: Workflow + inputs: {} + outputs: [ + { + label: out, + outputSource: pi-1/out, + type: float + }, + { + label: out, + outputSource: pi-1/out, + type: float + } + ] + steps: + pi-1: + run: pi + in: {} + out: [out] + """) + + def test_subworkflows_can_use_globals() -> None: """Produce a subworkflow that uses a global.""" my_param = param("num", typ=int) result = increment(num=add_constant(num=increment(num=my_param))) - workflow = construct(result, simplify_ids=True) - rendered, subworkflows = render(workflow) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw) + rendered = subworkflows["__root__"] - assert len(subworkflows) == 1 + assert len(subworkflows) == 2 assert isinstance(subworkflows, dict) assert rendered == yaml.safe_load(""" @@ -78,16 +138,16 @@ def test_subworkflows_can_use_globals() -> None: outputs: out: label: out - outputSource: increment-2/out + outputSource: increment-1/out type: int steps: - increment-1: + increment-2: in: num: source: num out: [out] run: increment - increment-2: + increment-1: in: num: source: add_constant-1/out @@ -98,7 +158,7 @@ def test_subworkflows_can_use_globals() -> None: CONSTANT: source: CONSTANT num: - source: increment-1/out + source: increment-2/out out: [out] run: add_constant """) @@ -108,10 +168,11 @@ def test_subworkflows_can_use_factories() -> None: """Produce a subworkflow that uses a factory.""" my_param = param("num", typ=int) result = pop(queue=make_queue(num=increment(num=my_param))) - workflow = construct(result, simplify_ids=True) - rendered, subworkflows = render(workflow, allow_complex_types=True) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) + rendered = subworkflows["__root__"] - assert len(subworkflows) == 1 + assert len(subworkflows) == 2 assert isinstance(subworkflows, dict) assert rendered == yaml.safe_load(""" @@ -152,10 +213,11 @@ def test_subworkflows_can_use_global_factories() -> None: """Check whether we can produce a subworkflow that uses a global factory.""" my_param = param("num", typ=int) result = pop(queue=get_global_queue(num=increment(num=my_param))) - workflow = construct(result, simplify_ids=True) - rendered, subworkflows = render(workflow, allow_complex_types=True) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) + rendered = subworkflows["__root__"] - assert len(subworkflows) == 1 + assert len(subworkflows) == 2 assert isinstance(subworkflows, dict) assert rendered == yaml.safe_load(""" @@ -165,6 +227,9 @@ def test_subworkflows_can_use_global_factories() -> None: num: label: num type: int + GLOBAL_QUEUE: + label: GLOBAL_QUEUE + type: Queue outputs: out: label: out @@ -181,6 +246,8 @@ def test_subworkflows_can_use_global_factories() -> None: in: num: source: increment-1/out + GLOBAL_QUEUE: + source: GLOBAL_QUEUE out: [out] run: get_global_queue pop-1: @@ -190,3 +257,375 @@ def test_subworkflows_can_use_global_factories() -> None: out: [out] run: pop """) + + +def test_subworkflows_can_return_lists() -> None: + """Check whether we can produce a subworkflow that returns a list.""" + my_param = param("num", typ=int) + result = get_global_queues(num=increment(num=my_param)) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) + rendered = subworkflows["__root__"] + del subworkflows["__root__"] + + assert len(subworkflows) == 2 + assert isinstance(subworkflows, dict) + osubworkflows = sorted(list(subworkflows.items())) + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + num: + label: num + type: int + CONSTANT: + label: CONSTANT + default: 3 + type: int + GLOBAL_QUEUE: + label: GLOBAL_QUEUE + type: Queue + outputs: + out: + label: out + items: + - Queue + - int + outputSource: get_global_queues-1/out + type: array + steps: + increment-1: + in: + num: + source: num + out: [out] + run: increment + get_global_queues-1: + in: + num: + source: increment-1/out + CONSTANT: + source: CONSTANT + GLOBAL_QUEUE: + source: GLOBAL_QUEUE + out: [out] + run: get_global_queues + """) + + assert osubworkflows[0] == ( + "add_constant-1-1", + yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + num: + label: num + type: int + CONSTANT: + default: 3 + label: CONSTANT + type: int + outputs: + out: + label: out + outputSource: to_int-1-1-1/out + type: int + steps: + sum-1-1-1: + in: + left: + source: num + right: + source: CONSTANT + out: + - out + run: sum + to_int-1-1-1: + in: + num: + source: sum-1-1-1/out + out: + - out + run: to_int + """), + ) + + assert osubworkflows[1] == ( + "get_global_queues-1", + yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + CONSTANT: + default: 3 + label: CONSTANT + type: int + num: + label: num + type: int + GLOBAL_QUEUE: + label: GLOBAL_QUEUE + type: Queue + outputs: + - label: out + outputSource: add_and_queue-1-1/out + type: Queue + - label: out + outputSource: add_constant-1-1/out + type: int + steps: + add_and_queue-1-1: + in: + num: + source: to_int-1-1/out + queue: + source: GLOBAL_QUEUE + out: + - out + run: add_and_queue + add_constant-1-1: + in: + CONSTANT: + source: CONSTANT + num: + source: num + out: + - out + run: add_constant + to_int-1-1: + in: + num: + source: num + out: + - out + run: to_int + """), + ) + + +def test_can_merge_workflows() -> None: + """Check whether we can merge workflows.""" + my_param = param("num", typ=int) + value = to_int(num=increment(num=my_param)) + result = sum(left=value, right=increment(num=value)) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw, allow_complex_types=True) + rendered = subworkflows["__root__"] + del subworkflows["__root__"] + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + num: + label: num + type: int + outputs: + out: + label: out + outputSource: sum-1/out + type: [ + int, + float + ] + steps: + increment-1: + in: + num: + source: num + out: [out] + run: increment + increment-2: + in: + num: + source: to_int-1/out + out: [out] + run: increment + sum-1: + in: + left: + source: to_int-1/out + right: + source: increment-2/out + out: [out] + run: sum + to_int-1: + in: + num: + source: increment-1/out + out: [out] + run: to_int + """) + + +def test_subworkflows_can_use_globals_in_right_scope() -> None: + """Produce a subworkflow that uses a global.""" + my_param = param("num", typ=int) + result = increment(num=add_constants(num=increment(num=my_param))) + wkflw = construct(result, simplify_ids=True) + subworkflows = render(wkflw) + rendered = subworkflows["__root__"] + del subworkflows["__root__"] + + assert len(subworkflows) == 1 + assert isinstance(subworkflows, dict) + osubworkflows = sorted(list(subworkflows.items())) + + assert rendered == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + CONSTANT: + default: 3 + label: CONSTANT + type: int + num: + label: num + type: int + outputs: + out: + label: out + outputSource: increment-1/out + type: int + steps: + increment-1: + in: + num: + source: add_constants-1/out + out: [out] + run: increment + increment-2: + in: + num: + source: num + out: [out] + run: increment + add_constants-1: + in: + CONSTANT: + source: CONSTANT + num: + source: increment-2/out + out: [out] + run: add_constants + """) + + assert osubworkflows[0] == ( + "add_constants-1", + yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + num: + label: num + type: int + CONSTANT: + default: 3 + label: CONSTANT + type: int + outputs: + out: + label: out + outputSource: to_int-1-1/out + type: int + steps: + sum-1-1: + in: + left: + source: num + right: + source: CONSTANT + out: + - out + run: sum + sum-1-2: + in: + left: + source: sum-1-1/out + right: + source: CONSTANT + out: + - out + run: sum + to_int-1-1: + in: + num: + source: sum-1-2/out + out: + - out + run: to_int + """), + ) + + +@define +class PackResult: + """A class representing the counts of card suits in a deck, including hearts, clubs, spades, and diamonds.""" + + hearts: int + clubs: int + spades: int + diamonds: int + + +def test_combining_attrs_and_factories() -> None: + """Check combining attributes from a dataclass with factory-produced instances.""" + Pack = factory(PackResult) + + @task() + def sum(left: int, right: int) -> int: + return left + right + + @workflow() + def black_total(pack: PackResult) -> int: + return sum(left=pack.spades, right=pack.clubs) + + pack = Pack(hearts=13, spades=13, diamonds=13, clubs=13) + wkflw = construct(black_total(pack=pack), simplify_ids=True) + cwl = render(wkflw, allow_complex_types=True, factories_as_params=True) + assert cwl["__root__"] == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + PackResult-1: + label: PackResult-1 + type: record + outputs: + out: + label: out + outputSource: black_total-1/out + type: int + steps: + black_total-1: + in: + pack: + source: PackResult-1/out + out: + - out + run: black_total + """) + + assert cwl["black_total-1"] == yaml.safe_load(""" + class: Workflow + cwlVersion: 1.2 + inputs: + pack: + label: pack + type: record + outputs: + out: + label: out + outputSource: sum-1-1/out + type: int + steps: + sum-1-1: + in: + left: + source: pack/spades + right: + source: pack/clubs + out: + - out + run: sum + """)