diff --git a/backend/connector_auth_v2/__init__.py b/backend/connector_auth_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/connector_auth_v2/admin.py b/backend/connector_auth_v2/admin.py new file mode 100644 index 000000000..014dfec3e --- /dev/null +++ b/backend/connector_auth_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import ConnectorAuth + +admin.site.register(ConnectorAuth) diff --git a/backend/connector_auth_v2/apps.py b/backend/connector_auth_v2/apps.py new file mode 100644 index 000000000..e7cc819b7 --- /dev/null +++ b/backend/connector_auth_v2/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class ConnectorAuthConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "connector_auth_v2" diff --git a/backend/connector_auth_v2/constants.py b/backend/connector_auth_v2/constants.py new file mode 100644 index 000000000..886968d87 --- /dev/null +++ b/backend/connector_auth_v2/constants.py @@ -0,0 +1,18 @@ +class ConnectorAuthKey: + OAUTH_KEY = "oauth-key" + + +class SocialAuthConstants: + UID = "uid" + PROVIDER = "provider" + ACCESS_TOKEN = "access_token" + REFRESH_TOKEN = "refresh_token" + TOKEN_TYPE = "token_type" + AUTH_TIME = "auth_time" + EXPIRES = "expires" + + REFRESH_AFTER_FORMAT = "%d/%m/%Y %H:%M:%S" + REFRESH_AFTER = "refresh_after" # Timestamp to refresh tokens after + + GOOGLE_OAUTH = "google-oauth2" + GOOGLE_TOKEN_EXPIRY_FORMAT = "%d/%m/%Y %H:%M:%S" diff --git a/backend/connector_auth_v2/exceptions.py b/backend/connector_auth_v2/exceptions.py new file mode 100644 index 000000000..603bcc8d9 --- /dev/null +++ b/backend/connector_auth_v2/exceptions.py @@ -0,0 +1,31 @@ +from typing import Optional + +from rest_framework.exceptions import APIException + + +class CacheMissException(APIException): + status_code = 404 + default_detail = "Key doesn't exist." + + +class EnrichConnectorMetadataException(APIException): + status_code = 500 + default_detail = "Connector metadata could not be enriched" + + +class MissingParamException(APIException): + status_code = 400 + default_detail = "Bad request, missing parameter." + + def __init__( + self, + code: Optional[str] = None, + param: Optional[str] = None, + ) -> None: + detail = f"Bad request, missing parameter: {param}" + super().__init__(detail, code) + + +class KeyNotConfigured(APIException): + status_code = 500 + default_detail = "Key is not configured correctly" diff --git a/backend/connector_auth_v2/models.py b/backend/connector_auth_v2/models.py new file mode 100644 index 000000000..ca8caaefd --- /dev/null +++ b/backend/connector_auth_v2/models.py @@ -0,0 +1,141 @@ +import logging +import uuid +from typing import Any + +from account_v2.models import User +from connector_auth_v2.constants import SocialAuthConstants +from connector_auth_v2.pipeline.google import GoogleAuthHelper +from django.db import models +from django.db.models.query import QuerySet +from rest_framework.request import Request +from social_django.fields import JSONField +from social_django.models import AbstractUserSocialAuth, DjangoStorage +from social_django.strategy import DjangoStrategy + +logger = logging.getLogger(__name__) + + +class ConnectorAuthManager(models.Manager): + def get_queryset(self) -> QuerySet: + queryset = super().get_queryset() + # TODO PAN-83: Decrypt here + # for obj in queryset: + # logger.info(f"Decrypting extra_data: {obj.extra_data}") + + return queryset + + +class ConnectorAuth(AbstractUserSocialAuth): + """Social Auth association model, stores tokens. + The relation with `account.User` is only for the library to work + and should be NOT be used to access the secrets. + Use the following static methods instead + ``` + @classmethod + def get_social_auth(cls, provider, id): + + @classmethod + def create_social_auth(cls, user, uid, provider): + ``` + """ + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + user = models.ForeignKey( + User, + related_name="connector_auths", + on_delete=models.CASCADE, + null=True, + ) + + def __str__(self) -> str: + return f"ConnectorAuth(provider: {self.provider}, uid: {self.uid})" + + def save(self, *args: Any, **kwargs: Any) -> Any: + # TODO PAN-83: Encrypt here + # logger.info(f"Encrypting extra_data: {self.extra_data}") + return super().save(*args, **kwargs) + + def set_extra_data(self, extra_data=None): # type: ignore + ConnectorAuth.check_credential_format(extra_data) + if extra_data[SocialAuthConstants.PROVIDER] == SocialAuthConstants.GOOGLE_OAUTH: + extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data) + return super().set_extra_data(extra_data) + + def refresh_token(self, strategy, *args, **kwargs): # type: ignore + """Override of Python Social Auth (PSA)'s refresh_token functionality + to store uid, provider.""" + token = self.extra_data.get("refresh_token") or self.extra_data.get( + "access_token" + ) + backend = self.get_backend_instance(strategy) + if token and backend and hasattr(backend, "refresh_token"): + response = backend.refresh_token(token, *args, **kwargs) + extra_data = backend.extra_data(self, self.uid, response, self.extra_data) + extra_data[SocialAuthConstants.PROVIDER] = backend.name + extra_data[SocialAuthConstants.UID] = self.uid + if self.set_extra_data(extra_data): # type: ignore + self.save() + + def get_and_refresh_tokens(self, request: Request = None) -> tuple[JSONField, bool]: + """Uses Social Auth's ability to refresh tokens if necessary. + + Returns: + Tuple[JSONField, bool]: JSONField of connector metadata + and flag indicating if tokens were refreshed + """ + # To avoid circular dependency error on import + from social_django.utils import load_strategy + + refreshed_token = False + strategy: DjangoStrategy = load_strategy(request=request) + existing_access_token = self.access_token + new_access_token = self.get_access_token(strategy) + if new_access_token != existing_access_token: + refreshed_token = True + related_connector_instances = self.connectorinstance_set.all() + for connector_instance in related_connector_instances: + connector_instance.connector_metadata = self.extra_data + connector_instance.save() + logger.info( + f"Refreshed access token for connector {connector_instance.id}, " + f"provider: {self.provider}, uid: {self.uid}" + ) + + return self.extra_data, refreshed_token + + @staticmethod + def check_credential_format( + oauth_credentials: dict[str, str], raise_exception: bool = True + ) -> bool: + if ( + SocialAuthConstants.PROVIDER in oauth_credentials + and SocialAuthConstants.UID in oauth_credentials + ): + return True + else: + if raise_exception: + raise ValueError( + "Auth credential should have provider, uid and connector guid" + ) + return False + + objects = ConnectorAuthManager() + + class Meta: + app_label = "connector_auth_v2" + verbose_name = "Connector Auth" + verbose_name_plural = "Connector Auths" + db_table = "connector_auth_v2" + constraints = [ + models.UniqueConstraint( + fields=[ + "provider", + "uid", + ], + name="unique_provider_uid_index", + ), + ] + + +class ConnectorDjangoStorage(DjangoStorage): + user = ConnectorAuth diff --git a/backend/connector_auth_v2/pipeline/common.py b/backend/connector_auth_v2/pipeline/common.py new file mode 100644 index 000000000..b6c4861c9 --- /dev/null +++ b/backend/connector_auth_v2/pipeline/common.py @@ -0,0 +1,111 @@ +import logging +from typing import Any, Optional + +from account_v2.models import User +from connector_auth_v2.constants import ConnectorAuthKey, SocialAuthConstants +from connector_auth_v2.models import ConnectorAuth +from connector_auth_v2.pipeline.google import GoogleAuthHelper +from django.conf import settings +from django.core.cache import cache +from rest_framework.exceptions import PermissionDenied +from social_core.backends.oauth import BaseOAuth2 + +logger = logging.getLogger(__name__) + + +def check_user_exists(backend: BaseOAuth2, user: User, **kwargs: Any) -> dict[str, str]: + """Checks if user is authenticated (will be handled in auth middleware, + present as a fail safe) + + Args: + user (account.User): User model + + Raises: + PermissionDenied: Unauthorized user + + Returns: + dict: Carrying response details for auth pipeline + """ + if not user: + raise PermissionDenied(backend) + return {**kwargs} + + +def cache_oauth_creds( + backend: BaseOAuth2, + details: dict[str, str], + response: dict[str, str], + uid: str, + user: User, + *args: Any, + **kwargs: Any, +) -> dict[str, str]: + """Used to cache the extra data JSON in redis against a key. + + This contains the access and refresh token along with details + regarding expiry, uid (unique ID given by provider) and provider. + """ + cache_key = kwargs.get("cache_key") or backend.strategy.session_get( + settings.SOCIAL_AUTH_FIELDS_STORED_IN_SESSION[0], + ConnectorAuthKey.OAUTH_KEY, + ) + extra_data = backend.extra_data(user, uid, response, details, *args, **kwargs) + extra_data[SocialAuthConstants.PROVIDER] = backend.name + extra_data[SocialAuthConstants.UID] = uid + + if backend.name == SocialAuthConstants.GOOGLE_OAUTH: + extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data) + + cache.set( + cache_key, + extra_data, + int(settings.SOCIAL_AUTH_EXTRA_DATA_EXPIRATION_TIME_IN_SECOND), + ) + return {**kwargs} + + +class ConnectorAuthHelper: + @staticmethod + def get_oauth_creds_from_cache( + cache_key: str, delete_key: bool = True + ) -> Optional[dict[str, str]]: + """Retrieves oauth credentials from the cache. + + Args: + cache_key (str): Key to obtain credentials from + + Returns: + Optional[dict[str,str]]: Returns credentials. None if it doesn't exist + """ + oauth_creds: dict[str, str] = cache.get(cache_key) + if delete_key: + cache.delete(cache_key) + return oauth_creds + + @staticmethod + def get_or_create_connector_auth( + oauth_credentials: dict[str, str], user: User = None # type: ignore + ) -> ConnectorAuth: + """Gets or creates a ConnectorAuth object. + + Args: + user (User): Used while creation, can be removed if not required + oauth_credentials (dict[str,str]): Needs to have provider and uid + + Returns: + ConnectorAuth: Object for the respective provider/uid + """ + ConnectorAuth.check_credential_format(oauth_credentials) + provider = oauth_credentials[SocialAuthConstants.PROVIDER] + uid = oauth_credentials[SocialAuthConstants.UID] + connector_oauth: ConnectorAuth = ConnectorAuth.get_social_auth( + provider=provider, uid=uid + ) + if not connector_oauth: + connector_oauth = ConnectorAuth.create_social_auth( + user, uid=uid, provider=provider + ) + + # TODO: Remove User's related manager access to ConnectorAuth + connector_oauth.set_extra_data(oauth_credentials) # type: ignore + return connector_oauth diff --git a/backend/connector_auth_v2/pipeline/google.py b/backend/connector_auth_v2/pipeline/google.py new file mode 100644 index 000000000..d585bd085 --- /dev/null +++ b/backend/connector_auth_v2/pipeline/google.py @@ -0,0 +1,33 @@ +from datetime import datetime, timedelta + +from connector_auth_v2.constants import SocialAuthConstants as AuthConstants +from connector_auth_v2.exceptions import EnrichConnectorMetadataException +from connector_processor.constants import ConnectorKeys + +from unstract.connectors.filesystems.google_drive.constants import GDriveConstants + + +class GoogleAuthHelper: + @staticmethod + def enrich_connector_metadata(kwargs: dict[str, str]) -> dict[str, str]: + token_expiry: datetime = datetime.now() + auth_time = kwargs.get(AuthConstants.AUTH_TIME) + expires = kwargs.get(AuthConstants.EXPIRES) + if auth_time and expires: + reference = datetime.utcfromtimestamp(float(auth_time)) + token_expiry = reference + timedelta(seconds=float(expires)) + else: + raise EnrichConnectorMetadataException + # Used by GDrive FS, apart from ACCESS_TOKEN and REFRESH_TOKEN + kwargs[GDriveConstants.TOKEN_EXPIRY] = token_expiry.strftime( + AuthConstants.GOOGLE_TOKEN_EXPIRY_FORMAT + ) + + # Used by Unstract + kwargs[ConnectorKeys.PATH] = ( + GDriveConstants.ROOT_PREFIX + ) # Acts as a prefix for all paths + kwargs[AuthConstants.REFRESH_AFTER] = token_expiry.strftime( + AuthConstants.REFRESH_AFTER_FORMAT + ) + return kwargs diff --git a/backend/connector_auth_v2/urls.py b/backend/connector_auth_v2/urls.py new file mode 100644 index 000000000..55337ad20 --- /dev/null +++ b/backend/connector_auth_v2/urls.py @@ -0,0 +1,21 @@ +from django.urls import include, path, re_path +from rest_framework.urlpatterns import format_suffix_patterns + +from .views import ConnectorAuthViewSet + +connector_auth_cache = ConnectorAuthViewSet.as_view( + { + "get": "cache_key", + } +) + +urlpatterns = format_suffix_patterns( + [ + path("oauth/", include("social_django.urls", namespace="social")), + re_path( + "^oauth/cache-key/(?P.+)$", + connector_auth_cache, + name="connector-cache", + ), + ] +) diff --git a/backend/connector_auth_v2/views.py b/backend/connector_auth_v2/views.py new file mode 100644 index 000000000..1cc2022bf --- /dev/null +++ b/backend/connector_auth_v2/views.py @@ -0,0 +1,48 @@ +import logging +import uuid + +from connector_auth_v2.constants import SocialAuthConstants +from connector_auth_v2.exceptions import KeyNotConfigured +from django.conf import settings +from rest_framework import status, viewsets +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from utils.user_session import UserSessionUtils + +logger = logging.getLogger(__name__) + + +class ConnectorAuthViewSet(viewsets.ViewSet): + """Contains methods for Connector related authentication.""" + + versioning_class = URLPathVersioning + + def cache_key( + self: "ConnectorAuthViewSet", request: Request, backend: str + ) -> Response: + if backend == SocialAuthConstants.GOOGLE_OAUTH and ( + settings.SOCIAL_AUTH_GOOGLE_OAUTH2_KEY is None + or settings.SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET is None + ): + msg = ( + f"Keys not configured for {backend}, add env vars " + f"`GOOGLE_OAUTH2_KEY` and `GOOGLE_OAUTH2_SECRET`." + ) + logger.warn(msg) + raise KeyNotConfigured( + f"{msg}\nRefer to: " + "https://developers.google.com/identity/protocols/oauth2#1.-" + "obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar." + "console_name-." + ) + + random = str(uuid.uuid4()) + user_id = request.user.user_id + org_id = UserSessionUtils.get_organization_id(request) + cache_key = f"oauth:{org_id}|{user_id}|{backend}|{random}" + logger.info(f"Generated cache key: {cache_key}") + return Response( + status=status.HTTP_200_OK, + data={"cache_key": f"{cache_key}"}, + ) diff --git a/backend/connector_v2/__init__.py b/backend/connector_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/connector_v2/admin.py b/backend/connector_v2/admin.py new file mode 100644 index 000000000..3750fb4c5 --- /dev/null +++ b/backend/connector_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import ConnectorInstance + +admin.site.register(ConnectorInstance) diff --git a/backend/connector_v2/apps.py b/backend/connector_v2/apps.py new file mode 100644 index 000000000..646d84f0d --- /dev/null +++ b/backend/connector_v2/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class ConnectorConfig(AppConfig): + name = "connector_v2" diff --git a/backend/connector_v2/connector_instance_helper.py b/backend/connector_v2/connector_instance_helper.py new file mode 100644 index 000000000..e6a7d4307 --- /dev/null +++ b/backend/connector_v2/connector_instance_helper.py @@ -0,0 +1,330 @@ +import logging +from typing import Any, Optional + +from account_v2.models import User +from connector_v2.constants import ConnectorInstanceConstant +from connector_v2.models import ConnectorInstance +from connector_v2.unstract_account import UnstractAccount +from django.conf import settings +from utils.user_context import UserContext +from workflow_manager.workflow_v2.models.workflow import Workflow + +from unstract.connectors.filesystems.ucs import UnstractCloudStorage +from unstract.connectors.filesystems.ucs.constants import UCSKey + +logger = logging.getLogger(__name__) + + +class ConnectorInstanceHelper: + @staticmethod + def create_default_gcs_connector(workflow: Workflow, user: User) -> None: + """Method to create default storage connector. + + Args: + org_id (str) + workflow (Workflow) + user (User) + """ + organization_id = UserContext.get_organization_identifier() + if not user.project_storage_created: + logger.info("Creating default storage") + account = UnstractAccount(organization_id, user.email) + account.provision_s3_storage() + account.upload_sample_files() + user.project_storage_created = True + user.save() + logger.info("default storage created successfully.") + + logger.info("Adding connectors to Unstract") + connector_name = ConnectorInstanceConstant.USER_STORAGE + gcs_id = UnstractCloudStorage.get_id() + bucket_name = settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME + base_path = f"{bucket_name}/{organization_id}/{user.email}" + + connector_metadata = { + UCSKey.KEY: settings.GOOGLE_STORAGE_ACCESS_KEY_ID, + UCSKey.SECRET: settings.GOOGLE_STORAGE_SECRET_ACCESS_KEY, + UCSKey.BUCKET: bucket_name, + UCSKey.ENDPOINT_URL: settings.GOOGLE_STORAGE_BASE_URL, + } + connector_metadata__input = { + **connector_metadata, + UCSKey.PATH: base_path + "/input", + } + connector_metadata__output = { + **connector_metadata, + UCSKey.PATH: base_path + "/output", + } + ConnectorInstance.objects.create( + connector_name=connector_name, + workflow=workflow, + created_by=user, + connector_id=gcs_id, + connector_metadata=connector_metadata__input, + connector_type=ConnectorInstance.ConnectorType.INPUT, + connector_mode=ConnectorInstance.ConnectorMode.FILE_SYSTEM, + ) + ConnectorInstance.objects.create( + connector_name=connector_name, + workflow=workflow, + created_by=user, + connector_id=gcs_id, + connector_metadata=connector_metadata__output, + connector_type=ConnectorInstance.ConnectorType.OUTPUT, + connector_mode=ConnectorInstance.ConnectorMode.FILE_SYSTEM, + ) + logger.info("Connectors added successfully.") + + @staticmethod + def get_connector_instances_by_workflow( + workflow_id: str, + connector_type: tuple[str, str], + connector_mode: Optional[tuple[int, str]] = None, + values: Optional[list[str]] = None, + connector_name: Optional[str] = None, + ) -> list[ConnectorInstance]: + """Method to get connector instances by workflow. + + Args: + workflow_id (str) + connector_type (tuple[str, str]): Specifies input/output + connector_mode (Optional[tuple[int, str]], optional): + Specifies database/file + values (Optional[list[str]], optional): Defaults to None. + connector_name (Optional[str], optional): Defaults to None. + + Returns: + list[ConnectorInstance] + """ + logger.info(f"Setting connector mode to {connector_mode}") + filter_params: dict[str, Any] = { + "workflow": workflow_id, + "connector_type": connector_type, + } + if connector_mode is not None: + filter_params["connector_mode"] = connector_mode + if connector_name is not None: + filter_params["connector_name"] = connector_name + + connector_instances = ConnectorInstance.objects.filter(**filter_params).all() + logger.debug(f"Retrieved connector instance values {connector_instances}") + if values is not None: + filtered_connector_instances = connector_instances.values(*values) + logger.info( + f"Returning filtered \ + connector instance value {filtered_connector_instances}" + ) + return list(filtered_connector_instances) + logger.info(f"Returning connector instances {connector_instances}") + return list(connector_instances) + + @staticmethod + def get_connector_instance_by_workflow( + workflow_id: str, + connector_type: tuple[str, str], + connector_mode: Optional[tuple[int, str]] = None, + connector_name: Optional[str] = None, + ) -> Optional[ConnectorInstance]: + """Get one connector instance. + + Use this method if the connector instance is unique for \ + filter_params + Args: + workflow_id (str): _description_ + connector_type (tuple[str, str]): Specifies input/output + connector_mode (Optional[tuple[int, str]], optional). + Specifies database/filesystem. + connector_name (Optional[str], optional). + + Returns: + list[ConnectorInstance]: _description_ + """ + logger.info("Fetching connector instance by workflow") + filter_params: dict[str, Any] = { + "workflow": workflow_id, + "connector_type": connector_type, + } + if connector_mode is not None: + filter_params["connector_mode"] = connector_mode + if connector_name is not None: + filter_params["connector_name"] = connector_name + + try: + connector_instance: ConnectorInstance = ConnectorInstance.objects.filter( + **filter_params + ).first() + except Exception as exc: + logger.error(f"Error occured while fetching connector instances {exc}") + raise exc + + return connector_instance + + @staticmethod + def get_input_connector_instance_by_name_for_workflow( + workflow_id: str, + connector_name: str, + ) -> Optional[ConnectorInstance]: + """Method to get Input connector instance name from the workflow. + + Args: + workflow_id (str) + connector_name (str) + + Returns: + Optional[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instance_by_workflow( + workflow_id=workflow_id, + connector_type=ConnectorInstance.ConnectorType.INPUT, + connector_name=connector_name, + ) + + @staticmethod + def get_output_connector_instance_by_name_for_workflow( + workflow_id: str, + connector_name: str, + ) -> Optional[ConnectorInstance]: + """Method to get output connector name by Workflow. + + Args: + workflow_id (str) + connector_name (str) + + Returns: + Optional[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instance_by_workflow( + workflow_id=workflow_id, + connector_type=ConnectorInstance.ConnectorType.OUTPUT, + connector_name=connector_name, + ) + + @staticmethod + def get_input_connector_instances_by_workflow( + workflow_id: str, + ) -> list[ConnectorInstance]: + """Method to get connector instances by workflow. + + Args: + workflow_id (str) + + Returns: + list[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, ConnectorInstance.ConnectorType.INPUT + ) + + @staticmethod + def get_output_connector_instances_by_workflow( + workflow_id: str, + ) -> list[ConnectorInstance]: + """Method to get output connector instances by workflow. + + Args: + workflow_id (str): _description_ + + Returns: + list[ConnectorInstance]: _description_ + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, ConnectorInstance.ConnectorType.OUTPUT + ) + + @staticmethod + def get_file_system_input_connector_instances_by_workflow( + workflow_id: str, values: Optional[list[str]] = None + ) -> list[ConnectorInstance]: + """Method to fetch file system connector by workflow. + + Args: + workflow_id (str): + values (Optional[list[str]], optional) + + Returns: + list[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, + ConnectorInstance.ConnectorType.INPUT, + ConnectorInstance.ConnectorMode.FILE_SYSTEM, + values, + ) + + @staticmethod + def get_file_system_output_connector_instances_by_workflow( + workflow_id: str, values: Optional[list[str]] = None + ) -> list[ConnectorInstance]: + """Method to get file system output connector by workflow. + + Args: + workflow_id (str) + values (Optional[list[str]], optional) + + Returns: + list[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, + ConnectorInstance.ConnectorType.OUTPUT, + ConnectorInstance.ConnectorMode.FILE_SYSTEM, + values, + ) + + @staticmethod + def get_database_input_connector_instances_by_workflow( + workflow_id: str, values: Optional[list[str]] = None + ) -> list[ConnectorInstance]: + """Method to fetch input database connectors by workflow. + + Args: + workflow_id (str) + values (Optional[list[str]], optional) + + Returns: + list[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, + ConnectorInstance.ConnectorType.INPUT, + ConnectorInstance.ConnectorMode.DATABASE, + values, + ) + + @staticmethod + def get_database_output_connector_instances_by_workflow( + workflow_id: str, values: Optional[list[str]] = None + ) -> list[ConnectorInstance]: + """Method to fetch output database connectors by workflow. + + Args: + workflow_id (str) + values (Optional[list[str]], optional) + + Returns: + list[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, + ConnectorInstance.ConnectorType.OUTPUT, + ConnectorInstance.ConnectorMode.DATABASE, + values, + ) + + @staticmethod + def get_input_output_connector_instances_by_workflow( + workflow_id: str, + ) -> list[ConnectorInstance]: + """Method to fetch input and output connectors by workflow. + + Args: + workflow_id (str) + + Returns: + list[ConnectorInstance] + """ + filter_params: dict[str, Any] = { + "workflow": workflow_id, + } + connector_instances = ConnectorInstance.objects.filter(**filter_params).all() + return list(connector_instances) diff --git a/backend/connector_v2/constants.py b/backend/connector_v2/constants.py new file mode 100644 index 000000000..a02512720 --- /dev/null +++ b/backend/connector_v2/constants.py @@ -0,0 +1,16 @@ +class ConnectorInstanceKey: + CONNECTOR_ID = "connector_id" + CONNECTOR_NAME = "connector_name" + CONNECTOR_TYPE = "connector_type" + CONNECTOR_MODE = "connector_mode" + CONNECTOR_VERSION = "connector_version" + CONNECTOR_AUTH = "connector_auth" + CONNECTOR_METADATA = "connector_metadata" + CONNECTOR_EXISTS = ( + "Connector with this configuration already exists in this project." + ) + DUPLICATE_API = "It appears that a duplicate call may have been made." + + +class ConnectorInstanceConstant: + USER_STORAGE = "User Storage" diff --git a/backend/connector_v2/fields.py b/backend/connector_v2/fields.py new file mode 100644 index 000000000..2a0f18c54 --- /dev/null +++ b/backend/connector_v2/fields.py @@ -0,0 +1,41 @@ +import logging +from datetime import datetime + +from connector_auth_v2.constants import SocialAuthConstants +from connector_auth_v2.models import ConnectorAuth +from django.db import models + +logger = logging.getLogger(__name__) + + +class ConnectorAuthJSONField(models.JSONField): + def from_db_value(self, value, expression, connection): # type: ignore + """Overriding default function.""" + metadata = super().from_db_value(value, expression, connection) + provider = metadata.get(SocialAuthConstants.PROVIDER) + uid = metadata.get(SocialAuthConstants.UID) + if not provider or not uid: + return metadata + + refresh_after_str = metadata.get(SocialAuthConstants.REFRESH_AFTER) + if not refresh_after_str: + return metadata + + refresh_after = datetime.strptime( + refresh_after_str, SocialAuthConstants.REFRESH_AFTER_FORMAT + ) + if datetime.now() > refresh_after: + metadata = self._refresh_tokens(provider, uid) + return metadata + + def _refresh_tokens(self, provider: str, uid: str) -> dict[str, str]: + """Retrieves PSA object and refreshes the token if necessary.""" + connector_auth: ConnectorAuth = ConnectorAuth.get_social_auth( + provider=provider, uid=uid + ) + if connector_auth: + ( + connector_metadata, + _, + ) = connector_auth.get_and_refresh_tokens() + return connector_metadata # type: ignore diff --git a/backend/connector_v2/migrations/__init__.py b/backend/connector_v2/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/connector_v2/models.py b/backend/connector_v2/models.py new file mode 100644 index 000000000..3fd535b05 --- /dev/null +++ b/backend/connector_v2/models.py @@ -0,0 +1,132 @@ +import json +import uuid +from typing import Any + +from account_v2.models import User +from connector_auth_v2.models import ConnectorAuth +from connector_processor.connector_processor import ConnectorProcessor +from connector_processor.constants import ConnectorKeys +from cryptography.fernet import Fernet +from django.conf import settings +from django.db import models +from utils.models.base_model import BaseModel +from utils.models.organization_mixin import ( + DefaultOrganizationManagerMixin, + DefaultOrganizationMixin, +) +from workflow_manager.workflow_v2.models import Workflow + +from backend.constants import FieldLengthConstants as FLC + +CONNECTOR_NAME_SIZE = 128 +VERSION_NAME_SIZE = 64 + + +class ConnectorInstanceModelManager(DefaultOrganizationManagerMixin, models.Manager): + pass + + +class ConnectorInstance(DefaultOrganizationMixin, BaseModel): + # TODO: handle all cascade deletions + class ConnectorType(models.TextChoices): + INPUT = "INPUT", "Input" + OUTPUT = "OUTPUT", "Output" + + class ConnectorMode(models.IntegerChoices): + UNKNOWN = 0, "UNKNOWN" + FILE_SYSTEM = 1, "FILE_SYSTEM" + DATABASE = 2, "DATABASE" + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + connector_name = models.TextField( + max_length=CONNECTOR_NAME_SIZE, null=False, blank=False + ) + workflow = models.ForeignKey( + Workflow, + on_delete=models.CASCADE, + related_name="connector_workflow", + null=False, + blank=False, + ) + connector_id = models.CharField(max_length=FLC.CONNECTOR_ID_LENGTH, default="") + connector_metadata = models.BinaryField(null=True) + connector_version = models.CharField(max_length=VERSION_NAME_SIZE, default="") + connector_type = models.CharField(choices=ConnectorType.choices) + # TODO: handle connector_auth cascade deletion + connector_auth = models.ForeignKey( + ConnectorAuth, + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="connector_instances", + ) + connector_mode = models.CharField( + choices=ConnectorMode.choices, + default=ConnectorMode.UNKNOWN, + db_comment="0: UNKNOWN, 1: FILE_SYSTEM, 2: DATABASE", + ) + + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="connectors_created", + null=True, + blank=True, + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="connectors_modified", + null=True, + blank=True, + ) + + # Manager + objects = ConnectorInstanceModelManager() + + def get_connector_metadata(self) -> dict[str, str]: + """Gets connector metadata and refreshes the tokens if needed in case + of OAuth.""" + tokens_refreshed = False + if self.connector_auth: + ( + self.connector_metadata, + tokens_refreshed, + ) = self.connector_auth.get_and_refresh_tokens() + if tokens_refreshed: + self.save() + return self.connector_metadata + + @staticmethod + def supportsOAuth(connector_id: str) -> bool: + return bool( + ConnectorProcessor.get_connector_data_with_key( + connector_id, ConnectorKeys.OAUTH + ) + ) + + def __str__(self) -> str: + return ( + f"Connector({self.id}, type{self.connector_type}," + f" workflow: {self.workflow})" + ) + + @property + def metadata(self) -> Any: + encryption_secret: str = settings.ENCRYPTION_KEY + cipher_suite: Fernet = Fernet(encryption_secret.encode("utf-8")) + decrypted_value = cipher_suite.decrypt( + bytes(self.connector_metadata).decode("utf-8") + ) + return json.loads(decrypted_value) + + class Meta: + db_table = "connector_instance_v2" + verbose_name = "Connector Instance" + verbose_name_plural = "Connector Instances" + constraints = [ + models.UniqueConstraint( + fields=["connector_name", "workflow", "connector_type"], + name="unique_workflow_connector", + ), + ] diff --git a/backend/connector_v2/serializers.py b/backend/connector_v2/serializers.py new file mode 100644 index 000000000..cd9727e84 --- /dev/null +++ b/backend/connector_v2/serializers.py @@ -0,0 +1,101 @@ +import json +import logging +from collections import OrderedDict +from typing import Any, Optional + +from connector_auth_v2.models import ConnectorAuth +from connector_auth_v2.pipeline.common import ConnectorAuthHelper +from connector_processor.connector_processor import ConnectorProcessor +from connector_processor.constants import ConnectorKeys +from connector_processor.exceptions import OAuthTimeOut +from connector_v2.constants import ConnectorInstanceKey as CIKey +from cryptography.fernet import Fernet +from django.conf import settings +from utils.serializer_utils import SerializerUtils + +from backend.serializers import AuditSerializer +from unstract.connectors.filesystems.ucs import UnstractCloudStorage + +from .models import ConnectorInstance + +logger = logging.getLogger(__name__) + + +class ConnectorInstanceSerializer(AuditSerializer): + class Meta: + model = ConnectorInstance + fields = "__all__" + + def validate_connector_metadata(self, value: dict[Any]) -> dict[Any]: + """Validating Json metadata This custom validation is to avoid conflict + with user input and db binary data. + + Args: + value (Any): dict of metadata + + Returns: + dict[Any]: dict of metadata + """ + return value + + def save(self, **kwargs): # type: ignore + user = self.context.get("request").user or None + connector_id: str = kwargs[CIKey.CONNECTOR_ID] + connector_oauth: Optional[ConnectorAuth] = None + if ( + ConnectorInstance.supportsOAuth(connector_id=connector_id) + and CIKey.CONNECTOR_METADATA in kwargs + ): + try: + connector_oauth = ConnectorAuthHelper.get_or_create_connector_auth( + user=user, # type: ignore + oauth_credentials=kwargs[CIKey.CONNECTOR_METADATA], + ) + kwargs[CIKey.CONNECTOR_AUTH] = connector_oauth + ( + kwargs[CIKey.CONNECTOR_METADATA], + refresh_status, + ) = connector_oauth.get_and_refresh_tokens() + except Exception as exc: + logger.error( + "Error while obtaining ConnectorAuth for connector id " + f"{connector_id}: {exc}" + ) + raise OAuthTimeOut + + connector_mode = ConnectorProcessor.get_connector_data_with_key( + connector_id, CIKey.CONNECTOR_MODE + ) + kwargs[CIKey.CONNECTOR_MODE] = connector_mode.value + + encryption_secret: str = settings.ENCRYPTION_KEY + f: Fernet = Fernet(encryption_secret.encode("utf-8")) + json_string: str = json.dumps(kwargs.get(CIKey.CONNECTOR_METADATA)) + if self.validated_data: + self.validated_data.pop(CIKey.CONNECTOR_METADATA) + + kwargs[CIKey.CONNECTOR_METADATA] = f.encrypt(json_string.encode("utf-8")) + + instance = super().save(**kwargs) + return instance + + def to_representation(self, instance: ConnectorInstance) -> dict[str, str]: + # to remove the sensitive fields being returned + rep: OrderedDict[str, Any] = super().to_representation(instance) + if instance.connector_id == UnstractCloudStorage.get_id(): + rep[CIKey.CONNECTOR_METADATA] = {} + if SerializerUtils.check_context_for_GET_or_POST(context=self.context): + rep.pop(CIKey.CONNECTOR_AUTH) + # set icon fields for UI + rep[ConnectorKeys.ICON] = ConnectorProcessor.get_connector_data_with_key( + instance.connector_id, ConnectorKeys.ICON + ) + encryption_secret: str = settings.ENCRYPTION_KEY + f: Fernet = Fernet(encryption_secret.encode("utf-8")) + + if instance.connector_metadata: + adapter_metadata = json.loads( + f.decrypt(bytes(instance.connector_metadata).decode("utf-8")) + ) + rep[CIKey.CONNECTOR_METADATA] = adapter_metadata + return rep diff --git a/backend/connector_v2/tests/conftest.py b/backend/connector_v2/tests/conftest.py new file mode 100644 index 000000000..89f1715a1 --- /dev/null +++ b/backend/connector_v2/tests/conftest.py @@ -0,0 +1,9 @@ +import pytest +from django.core.management import call_command + + +@pytest.fixture(scope="session") +def django_db_setup(django_db_blocker): # type: ignore + fixtures = ["./connector/tests/fixtures/fixtures_0001.json"] + with django_db_blocker.unblock(): + call_command("loaddata", *fixtures) diff --git a/backend/connector_v2/tests/connector_tests.py b/backend/connector_v2/tests/connector_tests.py new file mode 100644 index 000000000..f3e81d72f --- /dev/null +++ b/backend/connector_v2/tests/connector_tests.py @@ -0,0 +1,332 @@ +# mypy: ignore-errors +import pytest +from connector_v2.models import ConnectorInstance +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APITestCase + +pytestmark = pytest.mark.django_db + + +@pytest.mark.connector +class TestConnector(APITestCase): + def test_connector_list(self) -> None: + """Tests to List the connectors.""" + + url = reverse("connectors_v1-list") + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_connectors_detail(self) -> None: + """Tests to fetch a connector with given pk.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_connectors_detail_not_found(self) -> None: + """Tests for negative case to fetch non exiting key.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 768}) + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_connectors_create(self) -> None: + """Tests to create a new ConnectorInstance.""" + + url = reverse("connectors_v1-list") + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "sample_url", + "sharable_link": True, + }, + } + response = self.client.post(url, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(ConnectorInstance.objects.count(), 2) + + def test_connectors_create_with_json_list(self) -> None: + """Tests to create a new connector with list included in the json + field.""" + + url = reverse("connectors_v1-list") + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "sample_url", + "sharable_link": True, + "file_name_list": ["a1", "a2"], + }, + } + response = self.client.post(url, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(ConnectorInstance.objects.count(), 2) + + def test_connectors_create_with_nested_json(self) -> None: + """Tests to create a new connector with json field as nested json.""" + + url = reverse("connectors_v1-list") + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.post(url, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(ConnectorInstance.objects.count(), 2) + + def test_connectors_create_bad_request(self) -> None: + """Tests for negative case to throw error on a wrong access.""" + + url = reverse("connectors_v1-list") + data = { + "org": 5, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.post(url, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_connectors_update_json_field(self) -> None: + """Tests to update connector with json field update.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "new_sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.put(url, data, format="json") + drive_link = response.data["connector_metadata"]["drive_link"] + self.assertEqual(drive_link, "new_sample_url") + + def test_connectors_update(self) -> None: + """Tests to update connector update single field.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "org": 1, + "project": 1, + "created_by": 1, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "new_sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.put(url, data, format="json") + modified_by = response.data["modified_by"] + self.assertEqual(modified_by, 2) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_connectors_update_pk(self) -> None: + """Tests the PUT method for 400 error.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "org": 2, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "new_sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.put(url, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_connectors_update_json_fields(self) -> None: + """Tests to update ConnectorInstance.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "new_sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.put(url, data, format="json") + nested_value = response.data["connector_metadata"]["sample_metadata_json"][ + "key1" + ] + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(nested_value, "value1") + + def test_connectors_update_json_list_fields(self) -> None: + """Tests to update connector to the third second level of json.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "new_sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + "file_list": ["a1", "a2", "a3"], + }, + } + response = self.client.put(url, data, format="json") + nested_value = response.data["connector_metadata"]["sample_metadata_json"][ + "key1" + ] + nested_list = response.data["connector_metadata"]["file_list"] + last_val = nested_list.pop() + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(nested_value, "value1") + self.assertEqual(last_val, "a3") + + # @pytest.mark.xfail(raises=KeyError) + # def test_connectors_update_json_fields_failed(self) -> None: + # """Tests to update connector to the second level of JSON with a wrong + # key.""" + + # url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + # data = { + # "org": 1, + # "project": 1, + # "created_by": 2, + # "modified_by": 2, + # "modified_at": "2023-06-14T05:28:47.759Z", + # "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + # "connector_metadata": { + # "drive_link": "new_sample_url", + # "sharable_link": True, + # "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + # }, + # } + # response = self.client.put(url, data, format="json") + # nested_value = response.data["connector_metadata"]["sample_metadata_json"][ + # "key00" + # ] + + # @pytest.mark.xfail(raises=KeyError) + # def test_connectors_update_json_nested_failed(self) -> None: + # """Tests to update connector to test a first level of json with a wrong + # key.""" + + # url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + # data = { + # "org": 1, + # "project": 1, + # "created_by": 2, + # "modified_by": 2, + # "modified_at": "2023-06-14T05:28:47.759Z", + # "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + # "connector_metadata": { + # "drive_link": "new_sample_url", + # "sharable_link": True, + # "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + # }, + # } + # response = self.client.put(url, data, format="json") + # nested_value = response.data["connector_metadata"]["sample_metadata_jsonNew"] + + def test_connectors_update_field(self) -> None: + """Tests the PATCH method.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = {"connector_id": "e3a4512m-efgb-48d5-98a9-3983ntest"} + response = self.client.patch(url, data, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + connector_id = response.data["connector_id"] + + self.assertEqual( + connector_id, + ConnectorInstance.objects.get(connector_id=connector_id).connector_id, + ) + + def test_connectors_update_json_field_patch(self) -> None: + """Tests the PATCH method.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "connector_metadata": { + "drive_link": "patch_update_url", + "sharable_link": True, + "sample_metadata_json": { + "key1": "patch_update1", + "key2": "value2", + }, + } + } + + response = self.client.patch(url, data, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + drive_link = response.data["connector_metadata"]["drive_link"] + + self.assertEqual(drive_link, "patch_update_url") + + def test_connectors_delete(self) -> None: + """Tests the DELETE method.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + response = self.client.delete(url, format="json") + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) diff --git a/backend/connector_v2/tests/fixtures/fixtures_0001.json b/backend/connector_v2/tests/fixtures/fixtures_0001.json new file mode 100644 index 000000000..55b39e6d8 --- /dev/null +++ b/backend/connector_v2/tests/fixtures/fixtures_0001.json @@ -0,0 +1,67 @@ +[ + { + "model": "account.org", + "pk": 1, + "fields": { + "org_name": "Zipstack", + "created_by": 1, + "modified_by": 1, + "modified_at": "2023-06-14T05:28:47.739Z" + } + }, + { + "model": "account.user", + "pk": 1, + "fields": { + "org": 1, + "email": "johndoe@gmail.com", + "first_name": "John", + "last_name": "Doe", + "is_admin": true, + "created_by": null, + "modified_by": null, + "modified_at": "2023-06-14T05:28:47.744Z" + } + }, + { + "model": "account.user", + "pk": 2, + "fields": { + "org": 1, + "email": "user1@gmail.com", + "first_name": "Ron", + "last_name": "Stone", + "is_admin": false, + "created_by": 1, + "modified_by": 1, + "modified_at": "2023-06-14T05:28:47.750Z" + } + }, + { + "model": "project.project", + "pk": 1, + "fields": { + "org": 1, + "project_name": "Unstract Test", + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z" + } + }, + { + "model": "connector.connector", + "pk": 1, + "fields": { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e38a59b7-efbb-48d5-9da6-3a0cf2d882a0", + "connector_metadata": { + "connector_type": "gdrive", + "auth_type": "oauth" + } + } + } +] diff --git a/backend/connector_v2/unstract_account.py b/backend/connector_v2/unstract_account.py new file mode 100644 index 000000000..ac270f6d6 --- /dev/null +++ b/backend/connector_v2/unstract_account.py @@ -0,0 +1,71 @@ +import logging +import os + +import boto3 +from botocore.exceptions import ClientError +from django.conf import settings + +logger = logging.getLogger(__name__) + + +# TODO: UnstractAccount need to be pluggable +class UnstractAccount: + def __init__(self, tenant: str, username: str) -> None: + self.tenant = tenant + self.username = username + + def provision_s3_storage(self) -> None: + access_key = settings.GOOGLE_STORAGE_ACCESS_KEY_ID + secret_key = settings.GOOGLE_STORAGE_SECRET_ACCESS_KEY + bucket_name: str = settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME + + s3 = boto3.client( + "s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + endpoint_url="https://storage.googleapis.com", + ) + + # Check if folder exists and create if it is not available + account_folder = f"{self.tenant}/{self.username}/input/examples/" + try: + logger.info(f"Checking if folder {account_folder} exists...") + s3.head_object(Bucket=bucket_name, Key=account_folder) + logger.info(f"Folder {account_folder} already exists") + except ClientError as e: + logger.info(f"{bucket_name} Folder {account_folder} does not exist") + if e.response["Error"]["Code"] == "404": + logger.info(f"Folder {account_folder} does not exist. Creating it...") + s3.put_object(Bucket=bucket_name, Key=account_folder) + account_folder_output = f"{self.tenant}/{self.username}/output/" + s3.put_object(Bucket=bucket_name, Key=account_folder_output) + else: + logger.error(f"Error checking folder {account_folder}: {e}") + raise e + + def upload_sample_files(self) -> None: + access_key = settings.GOOGLE_STORAGE_ACCESS_KEY_ID + secret_key = settings.GOOGLE_STORAGE_SECRET_ACCESS_KEY + bucket_name: str = settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME + + s3 = boto3.client( + "s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + endpoint_url="https://storage.googleapis.com", + ) + + folder = f"{self.tenant}/{self.username}/input/examples/" + + local_path = f"{os.path.dirname(__file__)}/static" + for root, dirs, files in os.walk(local_path): + for file in files: + local_file_path = os.path.join(root, file) + s3_key = os.path.join( + folder, os.path.relpath(local_file_path, local_path) + ) + logger.info( + f"Uploading: {local_file_path} => s3://{bucket_name}/{s3_key}" + ) + s3.upload_file(local_file_path, bucket_name, s3_key) + logger.info(f"Uploaded: {local_file_path}") diff --git a/backend/connector_v2/urls.py b/backend/connector_v2/urls.py new file mode 100644 index 000000000..424033528 --- /dev/null +++ b/backend/connector_v2/urls.py @@ -0,0 +1,21 @@ +from django.urls import path +from rest_framework.urlpatterns import format_suffix_patterns + +from .views import ConnectorInstanceViewSet as CIViewSet + +connector_list = CIViewSet.as_view({"get": "list", "post": "create"}) +connector_detail = CIViewSet.as_view( + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", + } +) + +urlpatterns = format_suffix_patterns( + [ + path("connector/", connector_list, name="connector-list"), + path("connector//", connector_detail, name="connector-detail"), + ] +) diff --git a/backend/connector_v2/views.py b/backend/connector_v2/views.py new file mode 100644 index 000000000..3f098d355 --- /dev/null +++ b/backend/connector_v2/views.py @@ -0,0 +1,124 @@ +import logging +from typing import Any, Optional + +from account_v2.custom_exceptions import DuplicateData +from connector_auth_v2.constants import ConnectorAuthKey +from connector_auth_v2.exceptions import CacheMissException, MissingParamException +from connector_auth_v2.pipeline.common import ConnectorAuthHelper +from connector_processor.exceptions import OAuthTimeOut +from connector_v2.constants import ConnectorInstanceKey as CIKey +from django.db import IntegrityError +from django.db.models import QuerySet +from rest_framework import status, viewsets +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from utils.filtering import FilterHelper + +from backend.constants import RequestKey + +from .models import ConnectorInstance +from .serializers import ConnectorInstanceSerializer + +logger = logging.getLogger(__name__) + + +class ConnectorInstanceViewSet(viewsets.ModelViewSet): + versioning_class = URLPathVersioning + queryset = ConnectorInstance.objects.all() + serializer_class = ConnectorInstanceSerializer + + def get_queryset(self) -> Optional[QuerySet]: + filter_args = FilterHelper.build_filter_args( + self.request, + RequestKey.WORKFLOW, + RequestKey.CREATED_BY, + CIKey.CONNECTOR_TYPE, + CIKey.CONNECTOR_MODE, + ) + if filter_args: + queryset = ConnectorInstance.objects.filter(**filter_args) + else: + queryset = ConnectorInstance.objects.all() + return queryset + + def _get_connector_metadata(self, connector_id: str) -> Optional[dict[str, str]]: + """Gets connector metadata for the ConnectorInstance. + + For non oauth based - obtains from request + For oauth based - obtains from cache + + Raises: + e: MissingParamException, CacheMissException + + Returns: + dict[str, str]: Connector creds dict to connect with + """ + connector_metadata = None + if ConnectorInstance.supportsOAuth(connector_id=connector_id): + logger.info(f"Fetching oauth data for {connector_id}") + oauth_key = self.request.query_params.get(ConnectorAuthKey.OAUTH_KEY) + if oauth_key is None: + raise MissingParamException(param=ConnectorAuthKey.OAUTH_KEY) + connector_metadata = ConnectorAuthHelper.get_oauth_creds_from_cache( + cache_key=oauth_key, delete_key=True + ) + if connector_metadata is None: + raise CacheMissException( + f"Couldn't find credentials for {oauth_key} from cache" + ) + else: + connector_metadata = self.request.data.get(CIKey.CONNECTOR_METADATA) + return connector_metadata + + def perform_update(self, serializer: ConnectorInstanceSerializer) -> None: + connector_metadata = None + connector_id = self.request.data.get( + CIKey.CONNECTOR_ID, serializer.instance.connector_id + ) + try: + connector_metadata = self._get_connector_metadata(connector_id) + # TODO: Handle specific exceptions instead of using a generic Exception. + except Exception: + # Suppress here to not shout during partial updates + pass + # Take metadata from instance itself since update + # is performed on other fields of ConnectorInstance + if connector_metadata is None: + connector_metadata = serializer.instance.connector_metadata + serializer.save( + connector_id=connector_id, + connector_metadata=connector_metadata, + modified_by=self.request.user, + ) # type: ignore + + def perform_create(self, serializer: ConnectorInstanceSerializer) -> None: + connector_metadata = None + connector_id = self.request.data.get(CIKey.CONNECTOR_ID) + try: + connector_metadata = self._get_connector_metadata(connector_id=connector_id) + # TODO: Handle specific exceptions instead of using a generic Exception. + except Exception as exc: + logger.error(f"Error while obtaining ConnectorAuth: {exc}") + raise OAuthTimeOut + serializer.save( + connector_id=connector_id, + connector_metadata=connector_metadata, + created_by=self.request.user, + modified_by=self.request.user, + ) # type: ignore + + def create(self, request: Any) -> Response: + # Overriding default exception behavior + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + try: + self.perform_create(serializer) + except IntegrityError: + raise DuplicateData( + f"{CIKey.CONNECTOR_EXISTS}, \ + {CIKey.DUPLICATE_API}" + ) + headers = self.get_success_headers(serializer.data) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + )