Skip to content

Commit

Permalink
fix: add warnings in DataConnectionComputeDifferencesPayload for sy…
Browse files Browse the repository at this point in the history
…nc cloud storage (#1499)
  • Loading branch information
Jonas1312 authored Oct 2, 2023
1 parent 1b7ca3c commit 46e90eb
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 35 deletions.
24 changes: 17 additions & 7 deletions src/kili/services/data_connection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GQL_COMPUTE_DATA_CONNECTION_DIFFERENCES,
GQL_VALIDATE_DATA_DIFFERENCES,
)
from kili.services.project import get_project_field

LOGGER = None

Expand Down Expand Up @@ -107,7 +108,7 @@ def compute_differences(kili, data_connection_id: str) -> Dict:

data_integration = data_connection["dataIntegration"]

blob_paths = None
blob_paths = warnings = None

# for azure using credentials, it is required to provide the blob paths to compute the diffs
if (
Expand All @@ -124,22 +125,31 @@ def compute_differences(kili, data_connection_id: str) -> Dict:

try:
# pylint: disable=import-outside-toplevel
from .azure import (
get_blob_paths_azure_data_connection_with_service_credentials,
)
from .azure import AzureBucket
except ImportError as err:
raise ImportError(
"The azure-storage-blob package is required to use Azure buckets. "
" Run `pip install kili[azure]` to install it."
) from err

blob_paths = get_blob_paths_azure_data_connection_with_service_credentials(
data_connection=data_connection, data_integration=data_integration
blob_paths, warnings = AzureBucket(
sas_token=data_integration["azureSASToken"],
connection_url=data_integration["azureConnectionURL"],
).get_blob_paths_azure_data_connection_with_service_credentials(
data_connection["selectedFolders"],
input_type=get_project_field(
kili,
project_id=get_data_connection(
kili, data_connection_id=data_connection_id, fields=("projectId",)
)["projectId"],
field="inputType",
),
)

variables: Dict[str, Any] = {"where": {"id": data_connection_id}}
if blob_paths is not None:
variables["data"] = {"blobPaths": blob_paths}
variables["data"] = {"blobPaths": blob_paths, "warnings": warnings}

result = kili.graphql_client.execute(GQL_COMPUTE_DATA_CONNECTION_DIFFERENCES, variables)
return format_result("data", result, None, kili.http_client)

Expand Down
92 changes: 65 additions & 27 deletions src/kili/services/data_connection/azure.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Code specific to Azure blob storage."""

from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse

from azure.storage.blob import BlobServiceClient

from kili.domain.project import InputType


class AzureBucket:
"""Class for Azure blob storage buckets."""
Expand Down Expand Up @@ -33,10 +35,6 @@ def _split_connection_url_into_storage_account_and_container_name(
container_name = url_connection.path.lstrip("/")
return storage_account, container_name

def get_blob_paths(self) -> List[str]:
"""List files in the Azure bucket."""
return list(self.storage_bucket.list_blob_names())

def get_blob_paths_as_tree(self) -> Dict:
"""Get a tree representation of the Azure bucket.
Expand All @@ -58,28 +56,68 @@ def get_blob_paths_as_tree(self) -> Dict:

return filetree

def get_blob_paths_azure_data_connection_with_service_credentials(
self, selected_folders: Optional[List[str]], input_type: InputType
) -> Tuple[List[str], List[Optional[str]]]:
"""Get the blob paths for an Azure data connection using service credentials."""
blob_paths = []
warnings = set()
for blob in self.storage_bucket.list_blobs():
if not hasattr(blob, "name") or not isinstance(blob.name, str):
continue

# blob_paths_in_bucket contains all blob paths in the bucket, we need to filter them
# to keep only the ones in the data connection selected folders
if isinstance(selected_folders, List) and not any(
blob.name.startswith(selected_folder) for selected_folder in selected_folders
):
continue

has_content_type_field = (
hasattr(blob, "content_settings")
and hasattr(blob.content_settings, "content_type")
and isinstance(blob.content_settings.content_type, str)
)
if not has_content_type_field:
warnings.add("Objects with missing content-type were ignored")

elif not self._is_content_type_compatible_with_input_type(
blob.content_settings.content_type, # pyright: ignore[reportGeneralTypeIssues]
input_type,
):
warnings.add(
"Objects with unsupported content-type for this type of project were ignored"
)

def get_blob_paths_azure_data_connection_with_service_credentials(
data_integration: Dict, data_connection: Dict
) -> List[str]:
"""Get the blob paths for an Azure data connection using service credentials."""
azure_client = AzureBucket(
sas_token=data_integration["azureSASToken"],
connection_url=data_integration["azureConnectionURL"],
)

blob_paths = azure_client.get_blob_paths()

# blob_paths_in_bucket contains all blob paths in the bucket, we need to filter them
# to keep only the ones in the data connection selected folders
if isinstance(data_connection["selectedFolders"], List):
blob_paths = [
blob_path
for blob_path in blob_paths
if any(
blob_path.startswith(selected_folder)
for selected_folder in data_connection["selectedFolders"]
else:
blob_paths.append(blob.name)

return blob_paths, list(warnings)

@staticmethod
def _is_content_type_compatible_with_input_type(
content_type: str, input_type: InputType
) -> bool:
"""Check if the content type is compatible with the input type."""
if input_type == "IMAGE":
return content_type.startswith("image")

if input_type == "VIDEO":
return content_type.startswith("video")

if input_type == "PDF":
return content_type.startswith("application/pdf")

if input_type == "TEXT":
return any(
content_type.startswith(text_type)
for text_type in (
"application/json",
"text/plain",
"text/html",
"text/csv",
"text/xml",
)
)
]

return blob_paths
raise ValueError(f"Unknown project input type: {input_type}")
1 change: 0 additions & 1 deletion tests/e2e/test_cloud_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def is_same_endpoint(endpoint_short_name: str, endpoint_url: str) -> bool:
("STAGING", "GCP", "f474c0170c8daa09ec2e368ce4720c73", None, 5),
],
)
@pytest.mark.skip("to fix")
def test_e2e_synchronize_cloud_storage_connection(
kili: Kili,
src_project: Dict,
Expand Down

0 comments on commit 46e90eb

Please sign in to comment.