Skip to content

Commit

Permalink
fix: use projects resolver to check for data connections (#1522)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas1312 authored Oct 6, 2023
1 parent 8d0da6f commit 2886b02
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 54 deletions.
4 changes: 2 additions & 2 deletions src/kili/services/copy_project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ProjectCopier: # pylint: disable=too-few-public-methods
"inputType",
"description",
"id",
"dataConnections.dataIntegrationId",
"dataConnections.id",
)
FIELDS_JSON_INTERFACE = ("jsonInterface",)
FIELDS_QUALITY_SETTINGS = (
Expand Down Expand Up @@ -87,7 +87,7 @@ def copy_project( # pylint: disable=too-many-arguments,too-many-locals

src_project = get_project(self.kili, from_project_id, fields)

if len(src_project["dataConnections"]) > 0 and copy_assets:
if src_project["dataConnections"] and copy_assets:
raise NotImplementedError("Copying projects with cloud storage is not supported.")

new_project_title = title or self._generate_project_title(src_title=src_project["title"])
Expand Down
15 changes: 2 additions & 13 deletions src/kili/services/export/format/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional, Tuple, cast

from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions
from kili.core.graphql.operations.data_connection.queries import (
DataConnectionsQuery,
DataConnectionsWhere,
)
from kili.domain.asset import AssetId
from kili.domain.project import ProjectId
from kili.orm import Asset, Label
Expand Down Expand Up @@ -206,14 +201,8 @@ def _check_and_ensure_asset_access(self) -> None:
)

def _has_data_connection(self) -> bool:
data_connections_gen = DataConnectionsQuery(
self.kili.graphql_client, self.kili.http_client
)(
where=DataConnectionsWhere(project_id=self.project_id),
fields=["id"],
options=QueryOptions(disable_tqdm=True, first=1, skip=0),
)
return len(list(data_connections_gen)) > 0
project = get_project(self.kili, self.project_id, ["dataConnections.id"])
return bool(project["dataConnections"])

def _check_geotiff_export_compatibility(self, assets: List[Asset]) -> None:
# pylint: disable=line-too-long
Expand Down
18 changes: 4 additions & 14 deletions src/kili/use_cases/asset/media_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@

from kili.adapters.http_client import HttpClient
from kili.adapters.kili_api_gateway import KiliAPIGateway
from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions
from kili.core.graphql.operations.data_connection.queries import (
DataConnectionsQuery,
DataConnectionsWhere,
)
from kili.domain.project import ProjectId
from kili.domain.types import ListOrTuple

Expand All @@ -41,19 +36,14 @@ def get_download_assets_function(
if not download_media:
return None, fields

project = kili_api_gateway.get_project(project_id=project_id, fields=("inputType",))
project = kili_api_gateway.get_project(
project_id=project_id, fields=("inputType", "dataConnections.id")
)
input_type = project["inputType"]

# We need to query the data connections to know if the assets are hosted in a cloud storage
# If so, we remove the fields "content" and "jsonContent" from the query
data_connections_gen = DataConnectionsQuery(
kili_api_gateway.graphql_client, kili_api_gateway.http_client
)(
where=DataConnectionsWhere(project_id=project_id),
fields=("id",),
options=QueryOptions(disable_tqdm=True, first=1, skip=0),
)
if len(list(data_connections_gen)) > 0:
if project["dataConnections"]:
raise DownloadNotAllowedError(
"The download of assets from a project connected to a cloud storage is not allowed."
" Asset download is disabled."
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/test_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def process_notebook(notebook_filename: str) -> None:
pytest.param(
"tests/e2e/plugin_workflow.ipynb",
marks=pytest.mark.skipif(
"lts.cloud" in os.environ["KILI_API_ENDPOINT"],
"lts.cloud" in os.getenv("KILI_API_ENDPOINT", ""),
reason="Feature not available on premise",
),
),
Expand All @@ -50,7 +50,7 @@ def process_notebook(notebook_filename: str) -> None:
pytest.param(
"recipes/plugins_example.ipynb",
marks=pytest.mark.skipif(
"lts.cloud" in os.environ["KILI_API_ENDPOINT"],
"lts.cloud" in os.getenv("KILI_API_ENDPOINT", ""),
reason="Feature not available on premise",
),
),
Expand Down
4 changes: 4 additions & 0 deletions tests/fakes/fake_kili.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def mocked_ProjectQuery(where, _fields, _options):
"description": "This is a test project",
"jsonInterface": json_interface,
"inputType": "IMAGE",
"dataConnections": None,
}
]
elif project_id == "object_detection_video_project":
Expand Down Expand Up @@ -122,6 +123,7 @@ def mocked_ProjectQuery(where, _fields, _options):
"description": "This is a test project",
"jsonInterface": json_interface,
"inputType": "VIDEO",
"dataConnections": None,
}
]
elif project_id == "text_classification":
Expand Down Expand Up @@ -155,6 +157,7 @@ def mocked_ProjectQuery(where, _fields, _options):
"description": "This is a TC test project",
"jsonInterface": json_interface,
"inputType": "TEXT",
"dataConnections": None,
}
]
elif project_id == "semantic_segmentation":
Expand Down Expand Up @@ -198,6 +201,7 @@ def mocked_ProjectQuery(where, _fields, _options):
"description": "This is a semantic segmentation test project",
"jsonInterface": json_interface,
"inputType": "IMAGE",
"dataConnections": None,
}
]
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/use_cases/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_given_query_parameters_I_can_query_assets_and_download_their_media(
kili_api_gateway: KiliAPIGateway, mocker
):
# mocking
kili_api_gateway.get_project.return_value = {"inputType": "IMAGE"}
kili_api_gateway.get_project.return_value = {"inputType": "IMAGE", "dataConnections": None}
media_downlaoder_mock = mocker.patch.object(MediaDownloader, "__init__", return_value=None)

