Skip to content

Commit

Permalink
Introduce new zntrack.<field> and zntrack.<field>_path (#699)
Browse files Browse the repository at this point in the history
* introduce new field names

* update with new fileds

* fix **

* update README.md
  • Loading branch information
PythonFZ authored Aug 16, 2023
1 parent 9d85511 commit 91f752e
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 27 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ from random import randrange
class HelloWorld(zntrack.Node):
"""Define a ZnTrack Node"""
# parameter to be tracked
max_number: int = zntrack.zn.params()
max_number: int = zntrack.params()
# parameter to store as output
random_number: int = zntrack.zn.outs()
random_number: int = zntrack.outs()

def run(self):
"""Command to be run by DVC"""
Expand Down
25 changes: 25 additions & 0 deletions zntrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@
from zntrack.core.node import Node
from zntrack.core.nodify import NodeConfig, nodify
from zntrack.fields import Field, FieldGroup, LazyField, dvc, meta, zn
from zntrack.fields.fields import (
deps,
deps_path,
metrics,
metrics_path,
outs,
outs_path,
params,
params_path,
plots,
plots_path,
)
from zntrack.project import Project
from zntrack.utils import config
from zntrack.utils.node_wd import nwd
Expand All @@ -32,3 +44,16 @@
"exceptions",
"from_rev",
]

__all__ += [
"outs",
"metrics",
"params",
"deps",
"plots",
"outs_path",
"metrics_path",
"params_path",
"deps_path",
"plots_path",
]
50 changes: 25 additions & 25 deletions zntrack/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
class ParamsToOuts(zntrack.Node):
"""Save params to outs."""

params = zntrack.zn.params()
outs = zntrack.zn.outs()
params = zntrack.params()
outs = zntrack.outs()

def run(self) -> None:
"""Save params to outs."""
Expand All @@ -21,8 +21,8 @@ def run(self) -> None:
class ParamsToMetrics(zntrack.Node):
"""Save params to metrics."""

params = zntrack.zn.params()
metrics = zntrack.zn.metrics()
params = zntrack.params()
metrics = zntrack.metrics()

def run(self) -> None:
"""Save params to metrics."""
Expand All @@ -32,9 +32,9 @@ def run(self) -> None:
class WritePlots(zntrack.Node):
"""Generate a plot."""

plots: pd.DataFrame = zntrack.zn.plots()
x: list = zntrack.zn.params([1, 2, 3])
y: list = zntrack.zn.params([4, 5, 6])
plots: pd.DataFrame = zntrack.plots()
x: list = zntrack.params([1, 2, 3])
y: list = zntrack.params([4, 5, 6])

def run(self):
"""Write plots."""
Expand All @@ -44,9 +44,9 @@ def run(self):
class AddNumbers(zntrack.Node):
"""Add two numbers."""

a = zntrack.zn.params()
b = zntrack.zn.params()
c = zntrack.zn.outs()
a = zntrack.params()
b = zntrack.params()
c = zntrack.outs()

def run(self):
"""Add two numbers."""
Expand All @@ -56,9 +56,9 @@ def run(self):
class AddNodes(zntrack.Node):
"""Add two nodes."""

a: AddNumbers = zntrack.zn.deps()
b: AddNumbers = zntrack.zn.deps()
c = zntrack.zn.outs()
a: AddNumbers = zntrack.deps()
b: AddNumbers = zntrack.deps()
c = zntrack.outs()

def run(self):
"""Add two nodes."""
Expand All @@ -68,9 +68,9 @@ def run(self):
class AddNodeAttributes(zntrack.Node):
"""Add two node attributes."""

a: float = zntrack.zn.deps()
b: float = zntrack.zn.deps()
c = zntrack.zn.outs()
a: float = zntrack.deps()
b: float = zntrack.deps()
c = zntrack.outs()

def run(self):
"""Add two node attributes."""
Expand All @@ -80,8 +80,8 @@ def run(self):
class AddNodeNumbers(zntrack.Node):
"""Add up all 'x.outs' from the dependencies."""

numbers: list = zntrack.zn.deps()
sum: int = zntrack.zn.outs()
numbers: list = zntrack.deps()
sum: int = zntrack.outs()

def run(self):
"""Add up all 'x.outs' from the dependencies."""
Expand All @@ -91,9 +91,9 @@ def run(self):
class SumNodeAttributes(zntrack.Node):
"""Sum a list of numbers."""

