Skip to content

Commit

Permalink
exp: ignore workspace errors during push/pull (iterative#10128)
Browse files Browse the repository at this point in the history
* exp: ignore workspace errors during push/pull

* refactor fetch revs warning

* mention failed revs are skipped

* add workspace arg to brancher

* pass workspace=False only for dvc exp push/pull
  • Loading branch information
dberenbaum authored Dec 7, 2023
1 parent 1f07023 commit 7c898cf
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 5 deletions.
5 changes: 4 additions & 1 deletion dvc/repo/brancher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def brancher(
all_tags=False,
all_commits=False,
all_experiments=False,
workspace=True,
commit_date: Optional[str] = None,
sha_only=False,
num=1,
Expand All @@ -31,6 +32,7 @@ def brancher(
all_branches (bool): iterate over all available branches.
all_commits (bool): iterate over all commits.
all_tags (bool): iterate over all available tags.
workspace (bool): include workspace.
commit_date (str): Keep experiments from the commits after(include)
a certain date. Date must match the extended
ISO 8601 format (YYYY-MM-DD).
Expand Down Expand Up @@ -73,7 +75,8 @@ def brancher(

logger.trace("switching fs to workspace")
self.fs = LocalFileSystem(url=self.root_dir)
yield "workspace"
if workspace:
yield "workspace"

revs = revs.copy() if revs else []
if "workspace" in revs:
Expand Down
4 changes: 3 additions & 1 deletion dvc/repo/experiments/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,6 @@ def _pull_cache(
refs = [refs]
revs = list(exp_commits(repo.scm, refs))
logger.debug("dvc fetch experiment '%s'", refs)
repo.fetch(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs)
repo.fetch(
jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, workspace=False
)
4 changes: 3 additions & 1 deletion dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,6 @@ def _push_cache(
assert isinstance(repo.scm, Git)
revs = list(exp_commits(repo.scm, refs))
logger.debug("dvc push experiment '%s'", refs)
return repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs)
return repo.push(
jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, workspace=False
)
8 changes: 6 additions & 2 deletions dvc/repo/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _collect_indexes( # noqa: PLR0913
recursive=False,
all_commits=False,
revs=None,
workspace=True,
max_size=None,
types=None,
config=None,
Expand Down Expand Up @@ -62,6 +63,7 @@ def outs_filter(out: "Output") -> bool:
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
workspace=workspace,
):
try:
repo.config.merge(config)
Expand All @@ -79,11 +81,11 @@ def outs_filter(out: "Output") -> bool:
idx.data["repo"].onerror = _make_index_onerror(onerror, rev)

indexes[rev or "workspace"] = idx
except Exception as exc:
except Exception as exc: # noqa: BLE001
if onerror:
onerror(rev, None, exc)
collection_exc = exc
logger.exception("failed to collect '%s'", rev or "workspace")
logger.warning("failed to collect '%s', skipping", rev or "workspace")

if not indexes and collection_exc:
raise collection_exc
Expand All @@ -104,6 +106,7 @@ def fetch( # noqa: PLR0913
all_commits=False,
run_cache=False,
revs=None,
workspace=True,
max_size=None,
types=None,
config=None,
Expand Down Expand Up @@ -148,6 +151,7 @@ def fetch( # noqa: PLR0913
recursive=recursive,
all_commits=all_commits,
revs=revs,
workspace=workspace,
max_size=max_size,
types=types,
config=config,
Expand Down
2 changes: 2 additions & 0 deletions dvc/repo/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def push( # noqa: PLR0913
all_commits=False,
run_cache=False,
revs=None,
workspace=True,
glob=False,
):
from fsspec.utils import tokenize
Expand Down Expand Up @@ -87,6 +88,7 @@ def push( # noqa: PLR0913
recursive=recursive,
all_commits=all_commits,
revs=revs,
workspace=workspace,
push=True,
)

Expand Down
16 changes: 16 additions & 0 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import pytest
from funcy import first

Expand Down Expand Up @@ -360,3 +362,17 @@ def test_get(tmp_dir, scm, dvc, exp_stage, erepo_dir, use_ref):
rev=exp_ref.name if use_ref else exp_rev,
)
assert (erepo_dir / "params.yaml").read_text().strip() == "foo: 2"


def test_push_pull_invalid_workspace(
tmp_dir, scm, dvc, git_upstream, exp_stage, local_remote, caplog
):
dvc.experiments.run()

with open("dvc.yaml", mode="a") as f:
f.write("\ninvalid")

with caplog.at_level(logging.WARNING, logger="dvc"):
dvc.experiments.push(git_upstream.remote, push_cache=True)
dvc.experiments.pull(git_upstream.remote, pull_cache=True)
assert "failed to collect" not in caplog.text

0 comments on commit 7c898cf

Please sign in to comment.