Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: bump version to 2.147.0 (#1489) #1513

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions src/kili/adapters/kili_api_gateway/tag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
GQL_UNCHECK_TAG,
GQL_UPDATE_TAG,
get_list_tags_by_org_query,
get_list_tags_by_project_query,
)
from .types import UpdateTagReturnData

Expand All @@ -30,24 +31,11 @@ def list_tags_by_org(self, fields: ListOrTuple[str]) -> List[Dict]:

def list_tags_by_project(self, project_id: ProjectId, fields: ListOrTuple[str]) -> List[Dict]:
"""Send a GraphQL request calling listTagsByProject resolver."""
# fragment = fragment_builder(fields=fields) # noqa: ERA001
# query = get_list_tags_by_project_query(fragment)# noqa: ERA001
# variables = {"projectId": project_id}# noqa: ERA001
# result = self.graphql_client.execute(query, variables)# noqa: ERA001
# return result["data"] # noqa: ERA001
# TODO: listTagsByProject is broken currently. Use listTagsByOrg instead.

fields_with_project_ids = (
("checkedForProjects", *fields) if "checkedForProjects" not in fields else fields
)
tags_of_org = self.list_tags_by_org(fields=fields_with_project_ids)
tags_of_project = [tag for tag in tags_of_org if project_id in tag["checkedForProjects"]]

if "checkedForProjects" not in fields:
for tag in tags_of_project:
del tag["checkedForProjects"]

return tags_of_project
fragment = fragment_builder(fields=fields)
query = get_list_tags_by_project_query(fragment)
variables = {"projectId": project_id}
result = self.graphql_client.execute(query, variables)
return result["data"]

def check_tag(self, project_id: ProjectId, tag_id: TagId) -> TagId:
"""Send a GraphQL request calling checkTag resolver."""
Expand Down
8 changes: 7 additions & 1 deletion src/kili/core/graphql/graphql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _get_graphql_schema_from_endpoint(self) -> str:
fetch_schema_from_transport=True,
introspection_args=self._get_introspection_args(),
) as session:
return print_schema(session.client.schema) # type: ignore
return print_schema(session.client.schema) # pyright: ignore[reportGeneralTypeIssues]

def _cache_graphql_schema(self, graphql_schema_path: Path, schema_str: str) -> None:
"""Cache the graphql schema on disk."""
Expand Down Expand Up @@ -233,6 +233,11 @@ def _get_kili_app_version(self) -> Optional[str]:
return response_json["version"]
return None

@staticmethod
def _remove_nullable_inputs(variables: Dict) -> Dict:
"""Remove nullable inputs from the variables."""
return {k: v for k, v in variables.items() if v is not None}

