From 48950a8405d55b0515bacfe1123ff0e4eb8259ba Mon Sep 17 00:00:00 2001 From: Jonas Maison Date: Thu, 12 Oct 2023 12:34:12 +0200 Subject: [PATCH] refactor: user methods (#1535) --- docs/sdk/user.md | 5 +- .../adapters/kili_api_gateway/__init__.py | 2 + .../kili_api_gateway}/user/__init__.py | 0 .../adapters/kili_api_gateway/user/mappers.py | 47 ++++ .../kili_api_gateway/user/operation_mixin.py | 93 +++++++ .../kili_api_gateway/user/operations.py | 84 +++++++ .../adapters/kili_api_gateway/user/types.py | 35 +++ src/kili/client.py | 18 +- src/kili/core/graphql/graphql_client.py | 8 +- .../core/graphql/operations/user/queries.py | 60 ----- src/kili/domain/user.py | 21 ++ .../entrypoints/mutations/label/__init__.py | 7 +- .../entrypoints/mutations/user/__init__.py | 120 --------- .../entrypoints/mutations/user/fragments.py | 6 - .../entrypoints/mutations/user/queries.py | 62 ----- src/kili/entrypoints/queries/user/__init__.py | 116 --------- src/kili/exceptions.py | 4 - src/kili/presentation/client/internal.py | 5 +- src/kili/presentation/client/user.py | 229 ++++++++++++++++++ src/kili/services/asset_import/base.py | 13 +- src/kili/use_cases/user/__init__.py | 95 ++++++++ tests/e2e/test_e2e_graphql_client.py | 5 +- tests/e2e/test_query_project_users.py | 2 +- .../adapters/kili_api_gateway/test_user.py | 169 +++++++++++++ .../entrypoints/client/queries/test_users.py | 27 --- tests/integration/presentation/test_user.py | 37 +++ 26 files changed, 841 insertions(+), 429 deletions(-) rename src/kili/{core/graphql/operations => adapters/kili_api_gateway}/user/__init__.py (100%) create mode 100644 src/kili/adapters/kili_api_gateway/user/mappers.py create mode 100644 src/kili/adapters/kili_api_gateway/user/operation_mixin.py create mode 100644 src/kili/adapters/kili_api_gateway/user/operations.py create mode 100644 src/kili/adapters/kili_api_gateway/user/types.py delete mode 100644 src/kili/core/graphql/operations/user/queries.py create mode 100644 src/kili/domain/user.py delete mode 100644 src/kili/entrypoints/mutations/user/__init__.py delete mode 100644 src/kili/entrypoints/mutations/user/fragments.py delete mode 100644 src/kili/entrypoints/mutations/user/queries.py delete mode 100644 src/kili/entrypoints/queries/user/__init__.py create mode 100644 src/kili/presentation/client/user.py create mode 100644 src/kili/use_cases/user/__init__.py create mode 100644 tests/integration/adapters/kili_api_gateway/test_user.py delete mode 100644 tests/integration/entrypoints/client/queries/test_users.py create mode 100644 tests/integration/presentation/test_user.py diff --git a/docs/sdk/user.md b/docs/sdk/user.md index a44790129..1907d22ee 100644 --- a/docs/sdk/user.md +++ b/docs/sdk/user.md @@ -1,6 +1,3 @@ # User module -## Queries -::: kili.entrypoints.queries.user.__init__.QueriesUser -## Mutations -::: kili.entrypoints.mutations.user.__init__.MutationsUser +::: kili.presentation.client.user.UserClientMethods diff --git a/src/kili/adapters/kili_api_gateway/__init__.py b/src/kili/adapters/kili_api_gateway/__init__.py index 5d1a1b7ae..49e480cd9 100644 --- a/src/kili/adapters/kili_api_gateway/__init__.py +++ b/src/kili/adapters/kili_api_gateway/__init__.py @@ -8,6 +8,7 @@ from kili.adapters.kili_api_gateway.issue import IssueOperationMixin from kili.adapters.kili_api_gateway.project import ProjectOperationMixin from kili.adapters.kili_api_gateway.tag import TagOperationMixin +from kili.adapters.kili_api_gateway.user.operation_mixin import UserOperationMixin from kili.core.graphql.graphql_client import GraphQLClient @@ -17,6 +18,7 @@ class KiliAPIGateway( ProjectOperationMixin, TagOperationMixin, ApiKeyOperationMixin, + UserOperationMixin, CloudStorageOperationMixin, ): """GraphQL gateway to communicate with Kili backend.""" diff --git a/src/kili/core/graphql/operations/user/__init__.py b/src/kili/adapters/kili_api_gateway/user/__init__.py similarity index 100% rename from src/kili/core/graphql/operations/user/__init__.py rename to src/kili/adapters/kili_api_gateway/user/__init__.py diff --git a/src/kili/adapters/kili_api_gateway/user/mappers.py b/src/kili/adapters/kili_api_gateway/user/mappers.py new file mode 100644 index 000000000..22b558432 --- /dev/null +++ b/src/kili/adapters/kili_api_gateway/user/mappers.py @@ -0,0 +1,47 @@ +"""GraphQL payload data mappers for user operations.""" + +from typing import Dict + +from kili.domain.user import UserFilter + +from .types import CreateUserDataKiliGatewayInput, UserDataKiliGatewayInput + + +def user_where_mapper(filters: UserFilter) -> Dict: + """Build the GraphQL UserWhere variable to be sent in an operation.""" + return { + "activated": filters.activated, + "apiKey": filters.api_key, + "email": filters.email, + "id": filters.id, + "idIn": filters.id_in, + "organization": {"id": filters.organization_id}, + } + + +def create_user_data_mapper(data: CreateUserDataKiliGatewayInput) -> Dict: + """Build the CreateUserDataKiliGatewayInput data variable to be sent in an operation.""" + return { + "email": data.email, + "firstname": data.firstname, + "lastname": data.lastname, + "password": data.password, + "organizationRole": data.organization_role, + } + + +def update_user_data_mapper(data: UserDataKiliGatewayInput) -> Dict: + """Build the UserDataKiliGatewayInput data variable to be sent in an operation.""" + return { + "activated": data.activated, + "apiKey": data.api_key, + # "auth0Id": data.auth0_id, # refused by the backend: only used for service account # noqa: ERA001 # pylint: disable=line-too-long + "email": data.email, + "firstname": data.firstname, + "hasCompletedLabelingTour": data.has_completed_labeling_tour, + "hubspotSubscriptionStatus": data.hubspot_subscription_status, + "lastname": data.lastname, + "organization": data.organization, + "organizationId": data.organization_id, + "organizationRole": data.organization_role, + } diff --git a/src/kili/adapters/kili_api_gateway/user/operation_mixin.py b/src/kili/adapters/kili_api_gateway/user/operation_mixin.py new file mode 100644 index 000000000..decb7c4c6 --- /dev/null +++ b/src/kili/adapters/kili_api_gateway/user/operation_mixin.py @@ -0,0 +1,93 @@ +"""Mixin extending Kili API Gateway class with User related operations.""" + +from typing import Dict, Generator + +from kili.adapters.kili_api_gateway.base import BaseOperationMixin +from kili.adapters.kili_api_gateway.helpers.queries import ( + PaginatedGraphQLQuery, + QueryOptions, + fragment_builder, +) +from kili.domain.types import ListOrTuple +from kili.domain.user import UserFilter + +from .mappers import create_user_data_mapper, update_user_data_mapper, user_where_mapper +from .operations import ( + GQL_COUNT_USERS, + get_create_user_mutation, + get_current_user_query, + get_update_password_mutation, + get_update_user_mutation, + get_users_query, +) +from .types import CreateUserDataKiliGatewayInput, UserDataKiliGatewayInput + + +class UserOperationMixin(BaseOperationMixin): + """GraphQL Mixin extending GraphQL Gateway class with User related operations.""" + + def list_users( + self, user_filters: UserFilter, fields: ListOrTuple[str], options: QueryOptions + ) -> Generator[Dict, None, None]: + """Return a generator of users that match the filter.""" + fragment = fragment_builder(fields) + query = get_users_query(fragment) + where = user_where_mapper(filters=user_filters) + return PaginatedGraphQLQuery(self.graphql_client).execute_query_from_paginated_call( + query, where, options, "Retrieving users", GQL_COUNT_USERS + ) + + def count_users(self, user_filters: UserFilter) -> int: + """Return the number of users that match the filter.""" + result = self.graphql_client.execute(GQL_COUNT_USERS, user_where_mapper(user_filters)) + return result["data"] + + def get_current_user(self, fields: ListOrTuple[str]) -> Dict: + """Return the current user.""" + fragment = fragment_builder(fields) + query = get_current_user_query(fragment=fragment) + result = self.graphql_client.execute(query) + return result["data"] + + def create_user(self, data: CreateUserDataKiliGatewayInput, fields: ListOrTuple[str]) -> Dict: + """Create a user.""" + fragment = fragment_builder(fields) + query = get_create_user_mutation(fragment) + variables = create_user_data_mapper(data) + result = self.graphql_client.execute(query, variables) + return result["data"] + + def update_password( + self, + old_password: str, + new_password_1: str, + new_password_2: str, + user_filter: UserFilter, + fields: ListOrTuple[str], + ) -> Dict: + """Update user password.""" + fragment = fragment_builder(fields) + query = get_update_password_mutation(fragment) + variables = { + "data": { + "oldPassword": old_password, + "newPassword1": new_password_1, + "newPassword2": new_password_2, + }, + "where": user_where_mapper(filters=user_filter), + } + result = self.graphql_client.execute(query, variables) + return result["data"] + + def update_user( + self, user_filter: UserFilter, data: UserDataKiliGatewayInput, fields: ListOrTuple[str] + ) -> Dict: + """Update a user.""" + fragment = fragment_builder(fields) + query = get_update_user_mutation(fragment) + variables = { + "data": update_user_data_mapper(data), + "where": user_where_mapper(filters=user_filter), + } + result = self.graphql_client.execute(query, variables) + return result["data"] diff --git a/src/kili/adapters/kili_api_gateway/user/operations.py b/src/kili/adapters/kili_api_gateway/user/operations.py new file mode 100644 index 000000000..da0075a2c --- /dev/null +++ b/src/kili/adapters/kili_api_gateway/user/operations.py @@ -0,0 +1,84 @@ +"""GraphQL User operations.""" + + +def get_users_query(fragment: str) -> str: + """Return the GraphQL users query.""" + return f""" + query users($where: UserWhere!, $first: PageSize!, $skip: Int!) {{ + data: users(where: $where, first: $first, skip: $skip) {{ + {fragment} + }} + }} + """ + + +def get_current_user_query(fragment: str) -> str: + """Return the GraphQL current user query.""" + return f""" + query me {{ + data: me {{ + {fragment} + }} + }} + """ + + +def get_create_user_mutation(fragment: str) -> str: + """Return the GraphQL create user mutation.""" + return f""" + mutation( + $data: CreateUserData! + ) {{ + data: createUser( + data: $data + ) {{ + {fragment} + }} + }} + """ + + +def get_update_password_mutation(fragment: str) -> str: + """Return the GraphQL update password mutation.""" + return f""" +mutation( + $data: UpdatePasswordData! + $where: UserWhere! +) {{ + data: updatePassword( + data: $data + where: $where + ) {{ + {fragment} + }} +}} +""" + + +def get_update_user_mutation(fragment: str) -> str: + """Return the GraphQL update user mutation.""" + return f""" + mutation updatePropertiesInUser( $data: UserData!, $where: UserWhere!) {{ + data: updatePropertiesInUser( data: $data, where: $where) {{ + {fragment} + }} + }} + """ + + +GQL_COUNT_USERS = """ + query countUsers($where: UserWhere!) { + data: countUsers(where: $where) + } + """ + + +def get_reset_password_mutation(fragment: str) -> str: + """Return the GraphQL reset password mutation.""" + return f""" + mutation($where: UserWhere!) {{ + data: resetPassword(where: $where) {{ + {fragment} + }} + }} +""" diff --git a/src/kili/adapters/kili_api_gateway/user/types.py b/src/kili/adapters/kili_api_gateway/user/types.py new file mode 100644 index 000000000..df86633f3 --- /dev/null +++ b/src/kili/adapters/kili_api_gateway/user/types.py @@ -0,0 +1,35 @@ +"""Types for the user-related Kili API gateway functions.""" +from dataclasses import dataclass +from typing import Optional + +from kili.core.enums import OrganizationRole +from kili.domain.organization import OrganizationId +from kili.domain.user import HubspotSubscriptionStatus + + +@dataclass +class CreateUserDataKiliGatewayInput: + """Input type for creating a user in Kili Gateway.""" + + email: str + firstname: Optional[str] + lastname: Optional[str] + password: Optional[str] + organization_role: OrganizationRole + + +@dataclass +class UserDataKiliGatewayInput: + """Input type for updating a user in Kili Gateway.""" + + activated: Optional[bool] = None + api_key: Optional[str] = None + # auth0_id: Optional[str] = None # refused by the backend: only used for service account # noqa: ERA001 # pylint: disable=line-too-long + email: Optional[str] = None + firstname: Optional[str] = None + has_completed_labeling_tour: Optional[bool] = None + hubspot_subscription_status: Optional[HubspotSubscriptionStatus] = None + lastname: Optional[str] = None + organization: Optional[str] = None + organization_id: Optional[OrganizationId] = None + organization_role: Optional[OrganizationRole] = None diff --git a/src/kili/client.py b/src/kili/client.py index 580632d51..49d2edd41 100644 --- a/src/kili/client.py +++ b/src/kili/client.py @@ -10,7 +10,6 @@ from kili.adapters.http_client import HttpClient from kili.adapters.kili_api_gateway import KiliAPIGateway from kili.core.graphql.graphql_client import GraphQLClient, GraphQLClientName -from kili.core.graphql.operations.user.queries import GQL_ME from kili.entrypoints.mutations.asset import MutationsAsset from kili.entrypoints.mutations.issue import MutationsIssue from kili.entrypoints.mutations.label import MutationsLabel @@ -18,22 +17,21 @@ from kili.entrypoints.mutations.plugins import MutationsPlugins from kili.entrypoints.mutations.project import MutationsProject from kili.entrypoints.mutations.project_version import MutationsProjectVersion -from kili.entrypoints.mutations.user import MutationsUser from kili.entrypoints.queries.label import QueriesLabel from kili.entrypoints.queries.notification import QueriesNotification from kili.entrypoints.queries.organization import QueriesOrganization from kili.entrypoints.queries.plugins import QueriesPlugins from kili.entrypoints.queries.project_user import QueriesProjectUser from kili.entrypoints.queries.project_version import QueriesProjectVersion -from kili.entrypoints.queries.user import QueriesUser from kili.entrypoints.subscriptions.label import SubscriptionsLabel -from kili.exceptions import AuthenticationFailed, UserNotFoundError +from kili.exceptions import AuthenticationFailed from kili.presentation.client.asset import AssetClientMethods from kili.presentation.client.cloud_storage import CloudStorageClientMethods from kili.presentation.client.internal import InternalClientMethods from kili.presentation.client.issue import IssueClientMethods from kili.presentation.client.project import ProjectClientMethods from kili.presentation.client.tag import TagClientMethods +from kili.presentation.client.user import UserClientMethods from kili.use_cases.api_key import ApiKeyUseCases warnings.filterwarnings("default", module="kili", category=DeprecationWarning) @@ -58,20 +56,19 @@ class Kili( # pylint: disable=too-many-ancestors,too-many-instance-attributes MutationsPlugins, MutationsProject, MutationsProjectVersion, - MutationsUser, QueriesLabel, QueriesNotification, QueriesOrganization, QueriesPlugins, QueriesProjectUser, QueriesProjectVersion, - QueriesUser, SubscriptionsLabel, IssueClientMethods, AssetClientMethods, TagClientMethods, ProjectClientMethods, CloudStorageClientMethods, + UserClientMethods, ): """Kili Client.""" @@ -168,12 +165,3 @@ def __init__( if not skip_checks: api_key_use_cases = ApiKeyUseCases(self.kili_api_gateway) api_key_use_cases.check_expiry_of_key_is_close(api_key) - - def get_user(self) -> Dict: - # TODO: move this method - """Get the current user from the api_key provided.""" - result = self.graphql_client.execute(GQL_ME) - user = self.format_result("data", result) - if user is None or user["id"] is None or user["email"] is None: - raise UserNotFoundError("No user attached to the API key was found") - return user diff --git a/src/kili/core/graphql/graphql_client.py b/src/kili/core/graphql/graphql_client.py index 0d01b2c91..fbb0110ae 100644 --- a/src/kili/core/graphql/graphql_client.py +++ b/src/kili/core/graphql/graphql_client.py @@ -228,9 +228,13 @@ def _get_kili_app_version(self) -> Optional[str]: return response_json["version"] return None - @staticmethod - def _remove_nullable_inputs(variables: Dict) -> Dict: + @classmethod + def _remove_nullable_inputs(cls, variables: Dict) -> Dict: """Remove nullable inputs from the variables.""" + if "data" in variables and isinstance(variables["data"], dict): + variables["data"] = cls._remove_nullable_inputs(variables["data"]) + if "where" in variables and isinstance(variables["where"], dict): + variables["where"] = cls._remove_nullable_inputs(variables["where"]) return {k: v for k, v in variables.items() if v is not None} def execute( diff --git a/src/kili/core/graphql/operations/user/queries.py b/src/kili/core/graphql/operations/user/queries.py deleted file mode 100644 index 8d335f1e0..000000000 --- a/src/kili/core/graphql/operations/user/queries.py +++ /dev/null @@ -1,60 +0,0 @@ -"""GraphQL Queries of Users.""" - - -from typing import Optional - -from kili.core.graphql.queries import BaseQueryWhere, GraphQLQuery - - -class UserWhere(BaseQueryWhere): - """Tuple to be passed to the UserQuery to restrict query.""" - - def __init__( - self, - api_key: Optional[str] = None, - email: Optional[str] = None, - organization_id: Optional[str] = None, - ) -> None: - self.api_key = api_key - self.email = email - self.organization_id = organization_id - super().__init__() - - def graphql_where_builder(self): - """Build the GraphQL Where payload sent in the resolver from the SDK UserWhere.""" - return { - "apiKey": self.api_key, - "email": self.email, - "organization": {"id": self.organization_id}, - } - - -class UserQuery(GraphQLQuery): - """User query.""" - - @staticmethod - def query(fragment): - """Return the GraphQL users query.""" - return f""" - query users($where: UserWhere!, $first: PageSize!, $skip: Int!) {{ - data: users(where: $where, first: $first, skip: $skip) {{ - {fragment} - }} - }} - """ - - COUNT_QUERY = """ - query countUsers($where: UserWhere!) { - data: countUsers(where: $where) - } - """ - - -GQL_ME = """ -query Me { - data: me { - id - email - } -} -""" diff --git a/src/kili/domain/user.py b/src/kili/domain/user.py new file mode 100644 index 000000000..44a15d697 --- /dev/null +++ b/src/kili/domain/user.py @@ -0,0 +1,21 @@ +"""User domain.""" +from dataclasses import dataclass +from typing import List, Literal, NewType, Optional + +from .organization import OrganizationId + +UserId = NewType("UserId", str) + +HubspotSubscriptionStatus = Literal["SUBSCRIBED", "UNSUBSCRIBED"] + + +@dataclass +class UserFilter: + """User filters for running a users search.""" + + id: Optional[UserId] # noqa: A003 + activated: Optional[bool] = None + api_key: Optional[str] = None + email: Optional[str] = None + id_in: Optional[List[UserId]] = None + organization_id: Optional[OrganizationId] = None diff --git a/src/kili/entrypoints/mutations/label/__init__.py b/src/kili/entrypoints/mutations/label/__init__.py index 67ad2ba0d..414672150 100644 --- a/src/kili/entrypoints/mutations/label/__init__.py +++ b/src/kili/entrypoints/mutations/label/__init__.py @@ -132,8 +132,9 @@ def append_to_labels( project_id: Optional[str] = None, seconds_to_label: Optional[int] = 0, ) -> Dict[Literal["id"], str]: - """!!! danger "[DEPRECATED]" - append_to_labels method is deprecated. Please use append_labels instead. + """!!! danger "[DEPRECATED]". + + append_to_labels method is deprecated. Please use append_labels instead. This new function allows to import several labels 10 times faster. Append a label to an asset. @@ -158,7 +159,7 @@ def append_to_labels( >>> kili.append_to_labels(label_asset_id=asset_id, json_response={...}) """ if author_id is None: - user = self.get_user() # type: ignore # pylint: disable=no-member + user = self.kili_api_gateway.get_current_user(fields=("id",)) author_id = user["id"] check_asset_identifier_arguments( diff --git a/src/kili/entrypoints/mutations/user/__init__.py b/src/kili/entrypoints/mutations/user/__init__.py deleted file mode 100644 index 5720df238..000000000 --- a/src/kili/entrypoints/mutations/user/__init__.py +++ /dev/null @@ -1,120 +0,0 @@ -"""User mutations.""" - -from typing import Any, Dict, Literal, Optional - -from typeguard import typechecked - -from kili.entrypoints.base import BaseOperationEntrypointMixin -from kili.entrypoints.mutations.user.queries import ( - GQL_CREATE_USER, - GQL_UPDATE_PASSWORD, - GQL_UPDATE_PROPERTIES_IN_USER, -) - - -class MutationsUser(BaseOperationEntrypointMixin): - """Set of User mutations.""" - - # pylint: disable=too-many-arguments - @typechecked - def create_user( - self, - email: str, - password: str, - organization_role: str, - firstname: Optional[str] = None, - lastname: Optional[str] = None, - ) -> Dict[Literal["id"], str]: - """Add a user to your organization. - - Args: - email: Email of the new user, used as user's unique identifier. - password: On the first sign in, he will use this password and be able to change it. - organization_role: One of "ADMIN", "USER". - firstname: First name of the new user. - lastname: Last name of the new user. - - Returns: - A dictionary with the id of the new user. - """ - variables = { - "data": { - "email": email, - "password": password, - "organizationRole": organization_role, - } - } - if firstname is not None: - variables["data"]["firstname"] = firstname - if lastname is not None: - variables["data"]["lastname"] = lastname - result = self.graphql_client.execute(GQL_CREATE_USER, variables) - return self.format_result("data", result) - - @typechecked - def update_password( - self, email: str, old_password: str, new_password_1: str, new_password_2: str - ) -> Dict[Literal["id"], str]: - """Allow to modify the password that you use to connect to Kili. - - This resolver only works for on-premise installations without Auth0. - - Args: - email: Email of the person whose password has to be updated. - old_password: The old password - new_password_1: The new password - new_password_2: A confirmation field for the new password - - Returns: - A dict with the user id. - """ - variables = { - "data": { - "oldPassword": old_password, - "newPassword1": new_password_1, - "newPassword2": new_password_2, - }, - "where": {"email": email}, - } - result = self.graphql_client.execute(GQL_UPDATE_PASSWORD, variables) - return self.format_result("data", result) - - @typechecked - def update_properties_in_user( - self, - email: str, - firstname: Optional[str] = None, - lastname: Optional[str] = None, - organization_id: Optional[str] = None, - organization_role: Optional[str] = None, - activated: Optional[bool] = None, - ) -> Dict[Literal["id"], str]: - """Update the properties of a user. - - Args: - email: The email is the identifier of the user. - firstname:Change the first name of the user. - lastname: Change the last name of the user. - organization_id: Change the organization the user is related to. - organization_role: Change the role of the user. - One of "ADMIN", "TEAM_MANAGER", "REVIEWER", "LABELER". - activated: In case we want to deactivate a user, but keep it. - - Returns: - A dict with the user id. - """ - variables: Dict[str, Any] = { - "email": email, - } - if firstname is not None: - variables["firstname"] = firstname - if lastname is not None: - variables["lastname"] = lastname - if organization_id is not None: - variables["organizationId"] = organization_id - if organization_role is not None: - variables["organizationRole"] = organization_role - if activated is not None: - variables["activated"] = activated - result = self.graphql_client.execute(GQL_UPDATE_PROPERTIES_IN_USER, variables) - return self.format_result("data", result) diff --git a/src/kili/entrypoints/mutations/user/fragments.py b/src/kili/entrypoints/mutations/user/fragments.py deleted file mode 100644 index b241ce55c..000000000 --- a/src/kili/entrypoints/mutations/user/fragments.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Fragments of user mutations.""" - - -USER_FRAGMENT = """ -id -""" diff --git a/src/kili/entrypoints/mutations/user/queries.py b/src/kili/entrypoints/mutations/user/queries.py deleted file mode 100644 index 1771de2c3..000000000 --- a/src/kili/entrypoints/mutations/user/queries.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Queries of user mutations.""" - -from .fragments import USER_FRAGMENT - -GQL_CREATE_USER = f""" -mutation( - $data: CreateUserData! -) {{ - data: createUser( - data: $data - ) {{ - {USER_FRAGMENT} - }} -}} -""" - -GQL_UPDATE_PASSWORD = f""" -mutation( - $data: UpdatePasswordData! - $where: UserWhere! -) {{ - data: updatePassword( - data: $data - where: $where - ) {{ - {USER_FRAGMENT} - }} -}} -""" - -GQL_RESET_PASSWORD = f""" -mutation($where: UserWhere!) {{ - data: resetPassword(where: $where) {{ - {USER_FRAGMENT} - }} -}} -""" - -GQL_UPDATE_PROPERTIES_IN_USER = f""" -mutation( - $email: String! - $firstname: String - $lastname: String - $organizationId: String - $organizationRole: OrganizationRole - $activated: Boolean -) {{ - data: updatePropertiesInUser( - where: {{email: $email}} - data: {{ - firstname: $firstname - lastname: $lastname - email: $email - organizationId: $organizationId - organizationRole: $organizationRole - activated: $activated - }} - ) {{ - {USER_FRAGMENT} - }} -}} -""" diff --git a/src/kili/entrypoints/queries/user/__init__.py b/src/kili/entrypoints/queries/user/__init__.py deleted file mode 100644 index a7ba5907e..000000000 --- a/src/kili/entrypoints/queries/user/__init__.py +++ /dev/null @@ -1,116 +0,0 @@ -"""User queries.""" - -from typing import Dict, Generator, Iterable, List, Literal, Optional, overload - -from typeguard import typechecked - -from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions -from kili.core.graphql.operations.user.queries import UserQuery, UserWhere -from kili.domain.types import ListOrTuple -from kili.entrypoints.base import BaseOperationEntrypointMixin -from kili.presentation.client.helpers.common_validators import ( - disable_tqdm_if_as_generator, -) -from kili.utils.logcontext import for_all_methods, log_call - - -@for_all_methods(log_call, exclude=["__init__"]) -class QueriesUser(BaseOperationEntrypointMixin): - """Set of User queries.""" - - # pylint: disable=too-many-arguments - - @overload - def users( - self, - api_key: Optional[str] = None, - email: Optional[str] = None, - organization_id: Optional[str] = None, - fields: ListOrTuple[str] = ("email", "id", "firstname", "lastname"), - first: Optional[int] = None, - skip: int = 0, - disable_tqdm: Optional[bool] = None, - *, - as_generator: Literal[True], - ) -> Generator[Dict, None, None]: ... - - @overload - def users( - self, - api_key: Optional[str] = None, - email: Optional[str] = None, - organization_id: Optional[str] = None, - fields: ListOrTuple[str] = ("email", "id", "firstname", "lastname"), - first: Optional[int] = None, - skip: int = 0, - disable_tqdm: Optional[bool] = None, - *, - as_generator: Literal[False] = False, - ) -> List[Dict]: ... - - @typechecked - def users( - self, - api_key: Optional[str] = None, - email: Optional[str] = None, - organization_id: Optional[str] = None, - fields: ListOrTuple[str] = ("email", "id", "firstname", "lastname"), - first: Optional[int] = None, - skip: int = 0, - disable_tqdm: Optional[bool] = None, - *, - as_generator: bool = False, - ) -> Iterable[Dict]: - # pylint: disable=line-too-long - """Get a generator or a list of users given a set of criteria. - - Args: - api_key: Query an user by its API Key - email: Email of the user - organization_id: Identifier of the user's organization - fields: All the fields to request among the possible fields for the users. - See [the documentation](https://docs.kili-technology.com/reference/graphql-api#user) for all possible fields. - first: Maximum number of users to return - skip: Number of skipped users (they are ordered by creation date) - disable_tqdm: If `True`, the progress bar will be disabled - as_generator: If `True`, a generator on the users is returned. - - Returns: - An iterable of users. - - Examples: - ``` - # List all users in my organization - >>> organization = kili.organizations()[0] - >>> organization_id = organization['id'] - >>> kili.users(organization_id=organization_id) - ``` - """ - where = UserWhere(api_key=api_key, email=email, organization_id=organization_id) - disable_tqdm = disable_tqdm_if_as_generator(as_generator, disable_tqdm) - options = QueryOptions(disable_tqdm, first, skip) - users_gen = UserQuery(self.graphql_client, self.http_client)(where, fields, options) - - if as_generator: - return users_gen - return list(users_gen) - - @typechecked - def count_users( - self, - organization_id: Optional[str] = None, - api_key: Optional[str] = None, - email: Optional[str] = None, - ) -> int: - """Get user count based on a set of constraints. - - Args: - organization_id: Identifier of the user's organization. - api_key: Filter by API Key. - email: Filter by email. - - Returns: - The number of organizations with the parameters provided. - """ - where = UserWhere(api_key=api_key, email=email, organization_id=organization_id) - return UserQuery(self.graphql_client, self.http_client).count(where) diff --git a/src/kili/exceptions.py b/src/kili/exceptions.py index 3fa94fc20..0d014e900 100644 --- a/src/kili/exceptions.py +++ b/src/kili/exceptions.py @@ -63,7 +63,3 @@ class MissingArgumentError(ValueError): class IncompatibleArgumentsError(ValueError): """Raised when the user gave at least two incompatible arguments.""" - - -class UserNotFoundError(Exception): - """Raised when the user is not found.""" diff --git a/src/kili/presentation/client/internal.py b/src/kili/presentation/client/internal.py index b05b64b42..462d8063f 100644 --- a/src/kili/presentation/client/internal.py +++ b/src/kili/presentation/client/internal.py @@ -6,11 +6,11 @@ from kili.adapters.kili_api_gateway import KiliAPIGateway from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions +from kili.adapters.kili_api_gateway.user.operations import get_reset_password_mutation from kili.domain.api_key import ApiKeyFilters from kili.domain.types import ListOrTuple from kili.entrypoints.mutations.organization import MutationsOrganization from kili.entrypoints.mutations.project.queries import GQL_DELETE_PROJECT -from kili.entrypoints.mutations.user.queries import GQL_RESET_PASSWORD from kili.use_cases.api_key import ApiKeyUseCases @@ -41,8 +41,9 @@ def reset_password(self, email: str): A result object which indicates if the mutation was successful, or an error message. """ + query = get_reset_password_mutation(fragment="id") variables = {"where": {"email": email}} - result = self.kili_api_gateway.graphql_client.execute(GQL_RESET_PASSWORD, variables) + result = self.kili_api_gateway.graphql_client.execute(query, variables) return result["data"] @typechecked diff --git a/src/kili/presentation/client/user.py b/src/kili/presentation/client/user.py new file mode 100644 index 000000000..59f3e1e23 --- /dev/null +++ b/src/kili/presentation/client/user.py @@ -0,0 +1,229 @@ +"""Client presentation methods for users.""" + +from typing import Dict, Generator, Iterable, List, Literal, Optional, overload + +from typeguard import typechecked + +from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions +from kili.core.enums import OrganizationRole +from kili.domain.organization import OrganizationId +from kili.domain.types import ListOrTuple +from kili.domain.user import UserFilter +from kili.presentation.client.helpers.common_validators import ( + disable_tqdm_if_as_generator, +) +from kili.use_cases.user import UserUseCases +from kili.utils.logcontext import for_all_methods, log_call + +from .base import BaseClientMethods + + +@for_all_methods(log_call, exclude=["__init__"]) +class UserClientMethods(BaseClientMethods): + """Methods attached to the Kili client, to run actions on users.""" + + @overload + def users( + self, + api_key: Optional[str] = None, + email: Optional[str] = None, + organization_id: Optional[str] = None, + fields: ListOrTuple[str] = ("email", "id", "firstname", "lastname"), + first: Optional[int] = None, + skip: int = 0, + disable_tqdm: Optional[bool] = None, + *, + as_generator: Literal[True], + ) -> Generator[Dict, None, None]: ... + + @overload + def users( + self, + api_key: Optional[str] = None, + email: Optional[str] = None, + organization_id: Optional[str] = None, + fields: ListOrTuple[str] = ("email", "id", "firstname", "lastname"), + first: Optional[int] = None, + skip: int = 0, + disable_tqdm: Optional[bool] = None, + *, + as_generator: Literal[False] = False, + ) -> List[Dict]: ... + + @typechecked + def users( + self, + api_key: Optional[str] = None, + email: Optional[str] = None, + organization_id: Optional[str] = None, + fields: ListOrTuple[str] = ("email", "id", "firstname", "lastname"), + first: Optional[int] = None, + skip: int = 0, + disable_tqdm: Optional[bool] = None, + *, + as_generator: bool = False, + ) -> Iterable[Dict]: + # pylint: disable=line-too-long + """Get a generator or a list of users given a set of criteria. + + Args: + api_key: Query an user by its API Key + email: Email of the user + organization_id: Identifier of the user's organization + fields: All the fields to request among the possible fields for the users. + See [the documentation](https://docs.kili-technology.com/reference/graphql-api#user) for all possible fields. + first: Maximum number of users to return + skip: Number of skipped users (they are ordered by creation date) + disable_tqdm: If `True`, the progress bar will be disabled + as_generator: If `True`, a generator on the users is returned. + + Returns: + An iterable of users. + + Examples: + ``` + # List all users in my organization + >>> organization = kili.organizations()[0] + >>> organization_id = organization['id'] + >>> kili.users(organization_id=organization_id) + ``` + """ + disable_tqdm = disable_tqdm_if_as_generator(as_generator, disable_tqdm) + + users_gen = UserUseCases(self.kili_api_gateway).list_users( + filters=UserFilter( + api_key=api_key, + email=email, + organization_id=OrganizationId(organization_id) if organization_id else None, + activated=None, + id=None, + id_in=None, + ), + fields=fields, + options=QueryOptions(disable_tqdm, first, skip), + ) + + if as_generator: + return users_gen + return list(users_gen) + + @typechecked + def count_users( + self, + organization_id: Optional[str] = None, + api_key: Optional[str] = None, + email: Optional[str] = None, + ) -> int: + """Get user count based on a set of constraints. + + Args: + organization_id: Identifier of the user's organization. + api_key: Filter by API Key. + email: Filter by email. + + Returns: + The number of organizations with the parameters provided. + """ + return UserUseCases(self.kili_api_gateway).count_users( + UserFilter( + api_key=api_key, + email=email, + organization_id=OrganizationId(organization_id) if organization_id else None, + activated=None, + id=None, + id_in=None, + ) + ) + + @typechecked + def create_user( + self, + email: str, + password: str, + organization_role: OrganizationRole, + firstname: Optional[str] = None, + lastname: Optional[str] = None, + ) -> Dict[Literal["id"], str]: + """Add a user to your organization. + + Args: + email: Email of the new user, used as user's unique identifier. + password: On the first sign in, he will use this password and be able to change it. + organization_role: One of "ADMIN", "USER". + firstname: First name of the new user. + lastname: Last name of the new user. + + Returns: + A dictionary with the id of the new user. + """ + return UserUseCases(self.kili_api_gateway).create_user( + email=email, + password=password, + organization_role=organization_role, + firstname=firstname, + lastname=lastname, + fields=("id",), + ) + + @typechecked + def update_password( + self, email: str, old_password: str, new_password_1: str, new_password_2: str + ) -> Dict[Literal["id"], str]: + """Allow to modify the password that you use to connect to Kili. + + This resolver only works for on-premise installations without Auth0. + + Args: + email: Email of the person whose password has to be updated. + old_password: The old password + new_password_1: The new password + new_password_2: A confirmation field for the new password + + Returns: + A dict with the user id. + """ + return UserUseCases(self.kili_api_gateway).update_password( + old_password=old_password, + new_password_1=new_password_1, + new_password_2=new_password_2, + user_filter=UserFilter( + email=email, activated=None, api_key=None, id=None, id_in=None, organization_id=None + ), + fields=("id",), + ) + + @typechecked + def update_properties_in_user( + self, + email: str, + firstname: Optional[str] = None, + lastname: Optional[str] = None, + organization_id: Optional[str] = None, + organization_role: Optional[OrganizationRole] = None, + activated: Optional[bool] = None, + ) -> Dict[Literal["id"], str]: + """Update the properties of a user. + + Args: + email: The email is the identifier of the user. + firstname: Change the first name of the user. + lastname: Change the last name of the user. + organization_id: Change the organization the user is related to. + organization_role: Change the role of the user. + One of "ADMIN", "TEAM_MANAGER", "REVIEWER", "LABELER". + activated: In case we want to deactivate a user, but keep it. + + Returns: + A dict with the user id. + """ + return UserUseCases(self.kili_api_gateway).update_user( + user_filter=UserFilter( + email=email, activated=None, api_key=None, id=None, id_in=None, organization_id=None + ), + firstname=firstname, + lastname=lastname, + organization_id=OrganizationId(organization_id) if organization_id else None, + organization_role=organization_role, + activated=activated, + fields=("id",), + ) diff --git a/src/kili/services/asset_import/base.py b/src/kili/services/asset_import/base.py index a1d70eb6f..565aaec50 100644 --- a/src/kili/services/asset_import/base.py +++ b/src/kili/services/asset_import/base.py @@ -8,7 +8,7 @@ from itertools import repeat from json import dumps from pathlib import Path -from typing import Callable, List, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, List, NamedTuple, Optional, Tuple, Union from uuid import uuid4 from tenacity import Retrying @@ -43,6 +43,9 @@ from kili.utils import bucket from kili.utils.tqdm import tqdm +if TYPE_CHECKING: + from kili.client import Kili + class BatchParams(NamedTuple): """Contains all parameters related to the batch to import.""" @@ -75,7 +78,7 @@ class BaseBatchImporter: # pylint: disable=too-many-instance-attributes """Base class for BatchImporters.""" def __init__( - self, kili, project_params: ProjectParams, batch_params: BatchParams, pbar: tqdm + self, kili: "Kili", project_params: ProjectParams, batch_params: BatchParams, pbar: tqdm ) -> None: self.kili = kili self.project_id = project_params.project_id @@ -86,7 +89,7 @@ def __init__( self.http_client = kili.http_client logging.basicConfig() - self.logger = logging.getLogger("kili.services.asset_import.base") + self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) def import_batch(self, assets: ListOrTuple[AssetLike], verify: bool) -> List[str]: @@ -342,7 +345,7 @@ class BaseAbstractAssetImporter(abc.ABC): def __init__( self, - kili, + kili: "Kili", project_params: ProjectParams, processing_params: ProcessingParams, logger_params: LoggerParams, @@ -379,7 +382,7 @@ def is_hosted_content(assets: List[AssetLike]): return False def _can_upload_from_local_data(self): - user_me = self.kili.get_user() + user_me = self.kili.kili_api_gateway.get_current_user(fields=("email",)) where = OrganizationWhere( email=user_me["email"], ) diff --git a/src/kili/use_cases/user/__init__.py b/src/kili/use_cases/user/__init__.py new file mode 100644 index 000000000..01a4cf75f --- /dev/null +++ b/src/kili/use_cases/user/__init__.py @@ -0,0 +1,95 @@ +"""User use cases.""" +from typing import Dict, Generator, Optional + +from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions +from kili.adapters.kili_api_gateway.user.types import ( + CreateUserDataKiliGatewayInput, + UserDataKiliGatewayInput, +) +from kili.core.enums import OrganizationRole +from kili.domain.organization import OrganizationId +from kili.domain.types import ListOrTuple +from kili.domain.user import UserFilter +from kili.use_cases.base import BaseUseCases + + +class UserUseCases(BaseUseCases): + """User use cases.""" + + def list_users( + self, filters: UserFilter, fields: ListOrTuple[str], options: QueryOptions + ) -> Generator[Dict, None, None]: + """List all users.""" + return self._kili_api_gateway.list_users( + fields=fields, user_filters=filters, options=options + ) + + def count_users(self, filters: UserFilter) -> int: + """Count users.""" + return self._kili_api_gateway.count_users(user_filters=filters) + + def create_user( + self, + email: str, + password: str, + organization_role: OrganizationRole, + firstname: Optional[str], + lastname: Optional[str], + fields: ListOrTuple[str], + ) -> Dict: + """Create a user.""" + return self._kili_api_gateway.create_user( + data=CreateUserDataKiliGatewayInput( + email=email, + password=password, + organization_role=organization_role, + firstname=firstname, + lastname=lastname, + ), + fields=fields, + ) + + def update_password( + self, + old_password: str, + new_password_1: str, + new_password_2: str, + user_filter: UserFilter, + fields: ListOrTuple[str], + ) -> Dict: + """Update user password.""" + return self._kili_api_gateway.update_password( + old_password=old_password, + new_password_1=new_password_1, + new_password_2=new_password_2, + user_filter=user_filter, + fields=fields, + ) + + def update_user( + self, + user_filter: UserFilter, + firstname: Optional[str], + lastname: Optional[str], + organization_id: Optional[OrganizationId], + organization_role: Optional[OrganizationRole], + activated: Optional[bool], + fields: ListOrTuple[str], + ) -> Dict: + """Update user.""" + return self._kili_api_gateway.update_user( + user_filter=user_filter, + data=UserDataKiliGatewayInput( + activated=activated, + firstname=firstname, + lastname=lastname, + organization_id=organization_id, + organization_role=organization_role, + api_key=None, + email=None, + has_completed_labeling_tour=None, + hubspot_subscription_status=None, + organization=None, + ), + fields=fields, + ) diff --git a/tests/e2e/test_e2e_graphql_client.py b/tests/e2e/test_e2e_graphql_client.py index 9441cebc8..64bbdff6a 100644 --- a/tests/e2e/test_e2e_graphql_client.py +++ b/tests/e2e/test_e2e_graphql_client.py @@ -7,8 +7,8 @@ from gql.transport import exceptions from graphql import build_ast_schema, parse +from kili.adapters.kili_api_gateway.user.operations import get_current_user_query from kili.client import Kili -from kili.core.graphql.operations.user.queries import GQL_ME from kili.exceptions import GraphQLError @@ -88,7 +88,8 @@ def test_kili_client_can_be_used_in_multiple_threads(): NB_THREADS = 10 with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor: - futures = [executor.submit(kili.graphql_client.execute, GQL_ME) for _ in range(NB_THREADS)] + query = get_current_user_query("id email") + futures = [executor.submit(kili.graphql_client.execute, query) for _ in range(NB_THREADS)] for future in concurrent.futures.as_completed(futures): assert future.result()["data"]["id"] assert future.result()["data"]["email"] diff --git a/tests/e2e/test_query_project_users.py b/tests/e2e/test_query_project_users.py index 3e3469946..d8610d30a 100644 --- a/tests/e2e/test_query_project_users.py +++ b/tests/e2e/test_query_project_users.py @@ -30,7 +30,7 @@ def test_given_project_when_querying_project_users_it_works( ): # Given project_id, suspended_user_email = project_id_suspended_user_email - api_user = kili.get_user() + api_user = kili.kili_api_gateway.get_current_user(fields=("email",)) fields = ["activated", "deletedAt", "id", "role", "user.email", "user.id", "status"] # When diff --git a/tests/integration/adapters/kili_api_gateway/test_user.py b/tests/integration/adapters/kili_api_gateway/test_user.py new file mode 100644 index 000000000..0d7f44fc3 --- /dev/null +++ b/tests/integration/adapters/kili_api_gateway/test_user.py @@ -0,0 +1,169 @@ +import pytest_mock + +from kili.adapters.http_client import HttpClient +from kili.adapters.kili_api_gateway import KiliAPIGateway +from kili.adapters.kili_api_gateway.helpers.queries import ( + PaginatedGraphQLQuery, + QueryOptions, +) +from kili.adapters.kili_api_gateway.user.operations import ( + GQL_COUNT_USERS, + get_create_user_mutation, + get_current_user_query, + get_update_user_mutation, + get_users_query, +) +from kili.adapters.kili_api_gateway.user.types import ( + CreateUserDataKiliGatewayInput, + UserDataKiliGatewayInput, +) +from kili.core.graphql.graphql_client import GraphQLClient +from kili.domain.user import UserFilter, UserId + + +def test_given_kili_gateway_when_querying_users_list_it_calls_proper_resolver( + graphql_client: GraphQLClient, http_client: HttpClient, mocker: pytest_mock.MockerFixture +): + # Given + mocker.patch.object(PaginatedGraphQLQuery, "get_number_of_elements_to_query", return_value=1) + graphql_client.execute.return_value = {"data": [{"email": "fake_email"}]} + kili_gateway = KiliAPIGateway(graphql_client=graphql_client, http_client=http_client) + + # When + users_gen = kili_gateway.list_users( + UserFilter(id=UserId("fake_user_id")), + fields=("email",), + options=QueryOptions(disable_tqdm=True), + ) + _ = list(users_gen) + + # Then + graphql_client.execute.assert_called_once_with( + get_users_query(" email"), + { + "where": { + "activated": None, + "apiKey": None, + "email": None, + "id": "fake_user_id", + "idIn": None, + "organization": {"id": None}, + }, + "skip": 0, + "first": 1, + }, + ) + + +def test_given_kili_gateway_when_querying_count_users_it_calls_proper_resolver( + graphql_client: GraphQLClient, http_client: HttpClient +): + # Given + graphql_client.execute.return_value = {"data": 42} + kili_gateway = KiliAPIGateway(graphql_client=graphql_client, http_client=http_client) + + # When + users_count = kili_gateway.count_users(UserFilter(id=UserId("fake_user_id"))) + + # Then + assert users_count == 42 + graphql_client.execute.assert_called_once_with( + GQL_COUNT_USERS, + { + "activated": None, + "apiKey": None, + "email": None, + "id": "fake_user_id", + "idIn": None, + "organization": {"id": None}, + }, + ) + + +def test_given_kili_gateway_when_querying_current_users_it_calls_proper_resolver( + graphql_client: GraphQLClient, http_client: HttpClient +): + # Given + graphql_client.execute.return_value = {"data": {"id": "current_user_id"}} + kili_gateway = KiliAPIGateway(graphql_client=graphql_client, http_client=http_client) + + # When + current_user = kili_gateway.get_current_user(fields=("id",)) + + # Then + assert current_user == {"id": "current_user_id"} + graphql_client.execute.assert_called_once_with( + get_current_user_query(" id"), + ) + + +def test_given_kili_gateway_when_creating_user_it_calls_proper_resolver( + graphql_client: GraphQLClient, http_client: HttpClient +): + # Given + kili_gateway = KiliAPIGateway(graphql_client=graphql_client, http_client=http_client) + + # When + _ = kili_gateway.create_user( + fields=("id",), + data=CreateUserDataKiliGatewayInput( + email="fake@email.com", + firstname="john", + lastname="doe", + password="fake_pass", + organization_role="USER", + ), + ) + + # Then + graphql_client.execute.assert_called_once_with( + get_create_user_mutation(" id"), + { + "email": "fake@email.com", + "firstname": "john", + "lastname": "doe", + "password": "fake_pass", + "organizationRole": "USER", + }, + ) + + +def test_given_kili_gateway_when_updating_user_it_calls_proper_resolver( + graphql_client: GraphQLClient, http_client: HttpClient +): + # Given + kili_gateway = KiliAPIGateway(graphql_client=graphql_client, http_client=http_client) + + # When + _ = kili_gateway.update_user( + user_filter=UserFilter(id=UserId("fake_user_id")), + fields=("id",), + data=UserDataKiliGatewayInput(organization_role="USER"), + ) + + # Then + graphql_client.execute.assert_called_once_with( + get_update_user_mutation(" id"), + { + "data": { + "activated": None, + "apiKey": None, + "email": None, + "firstname": None, + "hasCompletedLabelingTour": None, + "hubspotSubscriptionStatus": None, + "lastname": None, + "organization": None, + "organizationId": None, + "organizationRole": "USER", + }, + "where": { + "activated": None, + "apiKey": None, + "email": None, + "id": "fake_user_id", + "idIn": None, + "organization": {"id": None}, + }, + }, + ) diff --git a/tests/integration/entrypoints/client/queries/test_users.py b/tests/integration/entrypoints/client/queries/test_users.py deleted file mode 100644 index f82f122cc..000000000 --- a/tests/integration/entrypoints/client/queries/test_users.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Dict, Generator, List -from unittest.mock import patch - -import pytest -from typeguard import check_type - -from kili.core.graphql.operations.user.queries import UserQuery -from kili.entrypoints.queries.user import QueriesUser - - -@pytest.mark.parametrize( - ("args", "kwargs", "expected_return_type"), - [ - ((), {}, List[Dict]), - ((), {"as_generator": True}, Generator[Dict, None, None]), - ((), {"as_generator": False}, List[Dict]), - ((), {"email": "test@kili.com", "as_generator": False}, List[Dict]), - ], -) -@patch.object(UserQuery, "__call__") -def test_users_query_return_type(mocker, args, kwargs, expected_return_type): - kili = QueriesUser() - kili.graphql_client = mocker.MagicMock() - kili.http_client = mocker.MagicMock() - - result = kili.users(*args, **kwargs) - check_type(result, expected_return_type) diff --git a/tests/integration/presentation/test_user.py b/tests/integration/presentation/test_user.py new file mode 100644 index 000000000..2cc4127c9 --- /dev/null +++ b/tests/integration/presentation/test_user.py @@ -0,0 +1,37 @@ +from typing import Dict, Generator, List + +import pytest +from typeguard import check_type + +from kili.adapters.kili_api_gateway import KiliAPIGateway +from kili.presentation.client.user import UserClientMethods +from kili.use_cases.user import UserUseCases + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_return_type"), + [ + ((), {}, List[Dict]), + ((), {"as_generator": True}, Generator[Dict, None, None]), + ((), {"as_generator": False}, List[Dict]), + ((), {"email": "test@kili.com", "as_generator": False}, List[Dict]), + ], +) +def test_given_users_query_when_i_call_it_i_get_correct_return_type( + kili_api_gateway: KiliAPIGateway, mocker, args, kwargs, expected_return_type +): + # Given + mocker.patch.object( + UserUseCases, + "list_users", + return_value=(u for u in [{"id": "fake_user_id_1"}, {"id": "fake_user_id_2"}]), + ) + kili = UserClientMethods() + kili.kili_api_gateway = kili_api_gateway + + # When + result = kili.users(*args, **kwargs) + + # Then + check_type(result, expected_return_type) + assert list(result) == [{"id": "fake_user_id_1"}, {"id": "fake_user_id_2"}]