Skip to content

Commit

Permalink
single and incompatible working
Browse files Browse the repository at this point in the history
  • Loading branch information
KrissiHub committed Jan 17, 2024
1 parent 3a21fae commit 722a333
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 34 deletions.
24 changes: 11 additions & 13 deletions deepcave/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import copy
import itertools
import re
import webbrowser
from collections import defaultdict
Expand Down Expand Up @@ -177,7 +178,6 @@ def register_input(
key = (id, attribute, filter, type_)
if key not in self.inputs:
self.inputs.append(key)

# We have to rearrange the inputs because `State`
# must follow all `Input`. Since all filters are `Input`, we have to
# shift them to the front.
Expand Down Expand Up @@ -313,7 +313,7 @@ def plugin_input_update(pathname: str, *inputs_list: str) -> List[str]:
update_dict(inputs, new_inputs)

# Set not used inputs
for (id, attribute, _, _) in self.inputs:
for id, attribute, _, _ in self.inputs:
if id not in inputs:
inputs[id] = {}

Expand Down Expand Up @@ -420,7 +420,7 @@ def toggle_help_modal(n: Optional[int], is_open: bool) -> Tuple[bool, str]:
return is_open

# Register callback to click on configurations
for (id, *_) in self.outputs:
for id, *_ in self.outputs:
internal_id = self.get_internal_output_id(id)

@app.callback(
Expand Down Expand Up @@ -469,8 +469,7 @@ def _inputs_changed(
# If only filters changed, then we don't need to
# calculate the results again.
if last_inputs is not None:
for (id, attribute, filter, _) in self.inputs:

for id, attribute, filter, _ in self.inputs:
if self.activate_run_selection:
if id == "run":
continue
Expand Down Expand Up @@ -524,7 +523,7 @@ def _process_raw_outputs(
# We have to add no_updates here for the mode we don't want
count_outputs = 0
count_mpl_outputs = 0
for (_, _, mpl_mode) in self.outputs:
for _, _, mpl_mode in self.outputs:
if mpl_mode:
count_mpl_outputs += 1
else:
Expand Down Expand Up @@ -569,7 +568,6 @@ def _list_to_dict(self, values: List[str], input: bool = True) -> Dict[str, Dict
mapping[id] = {}

mapping[id][attribute] = value

return mapping

@interactive
Expand Down Expand Up @@ -600,7 +598,7 @@ def _dict_to_list(
order = self.outputs # type: ignore

result: List[Optional[str]] = []
for (id, attribute, instance, *_) in order:
for id, attribute, instance, *_ in order:
if not input:
# Instance is mlp_mode in case of outputs
# Simply ignore other outputs.
Expand Down Expand Up @@ -637,7 +635,7 @@ def _dict_as_key(self, d: Dict[str, Any], remove_filters: bool = False) -> Optio

new_d = copy.deepcopy(d)
if remove_filters:
for (id, _, filter, _) in self.inputs:
for id, _, filter, _ in self.inputs:
if filter:
if id in new_d:
del new_d[id]
Expand All @@ -662,10 +660,9 @@ def _cast_inputs(self, inputs: Dict[str, Dict[str, str]]) -> Dict[str, Dict[str,
casted_inputs: Dict[str, Dict[str, str]] = defaultdict(dict)
for id, attributes in inputs.items():
for attribute in attributes:

# Find corresponding input
type = None
for (id_, attribute_, _, type_) in self.inputs:
for id_, attribute_, _, type_ in self.inputs:
if id == id_ and attribute == attribute_:
type = type_
break
Expand Down Expand Up @@ -702,7 +699,7 @@ def _clean_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
used_ids = []
cleaned_inputs = {}
for (id, attribute, *_) in self.inputs:
for id, attribute, *_ in self.inputs:
# Since self.inputs is ordered, we use the first occuring attribute and add
# the id so it is not used again.
if id not in used_ids:
Expand Down Expand Up @@ -803,6 +800,7 @@ def __call__(self, render_button: bool = False) -> List[Component]:
else:
components += [html.H1(self.name)]


try:
self.check_runs_compatibility(self.all_runs)
except NotMergeableError as message:
Expand Down Expand Up @@ -1332,7 +1330,7 @@ def generate_inputs(self, **kwargs: Any) -> Dict[str, Any]:
The inputs for the run.
"""
mapping = {}
for (id, attribute, *_) in self.inputs:
for id, attribute, *_ in self.inputs:
# Since `self.inputs` is ordered, we use the first occuring attribute and add
# the id so it is not used again.
if id not in mapping:
Expand Down
1 change: 0 additions & 1 deletion deepcave/plugins/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def register_callbacks(self) -> None:
inputs = [Input(self.get_internal_id("update-button"), "n_clicks")]
for id, attribute, _, _ in self.inputs:
inputs.append(Input(self.get_internal_input_id(id), attribute))

# Register updates from inputs
@app.callback(outputs, inputs)
def plugin_output_update(_, *inputs_list): # type: ignore
Expand Down
47 changes: 28 additions & 19 deletions deepcave/plugins/objective/cost_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
from dash import dcc, html
from dash.exceptions import PreventUpdate

from deepcave import config
from deepcave import config, notification
from deepcave.plugins.dynamic import DynamicPlugin
from deepcave.runs import AbstractRun, check_equality
from deepcave.runs.group import NotMergeableError
from deepcave.utils.layout import get_select_options, help_button
from deepcave.utils.styled_plotty import (
get_color,
get_hovertext_from_config,
save_image,
)
from deepcave.runs.group import NotMergeableError
from deepcave import notification


class CostOverTime(DynamicPlugin):
Expand All @@ -27,12 +26,21 @@ class CostOverTime(DynamicPlugin):
help = "docs/plugins/cost_over_time.rst"

def check_runs_compatibility(self, runs: List[AbstractRun]) -> None:
#If the runs are not mergeable, there should still
#be an option to look at one of the runs
# If the runs are not mergeable, there should still
# be an option to look at one of the runs
try:
check_equality(runs, objectives=True, budgets=True)
if self.activate_run_selection == True:
self.activate_run_selection = False
self.inputs.remove(('run', 'value', False, None))
self.inputs.remove(('run', 'options', False, None))
DynamicPlugin.register_callbacks(self)

except NotMergeableError:
notification.update("The runs you chose could not be combined. You can still choose to look at the Cost Over Time for one specific run though.")
notification.update(
"The runs you chose could not be combined. You can still choose to look at the Cost Over Time for one specific run though."
)
self.activate_run_selection = True

# Set some attributes here
run = runs[0]
Expand Down Expand Up @@ -177,7 +185,7 @@ def load_outputs(runs, inputs, outputs):
x = outputs["times"]
if inputs["xaxis"] == "trials":
x = outputs["ids"]

y = np.array(outputs["costs_mean"])
y_err = np.array(outputs["costs_std"])
y_upper = list(y + y_err)
Expand Down Expand Up @@ -209,17 +217,17 @@ def load_outputs(runs, inputs, outputs):
)

traces.append(
go.Scatter(
x=x,
y=y_upper,
line=dict(color=get_color(0, 0)),
line_shape="hv",
hoverinfo="skip",
showlegend=False,
marker=dict(symbol=None),
)
go.Scatter(
x=x,
y=y_upper,
line=dict(color=get_color(0, 0)),
line_shape="hv",
hoverinfo="skip",
showlegend=False,
marker=dict(symbol=None),
)

)

traces.append(
go.Scatter(
x=x,
Expand Down Expand Up @@ -258,7 +266,9 @@ def load_outputs(runs, inputs, outputs):
symbol = None
mode = "lines"
if len(config_ids) > 0:
hovertext = [get_hovertext_from_config(run, config_id) for config_id in config_ids]
hovertext = [
get_hovertext_from_config(run, config_id) for config_id in config_ids
]
hoverinfo = "text"
symbol = "circle"
mode = "lines+markers"
Expand Down Expand Up @@ -324,4 +334,3 @@ def load_outputs(runs, inputs, outputs):
save_image(figure, "cost_over_time.pdf")

return figure

1 change: 0 additions & 1 deletion deepcave/runs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,5 +986,4 @@ def check_equality(
result["objectives"] = serialized_objectives
if meta:
result["meta"]["objectives"] = serialized_objectives

return result

0 comments on commit 722a333

Please sign in to comment.