Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dates handling #121

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sirocco/core/_tasks/icon_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import f90nml

from sirocco.core.graph_items import Task
from sirocco.parsing._yaml_data_models import ConfigIconTaskSpecs
from sirocco.parsing.yaml_data_models import ConfigIconTaskSpecs


@dataclass(kw_only=True)
Expand Down Expand Up @@ -51,7 +51,7 @@ def update_core_namelists_from_workflow(self):
self.core_namelists["icon_master.namelist"]["master_time_control_nml"].update(
{
"experimentStartDate": self.start_date.isoformat() + "Z",
"experimentStopDate": self.end_date.isoformat() + "Z",
"experimentStopDate": self.stop_date.isoformat() + "Z",
}
)
self.core_namelists["icon_master.namelist"]["master_nml"]["lrestart"] = any(
Expand Down
2 changes: 1 addition & 1 deletion src/sirocco/core/_tasks/shell_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

from sirocco.core.graph_items import Task
from sirocco.parsing._yaml_data_models import ConfigShellTaskSpecs
from sirocco.parsing.yaml_data_models import ConfigShellTaskSpecs


@dataclass(kw_only=True)
Expand Down
61 changes: 26 additions & 35 deletions src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
from itertools import chain, product
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeAlias, TypeVar, cast

from sirocco.parsing._yaml_data_models import (
from sirocco.parsing.target_date import DateList, LagList, SameDate
from sirocco.parsing.when import WhenSpec
from sirocco.parsing.yaml_data_models import (
ConfigAvailableData,
ConfigBaseDataSpecs,
ConfigBaseTaskSpecs,
)

if TYPE_CHECKING:
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path

from termcolor._types import Color

from sirocco.parsing._yaml_data_models import (
from sirocco.parsing.cycling import CyclePoint
from sirocco.parsing.yaml_data_models import (
ConfigBaseData,
ConfigCycleTask,
ConfigCycleTaskWaitOn,
Expand Down Expand Up @@ -73,8 +75,7 @@ class Task(ConfigBaseTaskSpecs, GraphItem):
outputs: list[Data] = field(default_factory=list)
wait_on: list[Task] = field(default_factory=list)
config_rootdir: Path
start_date: datetime | None = None
end_date: datetime | None = None
cycle_point: CyclePoint

_wait_on_specs: list[ConfigCycleTaskWaitOn] = field(default_factory=list, repr=False)

Expand All @@ -90,8 +91,7 @@ def from_config(
cls,
config: ConfigTask,
config_rootdir: Path,
start_date: datetime | None,
end_date: datetime | None,
cycle_point: CyclePoint,
coordinates: dict[str, Any],
datastore: Store,
graph_spec: ConfigCycleTask,
Expand All @@ -112,8 +112,7 @@ def from_config(
new = plugin_cls(
config_rootdir=config_rootdir,
coordinates=coordinates,
start_date=start_date,
end_date=end_date,
cycle_point=cycle_point,
inputs=inputs,
outputs=outputs,
**cls_config,
Expand Down Expand Up @@ -187,29 +186,29 @@ def __getitem__(self, coordinates: dict) -> GRAPH_ITEM_T:
key = tuple(coordinates[dim] for dim in self._dims)
return self._dict[key]

def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> Iterator[GRAPH_ITEM_T]:
def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, ref_coordinates: dict) -> Iterator[GRAPH_ITEM_T]:
# Check date references
if "date" not in self._dims and (spec.lag or spec.date):
if "date" not in self._dims and isinstance(spec.target_date, DateList | LagList):
msg = f"Array {self._name} has no date dimension, cannot be referenced by dates"
raise ValueError(msg)
if "date" in self._dims and reference.get("date") is None and len(spec.date) == 0:
if "date" in self._dims and ref_coordinates.get("date") is None and not isinstance(spec.target_date, DateList):
msg = f"Array {self._name} has a date dimension, must be referenced by dates"
raise ValueError(msg)

for key in product(*(self._resolve_target_dim(spec, dim, reference) for dim in self._dims)):
for key in product(*(self._resolve_target_dim(spec, dim, ref_coordinates) for dim in self._dims)):
yield self._dict[key]

def _resolve_target_dim(self, spec: TargetNodesBaseModel, dim: str, reference: Any) -> Iterator[Any]:
def _resolve_target_dim(self, spec: TargetNodesBaseModel, dim: str, ref_coordinates: Any) -> Iterator[Any]:
if dim == "date":
if not spec.lag and not spec.date:
yield reference["date"]
if spec.lag:
for lag in spec.lag:
yield reference["date"] + lag
if spec.date:
yield from spec.date
match spec.target_date:
case SameDate():
yield ref_coordinates["date"]
case DateList():
yield from spec.target_date.dates
case LagList():
yield from spec.target_date.lags
elif spec.parameters.get(dim) == "single":
yield reference[dim]
yield ref_coordinates[dim]
else:
yield from self._axes[dim]

Expand Down Expand Up @@ -239,20 +238,12 @@ def __getitem__(self, key: tuple[str, dict]) -> GRAPH_ITEM_T:
raise KeyError(msg)
return self._dict[name][coordinates]

def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> Iterator[GRAPH_ITEM_T]:
# Check if target items should be querried at all
if (when := spec.when) is not None:
if (ref_date := reference.get("date")) is None:
msg = "Cannot use a `when` specification without a `reference date`"
raise ValueError(msg)
if (at := when.at) is not None and at != ref_date:
return
if (before := when.before) is not None and before <= ref_date:
return
if (after := when.after) is not None and after >= ref_date:
return
def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, ref_coordinates: dict) -> Iterator[GRAPH_ITEM_T]:
# Check if we need to skip this querry
if isinstance(spec.when, WhenSpec) and not spec.when.is_active(ref_coordinates.get("date")):
return
# Yield items
yield from self._dict[spec.name].iter_from_cycle_spec(spec, reference)
yield from self._dict[spec.name].iter_from_cycle_spec(spec, ref_coordinates)

def __iter__(self) -> Iterator[GRAPH_ITEM_T]:
yield from chain(*(self._dict.values()))
38 changes: 16 additions & 22 deletions src/sirocco/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
from typing import TYPE_CHECKING, Self

from sirocco.core.graph_items import Cycle, Data, Store, Task
from sirocco.parsing._yaml_data_models import (
from sirocco.parsing.cycling import DateCyclePoint, OneOffPoint
from sirocco.parsing.yaml_data_models import (
ConfigBaseData,
ConfigWorkflow,
)

if TYPE_CHECKING:
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path

from sirocco.parsing._yaml_data_models import (
from sirocco.parsing.cycling import CyclePoint
from sirocco.parsing.yaml_data_models import (
ConfigCycle,
ConfigData,
ConfigTask,
Expand Down Expand Up @@ -44,40 +45,40 @@ def __init__(
task_dict: dict[str, ConfigTask] = {task.name: task for task in tasks}

# Function to iterate over date and parameter combinations
def iter_coordinates(param_refs: list, date: datetime | None = None) -> Iterator[dict]:
space = ({} if date is None else {"date": [date]}) | {k: parameters[k] for k in param_refs}
yield from (dict(zip(space.keys(), x, strict=False)) for x in product(*space.values()))
def iter_coordinates(param_refs: list[str], cycle_point: CyclePoint) -> Iterator[dict]:
axes = {k: parameters[k] for k in param_refs}
if isinstance(cycle_point, DateCyclePoint):
axes["date"] = [cycle_point.begin_date]
yield from (dict(zip(axes.keys(), x, strict=False)) for x in product(*axes.values()))

# 1 - create availalbe data nodes
for available_data_config in data.available:
for coordinates in iter_coordinates(param_refs=available_data_config.parameters, date=None):
for coordinates in iter_coordinates(param_refs=available_data_config.parameters, cycle_point=OneOffPoint()):
self.data.add(Data.from_config(config=available_data_config, coordinates=coordinates))

# 2 - create output data nodes
for cycle_config in cycles:
for date in self.cycle_dates(cycle_config):
for cycle_point in cycle_config.cycling.iter_cycle_points():
for task_ref in cycle_config.tasks:
for data_ref in task_ref.outputs:
data_name = data_ref.name
data_config = data_dict[data_name]
for coordinates in iter_coordinates(param_refs=data_config.parameters, date=date):
for coordinates in iter_coordinates(param_refs=data_config.parameters, cycle_point=cycle_point):
self.data.add(Data.from_config(config=data_config, coordinates=coordinates))

# 3 - create cycles and tasks
for cycle_config in cycles:
cycle_name = cycle_config.name
for date in self.cycle_dates(cycle_config):
for cycle_point in cycle_config.cycling.iter_cycle_points():
cycle_tasks = []
for task_graph_spec in cycle_config.tasks:
task_name = task_graph_spec.name
task_config = task_dict[task_name]

for coordinates in iter_coordinates(param_refs=task_config.parameters, date=date):
for coordinates in iter_coordinates(param_refs=task_config.parameters, cycle_point=cycle_point):
task = Task.from_config(
config=task_config,
config_rootdir=self.config_rootdir,
start_date=cycle_config.start_date,
end_date=cycle_config.end_date,
cycle_point=cycle_point,
coordinates=coordinates,
datastore=self.data,
graph_spec=task_graph_spec,
Expand All @@ -88,21 +89,14 @@ def iter_coordinates(param_refs: list, date: datetime | None = None) -> Iterator
Cycle(
name=cycle_name,
tasks=cycle_tasks,
coordinates={} if date is None else {"date": date},
coordinates={} if isinstance(cycle_point, OneOffPoint) else {"date": cycle_point.begin_date},
)
)

# 4 - Link wait on tasks
for task in self.tasks:
task.link_wait_on_tasks(self.tasks)

@staticmethod
def cycle_dates(cycle_config: ConfigCycle) -> Iterator[datetime | None]:
yield (date := cycle_config.start_date)
if cycle_config.period is not None and date is not None and cycle_config.end_date is not None:
while (date := date + cycle_config.period) < cycle_config.end_date:
yield date

@classmethod
def from_config_file(cls: type[Self], config_path: str) -> Self:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/sirocco/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._yaml_data_models import (
from .yaml_data_models import (
ConfigWorkflow,
)

Expand Down
Loading
Loading