Skip to content

Commit

Permalink
fix(signed url): pass version id down to fsspec via path
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Dec 29, 2024
1 parent 1bd7f8b commit 84885bc
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ tests = [
"pytest-sugar>=0.9.6",
"pytest-cov>=4.1.0",
"pytest-mock>=3.12.0",
"pytest-servers[all]>=0.5.8",
"pytest-servers[all]>=0.5.9",
"pytest-benchmark[histogram]",
"pytest-xdist>=3.3.1",
"virtualenv",
Expand Down
10 changes: 8 additions & 2 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,10 +1236,16 @@ def ls_dataset_rows(

return q.to_db_records()

def signed_url(self, source: str, path: str, client_config=None) -> str:
def signed_url(
self,
source: str,
path: str,
version_id: Optional[str] = None,
client_config=None,
) -> str:
client_config = client_config or self.client_config
client = Client.get_client(source, self.cache, **client_config)
return client.url(path)
return client.url(path, version_id=version_id)

def export_dataset_table(
self,
Expand Down
22 changes: 21 additions & 1 deletion src/datachain/client/azure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
from typing import Any, Optional
from urllib.parse import parse_qs, urlsplit, urlunsplit

from adlfs import AzureBlobFileSystem
from tqdm import tqdm
Expand All @@ -25,6 +26,16 @@ def info_to_file(self, v: dict[str, Any], path: str) -> File:
size=v.get("size", ""),
)

def url(self, path: str, expires: int = 3600, **kwargs) -> str:
"""
Generate a signed URL for the given path.
"""
version_id = kwargs.pop("version_id", None)
result = self.fs.sign(
self.get_full_path(path, version_id), expiration=expires, **kwargs
)
return result + (f"&versionid={version_id}" if version_id else "")

async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None:
prefix = start_prefix
if prefix:
Expand Down Expand Up @@ -57,4 +68,13 @@ async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> Non
finally:
result_queue.put_nowait(None)

@classmethod
def version_path(cls, path: str, version_id: Optional[str]) -> str:
parts = list(urlsplit(path))
query = parse_qs(parts[3])
if "versionid" in query:
raise ValueError("path already includes a version query")

Check warning on line 76 in src/datachain/client/azure.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/azure.py#L76

Added line #L76 was not covered by tests
parts[3] = f"versionid={version_id}" if version_id else ""
return urlunsplit(parts)

_fetch_default = _fetch_flat
6 changes: 5 additions & 1 deletion src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,11 @@ def fs(self) -> "AbstractFileSystem":
return self._fs

def url(self, path: str, expires: int = 3600, **kwargs) -> str:
return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs)
return self.fs.sign(
self.get_full_path(path, kwargs.pop("version_id", None)),
expiration=expires,
**kwargs,
)

async def get_current_etag(self, file: "File") -> str:
kwargs = {}
Expand Down
30 changes: 7 additions & 23 deletions src/datachain/client/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections.abc import Iterable
from datetime import datetime
from typing import Any, Optional, cast
from urllib.parse import urlsplit

from dateutil.parser import isoparse
from gcsfs import GCSFileSystem
Expand Down Expand Up @@ -39,9 +38,13 @@ def url(self, path: str, expires: int = 3600, **kwargs) -> str:
If the client is anonymous, a public URL is returned instead
(see https://cloud.google.com/storage/docs/access-public-data#api-link).
"""
version_id = kwargs.pop("version_id", None)
if self.fs.storage_options.get("token") == "anon":
return f"https://storage.googleapis.com/{self.name}/{path}"
return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs)
query = f"?generation={version_id}" if version_id else ""
return f"https://storage.googleapis.com/{self.name}/{path}{query}"
return self.fs.sign(
self.get_full_path(path, version_id), expiration=expires, **kwargs
)

@staticmethod
def parse_timestamp(timestamp: str) -> datetime:
Expand Down Expand Up @@ -133,25 +136,6 @@ def info_to_file(self, v: dict[str, Any], path: str) -> File:
size=v.get("size", ""),
)

@classmethod
def _split_version(cls, path: str) -> tuple[str, Optional[str]]:
parts = list(urlsplit(path))
scheme = parts[0]
parts = GCSFileSystem._split_path( # pylint: disable=protected-access
path, version_aware=True
)
bucket, key, generation = parts
scheme = f"{scheme}://" if scheme else ""
return f"{scheme}{bucket}/{key}", generation

@classmethod
def _join_version(cls, path: str, version_id: Optional[str]) -> str:
path, path_version = cls._split_version(path)
if path_version:
raise ValueError("path already includes an object generation")
return f"{path}#{version_id}" if version_id else path

@classmethod
def version_path(cls, path: str, version_id: Optional[str]) -> str:
path, _ = cls._split_version(path)
return cls._join_version(path, version_id)
return f"{path}#{version_id}" if version_id else path
10 changes: 10 additions & 0 deletions src/datachain/client/s3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from typing import Any, Optional, cast
from urllib.parse import parse_qs, urlsplit, urlunsplit

from botocore.exceptions import NoCredentialsError
from s3fs import S3FileSystem
Expand Down Expand Up @@ -121,6 +122,15 @@ def _entry_from_boto(self, v, bucket, versions=False) -> File:
size=v["Size"],
)

@classmethod
def version_path(cls, path: str, version_id: Optional[str]) -> str:
parts = list(urlsplit(path))
query = parse_qs(parts[3])
if "versionId" in query:
raise ValueError("path already includes a version query")
parts[3] = f"versionId={version_id}" if version_id else ""
return urlunsplit(parts)

async def _fetch_dir(
self,
prefix,
Expand Down
9 changes: 9 additions & 0 deletions tests/func/fake-service-account-credentials.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"type": "service_account",
"project_id": "gcsfs",
"private_key_id": "84e3fd6d7101ec632e7348e8940b2aca71133e71",
"private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDAJWz1KlBu2jRE\nlUahHKuJes34hj4pr8ADhgejpAguBBrubXVvSro7aSSbvyDC/GIcyDQ8Q33YK/kT\nufQvCez7iIACbtP53o6WjcrIAP+l8z9RUL9so+sBCaVRZzh74+cEMfWIbc3ACBB5\nU2BPBWQFtr3Qtbe8TUJ+liNcLb8I2JznfydHvl9cn0/50HeOB99Xho5JAY75aE0Y\nT+/aMTFlr/kUbekLRRi4pyE+uOA/ei5RmfwzqO366YLMtEC2DaHwTqSuxBWnbtTW\nu/OvYpmPHazd6own2zJLQ0Elnm5WC/d9YmxhHi/8pJFkkbVf/2CYWEBbmBI3ZOx3\n/nHQwcIPAgMBAAECggEAUztC/dYE/me10WmKLTrykTxpYTihT8RqG/ygbYGd63Tq\nx5IRlxJbJmYOrgp2IhBaXZZZjis8JXoyzBk2TXPyvChuLt+cIfYGdO/ZwZYxJ0z9\nhfdA3EoK/6mSe3cHcB8SEG6lqaHKyN6VaEC2DLTMlW8JvREiFEaxQY0+puzH/ge4\n2EypCP4pvlveH78EIIipPgWcJYGpv0bv8KErECuVHRjJv6vZqUjQdcIi73mCz/5u\nnQqLY8j9lOuCr9vBis7DZIyY2tn4vfqcqxfH9wuIFXnzIQW6Wyg0+bBQydHg1kJ2\nFOszfkBVxZ6LpcHGB4CV4c5z7Me2cMReXQz6VsyoLQKBgQD9v92rHZYDBy4/vGxx\nbpfUkAlcCGW8GXu+qsdmyhZdjSdjDLY6lav+6UoHIJgmnA7LsKPFgnEDrdn78KBb\n3wno3VHfozL5kF887q9hC/+UurwScCKIw5QkmWtsStVgjr6wPmAu6rspMz5xNjaa\nSU4YzlNcbBUUXUawhXytWPR+OwKBgQDB2bDCD00R2yfYFdjAKapqenOtMvrnihUi\nW9Se7Yizme7s25fDxF5CBPpOdKPU2EZUlqBC/5182oMUP/xYUOHJkuUhbYcvU0qr\n+BQewLwr6rs+O1QPTh/6e70SUFR+YJLaAHkDc6fvcdjtl+Zx/p02Zj+UiW3/D4Jj\nc0EqVr4qPQKBgQCbJx3a6xQ2dcWJoySLlxuvFQMkCt5pzQsk4jdaWmaifRSAM92Y\npLut+ecRxJRDx1gko7T/p2qC3WJT8iWbBx2ADRNqstcQUX5qO2dw5202+5bTj00O\nYsfKOSS96mPdzmo6SWl2RoB6CKM9hfCNFhVyhXXjJRMeiIoYlQZO1/1m0QKBgCzz\nat6FJ8z1MdcUsc9VmhPY00wdXzsjtOTjwHkeAa4MCvBXt2iI94Z9mwFoYLkxcZWZ\n3A3NMlrKXMzsTXq5PrI8Yu+Oc2OQ/+bCvv+ml7vjUYoLveFSr22pFd3STNWFVWhB\n5c3cGtwWXUQzDhfu/8umiCXMfHpBwW2IQ1srBCvNAoGATcC3oCFBC/HdGxdeJC5C\n59EoFvKdZsAdc2I5GS/DtZ1Wo9sXqubCaiUDz+4yty+ssHIZ1ikFr8rWfL6KFEs2\niTe+kgM/9FLFtftf1WDpbfIOumbz/6CiGLqsGNlO3ZaU0kYJ041SZ8RleTOYa0zO\noSTLwBo3vje+aflytEwS8SI=\n-----END PRIVATE KEY-----",
"client_email": "[email protected]",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token"
}
67 changes: 67 additions & 0 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from urllib.parse import urlparse

import pytest
import requests
import yaml
from fsspec.implementations.local import LocalFileSystem

Expand Down Expand Up @@ -993,3 +994,69 @@ def test_garbage_collect(cloud_test_catalog, from_cli, capsys):
else:
catalog.cleanup_tables(temp_tables)
assert catalog.get_temp_table_names() == []


@pytest.mark.parametrize("tree", [{"test-signed-file": "original"}], indirect=True)
@pytest.mark.parametrize(
"cloud_type, version_aware",
(["s3", False], ["azure", False], ["gs", False]),
indirect=True,
)
def test_signed_url(cloud_test_catalog, monkeypatch):
monkeypatch.setenv(
"GOOGLE_APPLICATION_CREDENTIALS",
os.path.dirname(__file__) + "/fake-service-account-credentials.json",
)

signed_url = cloud_test_catalog.catalog.signed_url(
cloud_test_catalog.src_uri, "test-signed-file"
)
content = requests.get(signed_url, timeout=10).text
assert content == "original"


@pytest.mark.parametrize(
"tree", [{"test-signed-file-versioned": "original"}], indirect=True
)
@pytest.mark.parametrize(
"cloud_type, version_aware",
(["s3", True], ["azure", True], ["gs", True]),
indirect=True,
)
def test_signed_url_versioned(cloud_test_catalog, monkeypatch):
file_uri = f"{cloud_test_catalog.src_uri}/test-signed-file-versioned"
file_original = next(
DataChain.from_storage(file_uri, session=cloud_test_catalog.session).collect(
"file"
)
)

(cloud_test_catalog.src / "test-signed-file-versioned").write_text("modified")

file_modified = next(
DataChain.from_storage(file_uri, session=cloud_test_catalog.session).collect(
"file"
)
)
monkeypatch.setenv(
"GOOGLE_APPLICATION_CREDENTIALS",
os.path.dirname(__file__) + "/fake-service-account-credentials.json",
)

signed_url = cloud_test_catalog.catalog.signed_url(
cloud_test_catalog.src_uri,
"test-signed-file-versioned",
version_id=file_original.version,
)

content = requests.get(signed_url, timeout=10).text
assert content == "original"

signed_url = cloud_test_catalog.catalog.signed_url(
cloud_test_catalog.src_uri,
"test-signed-file-versioned",
version_id=file_modified.version,
)

content = requests.get(signed_url, timeout=10).text
assert content == "modified"
8 changes: 8 additions & 0 deletions tests/unit/test_client_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,11 @@
def test_anon_url():
client = Client.get_client("gs://foo", None, anon=True)
assert client.url("bar") == "https://storage.googleapis.com/foo/bar"


def test_anon_versioned_url():
client = Client.get_client("gs://foo", None, anon=True)
assert (
client.url("bar", version_id="1234566")
== "https://storage.googleapis.com/foo/bar?generation=1234566"
)
6 changes: 6 additions & 0 deletions tests/unit/test_client_s3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from datachain.client.s3 import ClientS3
from datachain.node import DirType, Node
from datachain.nodes_thread_pool import NodeChunk

Expand Down Expand Up @@ -77,3 +78,8 @@ def test_node_bucket_full_split(nodes):
assert len(bkt[1]) == 1
assert len(bkt[2]) == 1
assert len(bkt[3]) == 1


def test_version_path_already_has_version():
with pytest.raises(ValueError):
ClientS3.version_path("s3://foo/bar?versionId=123", "456")

0 comments on commit 84885bc

Please sign in to comment.