inputs: list = zntrack.zn.deps()
shift: int = zntrack.zn.params()
output: int = zntrack.zn.outs()
inputs: list = zntrack.deps()
shift: int = zntrack.params()
output: int = zntrack.outs()

def run(self) -> None:
"""Sum a list of numbers."""
Expand All @@ -103,8 +103,8 @@ def run(self) -> None:
class AddOne(zntrack.Node):
"""Add one to the number."""

number: int = zntrack.zn.deps()
outs: int = zntrack.zn.outs()
number: int = zntrack.deps()
outs: int = zntrack.outs()

def run(self) -> None:
"""Add one to the number."""
Expand All @@ -114,8 +114,8 @@ def run(self) -> None:
class WriteDVCOuts(zntrack.Node):
"""Write an output file."""

params = zntrack.zn.params()
outs = zntrack.dvc.outs(zntrack.nwd / "output.txt")
params = zntrack.params()
outs = zntrack.outs_path(zntrack.nwd / "output.txt")

def run(self):
"""Write an output file."""
Expand Down
148 changes: 148 additions & 0 deletions zntrack/fields/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Fields that are used to define Nodes."""
from zntrack.fields import dvc, zn

# Serialized Fields


def outs():
"""Define a Node Output.
Parameters
----------
data: any
A data object that is generated by the Node.
The object is serialized and deserialized by ZnTrack
and stored in the node working directory.
see https://dvc.org/doc/command-reference/stage/add#-o
"""
return zn.outs()


def metrics():
"""Define a Node Metric.
Parameters
----------
data: dict
A dictionary that is used by DVC as a metric.
The object is serialized and deserialized by ZnTrack
and stored in the node working directory.
see https://dvc.org/doc/command-reference/stage/add#-M
"""
return zn.metrics()


def params(*data):
"""Define a Node Parameter.
Parameters
----------
data: any
A data object that is used as a parameter.
Typically, this should be a string or number.
The object is serialized and deserialized by ZnTrack
and stored in params.yaml.
see https://dvc.org/doc/command-reference/stage/add#-p
"""
return zn.params(*data)


def deps(*data):
"""Define a Node Dependency.
Parameters
----------
data: any
A data object that is used as a dependency.
This can either be a Node or an attribute of a Node.
It can not be an object that is not part of the Node graph.
see https://dvc.org/doc/command-reference/stage/add#-d
"""
return zn.deps(*data)


def plots(*data, **kwargs):
"""Define a Node Plot.
Parameters
----------
data: pd.DataFrame
A pandas DataFrame that is used as a plot.
The object is serialized and deserialized by ZnTrack
and stored in the node working directory.
see https://dvc.org/doc/command-reference/stage/add#--plots
kwargs: dict
Additional keyword arguments that are used for plotting.
"""
return zn.plots(*data, **kwargs)


# Path Fields


def outs_path(*path):
"""Define a Node Output.
Parameters
----------
path: str|Path
A file or directory that is generated by the Node.
see https://dvc.org/doc/command-reference/stage/add#-o
"""
return dvc.outs(*path)


def metrics_path(*path):
"""Define a Node Metric.
Parameters
----------
path : str|Path
A file that is used by DVC as a metric, such as *.json
see https://dvc.org/doc/command-reference/stage/add#-M
"""
return dvc.metrics(*path)


def params_path(*path):
"""Define a Node Parameter.
Parameters
----------
path : str|Path
A file that is used by DVC for reading parameters.
This includes typically json or yaml files.
see https://dvc.org/doc/command-reference/stage/add#-p
"""
return dvc.params(*path)


def deps_path(*path):
"""Define a Node Dependency.
Parameters
----------
path : str|Path
A file or directory that is defined as a dependency to the Node.
see https://dvc.org/doc/command-reference/stage/add#-d
"""
return dvc.deps(*path)


def plots_path(*path, **kwargs):
"""Define a Node Plot.
Parameters
----------
path : str|Path
A file or directory that is defined as a plot to the Node.
see https://dvc.org/doc/command-reference/stage/add#--plots
kwargs: dict
Additional keyword arguments that are used for plotting.
"""
return dvc.plots(*path, **kwargs)

0 comments on commit 91f752e

Please sign in to comment.