Skip to content

Commit

Permalink
Merge pull request #428 from c-bata/support-optuna-metric-names
Browse files Browse the repository at this point in the history
Support `metric_names` introduced in Optuna
  • Loading branch information
c-bata committed Apr 27, 2023
1 parent 86d3d54 commit cdee5c2
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
17 changes: 16 additions & 1 deletion docs/errors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Warning Messages
----------------

Human-in-the-loop optimization will not work with ``_CachedStorage`` in Optuna prior to v3.2.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This warning occurs when the storage object associated with the Optuna Study is of the ``_CachedStorage`` class.

Expand All @@ -22,3 +22,18 @@ or use a following dirty hack to unwrap ``_CachedStorage`` class.
if isinstance(study._storage, optuna.storages._CachedStorage):
study._storage = study._storage._backend
``set_objective_names()`` function is deprecated. Please use ``study.set_metric_names()`` instead.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

:func:`~optuna_dashboard.set_objective_names` function has been ported to Optuna.
Please use `study.set_metric_names() <https://optuna.readthedocs.io/en/latest/reference/generated/optuna.study.Study.html#optuna.study.Study>`_ function instead.

.. list-table::

* - Deprecated APIs
- Corresponding Active APIs
* - ``optuna_dashboard.set_objective_names(study, ["objective 1", "objective 2"])``
- ``study.set_metric_names(["objective 1", "objective 2"])``

18 changes: 18 additions & 0 deletions optuna_dashboard/_named_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

from typing import Any
from typing import Optional
import warnings

import optuna


# Should be equivalent to `optuna.study.study._SYSTEM_ATTR_METRIC_NAMES`.
# See https://github.com/optuna/optuna/pull/4383 for details.
SYSTEM_ATTR_METRIC_NAMES = "study:metric_names"

SYSTEM_ATTR_NAME = "dashboard:objective_names"


Expand All @@ -22,6 +27,17 @@ def set_objective_names(study: optuna.Study, names: list[str]) -> None:
study = optuna.create_study(directions=["minimize", "minimize"])
set_objective_names(study, ["val_loss", "flops"])
"""

if hasattr(study, "set_metric_names"):
warnings.warn(
"`set_objective_names()` function is deprecated."
" Please use `study.set_metric_names()` instead."
" See https://optuna-dashboard.readthedocs.io/en/latest/errors.html for details.",
category=FutureWarning,
)
study.set_metric_names(names)
return

storage = study._storage
study_id = study._study_id

Expand All @@ -32,4 +48,6 @@ def set_objective_names(study: optuna.Study, names: list[str]) -> None:


def get_objective_names(system_attrs: dict[str, Any]) -> Optional[list[str]]:
if SYSTEM_ATTR_METRIC_NAMES in system_attrs:
return system_attrs[SYSTEM_ATTR_METRIC_NAMES]
return system_attrs.get(SYSTEM_ATTR_NAME)
24 changes: 24 additions & 0 deletions python_tests/test_metric_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import unittest

import optuna
from optuna.version import __version__ as optuna_ver
from optuna_dashboard._named_objectives import get_objective_names
from packaging import version


class MetricNamesTestCase(unittest.TestCase):
@unittest.skipIf(
version.parse(optuna_ver) < version.Version("3.2.0.dev"),
"study.set_metric_names() is not implemented yet",
)
def test_get_metric_names(self) -> None:
study = optuna.create_study(directions=["minimize", "minimize"])
# TODO(c-bata): Remove the following `type: ignore` after released Optuna v3.2.
study.set_metric_names(["val_loss", "flops"]) # type: ignore

study_system_attrs = study._storage.get_study_system_attrs(study._study_id)
metric_names = get_objective_names(study_system_attrs)

assert metric_names == ["val_loss", "flops"]

0 comments on commit cdee5c2

Please sign in to comment.