# given parameters to query assets
Expand Down
10 changes: 3 additions & 7 deletions tests/unit/services/export/test_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
_get_coco_categories_with_mapping,
_get_coco_geometry_from_kili_bpoly,
)
from kili.services.export.format.kili import KiliExporter
from kili.services.types import Job, JobName
from kili.utils.tempfile import TemporaryDirectory

Expand Down Expand Up @@ -579,12 +578,9 @@ def test_when_exporting_to_coco_given_a_project_with_data_connection_then_it_sho
mocker.patch(
"kili.services.export.format.base.get_project", return_value=get_project_return_val
)
mocker.patch.object(KiliExporter, "_check_arguments_compatibility", return_value=None)
mocker.patch.object(KiliExporter, "_check_project_compatibility", return_value=None)
mocker.patch(
"kili.services.export.format.base.DataConnectionsQuery.__call__",
return_value=(i for i in [{"id": "fake_data_connection_id"}]),
)
mocker.patch.object(CocoExporter, "_check_arguments_compatibility", return_value=None)
mocker.patch.object(CocoExporter, "_check_project_compatibility", return_value=None)
mocker.patch.object(CocoExporter, "_has_data_connection", return_value=True)

kili = QueriesLabel()
kili.api_key = "" # type: ignore
Expand Down
67 changes: 61 additions & 6 deletions tests/unit/services/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from kili.domain.asset import AssetFilters
from kili.entrypoints.queries.label import QueriesLabel
from kili.orm import Asset
from kili.services.export import export_labels
from kili.services.export import AbstractExporter, export_labels
from kili.services.export.exceptions import (
NoCompatibleJobError,
NotCompatibleInputType,
NotCompatibleOptions,
)
from kili.services.export.format.kili import KiliExporter
from kili.services.export.format.voc import VocExporter
from tests.fakes.fake_kili import (
FakeKili,
mocked_AssetQuery,
Expand Down Expand Up @@ -794,6 +795,7 @@ def test_export_with_asset_filter_kwargs(mocker):
"kili.services.export.format.base.get_project", return_value=get_project_return_val
)
mocker.patch.object(KiliExporter, "process_and_save", return_value=None)
mocker.patch.object(KiliExporter, "_has_data_connection", return_value=False)
kili = QueriesLabel()
kili.api_endpoint = "https://" # type: ignore
kili.api_key = "" # type: ignore
Expand Down Expand Up @@ -875,6 +877,7 @@ def test_export_with_asset_filter_kwargs_unknown_arg(mocker):
)
mocker.patch.object(KiliExporter, "_check_arguments_compatibility", return_value=None)
mocker.patch.object(KiliExporter, "_check_project_compatibility", return_value=None)
mocker.patch.object(KiliExporter, "_has_data_connection", return_value=False)
kili = QueriesLabel()
kili.api_endpoint = "https://" # type: ignore
kili.api_key = "" # type: ignore
Expand Down Expand Up @@ -917,11 +920,7 @@ def mock_kili(mocker, with_data_connection):
mocker.patch(
"kili.services.export.format.base.get_project", return_value=get_project_return_val
)
if with_data_connection:
mocker.patch(
"kili.services.export.format.base.DataConnectionsQuery.__call__",
return_value=(i for i in [{"id": "fake_data_connection_id"}]),
)
mocker.patch.object(AbstractExporter, "_has_data_connection", return_value=with_data_connection)

