Skip to content

Commit

Permalink
plots: standardise across DVC/Studio/VS Code (#9931)
Browse files Browse the repository at this point in the history
* standardise plot data output

* add data required for new anchors to renderer properties

* feed dvc-render updates into API output

* update --json --split api with new data types (send split anchor definitions as strings)

* accomodate zoom and pan anchor

* switch anchor_revs to revs_with_datapoints

* accomodate height and width anchors

* drop terrible idea of holding all data as strings

* fix plot having multiple x fields

* move inferred properties off vega converter class
  • Loading branch information
mattseddon authored Dec 6, 2023
1 parent 1aed230 commit 5bb72b3
Show file tree
Hide file tree
Showing 12 changed files with 488 additions and 406 deletions.
46 changes: 1 addition & 45 deletions dvc/commands/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,49 +59,6 @@ def _show_json(
ui.write_json(compact({"errors": all_errors, "data": data}), highlight=False)


def _adjust_vega_renderers(renderers):
from dvc.render import REVISION_FIELD, VERSION_FIELD
from dvc_render import VegaRenderer

for r in renderers:
if isinstance(r, VegaRenderer):
if _data_versions_count(r) > 1:
summary = _summarize_version_infos(r)
for dp in r.datapoints:
vi = dp.pop(VERSION_FIELD, {})
keys = list(vi.keys())
for key in keys:
if not (len(summary.get(key, set())) > 1):
vi.pop(key)
if vi:
dp["rev"] = "::".join(vi.values())
else:
for dp in r.datapoints:
dp[REVISION_FIELD] = dp[VERSION_FIELD]["revision"]
dp.pop(VERSION_FIELD, {})


def _summarize_version_infos(renderer):
from collections import defaultdict

from dvc.render import VERSION_FIELD

result = defaultdict(set)

for dp in renderer.datapoints:
for key, value in dp.get(VERSION_FIELD, {}).items():
result[key].add(value)
return dict(result)


def _data_versions_count(renderer):
from itertools import product

summary = _summarize_version_infos(renderer)
x = product(summary.get("filename", {None}), summary.get("field", {None}))
return len(set(x))


class CmdPlots(CmdBase):
def _func(self, *args, **kwargs):
raise NotImplementedError
Expand Down Expand Up @@ -175,11 +132,10 @@ def run(self) -> int: # noqa: C901, PLR0911, PLR0912
return 0

renderers = [r.renderer for r in renderers_with_errors]
_adjust_vega_renderers(renderers)
if self.args.show_vega:
renderer = first(filter(lambda r: r.TYPE == "vega", renderers))
if renderer:
ui.write_json(renderer.get_filled_template(as_string=False))
ui.write_json(renderer.get_filled_template())
return 0

output_file: Path = (Path.cwd() / out).resolve() / "index.html"
Expand Down
13 changes: 7 additions & 6 deletions dvc/render/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
INDEX_FIELD = "step"
REVISION_FIELD = "rev"
FILENAME_FIELD = "filename"
VERSION_FIELD = "dvc_data_version_info"
REVISIONS_KEY = "revisions"
INDEX = "step"
REVISION = "rev"
FILENAME = "filename"
FIELD = "field"
REVISIONS = "revisions"
ANCHOR_DEFINITIONS = "anchor_definitions"
TYPE_KEY = "type"
SRC_FIELD = "src"
SRC = "src"
45 changes: 18 additions & 27 deletions dvc/render/convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import defaultdict
from typing import Dict, List, Union

from dvc.render import REVISION_FIELD, REVISIONS_KEY, SRC_FIELD, TYPE_KEY, VERSION_FIELD
from dvc.render import REVISION, REVISIONS, SRC, TYPE_KEY
from dvc.render.converter.image import ImageConverter
from dvc.render.converter.vega import VegaConverter

Expand All @@ -19,39 +18,31 @@ def _get_converter(
raise ValueError(f"Invalid renderer class {renderer_class}")


def _group_by_rev(datapoints):
grouped = defaultdict(list)
for datapoint in datapoints:
rev = datapoint.get(VERSION_FIELD, {}).get("revision")
grouped[rev].append(datapoint)
return dict(grouped)


def to_json(renderer, split: bool = False) -> List[Dict]:
if renderer.TYPE == "vega":
grouped = _group_by_rev(renderer.datapoints)
if not renderer.datapoints:
return []
revs = renderer.get_revs()
if split:
content = renderer.get_filled_template(
skip_anchors=["data"], as_string=False
)
content, split_content = renderer.get_partial_filled_template()
else:
content = renderer.get_filled_template(as_string=False)
if grouped:
return [
{
TYPE_KEY: renderer.TYPE,
REVISIONS_KEY: sorted(grouped.keys()),
"content": content,
"datapoints": grouped,
}
]
return []
content = renderer.get_filled_template()
split_content = {}

return [
{
TYPE_KEY: renderer.TYPE,
REVISIONS: revs,
"content": content,
**split_content,
}
]
if renderer.TYPE == "image":
return [
{
TYPE_KEY: renderer.TYPE,
REVISIONS_KEY: [datapoint.get(REVISION_FIELD)],
"url": datapoint.get(SRC_FIELD),
REVISIONS: [datapoint.get(REVISION)],
"url": datapoint.get(SRC),
}
for datapoint in renderer.datapoints
]
Expand Down
8 changes: 4 additions & 4 deletions dvc/render/converter/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from typing import TYPE_CHECKING, Any, Dict, List, Tuple

from dvc.render import FILENAME_FIELD, REVISION_FIELD, SRC_FIELD
from dvc.render import FILENAME, REVISION, SRC

from . import Converter

Expand Down Expand Up @@ -58,9 +58,9 @@ def flat_datapoints(self, revision: str) -> Tuple[List[Dict], Dict]:
else:
src = self._encode_image(image_data)
datapoint = {
REVISION_FIELD: revision,
FILENAME_FIELD: filename,
SRC_FIELD: src,
REVISION: revision,
FILENAME: filename,
SRC: src,
}
datapoints.append(datapoint)
return datapoints, properties
67 changes: 41 additions & 26 deletions dvc/render/converter/vega.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from funcy import first, last

from dvc.exceptions import DvcException
from dvc.render import FILENAME_FIELD, INDEX_FIELD, VERSION_FIELD
from dvc.render import FIELD, FILENAME, INDEX, REVISION

from . import Converter

Expand Down Expand Up @@ -112,42 +112,45 @@ def __init__(
):
super().__init__(plot_id, data, properties)
self.plot_id = plot_id
self.inferred_properties: Dict = {}

def _infer_y_from_data(self):
if self.plot_id in self.data:
for lst in _lists(self.data[self.plot_id]):
if all(isinstance(item, dict) for item in lst):
datapoint = first(lst)
field = last(datapoint.keys())
self.inferred_properties["y"] = {self.plot_id: field}
break
return {self.plot_id: field}
return None

def _infer_x_y(self):
x = self.properties.get("x", None)
y = self.properties.get("y", None)

inferred_properties: Dict = {}

# Infer x.
if isinstance(x, str):
self.inferred_properties["x"] = {}
inferred_properties["x"] = {}
# If multiple y files, duplicate x for each file.
if isinstance(y, dict):
for file, fields in y.items():
# Duplicate x for each y.
if isinstance(fields, list):
self.inferred_properties["x"][file] = [x] * len(fields)
inferred_properties["x"][file] = [x] * len(fields)
else:
self.inferred_properties["x"][file] = x
inferred_properties["x"][file] = x
# Otherwise use plot ID as file.
else:
self.inferred_properties["x"][self.plot_id] = x
inferred_properties["x"][self.plot_id] = x

# Infer y.
if y is None:
self._infer_y_from_data()
inferred_properties["y"] = self._infer_y_from_data()
# If y files not provided, use plot ID as file.
elif not isinstance(y, dict):
self.inferred_properties["y"] = {self.plot_id: y}
inferred_properties["y"] = {self.plot_id: y}

return inferred_properties

def _find_datapoints(self):
result = {}
Expand Down Expand Up @@ -182,7 +185,7 @@ def infer_x_label(properties):

x = properties.get("x", None)
if not isinstance(x, dict):
return INDEX_FIELD
return INDEX

fields = {field for _, field in _file_field(x)}
if len(fields) == 1:
Expand All @@ -192,23 +195,25 @@ def infer_x_label(properties):
def flat_datapoints(self, revision): # noqa: C901, PLR0912
file2datapoints, properties = self.convert()

props_update = {}
props_update: Dict[str, Union[str, List[Dict[str, str]]]] = {}

xs = list(_get_xs(properties, file2datapoints))

# assign "step" if no x provided
if not xs:
x_file, x_field = (
None,
INDEX_FIELD,
INDEX,
)
else:
x_file, x_field = xs[0]
props_update["x"] = x_field

num_xs = len(xs)
multiple_x_fields = num_xs > 1 and len({x[1] for x in xs}) > 1
props_update["x"] = "dvc_inferred_x_value" if multiple_x_fields else x_field

ys = list(_get_ys(properties, file2datapoints))

num_xs = len(xs)
num_ys = len(ys)
if num_xs > 1 and num_xs != num_ys:
raise DvcException(
Expand Down Expand Up @@ -237,6 +242,14 @@ def flat_datapoints(self, revision): # noqa: C901, PLR0912
else:
common_prefix_len = 0

props_update["anchors_y_definitions"] = [
{
FILENAME: _get_short_y_file(y_file, common_prefix_len),
FIELD: y_field,
}
for y_file, y_field in ys
]

for i, (y_file, y_field) in enumerate(ys):
if num_xs > 1:
x_file, x_field = xs[i]
Expand All @@ -249,15 +262,16 @@ def flat_datapoints(self, revision): # noqa: C901, PLR0912
source_field=y_field,
)

if x_field == INDEX_FIELD and x_file is None:
_update_from_index(datapoints, INDEX_FIELD)
if x_field == INDEX and x_file is None:
_update_from_index(datapoints, INDEX)
else:
x_datapoints = file2datapoints.get(x_file, [])
try:
_update_from_field(
datapoints,
field=x_field,
field="dvc_inferred_x_value" if multiple_x_fields else x_field,
source_datapoints=x_datapoints,
source_field=x_field,
)
except IndexError:
raise DvcException( # noqa: B904
Expand All @@ -266,15 +280,12 @@ def flat_datapoints(self, revision): # noqa: C901, PLR0912
"They have to have same length."
)

y_file_short = y_file[common_prefix_len:].strip("/\\")
_update_all(
datapoints,
update_dict={
VERSION_FIELD: {
"revision": revision,
FILENAME_FIELD: y_file_short,
"field": y_field,
}
REVISION: revision,
FILENAME: _get_short_y_file(y_file, common_prefix_len),
FIELD: y_field,
},
)

Expand All @@ -295,17 +306,21 @@ def convert(
generated datapoints and updated properties. `x`, `y` values and labels
are inferred and always provided.
"""
self._infer_x_y()
inferred_properties = self._infer_x_y()

datapoints = self._find_datapoints()
properties = {**self.properties, **self.inferred_properties}
properties = {**self.properties, **inferred_properties}

properties["y_label"] = self.infer_y_label(properties)
properties["x_label"] = self.infer_x_label(properties)

return datapoints, properties


def _get_short_y_file(y_file, common_prefix_len):
return y_file[common_prefix_len:].strip("/\\")


def _update_from_field(
target_datapoints: List[Dict],
field: str,
Expand Down
18 changes: 12 additions & 6 deletions dvc/render/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def match_defs_renderers( # noqa: C901, PLR0912
for plot_id, group in plots_data.group_definitions().items():
plot_datapoints: List[Dict] = []
props = _squash_plots_properties(group)
final_props: Dict = {}
first_props: Dict = {}

def_errors: Dict[str, Exception] = {}
src_errors: DefaultDict[str, Dict[str, Exception]] = defaultdict(dict)
Expand All @@ -90,6 +90,7 @@ def match_defs_renderers( # noqa: C901, PLR0912
if templates_dir is not None:
props["template_dir"] = templates_dir

revs = []
for rev, inner_id, plot_definition in group:
plot_sources = infer_data_sources(inner_id, plot_definition)
definitions_data = plots_data.get_definition_data(plot_sources, rev)
Expand All @@ -109,19 +110,24 @@ def match_defs_renderers( # noqa: C901, PLR0912

try:
dps, rev_props = converter.flat_datapoints(rev)
if dps and rev not in revs:
revs.append(rev)
except Exception as e: # noqa: BLE001
logger.warning("In %r, %s", rev, str(e).lower())
def_errors[rev] = e
continue

if not final_props and rev_props:
final_props = rev_props
if not first_props and rev_props:
first_props = rev_props
plot_datapoints.extend(dps)

if "title" not in final_props:
final_props["title"] = renderer_id
if "title" not in first_props:
first_props["title"] = renderer_id

if revs:
first_props["revs_with_datapoints"] = revs

if renderer_cls is not None:
renderer = renderer_cls(plot_datapoints, renderer_id, **final_props)
renderer = renderer_cls(plot_datapoints, renderer_id, **first_props)
renderers.append(RendererWithErrors(renderer, dict(src_errors), def_errors))
return renderers
Loading

0 comments on commit 5bb72b3

Please sign in to comment.