From 438051e01d6926c4f08c6c2893ce3a59a5259f7e Mon Sep 17 00:00:00 2001 From: Jonathan de Bruin Date: Mon, 18 Sep 2023 18:12:55 +0200 Subject: [PATCH] Add support for params and refactor download response (#64) Co-authored-by: Vetra <78014027+SexyVetra@users.noreply.github.com> --- .github/workflows/python-package.yml | 2 +- datahugger/__init__.py | 2 + datahugger/__main__.py | 25 +++++++++++ datahugger/api.py | 14 ++++--- datahugger/base.py | 63 ++++++++++++++-------------- datahugger/services.py | 31 +++++++------- tests/test_repositories.py | 15 ++++++- tests/test_resolver.py | 4 +- 8 files changed, 100 insertions(+), 56 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 94c4785..dec5926 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -22,7 +22,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install . + python -m pip install .[all] python -m pip install pytest pytest-xdist - name: Test with pytest run: | diff --git a/datahugger/__init__.py b/datahugger/__init__.py index 06422d4..cf827fe 100644 --- a/datahugger/__init__.py +++ b/datahugger/__init__.py @@ -1,6 +1,7 @@ from datahugger.api import get from datahugger.api import info from datahugger.api import parse_resource_identifier +from datahugger.base import DownloadResult from datahugger.exceptions import DOIError from datahugger.exceptions import RepositoryNotSupportedError @@ -8,6 +9,7 @@ "get", "info", "parse_resource_identifier", + "DownloadResult", "DOIError", "RepositoryNotSupportedError", ] diff --git a/datahugger/__main__.py b/datahugger/__main__.py index 2a3554e..f833f0a 100644 --- a/datahugger/__main__.py +++ b/datahugger/__main__.py @@ -22,6 +22,20 @@ def print_green(s): print(f"\u001b[32m{s}\u001b[0m") +class KVAppendAction(argparse.Action): + def __call__(self, parser, args, values, option_string=None): + try: + (k, v) = values[0].split("=", 2) + except ValueError as err: + raise argparse.ArgumentError( + self, f"'{values[0]}' is not a valid key-value pair (key=value)" + ) from err + + d = getattr(args, self.dest) or {} + d[k] = v + setattr(args, self.dest, d) + + def main(): parser = argparse.ArgumentParser( prog="datahugger", @@ -69,6 +83,16 @@ def main(): help="Python based log levels. Default: WARNING.", ) + parser.add_argument( + "-p", + "--param", + nargs=1, + action=KVAppendAction, + dest="params", + help="Add key=value params to pass to the downloader. " + "May appear multiple times.", + ) + # version parser.add_argument( "-V", @@ -91,6 +115,7 @@ def main(): unzip=args.unzip, progress=args.progress, print_only=args.print_only, + params=args.params, ) except DOIError as doi_err: diff --git a/datahugger/api.py b/datahugger/api.py index f5bfba1..e130e72 100644 --- a/datahugger/api.py +++ b/datahugger/api.py @@ -57,7 +57,7 @@ def info( unzip=True, progress=True, print_only=False, - **kwargs, + params=None, ): """Get info on the content of the dataset. @@ -65,8 +65,6 @@ def info( --------- resource: str, pathlib.Path The URL, DOI, or Handle of the dataset. - output_folder: str, pathlib.Path - The folder to download the dataset files to. max_file_size: int The maximum number of bytes for a single file. If exceeded, the file is skipped. @@ -80,6 +78,8 @@ def info( print_only: bool Print the output of the dataset download without downloading the actual files (Dry run). Default: False. + params: dict + Extra parameters for the request. Returns ------- @@ -98,7 +98,7 @@ def info( unzip=unzip, progress=progress, print_only=print_only, - **kwargs, + params=params, ) @@ -110,7 +110,7 @@ def get( unzip=True, progress=True, print_only=False, - **kwargs, + params=None, ): """Get the content of repository. @@ -136,6 +136,8 @@ def get( print_only: bool Print the output of the dataset download without downloading the actual files (Dry run). Default: False. + params: dict + Extra parameters for the request. Returns ------- @@ -151,7 +153,7 @@ def get( unzip=unzip, progress=progress, print_only=print_only, - **kwargs, + params=params, ) return service.download(output_folder) diff --git a/datahugger/base.py b/datahugger/base.py index 6b52653..0566c8e 100644 --- a/datahugger/base.py +++ b/datahugger/base.py @@ -17,14 +17,18 @@ from datahugger.utils import _is_url -class DatasetResult: +class DownloadResult: """Result class after downloading the dataset.""" + def __init__(self, dataset, output_folder): + self.dataset = dataset + self.output_folder = output_folder + def __str__(self): return f"<{self.__class__.__name__} n_files={len(self)} >" def __len__(self): - return len(self.files) + return len(self.dataset.files) def tree(self, **kwargs): """Return the folder tree. @@ -49,6 +53,7 @@ def __init__( progress=True, unzip=True, print_only=False, + params=None, ): super().__init__() self.resource = resource @@ -58,6 +63,7 @@ def __init__( self.progress = progress self.unzip = unzip self.print_only = print_only + self.params = params def _get_attr_attr(self, record, jsonp): try: @@ -197,20 +203,6 @@ def _unpack_single_folder(self, zip_url, output_folder): zip_info.filename = os.path.basename(zip_info.filename) z.extract(zip_info, output_folder) - @property - def _params(self): - if hasattr(self, "__params"): - return self.__params - - url = _get_url(self.resource) - - # if isinstance(url, str) and _is_url(url): - self.__params = self._parse_url(url) - # else: - # self.__params = {"record_id": url, "version": None} - - return self.__params - def _pre_files(self): pass @@ -294,6 +286,23 @@ def _get_single_file(self, url, folder_name=None, base_url=None): } ] + @property + def _params(self): + """Params including url params.""" + if hasattr(self, "__params"): + return self.__params + + url = _get_url(self.resource) + url_params = self._parse_url(url) + if self.params: + new_params = self.params.copy() + new_params.update(url_params) + self.__params = new_params + else: + self.__params = url_params + + return self.__params + @property def files(self): if hasattr(self, "_files"): @@ -301,9 +310,12 @@ def files(self): self._pre_files() - uri = urlparse(_get_url(self.resource)) + url = _get_url(self.resource) + uri = urlparse(url) base_url = uri.scheme + "://" + uri.netloc + print(self._params) + if hasattr(self, "is_singleton") and self.is_singleton: self._files = self._get_single_file( self.API_URL_META_SINGLE.format( @@ -324,7 +336,6 @@ def files(self): def _get( self, output_folder: Union[Path, str], - **kwargs, ): if ( len(self.files) == 1 @@ -347,26 +358,16 @@ def _get( def download( self, output_folder: Union[Path, str], - **kwargs, ): - """Download files for the given URL or record id. + """Download files. Arguments --------- - record_id_or_url: str - The identifier of the record or the url to the resource - to download. output_folder: str The folder to store the downloaded results. - version: str, int - The version of the dataset """ - Path(output_folder).mkdir(parents=True, exist_ok=True) - - self._get(output_folder, **kwargs) - # store the location of the last known output folder - self.output_folder = output_folder + self._get(output_folder=output_folder) - return self + return DownloadResult(self, output_folder) diff --git a/datahugger/services.py b/datahugger/services.py index 05a5164..036fc20 100644 --- a/datahugger/services.py +++ b/datahugger/services.py @@ -10,11 +10,10 @@ from jsonpath_ng import parse from datahugger.base import DatasetDownloader -from datahugger.base import DatasetResult from datahugger.utils import _get_url -class ZenodoDataset(DatasetDownloader, DatasetResult): +class ZenodoDataset(DatasetDownloader): """Downloader for Zenodo repository. For Zenodo records, new versions have new identifiers. @@ -42,7 +41,7 @@ def _get_attr_hash_type(self, record): return self._get_attr_attr(record, self.ATTR_HASH_JSONPATH).split(":")[0] -class DataverseDataset(DatasetDownloader, DatasetResult): +class DataverseDataset(DatasetDownloader): """Downloader for Dataverse repository.""" REGEXP_ID = r"(?Pdataset|file)\.xhtml\?persistentId=(?P.*)" @@ -68,7 +67,7 @@ def _pre_files(self): self.is_singleton = True -class FigShareDataset(DatasetDownloader, DatasetResult): +class FigShareDataset(DatasetDownloader): """Downloader for FigShare repository.""" REGEXP_ID = r"articles\/.*?\/.*?\/(?P\d+)(?:\/(?P\d+)|)" @@ -96,7 +95,7 @@ class DjehutyDataset(FigShareDataset): API_URL = "https://data.4tu.nl/v2" -class OSFDataset(DatasetDownloader, DatasetResult): +class OSFDataset(DatasetDownloader): """Downloader for OSF repository.""" REGEXP_ID = r"osf\.io\/(?P.*)/" @@ -122,7 +121,7 @@ class OSFDataset(DatasetDownloader, DatasetResult): ATTR_HASH_TYPE_VALUE = "sha256" -class DataDryadDataset(DatasetDownloader, DatasetResult): +class DataDryadDataset(DatasetDownloader): """Downloader for DataDryad repository.""" REGEXP_ID = r"datadryad\.org[\:]*[43]{0,3}\/stash\/dataset\/doi:(?P.*)" @@ -180,7 +179,7 @@ def _get_attr_link(self, record): return "https://datadryad.org" + record["_links"]["stash:file-download"]["href"] -class DataOneDataset(DatasetDownloader, DatasetResult): +class DataOneDataset(DatasetDownloader): """Downloader for DataOne repositories.""" REGEXP_ID = r"view/doi:(?P.*)" @@ -218,7 +217,7 @@ def files(self): return self._files -class DSpaceDataset(DatasetDownloader, DatasetResult): +class DSpaceDataset(DatasetDownloader): """Downloader for DSpaceDataset repositories.""" REGEXP_ID = r"handle/(?P\d+\/\d+)" @@ -248,7 +247,7 @@ def _pre_files(self): self.API_URL_META = base_url + res.json()["link"] + "/bitstreams" -class MendeleyDataset(DatasetDownloader, DatasetResult): +class MendeleyDataset(DatasetDownloader): """Downloader for Mendeley repository.""" REGEXP_ID = r"data\.mendeley\.com\/datasets\/(?P[0-9a-z]+)(?:\/(?P\d+)|)" # noqa @@ -280,7 +279,7 @@ def _pre_files(self): self.version = r_version.json()[-1]["version"] -class GitHubDataset(DatasetDownloader, DatasetResult): +class GitHubDataset(DatasetDownloader): """Downloader for GitHub repository.""" API_URL = "https://github.com/" @@ -296,10 +295,10 @@ def _get(self, output_folder: Union[Path, str], *args, **kwargs): @property def files(self): # at the moment, .files is not available for GitHub - raise AttributeError("'files' is not available for GitHub") + raise NotImplementedError("'files' is not available for GitHub") -class HuggingFaceDataset(DatasetDownloader, DatasetResult): +class HuggingFaceDataset(DatasetDownloader): """Downloader for Huggingface repository.""" REGEXP_ID = r"huggingface.co/datasets/(?P.*)" @@ -307,7 +306,6 @@ class HuggingFaceDataset(DatasetDownloader, DatasetResult): def _get( self, output_folder: Union[Path, str], - **kwargs, ): try: from datasets import load_dataset @@ -317,15 +315,16 @@ def _get( " or use 'pip install datahugger[all]'" ) from err - load_dataset(self._params["record_id"], cache_dir=output_folder, **kwargs) + params = self.params if self.params else {} + load_dataset(self._params["record_id"], cache_dir=output_folder, **params) @property def files(self): # at the moment, .files is not available for HuggingFace - raise AttributeError("'files' is not available for HuggingFace") + raise NotImplementedError("'files' is not available for HuggingFace") -class ArXivDataset(DatasetDownloader, DatasetResult): +class ArXivDataset(DatasetDownloader): """Downloader for ArXiv publication.""" REGEXP_ID = r"https://arxiv\.org/abs/(?P.*)" diff --git a/tests/test_repositories.py b/tests/test_repositories.py index 56c2401..954aaab 100644 --- a/tests/test_repositories.py +++ b/tests/test_repositories.py @@ -95,4 +95,17 @@ def test_info_without_loading(tmpdir): dh_info = datahugger.info("https://osf.io/wdzh5/") - assert dh_get.files == dh_info.files + assert dh_get.dataset.files == dh_info.files + + +def test_huggingface(tmpdir): + datahugger.get( + "https://huggingface.co/datasets/wikitext", + tmpdir, + params={"name": "wikitext-2-v1"}, + ) + + +def test_huggingface_without_params(tmpdir): + with pytest.raises(ValueError): + datahugger.get("https://huggingface.co/datasets/wikitext", tmpdir) diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 2a13898..40b008f 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -27,8 +27,10 @@ def test_resolve_service_via_doi_handle(tmpdir): doi = DOI.parse("10.34894/FXUGHW") doi.resolve() + r = datahugger.get(doi, tmpdir) + assert isinstance(doi, DOI) - assert isinstance(datahugger.get(doi, tmpdir), DataverseDataset) + assert isinstance(r.dataset, DataverseDataset) def test_get_doi_metadata_cls(tmpdir):