Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to a regex pattern to filter files by name #92

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions datahugger/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def main():
help="Skip files larger than this size. Might not work for all services.",
)

parser.add_argument(
"--filter-files",
default=None,
type=str,
help="A regex pattern to filter files by name.",
)

parser.add_argument(
"-f", "--force-download", dest="force_download", action="store_true"
)
Expand Down Expand Up @@ -113,6 +120,7 @@ def main():
args.url_or_doi,
args.output_dir,
max_file_size=args.max_file_size,
filter_files=args.filter_files,
force_download=args.force_download,
unzip=args.unzip,
checksum=args.checksum,
Expand Down
8 changes: 8 additions & 0 deletions datahugger/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def parse_resource_identifier(resource, resolve=True):
def info(
resource,
max_file_size=None,
filter_files=None,
force_download=False,
unzip=True,
checksum=False,
Expand All @@ -69,6 +70,8 @@ def info(
max_file_size: int
The maximum number of bytes for a single file. If exceeded,
the file is skipped.
filter_files string
A regex pattern to filter files by name.
force_download: bool
Force the download of the dataset even if there are already
files in the destination folder. Default: False.
Expand Down Expand Up @@ -97,6 +100,7 @@ def info(
return service_class(
handle,
max_file_size=max_file_size,
filter_files=filter_files,
force_download=force_download,
unzip=unzip,
checksum=checksum,
Expand All @@ -110,6 +114,7 @@ def get(
resource,
output_folder,
max_file_size=None,
filter_files=None,
force_download=False,
unzip=True,
checksum=False,
Expand All @@ -131,6 +136,8 @@ def get(
max_file_size: int
The maximum number of bytes for a single file. If exceeded,
the file is skipped.
filter_files string
A regex pattern to filter files by name.
force_download: bool
Force the download of the dataset even if there are already
files in the destination folder. Default: False.
Expand All @@ -156,6 +163,7 @@ def get(
service = info(
resource,
max_file_size=max_file_size,
filter_files=filter_files,
force_download=force_download,
unzip=unzip,
checksum=checksum,
Expand Down
8 changes: 8 additions & 0 deletions datahugger/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self,
resource,
max_file_size=None,
filter_files=None,
force_download=False,
progress=True,
unzip=True,
Expand All @@ -61,6 +62,7 @@ def __init__(
super().__init__()
self.resource = resource
self.max_file_size = max_file_size
self.filter_files = filter_files
self.force_download = force_download
self.progress = progress
self.unzip = unzip
Expand Down Expand Up @@ -158,6 +160,12 @@ def download_file(
print(f"{_format_filename(file_name)}: SKIPPED")
return

if self.filter_files and not re.match(self.filter_files, file_name):
logging.info(f"Skipping file by filter {file_link}")
if self.progress:
print(f"{_format_filename(file_name)}: SKIPPED")
return

if not self.print_only:
logging.info(f"Downloading file {file_link}")
res = requests.get(file_link, stream=True)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_repositories_plus.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import pytest

import datahugger
Expand All @@ -14,3 +16,11 @@ def test_huggingface(tmpdir):
def test_huggingface_without_params(tmpdir):
with pytest.raises(ValueError):
datahugger.get("https://huggingface.co/datasets/wikitext", tmpdir)


def test_filter(tmpdir):
datahugger.get("https://zenodo.org/records/6614829", tmpdir, filter_files=r".*\.m")

files = [file for file in Path(tmpdir).iterdir()]
assert len(files) == 1
assert files[0].name == "quasiperiod.m"
Loading