Skip to content

Commit

Permalink
Add zntrack list and zntrack.get_nodes (#743)
Browse files Browse the repository at this point in the history
* add CLI option

* add get_nodes

* move from option to argument

* test cli
  • Loading branch information
PythonFZ authored Nov 9, 2023
1 parent 473005e commit f784dc5
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ print(hello_world.random_number)
> node = zntrack.from_rev(
> "HelloWorld",
> remote="https://github.com/PythonFZ/ZnTrackExamples.git",
> rev="fbb6ada",
> rev="890c714",
> )
> ```
>
Expand Down
60 changes: 59 additions & 1 deletion tests/integration/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import pathlib

import pytest
import yaml
from typer.testing import CliRunner

import zntrack.examples
from zntrack import Node, NodeConfig, nodify, zn
from zntrack import Node, NodeConfig, get_nodes, nodify, utils, zn
from zntrack.cli import app


Expand Down Expand Up @@ -106,3 +107,60 @@ def test_run_w_name(proj_path, runner):

node.load()
assert node.outs == 15


def test_list_groups(proj_path, runner):
with zntrack.Project(automatic_node_names=True) as proj:
_ = zntrack.examples.ParamsToOuts(params=15)
_ = zntrack.examples.ParamsToOuts(params=15)

with proj.group("example1"):
_ = zntrack.examples.ParamsToOuts(params=15)
_ = zntrack.examples.ParamsToOuts(params=15)

# TODO: This is not working yet
# with proj.group("nested"):
# _ = zntrack.examples.ParamsToOuts(params=15)
# _ = zntrack.examples.ParamsToOuts(params=15)

with proj.group("nested", "GRP1"):
_ = zntrack.examples.ParamsToOuts(params=15)
_ = zntrack.examples.ParamsToOuts(params=15)
with proj.group("nested", "GRP2"):
_ = zntrack.examples.ParamsToOuts(params=15)
_ = zntrack.examples.ParamsToOuts(params=15)

proj.build()

true_groups = {
"example1": [
"ParamsToOuts -> example1_ParamsToOuts",
"ParamsToOuts_1 -> example1_ParamsToOuts_1",
],
"nodes": [
"ParamsToOuts",
"ParamsToOuts_1",
],
"nested": [
{
"GRP1": [
"ParamsToOuts -> nested_GRP1_ParamsToOuts",
"ParamsToOuts_1 -> nested_GRP1_ParamsToOuts_1",
],
"GRP2": [
"ParamsToOuts -> nested_GRP2_ParamsToOuts",
"ParamsToOuts_1 -> nested_GRP2_ParamsToOuts_1",
],
}
],
}

groups, _ = utils.cli.get_groups(remote=proj_path, rev=None)
assert groups == true_groups

result = runner.invoke(app, ["list", proj_path.as_posix()])
# test stdout == yaml.dump of true_groups
groups = yaml.safe_load(result.stdout)
assert groups == true_groups

assert result.exit_code == 0
42 changes: 42 additions & 0 deletions tests/integration/test_get_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

import zntrack.examples


@pytest.mark.needs_internet
def test_get_nodes_remote(proj_path):
nodes = zntrack.get_nodes(
remote="https://github.com/PythonFZ/ZnTrackExamples.git", rev="890c714"
)

assert nodes["HelloWorld"].random_number == 123


def test_get_nodes(proj_path):
with zntrack.Project(automatic_node_names=True) as proj:
_ = zntrack.examples.ParamsToOuts(params=15)
_ = zntrack.examples.ParamsToOuts(params=15)

with proj.group("example1"):
_ = zntrack.examples.ParamsToOuts(params=15)
_ = zntrack.examples.ParamsToOuts(params=15)

with proj.group("nested", "GRP1"):
_ = zntrack.examples.ParamsToOuts(params=15)
_ = zntrack.examples.ParamsToOuts(params=15)
with proj.group("nested", "GRP2"):
_ = zntrack.examples.ParamsToOuts(params=15)
_ = zntrack.examples.ParamsToOuts(params=15)

proj.run()

nodes = zntrack.get_nodes(remote=proj_path, rev=None)

assert nodes["ParamsToOuts"].outs == 15
assert nodes["ParamsToOuts_1"].outs == 15
assert nodes["example1_ParamsToOuts"].outs == 15
assert nodes["example1_ParamsToOuts_1"].outs == 15
assert nodes["nested_GRP1_ParamsToOuts"].outs == 15
assert nodes["nested_GRP1_ParamsToOuts_1"].outs == 15
assert nodes["nested_GRP2_ParamsToOuts"].outs == 15
assert nodes["nested_GRP2_ParamsToOuts_1"].outs == 15
3 changes: 2 additions & 1 deletion zntrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import importlib.metadata

from zntrack import exceptions, tools
from zntrack.core.load import from_rev
from zntrack.core.load import from_rev, get_nodes
from zntrack.core.node import Node
from zntrack.core.nodify import NodeConfig, nodify
from zntrack.fields import Field, FieldGroup, LazyField, dvc, meta, zn
Expand Down Expand Up @@ -44,6 +44,7 @@
"tools",
"exceptions",
"from_rev",
"get_nodes",
]

__all__ += [
Expand Down
10 changes: 10 additions & 0 deletions zntrack/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,13 @@ def init(
"""Initialize a new ZnTrack Project."""
initializer = utils.cli.Initializer(name=name, gitignore=gitignore, force=force)
initializer.run()


@app.command()
def list(
remote: str = typer.Argument(".", help="The path/url to the repository"),
rev: str = typer.Argument(None, help="The revision to list (default: HEAD)"),
):
"""List all Nodes in the Project."""
groups, _ = utils.cli.get_groups(remote, rev)
print(yaml.dump(groups))
7 changes: 7 additions & 0 deletions zntrack/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from zntrack.core.node import Node
from zntrack.utils import config
from zntrack.utils.cli import get_groups

T = typing.TypeVar("T", bound=Node)

Expand Down Expand Up @@ -145,3 +146,9 @@ def from_rev(name, remote=".", rev=None, **kwargs) -> T:
cls = getattr(module, cls_name)

return cls.from_rev(name, remote, rev, **kwargs)


def get_nodes(remote=".", rev=None) -> dict[str, Node]:
"""Load all nodes from the given remote and revision."""
_, node_names = get_groups(remote, rev)
return {node_name: from_rev(node_name, remote, rev) for node_name in node_names}
54 changes: 54 additions & 0 deletions zntrack/utils/cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""CLI Helpers."""

import dataclasses
import json
import pathlib
import subprocess
import urllib.request

import typer
from dvc.api import DVCFileSystem


@dataclasses.dataclass
Expand Down Expand Up @@ -71,3 +73,55 @@ def make_repo(self):
"""Initialize the repository."""
subprocess.check_call(["git", "init"])
subprocess.check_call(["dvc", "init"])


def get_groups(remote, rev) -> (dict, list):
"""Get the group names and the nodes in each group from the remote.
Arguments:
---------
remote : str
The remote to get the group names from.
rev : str
The revision to use.
Returns:
-------
groups : dict
a nested dictionary with the group names as keys and the nodes in each group as
values. Contains "short-name -> long-name" if inside a group.
node_names: list
A list of all node names in the project.
"""
fs = DVCFileSystem(url=remote, rev=rev)
with fs.open("zntrack.json") as f:
config = json.load(f)

true_groups = {}
node_names = []

def add_to_group(groups, grp_names, node_name):
if len(grp_names) == 1:
if grp_names[0] not in groups:
groups[grp_names[0]] = []
groups[grp_names[0]].append(node_name)
else:
if grp_names[0] not in groups:
groups[grp_names[0]] = [{}]
add_to_group(groups[grp_names[0]][0], grp_names[1:], node_name)

for node_name, node_config in config.items():
nwd = pathlib.Path(node_config["nwd"]["value"])
grp_names = nwd.parent.as_posix().split("/")[1:]
if len(grp_names) == 0:
node_names.append(node_name)
grp_names = ["nodes"]
else:
for grp_name in grp_names:
node_name = node_name.replace(f"{grp_name}_", "")

node_names.append(f"{'_'.join(grp_names)}_{node_name}")
node_name = f"{node_name} -> {node_names[-1]}"
add_to_group(true_groups, grp_names, node_name)

return true_groups, node_names

0 comments on commit f784dc5

Please sign in to comment.