def execute(
self, query: Union[str, DocumentNode], variables: Optional[Dict] = None, **kwargs
) -> Dict[str, Any]:
Expand All @@ -244,6 +249,7 @@ def execute(
kwargs: additional arguments to pass to the GraphQL client
"""
document = query if isinstance(query, DocumentNode) else gql(query)
variables = self._remove_nullable_inputs(variables) if variables else None

try:
return self._execute_with_retries(document, variables, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions src/kili/entrypoints/mutations/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def append_many_to_dataset(
Example for one asset: `json_metadata_array = [{'imageUrl': '','text': '','url': ''}]`.
- For VIDEO projects (and not VIDEO_LEGACY), you can specify a value with key 'processingParameters' to specify the sampling rate (default: 30).
Example for one asset: `json_metadata_array = [{'processingParameters': {'framesPlayedPerSecond': 10}}]`.
- For Image projects, if you work with geotiff, you can specify a value with key 'processingParameters' to specify the minimum and maximum zoom level.
Example for one asset: `json_metadata_array = [{'processingParameters': {'minZoom': 17, 'maxZoom': 19}}]`.
disable_tqdm: If `True`, the progress bar will be disabled
wait_until_availability: If `True`, the function will return once the assets are fully imported in Kili.
If `False`, the function will return faster but the assets might not be fully processed by the server.
Expand Down
24 changes: 17 additions & 7 deletions src/kili/services/data_connection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GQL_COMPUTE_DATA_CONNECTION_DIFFERENCES,
GQL_VALIDATE_DATA_DIFFERENCES,
)
from kili.services.project import get_project_field

LOGGER = None

Expand Down Expand Up @@ -107,7 +108,7 @@ def compute_differences(kili, data_connection_id: str) -> Dict:

data_integration = data_connection["dataIntegration"]

blob_paths = None
blob_paths = warnings = None

# for azure using credentials, it is required to provide the blob paths to compute the diffs
if (
Expand All @@ -124,22 +125,31 @@ def compute_differences(kili, data_connection_id: str) -> Dict:

try:
# pylint: disable=import-outside-toplevel
from .azure import (
get_blob_paths_azure_data_connection_with_service_credentials,
)
from .azure import AzureBucket
except ImportError as err:
raise ImportError(
"The azure-storage-blob package is required to use Azure buckets. "
" Run `pip install kili[azure]` to install it."
) from err

blob_paths = get_blob_paths_azure_data_connection_with_service_credentials(
data_connection=data_connection, data_integration=data_integration
blob_paths, warnings = AzureBucket(
sas_token=data_integration["azureSASToken"],
connection_url=data_integration["azureConnectionURL"],
).get_blob_paths_azure_data_connection_with_service_credentials(
data_connection["selectedFolders"],
input_type=get_project_field(
kili,
project_id=get_data_connection(
kili, data_connection_id=data_connection_id, fields=("projectId",)
)["projectId"],
field="inputType",
),
)

variables: Dict[str, Any] = {"where": {"id": data_connection_id}}
if blob_paths is not None:
variables["data"] = {"blobPaths": blob_paths}
variables["data"] = {"blobPaths": blob_paths, "warnings": warnings}

result = kili.graphql_client.execute(GQL_COMPUTE_DATA_CONNECTION_DIFFERENCES, variables)
return format_result("data", result, None, kili.http_client)

Expand Down
92 changes: 65 additions & 27 deletions src/kili/services/data_connection/azure.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Code specific to Azure blob storage."""

from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse

from azure.storage.blob import BlobServiceClient

from kili.domain.project import InputType


class AzureBucket:
"""Class for Azure blob storage buckets."""
Expand Down Expand Up @@ -33,10 +35,6 @@ def _split_connection_url_into_storage_account_and_container_name(
container_name = url_connection.path.lstrip("/")
return storage_account, container_name

def get_blob_paths(self) -> List[str]:
"""List files in the Azure bucket."""
return list(self.storage_bucket.list_blob_names())

def get_blob_paths_as_tree(self) -> Dict:
"""Get a tree representation of the Azure bucket.

Expand All @@ -58,28 +56,68 @@ def get_blob_paths_as_tree(self) -> Dict:

return filetree

def get_blob_paths_azure_data_connection_with_service_credentials(
self, selected_folders: Optional[List[str]], input_type: InputType
) -> Tuple[List[str], List[Optional[str]]]:
"""Get the blob paths for an Azure data connection using service credentials."""
blob_paths = []
warnings = set()
for blob in self.storage_bucket.list_blobs():
if not hasattr(blob, "name") or not isinstance(blob.name, str):
continue

# blob_paths_in_bucket contains all blob paths in the bucket, we need to filter them
# to keep only the ones in the data connection selected folders
if isinstance(selected_folders, List) and not any(
blob.name.startswith(selected_folder) for selected_folder in selected_folders
):
continue

has_content_type_field = (
hasattr(blob, "content_settings")
and hasattr(blob.content_settings, "content_type")
and isinstance(blob.content_settings.content_type, str)
)
if not has_content_type_field:
warnings.add("Objects with missing content-type were ignored")

elif not self._is_content_type_compatible_with_input_type(
blob.content_settings.content_type, # pyright: ignore[reportGeneralTypeIssues]
input_type,
):
warnings.add(
"Objects with unsupported content-type for this type of project were ignored"
)

def get_blob_paths_azure_data_connection_with_service_credentials(
data_integration: Dict, data_connection: Dict
) -> List[str]:
"""Get the blob paths for an Azure data connection using service credentials."""
azure_client = AzureBucket(
sas_token=data_integration["azureSASToken"],
connection_url=data_integration["azureConnectionURL"],
)

blob_paths = azure_client.get_blob_paths()

# blob_paths_in_bucket contains all blob paths in the bucket, we need to filter them
# to keep only the ones in the data connection selected folders
if isinstance(data_connection["selectedFolders"], List):
blob_paths = [
blob_path
for blob_path in blob_paths
if any(
blob_path.startswith(selected_folder)
for selected_folder in data_connection["selectedFolders"]
else:
blob_paths.append(blob.name)

return blob_paths, list(warnings)

@staticmethod
def _is_content_type_compatible_with_input_type(
content_type: str, input_type: InputType
) -> bool:
"""Check if the content type is compatible with the input type."""
if input_type == "IMAGE":
return content_type.startswith("image")

if input_type == "VIDEO":
return content_type.startswith("video")

if input_type == "PDF":
return content_type.startswith("application/pdf")

if input_type == "TEXT":
return any(
content_type.startswith(text_type)
for text_type in (
"application/json",
"text/plain",
"text/html",
"text/csv",
"text/xml",
)
)
]

return blob_paths
raise ValueError(f"Unknown project input type: {input_type}")
4 changes: 2 additions & 2 deletions src/kili/utils/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def upload_data_via_rest(
url_to_use_for_upload = url_with_id.split("&id=")[0]
if "blob.core.windows.net" in url_to_use_for_upload:
headers["x-ms-blob-type"] = "BlockBlob"

response = http_client.put(url_to_use_for_upload, data=data, headers=headers, timeout=30)
# Do we not put a timeout here because it can take an arbitrary long time (ML-1395)
response = http_client.put(url_to_use_for_upload, data=data, headers=headers)
response.raise_for_status()
return url_with_id

Expand Down
1 change: 0 additions & 1 deletion tests/e2e/test_cloud_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def is_same_endpoint(endpoint_short_name: str, endpoint_url: str) -> bool:
("STAGING", "GCP", "f474c0170c8daa09ec2e368ce4720c73", None, 5),
],
)
@pytest.mark.skip("to fix")
def test_e2e_synchronize_cloud_storage_connection(
kili: Kili,
src_project: Dict,
Expand Down
4 changes: 0 additions & 4 deletions tests/integration/presentation/test_tag.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import pytest_mock

from kili.adapters.kili_api_gateway import KiliAPIGateway
Expand Down Expand Up @@ -26,7 +25,6 @@ def test_when_fetching_org_tags_then_i_get_tags(mocker: pytest_mock.MockerFixtur
)


@pytest.mark.skip(reason="listTagsByProject is broken currently. Use listTagsByOrg instead.")
def test_when_fetching_project_tags_then_i_get_tags(mocker: pytest_mock.MockerFixture):
kili = TagClientMethods()
kili.kili_api_gateway = KiliAPIGateway(
Expand Down Expand Up @@ -61,7 +59,6 @@ def test_given_tags_when_i_tag_project_with_tag_ids_then_it_is_tagged(
kili.tag_project(project_id="fake_proj_id", tag_ids=["tag1_id", "tag2_id"])

# Then
assert kili.kili_api_gateway.graphql_client.execute.call_count == len(tags)
kili.kili_api_gateway.graphql_client.execute.assert_called_with(
GQL_CHECK_TAG, {"data": {"tagId": "tag2_id", "projectId": "fake_proj_id"}}
)
Expand All @@ -85,7 +82,6 @@ def test_given_tags_when_i_tag_project_with_tag_labels_then_it_is_tagged(
kili.tag_project(project_id="fake_proj_id", tags=["tag1", "tag2"])

# Then
assert kili.kili_api_gateway.graphql_client.execute.call_count == len(tags)
kili.kili_api_gateway.graphql_client.execute.assert_called_with(
GQL_CHECK_TAG, {"data": {"tagId": "tag2_id", "projectId": "fake_proj_id"}}
)
Expand Down
73 changes: 73 additions & 0 deletions tests/unit/test_graphql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from time import time
from typing import Dict
from unittest import mock

import graphql
Expand Down Expand Up @@ -355,3 +356,75 @@ def mocked_backend_response(*args, **kwargs):
# Then
assert result["data"] == "all good"
assert mocked_execute.call_count == nb_times_called == 3


@pytest.mark.parametrize(
("variables", "expected"),
[
({"id": "123456"}, {"id": "123456"}),
({"id": None}, {}),
(
{
"project": {"id": "project_id"},
"asset": {"id": None},
"assetIn": ["123456"],
"status": "some_status",
"type": None,
},
{
"project": {"id": "project_id"},
"asset": {"id": None},
"assetIn": ["123456"],
"status": "some_status",
},
),
(
{
"id": None,
"searchQuery": "truc",
"shouldRelaunchKpiComputation": None,
"starred": True,
"updatedAtGte": None,
"updatedAtLte": None,
"createdAtGte": None,
"createdAtLte": None,
"tagIds": ["tag_id"],
},
{
"searchQuery": "truc",
"starred": True,
"tagIds": ["tag_id"],
},
),
( # assetwhere
{
"externalIdStrictlyIn": ["truc"],
"externalIdIn": None,
"honeypotMarkGte": None,
"honeypotMarkLte": 0.0,
"id": "fake_asset_id",
"metadata": {"key": None}, # this field is a JSON graphql type. It should be kept
"project": {"id": "fake_proj_id"},
"skipped": True,
"updatedAtLte": None,
},
{
"externalIdStrictlyIn": ["truc"],
"honeypotMarkLte": 0.0,
"id": "fake_asset_id",
"metadata": {"key": None},
"project": {"id": "fake_proj_id"},
"skipped": True,
},
),
],
)
def test_given_variables_when_i_remove_null_values_then_it_works(variables: Dict, expected: Dict):
# Given
_ = variables

# When
output = GraphQLClient._remove_nullable_inputs(variables)

# Then
assert output == expected