Skip to content

Commit

Permalink
dvcfs: optimize get() by reducing index.info calls() (#10540)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Aug 27, 2024
1 parent 016f285 commit 4f3fb15
Show file tree
Hide file tree
Showing 5 changed files with 1,007 additions and 26 deletions.
8 changes: 4 additions & 4 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ def download(self, to: "Output", jobs: Optional[int] = None):

files = super().download(to=to, jobs=jobs)
if not isinstance(to.fs, LocalFileSystem):
return files
return

hashes: list[tuple[str, HashInfo, dict[str, Any]]] = []
for src_path, dest_path in files:
for src_path, dest_path, *rest in files:
try:
hash_info = self.fs.info(src_path)["dvc_info"]["entry"].hash_info
info = rest[0] if rest else self.fs.info(src_path)
hash_info = info["dvc_info"]["entry"].hash_info
dest_info = to.fs.info(dest_path)
except (KeyError, AttributeError):
# If no hash info found, just keep going and output will be hashed later
Expand All @@ -112,7 +113,6 @@ def download(self, to: "Output", jobs: Optional[int] = None):
hashes.append((dest_path, hash_info, dest_info))
cache = to.cache if to.use_cache else to.local_cache
cache.state.save_many(hashes, to.fs)
return files

def update(self, rev: Optional[str] = None):
if rev:
Expand Down
25 changes: 14 additions & 11 deletions dvc/fs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import glob
from typing import Optional
from typing import Optional, Union
from urllib.parse import urlparse

from dvc.config import ConfigError as RepoConfigError
Expand Down Expand Up @@ -47,12 +47,24 @@

def download(
fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None
) -> list[tuple[str, str]]:
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
from dvc.scm import lfs_prefetch

from .callbacks import TqdmCallback

with TqdmCallback(desc=f"Downloading {fs.name(fs_path)}", unit="files") as cb:
if isinstance(fs, DVCFileSystem):
lfs_prefetch(
fs,
[
f"{fs.normpath(glob.escape(fs_path))}/**"
if fs.isdir(fs_path)
else glob.escape(fs_path)
],
)
if not glob.has_magic(fs_path):
return fs._get(fs_path, to, batch_size=jobs, callback=cb)

# NOTE: We use dvc-objects generic.copy over fs.get since it makes file
# download atomic and avoids fsspec glob/regex path expansion.
if fs.isdir(fs_path):
Expand All @@ -69,15 +81,6 @@ def download(
from_infos = [fs_path]
to_infos = [to]

if isinstance(fs, DVCFileSystem):
lfs_prefetch(
fs,
[
f"{fs.normpath(glob.escape(fs_path))}/**"
if fs.isdir(fs_path)
else glob.escape(fs_path)
],
)
cb.set_size(len(from_infos))
jobs = jobs or fs.jobs
generic.copy(fs, from_infos, localfs, to_infos, callback=cb, batch_size=jobs)
Expand Down
148 changes: 146 additions & 2 deletions dvc/fs/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@
import threading
from collections import deque
from contextlib import ExitStack, suppress
from glob import has_magic
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from fsspec.spec import AbstractFileSystem
from fsspec.spec import DEFAULT_CALLBACK, AbstractFileSystem
from funcy import wrap_with

from dvc.log import logger
from dvc_objects.fs.base import FileSystem
from dvc.utils.threadpool import ThreadPoolExecutor
from dvc_objects.fs.base import AnyFSPath, FileSystem

from .data import DataFileSystem

if TYPE_CHECKING:
from dvc.repo import Repo
from dvc.types import DictStrAny, StrPath

from .callbacks import Callback

logger = logger.getChild(__name__)

RepoFactory = Union[Callable[..., "Repo"], type["Repo"]]
Expand Down Expand Up @@ -474,9 +478,110 @@ def _info( # noqa: C901
info["name"] = path
return info

def get(
self,
rpath,
lpath,
recursive=False,
callback=DEFAULT_CALLBACK,
maxdepth=None,
batch_size=None,
**kwargs,
):
self._get(
rpath,
lpath,
recursive=recursive,
callback=callback,
maxdepth=maxdepth,
batch_size=batch_size,
**kwargs,
)

def _get( # noqa: C901
self,
rpath,
lpath,
recursive=False,
callback=DEFAULT_CALLBACK,
maxdepth=None,
batch_size=None,
**kwargs,
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
if (
isinstance(rpath, list)
or isinstance(lpath, list)
or has_magic(rpath)
or not self.exists(rpath)
or not recursive
):
super().get(
rpath,
lpath,
recursive=recursive,
callback=callback,
maxdepth=maxdepth,
**kwargs,
)
return []

if os.path.isdir(lpath) or lpath.endswith(os.path.sep):
lpath = self.join(lpath, os.path.basename(rpath))

if self.isfile(rpath):
with callback.branched(rpath, lpath) as child:
self.get_file(rpath, lpath, callback=child, **kwargs)
return [(rpath, lpath)]

_files = []
_dirs: list[str] = []
for root, dirs, files in self.walk(rpath, maxdepth=maxdepth, detail=True):
if files:
callback.set_size((callback.size or 0) + len(files))

parts = self.relparts(root, rpath)
if parts in ((os.curdir,), ("",)):
parts = ()
dest_root = os.path.join(lpath, *parts)
if not maxdepth or len(parts) < maxdepth - 1:
_dirs.extend(f"{dest_root}{os.path.sep}{d}" for d in dirs)

key = self._get_key_from_relative(root)
_, dvc_fs, _ = self._get_subrepo_info(key)

for name, info in files.items():
src_path = f"{root}{self.sep}{name}"
dest_path = f"{dest_root}{os.path.sep}{name}"
_files.append((dvc_fs, src_path, dest_path, info))

os.makedirs(lpath, exist_ok=True)
for d in _dirs:
os.mkdir(d)

def _get_file(arg):
dvc_fs, src, dest, info = arg
dvc_info = info.get("dvc_info")
if dvc_info and dvc_fs:
dvc_path = dvc_info["name"]
dvc_fs.get_file(
dvc_path, dest, callback=callback, info=dvc_info, **kwargs
)
else:
self.get_file(src, dest, callback=callback, **kwargs)
return src, dest, info

with ThreadPoolExecutor(max_workers=batch_size) as executor:
return list(executor.imap_unordered(_get_file, _files))

def get_file(self, rpath, lpath, **kwargs):
key = self._get_key_from_relative(rpath)
fs_path = self._from_key(key)

dirpath = os.path.dirname(lpath)
if dirpath:
# makedirs raises error if the string is empty
os.makedirs(dirpath, exist_ok=True)

try:
return self.repo.fs.get_file(fs_path, lpath, **kwargs)
except FileNotFoundError:
Expand Down Expand Up @@ -553,6 +658,45 @@ def immutable(self):
def getcwd(self):
return self.fs.getcwd()

def _get(
self,
from_info: Union[AnyFSPath, list[AnyFSPath]],
to_info: Union[AnyFSPath, list[AnyFSPath]],
callback: "Callback" = DEFAULT_CALLBACK,
recursive: bool = False,
batch_size: Optional[int] = None,
**kwargs,
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
# FileSystem.get is non-recursive by default if arguments are lists
# otherwise, it's recursive.
recursive = not (isinstance(from_info, list) and isinstance(to_info, list))
return self.fs._get(
from_info,
to_info,
callback=callback,
recursive=recursive,
batch_size=batch_size,
**kwargs,
)

def get(
self,
from_info: Union[AnyFSPath, list[AnyFSPath]],
to_info: Union[AnyFSPath, list[AnyFSPath]],
callback: "Callback" = DEFAULT_CALLBACK,
recursive: bool = False,
batch_size: Optional[int] = None,
**kwargs,
) -> None:
self._get(
from_info,
to_info,
callback=callback,
batch_size=batch_size,
recursive=recursive,
**kwargs,
)

@property
def fsid(self) -> str:
return self.fs.fsid
Expand Down
47 changes: 38 additions & 9 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dvc.testing.tmp_dir import make_subrepo
from dvc.utils.fs import remove
from dvc_data.hashfile import hash
from dvc_data.index.index import DataIndexDirError
from dvc_data.index.index import DataIndex, DataIndexDirError


def test_import(tmp_dir, scm, dvc, erepo_dir):
Expand Down Expand Up @@ -725,12 +725,41 @@ def test_import_invalid_configs(tmp_dir, scm, dvc, erepo_dir):
)


def test_import_no_hash(tmp_dir, scm, dvc, erepo_dir, mocker):
@pytest.mark.parametrize(
"files,expected_info_calls",
[
({"foo": "foo"}, {("foo",)}),
(
{
"dir": {
"bar": "bar",
"subdir": {"lorem": "ipsum", "nested": {"lorem": "lorem"}},
}
},
# info calls should be made for only directories
{("dir",), ("dir", "subdir"), ("dir", "subdir", "nested")},
),
],
)
def test_import_no_hash(
tmp_dir, scm, dvc, erepo_dir, mocker, files, expected_info_calls
):
with erepo_dir.chdir():
erepo_dir.dvc_gen("foo", "foo content", commit="create foo")

spy = mocker.spy(hash, "file_md5")
stage = dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported")
assert spy.call_count == 1
for call in spy.call_args_list:
assert stage.outs[0].fs_path != call.args[0]
erepo_dir.dvc_gen(files, commit="create foo")

file_md5_spy = mocker.spy(hash, "file_md5")
index_info_spy = mocker.spy(DataIndex, "info")
name = next(iter(files))

dvc.imp(os.fspath(erepo_dir), name, "out")

local_hashes = [
call.args[0]
for call in file_md5_spy.call_args_list
if call.args[1].protocol == "local"
]
# no files should be hashed, should use existing metadata
assert not local_hashes
assert {
call.args[1] for call in index_info_spy.call_args_list
} == expected_info_calls
Loading

0 comments on commit 4f3fb15

Please sign in to comment.