kili = QueriesLabel()
kili.kili_api_gateway = mocker.MagicMock()
Expand Down Expand Up @@ -1050,3 +1049,59 @@ def test_when_exporting_geotiff_asset_with_incompatible_options_then_it_crashes(
),
):
kili.export_labels("fake_proj_id", "export.zip", fmt="kili", normalized_coordinates=False)


def test_given_kili_when_exporting_it_does_not_call_dataconnection_resolver(
mocker: pytest_mock.MockerFixture,
):
"""Test that the dataconnection resolver is not called when exporting.
Export for projects with data connections is forbidden.
But dataConnections() resolver requires high permissions.
This test ensures that the resolver is not called when exporting.
"""
# Given
project_return_val = {
"jsonInterface": {
"jobs": {
"OBJECT_DETECTION_JOB": {
"content": {
"categories": {
"GDGF": {
"children": [],
"color": "#472CED",
"name": "gdgf",
}
},
"input": "radio",
},
"instruction": "df",
"mlTask": "OBJECT_DETECTION",
"required": 1,
"tools": ["rectangle"],
"isChild": False,
}
}
},
"inputType": "IMAGE",
"title": "",
"dataConnections": None,
}
mocker.patch.object(ProjectQuery, "__call__", return_value=[project_return_val])
mocker.patch("kili.services.export.format.base.fetch_assets", return_value=[])
process_and_save_mock = mocker.patch.object(VocExporter, "process_and_save", return_value=None)
kili = QueriesLabel()
kili.api_endpoint = "https://" # type: ignore
kili.api_key = "" # type: ignore
kili.graphql_client = mocker.MagicMock()
kili.http_client = mocker.MagicMock()
kili.kili_api_gateway = mocker.MagicMock()

# When
kili.export_labels(
project_id="fake_proj_id", filename="exp.zip", fmt="pascal_voc", layout="merged"
)

# Then
process_and_save_mock.assert_called_once()
kili.graphql_client.execute.assert_not_called()
6 changes: 2 additions & 4 deletions tests/unit/services/export/test_geojson.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from kili.entrypoints.queries.label import QueriesLabel
from kili.orm import Asset
from kili.services.export.format.geojson import GeoJsonExporter


