Skip to content

Commit

Permalink
Add support for params and refactor download response (#64)
Browse files Browse the repository at this point in the history
Co-authored-by: Vetra <[email protected]>
  • Loading branch information
J535D165 and SexyVetra authored Sep 18, 2023
1 parent 2af4535 commit 438051e
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .
python -m pip install .[all]
python -m pip install pytest pytest-xdist
- name: Test with pytest
run: |
Expand Down
2 changes: 2 additions & 0 deletions datahugger/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from datahugger.api import get
from datahugger.api import info
from datahugger.api import parse_resource_identifier
from datahugger.base import DownloadResult
from datahugger.exceptions import DOIError
from datahugger.exceptions import RepositoryNotSupportedError

__all__ = [
"get",
"info",
"parse_resource_identifier",
"DownloadResult",
"DOIError",
"RepositoryNotSupportedError",
]
Expand Down
25 changes: 25 additions & 0 deletions datahugger/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ def print_green(s):
print(f"\u001b[32m{s}\u001b[0m")


class KVAppendAction(argparse.Action):
def __call__(self, parser, args, values, option_string=None):
try:
(k, v) = values[0].split("=", 2)
except ValueError as err:
raise argparse.ArgumentError(
self, f"'{values[0]}' is not a valid key-value pair (key=value)"
) from err

d = getattr(args, self.dest) or {}
d[k] = v
setattr(args, self.dest, d)


def main():
parser = argparse.ArgumentParser(
prog="datahugger",
Expand Down Expand Up @@ -69,6 +83,16 @@ def main():
help="Python based log levels. Default: WARNING.",
)

parser.add_argument(
"-p",
"--param",
nargs=1,
action=KVAppendAction,
dest="params",
help="Add key=value params to pass to the downloader. "
"May appear multiple times.",
)

# version
parser.add_argument(
"-V",
Expand All @@ -91,6 +115,7 @@ def main():
unzip=args.unzip,
progress=args.progress,
print_only=args.print_only,
params=args.params,
)

except DOIError as doi_err:
Expand Down
14 changes: 8 additions & 6 deletions datahugger/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,14 @@ def info(
unzip=True,
progress=True,
print_only=False,
**kwargs,
params=None,
):
"""Get info on the content of the dataset.
Arguments
---------
resource: str, pathlib.Path
The URL, DOI, or Handle of the dataset.
output_folder: str, pathlib.Path
The folder to download the dataset files to.
max_file_size: int
The maximum number of bytes for a single file. If exceeded,
the file is skipped.
Expand All @@ -80,6 +78,8 @@ def info(
print_only: bool
Print the output of the dataset download without downloading
the actual files (Dry run). Default: False.
params: dict
Extra parameters for the request.
Returns
-------
Expand All @@ -98,7 +98,7 @@ def info(
unzip=unzip,
progress=progress,
print_only=print_only,
**kwargs,
params=params,
)


Expand All @@ -110,7 +110,7 @@ def get(
unzip=True,
progress=True,
print_only=False,
**kwargs,
params=None,
):
"""Get the content of repository.
Expand All @@ -136,6 +136,8 @@ def get(
print_only: bool
Print the output of the dataset download without downloading
the actual files (Dry run). Default: False.
params: dict
Extra parameters for the request.
Returns
-------
Expand All @@ -151,7 +153,7 @@ def get(
unzip=unzip,
progress=progress,
print_only=print_only,
**kwargs,
params=params,
)

return service.download(output_folder)
63 changes: 32 additions & 31 deletions datahugger/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
from datahugger.utils import _is_url


class DatasetResult:
class DownloadResult:
"""Result class after downloading the dataset."""

def __init__(self, dataset, output_folder):
self.dataset = dataset
self.output_folder = output_folder

def __str__(self):
return f"<{self.__class__.__name__} n_files={len(self)} >"

def __len__(self):
return len(self.files)
return len(self.dataset.files)

def tree(self, **kwargs):
"""Return the folder tree.
Expand All @@ -49,6 +53,7 @@ def __init__(
progress=True,
unzip=True,
print_only=False,
params=None,
):
super().__init__()
self.resource = resource
Expand All @@ -58,6 +63,7 @@ def __init__(
self.progress = progress
self.unzip = unzip
self.print_only = print_only
self.params = params

def _get_attr_attr(self, record, jsonp):
try:
Expand Down Expand Up @@ -197,20 +203,6 @@ def _unpack_single_folder(self, zip_url, output_folder):
zip_info.filename = os.path.basename(zip_info.filename)
z.extract(zip_info, output_folder)

@property
def _params(self):
if hasattr(self, "__params"):
return self.__params

url = _get_url(self.resource)

# if isinstance(url, str) and _is_url(url):
self.__params = self._parse_url(url)
# else:
# self.__params = {"record_id": url, "version": None}

return self.__params

def _pre_files(self):
pass

Expand Down Expand Up @@ -294,16 +286,36 @@ def _get_single_file(self, url, folder_name=None, base_url=None):
}
]

@property
def _params(self):
"""Params including url params."""
if hasattr(self, "__params"):
return self.__params

url = _get_url(self.resource)
url_params = self._parse_url(url)
if self.params:
new_params = self.params.copy()
new_params.update(url_params)
self.__params = new_params
else:
self.__params = url_params

return self.__params

@property
def files(self):
if hasattr(self, "_files"):
return self._files

self._pre_files()

uri = urlparse(_get_url(self.resource))
url = _get_url(self.resource)
uri = urlparse(url)
base_url = uri.scheme + "://" + uri.netloc

print(self._params)

if hasattr(self, "is_singleton") and self.is_singleton:
self._files = self._get_single_file(
self.API_URL_META_SINGLE.format(
Expand All @@ -324,7 +336,6 @@ def files(self):
def _get(
self,
output_folder: Union[Path, str],
**kwargs,
):
if (
len(self.files) == 1
Expand All @@ -347,26 +358,16 @@ def _get(
def download(
self,
output_folder: Union[Path, str],
**kwargs,
):
"""Download files for the given URL or record id.
"""Download files.
Arguments
---------
record_id_or_url: str
The identifier of the record or the url to the resource
to download.
output_folder: str
The folder to store the downloaded results.
version: str, int
The version of the dataset
"""
Path(output_folder).mkdir(parents=True, exist_ok=True)

self._get(output_folder, **kwargs)

# store the location of the last known output folder
self.output_folder = output_folder
self._get(output_folder=output_folder)

return self
return DownloadResult(self, output_folder)
31 changes: 15 additions & 16 deletions datahugger/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from jsonpath_ng import parse

from datahugger.base import DatasetDownloader
from datahugger.base import DatasetResult
from datahugger.utils import _get_url


class ZenodoDataset(DatasetDownloader, DatasetResult):
class ZenodoDataset(DatasetDownloader):
"""Downloader for Zenodo repository.
For Zenodo records, new versions have new identifiers.
Expand Down Expand Up @@ -42,7 +41,7 @@ def _get_attr_hash_type(self, record):
return self._get_attr_attr(record, self.ATTR_HASH_JSONPATH).split(":")[0]


class DataverseDataset(DatasetDownloader, DatasetResult):
class DataverseDataset(DatasetDownloader):
"""Downloader for Dataverse repository."""

REGEXP_ID = r"(?P<type>dataset|file)\.xhtml\?persistentId=(?P<record_id>.*)"
Expand All @@ -68,7 +67,7 @@ def _pre_files(self):
self.is_singleton = True


class FigShareDataset(DatasetDownloader, DatasetResult):
class FigShareDataset(DatasetDownloader):
"""Downloader for FigShare repository."""

REGEXP_ID = r"articles\/.*?\/.*?\/(?P<record_id>\d+)(?:\/(?P<version>\d+)|)"
Expand Down Expand Up @@ -96,7 +95,7 @@ class DjehutyDataset(FigShareDataset):
API_URL = "https://data.4tu.nl/v2"


class OSFDataset(DatasetDownloader, DatasetResult):
class OSFDataset(DatasetDownloader):
"""Downloader for OSF repository."""

REGEXP_ID = r"osf\.io\/(?P<record_id>.*)/"
Expand All @@ -122,7 +121,7 @@ class OSFDataset(DatasetDownloader, DatasetResult):
ATTR_HASH_TYPE_VALUE = "sha256"


class DataDryadDataset(DatasetDownloader, DatasetResult):
class DataDryadDataset(DatasetDownloader):
"""Downloader for DataDryad repository."""

REGEXP_ID = r"datadryad\.org[\:]*[43]{0,3}\/stash\/dataset\/doi:(?P<record_id>.*)"
Expand Down Expand Up @@ -180,7 +179,7 @@ def _get_attr_link(self, record):
return "https://datadryad.org" + record["_links"]["stash:file-download"]["href"]


class DataOneDataset(DatasetDownloader, DatasetResult):
class DataOneDataset(DatasetDownloader):
"""Downloader for DataOne repositories."""

REGEXP_ID = r"view/doi:(?P<record_id>.*)"
Expand Down Expand Up @@ -218,7 +217,7 @@ def files(self):
return self._files


class DSpaceDataset(DatasetDownloader, DatasetResult):
class DSpaceDataset(DatasetDownloader):
"""Downloader for DSpaceDataset repositories."""

REGEXP_ID = r"handle/(?P<record_id>\d+\/\d+)"
Expand Down Expand Up @@ -248,7 +247,7 @@ def _pre_files(self):
self.API_URL_META = base_url + res.json()["link"] + "/bitstreams"


class MendeleyDataset(DatasetDownloader, DatasetResult):
class MendeleyDataset(DatasetDownloader):
"""Downloader for Mendeley repository."""

REGEXP_ID = r"data\.mendeley\.com\/datasets\/(?P<record_id>[0-9a-z]+)(?:\/(?P<version>\d+)|)" # noqa
Expand Down Expand Up @@ -280,7 +279,7 @@ def _pre_files(self):
self.version = r_version.json()[-1]["version"]


class GitHubDataset(DatasetDownloader, DatasetResult):
class GitHubDataset(DatasetDownloader):
"""Downloader for GitHub repository."""

API_URL = "https://github.com/"
Expand All @@ -296,18 +295,17 @@ def _get(self, output_folder: Union[Path, str], *args, **kwargs):
@property
def files(self):
# at the moment, .files is not available for GitHub
raise AttributeError("'files' is not available for GitHub")
raise NotImplementedError("'files' is not available for GitHub")


class HuggingFaceDataset(DatasetDownloader, DatasetResult):
class HuggingFaceDataset(DatasetDownloader):
"""Downloader for Huggingface repository."""

REGEXP_ID = r"huggingface.co/datasets/(?P<record_id>.*)"

def _get(
self,
output_folder: Union[Path, str],
**kwargs,
):
try:
from datasets import load_dataset
Expand All @@ -317,15 +315,16 @@ def _get(
" or use 'pip install datahugger[all]'"
) from err

load_dataset(self._params["record_id"], cache_dir=output_folder, **kwargs)
params = self.params if self.params else {}
load_dataset(self._params["record_id"], cache_dir=output_folder, **params)

@property
def files(self):
# at the moment, .files is not available for HuggingFace
raise AttributeError("'files' is not available for HuggingFace")
raise NotImplementedError("'files' is not available for HuggingFace")


class ArXivDataset(DatasetDownloader, DatasetResult):
class ArXivDataset(DatasetDownloader):
"""Downloader for ArXiv publication."""

REGEXP_ID = r"https://arxiv\.org/abs/(?P<record_id>.*)"
Expand Down
Loading

0 comments on commit 438051e

Please sign in to comment.