Skip to content

Commit

Permalink
fix(signed url): pass version id down to fsspec via path
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Dec 28, 2024
1 parent 1bd7f8b commit 1402c6a
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 12 deletions.
10 changes: 8 additions & 2 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,10 +1236,16 @@ def ls_dataset_rows(

return q.to_db_records()

def signed_url(self, source: str, path: str, client_config=None) -> str:
def signed_url(
self,
source: str,
path: str,
version_id: Optional[str] = None,
client_config=None,
) -> str:
client_config = client_config or self.client_config
client = Client.get_client(source, self.cache, **client_config)
return client.url(path)
return client.url(path, version_id=version_id)

Check warning on line 1248 in src/datachain/catalog/catalog.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/catalog/catalog.py#L1248

Added line #L1248 was not covered by tests

def export_dataset_table(
self,
Expand Down
24 changes: 23 additions & 1 deletion src/datachain/client/azure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
from typing import Any, Optional
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit

from adlfs import AzureBlobFileSystem
from tqdm import tqdm
Expand Down Expand Up @@ -57,4 +58,25 @@ async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> Non
finally:
result_queue.put_nowait(None)

@classmethod
def _split_version(cls, path: str) -> tuple[str, Optional[str]]:
parts = list(urlsplit(path))
query = parse_qs(parts[3])
if "versionid" in query:
version_id = query["versionid"][0]
del query["versionid"]
parts[3] = urlencode(query)

Check warning on line 68 in src/datachain/client/azure.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/azure.py#L66-L68

Added lines #L66 - L68 were not covered by tests
else:
version_id = None
return urlunsplit(parts), version_id

@classmethod
def _join_version(cls, path: str, version_id: Optional[str]) -> str:
parts = list(urlsplit(path))
query = parse_qs(parts[3])
if "versionid" in query:
raise ValueError("path already includes a version query")

Check warning on line 78 in src/datachain/client/azure.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/azure.py#L78

Added line #L78 was not covered by tests
parts[3] = f"versionid={version_id}" if version_id else ""
return urlunsplit(parts)

_fetch_default = _fetch_flat
17 changes: 15 additions & 2 deletions src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,18 @@ def create_fs(cls, **kwargs) -> "AbstractFileSystem":
return fs

@classmethod
def version_path(cls, path: str, version_id: Optional[str]) -> str:
def _split_version(cls, path: str) -> tuple[str, Optional[str]]:
return path, None

Check warning on line 142 in src/datachain/client/fsspec.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/fsspec.py#L142

Added line #L142 was not covered by tests

@classmethod
def _join_version(cls, path: str, version_id: Optional[str]) -> str:
return path

@classmethod
def version_path(cls, path: str, version_id: Optional[str]) -> str:
path, _ = cls._split_version(path)
return cls._join_version(path, version_id)

@classmethod
def from_name(
cls,
Expand Down Expand Up @@ -202,7 +211,11 @@ def fs(self) -> "AbstractFileSystem":
return self._fs

def url(self, path: str, expires: int = 3600, **kwargs) -> str:
return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs)
return self.fs.sign(

Check warning on line 214 in src/datachain/client/fsspec.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/fsspec.py#L214

Added line #L214 was not covered by tests
self.get_full_path(path, kwargs.pop("version_id", None)),
expiration=expires,
**kwargs,
)

async def get_current_etag(self, file: "File") -> str:
kwargs = {}
Expand Down
13 changes: 6 additions & 7 deletions src/datachain/client/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ def url(self, path: str, expires: int = 3600, **kwargs) -> str:
If the client is anonymous, a public URL is returned instead
(see https://cloud.google.com/storage/docs/access-public-data#api-link).
"""
version_id = kwargs.pop("version_id", None)
if self.fs.storage_options.get("token") == "anon":
return f"https://storage.googleapis.com/{self.name}/{path}"
return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs)
query = f"?generation={version_id}" if version_id else ""
return f"https://storage.googleapis.com/{self.name}/{path}{query}"
return self.fs.sign(

Check warning on line 46 in src/datachain/client/gcs.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/gcs.py#L46

Added line #L46 was not covered by tests
self.get_full_path(path, version_id), expiration=expires, **kwargs
)

@staticmethod
def parse_timestamp(timestamp: str) -> datetime:
Expand Down Expand Up @@ -150,8 +154,3 @@ def _join_version(cls, path: str, version_id: Optional[str]) -> str:
if path_version:
raise ValueError("path already includes an object generation")
return f"{path}#{version_id}" if version_id else path

@classmethod
def version_path(cls, path: str, version_id: Optional[str]) -> str:
path, _ = cls._split_version(path)
return cls._join_version(path, version_id)
22 changes: 22 additions & 0 deletions src/datachain/client/s3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from typing import Any, Optional, cast
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit

from botocore.exceptions import NoCredentialsError
from s3fs import S3FileSystem
Expand Down Expand Up @@ -121,6 +122,27 @@ def _entry_from_boto(self, v, bucket, versions=False) -> File:
size=v["Size"],
)

@classmethod
def _split_version(cls, path: str) -> tuple[str, Optional[str]]:
parts = list(urlsplit(path))
query = parse_qs(parts[3])
if "versionId" in query:
version_id = query["versionId"][0]
del query["versionId"]
parts[3] = urlencode(query)

Check warning on line 132 in src/datachain/client/s3.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/s3.py#L130-L132

Added lines #L130 - L132 were not covered by tests
else:
version_id = None
return urlunsplit(parts), version_id

@classmethod
def _join_version(cls, path: str, version_id: Optional[str]) -> str:
parts = list(urlsplit(path))
query = parse_qs(parts[3])
if "versionId" in query:
raise ValueError("path already includes a version query")

Check warning on line 142 in src/datachain/client/s3.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/client/s3.py#L142

Added line #L142 was not covered by tests
parts[3] = f"versionId={version_id}" if version_id else ""
return urlunsplit(parts)

async def _fetch_dir(
self,
prefix,
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_client_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,11 @@
def test_anon_url():
client = Client.get_client("gs://foo", None, anon=True)
assert client.url("bar") == "https://storage.googleapis.com/foo/bar"


def test_anon_versioned_url():
client = Client.get_client("gs://foo", None, anon=True)
assert (
client.url("bar", version_id="1234566")
== "https://storage.googleapis.com/foo/bar?generation=1234566"
)

0 comments on commit 1402c6a

Please sign in to comment.