def test_kili_export_labels_geojson(mocker: pytest_mock.MockerFixture):
Expand All @@ -22,10 +23,7 @@ def test_kili_export_labels_geojson(mocker: pytest_mock.MockerFixture):
mocker.patch(
"kili.services.export.format.base.get_project", return_value=get_project_return_val
)
mocker.patch(
"kili.services.export.format.base.DataConnectionsQuery.__call__",
return_value=(i for i in [{"id": "fake_data_connection_id"}]),
)
mocker.patch.object(GeoJsonExporter, "_has_data_connection", return_value=False)
mocker.patch(
"kili.services.export.format.base.fetch_assets",
return_value=[
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/services/export/test_kili.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def test_kili_exporter_convert_to_pixel_coords_pdf(mocker: pytest_mock.MockerFix
def test_kili_export_labels_non_normalized_pdf(mocker: pytest_mock.MockerFixture):
get_project_return_val = {
"inputType": "PDF",
"dataConnections": None,
"id": "fake_proj_id",
"title": "fake_proj_title",
"description": "fake_proj_description",
Expand Down Expand Up @@ -273,6 +274,7 @@ def test_kili_export_labels_non_normalized_image(mocker: pytest_mock.MockerFixtu
get_project_return_val = {
"id": "fake_proj_id",
"title": "hgfhfg",
"dataConnections": None,
"inputType": "IMAGE",
"jsonInterface": {
"jobs": {
Expand Down Expand Up @@ -444,6 +446,7 @@ def test_kili_export_labels_non_normalized_video(mocker: pytest_mock.MockerFixtu
"title": "Object tracking on video",
"description": "Use bounding-box to track objects across video frames.",
"id": "fake_proj_id",
"dataConnections": None,
}

mocker.patch("kili.services.export.get_project", return_value=get_project_return_val)
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/services/export/test_voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from kili.entrypoints.queries.label import QueriesLabel
from kili.services.export.exceptions import NotCompatibleOptions
from kili.services.export.format.voc import _convert_from_kili_to_voc_format
from kili.services.export.format.voc import (
VocExporter,
_convert_from_kili_to_voc_format,
)
from tests.fakes.fake_data import asset_image_1, asset_image_1_without_annotation


Expand Down Expand Up @@ -44,15 +47,13 @@ def test_when_exporting_to_voc_given_a_project_with_data_connection_then_it_shou
"inputType": "IMAGE",
"title": "",
"id": "fake_proj_id",
"dataConnections": None,
}
mocker.patch("kili.services.export.get_project", return_value=get_project_return_val)
mocker.patch(
"kili.services.export.format.base.get_project", return_value=get_project_return_val
)
mocker.patch(
"kili.services.export.format.base.DataConnectionsQuery.__call__",
return_value=(i for i in [{"id": "fake_data_connection_id"}]),
)
mocker.patch.object(VocExporter, "_has_data_connection", return_value=True)

kili = QueriesLabel()
kili.api_endpoint = "https://" # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/services/export/test_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from kili.entrypoints.queries.label import QueriesLabel
from kili.orm import Asset
from kili.services.export.format.yolo import (
YoloExporter,
_convert_from_kili_to_yolo_format,
_process_asset,
_write_class_file,
Expand Down Expand Up @@ -283,6 +284,7 @@ def test_yolo_v8_merged(mocker: pytest_mock.MockerFixture):
"kili.services.export.format.base.fetch_assets",
return_value=[Asset(asset) for asset in assets],
)
mocker.patch.object(YoloExporter, "_has_data_connection", return_value=False)

kili = QueriesLabel()
kili.api_endpoint = "https://" # type: ignore
Expand Down Expand Up @@ -334,6 +336,7 @@ def test_yolo_v8_split_jobs(mocker: pytest_mock.MockerFixture):
"kili.services.export.format.base.fetch_assets",
return_value=[Asset(asset) for asset in assets],
)
mocker.patch.object(YoloExporter, "_has_data_connection", return_value=False)

kili = QueriesLabel()
kili.api_endpoint = "https://" # type: ignore
Expand Down

0 comments on commit 2886b02

Please sign in to comment.