diff --git a/backend/adapter_processor/adapter_processor.py b/backend/adapter_processor/adapter_processor.py index f8c4bb1ef..786280367 100644 --- a/backend/adapter_processor/adapter_processor.py +++ b/backend/adapter_processor/adapter_processor.py @@ -8,6 +8,7 @@ InternalServiceError, InValidAdapterId, TestAdapterError, + TestAdapterInputError, ) from django.conf import settings from django.core.exceptions import ObjectDoesNotExist @@ -97,6 +98,12 @@ def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool: test_result: bool = adapter_instance.test_connection() logger.info(f"{adapter_id} test result: {test_result}") return test_result + # HACK: Remove after error is explicitly handled in VertexAI adapter + except json.JSONDecodeError: + raise TestAdapterInputError( + "Credentials is not a valid service account JSON, " + "please provide a valid JSON." + ) except AdapterError as e: raise TestAdapterError(str(e)) diff --git a/backend/adapter_processor/exceptions.py b/backend/adapter_processor/exceptions.py index 5bddbb3c2..5e1bbf948 100644 --- a/backend/adapter_processor/exceptions.py +++ b/backend/adapter_processor/exceptions.py @@ -2,8 +2,6 @@ from rest_framework.exceptions import APIException -from backend.exceptions import UnstractBaseException - class IdIsMandatory(APIException): status_code = 400 @@ -46,11 +44,16 @@ class UniqueConstraintViolation(APIException): default_detail = "Unique constraint violated" -class TestAdapterError(UnstractBaseException): +class TestAdapterError(APIException): status_code = 500 default_detail = "Error while testing adapter" +class TestAdapterInputError(APIException): + status_code = 400 + default_detail = "Error while testing adapter, please check the configuration." + + class DeleteAdapterInUseError(APIException): status_code = 409 diff --git a/backend/adapter_processor/views.py b/backend/adapter_processor/views.py index 333b05d71..162e4aaf0 100644 --- a/backend/adapter_processor/views.py +++ b/backend/adapter_processor/views.py @@ -115,17 +115,13 @@ def test(self, request: Request) -> Response: adapter_metadata[AdapterKeys.ADAPTER_TYPE] = serializer.validated_data.get( AdapterKeys.ADAPTER_TYPE ) - try: - test_result = AdapterProcessor.test_adapter( - adapter_id=adapter_id, adapter_metadata=adapter_metadata - ) - return Response( - {AdapterKeys.IS_VALID: test_result}, - status=status.HTTP_200_OK, - ) - except Exception as e: - logger.error(f"Error testing adapter : {str(e)}") - raise e + test_result = AdapterProcessor.test_adapter( + adapter_id=adapter_id, adapter_metadata=adapter_metadata + ) + return Response( + {AdapterKeys.IS_VALID: test_result}, + status=status.HTTP_200_OK, + ) class AdapterInstanceViewSet(ModelViewSet): diff --git a/backend/adapter_processor_v2/__init__.py b/backend/adapter_processor_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/adapter_processor_v2/adapter_processor.py b/backend/adapter_processor_v2/adapter_processor.py new file mode 100644 index 000000000..dd9d44ff3 --- /dev/null +++ b/backend/adapter_processor_v2/adapter_processor.py @@ -0,0 +1,257 @@ +import json +import logging +from typing import Any, Optional + +from account_v2.models import User +from adapter_processor_v2.constants import AdapterKeys +from adapter_processor_v2.exceptions import ( + InternalServiceError, + InValidAdapterId, + TestAdapterError, +) +from django.conf import settings +from django.core.exceptions import ObjectDoesNotExist +from platform_settings_v2.platform_auth_service import PlatformAuthenticationService +from unstract.adapters.adapterkit import Adapterkit +from unstract.adapters.base import Adapter +from unstract.adapters.enums import AdapterTypes +from unstract.adapters.exceptions import AdapterError +from unstract.adapters.x2text.constants import X2TextConstants + +from .models import AdapterInstance, UserDefaultAdapter + +logger = logging.getLogger(__name__) + + +class AdapterProcessor: + @staticmethod + def get_json_schema(adapter_id: str) -> dict[str, Any]: + """Function to return JSON Schema for Adapters.""" + schema_details: dict[str, Any] = {} + updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value( + AdapterKeys.ID, adapter_id + ) + if len(updated_adapters) != 0: + try: + schema_details[AdapterKeys.JSON_SCHEMA] = json.loads( + updated_adapters[0].get(AdapterKeys.JSON_SCHEMA) + ) + except Exception as exc: + logger.error(f"Error occured while parsing JSON Schema : {exc}") + raise InternalServiceError() + else: + logger.error( + f"Invalid adapter Id : {adapter_id} while fetching JSON Schema" + ) + raise InValidAdapterId() + return schema_details + + @staticmethod + def get_all_supported_adapters(type: str) -> list[dict[Any, Any]]: + """Function to return list of all supported adapters.""" + supported_adapters = [] + updated_adapters = [] + updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value( + AdapterKeys.ADAPTER_TYPE, type + ) + for each_adapter in updated_adapters: + supported_adapters.append( + { + AdapterKeys.ID: each_adapter.get(AdapterKeys.ID), + AdapterKeys.NAME: each_adapter.get(AdapterKeys.NAME), + AdapterKeys.DESCRIPTION: each_adapter.get(AdapterKeys.DESCRIPTION), + AdapterKeys.ICON: each_adapter.get(AdapterKeys.ICON), + AdapterKeys.ADAPTER_TYPE: each_adapter.get( + AdapterKeys.ADAPTER_TYPE + ), + } + ) + return supported_adapters + + @staticmethod + def get_adapter_data_with_key(adapter_id: str, key_value: str) -> Any: + """Generic Function to get adapter data with provided key.""" + updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value( + "id", adapter_id + ) + if len(updated_adapters) == 0: + logger.error(f"Invalid adapter ID {adapter_id} while invoking utility") + raise InValidAdapterId() + return updated_adapters[0].get(key_value) + + @staticmethod + def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool: + logger.info(f"Testing adapter: {adapter_id}") + try: + adapter_class = Adapterkit().get_adapter_class_by_adapter_id(adapter_id) + + if adapter_metadata.pop(AdapterKeys.ADAPTER_TYPE) == AdapterKeys.X2TEXT: + adapter_metadata[X2TextConstants.X2TEXT_HOST] = settings.X2TEXT_HOST + adapter_metadata[X2TextConstants.X2TEXT_PORT] = settings.X2TEXT_PORT + platform_key = PlatformAuthenticationService.get_active_platform_key() + adapter_metadata[X2TextConstants.PLATFORM_SERVICE_API_KEY] = str( + platform_key.key + ) + + adapter_instance = adapter_class(adapter_metadata) + test_result: bool = adapter_instance.test_connection() + logger.info(f"{adapter_id} test result: {test_result}") + return test_result + except AdapterError as e: + raise TestAdapterError(str(e)) + + @staticmethod + def __fetch_adapters_by_key_value(key: str, value: Any) -> Adapter: + """Fetches a list of adapters that have an attribute matching key and + value.""" + logger.info(f"Fetching adapter list for {key} with {value}") + adapter_kit = Adapterkit() + adapters = adapter_kit.get_adapters_list() + return [iterate for iterate in adapters if iterate[key] == value] + + @staticmethod + def set_default_triad(default_triad: dict[str, str], user: User) -> None: + try: + ( + user_default_adapter, + created, + ) = UserDefaultAdapter.objects.get_or_create(user=user) + + if default_triad.get(AdapterKeys.LLM_DEFAULT, None): + user_default_adapter.default_llm_adapter = AdapterInstance.objects.get( + pk=default_triad[AdapterKeys.LLM_DEFAULT] + ) + if default_triad.get(AdapterKeys.EMBEDDING_DEFAULT, None): + user_default_adapter.default_embedding_adapter = ( + AdapterInstance.objects.get( + pk=default_triad[AdapterKeys.EMBEDDING_DEFAULT] + ) + ) + + if default_triad.get(AdapterKeys.VECTOR_DB_DEFAULT, None): + user_default_adapter.default_vector_db_adapter = ( + AdapterInstance.objects.get( + pk=default_triad[AdapterKeys.VECTOR_DB_DEFAULT] + ) + ) + + if default_triad.get(AdapterKeys.X2TEXT_DEFAULT, None): + user_default_adapter.default_x2text_adapter = ( + AdapterInstance.objects.get( + pk=default_triad[AdapterKeys.X2TEXT_DEFAULT] + ) + ) + + user_default_adapter.save() + + logger.info("Changed defaults successfully") + except Exception as e: + logger.error(f"Unable to save defaults because: {e}") + if isinstance(e, InValidAdapterId): + raise e + else: + raise InternalServiceError() + + @staticmethod + def get_adapter_instance_by_id(adapter_instance_id: str) -> Adapter: + """Get the adapter instance by its ID. + + Parameters: + - adapter_instance_id (str): The ID of the adapter instance. + + Returns: + - Adapter: The adapter instance with the specified ID. + + Raises: + - Exception: If there is an error while fetching the adapter instance. + """ + try: + adapter = AdapterInstance.objects.get(id=adapter_instance_id) + except Exception as e: + logger.error(f"Unable to fetch adapter: {e}") + if not adapter: + logger.error("Unable to fetch adapter") + return adapter.adapter_name + + @staticmethod + def get_adapters_by_type( + adapter_type: AdapterTypes, user: User + ) -> list[AdapterInstance]: + """Get a list of adapters by their type. + + Parameters: + - adapter_type (AdapterTypes): The type of adapters to retrieve. + - user: Logged in User + + Returns: + - list[AdapterInstance]: A list of AdapterInstance objects that match + the specified adapter type. + """ + + adapters: list[AdapterInstance] = AdapterInstance.objects.for_user(user).filter( + adapter_type=adapter_type.value, + ) + return adapters + + @staticmethod + def get_adapter_by_name_and_type( + adapter_type: AdapterTypes, + adapter_name: Optional[str] = None, + ) -> Optional[AdapterInstance]: + """Get the adapter instance by its name and type. + + Parameters: + - adapter_name (str): The name of the adapter instance. + - adapter_type (AdapterTypes): The type of the adapter instance. + + Returns: + - AdapterInstance: The adapter with the specified name and type. + """ + if adapter_name: + adapter: AdapterInstance = AdapterInstance.objects.get( + adapter_name=adapter_name, adapter_type=adapter_type.value + ) + else: + try: + adapter = AdapterInstance.objects.get( + adapter_type=adapter_type.value, is_default=True + ) + except AdapterInstance.DoesNotExist: + return None + return adapter + + @staticmethod + def get_default_adapters(user: User) -> list[AdapterInstance]: + """Retrieve a list of default adapter instances. This method queries + the database to fetch all adapter instances marked as default. + + Raises: + InternalServiceError: If an unexpected error occurs during + the database query. + + Returns: + list[AdapterInstance]: A list of AdapterInstance objects that are + marked as default. + """ + try: + adapters: list[AdapterInstance] = [] + default_adapter = UserDefaultAdapter.objects.get(user=user) + + if default_adapter.default_embedding_adapter: + adapters.append(default_adapter.default_embedding_adapter) + if default_adapter.default_llm_adapter: + adapters.append(default_adapter.default_llm_adapter) + if default_adapter.default_vector_db_adapter: + adapters.append(default_adapter.default_vector_db_adapter) + if default_adapter.default_x2text_adapter: + adapters.append(default_adapter.default_x2text_adapter) + + return adapters + except ObjectDoesNotExist as e: + logger.error(f"No default adapters found: {e}") + raise InternalServiceError( + "No default adapters found, " "configure them through Platform Settings" + ) + except Exception as e: + logger.error(f"Error occurred while fetching default adapters: {e}") + raise InternalServiceError("Error fetching default adapters") diff --git a/backend/adapter_processor_v2/constants.py b/backend/adapter_processor_v2/constants.py new file mode 100644 index 000000000..e35f2d0ac --- /dev/null +++ b/backend/adapter_processor_v2/constants.py @@ -0,0 +1,28 @@ +class AdapterKeys: + JSON_SCHEMA = "json_schema" + ADAPTER_TYPE = "adapter_type" + IS_DEFAULT = "is_default" + LLM = "LLM" + X2TEXT = "X2TEXT" + OCR = "OCR" + VECTOR_DB = "VECTOR_DB" + EMBEDDING = "EMBEDDING" + NAME = "name" + DESCRIPTION = "description" + ICON = "icon" + ADAPTER_ID = "adapter_id" + ADAPTER_METADATA = "adapter_metadata" + ADAPTER_METADATA_B = "adapter_metadata_b" + ID = "id" + IS_VALID = "is_valid" + LLM_DEFAULT = "llm_default" + VECTOR_DB_DEFAULT = "vector_db_default" + EMBEDDING_DEFAULT = "embedding_default" + X2TEXT_DEFAULT = "x2text_default" + SHARED_USERS = "shared_users" + ADAPTER_NAME_EXISTS = ( + "Configuration with this Name already exists. " + "Please try with a different Name" + ) + ADAPTER_CREATED_BY = "created_by_email" + ADAPTER_CONTEXT_WINDOW_SIZE = "context_window_size" diff --git a/backend/adapter_processor_v2/exceptions.py b/backend/adapter_processor_v2/exceptions.py new file mode 100644 index 000000000..5bddbb3c2 --- /dev/null +++ b/backend/adapter_processor_v2/exceptions.py @@ -0,0 +1,68 @@ +from typing import Optional + +from rest_framework.exceptions import APIException + +from backend.exceptions import UnstractBaseException + + +class IdIsMandatory(APIException): + status_code = 400 + default_detail = "ID is Mandatory." + + +class InValidType(APIException): + status_code = 400 + default_detail = "Type is not Valid." + + +class InValidAdapterId(APIException): + status_code = 400 + default_detail = "Adapter ID is not Valid." + + +class InvalidEncryptionKey(APIException): + status_code = 403 + default_detail = ( + "Platform encryption key for storing adapter credentials has changed! " + "Please inform the organization admin to contact support." + ) + + +class InternalServiceError(APIException): + status_code = 500 + default_detail = "Internal Service error" + + +class CannotDeleteDefaultAdapter(APIException): + status_code = 500 + default_detail = ( + "This is configured as default and cannot be deleted. " + "Please configure a different default before you try again!" + ) + + +class UniqueConstraintViolation(APIException): + status_code = 400 + default_detail = "Unique constraint violated" + + +class TestAdapterError(UnstractBaseException): + status_code = 500 + default_detail = "Error while testing adapter" + + +class DeleteAdapterInUseError(APIException): + status_code = 409 + + def __init__( + self, + detail: Optional[str] = None, + code: Optional[str] = None, + adapter_name: str = "adapter", + ): + if detail is None: + detail = ( + f"Cannot delete {adapter_name}. " + "It is used in a workflow or a prompt studio project" + ) + super().__init__(detail, code) diff --git a/backend/adapter_processor_v2/models.py b/backend/adapter_processor_v2/models.py new file mode 100644 index 000000000..5e63056c5 --- /dev/null +++ b/backend/adapter_processor_v2/models.py @@ -0,0 +1,208 @@ +import json +import logging +import uuid +from typing import Any + +from account_v2.models import User +from cryptography.fernet import Fernet +from django.conf import settings +from django.db import models +from django.db.models import QuerySet +from tenant_account_v2.models import OrganizationMember +from unstract.adapters.adapterkit import Adapterkit +from unstract.adapters.enums import AdapterTypes +from unstract.adapters.exceptions import AdapterError +from utils.models.base_model import BaseModel +from utils.models.organization_mixin import ( + DefaultOrganizationManagerMixin, + DefaultOrganizationMixin, +) + +logger = logging.getLogger(__name__) + +ADAPTER_NAME_SIZE = 128 +VERSION_NAME_SIZE = 64 +ADAPTER_ID_LENGTH = 128 + +logger = logging.getLogger(__name__) + + +class AdapterInstanceModelManager(DefaultOrganizationManagerMixin, models.Manager): + def get_queryset(self) -> QuerySet[Any]: + return super().get_queryset() + + def for_user(self, user: User) -> QuerySet[Any]: + return ( + self.get_queryset() + .filter( + models.Q(created_by=user) + | models.Q(shared_users=user) + | models.Q(shared_to_org=True) + | models.Q(is_friction_less=True) + ) + .distinct("id") + ) + + +class AdapterInstance(DefaultOrganizationMixin, BaseModel): + id = models.UUIDField( + primary_key=True, + default=uuid.uuid4, + editable=False, + db_comment="Unique identifier for the Adapter Instance", + ) + adapter_name = models.TextField( + max_length=ADAPTER_NAME_SIZE, + null=False, + blank=False, + db_comment="Name of the Adapter Instance", + ) + adapter_id = models.CharField( + max_length=ADAPTER_ID_LENGTH, + default="", + db_comment="Unique identifier of the Adapter", + ) + + # TODO to be removed once the migration for encryption + adapter_metadata = models.JSONField( + db_column="adapter_metadata", + null=False, + blank=False, + default=dict, + db_comment="JSON adapter metadata submitted by the user", + ) + adapter_metadata_b = models.BinaryField(null=True) + adapter_type = models.CharField( + choices=[(tag.value, tag.name) for tag in AdapterTypes], + db_comment="Type of adapter LLM/EMBEDDING/VECTOR_DB", + ) + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="adapters_created", + null=True, + blank=True, + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="adapters_modified", + null=True, + blank=True, + ) + + is_active = models.BooleanField( + default=False, + db_comment="Is the adapter instance currently being used", + ) + shared_to_org = models.BooleanField( + default=False, + db_comment="Is the adapter shared to entire org", + ) + + is_friction_less = models.BooleanField( + default=False, + db_comment="Was the adapter created through frictionless onboarding", + ) + + # Can be used if the adapter usage gets exhausted + # Can also be used in other possible scenarios in feature + is_usable = models.BooleanField( + default=True, + db_comment="Is the Adpater Usable", + ) + + # Introduced field to establish M2M relation between users and adapters. + # This will introduce intermediary table which relates both the models. + shared_users = models.ManyToManyField(User, related_name="shared_adapters_instance") + description = models.TextField(blank=True, null=True, default=None) + + objects = AdapterInstanceModelManager() + + class Meta: + verbose_name = "adapter instance" + verbose_name_plural = "adapter instances" + db_table = "adapter_instance_v2" + constraints = [ + models.UniqueConstraint( + fields=["adapter_name", "adapter_type", "organization"], + name="unique_organization_adapter", + ), + ] + + def create_adapter(self) -> None: + + encryption_secret: str = settings.ENCRYPTION_KEY + f: Fernet = Fernet(encryption_secret.encode("utf-8")) + + self.adapter_metadata_b = f.encrypt( + json.dumps(self.adapter_metadata).encode("utf-8") + ) + self.adapter_metadata = {} + + self.save() + + def get_adapter_meta_data(self) -> Any: + encryption_secret: str = settings.ENCRYPTION_KEY + f: Fernet = Fernet(encryption_secret.encode("utf-8")) + + adapter_metadata = json.loads( + f.decrypt(bytes(self.adapter_metadata_b).decode("utf-8")) + ) + return adapter_metadata + + def get_context_window_size(self) -> int: + + adapter_metadata = self.get_adapter_meta_data() + # Get the adapter_instance + adapter_class = Adapterkit().get_adapter_class_by_adapter_id(self.adapter_id) + try: + adapter_instance = adapter_class(adapter_metadata) + return adapter_instance.get_context_window_size() + except AdapterError as e: + logger.warning(f"Unable to retrieve context window size - {e}") + return 0 + + +class UserDefaultAdapter(BaseModel): + user = models.OneToOneField( + User, on_delete=models.CASCADE, related_name="organization_default_adapters" + ) + organization_member = models.OneToOneField( + OrganizationMember, + on_delete=models.CASCADE, + default=None, + null=True, + db_comment="Foreign key reference to the OrganizationMember model.", + related_name="default_adapters", + ) + default_llm_adapter = models.ForeignKey( + AdapterInstance, + on_delete=models.SET_NULL, + null=True, + related_name="user_default_llm_adapter", + ) + default_embedding_adapter = models.ForeignKey( + AdapterInstance, + on_delete=models.SET_NULL, + null=True, + related_name="user_default_embedding_adapter", + ) + default_vector_db_adapter = models.ForeignKey( + AdapterInstance, + on_delete=models.SET_NULL, + null=True, + related_name="user_default_vector_db_adapter", + ) + + default_x2text_adapter = models.ForeignKey( + AdapterInstance, + on_delete=models.SET_NULL, + null=True, + related_name="user_default_x2text_adapter", + ) + + class Meta: + verbose_name = "Default Adapter for Organization User" + verbose_name_plural = "Default Adapters for Organization Users" + db_table = "default_organization_user_adapter_v2" diff --git a/backend/adapter_processor_v2/serializers.py b/backend/adapter_processor_v2/serializers.py new file mode 100644 index 000000000..927fe7e91 --- /dev/null +++ b/backend/adapter_processor_v2/serializers.py @@ -0,0 +1,164 @@ +import json +from typing import Any + +from account_v2.serializer import UserSerializer +from adapter_processor_v2.adapter_processor import AdapterProcessor +from adapter_processor_v2.constants import AdapterKeys +from adapter_processor_v2.exceptions import InvalidEncryptionKey +from cryptography.fernet import Fernet, InvalidToken +from django.conf import settings +from rest_framework import serializers +from rest_framework.serializers import ModelSerializer +from unstract.adapters.constants import Common as common +from unstract.adapters.enums import AdapterTypes + +from backend.constants import FieldLengthConstants as FLC +from backend.serializers import AuditSerializer + +from .models import AdapterInstance, UserDefaultAdapter + + +class TestAdapterSerializer(serializers.Serializer): + adapter_id = serializers.CharField(max_length=FLC.ADAPTER_ID_LENGTH) + adapter_metadata = serializers.JSONField() + adapter_type = serializers.JSONField() + + +class BaseAdapterSerializer(AuditSerializer): + class Meta: + model = AdapterInstance + fields = "__all__" + + +class DefaultAdapterSerializer(serializers.Serializer): + llm_default = serializers.CharField(max_length=FLC.UUID_LENGTH, required=False) + embedding_default = serializers.CharField( + max_length=FLC.UUID_LENGTH, required=False + ) + vector_db_default = serializers.CharField( + max_length=FLC.UUID_LENGTH, required=False + ) + + +class AdapterInstanceSerializer(BaseAdapterSerializer): + """Inherits BaseAdapterSerializer. + + Used for CRUD other than listing + """ + + def to_internal_value(self, data: dict[str, Any]) -> dict[str, Any]: + if data.get(AdapterKeys.ADAPTER_METADATA, None): + encryption_secret: str = settings.ENCRYPTION_KEY + f: Fernet = Fernet(encryption_secret.encode("utf-8")) + json_string: str = json.dumps(data.pop(AdapterKeys.ADAPTER_METADATA)) + + data[AdapterKeys.ADAPTER_METADATA_B] = f.encrypt( + json_string.encode("utf-8") + ) + + return data + + def to_representation(self, instance: AdapterInstance) -> dict[str, str]: + rep: dict[str, str] = super().to_representation(instance) + + rep.pop(AdapterKeys.ADAPTER_METADATA_B) + + try: + adapter_metadata = instance.get_adapter_meta_data() + except InvalidToken: + raise InvalidEncryptionKey + rep[AdapterKeys.ADAPTER_METADATA] = adapter_metadata + # Retrieve context window if adapter is a LLM + # For other adapter types, context_window is not relevant. + if instance.adapter_type == AdapterTypes.LLM.value: + adapter_metadata[AdapterKeys.ADAPTER_CONTEXT_WINDOW_SIZE] = ( + instance.get_context_window_size() + ) + + rep[common.ICON] = AdapterProcessor.get_adapter_data_with_key( + instance.adapter_id, common.ICON + ) + rep[AdapterKeys.ADAPTER_CREATED_BY] = instance.created_by.email + + return rep + + +class AdapterInfoSerializer(BaseAdapterSerializer): + + context_window_size = serializers.SerializerMethodField() + + class Meta(BaseAdapterSerializer.Meta): + model = AdapterInstance + fields = ( + "id", + "adapter_id", + "adapter_name", + "adapter_type", + "created_by", + "context_window_size", + ) # type: ignore + + def get_context_window_size(self, obj: AdapterInstance) -> int: + return obj.get_context_window_size() + + +class AdapterListSerializer(BaseAdapterSerializer): + """Inherits BaseAdapterSerializer. + + Used for listing adapters + """ + + class Meta(BaseAdapterSerializer.Meta): + model = AdapterInstance + fields = ( + "id", + "adapter_id", + "adapter_name", + "adapter_type", + "created_by", + "description", + ) # type: ignore + + def to_representation(self, instance: AdapterInstance) -> dict[str, str]: + rep: dict[str, str] = super().to_representation(instance) + rep[common.ICON] = AdapterProcessor.get_adapter_data_with_key( + instance.adapter_id, common.ICON + ) + adapter_metadata = instance.get_adapter_meta_data() + model = adapter_metadata.get("model") + if model: + rep["model"] = model + + if instance.is_friction_less: + rep["created_by_email"] = "Unstract" + else: + rep["created_by_email"] = instance.created_by.email + + return rep + + +class SharedUserListSerializer(BaseAdapterSerializer): + """Inherits BaseAdapterSerializer. + + Used for listing adapter users + """ + + shared_users = UserSerializer(many=True) + created_by = UserSerializer() + + class Meta(BaseAdapterSerializer.Meta): + model = AdapterInstance + fields = ( + "id", + "adapter_id", + "adapter_name", + "adapter_type", + "created_by", + "shared_users", + ) # type: ignore + + +class UserDefaultAdapterSerializer(ModelSerializer): + class Meta: + model = UserDefaultAdapter + fields = "__all__" diff --git a/backend/adapter_processor_v2/urls.py b/backend/adapter_processor_v2/urls.py new file mode 100644 index 000000000..741d3c3e1 --- /dev/null +++ b/backend/adapter_processor_v2/urls.py @@ -0,0 +1,42 @@ +from adapter_processor_v2.views import ( + AdapterInstanceViewSet, + AdapterViewSet, + DefaultAdapterViewSet, +) +from django.urls import path +from rest_framework.urlpatterns import format_suffix_patterns + +default_triad = DefaultAdapterViewSet.as_view( + {"post": "configure_default_triad", "get": "get_default_triad"} +) +adapter = AdapterViewSet.as_view({"get": "list"}) +adapter_schema = AdapterViewSet.as_view({"get": "get_adapter_schema"}) +adapter_test = AdapterViewSet.as_view({"post": "test"}) +adapter_list = AdapterInstanceViewSet.as_view({"post": "create", "get": "list"}) +adapter_detail = AdapterInstanceViewSet.as_view( + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", + } +) + +adapter_users = AdapterInstanceViewSet.as_view({"get": "list_of_shared_users"}) +adapter_info = AdapterInstanceViewSet.as_view({"get": "adapter_info"}) +urlpatterns = format_suffix_patterns( + [ + path("adapter_schema/", adapter_schema, name="get_adapter_schema"), + path("supported_adapters/", adapter, name="adapter-list"), + path("adapter/", adapter_list, name="adapter-list"), + path("adapter/default_triad/", default_triad, name="default_triad"), + path("adapter//", adapter_detail, name="adapter_detail"), + path("adapter/info//", adapter_info, name="adapter_info"), + path("test_adapters/", adapter_test, name="adapter-test"), + path( + "adapter/users//", + adapter_users, + name="adapter-users", + ), + ] +) diff --git a/backend/adapter_processor_v2/views.py b/backend/adapter_processor_v2/views.py new file mode 100644 index 000000000..93cddf5aa --- /dev/null +++ b/backend/adapter_processor_v2/views.py @@ -0,0 +1,321 @@ +import logging +import uuid +from typing import Any, Optional + +from adapter_processor_v2.adapter_processor import AdapterProcessor +from adapter_processor_v2.constants import AdapterKeys +from adapter_processor_v2.exceptions import ( + CannotDeleteDefaultAdapter, + DeleteAdapterInUseError, + IdIsMandatory, + InValidType, + UniqueConstraintViolation, +) +from adapter_processor_v2.serializers import ( + AdapterInfoSerializer, + AdapterInstanceSerializer, + AdapterListSerializer, + DefaultAdapterSerializer, + SharedUserListSerializer, + TestAdapterSerializer, + UserDefaultAdapterSerializer, +) +from django.db import IntegrityError +from django.db.models import ProtectedError, QuerySet +from django.http import HttpRequest +from django.http.response import HttpResponse +from permissions.permission import ( + IsFrictionLessAdapter, + IsFrictionLessAdapterDelete, + IsOwner, + IsOwnerOrSharedUserOrSharedToOrg, +) +from rest_framework import status +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.serializers import ModelSerializer +from rest_framework.versioning import URLPathVersioning +from rest_framework.viewsets import GenericViewSet, ModelViewSet +from tenant_account_v2.organization_member_service import OrganizationMemberService +from utils.filtering import FilterHelper + +from .constants import AdapterKeys as constant +from .exceptions import InternalServiceError +from .models import AdapterInstance, UserDefaultAdapter + +logger = logging.getLogger(__name__) + + +class DefaultAdapterViewSet(ModelViewSet): + versioning_class = URLPathVersioning + serializer_class = DefaultAdapterSerializer + + def configure_default_triad( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> HttpResponse: + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + # Convert request data to json + default_triad = request.data + AdapterProcessor.set_default_triad(default_triad, request.user) + return Response(status=status.HTTP_200_OK) + + def get_default_triad( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> HttpResponse: + try: + user_default_adapter = UserDefaultAdapter.objects.get(user=request.user) + serializer = UserDefaultAdapterSerializer(user_default_adapter).data + return Response(serializer) + + except UserDefaultAdapter.DoesNotExist: + # Handle the case when no records are found + return Response(status=status.HTTP_200_OK, data={}) + + +class AdapterViewSet(GenericViewSet): + versioning_class = URLPathVersioning + serializer_class = TestAdapterSerializer + + def list( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> HttpResponse: + if request.method == "GET": + adapter_type = request.GET.get(AdapterKeys.ADAPTER_TYPE) + if ( + adapter_type == AdapterKeys.LLM + or adapter_type == AdapterKeys.EMBEDDING + or adapter_type == AdapterKeys.VECTOR_DB + or adapter_type == AdapterKeys.X2TEXT + or adapter_type == AdapterKeys.OCR + ): + json_schema = AdapterProcessor.get_all_supported_adapters( + type=adapter_type + ) + return Response(json_schema, status=status.HTTP_200_OK) + else: + raise InValidType + + def get_adapter_schema( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> HttpResponse: + if request.method == "GET": + adapter_name = request.GET.get(AdapterKeys.ID) + if adapter_name is None or adapter_name == "": + raise IdIsMandatory() + json_schema = AdapterProcessor.get_json_schema(adapter_id=adapter_name) + return Response(data=json_schema, status=status.HTTP_200_OK) + + def test(self, request: Request) -> Response: + """Tests the connector against the credentials passed.""" + serializer: AdapterInstanceSerializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + adapter_id = serializer.validated_data.get(AdapterKeys.ADAPTER_ID) + adapter_metadata = serializer.validated_data.get(AdapterKeys.ADAPTER_METADATA) + adapter_metadata[AdapterKeys.ADAPTER_TYPE] = serializer.validated_data.get( + AdapterKeys.ADAPTER_TYPE + ) + try: + test_result = AdapterProcessor.test_adapter( + adapter_id=adapter_id, adapter_metadata=adapter_metadata + ) + return Response( + {AdapterKeys.IS_VALID: test_result}, + status=status.HTTP_200_OK, + ) + except Exception as e: + logger.error(f"Error testing adapter : {str(e)}") + raise e + + +class AdapterInstanceViewSet(ModelViewSet): + + serializer_class = AdapterInstanceSerializer + + def get_permissions(self) -> list[Any]: + + if self.action in ["update", "retrieve"]: + return [IsFrictionLessAdapter()] + + elif self.action == "destroy": + return [IsFrictionLessAdapterDelete()] + + elif self.action in ["list_of_shared_users", "adapter_info"]: + return [IsOwnerOrSharedUserOrSharedToOrg()] + + # Hack for friction-less onboarding + # User cant view/update metadata but can delete/share etc + return [IsOwner()] + + def get_queryset(self) -> Optional[QuerySet]: + if filter_args := FilterHelper.build_filter_args( + self.request, + constant.ADAPTER_TYPE, + ): + queryset = AdapterInstance.objects.for_user(self.request.user).filter( + **filter_args + ) + else: + queryset = AdapterInstance.objects.for_user(self.request.user) + return queryset + + def get_serializer_class( + self, + ) -> ModelSerializer: + if self.action == "list": + return AdapterListSerializer + return AdapterInstanceSerializer + + def create(self, request: Any) -> Response: + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + try: + instance = serializer.save() + + # Check to see if there is a default configured + # for this adapter_type and for the current user + ( + user_default_adapter, + created, + ) = UserDefaultAdapter.objects.get_or_create(user=request.user) + + adapter_type = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE) + if (adapter_type == AdapterKeys.LLM) and ( + not user_default_adapter.default_llm_adapter + ): + user_default_adapter.default_llm_adapter = instance + + elif (adapter_type == AdapterKeys.EMBEDDING) and ( + not user_default_adapter.default_embedding_adapter + ): + user_default_adapter.default_embedding_adapter = instance + elif (adapter_type == AdapterKeys.VECTOR_DB) and ( + not user_default_adapter.default_vector_db_adapter + ): + user_default_adapter.default_vector_db_adapter = instance + elif (adapter_type == AdapterKeys.X2TEXT) and ( + not user_default_adapter.default_x2text_adapter + ): + user_default_adapter.default_x2text_adapter = instance + + organization_member = OrganizationMemberService.get_user_by_id( + request.user.id + ) + user_default_adapter.organization_member = organization_member + + user_default_adapter.save() + + except IntegrityError: + raise UniqueConstraintViolation(f"{AdapterKeys.ADAPTER_NAME_EXISTS}") + except Exception as e: + logger.error(f"Error saving adapter to DB: {e}") + raise InternalServiceError + headers = self.get_success_headers(serializer.data) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) + + def destroy( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> Response: + adapter_instance: AdapterInstance = self.get_object() + adapter_type = adapter_instance.adapter_type + try: + user_default_adapter: UserDefaultAdapter = UserDefaultAdapter.objects.get( + user=request.user + ) + + if ( + ( + adapter_type == AdapterKeys.LLM + and adapter_instance == user_default_adapter.default_llm_adapter + ) + or ( + adapter_type == AdapterKeys.EMBEDDING + and adapter_instance + == user_default_adapter.default_embedding_adapter + ) + or ( + adapter_type == AdapterKeys.VECTOR_DB + and adapter_instance + == user_default_adapter.default_vector_db_adapter + ) + or ( + adapter_type == AdapterKeys.X2TEXT + and adapter_instance == user_default_adapter.default_x2text_adapter + ) + ): + logger.error("Cannot delete a default adapter") + raise CannotDeleteDefaultAdapter() + except UserDefaultAdapter.DoesNotExist: + # We can go head and remove adapter here + logger.info("User default adpater doesnt not exist") + + try: + super().perform_destroy(adapter_instance) + except ProtectedError: + logger.error( + f"Failed to delete adapter: {adapter_instance.adapter_id}" + f" named {adapter_instance.adapter_name}" + ) + # TODO: Provide details of adpter usage with exception object + raise DeleteAdapterInUseError(adapter_name=adapter_instance.adapter_name) + return Response(status=status.HTTP_204_NO_CONTENT) + + def partial_update( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> Response: + if AdapterKeys.SHARED_USERS in request.data: + # find the deleted users + adapter = self.get_object() + shared_users = { + int(user_id) for user_id in request.data.get("shared_users", {}) + } + current_users = {user.id for user in adapter.shared_users.all()} + removed_users = current_users.difference(shared_users) + + # if removed user use this adapter as default + # Remove the same from his default + for user_id in removed_users: + try: + user_default_adapter = UserDefaultAdapter.objects.get( + user_id=user_id + ) + + if user_default_adapter.default_llm_adapter == adapter: + user_default_adapter.default_llm_adapter = None + elif user_default_adapter.default_embedding_adapter == adapter: + user_default_adapter.default_embedding_adapter = None + elif user_default_adapter.default_vector_db_adapter == adapter: + user_default_adapter.default_vector_db_adapter = None + elif user_default_adapter.default_x2text_adapter == adapter: + user_default_adapter.default_x2text_adapter = None + + user_default_adapter.save() + except UserDefaultAdapter.DoesNotExist: + logger.debug( + "User id : %s doesnt have default adapters configured", + user_id, + ) + continue + + return super().partial_update(request, *args, **kwargs) + + @action(detail=True, methods=["get"]) + def list_of_shared_users(self, request: HttpRequest, pk: Any = None) -> Response: + + adapter = self.get_object() + + serialized_instances = SharedUserListSerializer(adapter).data + + return Response(serialized_instances) + + @action(detail=True, methods=["get"]) + def adapter_info(self, request: HttpRequest, pk: uuid) -> Response: + + adapter = self.get_object() + + serialized_instances = AdapterInfoSerializer(adapter).data + + return Response(serialized_instances) diff --git a/backend/api_v2/__init__.py b/backend/api_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/api_v2/admin.py b/backend/api_v2/admin.py new file mode 100644 index 000000000..37f0837a7 --- /dev/null +++ b/backend/api_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import APIDeployment, APIKey + +admin.site.register([APIDeployment, APIKey]) diff --git a/backend/api_v2/api_deployment_views.py b/backend/api_v2/api_deployment_views.py new file mode 100644 index 000000000..802d8f8bc --- /dev/null +++ b/backend/api_v2/api_deployment_views.py @@ -0,0 +1,152 @@ +import json +import logging +from typing import Any, Optional + +from api_v2.constants import ApiExecution +from api_v2.deployment_helper import DeploymentHelper +from api_v2.exceptions import InvalidAPIRequest, NoActiveAPIKeyError +from api_v2.models import APIDeployment +from api_v2.postman_collection.dto import PostmanCollection +from api_v2.serializers import ( + APIDeploymentListSerializer, + APIDeploymentSerializer, + DeploymentResponseSerializer, + ExecutionRequestSerializer, +) +from django.db.models import QuerySet +from django.http import HttpResponse +from permissions.permission import IsOwner +from rest_framework import serializers, status, views, viewsets +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.serializers import Serializer +from utils.enums import CeleryTaskState +from workflow_manager.workflow_v2.dto import ExecutionResponse + +logger = logging.getLogger(__name__) + + +class DeploymentExecution(views.APIView): + def initialize_request( + self, request: Request, *args: Any, **kwargs: Any + ) -> Request: + """To remove csrf request for public API. + + Args: + request (Request): _description_ + + Returns: + Request: _description_ + """ + setattr(request, "csrf_processing_done", True) + return super().initialize_request(request, *args, **kwargs) + + @DeploymentHelper.validate_api_key + def post( + self, request: Request, org_name: str, api_name: str, api: APIDeployment + ) -> Response: + file_objs = request.FILES.getlist(ApiExecution.FILES_FORM_DATA) + serializer = ExecutionRequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + timeout = serializer.get_timeout(serializer.validated_data) + include_metadata = ( + request.data.get(ApiExecution.INCLUDE_METADATA, "false").lower() == "true" + ) + if not file_objs or len(file_objs) == 0: + raise InvalidAPIRequest("File shouldn't be empty") + response = DeploymentHelper.execute_workflow( + organization_name=org_name, + api=api, + file_objs=file_objs, + timeout=timeout, + include_metadata=include_metadata, + ) + if "error" in response and response["error"]: + return Response( + {"message": response}, + status=status.HTTP_422_UNPROCESSABLE_ENTITY, + ) + return Response({"message": response}, status=status.HTTP_200_OK) + + @DeploymentHelper.validate_api_key + def get( + self, request: Request, org_name: str, api_name: str, api: APIDeployment + ) -> Response: + execution_id = request.query_params.get("execution_id") + if not execution_id: + raise InvalidAPIRequest("execution_id shouldn't be empty") + response: ExecutionResponse = DeploymentHelper.get_execution_status( + execution_id=execution_id + ) + if response.execution_status != CeleryTaskState.SUCCESS.value: + return Response( + { + "status": response.execution_status, + "message": response.result, + }, + status=status.HTTP_422_UNPROCESSABLE_ENTITY, + ) + return Response( + {"status": response.execution_status, "message": response.result}, + status=status.HTTP_200_OK, + ) + + +class APIDeploymentViewSet(viewsets.ModelViewSet): + permission_classes = [IsOwner] + + def get_queryset(self) -> Optional[QuerySet]: + return APIDeployment.objects.filter(created_by=self.request.user) + + def get_serializer_class(self) -> serializers.Serializer: + if self.action in ["list"]: + return APIDeploymentListSerializer + return APIDeploymentSerializer + + @action(detail=True, methods=["get"]) + def fetch_one(self, request: Request, pk: Optional[str] = None) -> Response: + """Custom action to fetch a single instance.""" + instance = self.get_object() + serializer = self.get_serializer(instance) + return Response(serializer.data) + + def create( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> Response: + serializer: Serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + self.perform_create(serializer) + api_key = DeploymentHelper.create_api_key(serializer=serializer) + response_serializer = DeploymentResponseSerializer( + {"api_key": api_key.api_key, **serializer.data} + ) + + headers = self.get_success_headers(serializer.data) + return Response( + response_serializer.data, + status=status.HTTP_201_CREATED, + headers=headers, + ) + + @action(detail=True, methods=["get"]) + def download_postman_collection( + self, request: Request, pk: Optional[str] = None + ) -> Response: + """Downloads a Postman Collection of the API deployment instance.""" + instance = self.get_object() + api_key_inst = instance.apikey_set.filter(is_active=True).first() + if not api_key_inst: + logger.error(f"No active API key set for deployment {instance.pk}") + raise NoActiveAPIKeyError(deployment_name=instance.display_name) + + postman_collection = PostmanCollection.create( + instance=instance, api_key=api_key_inst.api_key + ) + response = HttpResponse( + json.dumps(postman_collection.to_dict()), content_type="application/json" + ) + response["Content-Disposition"] = ( + f'attachment; filename="{instance.display_name}.json"' + ) + return response diff --git a/backend/api_v2/api_key_views.py b/backend/api_v2/api_key_views.py new file mode 100644 index 000000000..906127c40 --- /dev/null +++ b/backend/api_v2/api_key_views.py @@ -0,0 +1,28 @@ +from api_v2.deployment_helper import DeploymentHelper +from api_v2.exceptions import APINotFound +from api_v2.key_helper import KeyHelper +from api_v2.models import APIKey +from api_v2.serializers import APIKeyListSerializer, APIKeySerializer +from rest_framework import serializers, viewsets +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.response import Response + + +class APIKeyViewSet(viewsets.ModelViewSet): + queryset = APIKey.objects.all() + + def get_serializer_class(self) -> serializers.Serializer: + if self.action in ["api_keys"]: + return APIKeyListSerializer + return APIKeySerializer + + @action(detail=True, methods=["get"]) + def api_keys(self, request: Request, api_id: str) -> Response: + """Custom action to fetch api keys of an api deployment.""" + api = DeploymentHelper.get_api_by_id(api_id=api_id) + if not api: + raise APINotFound() + keys = KeyHelper.list_api_keys_of_api(api_instance=api) + serializer = self.get_serializer(keys, many=True) + return Response(serializer.data) diff --git a/backend/api_v2/apps.py b/backend/api_v2/apps.py new file mode 100644 index 000000000..16cbd4ad5 --- /dev/null +++ b/backend/api_v2/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class ApiConfig(AppConfig): + name = "api_v2" diff --git a/backend/api_v2/constants.py b/backend/api_v2/constants.py new file mode 100644 index 000000000..0ec324cc9 --- /dev/null +++ b/backend/api_v2/constants.py @@ -0,0 +1,6 @@ +class ApiExecution: + PATH: str = "deployment/api" + MAXIMUM_TIMEOUT_IN_SEC: int = 300 # 5 minutes + FILES_FORM_DATA: str = "files" + TIMEOUT_FORM_DATA: str = "timeout" + INCLUDE_METADATA: str = "include_metadata" diff --git a/backend/api_v2/deployment_helper.py b/backend/api_v2/deployment_helper.py new file mode 100644 index 000000000..e94c84e46 --- /dev/null +++ b/backend/api_v2/deployment_helper.py @@ -0,0 +1,244 @@ +import logging +import uuid +from functools import wraps +from typing import Any, Optional +from urllib.parse import urlencode + +from api_v2.constants import ApiExecution +from api_v2.exceptions import ( + ApiKeyCreateException, + APINotFound, + Forbidden, + InactiveAPI, + UnauthorizedKey, +) +from api_v2.key_helper import KeyHelper +from api_v2.models import APIDeployment, APIKey +from api_v2.serializers import APIExecutionResponseSerializer +from django.core.files.uploadedfile import UploadedFile +from django.db import connection +from rest_framework import status +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.serializers import Serializer +from rest_framework.utils.serializer_helpers import ReturnDict +from utils.constants import Account +from utils.local_context import StateStore +from workflow_manager.endpoint_v2.destination import DestinationConnector +from workflow_manager.endpoint_v2.source import SourceConnector +from workflow_manager.workflow_v2.dto import ExecutionResponse +from workflow_manager.workflow_v2.models.workflow import Workflow +from workflow_manager.workflow_v2.workflow_helper import WorkflowHelper + +logger = logging.getLogger(__name__) + + +class DeploymentHelper: + @staticmethod + def validate_api_key(func: Any) -> Any: + """Decorator that validates the API key. + + Sample header: + Authorization: Bearer 123e4567-e89b-12d3-a456-426614174001 + Args: + func (Any): Function to wrap for validation + """ + + @wraps(func) + def wrapper(self: Any, request: Request, *args: Any, **kwargs: Any) -> Any: + """Wrapper to validate the inputs and key. + + Args: + request (Request): Request context + + Raises: + Forbidden: _description_ + APINotFound: _description_ + + Returns: + Any: _description_ + """ + try: + authorization_header = request.headers.get("Authorization") + api_key = None + if authorization_header and authorization_header.startswith("Bearer "): + api_key = authorization_header.split(" ")[1] + if not api_key: + raise Forbidden("Missing api key") + org_name = kwargs.get("org_name") or request.data.get("org_name") + api_name = kwargs.get("api_name") or request.data.get("api_name") + if not api_name: + raise Forbidden("Missing api_name") + # Set organization in state store for API + StateStore.set(Account.ORGANIZATION_ID, org_name) + + api_deployment = DeploymentHelper.get_deployment_by_api_name( + api_name=api_name + ) + DeploymentHelper.validate_api( + api_deployment=api_deployment, api_key=api_key + ) + kwargs["api"] = api_deployment + return func(self, request, *args, **kwargs) + + except (UnauthorizedKey, InactiveAPI, APINotFound): + raise + except Exception as exception: + logger.error(f"Exception: {exception}") + return Response( + {"error": str(exception)}, status=status.HTTP_403_FORBIDDEN + ) + + return wrapper + + @staticmethod + def validate_api(api_deployment: Optional[APIDeployment], api_key: str) -> None: + """Validating API and API key. + + Args: + api_deployment (Optional[APIDeployment]): _description_ + api_key (str): _description_ + + Raises: + APINotFound: _description_ + InactiveAPI: _description_ + """ + if not api_deployment: + raise APINotFound() + if not api_deployment.is_active: + raise InactiveAPI() + KeyHelper.validate_api_key(api_key=api_key, api_instance=api_deployment) + + @staticmethod + def validate_and_get_workflow(workflow_id: str) -> Workflow: + """Validate that the specified workflow_id exists in the Workflow + model.""" + return WorkflowHelper.get_workflow_by_id(workflow_id) + + @staticmethod + def get_api_by_id(api_id: str) -> Optional[APIDeployment]: + try: + api_deployment: APIDeployment = APIDeployment.objects.get(pk=api_id) + return api_deployment + except APIDeployment.DoesNotExist: + return None + + @staticmethod + def construct_complete_endpoint(api_name: str) -> str: + """Constructs the complete API endpoint by appending organization + schema, endpoint path, and Django app backend URL. + + Parameters: + - endpoint (str): The endpoint path to be appended to the complete URL. + + Returns: + - str: The complete API endpoint URL. + """ + org_schema = connection.get_tenant().schema_name + return f"{ApiExecution.PATH}/{org_schema}/{api_name}/" + + @staticmethod + def construct_status_endpoint(api_endpoint: str, execution_id: str) -> str: + """Construct a complete status endpoint URL by appending the + execution_id as a query parameter. + + Args: + api_endpoint (str): The base API endpoint. + execution_id (str): The execution ID to be included as + a query parameter. + + Returns: + str: The complete status endpoint URL. + """ + query_parameters = urlencode({"execution_id": execution_id}) + complete_endpoint = f"/{api_endpoint}?{query_parameters}" + return complete_endpoint + + @staticmethod + def get_deployment_by_api_name( + api_name: str, + ) -> Optional[APIDeployment]: + """Get and return the APIDeployment object by api_name.""" + try: + api: APIDeployment = APIDeployment.objects.get(api_name=api_name) + return api + except APIDeployment.DoesNotExist: + return None + + @staticmethod + def create_api_key(serializer: Serializer) -> APIKey: + """To make API key for an API. + + Args: + serializer (Serializer): Request serializer + + Raises: + ApiKeyCreateException: Exception + """ + api_deployment: APIDeployment = serializer.instance + try: + api_key: APIKey = KeyHelper.create_api_key(api_deployment) + return api_key + except Exception as error: + logger.error(f"Error while creating API key error: {str(error)}") + api_deployment.delete() + logger.info("Deleted the deployment instance") + raise ApiKeyCreateException() + + @staticmethod + def execute_workflow( + organization_name: str, + api: APIDeployment, + file_objs: list[UploadedFile], + timeout: int, + ) -> ReturnDict: + """Execute workflow by api. + + Args: + organization_name (str): organization name + api (APIDeployment): api model object + file_obj (UploadedFile): input file + + Returns: + ReturnDict: execution status/ result + """ + workflow_id = api.workflow.id + pipeline_id = api.id + execution_id = str(uuid.uuid4()) + hash_values_of_files = SourceConnector.add_input_file_to_api_storage( + workflow_id=workflow_id, + execution_id=execution_id, + file_objs=file_objs, + ) + try: + result = WorkflowHelper.execute_workflow_async( + workflow_id=workflow_id, + pipeline_id=pipeline_id, + hash_values_of_files=hash_values_of_files, + timeout=timeout, + execution_id=execution_id, + ) + result.status_api = DeploymentHelper.construct_status_endpoint( + api_endpoint=api.api_endpoint, execution_id=execution_id + ) + except Exception: + DestinationConnector.delete_api_storage_dir( + workflow_id=workflow_id, execution_id=execution_id + ) + raise + return APIExecutionResponseSerializer(result).data + + @staticmethod + def get_execution_status(execution_id: str) -> ExecutionResponse: + """Current status of api execution. + + Args: + execution_id (str): execution id + + Returns: + ReturnDict: status/result of execution + """ + execution_response: ExecutionResponse = WorkflowHelper.get_status_of_async_task( + execution_id=execution_id + ) + return execution_response diff --git a/backend/api_v2/exceptions.py b/backend/api_v2/exceptions.py new file mode 100644 index 000000000..c3a58ff05 --- /dev/null +++ b/backend/api_v2/exceptions.py @@ -0,0 +1,55 @@ +from typing import Optional + +from rest_framework.exceptions import APIException + + +class MandatoryWorkflowId(APIException): + status_code = 400 + default_detail = "Workflow ID is mandatory" + + +class ApiKeyCreateException(APIException): + status_code = 500 + default_detail = "Exception while create API key" + + +class Forbidden(APIException): + status_code = 403 + default_detail = ( + "User is forbidden from performing this action. Please contact admin" + ) + + +class APINotFound(APIException): + status_code = 404 + default_detail = "API not found" + + +class InvalidAPIRequest(APIException): + status_code = 400 + default_detail = "Bad request" + + +class InactiveAPI(APIException): + status_code = 404 + default_detail = "API not found or Inactive" + + +class UnauthorizedKey(APIException): + status_code = 401 + default_detail = "Unauthorized" + + +class NoActiveAPIKeyError(APIException): + status_code = 409 + default_detail = "No active API keys configured for this deployment" + + def __init__( + self, + detail: Optional[str] = None, + code: Optional[str] = None, + deployment_name: str = "this deployment", + ): + if detail is None: + detail = f"No active API keys configured for {deployment_name}" + super().__init__(detail, code) diff --git a/backend/api_v2/key_helper.py b/backend/api_v2/key_helper.py new file mode 100644 index 000000000..2a1a430a9 --- /dev/null +++ b/backend/api_v2/key_helper.py @@ -0,0 +1,72 @@ +import logging + +from api_v2.exceptions import Forbidden, UnauthorizedKey +from api_v2.models import APIDeployment, APIKey +from api_v2.serializers import APIKeySerializer +from workflow_manager.workflow_v2.workflow_helper import WorkflowHelper + +logger = logging.getLogger(__name__) + + +class KeyHelper: + @staticmethod + def validate_api_key(api_key: str, api_instance: APIDeployment) -> None: + """Validate api key. + + Args: + api_key (str): api key from request + api_instance (APIDeployment): api deployment instance + + Raises: + Forbidden: _description_ + """ + try: + api_key_instance: APIKey = APIKey.objects.get(api_key=api_key) + if not KeyHelper.has_access(api_key_instance, api_instance): + raise UnauthorizedKey() + except APIKey.DoesNotExist: + raise UnauthorizedKey() + except APIDeployment.DoesNotExist: + raise Forbidden("API not found.") + + @staticmethod + def list_api_keys_of_api(api_instance: APIDeployment) -> list[APIKey]: + api_keys: list[APIKey] = APIKey.objects.filter(api=api_instance).all() + return api_keys + + @staticmethod + def has_access(api_key: APIKey, api_instance: APIDeployment) -> bool: + """Check if the provided API key has access to the specified API + instance. + + Args: + api_key (APIKey): api key associated with the api + api_instance (APIDeployment): api model + + Returns: + bool: True if allowed to execute, False otherwise + """ + if not api_key.is_active: + return False + if isinstance(api_key.api, APIDeployment): + return api_key.api == api_instance + return False + + @staticmethod + def validate_workflow_exists(workflow_id: str) -> None: + """Validate that the specified workflow_id exists in the Workflow + model.""" + WorkflowHelper.get_workflow_by_id(workflow_id) + + @staticmethod + def create_api_key(deployment: APIDeployment) -> APIKey: + """Create an APIKey entity with the data from the provided + APIDeployment instance.""" + # Create an instance of the APIKey model + api_key_serializer = APIKeySerializer( + data={"api": deployment.id, "description": "Initial Access Key"}, + context={"deployment": deployment}, + ) + api_key_serializer.is_valid(raise_exception=True) + api_key: APIKey = api_key_serializer.save() + return api_key diff --git a/backend/api_v2/models.py b/backend/api_v2/models.py new file mode 100644 index 000000000..545426537 --- /dev/null +++ b/backend/api_v2/models.py @@ -0,0 +1,167 @@ +import uuid +from typing import Any + +from account_v2.models import User +from api_v2.constants import ApiExecution +from django.db import models +from utils.models.base_model import BaseModel +from utils.models.organization_mixin import ( + DefaultOrganizationManagerMixin, + DefaultOrganizationMixin, +) +from utils.user_context import UserContext +from workflow_manager.workflow_v2.models.workflow import Workflow + +API_NAME_MAX_LENGTH = 30 +DESCRIPTION_MAX_LENGTH = 255 +API_ENDPOINT_MAX_LENGTH = 255 + + +class APIDeploymentModelManager(DefaultOrganizationManagerMixin, models.Manager): + pass + + +class APIDeployment(DefaultOrganizationMixin, BaseModel): + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + display_name = models.CharField( + max_length=API_NAME_MAX_LENGTH, + default="default api", + db_comment="User-given display name for the API.", + ) + description = models.CharField( + max_length=DESCRIPTION_MAX_LENGTH, + blank=True, + default="", + db_comment="User-given description for the API.", + ) + workflow = models.ForeignKey( + Workflow, + on_delete=models.CASCADE, + db_comment="Foreign key reference to the Workflow model.", + related_name="apis", + ) + is_active = models.BooleanField( + default=True, + db_comment="Flag indicating whether the API is active or not.", + ) + api_endpoint = models.CharField( + max_length=API_ENDPOINT_MAX_LENGTH, + unique=True, + editable=False, + db_comment="URL endpoint for the API deployment.", + ) + api_name = models.CharField( + max_length=API_NAME_MAX_LENGTH, + default=uuid.uuid4, + db_comment="Short name for the API deployment.", + ) + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="apis_created", + null=True, + blank=True, + editable=False, + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="apis_modified", + null=True, + blank=True, + editable=False, + ) + + # Manager + objects = APIDeploymentModelManager() + + def __str__(self) -> str: + return f"{self.id} - {self.display_name}" + + def save(self, *args: Any, **kwargs: Any) -> None: + """Save hook to update api_endpoint. + + Custom save hook for updating the 'api_endpoint' based on + 'api_name'. If the instance is being updated, it checks for + changes in 'api_name' and adjusts 'api_endpoint' + accordingly. If the instance is new, 'api_endpoint' is set + based on 'api_name' and the current database schema. + """ + if self.pk is not None: + organization_id = UserContext.get_organization_identifier() + try: + original = APIDeployment.objects.get(pk=self.pk) + if original.api_name != self.api_name: + self.api_endpoint = ( + f"{ApiExecution.PATH}/{organization_id}/{self.api_name}/" + ) + except APIDeployment.DoesNotExist: + self.api_endpoint = ( + f"{ApiExecution.PATH}/{organization_id}/{self.api_name}/" + ) + super().save(*args, **kwargs) + + class Meta: + verbose_name = "Api Deployment" + verbose_name_plural = "Api Deployments" + db_table = "api_deployment_v2" + constraints = [ + models.UniqueConstraint( + fields=["api_name", "organization"], + name="unique_api_name", + ), + ] + + +class APIKey(BaseModel): + id = models.UUIDField( + primary_key=True, + editable=False, + default=uuid.uuid4, + db_comment="Unique identifier for the API key.", + ) + api_key = models.UUIDField( + default=uuid.uuid4, + editable=False, + unique=True, + db_comment="Actual key UUID.", + ) + api = models.ForeignKey( + APIDeployment, + on_delete=models.CASCADE, + db_comment="Foreign key reference to the APIDeployment model.", + related_name="api_keys", + ) + description = models.CharField( + max_length=DESCRIPTION_MAX_LENGTH, + null=True, + db_comment="Description of the API key.", + ) + is_active = models.BooleanField( + default=True, + db_comment="Flag indicating whether the API key is active or not.", + ) + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="api_keys_created", + null=True, + blank=True, + editable=False, + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="api_keys_modified", + null=True, + blank=True, + editable=False, + ) + + def __str__(self) -> str: + return f"{self.api.api_name} - {self.id} - {self.api_key}" + + class Meta: + verbose_name = "Api Deployment key" + verbose_name_plural = "Api Deployment keys" + db_table = "api_deployment_key_v2" diff --git a/backend/api_v2/postman_collection/__init__.py b/backend/api_v2/postman_collection/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/api_v2/postman_collection/constants.py b/backend/api_v2/postman_collection/constants.py new file mode 100644 index 000000000..ffb3ef729 --- /dev/null +++ b/backend/api_v2/postman_collection/constants.py @@ -0,0 +1,6 @@ +class CollectionKey: + POSTMAN_COLLECTION_V210 = "https://schema.getpostman.com/json/collection/v2.1.0/collection.json" # noqa: E501 + EXECUTE_API_KEY = "Process document" + STATUS_API_KEY = "Execution status" + STATUS_EXEC_ID_DEFAULT = "REPLACE_WITH_EXECUTION_ID" + AUTH_QUERY_PARAM_DEFAULT = "REPLACE_WITH_API_KEY" diff --git a/backend/api_v2/postman_collection/dto.py b/backend/api_v2/postman_collection/dto.py new file mode 100644 index 000000000..08f2294f1 --- /dev/null +++ b/backend/api_v2/postman_collection/dto.py @@ -0,0 +1,139 @@ +from dataclasses import asdict, dataclass, field +from typing import Any, Optional +from urllib.parse import urlencode, urljoin + +from api_v2.constants import ApiExecution +from api_v2.models import APIDeployment +from api_v2.postman_collection.constants import CollectionKey +from django.conf import settings +from utils.request import HTTPMethod + + +@dataclass +class HeaderItem: + key: str + value: str + + +@dataclass +class FormDataItem: + key: str + type: str + src: Optional[str] = None + value: Optional[str] = None + + def __post_init__(self) -> None: + if self.type == "file": + if self.src is None: + raise ValueError("src must be provided for type 'file'") + elif self.type == "text": + if self.value is None: + raise ValueError("value must be provided for type 'text'") + else: + raise ValueError(f"Unsupported type for form data: {self.type}") + + +@dataclass +class BodyItem: + formdata: list[FormDataItem] + mode: str = "formdata" + + +@dataclass +class RequestItem: + method: HTTPMethod + url: str + header: list[HeaderItem] + body: Optional[BodyItem] = None + + +@dataclass +class PostmanItem: + name: str + request: RequestItem + + +@dataclass +class PostmanInfo: + name: str = "Unstract's API deployment" + schema: str = CollectionKey.POSTMAN_COLLECTION_V210 + description: str = "Contains APIs meant for using the deployed Unstract API" + + +@dataclass +class PostmanCollection: + info: PostmanInfo + item: list[PostmanItem] = field(default_factory=list) + + @classmethod + def create( + cls, + instance: APIDeployment, + api_key: str = CollectionKey.AUTH_QUERY_PARAM_DEFAULT, + ) -> "PostmanCollection": + """Creates a PostmanCollection instance. + + This instance can help represent Postman collections (v2 format) that + can be used to easily invoke workflows deployed as APIs + + Args: + instance (APIDeployment): API deployment to generate collection for + api_key (str, optional): Active API key used to authenticate requests for + deployed APIs. Defaults to CollectionKey.AUTH_QUERY_PARAM_DEFAULT. + + Returns: + PostmanCollection: Instance representing PostmanCollection + """ + postman_info = PostmanInfo( + name=instance.display_name, description=instance.description + ) + header_list = [HeaderItem(key="Authorization", value=f"Bearer {api_key}")] + abs_api_endpoint = urljoin(settings.WEB_APP_ORIGIN_URL, instance.api_endpoint) + + # API execution API + execute_body = BodyItem( + formdata=[ + FormDataItem( + key=ApiExecution.FILES_FORM_DATA, type="file", src="/path_to_file" + ), + FormDataItem( + key=ApiExecution.TIMEOUT_FORM_DATA, + type="text", + value=ApiExecution.MAXIMUM_TIMEOUT_IN_SEC, + ), + FormDataItem( + key=ApiExecution.INCLUDE_METADATA, + type="text", + value=False, + ), + ] + ) + execute_request = RequestItem( + method=HTTPMethod.POST, + header=header_list, + body=execute_body, + url=abs_api_endpoint, + ) + + # Status API + status_query_param = {"execution_id": CollectionKey.STATUS_EXEC_ID_DEFAULT} + status_query_str = urlencode(status_query_param) + status_url = urljoin(abs_api_endpoint, "?" + status_query_str) + status_request = RequestItem( + method=HTTPMethod.GET, header=header_list, url=status_url + ) + + postman_item_list = [ + PostmanItem(name=CollectionKey.EXECUTE_API_KEY, request=execute_request), + PostmanItem(name=CollectionKey.STATUS_API_KEY, request=status_request), + ] + return cls(info=postman_info, item=postman_item_list) + + def to_dict(self) -> dict[str, Any]: + """Convert PostmanCollection instance to a dict. + + Returns: + dict[str, Any]: PostmanCollection as a dict + """ + collection_dict = asdict(self) + return collection_dict diff --git a/backend/api_v2/serializers.py b/backend/api_v2/serializers.py new file mode 100644 index 000000000..522e5a23e --- /dev/null +++ b/backend/api_v2/serializers.py @@ -0,0 +1,123 @@ +from collections import OrderedDict +from typing import Any, Union + +from api_v2.constants import ApiExecution +from api_v2.models import APIDeployment, APIKey +from django.core.validators import RegexValidator +from rest_framework.serializers import ( + CharField, + IntegerField, + JSONField, + ModelSerializer, + Serializer, + ValidationError, +) + +from backend.serializers import AuditSerializer + + +class APIDeploymentSerializer(AuditSerializer): + class Meta: + model = APIDeployment + fields = "__all__" + + def validate_api_name(self, value: str) -> str: + api_name_validator = RegexValidator( + regex=r"^[a-zA-Z0-9_-]+$", + message="Only letters, numbers, hyphen and \ + underscores are allowed.", + code="invalid_api_name", + ) + api_name_validator(value) + return value + + +class APIKeySerializer(AuditSerializer): + class Meta: + model = APIKey + fields = "__all__" + + def to_representation(self, instance: APIKey) -> OrderedDict[str, Any]: + """Override the to_representation method to include additional + context.""" + context = self.context.get("context", {}) + deployment: APIDeployment = context.get("deployment") + representation: OrderedDict[str, Any] = super().to_representation(instance) + if deployment: + representation["api"] = deployment.id + representation["description"] = f"API Key for {deployment.name}" + representation["is_active"] = True + + return representation + + +class ExecutionRequestSerializer(Serializer): + """Execution request serializer + timeout: 0: maximum value of timeout, -1: async execution + """ + + timeout = IntegerField( + min_value=-1, max_value=ApiExecution.MAXIMUM_TIMEOUT_IN_SEC, default=-1 + ) + + def validate_timeout(self, value: Any) -> int: + if not isinstance(value, int): + raise ValidationError("timeout must be a integer.") + if value == 0: + value = ApiExecution.MAXIMUM_TIMEOUT_IN_SEC + return value + + def get_timeout(self, validated_data: dict[str, Union[int, None]]) -> int: + value = validated_data.get(ApiExecution.TIMEOUT_FORM_DATA, -1) + if not isinstance(value, int): + raise ValidationError("timeout must be a integer.") + return value + + +class APIDeploymentListSerializer(ModelSerializer): + workflow_name = CharField(source="workflow.workflow_name", read_only=True) + + class Meta: + model = APIDeployment + fields = [ + "id", + "workflow", + "workflow_name", + "display_name", + "description", + "is_active", + "api_endpoint", + "api_name", + "created_by", + ] + + +class APIKeyListSerializer(ModelSerializer): + class Meta: + model = APIKey + fields = [ + "id", + "created_at", + "modified_at", + "api_key", + "is_active", + "description", + "api", + ] + + +class DeploymentResponseSerializer(Serializer): + is_active = CharField() + id = CharField() + api_key = CharField() + api_endpoint = CharField() + display_name = CharField() + description = CharField() + api_name = CharField() + + +class APIExecutionResponseSerializer(Serializer): + execution_status = CharField() + status_api = CharField() + error = CharField() + result = JSONField() diff --git a/backend/api_v2/urls.py b/backend/api_v2/urls.py new file mode 100644 index 000000000..065617d66 --- /dev/null +++ b/backend/api_v2/urls.py @@ -0,0 +1,63 @@ +from api_v2.api_deployment_views import APIDeploymentViewSet, DeploymentExecution +from api_v2.api_key_views import APIKeyViewSet +from django.urls import path, re_path +from rest_framework.urlpatterns import format_suffix_patterns + +deployment = APIDeploymentViewSet.as_view( + { + "get": APIDeploymentViewSet.list.__name__, + "post": APIDeploymentViewSet.create.__name__, + } +) +deployment_details = APIDeploymentViewSet.as_view( + { + "get": APIDeploymentViewSet.retrieve.__name__, + "put": APIDeploymentViewSet.update.__name__, + "patch": APIDeploymentViewSet.partial_update.__name__, + "delete": APIDeploymentViewSet.destroy.__name__, + } +) +download_postman_collection = APIDeploymentViewSet.as_view( + { + "get": APIDeploymentViewSet.download_postman_collection.__name__, + } +) + +execute = DeploymentExecution.as_view() + +key_details = APIKeyViewSet.as_view( + { + "get": APIKeyViewSet.retrieve.__name__, + "put": APIKeyViewSet.update.__name__, + "delete": APIKeyViewSet.destroy.__name__, + } +) +api_key = APIKeyViewSet.as_view( + { + "get": APIKeyViewSet.api_keys.__name__, + "post": APIKeyViewSet.create.__name__, + } +) + +urlpatterns = format_suffix_patterns( + [ + path("deployment/", deployment, name="api_deployment"), + path( + "deployment//", + deployment_details, + name="api_deployment_details", + ), + path( + "postman_collection//", + download_postman_collection, + name="download_postman_collection", + ), + re_path( + r"^api/(?P[\w-]+)/(?P[\w-]+)/?$", + execute, + name="api_deployment_execution", + ), + path("keys//", key_details, name="key_details"), + path("keys/api//", api_key, name="api_key"), + ] +) diff --git a/backend/connector_auth_v2/__init__.py b/backend/connector_auth_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/connector_auth_v2/admin.py b/backend/connector_auth_v2/admin.py new file mode 100644 index 000000000..014dfec3e --- /dev/null +++ b/backend/connector_auth_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import ConnectorAuth + +admin.site.register(ConnectorAuth) diff --git a/backend/connector_auth_v2/apps.py b/backend/connector_auth_v2/apps.py new file mode 100644 index 000000000..e7cc819b7 --- /dev/null +++ b/backend/connector_auth_v2/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class ConnectorAuthConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "connector_auth_v2" diff --git a/backend/connector_auth_v2/constants.py b/backend/connector_auth_v2/constants.py new file mode 100644 index 000000000..886968d87 --- /dev/null +++ b/backend/connector_auth_v2/constants.py @@ -0,0 +1,18 @@ +class ConnectorAuthKey: + OAUTH_KEY = "oauth-key" + + +class SocialAuthConstants: + UID = "uid" + PROVIDER = "provider" + ACCESS_TOKEN = "access_token" + REFRESH_TOKEN = "refresh_token" + TOKEN_TYPE = "token_type" + AUTH_TIME = "auth_time" + EXPIRES = "expires" + + REFRESH_AFTER_FORMAT = "%d/%m/%Y %H:%M:%S" + REFRESH_AFTER = "refresh_after" # Timestamp to refresh tokens after + + GOOGLE_OAUTH = "google-oauth2" + GOOGLE_TOKEN_EXPIRY_FORMAT = "%d/%m/%Y %H:%M:%S" diff --git a/backend/connector_auth_v2/exceptions.py b/backend/connector_auth_v2/exceptions.py new file mode 100644 index 000000000..603bcc8d9 --- /dev/null +++ b/backend/connector_auth_v2/exceptions.py @@ -0,0 +1,31 @@ +from typing import Optional + +from rest_framework.exceptions import APIException + + +class CacheMissException(APIException): + status_code = 404 + default_detail = "Key doesn't exist." + + +class EnrichConnectorMetadataException(APIException): + status_code = 500 + default_detail = "Connector metadata could not be enriched" + + +class MissingParamException(APIException): + status_code = 400 + default_detail = "Bad request, missing parameter." + + def __init__( + self, + code: Optional[str] = None, + param: Optional[str] = None, + ) -> None: + detail = f"Bad request, missing parameter: {param}" + super().__init__(detail, code) + + +class KeyNotConfigured(APIException): + status_code = 500 + default_detail = "Key is not configured correctly" diff --git a/backend/connector_auth_v2/models.py b/backend/connector_auth_v2/models.py new file mode 100644 index 000000000..ca8caaefd --- /dev/null +++ b/backend/connector_auth_v2/models.py @@ -0,0 +1,141 @@ +import logging +import uuid +from typing import Any + +from account_v2.models import User +from connector_auth_v2.constants import SocialAuthConstants +from connector_auth_v2.pipeline.google import GoogleAuthHelper +from django.db import models +from django.db.models.query import QuerySet +from rest_framework.request import Request +from social_django.fields import JSONField +from social_django.models import AbstractUserSocialAuth, DjangoStorage +from social_django.strategy import DjangoStrategy + +logger = logging.getLogger(__name__) + + +class ConnectorAuthManager(models.Manager): + def get_queryset(self) -> QuerySet: + queryset = super().get_queryset() + # TODO PAN-83: Decrypt here + # for obj in queryset: + # logger.info(f"Decrypting extra_data: {obj.extra_data}") + + return queryset + + +class ConnectorAuth(AbstractUserSocialAuth): + """Social Auth association model, stores tokens. + The relation with `account.User` is only for the library to work + and should be NOT be used to access the secrets. + Use the following static methods instead + ``` + @classmethod + def get_social_auth(cls, provider, id): + + @classmethod + def create_social_auth(cls, user, uid, provider): + ``` + """ + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + user = models.ForeignKey( + User, + related_name="connector_auths", + on_delete=models.CASCADE, + null=True, + ) + + def __str__(self) -> str: + return f"ConnectorAuth(provider: {self.provider}, uid: {self.uid})" + + def save(self, *args: Any, **kwargs: Any) -> Any: + # TODO PAN-83: Encrypt here + # logger.info(f"Encrypting extra_data: {self.extra_data}") + return super().save(*args, **kwargs) + + def set_extra_data(self, extra_data=None): # type: ignore + ConnectorAuth.check_credential_format(extra_data) + if extra_data[SocialAuthConstants.PROVIDER] == SocialAuthConstants.GOOGLE_OAUTH: + extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data) + return super().set_extra_data(extra_data) + + def refresh_token(self, strategy, *args, **kwargs): # type: ignore + """Override of Python Social Auth (PSA)'s refresh_token functionality + to store uid, provider.""" + token = self.extra_data.get("refresh_token") or self.extra_data.get( + "access_token" + ) + backend = self.get_backend_instance(strategy) + if token and backend and hasattr(backend, "refresh_token"): + response = backend.refresh_token(token, *args, **kwargs) + extra_data = backend.extra_data(self, self.uid, response, self.extra_data) + extra_data[SocialAuthConstants.PROVIDER] = backend.name + extra_data[SocialAuthConstants.UID] = self.uid + if self.set_extra_data(extra_data): # type: ignore + self.save() + + def get_and_refresh_tokens(self, request: Request = None) -> tuple[JSONField, bool]: + """Uses Social Auth's ability to refresh tokens if necessary. + + Returns: + Tuple[JSONField, bool]: JSONField of connector metadata + and flag indicating if tokens were refreshed + """ + # To avoid circular dependency error on import + from social_django.utils import load_strategy + + refreshed_token = False + strategy: DjangoStrategy = load_strategy(request=request) + existing_access_token = self.access_token + new_access_token = self.get_access_token(strategy) + if new_access_token != existing_access_token: + refreshed_token = True + related_connector_instances = self.connectorinstance_set.all() + for connector_instance in related_connector_instances: + connector_instance.connector_metadata = self.extra_data + connector_instance.save() + logger.info( + f"Refreshed access token for connector {connector_instance.id}, " + f"provider: {self.provider}, uid: {self.uid}" + ) + + return self.extra_data, refreshed_token + + @staticmethod + def check_credential_format( + oauth_credentials: dict[str, str], raise_exception: bool = True + ) -> bool: + if ( + SocialAuthConstants.PROVIDER in oauth_credentials + and SocialAuthConstants.UID in oauth_credentials + ): + return True + else: + if raise_exception: + raise ValueError( + "Auth credential should have provider, uid and connector guid" + ) + return False + + objects = ConnectorAuthManager() + + class Meta: + app_label = "connector_auth_v2" + verbose_name = "Connector Auth" + verbose_name_plural = "Connector Auths" + db_table = "connector_auth_v2" + constraints = [ + models.UniqueConstraint( + fields=[ + "provider", + "uid", + ], + name="unique_provider_uid_index", + ), + ] + + +class ConnectorDjangoStorage(DjangoStorage): + user = ConnectorAuth diff --git a/backend/connector_auth_v2/pipeline/common.py b/backend/connector_auth_v2/pipeline/common.py new file mode 100644 index 000000000..b6c4861c9 --- /dev/null +++ b/backend/connector_auth_v2/pipeline/common.py @@ -0,0 +1,111 @@ +import logging +from typing import Any, Optional + +from account_v2.models import User +from connector_auth_v2.constants import ConnectorAuthKey, SocialAuthConstants +from connector_auth_v2.models import ConnectorAuth +from connector_auth_v2.pipeline.google import GoogleAuthHelper +from django.conf import settings +from django.core.cache import cache +from rest_framework.exceptions import PermissionDenied +from social_core.backends.oauth import BaseOAuth2 + +logger = logging.getLogger(__name__) + + +def check_user_exists(backend: BaseOAuth2, user: User, **kwargs: Any) -> dict[str, str]: + """Checks if user is authenticated (will be handled in auth middleware, + present as a fail safe) + + Args: + user (account.User): User model + + Raises: + PermissionDenied: Unauthorized user + + Returns: + dict: Carrying response details for auth pipeline + """ + if not user: + raise PermissionDenied(backend) + return {**kwargs} + + +def cache_oauth_creds( + backend: BaseOAuth2, + details: dict[str, str], + response: dict[str, str], + uid: str, + user: User, + *args: Any, + **kwargs: Any, +) -> dict[str, str]: + """Used to cache the extra data JSON in redis against a key. + + This contains the access and refresh token along with details + regarding expiry, uid (unique ID given by provider) and provider. + """ + cache_key = kwargs.get("cache_key") or backend.strategy.session_get( + settings.SOCIAL_AUTH_FIELDS_STORED_IN_SESSION[0], + ConnectorAuthKey.OAUTH_KEY, + ) + extra_data = backend.extra_data(user, uid, response, details, *args, **kwargs) + extra_data[SocialAuthConstants.PROVIDER] = backend.name + extra_data[SocialAuthConstants.UID] = uid + + if backend.name == SocialAuthConstants.GOOGLE_OAUTH: + extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data) + + cache.set( + cache_key, + extra_data, + int(settings.SOCIAL_AUTH_EXTRA_DATA_EXPIRATION_TIME_IN_SECOND), + ) + return {**kwargs} + + +class ConnectorAuthHelper: + @staticmethod + def get_oauth_creds_from_cache( + cache_key: str, delete_key: bool = True + ) -> Optional[dict[str, str]]: + """Retrieves oauth credentials from the cache. + + Args: + cache_key (str): Key to obtain credentials from + + Returns: + Optional[dict[str,str]]: Returns credentials. None if it doesn't exist + """ + oauth_creds: dict[str, str] = cache.get(cache_key) + if delete_key: + cache.delete(cache_key) + return oauth_creds + + @staticmethod + def get_or_create_connector_auth( + oauth_credentials: dict[str, str], user: User = None # type: ignore + ) -> ConnectorAuth: + """Gets or creates a ConnectorAuth object. + + Args: + user (User): Used while creation, can be removed if not required + oauth_credentials (dict[str,str]): Needs to have provider and uid + + Returns: + ConnectorAuth: Object for the respective provider/uid + """ + ConnectorAuth.check_credential_format(oauth_credentials) + provider = oauth_credentials[SocialAuthConstants.PROVIDER] + uid = oauth_credentials[SocialAuthConstants.UID] + connector_oauth: ConnectorAuth = ConnectorAuth.get_social_auth( + provider=provider, uid=uid + ) + if not connector_oauth: + connector_oauth = ConnectorAuth.create_social_auth( + user, uid=uid, provider=provider + ) + + # TODO: Remove User's related manager access to ConnectorAuth + connector_oauth.set_extra_data(oauth_credentials) # type: ignore + return connector_oauth diff --git a/backend/connector_auth_v2/pipeline/google.py b/backend/connector_auth_v2/pipeline/google.py new file mode 100644 index 000000000..d585bd085 --- /dev/null +++ b/backend/connector_auth_v2/pipeline/google.py @@ -0,0 +1,33 @@ +from datetime import datetime, timedelta + +from connector_auth_v2.constants import SocialAuthConstants as AuthConstants +from connector_auth_v2.exceptions import EnrichConnectorMetadataException +from connector_processor.constants import ConnectorKeys + +from unstract.connectors.filesystems.google_drive.constants import GDriveConstants + + +class GoogleAuthHelper: + @staticmethod + def enrich_connector_metadata(kwargs: dict[str, str]) -> dict[str, str]: + token_expiry: datetime = datetime.now() + auth_time = kwargs.get(AuthConstants.AUTH_TIME) + expires = kwargs.get(AuthConstants.EXPIRES) + if auth_time and expires: + reference = datetime.utcfromtimestamp(float(auth_time)) + token_expiry = reference + timedelta(seconds=float(expires)) + else: + raise EnrichConnectorMetadataException + # Used by GDrive FS, apart from ACCESS_TOKEN and REFRESH_TOKEN + kwargs[GDriveConstants.TOKEN_EXPIRY] = token_expiry.strftime( + AuthConstants.GOOGLE_TOKEN_EXPIRY_FORMAT + ) + + # Used by Unstract + kwargs[ConnectorKeys.PATH] = ( + GDriveConstants.ROOT_PREFIX + ) # Acts as a prefix for all paths + kwargs[AuthConstants.REFRESH_AFTER] = token_expiry.strftime( + AuthConstants.REFRESH_AFTER_FORMAT + ) + return kwargs diff --git a/backend/connector_auth_v2/urls.py b/backend/connector_auth_v2/urls.py new file mode 100644 index 000000000..55337ad20 --- /dev/null +++ b/backend/connector_auth_v2/urls.py @@ -0,0 +1,21 @@ +from django.urls import include, path, re_path +from rest_framework.urlpatterns import format_suffix_patterns + +from .views import ConnectorAuthViewSet + +connector_auth_cache = ConnectorAuthViewSet.as_view( + { + "get": "cache_key", + } +) + +urlpatterns = format_suffix_patterns( + [ + path("oauth/", include("social_django.urls", namespace="social")), + re_path( + "^oauth/cache-key/(?P.+)$", + connector_auth_cache, + name="connector-cache", + ), + ] +) diff --git a/backend/connector_auth_v2/views.py b/backend/connector_auth_v2/views.py new file mode 100644 index 000000000..1cc2022bf --- /dev/null +++ b/backend/connector_auth_v2/views.py @@ -0,0 +1,48 @@ +import logging +import uuid + +from connector_auth_v2.constants import SocialAuthConstants +from connector_auth_v2.exceptions import KeyNotConfigured +from django.conf import settings +from rest_framework import status, viewsets +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from utils.user_session import UserSessionUtils + +logger = logging.getLogger(__name__) + + +class ConnectorAuthViewSet(viewsets.ViewSet): + """Contains methods for Connector related authentication.""" + + versioning_class = URLPathVersioning + + def cache_key( + self: "ConnectorAuthViewSet", request: Request, backend: str + ) -> Response: + if backend == SocialAuthConstants.GOOGLE_OAUTH and ( + settings.SOCIAL_AUTH_GOOGLE_OAUTH2_KEY is None + or settings.SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET is None + ): + msg = ( + f"Keys not configured for {backend}, add env vars " + f"`GOOGLE_OAUTH2_KEY` and `GOOGLE_OAUTH2_SECRET`." + ) + logger.warn(msg) + raise KeyNotConfigured( + f"{msg}\nRefer to: " + "https://developers.google.com/identity/protocols/oauth2#1.-" + "obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar." + "console_name-." + ) + + random = str(uuid.uuid4()) + user_id = request.user.user_id + org_id = UserSessionUtils.get_organization_id(request) + cache_key = f"oauth:{org_id}|{user_id}|{backend}|{random}" + logger.info(f"Generated cache key: {cache_key}") + return Response( + status=status.HTTP_200_OK, + data={"cache_key": f"{cache_key}"}, + ) diff --git a/backend/connector_processor/connector_processor.py b/backend/connector_processor/connector_processor.py index b2346a9d4..770d03c49 100644 --- a/backend/connector_processor/connector_processor.py +++ b/backend/connector_processor/connector_processor.py @@ -3,23 +3,30 @@ import logging from typing import Any, Optional -from connector.constants import ConnectorInstanceKey as CIKey -from connector_auth.constants import ConnectorAuthKey -from connector_auth.pipeline.common import ConnectorAuthHelper from connector_processor.constants import ConnectorKeys from connector_processor.exceptions import ( - InternalServiceError, InValidConnectorId, InValidConnectorMode, OAuthTimeOut, TestConnectorInputError, ) +from backend.constants import FeatureFlag from unstract.connectors.base import UnstractConnector from unstract.connectors.connectorkit import Connectorkit from unstract.connectors.enums import ConnectorMode from unstract.connectors.exceptions import ConnectorError from unstract.connectors.filesystems.ucs import UnstractCloudStorage +from unstract.flags.feature_flag import check_feature_flag_status + +if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): + from connector_auth_v2.constants import ConnectorAuthKey + from connector_auth_v2.pipeline.common import ConnectorAuthHelper + from connector_v2.constants import ConnectorInstanceKey as CIKey +else: + from connector.constants import ConnectorInstanceKey as CIKey + from connector_auth.constants import ConnectorAuthKey + from connector_auth.pipeline.common import ConnectorAuthHelper logger = logging.getLogger(__name__) @@ -45,26 +52,27 @@ def get_json_schema(connector_id: str) -> dict: updated_connectors = fetch_connectors_by_key_value( ConnectorKeys.ID, connector_id ) - if len(updated_connectors) != 0: - connector = updated_connectors[0] - schema_details[ConnectorKeys.OAUTH] = connector.get(ConnectorKeys.OAUTH) - schema_details[ConnectorKeys.SOCIAL_AUTH_URL] = connector.get( - ConnectorKeys.SOCIAL_AUTH_URL - ) - try: - schema_details[ConnectorKeys.JSON_SCHEMA] = json.loads( - connector.get(ConnectorKeys.JSON_SCHEMA) - ) - except Exception as exc: - logger.error(f"Error occurred while parsing JSON Schema: {exc}") - raise InternalServiceError() - else: + if len(updated_connectors) == 0: logger.error( f"Invalid connector Id : {connector_id} " f"while fetching " f"JSON Schema" ) raise InValidConnectorId() + + connector = updated_connectors[0] + schema_details[ConnectorKeys.OAUTH] = connector.get(ConnectorKeys.OAUTH) + schema_details[ConnectorKeys.SOCIAL_AUTH_URL] = connector.get( + ConnectorKeys.SOCIAL_AUTH_URL + ) + try: + schema_details[ConnectorKeys.JSON_SCHEMA] = json.loads( + connector.get(ConnectorKeys.JSON_SCHEMA) + ) + except Exception as exc: + logger.error(f"Error occurred decoding JSON for {connector_id}: {exc}") + raise exc + return schema_details @staticmethod diff --git a/backend/connector_processor/views.py b/backend/connector_processor/views.py index edca86ba1..20c9277a7 100644 --- a/backend/connector_processor/views.py +++ b/backend/connector_processor/views.py @@ -1,4 +1,3 @@ -from connector.constants import ConnectorInstanceKey as CIKey from connector_processor.connector_processor import ConnectorProcessor from connector_processor.constants import ConnectorKeys from connector_processor.exceptions import IdIsMandatory, InValidType @@ -13,6 +12,14 @@ from rest_framework.versioning import URLPathVersioning from rest_framework.viewsets import GenericViewSet +from backend.constants import FeatureFlag +from unstract.flags.feature_flag import check_feature_flag_status + +if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): + from connector_v2.constants import ConnectorInstanceKey as CIKey +else: + from connector.constants import ConnectorInstanceKey as CIKey + @api_view(("GET",)) def get_connector_schema(request: HttpRequest) -> HttpResponse: diff --git a/backend/connector_v2/__init__.py b/backend/connector_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/connector_v2/admin.py b/backend/connector_v2/admin.py new file mode 100644 index 000000000..3750fb4c5 --- /dev/null +++ b/backend/connector_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import ConnectorInstance + +admin.site.register(ConnectorInstance) diff --git a/backend/connector_v2/apps.py b/backend/connector_v2/apps.py new file mode 100644 index 000000000..646d84f0d --- /dev/null +++ b/backend/connector_v2/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class ConnectorConfig(AppConfig): + name = "connector_v2" diff --git a/backend/connector_v2/connector_instance_helper.py b/backend/connector_v2/connector_instance_helper.py new file mode 100644 index 000000000..e6a7d4307 --- /dev/null +++ b/backend/connector_v2/connector_instance_helper.py @@ -0,0 +1,330 @@ +import logging +from typing import Any, Optional + +from account_v2.models import User +from connector_v2.constants import ConnectorInstanceConstant +from connector_v2.models import ConnectorInstance +from connector_v2.unstract_account import UnstractAccount +from django.conf import settings +from utils.user_context import UserContext +from workflow_manager.workflow_v2.models.workflow import Workflow + +from unstract.connectors.filesystems.ucs import UnstractCloudStorage +from unstract.connectors.filesystems.ucs.constants import UCSKey + +logger = logging.getLogger(__name__) + + +class ConnectorInstanceHelper: + @staticmethod + def create_default_gcs_connector(workflow: Workflow, user: User) -> None: + """Method to create default storage connector. + + Args: + org_id (str) + workflow (Workflow) + user (User) + """ + organization_id = UserContext.get_organization_identifier() + if not user.project_storage_created: + logger.info("Creating default storage") + account = UnstractAccount(organization_id, user.email) + account.provision_s3_storage() + account.upload_sample_files() + user.project_storage_created = True + user.save() + logger.info("default storage created successfully.") + + logger.info("Adding connectors to Unstract") + connector_name = ConnectorInstanceConstant.USER_STORAGE + gcs_id = UnstractCloudStorage.get_id() + bucket_name = settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME + base_path = f"{bucket_name}/{organization_id}/{user.email}" + + connector_metadata = { + UCSKey.KEY: settings.GOOGLE_STORAGE_ACCESS_KEY_ID, + UCSKey.SECRET: settings.GOOGLE_STORAGE_SECRET_ACCESS_KEY, + UCSKey.BUCKET: bucket_name, + UCSKey.ENDPOINT_URL: settings.GOOGLE_STORAGE_BASE_URL, + } + connector_metadata__input = { + **connector_metadata, + UCSKey.PATH: base_path + "/input", + } + connector_metadata__output = { + **connector_metadata, + UCSKey.PATH: base_path + "/output", + } + ConnectorInstance.objects.create( + connector_name=connector_name, + workflow=workflow, + created_by=user, + connector_id=gcs_id, + connector_metadata=connector_metadata__input, + connector_type=ConnectorInstance.ConnectorType.INPUT, + connector_mode=ConnectorInstance.ConnectorMode.FILE_SYSTEM, + ) + ConnectorInstance.objects.create( + connector_name=connector_name, + workflow=workflow, + created_by=user, + connector_id=gcs_id, + connector_metadata=connector_metadata__output, + connector_type=ConnectorInstance.ConnectorType.OUTPUT, + connector_mode=ConnectorInstance.ConnectorMode.FILE_SYSTEM, + ) + logger.info("Connectors added successfully.") + + @staticmethod + def get_connector_instances_by_workflow( + workflow_id: str, + connector_type: tuple[str, str], + connector_mode: Optional[tuple[int, str]] = None, + values: Optional[list[str]] = None, + connector_name: Optional[str] = None, + ) -> list[ConnectorInstance]: + """Method to get connector instances by workflow. + + Args: + workflow_id (str) + connector_type (tuple[str, str]): Specifies input/output + connector_mode (Optional[tuple[int, str]], optional): + Specifies database/file + values (Optional[list[str]], optional): Defaults to None. + connector_name (Optional[str], optional): Defaults to None. + + Returns: + list[ConnectorInstance] + """ + logger.info(f"Setting connector mode to {connector_mode}") + filter_params: dict[str, Any] = { + "workflow": workflow_id, + "connector_type": connector_type, + } + if connector_mode is not None: + filter_params["connector_mode"] = connector_mode + if connector_name is not None: + filter_params["connector_name"] = connector_name + + connector_instances = ConnectorInstance.objects.filter(**filter_params).all() + logger.debug(f"Retrieved connector instance values {connector_instances}") + if values is not None: + filtered_connector_instances = connector_instances.values(*values) + logger.info( + f"Returning filtered \ + connector instance value {filtered_connector_instances}" + ) + return list(filtered_connector_instances) + logger.info(f"Returning connector instances {connector_instances}") + return list(connector_instances) + + @staticmethod + def get_connector_instance_by_workflow( + workflow_id: str, + connector_type: tuple[str, str], + connector_mode: Optional[tuple[int, str]] = None, + connector_name: Optional[str] = None, + ) -> Optional[ConnectorInstance]: + """Get one connector instance. + + Use this method if the connector instance is unique for \ + filter_params + Args: + workflow_id (str): _description_ + connector_type (tuple[str, str]): Specifies input/output + connector_mode (Optional[tuple[int, str]], optional). + Specifies database/filesystem. + connector_name (Optional[str], optional). + + Returns: + list[ConnectorInstance]: _description_ + """ + logger.info("Fetching connector instance by workflow") + filter_params: dict[str, Any] = { + "workflow": workflow_id, + "connector_type": connector_type, + } + if connector_mode is not None: + filter_params["connector_mode"] = connector_mode + if connector_name is not None: + filter_params["connector_name"] = connector_name + + try: + connector_instance: ConnectorInstance = ConnectorInstance.objects.filter( + **filter_params + ).first() + except Exception as exc: + logger.error(f"Error occured while fetching connector instances {exc}") + raise exc + + return connector_instance + + @staticmethod + def get_input_connector_instance_by_name_for_workflow( + workflow_id: str, + connector_name: str, + ) -> Optional[ConnectorInstance]: + """Method to get Input connector instance name from the workflow. + + Args: + workflow_id (str) + connector_name (str) + + Returns: + Optional[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instance_by_workflow( + workflow_id=workflow_id, + connector_type=ConnectorInstance.ConnectorType.INPUT, + connector_name=connector_name, + ) + + @staticmethod + def get_output_connector_instance_by_name_for_workflow( + workflow_id: str, + connector_name: str, + ) -> Optional[ConnectorInstance]: + """Method to get output connector name by Workflow. + + Args: + workflow_id (str) + connector_name (str) + + Returns: + Optional[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instance_by_workflow( + workflow_id=workflow_id, + connector_type=ConnectorInstance.ConnectorType.OUTPUT, + connector_name=connector_name, + ) + + @staticmethod + def get_input_connector_instances_by_workflow( + workflow_id: str, + ) -> list[ConnectorInstance]: + """Method to get connector instances by workflow. + + Args: + workflow_id (str) + + Returns: + list[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, ConnectorInstance.ConnectorType.INPUT + ) + + @staticmethod + def get_output_connector_instances_by_workflow( + workflow_id: str, + ) -> list[ConnectorInstance]: + """Method to get output connector instances by workflow. + + Args: + workflow_id (str): _description_ + + Returns: + list[ConnectorInstance]: _description_ + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, ConnectorInstance.ConnectorType.OUTPUT + ) + + @staticmethod + def get_file_system_input_connector_instances_by_workflow( + workflow_id: str, values: Optional[list[str]] = None + ) -> list[ConnectorInstance]: + """Method to fetch file system connector by workflow. + + Args: + workflow_id (str): + values (Optional[list[str]], optional) + + Returns: + list[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, + ConnectorInstance.ConnectorType.INPUT, + ConnectorInstance.ConnectorMode.FILE_SYSTEM, + values, + ) + + @staticmethod + def get_file_system_output_connector_instances_by_workflow( + workflow_id: str, values: Optional[list[str]] = None + ) -> list[ConnectorInstance]: + """Method to get file system output connector by workflow. + + Args: + workflow_id (str) + values (Optional[list[str]], optional) + + Returns: + list[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, + ConnectorInstance.ConnectorType.OUTPUT, + ConnectorInstance.ConnectorMode.FILE_SYSTEM, + values, + ) + + @staticmethod + def get_database_input_connector_instances_by_workflow( + workflow_id: str, values: Optional[list[str]] = None + ) -> list[ConnectorInstance]: + """Method to fetch input database connectors by workflow. + + Args: + workflow_id (str) + values (Optional[list[str]], optional) + + Returns: + list[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, + ConnectorInstance.ConnectorType.INPUT, + ConnectorInstance.ConnectorMode.DATABASE, + values, + ) + + @staticmethod + def get_database_output_connector_instances_by_workflow( + workflow_id: str, values: Optional[list[str]] = None + ) -> list[ConnectorInstance]: + """Method to fetch output database connectors by workflow. + + Args: + workflow_id (str) + values (Optional[list[str]], optional) + + Returns: + list[ConnectorInstance] + """ + return ConnectorInstanceHelper.get_connector_instances_by_workflow( + workflow_id, + ConnectorInstance.ConnectorType.OUTPUT, + ConnectorInstance.ConnectorMode.DATABASE, + values, + ) + + @staticmethod + def get_input_output_connector_instances_by_workflow( + workflow_id: str, + ) -> list[ConnectorInstance]: + """Method to fetch input and output connectors by workflow. + + Args: + workflow_id (str) + + Returns: + list[ConnectorInstance] + """ + filter_params: dict[str, Any] = { + "workflow": workflow_id, + } + connector_instances = ConnectorInstance.objects.filter(**filter_params).all() + return list(connector_instances) diff --git a/backend/connector_v2/constants.py b/backend/connector_v2/constants.py new file mode 100644 index 000000000..a02512720 --- /dev/null +++ b/backend/connector_v2/constants.py @@ -0,0 +1,16 @@ +class ConnectorInstanceKey: + CONNECTOR_ID = "connector_id" + CONNECTOR_NAME = "connector_name" + CONNECTOR_TYPE = "connector_type" + CONNECTOR_MODE = "connector_mode" + CONNECTOR_VERSION = "connector_version" + CONNECTOR_AUTH = "connector_auth" + CONNECTOR_METADATA = "connector_metadata" + CONNECTOR_EXISTS = ( + "Connector with this configuration already exists in this project." + ) + DUPLICATE_API = "It appears that a duplicate call may have been made." + + +class ConnectorInstanceConstant: + USER_STORAGE = "User Storage" diff --git a/backend/connector_v2/fields.py b/backend/connector_v2/fields.py new file mode 100644 index 000000000..2a0f18c54 --- /dev/null +++ b/backend/connector_v2/fields.py @@ -0,0 +1,41 @@ +import logging +from datetime import datetime + +from connector_auth_v2.constants import SocialAuthConstants +from connector_auth_v2.models import ConnectorAuth +from django.db import models + +logger = logging.getLogger(__name__) + + +class ConnectorAuthJSONField(models.JSONField): + def from_db_value(self, value, expression, connection): # type: ignore + """Overriding default function.""" + metadata = super().from_db_value(value, expression, connection) + provider = metadata.get(SocialAuthConstants.PROVIDER) + uid = metadata.get(SocialAuthConstants.UID) + if not provider or not uid: + return metadata + + refresh_after_str = metadata.get(SocialAuthConstants.REFRESH_AFTER) + if not refresh_after_str: + return metadata + + refresh_after = datetime.strptime( + refresh_after_str, SocialAuthConstants.REFRESH_AFTER_FORMAT + ) + if datetime.now() > refresh_after: + metadata = self._refresh_tokens(provider, uid) + return metadata + + def _refresh_tokens(self, provider: str, uid: str) -> dict[str, str]: + """Retrieves PSA object and refreshes the token if necessary.""" + connector_auth: ConnectorAuth = ConnectorAuth.get_social_auth( + provider=provider, uid=uid + ) + if connector_auth: + ( + connector_metadata, + _, + ) = connector_auth.get_and_refresh_tokens() + return connector_metadata # type: ignore diff --git a/backend/connector_v2/migrations/__init__.py b/backend/connector_v2/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/connector_v2/models.py b/backend/connector_v2/models.py new file mode 100644 index 000000000..3fd535b05 --- /dev/null +++ b/backend/connector_v2/models.py @@ -0,0 +1,132 @@ +import json +import uuid +from typing import Any + +from account_v2.models import User +from connector_auth_v2.models import ConnectorAuth +from connector_processor.connector_processor import ConnectorProcessor +from connector_processor.constants import ConnectorKeys +from cryptography.fernet import Fernet +from django.conf import settings +from django.db import models +from utils.models.base_model import BaseModel +from utils.models.organization_mixin import ( + DefaultOrganizationManagerMixin, + DefaultOrganizationMixin, +) +from workflow_manager.workflow_v2.models import Workflow + +from backend.constants import FieldLengthConstants as FLC + +CONNECTOR_NAME_SIZE = 128 +VERSION_NAME_SIZE = 64 + + +class ConnectorInstanceModelManager(DefaultOrganizationManagerMixin, models.Manager): + pass + + +class ConnectorInstance(DefaultOrganizationMixin, BaseModel): + # TODO: handle all cascade deletions + class ConnectorType(models.TextChoices): + INPUT = "INPUT", "Input" + OUTPUT = "OUTPUT", "Output" + + class ConnectorMode(models.IntegerChoices): + UNKNOWN = 0, "UNKNOWN" + FILE_SYSTEM = 1, "FILE_SYSTEM" + DATABASE = 2, "DATABASE" + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + connector_name = models.TextField( + max_length=CONNECTOR_NAME_SIZE, null=False, blank=False + ) + workflow = models.ForeignKey( + Workflow, + on_delete=models.CASCADE, + related_name="connector_workflow", + null=False, + blank=False, + ) + connector_id = models.CharField(max_length=FLC.CONNECTOR_ID_LENGTH, default="") + connector_metadata = models.BinaryField(null=True) + connector_version = models.CharField(max_length=VERSION_NAME_SIZE, default="") + connector_type = models.CharField(choices=ConnectorType.choices) + # TODO: handle connector_auth cascade deletion + connector_auth = models.ForeignKey( + ConnectorAuth, + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="connector_instances", + ) + connector_mode = models.CharField( + choices=ConnectorMode.choices, + default=ConnectorMode.UNKNOWN, + db_comment="0: UNKNOWN, 1: FILE_SYSTEM, 2: DATABASE", + ) + + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="connectors_created", + null=True, + blank=True, + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="connectors_modified", + null=True, + blank=True, + ) + + # Manager + objects = ConnectorInstanceModelManager() + + def get_connector_metadata(self) -> dict[str, str]: + """Gets connector metadata and refreshes the tokens if needed in case + of OAuth.""" + tokens_refreshed = False + if self.connector_auth: + ( + self.connector_metadata, + tokens_refreshed, + ) = self.connector_auth.get_and_refresh_tokens() + if tokens_refreshed: + self.save() + return self.connector_metadata + + @staticmethod + def supportsOAuth(connector_id: str) -> bool: + return bool( + ConnectorProcessor.get_connector_data_with_key( + connector_id, ConnectorKeys.OAUTH + ) + ) + + def __str__(self) -> str: + return ( + f"Connector({self.id}, type{self.connector_type}," + f" workflow: {self.workflow})" + ) + + @property + def metadata(self) -> Any: + encryption_secret: str = settings.ENCRYPTION_KEY + cipher_suite: Fernet = Fernet(encryption_secret.encode("utf-8")) + decrypted_value = cipher_suite.decrypt( + bytes(self.connector_metadata).decode("utf-8") + ) + return json.loads(decrypted_value) + + class Meta: + db_table = "connector_instance_v2" + verbose_name = "Connector Instance" + verbose_name_plural = "Connector Instances" + constraints = [ + models.UniqueConstraint( + fields=["connector_name", "workflow", "connector_type"], + name="unique_workflow_connector", + ), + ] diff --git a/backend/connector_v2/serializers.py b/backend/connector_v2/serializers.py new file mode 100644 index 000000000..cd9727e84 --- /dev/null +++ b/backend/connector_v2/serializers.py @@ -0,0 +1,101 @@ +import json +import logging +from collections import OrderedDict +from typing import Any, Optional + +from connector_auth_v2.models import ConnectorAuth +from connector_auth_v2.pipeline.common import ConnectorAuthHelper +from connector_processor.connector_processor import ConnectorProcessor +from connector_processor.constants import ConnectorKeys +from connector_processor.exceptions import OAuthTimeOut +from connector_v2.constants import ConnectorInstanceKey as CIKey +from cryptography.fernet import Fernet +from django.conf import settings +from utils.serializer_utils import SerializerUtils + +from backend.serializers import AuditSerializer +from unstract.connectors.filesystems.ucs import UnstractCloudStorage + +from .models import ConnectorInstance + +logger = logging.getLogger(__name__) + + +class ConnectorInstanceSerializer(AuditSerializer): + class Meta: + model = ConnectorInstance + fields = "__all__" + + def validate_connector_metadata(self, value: dict[Any]) -> dict[Any]: + """Validating Json metadata This custom validation is to avoid conflict + with user input and db binary data. + + Args: + value (Any): dict of metadata + + Returns: + dict[Any]: dict of metadata + """ + return value + + def save(self, **kwargs): # type: ignore + user = self.context.get("request").user or None + connector_id: str = kwargs[CIKey.CONNECTOR_ID] + connector_oauth: Optional[ConnectorAuth] = None + if ( + ConnectorInstance.supportsOAuth(connector_id=connector_id) + and CIKey.CONNECTOR_METADATA in kwargs + ): + try: + connector_oauth = ConnectorAuthHelper.get_or_create_connector_auth( + user=user, # type: ignore + oauth_credentials=kwargs[CIKey.CONNECTOR_METADATA], + ) + kwargs[CIKey.CONNECTOR_AUTH] = connector_oauth + ( + kwargs[CIKey.CONNECTOR_METADATA], + refresh_status, + ) = connector_oauth.get_and_refresh_tokens() + except Exception as exc: + logger.error( + "Error while obtaining ConnectorAuth for connector id " + f"{connector_id}: {exc}" + ) + raise OAuthTimeOut + + connector_mode = ConnectorProcessor.get_connector_data_with_key( + connector_id, CIKey.CONNECTOR_MODE + ) + kwargs[CIKey.CONNECTOR_MODE] = connector_mode.value + + encryption_secret: str = settings.ENCRYPTION_KEY + f: Fernet = Fernet(encryption_secret.encode("utf-8")) + json_string: str = json.dumps(kwargs.get(CIKey.CONNECTOR_METADATA)) + if self.validated_data: + self.validated_data.pop(CIKey.CONNECTOR_METADATA) + + kwargs[CIKey.CONNECTOR_METADATA] = f.encrypt(json_string.encode("utf-8")) + + instance = super().save(**kwargs) + return instance + + def to_representation(self, instance: ConnectorInstance) -> dict[str, str]: + # to remove the sensitive fields being returned + rep: OrderedDict[str, Any] = super().to_representation(instance) + if instance.connector_id == UnstractCloudStorage.get_id(): + rep[CIKey.CONNECTOR_METADATA] = {} + if SerializerUtils.check_context_for_GET_or_POST(context=self.context): + rep.pop(CIKey.CONNECTOR_AUTH) + # set icon fields for UI + rep[ConnectorKeys.ICON] = ConnectorProcessor.get_connector_data_with_key( + instance.connector_id, ConnectorKeys.ICON + ) + encryption_secret: str = settings.ENCRYPTION_KEY + f: Fernet = Fernet(encryption_secret.encode("utf-8")) + + if instance.connector_metadata: + adapter_metadata = json.loads( + f.decrypt(bytes(instance.connector_metadata).decode("utf-8")) + ) + rep[CIKey.CONNECTOR_METADATA] = adapter_metadata + return rep diff --git a/backend/connector_v2/tests/conftest.py b/backend/connector_v2/tests/conftest.py new file mode 100644 index 000000000..89f1715a1 --- /dev/null +++ b/backend/connector_v2/tests/conftest.py @@ -0,0 +1,9 @@ +import pytest +from django.core.management import call_command + + +@pytest.fixture(scope="session") +def django_db_setup(django_db_blocker): # type: ignore + fixtures = ["./connector/tests/fixtures/fixtures_0001.json"] + with django_db_blocker.unblock(): + call_command("loaddata", *fixtures) diff --git a/backend/connector_v2/tests/connector_tests.py b/backend/connector_v2/tests/connector_tests.py new file mode 100644 index 000000000..f3e81d72f --- /dev/null +++ b/backend/connector_v2/tests/connector_tests.py @@ -0,0 +1,332 @@ +# mypy: ignore-errors +import pytest +from connector_v2.models import ConnectorInstance +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APITestCase + +pytestmark = pytest.mark.django_db + + +@pytest.mark.connector +class TestConnector(APITestCase): + def test_connector_list(self) -> None: + """Tests to List the connectors.""" + + url = reverse("connectors_v1-list") + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_connectors_detail(self) -> None: + """Tests to fetch a connector with given pk.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_connectors_detail_not_found(self) -> None: + """Tests for negative case to fetch non exiting key.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 768}) + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_connectors_create(self) -> None: + """Tests to create a new ConnectorInstance.""" + + url = reverse("connectors_v1-list") + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "sample_url", + "sharable_link": True, + }, + } + response = self.client.post(url, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(ConnectorInstance.objects.count(), 2) + + def test_connectors_create_with_json_list(self) -> None: + """Tests to create a new connector with list included in the json + field.""" + + url = reverse("connectors_v1-list") + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "sample_url", + "sharable_link": True, + "file_name_list": ["a1", "a2"], + }, + } + response = self.client.post(url, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(ConnectorInstance.objects.count(), 2) + + def test_connectors_create_with_nested_json(self) -> None: + """Tests to create a new connector with json field as nested json.""" + + url = reverse("connectors_v1-list") + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.post(url, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(ConnectorInstance.objects.count(), 2) + + def test_connectors_create_bad_request(self) -> None: + """Tests for negative case to throw error on a wrong access.""" + + url = reverse("connectors_v1-list") + data = { + "org": 5, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.post(url, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_connectors_update_json_field(self) -> None: + """Tests to update connector with json field update.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "new_sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.put(url, data, format="json") + drive_link = response.data["connector_metadata"]["drive_link"] + self.assertEqual(drive_link, "new_sample_url") + + def test_connectors_update(self) -> None: + """Tests to update connector update single field.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "org": 1, + "project": 1, + "created_by": 1, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "new_sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.put(url, data, format="json") + modified_by = response.data["modified_by"] + self.assertEqual(modified_by, 2) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_connectors_update_pk(self) -> None: + """Tests the PUT method for 400 error.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "org": 2, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "new_sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.put(url, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_connectors_update_json_fields(self) -> None: + """Tests to update ConnectorInstance.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "new_sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + }, + } + response = self.client.put(url, data, format="json") + nested_value = response.data["connector_metadata"]["sample_metadata_json"][ + "key1" + ] + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(nested_value, "value1") + + def test_connectors_update_json_list_fields(self) -> None: + """Tests to update connector to the third second level of json.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + "connector_metadata": { + "drive_link": "new_sample_url", + "sharable_link": True, + "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + "file_list": ["a1", "a2", "a3"], + }, + } + response = self.client.put(url, data, format="json") + nested_value = response.data["connector_metadata"]["sample_metadata_json"][ + "key1" + ] + nested_list = response.data["connector_metadata"]["file_list"] + last_val = nested_list.pop() + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(nested_value, "value1") + self.assertEqual(last_val, "a3") + + # @pytest.mark.xfail(raises=KeyError) + # def test_connectors_update_json_fields_failed(self) -> None: + # """Tests to update connector to the second level of JSON with a wrong + # key.""" + + # url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + # data = { + # "org": 1, + # "project": 1, + # "created_by": 2, + # "modified_by": 2, + # "modified_at": "2023-06-14T05:28:47.759Z", + # "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + # "connector_metadata": { + # "drive_link": "new_sample_url", + # "sharable_link": True, + # "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + # }, + # } + # response = self.client.put(url, data, format="json") + # nested_value = response.data["connector_metadata"]["sample_metadata_json"][ + # "key00" + # ] + + # @pytest.mark.xfail(raises=KeyError) + # def test_connectors_update_json_nested_failed(self) -> None: + # """Tests to update connector to test a first level of json with a wrong + # key.""" + + # url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + # data = { + # "org": 1, + # "project": 1, + # "created_by": 2, + # "modified_by": 2, + # "modified_at": "2023-06-14T05:28:47.759Z", + # "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", + # "connector_metadata": { + # "drive_link": "new_sample_url", + # "sharable_link": True, + # "sample_metadata_json": {"key1": "value1", "key2": "value2"}, + # }, + # } + # response = self.client.put(url, data, format="json") + # nested_value = response.data["connector_metadata"]["sample_metadata_jsonNew"] + + def test_connectors_update_field(self) -> None: + """Tests the PATCH method.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = {"connector_id": "e3a4512m-efgb-48d5-98a9-3983ntest"} + response = self.client.patch(url, data, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + connector_id = response.data["connector_id"] + + self.assertEqual( + connector_id, + ConnectorInstance.objects.get(connector_id=connector_id).connector_id, + ) + + def test_connectors_update_json_field_patch(self) -> None: + """Tests the PATCH method.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + data = { + "connector_metadata": { + "drive_link": "patch_update_url", + "sharable_link": True, + "sample_metadata_json": { + "key1": "patch_update1", + "key2": "value2", + }, + } + } + + response = self.client.patch(url, data, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + drive_link = response.data["connector_metadata"]["drive_link"] + + self.assertEqual(drive_link, "patch_update_url") + + def test_connectors_delete(self) -> None: + """Tests the DELETE method.""" + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + response = self.client.delete(url, format="json") + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + + url = reverse("connectors_v1-detail", kwargs={"pk": 1}) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) diff --git a/backend/connector_v2/tests/fixtures/fixtures_0001.json b/backend/connector_v2/tests/fixtures/fixtures_0001.json new file mode 100644 index 000000000..55b39e6d8 --- /dev/null +++ b/backend/connector_v2/tests/fixtures/fixtures_0001.json @@ -0,0 +1,67 @@ +[ + { + "model": "account.org", + "pk": 1, + "fields": { + "org_name": "Zipstack", + "created_by": 1, + "modified_by": 1, + "modified_at": "2023-06-14T05:28:47.739Z" + } + }, + { + "model": "account.user", + "pk": 1, + "fields": { + "org": 1, + "email": "johndoe@gmail.com", + "first_name": "John", + "last_name": "Doe", + "is_admin": true, + "created_by": null, + "modified_by": null, + "modified_at": "2023-06-14T05:28:47.744Z" + } + }, + { + "model": "account.user", + "pk": 2, + "fields": { + "org": 1, + "email": "user1@gmail.com", + "first_name": "Ron", + "last_name": "Stone", + "is_admin": false, + "created_by": 1, + "modified_by": 1, + "modified_at": "2023-06-14T05:28:47.750Z" + } + }, + { + "model": "project.project", + "pk": 1, + "fields": { + "org": 1, + "project_name": "Unstract Test", + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z" + } + }, + { + "model": "connector.connector", + "pk": 1, + "fields": { + "org": 1, + "project": 1, + "created_by": 2, + "modified_by": 2, + "modified_at": "2023-06-14T05:28:47.759Z", + "connector_id": "e38a59b7-efbb-48d5-9da6-3a0cf2d882a0", + "connector_metadata": { + "connector_type": "gdrive", + "auth_type": "oauth" + } + } + } +] diff --git a/backend/connector_v2/unstract_account.py b/backend/connector_v2/unstract_account.py new file mode 100644 index 000000000..ac270f6d6 --- /dev/null +++ b/backend/connector_v2/unstract_account.py @@ -0,0 +1,71 @@ +import logging +import os + +import boto3 +from botocore.exceptions import ClientError +from django.conf import settings + +logger = logging.getLogger(__name__) + + +# TODO: UnstractAccount need to be pluggable +class UnstractAccount: + def __init__(self, tenant: str, username: str) -> None: + self.tenant = tenant + self.username = username + + def provision_s3_storage(self) -> None: + access_key = settings.GOOGLE_STORAGE_ACCESS_KEY_ID + secret_key = settings.GOOGLE_STORAGE_SECRET_ACCESS_KEY + bucket_name: str = settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME + + s3 = boto3.client( + "s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + endpoint_url="https://storage.googleapis.com", + ) + + # Check if folder exists and create if it is not available + account_folder = f"{self.tenant}/{self.username}/input/examples/" + try: + logger.info(f"Checking if folder {account_folder} exists...") + s3.head_object(Bucket=bucket_name, Key=account_folder) + logger.info(f"Folder {account_folder} already exists") + except ClientError as e: + logger.info(f"{bucket_name} Folder {account_folder} does not exist") + if e.response["Error"]["Code"] == "404": + logger.info(f"Folder {account_folder} does not exist. Creating it...") + s3.put_object(Bucket=bucket_name, Key=account_folder) + account_folder_output = f"{self.tenant}/{self.username}/output/" + s3.put_object(Bucket=bucket_name, Key=account_folder_output) + else: + logger.error(f"Error checking folder {account_folder}: {e}") + raise e + + def upload_sample_files(self) -> None: + access_key = settings.GOOGLE_STORAGE_ACCESS_KEY_ID + secret_key = settings.GOOGLE_STORAGE_SECRET_ACCESS_KEY + bucket_name: str = settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME + + s3 = boto3.client( + "s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + endpoint_url="https://storage.googleapis.com", + ) + + folder = f"{self.tenant}/{self.username}/input/examples/" + + local_path = f"{os.path.dirname(__file__)}/static" + for root, dirs, files in os.walk(local_path): + for file in files: + local_file_path = os.path.join(root, file) + s3_key = os.path.join( + folder, os.path.relpath(local_file_path, local_path) + ) + logger.info( + f"Uploading: {local_file_path} => s3://{bucket_name}/{s3_key}" + ) + s3.upload_file(local_file_path, bucket_name, s3_key) + logger.info(f"Uploaded: {local_file_path}") diff --git a/backend/connector_v2/urls.py b/backend/connector_v2/urls.py new file mode 100644 index 000000000..424033528 --- /dev/null +++ b/backend/connector_v2/urls.py @@ -0,0 +1,21 @@ +from django.urls import path +from rest_framework.urlpatterns import format_suffix_patterns + +from .views import ConnectorInstanceViewSet as CIViewSet + +connector_list = CIViewSet.as_view({"get": "list", "post": "create"}) +connector_detail = CIViewSet.as_view( + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", + } +) + +urlpatterns = format_suffix_patterns( + [ + path("connector/", connector_list, name="connector-list"), + path("connector//", connector_detail, name="connector-detail"), + ] +) diff --git a/backend/connector_v2/views.py b/backend/connector_v2/views.py new file mode 100644 index 000000000..3f098d355 --- /dev/null +++ b/backend/connector_v2/views.py @@ -0,0 +1,124 @@ +import logging +from typing import Any, Optional + +from account_v2.custom_exceptions import DuplicateData +from connector_auth_v2.constants import ConnectorAuthKey +from connector_auth_v2.exceptions import CacheMissException, MissingParamException +from connector_auth_v2.pipeline.common import ConnectorAuthHelper +from connector_processor.exceptions import OAuthTimeOut +from connector_v2.constants import ConnectorInstanceKey as CIKey +from django.db import IntegrityError +from django.db.models import QuerySet +from rest_framework import status, viewsets +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from utils.filtering import FilterHelper + +from backend.constants import RequestKey + +from .models import ConnectorInstance +from .serializers import ConnectorInstanceSerializer + +logger = logging.getLogger(__name__) + + +class ConnectorInstanceViewSet(viewsets.ModelViewSet): + versioning_class = URLPathVersioning + queryset = ConnectorInstance.objects.all() + serializer_class = ConnectorInstanceSerializer + + def get_queryset(self) -> Optional[QuerySet]: + filter_args = FilterHelper.build_filter_args( + self.request, + RequestKey.WORKFLOW, + RequestKey.CREATED_BY, + CIKey.CONNECTOR_TYPE, + CIKey.CONNECTOR_MODE, + ) + if filter_args: + queryset = ConnectorInstance.objects.filter(**filter_args) + else: + queryset = ConnectorInstance.objects.all() + return queryset + + def _get_connector_metadata(self, connector_id: str) -> Optional[dict[str, str]]: + """Gets connector metadata for the ConnectorInstance. + + For non oauth based - obtains from request + For oauth based - obtains from cache + + Raises: + e: MissingParamException, CacheMissException + + Returns: + dict[str, str]: Connector creds dict to connect with + """ + connector_metadata = None + if ConnectorInstance.supportsOAuth(connector_id=connector_id): + logger.info(f"Fetching oauth data for {connector_id}") + oauth_key = self.request.query_params.get(ConnectorAuthKey.OAUTH_KEY) + if oauth_key is None: + raise MissingParamException(param=ConnectorAuthKey.OAUTH_KEY) + connector_metadata = ConnectorAuthHelper.get_oauth_creds_from_cache( + cache_key=oauth_key, delete_key=True + ) + if connector_metadata is None: + raise CacheMissException( + f"Couldn't find credentials for {oauth_key} from cache" + ) + else: + connector_metadata = self.request.data.get(CIKey.CONNECTOR_METADATA) + return connector_metadata + + def perform_update(self, serializer: ConnectorInstanceSerializer) -> None: + connector_metadata = None + connector_id = self.request.data.get( + CIKey.CONNECTOR_ID, serializer.instance.connector_id + ) + try: + connector_metadata = self._get_connector_metadata(connector_id) + # TODO: Handle specific exceptions instead of using a generic Exception. + except Exception: + # Suppress here to not shout during partial updates + pass + # Take metadata from instance itself since update + # is performed on other fields of ConnectorInstance + if connector_metadata is None: + connector_metadata = serializer.instance.connector_metadata + serializer.save( + connector_id=connector_id, + connector_metadata=connector_metadata, + modified_by=self.request.user, + ) # type: ignore + + def perform_create(self, serializer: ConnectorInstanceSerializer) -> None: + connector_metadata = None + connector_id = self.request.data.get(CIKey.CONNECTOR_ID) + try: + connector_metadata = self._get_connector_metadata(connector_id=connector_id) + # TODO: Handle specific exceptions instead of using a generic Exception. + except Exception as exc: + logger.error(f"Error while obtaining ConnectorAuth: {exc}") + raise OAuthTimeOut + serializer.save( + connector_id=connector_id, + connector_metadata=connector_metadata, + created_by=self.request.user, + modified_by=self.request.user, + ) # type: ignore + + def create(self, request: Any) -> Response: + # Overriding default exception behavior + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + try: + self.perform_create(serializer) + except IntegrityError: + raise DuplicateData( + f"{CIKey.CONNECTOR_EXISTS}, \ + {CIKey.DUPLICATE_API}" + ) + headers = self.get_success_headers(serializer.data) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) diff --git a/backend/file_management/file_management_helper.py b/backend/file_management/file_management_helper.py index 0c13812b4..6293f2583 100644 --- a/backend/file_management/file_management_helper.py +++ b/backend/file_management/file_management_helper.py @@ -6,7 +6,6 @@ from typing import Any import magic -from connector.models import ConnectorInstance from django.conf import settings from django.http import StreamingHttpResponse from file_management.exceptions import ( @@ -24,8 +23,15 @@ from fsspec import AbstractFileSystem from pydrive2.files import ApiRequestError +from backend.constants import FeatureFlag from unstract.connectors.filesystems import connectors as fs_connectors from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem +from unstract.flags.feature_flag import check_feature_flag_status + +if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): + from connector_v2.models import ConnectorInstance +else: + from connector.models import ConnectorInstance class FileManagerHelper: diff --git a/backend/file_management/views.py b/backend/file_management/views.py index 7ff72a777..3ecfd0904 100644 --- a/backend/file_management/views.py +++ b/backend/file_management/views.py @@ -1,7 +1,6 @@ import logging from typing import Any -from connector.models import ConnectorInstance from django.http import HttpRequest from file_management.exceptions import ( ConnectorInstanceNotFound, @@ -17,15 +16,24 @@ FileUploadSerializer, ) from oauth2client.client import HttpAccessTokenRefreshError -from prompt_studio.prompt_studio_document_manager.models import DocumentManager from rest_framework import serializers, status, viewsets from rest_framework.decorators import action from rest_framework.response import Response from rest_framework.versioning import URLPathVersioning from utils.user_session import UserSessionUtils +from backend.constants import FeatureFlag from unstract.connectors.exceptions import ConnectorError from unstract.connectors.filesystems.local_storage.local_storage import LocalStorageFS +from unstract.flags.feature_flag import check_feature_flag_status + +if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): + from connector_v2.models import ConnectorInstance + from prompt_studio.prompt_studio_document_manager_v2.models import DocumentManager + +else: + from connector.models import ConnectorInstance + from prompt_studio.prompt_studio_document_manager.models import DocumentManager logger = logging.getLogger(__name__) diff --git a/backend/pipeline_v2/__init__.py b/backend/pipeline_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/pipeline_v2/constants.py b/backend/pipeline_v2/constants.py new file mode 100644 index 000000000..2475ad79f --- /dev/null +++ b/backend/pipeline_v2/constants.py @@ -0,0 +1,66 @@ +class PipelineConstants: + """Constants for Pipelines.""" + + TYPE = "type" + ETL_PIPELINE = "ETL" + TASK_PIPELINE = "TASK" + ETL = "etl" + TASK = "task" + CREATE_ACTION = "create" + UPDATE_ACTION = "update" + PIPELINE_GUID = "id" + ACTION = "action" + NOT_CONFIGURED = "Connector not configured." + SOURCE_NOT_CONFIGURED = "Source not configured." + DESTINATION_NOT_CONFIGURED = "Destination not configured." + SOURCE_ICON = "source_icon" + DESTINATION_ICON = "destination_icon" + SOURCE_NAME = "source_name" + DESTINATION_NAME = "destination_name" + INPUT_FILE = "input_file_connector" + INPUT_DB = "input_db_connector" + OUTPUT_FILE = "output_file_connector" + OUTPUT_DB = "output_db_connector" + SOURCE = "source" + DEST = "dest" + + +class PipelineExecutionKey: + PIPELINE = "pipeline" + EXECUTION = "execution" + + +class PipelineKey: + """Constants for the Pipeline model.""" + + PIPELINE_GUID = "id" + PIPELINE_NAME = "pipeline_name" + WORKFLOW = "workflow" + APP_ID = "app_id" + ACTIVE = "active" + SCHEDULED = "scheduled" + PIPELINE_TYPE = "pipeline_type" + RUN_COUNT = "run_count" + LAST_RUN_TIME = "last_run_time" + LAST_RUN_STATUS = "last_run_status" + # Used by serializer + CRON_DATA = "cron_data" + WORKFLOW_NAME = "workflow_name" + WORKFLOW_ID = "workflow_id" + CRON_STRING = "cron_string" + PIPELINE_ID = "pipeline_id" + + +class PipelineErrors: + PIPELINE_EXISTS = "Pipeline with this configuration might already exist or some mandatory field is missing." # noqa: E501 + DUPLICATE_API = "It appears that a duplicate call may have been made." + INVALID_WF = "The provided workflow does not exist" + + +class PipelineURL: + """Constants for URL names.""" + + DETAIL = "pipeline-detail" + EXECUTIONS = "pipeline-executions" + LIST = "pipeline-list" + EXECUTE = "tenant:pipeline-execute" diff --git a/backend/pipeline_v2/exceptions.py b/backend/pipeline_v2/exceptions.py new file mode 100644 index 000000000..8da26c481 --- /dev/null +++ b/backend/pipeline_v2/exceptions.py @@ -0,0 +1,46 @@ +from typing import Optional + +from rest_framework.exceptions import APIException + + +class WorkflowTriggerError(APIException): + status_code = 400 + default_detail = "Error triggering workflow. Pipeline created" + + +class PipelineExecuteError(APIException): + status_code = 500 + default_detail = "Error executing pipline" + + +class InactivePipelineError(APIException): + status_code = 422 + default_detail = "Pipeline is inactive, please activate the pipeline" + + def __init__( + self, + pipeline_name: Optional[str] = None, + detail: Optional[str] = None, + code: Optional[str] = None, + ): + if pipeline_name: + self.default_detail = ( + f"Pipeline '{pipeline_name}' is inactive, " + "please activate the pipeline" + ) + super().__init__(detail, code) + + +class MandatoryPipelineType(APIException): + status_code = 400 + default_detail = "Pipeline type is mandatory" + + +class MandatoryWorkflowId(APIException): + status_code = 400 + default_detail = "Workflow ID is mandatory" + + +class MandatoryCronSchedule(APIException): + status_code = 400 + default_detail = "Cron schedule is mandatory" diff --git a/backend/pipeline_v2/execution_view.py b/backend/pipeline_v2/execution_view.py new file mode 100644 index 000000000..fa514615a --- /dev/null +++ b/backend/pipeline_v2/execution_view.py @@ -0,0 +1,35 @@ +from permissions.permission import IsOwner +from pipeline_v2.serializers.execute import DateRangeSerializer +from rest_framework import viewsets +from rest_framework.versioning import URLPathVersioning +from utils.pagination import CustomPagination +from workflow_manager.workflow_v2.models.execution import WorkflowExecution +from workflow_manager.workflow_v2.serializers import WorkflowExecutionSerializer + + +class PipelineExecutionViewSet(viewsets.ModelViewSet): + versioning_class = URLPathVersioning + permission_classes = [IsOwner] + serializer_class = WorkflowExecutionSerializer + pagination_class = CustomPagination + + CREATED_AT_FIELD_DESC = "-created_at" + START_DATE_FIELD = "start_date" + END_DATE_FIELD = "end_date" + + def get_queryset(self): + # Get the pipeline_id from the URL path + pipeline_id = self.kwargs.get("pk") + queryset = WorkflowExecution.objects.filter(pipeline_id=pipeline_id) + + # Validate start_date and end_date parameters using DateRangeSerializer + date_range_serializer = DateRangeSerializer(data=self.request.query_params) + date_range_serializer.is_valid(raise_exception=True) + start_date = date_range_serializer.validated_data.get(self.START_DATE_FIELD) + end_date = date_range_serializer.validated_data.get(self.END_DATE_FIELD) + + if start_date and end_date: + queryset = queryset.filter(created_at__range=(start_date, end_date)) + + queryset = queryset.order_by(self.CREATED_AT_FIELD_DESC) + return queryset diff --git a/backend/pipeline_v2/manager.py b/backend/pipeline_v2/manager.py new file mode 100644 index 000000000..d5297107d --- /dev/null +++ b/backend/pipeline_v2/manager.py @@ -0,0 +1,59 @@ +import logging +from typing import Any, Optional + +from django.conf import settings +from django.urls import reverse +from pipeline_v2.constants import PipelineKey, PipelineURL +from pipeline_v2.models import Pipeline +from pipeline_v2.pipeline_processor import PipelineProcessor +from rest_framework.request import Request +from rest_framework.response import Response +from utils.request.constants import RequestConstants +from workflow_manager.workflow_v2.constants import WorkflowExecutionKey, WorkflowKey +from workflow_manager.workflow_v2.views import WorkflowViewSet + +from backend.constants import RequestHeader + +logger = logging.getLogger(__name__) + + +class PipelineManager: + """Helps manage the execution and scheduling of pipelines.""" + + @staticmethod + def execute_pipeline( + request: Request, + pipeline_id: str, + execution_id: Optional[str] = None, + ) -> Response: + """Used to execute a pipeline. + + Args: + pipeline_id (str): UUID of the pipeline to execute + execution_id (Optional[str], optional): + Uniquely identifies an execution. Defaults to None. + """ + logger.info(f"Executing pipeline {pipeline_id}, execution: {execution_id}") + pipeline: Pipeline = PipelineProcessor.initialize_pipeline_sync(pipeline_id) + # TODO: Use DRF's request and as_view() instead + request.data[WorkflowKey.WF_ID] = pipeline.workflow.id + if execution_id is not None: + request.data[WorkflowExecutionKey.EXECUTION_ID] = execution_id + wf_viewset = WorkflowViewSet() + return wf_viewset.execute(request=request, pipeline_guid=str(pipeline.pk)) + + @staticmethod + def get_pipeline_execution_data_for_scheduled_run( + pipeline_id: str, + ) -> Optional[dict[str, Any]]: + """Gets the required data to be passed while executing a pipeline Any + changes to pipeline execution needs to be propagated here.""" + callback_url = settings.DJANGO_APP_BACKEND_URL + reverse(PipelineURL.EXECUTE) + job_headers = {RequestHeader.X_API_KEY: settings.INTERNAL_SERVICE_API_KEY} + job_kwargs = { + RequestConstants.VERB: "POST", + RequestConstants.URL: callback_url, + RequestConstants.HEADERS: job_headers, + RequestConstants.DATA: {PipelineKey.PIPELINE_ID: pipeline_id}, + } + return job_kwargs diff --git a/backend/pipeline_v2/models.py b/backend/pipeline_v2/models.py new file mode 100644 index 000000000..35ef1914c --- /dev/null +++ b/backend/pipeline_v2/models.py @@ -0,0 +1,111 @@ +import uuid + +from account_v2.models import User +from django.db import models +from utils.models.base_model import BaseModel +from utils.models.organization_mixin import ( + DefaultOrganizationManagerMixin, + DefaultOrganizationMixin, +) +from workflow_manager.workflow_v2.models.workflow import Workflow + +from backend.constants import FieldLengthConstants as FieldLength + +APP_ID_LENGTH = 32 +PIPELINE_NAME_LENGTH = 32 + + +class PipelineModelManager(DefaultOrganizationManagerMixin, models.Manager): + pass + + +class Pipeline(DefaultOrganizationMixin, BaseModel): + """Model to hold data related to Pipelines.""" + + class PipelineType(models.TextChoices): + ETL = "ETL", "ETL" + TASK = "TASK", "TASK" + DEFAULT = "DEFAULT", "Default" + APP = "APP", "App" + + class PipelineStatus(models.TextChoices): + SUCCESS = "SUCCESS", "Success" + FAILURE = "FAILURE", "Failure" + INPROGRESS = "INPROGRESS", "Inprogress" + YET_TO_START = "YET_TO_START", "Yet to start" + RESTARTING = "RESTARTING", "Restarting" + PAUSED = "PAUSED", "Paused" + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + pipeline_name = models.CharField(max_length=PIPELINE_NAME_LENGTH, default="") + workflow = models.ForeignKey( + Workflow, + on_delete=models.CASCADE, + related_name="pipelines", + null=False, + blank=False, + ) + # Added as text field until a model for App is included. + app_id = models.TextField(null=True, blank=True, max_length=APP_ID_LENGTH) + active = models.BooleanField(default=False) # TODO: Add dbcomment + scheduled = models.BooleanField(default=False) # TODO: Add dbcomment + cron_string = models.TextField( + db_comment="UNIX cron string", + null=False, + blank=False, + max_length=FieldLength.CRON_LENGTH, + ) + pipeline_type = models.CharField( + choices=PipelineType.choices, default=PipelineType.DEFAULT + ) + run_count = models.IntegerField(default=0) + last_run_time = models.DateTimeField(null=True, blank=True) + last_run_status = models.CharField( + choices=PipelineStatus.choices, default=PipelineStatus.YET_TO_START + ) + app_icon = models.URLField( + null=True, blank=True, db_comment="Field to store icon url for Apps" + ) + app_url = models.URLField( + null=True, blank=True, db_comment="Stores deployed URL for App" + ) + # TODO: Change this to a Forgein key once the bundle is created. + access_control_bundle_id = models.TextField(null=True, blank=True) + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="pipelines_created", + null=True, + blank=True, + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="pipelines_modified", + null=True, + blank=True, + ) + + # Manager + objects = PipelineModelManager() + + def __str__(self) -> str: + return f"Pipeline({self.id})" + + class Meta: + verbose_name = "Pipeline" + verbose_name_plural = "Pipelines" + db_table = "pipeline_v2" + constraints = [ + models.UniqueConstraint( + fields=["id", "pipeline_type"], + name="unique_pipeline_entity", + ), + models.UniqueConstraint( + fields=["pipeline_name", "organization"], + name="unique_pipeline_name", + ), + ] + + def is_active(self) -> bool: + return bool(self.active) diff --git a/backend/pipeline_v2/pipeline_processor.py b/backend/pipeline_v2/pipeline_processor.py new file mode 100644 index 000000000..fa3328db2 --- /dev/null +++ b/backend/pipeline_v2/pipeline_processor.py @@ -0,0 +1,81 @@ +import logging +from typing import Optional + +from django.utils import timezone +from pipeline_v2.exceptions import InactivePipelineError +from pipeline_v2.models import Pipeline + +logger = logging.getLogger(__name__) + + +class PipelineProcessor: + @staticmethod + def initialize_pipeline_sync(pipeline_id: str) -> Pipeline: + """Fetches and initializes the sync for a pipeline. + + Args: + pipeline_id (str): UUID of the pipeline to sync + """ + pipeline: Pipeline = PipelineProcessor.fetch_pipeline(pipeline_id) + pipeline.run_count = pipeline.run_count + 1 + return PipelineProcessor._update_pipeline_status( + pipeline=pipeline, + status=Pipeline.PipelineStatus.RESTARTING, + is_end=False, + ) + + @staticmethod + def fetch_pipeline(pipeline_id: str, check_active: bool = True) -> Pipeline: + """Retrieves and checks for an active pipeline. + + Raises: + InactivePipelineError: If an active pipeline is not found + """ + pipeline: Pipeline = Pipeline.objects.get(pk=pipeline_id) + if check_active and not pipeline.is_active(): + logger.error(f"Inactive pipeline fetched: {pipeline_id}") + raise InactivePipelineError(pipeline_name=pipeline.pipeline_name) + return pipeline + + @staticmethod + def _update_pipeline_status( + pipeline: Pipeline, + status: tuple[str, str], + is_end: bool, + is_active: Optional[bool] = None, + ) -> Pipeline: + """Updates pipeline status during execution. + + Raises: + PipelineSaveError: Exception while saving a pipeline + + Returns: + Pipeline: Updated pipeline + """ + if is_end: + pipeline.last_run_time = timezone.now() + if status: + pipeline.last_run_status = status + if is_active is not None: + pipeline.active = is_active + + pipeline.save() + return pipeline + + @staticmethod + def update_pipeline( + pipeline_guid: Optional[str], + status: tuple[str, str], + is_active: Optional[bool] = None, + ) -> None: + if not pipeline_guid: + return + # Skip check if we are enabling an inactive pipeline + check_active = not is_active + pipeline: Pipeline = PipelineProcessor.fetch_pipeline( + pipeline_id=pipeline_guid, check_active=check_active + ) + PipelineProcessor._update_pipeline_status( + pipeline=pipeline, is_end=True, status=status, is_active=is_active + ) + logger.info(f"Updated pipeline {pipeline_guid} status: {status}") diff --git a/backend/pipeline_v2/serializers/crud.py b/backend/pipeline_v2/serializers/crud.py new file mode 100644 index 000000000..bb051abef --- /dev/null +++ b/backend/pipeline_v2/serializers/crud.py @@ -0,0 +1,104 @@ +import logging +from collections import OrderedDict +from typing import Any + +from connector_processor.connector_processor import ConnectorProcessor +from connector_v2.connector_instance_helper import ConnectorInstanceHelper +from connector_v2.models import ConnectorInstance +from pipeline_v2.constants import PipelineConstants as PC +from pipeline_v2.constants import PipelineKey as PK +from pipeline_v2.models import Pipeline +from scheduler.helper import SchedulerHelper +from utils.serializer_utils import SerializerUtils + +from backend.serializers import AuditSerializer +from unstract.connectors.connectorkit import Connectorkit + +logger = logging.getLogger(__name__) + + +class PipelineSerializer(AuditSerializer): + + class Meta: + model = Pipeline + fields = "__all__" + + def create(self, validated_data: dict[str, Any]) -> Any: + # TODO: Deduce pipeline type based on WF? + validated_data[PK.ACTIVE] = True # Add this as default instead? + validated_data[PK.SCHEDULED] = True + return super().create(validated_data) + + def save(self, **kwargs: Any) -> Pipeline: + pipeline: Pipeline = super().save(**kwargs) + SchedulerHelper.add_job( + str(pipeline.pk), + cron_string=pipeline.cron_string, + ) + return pipeline + + def _get_name_and_icon(self, connectors: list[Any], connector_id: Any) -> Any: + for obj in connectors: + if obj["id"] == connector_id: + return obj["name"], obj["icon"] + return PC.NOT_CONFIGURED, None + + def _add_connector_data( + self, + repr: OrderedDict[str, Any], + connector_instance_list: list[Any], + connectors: list[Any], + ) -> OrderedDict[str, Any]: + """Adds connector Input/Output data. + + Args: + sef (_type_): _description_ + repr (OrderedDict[str, Any]): _description_ + + Returns: + OrderedDict[str, Any]: _description_ + """ + repr[PC.SOURCE_NAME] = PC.NOT_CONFIGURED + repr[PC.DESTINATION_NAME] = PC.NOT_CONFIGURED + for instance in connector_instance_list: + if instance.connector_type == "INPUT": + repr[PC.SOURCE_NAME], repr[PC.SOURCE_ICON] = self._get_name_and_icon( + connectors=connectors, + connector_id=instance.connector_id, + ) + if instance.connector_type == "OUTPUT": + repr[PC.DESTINATION_NAME], repr[PC.DESTINATION_ICON] = ( + self._get_name_and_icon( + connectors=connectors, + connector_id=instance.connector_id, + ) + ) + return repr + + def to_representation(self, instance: Pipeline) -> OrderedDict[str, Any]: + """To set Source, Destination & Agency for Pipelines.""" + repr: OrderedDict[str, Any] = super().to_representation(instance) + + connector_kit = Connectorkit() + connectors = connector_kit.get_connectors_list() + + if SerializerUtils.check_context_for_GET_or_POST(context=self.context): + workflow = instance.workflow + connector_instance_list = ConnectorInstanceHelper.get_input_output_connector_instances_by_workflow( # noqa + workflow.id + ) + repr[PK.WORKFLOW_ID] = workflow.id + repr[PK.WORKFLOW_NAME] = workflow.workflow_name + repr[PK.CRON_STRING] = repr.pop(PK.CRON_STRING) + repr = self._add_connector_data( + repr=repr, + connector_instance_list=connector_instance_list, + connectors=connectors, + ) + + return repr + + def get_connector_data(self, connector: ConnectorInstance, key: str) -> Any: + return ConnectorProcessor.get_connector_data_with_key( + connector.connector_id, key + ) diff --git a/backend/pipeline_v2/serializers/execute.py b/backend/pipeline_v2/serializers/execute.py new file mode 100644 index 000000000..4e5a7fd58 --- /dev/null +++ b/backend/pipeline_v2/serializers/execute.py @@ -0,0 +1,24 @@ +import logging + +from pipeline_v2.models import Pipeline +from rest_framework import serializers + +logger = logging.getLogger(__name__) + + +class PipelineExecuteSerializer(serializers.Serializer): + # TODO: Add pipeline as a read_only related field + pipeline_id = serializers.UUIDField() + execution_id = serializers.UUIDField(required=False) + + def validate_pipeline_id(self, value: str) -> str: + try: + Pipeline.objects.get(pk=value) + except Pipeline.DoesNotExist: + raise serializers.ValidationError("Invalid pipeline ID") + return value + + +class DateRangeSerializer(serializers.Serializer): + start_date = serializers.DateTimeField(required=False) + end_date = serializers.DateTimeField(required=False) diff --git a/backend/pipeline_v2/serializers/update.py b/backend/pipeline_v2/serializers/update.py new file mode 100644 index 000000000..6fc9f66dc --- /dev/null +++ b/backend/pipeline_v2/serializers/update.py @@ -0,0 +1,14 @@ +from pipeline_v2.models import Pipeline +from rest_framework import serializers + + +class PipelineUpdateSerializer(serializers.Serializer): + pipeline_id = serializers.UUIDField(required=True) + active = serializers.BooleanField(required=True) + + def validate_pipeline_id(self, value: str) -> str: + try: + Pipeline.objects.get(pk=value) + except Pipeline.DoesNotExist: + raise serializers.ValidationError("Invalid pipeline ID") + return value diff --git a/backend/pipeline_v2/urls.py b/backend/pipeline_v2/urls.py new file mode 100644 index 000000000..d6950633d --- /dev/null +++ b/backend/pipeline_v2/urls.py @@ -0,0 +1,41 @@ +from django.urls import path +from pipeline_v2.constants import PipelineURL +from pipeline_v2.execution_view import PipelineExecutionViewSet +from pipeline_v2.views import PipelineViewSet +from rest_framework.urlpatterns import format_suffix_patterns + +pipeline_list = PipelineViewSet.as_view( + { + "get": "list", + "post": "create", + } +) +execution_list = PipelineExecutionViewSet.as_view( + { + "get": "list", + } +) +pipeline_detail = PipelineViewSet.as_view( + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", + } +) + +pipeline_execute = PipelineViewSet.as_view({"post": "execute"}) + + +urlpatterns = format_suffix_patterns( + [ + path("pipeline/", pipeline_list, name=PipelineURL.LIST), + path("pipeline//", pipeline_detail, name=PipelineURL.DETAIL), + path( + "pipeline//executions/", + execution_list, + name=PipelineURL.EXECUTIONS, + ), + path("pipeline/execute/", pipeline_execute, name=PipelineURL.EXECUTE), + ] +) diff --git a/backend/pipeline_v2/views.py b/backend/pipeline_v2/views.py new file mode 100644 index 000000000..553b0179a --- /dev/null +++ b/backend/pipeline_v2/views.py @@ -0,0 +1,118 @@ +import logging +from typing import Any, Optional + +from account_v2.custom_exceptions import DuplicateData +from django.db import IntegrityError +from django.db.models import QuerySet +from permissions.permission import IsOwner +from pipeline_v2.constants import ( + PipelineConstants, + PipelineErrors, + PipelineExecutionKey, +) +from pipeline_v2.constants import PipelineKey as PK +from pipeline_v2.manager import PipelineManager +from pipeline_v2.models import Pipeline +from pipeline_v2.pipeline_processor import PipelineProcessor +from pipeline_v2.serializers.crud import PipelineSerializer +from pipeline_v2.serializers.execute import ( + PipelineExecuteSerializer as ExecuteSerializer, +) +from pipeline_v2.serializers.update import PipelineUpdateSerializer +from rest_framework import serializers, status, viewsets +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from scheduler.helper import SchedulerHelper + +logger = logging.getLogger(__name__) + + +class PipelineViewSet(viewsets.ModelViewSet): + versioning_class = URLPathVersioning + queryset = Pipeline.objects.all() + permission_classes = [IsOwner] + serializer_class = PipelineSerializer + + def get_queryset(self) -> Optional[QuerySet]: + type = self.request.query_params.get(PipelineConstants.TYPE) + if type is not None: + queryset = Pipeline.objects.filter( + created_by=self.request.user, pipeline_type=type + ) + return queryset + elif type is None: + queryset = Pipeline.objects.filter(created_by=self.request.user) + return queryset + + def get_serializer_class(self) -> serializers.Serializer: + if self.action == "execute": + return ExecuteSerializer + else: + return PipelineSerializer + + # TODO: Refactor to perform an action with explicit arguments + # For eg, passing pipeline ID and with_log=False -> executes pipeline + # For FE however we call the same API twice + # (first call generates execution ID) + def execute(self, request: Request) -> Response: + serializer: ExecuteSerializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + execution_id = serializer.validated_data.get("execution_id", None) + pipeline_id = serializer.validated_data[PK.PIPELINE_ID] + + execution = PipelineManager.execute_pipeline( + request=request, + pipeline_id=pipeline_id, + execution_id=execution_id, + ) + pipeline: Pipeline = PipelineProcessor.fetch_pipeline(pipeline_id) + serializer = PipelineSerializer(pipeline) + response_data = { + PipelineExecutionKey.PIPELINE: serializer.data, + PipelineExecutionKey.EXECUTION: execution.data, + } + return Response(data=response_data, status=status.HTTP_200_OK) + + def create(self, request: Request) -> Response: + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + try: + serializer.save() + except IntegrityError: + raise DuplicateData( + f"{PipelineErrors.PIPELINE_EXISTS}, " f"{PipelineErrors.DUPLICATE_API}" + ) + return Response(data=serializer.data, status=status.HTTP_201_CREATED) + + def perform_destroy(self, instance: Pipeline) -> None: + pipeline_to_remove = str(instance.pk) + super().perform_destroy(instance) + return SchedulerHelper.remove_job(pipeline_to_remove) + + def partial_update(self, request: Request, pk: Any = None) -> Response: + serializer = PipelineUpdateSerializer(data=request.data) + if serializer.is_valid(): + pipeline_id = serializer.validated_data.get("pipeline_id") + active = serializer.validated_data.get("active") + try: + if active: + SchedulerHelper.resume_job(pipeline_id) + else: + SchedulerHelper.pause_job(pipeline_id) + except Exception as e: + logger.error(f"Failed to update pipeline status: {e}") + return Response( + {"error": "Failed to update pipeline status"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return Response( + { + "status": "success", + "message": f"Pipeline {pipeline_id} status updated", + }, + status=status.HTTP_200_OK, + ) + else: + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/backend/prompt_studio/prompt_profile_manager_v2/__init__.py b/backend/prompt_studio/prompt_profile_manager_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/prompt_studio/prompt_profile_manager_v2/admin.py b/backend/prompt_studio/prompt_profile_manager_v2/admin.py new file mode 100644 index 000000000..6878f3f43 --- /dev/null +++ b/backend/prompt_studio/prompt_profile_manager_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import ProfileManager + +admin.site.register(ProfileManager) diff --git a/backend/prompt_studio/prompt_profile_manager_v2/apps.py b/backend/prompt_studio/prompt_profile_manager_v2/apps.py new file mode 100644 index 000000000..02d4929c4 --- /dev/null +++ b/backend/prompt_studio/prompt_profile_manager_v2/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class ProfileManager(AppConfig): + name = "prompt_studio.prompt_profile_manager_v2" diff --git a/backend/prompt_studio/prompt_profile_manager_v2/constants.py b/backend/prompt_studio/prompt_profile_manager_v2/constants.py new file mode 100644 index 000000000..6540b58ee --- /dev/null +++ b/backend/prompt_studio/prompt_profile_manager_v2/constants.py @@ -0,0 +1,18 @@ +class ProfileManagerKeys: + CREATED_BY = "created_by" + TOOL_ID = "tool_id" + PROMPTS = "prompts" + ADAPTER_NAME = "adapter_name" + LLM = "llm" + VECTOR_STORE = "vector_store" + EMBEDDING_MODEL = "embedding_model" + X2TEXT = "x2text" + PROMPT_STUDIO_TOOL = "prompt_studio_tool" + MAX_PROFILE_COUNT = 4 + + +class ProfileManagerErrors: + SERIALIZATION_FAILED = "Data Serialization Failed." + PROFILE_NAME_EXISTS = "A profile with this name already exists." + DUPLICATE_API = "It appears that a duplicate call may have been made." + PLATFORM_ERROR = "Seems an error occured in Platform Service." diff --git a/backend/prompt_studio/prompt_profile_manager_v2/exceptions.py b/backend/prompt_studio/prompt_profile_manager_v2/exceptions.py new file mode 100644 index 000000000..023f6ad1c --- /dev/null +++ b/backend/prompt_studio/prompt_profile_manager_v2/exceptions.py @@ -0,0 +1,6 @@ +from rest_framework.exceptions import APIException + + +class PlatformServiceError(APIException): + status_code = 400 + default_detail = "Seems an error occured in Platform Service." diff --git a/backend/prompt_studio/prompt_profile_manager_v2/models.py b/backend/prompt_studio/prompt_profile_manager_v2/models.py new file mode 100644 index 000000000..303e96f0f --- /dev/null +++ b/backend/prompt_studio/prompt_profile_manager_v2/models.py @@ -0,0 +1,115 @@ +import uuid + +from account_v2.models import User +from adapter_processor_v2.models import AdapterInstance +from django.db import models +from prompt_studio.prompt_studio_core_v2.exceptions import DefaultProfileError +from prompt_studio.prompt_studio_core_v2.models import CustomTool +from utils.models.base_model import BaseModel + + +class ProfileManager(BaseModel): + """Model to store the LLM Triad management details for Prompt.""" + + class RetrievalStrategy(models.TextChoices): + SIMPLE = "simple", "Simple retrieval" + SUBQUESTION = "subquestion", "Subquestion retrieval" + + profile_id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + profile_name = models.TextField(blank=False) + vector_store = models.ForeignKey( + AdapterInstance, + db_comment="Field to store the chosen vector store.", + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="profiles_vector_store", + ) + embedding_model = models.ForeignKey( + AdapterInstance, + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="profiles_embedding_model", + ) + llm = models.ForeignKey( + AdapterInstance, + db_comment="Field to store the LLM chosen by the user", + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="profiles_llm", + ) + x2text = models.ForeignKey( + AdapterInstance, + db_comment="Field to store the X2Text Adapter chosen by the user", + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="profiles_x2text", + ) + chunk_size = models.IntegerField(null=True, blank=True) + chunk_overlap = models.IntegerField(null=True, blank=True) + reindex = models.BooleanField(default=False) + retrieval_strategy = models.TextField( + choices=RetrievalStrategy.choices, + blank=True, + db_comment="Field to store the retrieval strategy for prompts", + ) + similarity_top_k = models.IntegerField( + blank=True, + null=True, + db_comment="Field to store number of top embeddings to take into context", # noqa: E501 + ) + section = models.TextField( + blank=True, null=True, db_comment="Field to store limit to section" + ) + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="profile_managers_created", + null=True, + blank=True, + editable=False, + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="profile_managers_modified", + null=True, + blank=True, + editable=False, + ) + + prompt_studio_tool = models.ForeignKey( + CustomTool, on_delete=models.CASCADE, null=True, related_name="profile_managers" + ) + is_default = models.BooleanField( + default=False, + db_comment="Default LLM Profile used in prompt", + ) + + is_summarize_llm = models.BooleanField( + default=False, + db_comment="Default LLM Profile used for summarizing", + ) + + class Meta: + verbose_name = "Profile Manager" + verbose_name_plural = "Profile Managers" + db_table = "profile_manager_v2" + constraints = [ + models.UniqueConstraint( + fields=["prompt_studio_tool", "profile_name"], + name="unique_prompt_studio_tool_profile_name_index", + ), + ] + + @staticmethod + def get_default_llm_profile(tool: CustomTool) -> "ProfileManager": + try: + return ProfileManager.objects.get( # type: ignore + prompt_studio_tool=tool, is_default=True + ) + except ProfileManager.DoesNotExist: + raise DefaultProfileError diff --git a/backend/prompt_studio/prompt_profile_manager_v2/profile_manager_helper.py b/backend/prompt_studio/prompt_profile_manager_v2/profile_manager_helper.py new file mode 100644 index 000000000..48c731b08 --- /dev/null +++ b/backend/prompt_studio/prompt_profile_manager_v2/profile_manager_helper.py @@ -0,0 +1,11 @@ +from prompt_studio.prompt_profile_manager_v2.models import ProfileManager + + +class ProfileManagerHelper: + + @classmethod + def get_profile_manager(cls, profile_manager_id: str) -> ProfileManager: + try: + return ProfileManager.objects.get(profile_id=profile_manager_id) + except ProfileManager.DoesNotExist: + raise ValueError("ProfileManager does not exist.") diff --git a/backend/prompt_studio/prompt_profile_manager_v2/serializers.py b/backend/prompt_studio/prompt_profile_manager_v2/serializers.py new file mode 100644 index 000000000..6810918e1 --- /dev/null +++ b/backend/prompt_studio/prompt_profile_manager_v2/serializers.py @@ -0,0 +1,53 @@ +import logging + +from adapter_processor_v2.adapter_processor import AdapterProcessor +from prompt_studio.prompt_profile_manager_v2.constants import ProfileManagerKeys +from prompt_studio.prompt_studio_core_v2.exceptions import MaxProfilesReachedError + +from backend.serializers import AuditSerializer + +from .models import ProfileManager + +logger = logging.getLogger(__name__) + + +class ProfileManagerSerializer(AuditSerializer): + class Meta: + model = ProfileManager + fields = "__all__" + + def to_representation(self, instance): # type: ignore + rep: dict[str, str] = super().to_representation(instance) + llm = rep[ProfileManagerKeys.LLM] + embedding = rep[ProfileManagerKeys.EMBEDDING_MODEL] + vector_db = rep[ProfileManagerKeys.VECTOR_STORE] + x2text = rep[ProfileManagerKeys.X2TEXT] + if llm: + rep[ProfileManagerKeys.LLM] = AdapterProcessor.get_adapter_instance_by_id( + llm + ) + if embedding: + rep[ProfileManagerKeys.EMBEDDING_MODEL] = ( + AdapterProcessor.get_adapter_instance_by_id(embedding) + ) + if vector_db: + rep[ProfileManagerKeys.VECTOR_STORE] = ( + AdapterProcessor.get_adapter_instance_by_id(vector_db) + ) + if x2text: + rep[ProfileManagerKeys.X2TEXT] = ( + AdapterProcessor.get_adapter_instance_by_id(x2text) + ) + return rep + + def validate(self, data): + prompt_studio_tool = data.get(ProfileManagerKeys.PROMPT_STUDIO_TOOL) + + profile_count = ProfileManager.objects.filter( + prompt_studio_tool=prompt_studio_tool + ).count() + + if profile_count >= ProfileManagerKeys.MAX_PROFILE_COUNT: + raise MaxProfilesReachedError() + + return data diff --git a/backend/prompt_studio/prompt_profile_manager_v2/urls.py b/backend/prompt_studio/prompt_profile_manager_v2/urls.py new file mode 100644 index 000000000..ae95f1fb9 --- /dev/null +++ b/backend/prompt_studio/prompt_profile_manager_v2/urls.py @@ -0,0 +1,24 @@ +from django.urls import path +from rest_framework.urlpatterns import format_suffix_patterns + +from .views import ProfileManagerView + +profile_manager_detail = ProfileManagerView.as_view( + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", + } +) + + +urlpatterns = format_suffix_patterns( + [ + path( + "profile-manager//", + profile_manager_detail, + name="profile-manager-detail", + ), + ] +) diff --git a/backend/prompt_studio/prompt_profile_manager_v2/views.py b/backend/prompt_studio/prompt_profile_manager_v2/views.py new file mode 100644 index 000000000..3bfd4a5fd --- /dev/null +++ b/backend/prompt_studio/prompt_profile_manager_v2/views.py @@ -0,0 +1,50 @@ +from typing import Any, Optional + +from account_v2.custom_exceptions import DuplicateData +from django.db import IntegrityError +from django.db.models import QuerySet +from django.http import HttpRequest +from permissions.permission import IsOwner +from prompt_studio.prompt_profile_manager_v2.constants import ( + ProfileManagerErrors, + ProfileManagerKeys, +) +from prompt_studio.prompt_profile_manager_v2.serializers import ProfileManagerSerializer +from rest_framework import status, viewsets +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from utils.filtering import FilterHelper + +from .models import ProfileManager + + +class ProfileManagerView(viewsets.ModelViewSet): + """Viewset to handle all Custom tool related operations.""" + + versioning_class = URLPathVersioning + permission_classes = [IsOwner] + serializer_class = ProfileManagerSerializer + + def get_queryset(self) -> Optional[QuerySet]: + filter_args = FilterHelper.build_filter_args( + self.request, + ProfileManagerKeys.CREATED_BY, + ) + if filter_args: + queryset = ProfileManager.objects.filter(**filter_args) + else: + queryset = ProfileManager.objects.all() + return queryset + + def create( + self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> Response: + serializer: ProfileManagerSerializer = self.get_serializer(data=request.data) + # Overriding default exception behaviour + # TO DO : Handle model related exceptions. + serializer.is_valid(raise_exception=True) + try: + self.perform_create(serializer) + except IntegrityError: + raise DuplicateData(ProfileManagerErrors.PROFILE_NAME_EXISTS) + return Response(serializer.data, status=status.HTTP_201_CREATED) diff --git a/backend/prompt_studio/prompt_studio_core_v2/__init__.py b/backend/prompt_studio/prompt_studio_core_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/prompt_studio/prompt_studio_core_v2/admin.py b/backend/prompt_studio/prompt_studio_core_v2/admin.py new file mode 100644 index 000000000..e6e1457ea --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import CustomTool + +admin.site.register(CustomTool) diff --git a/backend/prompt_studio/prompt_studio_core_v2/apps.py b/backend/prompt_studio/prompt_studio_core_v2/apps.py new file mode 100644 index 000000000..085567889 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class CustomTool(AppConfig): + name = "prompt_studio.prompt_studio_core_v2" diff --git a/backend/prompt_studio/prompt_studio_core_v2/constants.py b/backend/prompt_studio/prompt_studio_core_v2/constants.py new file mode 100644 index 000000000..55d61e32e --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/constants.py @@ -0,0 +1,134 @@ +from enum import Enum + + +class ToolStudioKeys: + CREATED_BY = "created_by" + TOOL_ID = "tool_id" + PROMPTS = "prompts" + PLATFORM_SERVICE_API_KEY = "PLATFORM_SERVICE_API_KEY" + SUMMARIZE_LLM_PROFILE = "summarize_llm_profile" + DEFAULT_PROFILE = "default_profile" + + +class ToolStudioErrors: + SERIALIZATION_FAILED = "Data Serialization Failed." + TOOL_NAME_EXISTS = "Tool with the name already exists" + DUPLICATE_API = "It appears that a duplicate call may have been made." + PLATFORM_ERROR = "Seems an error occured in Platform Service." + PROMPT_NAME_EXISTS = "Prompt with the name already exists" + + +class ToolStudioPromptKeys: + CREATED_BY = "created_by" + TOOL_ID = "tool_id" + RUN_ID = "run_id" + NUMBER = "Number" + FLOAT = "Float" + PG_VECTOR = "Postgres pg_vector" + ANSWERS = "answers" + UNIQUE_FILE_ID = "unique_file_id" + ID = "id" + FILE_NAME = "file_name" + FILE_HASH = "file_hash" + TOOL_ID = "tool_id" + NAME = "name" + ACTIVE = "active" + PROMPT = "prompt" + CHUNK_SIZE = "chunk-size" + PROMPTX = "promptx" + VECTOR_DB = "vector-db" + EMBEDDING = "embedding" + X2TEXT_ADAPTER = "x2text_adapter" + CHUNK_OVERLAP = "chunk-overlap" + LLM = "llm" + IS_ASSERT = "is_assert" + ASSERTION_FAILURE_PROMPT = "assertion_failure_prompt" + RETRIEVAL_STRATEGY = "retrieval-strategy" + SIMPLE = "simple" + TYPE = "type" + NUMBER = "number" + EMAIL = "email" + DATE = "date" + BOOLEAN = "boolean" + JSON = "json" + PREAMBLE = "preamble" + SIMILARITY_TOP_K = "similarity-top-k" + PROMPT_TOKENS = "prompt_tokens" + COMPLETION_TOKENS = "completion_tokens" + TOTAL_TOKENS = "total_tokens" + RESPONSE = "response" + POSTAMBLE = "postamble" + GRAMMAR = "grammar" + WORD = "word" + SYNONYMS = "synonyms" + OUTPUTS = "outputs" + ASSERT_PROMPT = "assert_prompt" + SECTION = "section" + DEFAULT = "default" + REINDEX = "reindex" + EMBEDDING_SUFFIX = "embedding_suffix" + EVAL_METRIC_PREFIX = "eval_" + EVAL_RESULT_DELIM = "__" + EVAL_SETTINGS = "eval_settings" + EVAL_SETTINGS_EVALUATE = "evaluate" + EVAL_SETTINGS_MONITOR_LLM = "monitor_llm" + EVAL_SETTINGS_EXCLUDE_FAILED = "exclude_failed" + SUMMARIZE = "summarize" + SUMMARIZED_RESULT = "summarized_result" + DOCUMENT_ID = "document_id" + EXTRACT = "extract" + TOOL_SETTINGS = "tool_settings" + ENABLE_CHALLENGE = "enable_challenge" + CHALLENGE_LLM = "challenge_llm" + SINGLE_PASS_EXTRACTION_MODE = "single_pass_extraction_mode" + SINGLE_PASS_EXTRACTION = "single_pass_extraction" + NOTES = "NOTES" + OUTPUT = "output" + SEQUENCE_NUMBER = "sequence_number" + PROFILE_MANAGER_ID = "profile_manager" + CONTEXT = "context" + METADATA = "metadata" + + +class FileViewTypes: + ORIGINAL = "ORIGINAL" + EXTRACT = "EXTRACT" + SUMMARIZE = "SUMMARIZE" + + +class LogLevels: + INFO = "INFO" + ERROR = "ERROR" + DEBUG = "DEBUG" + RUN = "RUN" + + +class LogLevel(Enum): + DEBUG = "DEBUG" + INFO = "INFO" + WARN = "WARN" + ERROR = "ERROR" + FATAL = "FATAL" + + +class IndexingStatus(Enum): + PENDING_STATUS = "pending" + COMPLETED_STATUS = "completed" + STARTED_STATUS = "started" + DOCUMENT_BEING_INDEXED = "Document is being indexed" + + +class DefaultPrompts: + PREAMBLE = ( + "Your ability to extract and summarize this context accurately " + "is essential for effective analysis. " + "Pay close attention to the context's language, structure, and any " + "cross-references to ensure a comprehensive and precise extraction " + "of information. Do not use prior knowledge or information from " + "outside the context to answer the questions. Only use the " + "information provided in the context to answer the questions." + ) + POSTAMBLE = ( + "Do not include any explanation in the reply. " + "Only include the extracted information in the reply." + ) diff --git a/backend/prompt_studio/prompt_studio_core_v2/document_indexing_service.py b/backend/prompt_studio/prompt_studio_core_v2/document_indexing_service.py new file mode 100644 index 000000000..a323b73fc --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/document_indexing_service.py @@ -0,0 +1,53 @@ +from typing import Optional + +from django.conf import settings +from prompt_studio.prompt_studio_core_v2.constants import IndexingStatus +from utils.cache_service import CacheService + + +class DocumentIndexingService: + CACHE_PREFIX = "document_indexing:" + + @classmethod + def set_document_indexing(cls, org_id: str, user_id: str, doc_id_key: str) -> None: + CacheService.set_key( + cls._cache_key(org_id, user_id, doc_id_key), + IndexingStatus.STARTED_STATUS.value, + expire=settings.INDEXING_FLAG_TTL, + ) + + @classmethod + def is_document_indexing(cls, org_id: str, user_id: str, doc_id_key: str) -> bool: + return ( + CacheService.get_key(cls._cache_key(org_id, user_id, doc_id_key)) + == IndexingStatus.STARTED_STATUS.value + ) + + @classmethod + def mark_document_indexed( + cls, org_id: str, user_id: str, doc_id_key: str, doc_id: str + ) -> None: + CacheService.set_key( + cls._cache_key(org_id, user_id, doc_id_key), + doc_id, + expire=settings.INDEXING_FLAG_TTL, + ) + + @classmethod + def get_indexed_document_id( + cls, org_id: str, user_id: str, doc_id_key: str + ) -> Optional[str]: + result = CacheService.get_key(cls._cache_key(org_id, user_id, doc_id_key)) + if result and result != IndexingStatus.STARTED_STATUS.value: + return result + return None + + @classmethod + def remove_document_indexing( + cls, org_id: str, user_id: str, doc_id_key: str + ) -> None: + CacheService.delete_a_key(cls._cache_key(org_id, user_id, doc_id_key)) + + @classmethod + def _cache_key(cls, org_id: str, user_id: str, doc_id_key: str) -> str: + return f"{cls.CACHE_PREFIX}{org_id}:{user_id}:{doc_id_key}" diff --git a/backend/prompt_studio/prompt_studio_core_v2/exceptions.py b/backend/prompt_studio/prompt_studio_core_v2/exceptions.py new file mode 100644 index 000000000..a58f672d2 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/exceptions.py @@ -0,0 +1,69 @@ +from prompt_studio.prompt_profile_manager_v2.constants import ProfileManagerKeys +from prompt_studio.prompt_studio_core_v2.constants import ToolStudioErrors +from rest_framework.exceptions import APIException + + +class PlatformServiceError(APIException): + status_code = 400 + default_detail = ToolStudioErrors.PLATFORM_ERROR + + +class ToolNotValid(APIException): + status_code = 400 + default_detail = "Custom tool is not valid." + + +class IndexingAPIError(APIException): + status_code = 500 + default_detail = "Error while indexing file" + + +class AnswerFetchError(APIException): + status_code = 500 + default_detail = "Error occured while fetching response for the prompt" + + +class DefaultProfileError(APIException): + status_code = 500 + default_detail = ( + "Default LLM profile is not configured." + "Please set an LLM profile as default to continue." + ) + + +class EnvRequired(APIException): + status_code = 404 + default_detail = "Environment variable not set" + + +class OutputSaveError(APIException): + status_code = 500 + default_detail = "Unable to store the output." + + +class ToolDeleteError(APIException): + status_code = 500 + default_detail = "Failed to delete the error" + + +class NoPromptsFound(APIException): + status_code = 404 + default_detail = "No prompts available to process" + + +class PermissionError(APIException): + status_code = 403 + default_detail = "You do not have permission to perform this action." + + +class EmptyPromptError(APIException): + status_code = 422 + default_detail = "Prompt(s) cannot be empty" + + +class MaxProfilesReachedError(APIException): + status_code = 403 + default_detail = ( + f"Maximum number of profiles (max {ProfileManagerKeys.MAX_PROFILE_COUNT})" + " per prompt studio project has been reached." + ) diff --git a/backend/prompt_studio/prompt_studio_core_v2/models.py b/backend/prompt_studio/prompt_studio_core_v2/models.py new file mode 100644 index 000000000..f4784e47d --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/models.py @@ -0,0 +1,158 @@ +import logging +import shutil +import uuid +from typing import Any + +from account_v2.models import User +from adapter_processor_v2.models import AdapterInstance +from django.db import models +from django.db.models import QuerySet +from file_management.file_management_helper import FileManagerHelper +from prompt_studio.prompt_studio_core_v2.constants import DefaultPrompts +from utils.models.base_model import BaseModel +from utils.models.organization_mixin import ( + DefaultOrganizationManagerMixin, + DefaultOrganizationMixin, +) + +logger = logging.getLogger(__name__) + + +class CustomToolModelManager(DefaultOrganizationManagerMixin, models.Manager): + + def for_user(self, user: User) -> QuerySet[Any]: + return ( + self.get_queryset() + .filter(models.Q(created_by=user) | models.Q(shared_users=user)) + .distinct("tool_id") + ) + + +class CustomTool(DefaultOrganizationMixin, BaseModel): + """Model to store the custom tools designed in the tool studio.""" + + tool_id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + tool_name = models.TextField(blank=False, null=False) + description = models.TextField(blank=False, null=False) + author = models.TextField( + blank=False, + null=False, + db_comment="Specific to the user who created the tool.", + ) + icon = models.TextField( + blank=True, + db_comment="Field to store \ + icon url for the custom tool.", + ) + output = models.TextField( + db_comment="Field to store the output format type.", + blank=True, + ) + log_id = models.UUIDField( + default=uuid.uuid4, + db_comment="Field to store unique log_id for polling", + ) + + summarize_context = models.BooleanField( + default=False, db_comment="Flag to summarize content" + ) + summarize_as_source = models.BooleanField( + default=False, db_comment="Flag to use summarized content as source" + ) + summarize_prompt = models.TextField( + blank=True, + db_comment="Field to store the summarize prompt", + unique=False, + ) + preamble = models.TextField( + blank=True, + db_comment="Preamble to the prompts", + default=DefaultPrompts.PREAMBLE, + ) + postamble = models.TextField( + blank=True, + db_comment="Appended as postable to prompts.", + default=DefaultPrompts.POSTAMBLE, + ) + prompt_grammer = models.JSONField( + null=True, blank=True, db_comment="Synonymous words used in prompt" + ) + monitor_llm = models.ForeignKey( + AdapterInstance, + on_delete=models.PROTECT, + db_comment="Field to store monitor llm", + null=True, + blank=True, + related_name="custom_tools_monitor", + ) + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + null=True, + blank=True, + editable=False, + related_name="custom_tools_created", + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + null=True, + blank=True, + editable=False, + related_name="custom_tools_modified", + ) + + exclude_failed = models.BooleanField( + db_comment="Flag to make the answer null if it is incorrect", + default=True, + ) + single_pass_extraction_mode = models.BooleanField( + db_comment="Flag to enable or disable single pass extraction mode", + default=False, + ) + challenge_llm = models.ForeignKey( + AdapterInstance, + on_delete=models.PROTECT, + db_comment="Field to store challenge llm", + null=True, + blank=True, + related_name="custom_tools_challenge", + ) + enable_challenge = models.BooleanField( + db_comment="Flag to enable or disable challenge", default=False + ) + + # Introduced field to establish M2M relation between users and custom_tool. + # This will introduce intermediary table which relates both the models. + shared_users = models.ManyToManyField(User, related_name="shared_custom_tools") + + objects = CustomToolModelManager() + + def delete(self, organization_id=None, *args, **kwargs): + # Delete the documents associated with the tool + file_path = FileManagerHelper.handle_sub_directory_for_tenants( + organization_id, + is_create=False, + user_id=self.created_by.user_id, + tool_id=str(self.tool_id), + ) + if organization_id: + try: + shutil.rmtree(file_path) + except FileNotFoundError: + logger.error(f"The folder {file_path} does not exist.") + except OSError as e: + logger.error(f"Error: {file_path} : {e.strerror}") + # Continue with the deletion of the tool + super().delete(*args, **kwargs) + + class Meta: + verbose_name = "Custom Tool" + verbose_name_plural = "Custom Tools" + db_table = "custom_tool_v2" + constraints = [ + models.UniqueConstraint( + fields=["tool_name", "organization"], + name="unique_tool_name", + ), + ] diff --git a/backend/prompt_studio/prompt_studio_core_v2/prompt_ide_base_tool.py b/backend/prompt_studio/prompt_studio_core_v2/prompt_ide_base_tool.py new file mode 100644 index 000000000..d2b403974 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/prompt_ide_base_tool.py @@ -0,0 +1,42 @@ +import os + +from platform_settings_v2.platform_auth_service import PlatformAuthenticationService +from prompt_studio.prompt_studio_core_v2.constants import LogLevel, ToolStudioKeys +from unstract.sdk.tool.stream import StreamMixin + + +class PromptIdeBaseTool(StreamMixin): + def __init__(self, log_level: LogLevel = LogLevel.INFO, org_id: str = "") -> None: + """ + Args: + tool (UnstractAbstractTool): Instance of UnstractAbstractTool + Notes: + - PLATFORM_SERVICE_API_KEY environment variable is required. + """ + self.log_level = log_level + self.org_id = org_id + super().__init__(log_level=log_level) + + def get_env_or_die(self, env_key: str) -> str: + """Returns the value of an env variable. + + If its empty or None, raises an error and exits + + Args: + env_key (str): Key to retrieve + + Returns: + str: Value of the env + """ + # HACK: Adding platform key for multitenancy + if env_key == ToolStudioKeys.PLATFORM_SERVICE_API_KEY: + platform_key = PlatformAuthenticationService.get_active_platform_key( + self.org_id + ) + key: str = str(platform_key.key) + return key + else: + env_value = os.environ.get(env_key) + if env_value is None or env_value == "": + self.stream_error_and_exit(f"Env variable {env_key} is required") + return env_value # type:ignore diff --git a/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py new file mode 100644 index 000000000..a597d1a4f --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py @@ -0,0 +1,972 @@ +import json +import logging +import os +import uuid +from pathlib import Path +from typing import Any, Optional + +from account_v2.constants import Common +from account_v2.models import User +from adapter_processor_v2.constants import AdapterKeys +from adapter_processor_v2.models import AdapterInstance +from django.conf import settings +from django.db.models.manager import BaseManager +from file_management.file_management_helper import FileManagerHelper +from prompt_studio.prompt_profile_manager_v2.models import ProfileManager +from prompt_studio.prompt_profile_manager_v2.profile_manager_helper import ( + ProfileManagerHelper, +) +from prompt_studio.prompt_studio_core_v2.constants import IndexingStatus, LogLevels +from prompt_studio.prompt_studio_core_v2.constants import ( + ToolStudioPromptKeys as TSPKeys, +) +from prompt_studio.prompt_studio_core_v2.document_indexing_service import ( + DocumentIndexingService, +) +from prompt_studio.prompt_studio_core_v2.exceptions import ( + AnswerFetchError, + DefaultProfileError, + EmptyPromptError, + IndexingAPIError, + NoPromptsFound, + PermissionError, + ToolNotValid, +) +from prompt_studio.prompt_studio_core_v2.models import CustomTool +from prompt_studio.prompt_studio_core_v2.prompt_ide_base_tool import PromptIdeBaseTool +from prompt_studio.prompt_studio_document_manager_v2.models import DocumentManager +from prompt_studio.prompt_studio_index_manager_v2.prompt_studio_index_helper import ( # noqa: E501 + PromptStudioIndexHelper, +) +from prompt_studio.prompt_studio_output_manager_v2.output_manager_helper import ( + OutputManagerHelper, +) +from prompt_studio.prompt_studio_v2.models import ToolStudioPrompt +from unstract.sdk.constants import LogLevel +from unstract.sdk.exceptions import IndexingError, SdkError +from unstract.sdk.index import Index +from unstract.sdk.prompt import PromptTool +from unstract.sdk.utils.tool_utils import ToolUtils +from utils.local_context import StateStore + +from unstract.core.pubsub_helper import LogPublisher + +CHOICES_JSON = "/static/select_choices.json" +ERROR_MSG = "User %s doesn't have access to adapter %s" + +logger = logging.getLogger(__name__) + + +class PromptStudioHelper: + """Helper class for Custom tool operations.""" + + @staticmethod + def create_default_profile_manager(user: User, tool_id: uuid) -> None: + """Create a default profile manager for a given user and tool. + + Args: + user (User): The user for whom the default profile manager is + created. + tool_id (uuid): The ID of the tool for which the default profile + manager is created. + + Raises: + AdapterInstance.DoesNotExist: If no suitable adapter instance is + found for creating the default profile manager. + + Returns: + None + """ + try: + AdapterInstance.objects.get( + is_friction_less=True, + is_usable=True, + adapter_type=AdapterKeys.LLM, + ) + + default_adapters: BaseManager[AdapterInstance] = ( + AdapterInstance.objects.filter(is_friction_less=True) + ) + + profile_manager = ProfileManager( + prompt_studio_tool=CustomTool.objects.get(pk=tool_id), + is_default=True, + created_by=user, + modified_by=user, + chunk_size=0, + profile_name="sample profile", + chunk_overlap=0, + section="Default", + retrieval_strategy="simple", + similarity_top_k=3, + ) + + for adapter in default_adapters: + if adapter.adapter_type == AdapterKeys.LLM: + profile_manager.llm = adapter + elif adapter.adapter_type == AdapterKeys.VECTOR_DB: + profile_manager.vector_store = adapter + elif adapter.adapter_type == AdapterKeys.X2TEXT: + profile_manager.x2text = adapter + elif adapter.adapter_type == AdapterKeys.EMBEDDING: + profile_manager.embedding_model = adapter + + profile_manager.save() + + except AdapterInstance.DoesNotExist: + logger.info("skipping default profile creation") + + @staticmethod + def validate_adapter_status( + profile_manager: ProfileManager, + ) -> None: + """Helper method to validate the status of adapters in profile manager. + + Args: + profile_manager (ProfileManager): The profile manager instance to + validate. + + Raises: + PermissionError: If the owner does not have permission to perform + the action. + """ + + error_msg = "Permission Error: Free usage for the configured trial adapter exhausted.Please connect your own service accounts to continue.Please see our documentation for more details:https://docs.unstract.com/unstract_platform/setup_accounts/whats_needed" # noqa: E501 + adapters = [ + profile_manager.llm, + profile_manager.vector_store, + profile_manager.embedding_model, + profile_manager.x2text, + ] + + for adapter in adapters: + if not adapter.is_usable: + raise PermissionError(error_msg) + + @staticmethod + def validate_profile_manager_owner_access( + profile_manager: ProfileManager, + ) -> None: + """Helper method to validate the owner's access to the profile manager. + + Args: + profile_manager (ProfileManager): The profile manager instance to + validate. + + Raises: + PermissionError: If the owner does not have permission to perform + the action. + """ + profile_manager_owner = profile_manager.created_by + + is_llm_owned = ( + profile_manager.llm.shared_to_org + or profile_manager.llm.created_by == profile_manager_owner + or profile_manager.llm.shared_users.filter( + pk=profile_manager_owner.pk + ).exists() + ) + is_vector_store_owned = ( + profile_manager.vector_store.shared_to_org + or profile_manager.vector_store.created_by == profile_manager_owner + or profile_manager.vector_store.shared_users.filter( + pk=profile_manager_owner.pk + ).exists() + ) + is_embedding_model_owned = ( + profile_manager.embedding_model.shared_to_org + or profile_manager.embedding_model.created_by == profile_manager_owner + or profile_manager.embedding_model.shared_users.filter( + pk=profile_manager_owner.pk + ).exists() + ) + is_x2text_owned = ( + profile_manager.x2text.shared_to_org + or profile_manager.x2text.created_by == profile_manager_owner + or profile_manager.x2text.shared_users.filter( + pk=profile_manager_owner.pk + ).exists() + ) + + if not ( + is_llm_owned + and is_vector_store_owned + and is_embedding_model_owned + and is_x2text_owned + ): + adapter_names = set() + if not is_llm_owned: + logger.error( + ERROR_MSG, + profile_manager_owner.user_id, + profile_manager.llm.id, + ) + adapter_names.add(profile_manager.llm.adapter_name) + if not is_vector_store_owned: + logger.error( + ERROR_MSG, + profile_manager_owner.user_id, + profile_manager.vector_store.id, + ) + adapter_names.add(profile_manager.vector_store.adapter_name) + if not is_embedding_model_owned: + logger.error( + ERROR_MSG, + profile_manager_owner.user_id, + profile_manager.embedding_model.id, + ) + adapter_names.add(profile_manager.embedding_model.adapter_name) + if not is_x2text_owned: + logger.error( + ERROR_MSG, + profile_manager_owner.user_id, + profile_manager.x2text.id, + ) + adapter_names.add(profile_manager.x2text.adapter_name) + if len(adapter_names) > 1: + error_msg = ( + f"Multiple permission errors were encountered with {', '.join(adapter_names)}", # noqa: E501 + ) + else: + error_msg = ( + f"Permission Error: You do not have access to {adapter_names.pop()}", # noqa: E501 + ) + + raise PermissionError(error_msg) + + @staticmethod + def _publish_log( + component: dict[str, str], level: str, state: str, message: str + ) -> None: + LogPublisher.publish( + StateStore.get(Common.LOG_EVENTS_ID), + LogPublisher.log_prompt(component, level, state, message), + ) + + @staticmethod + def get_select_fields() -> dict[str, Any]: + """Method to fetch dropdown field values for frontend. + + Returns: + dict[str, Any]: Dict for dropdown data + """ + f = open(f"{os.path.dirname(__file__)}{CHOICES_JSON}") + choices = f.read() + f.close() + response: dict[str, Any] = json.loads(choices) + return response + + @staticmethod + def _fetch_prompt_from_id(id: str) -> ToolStudioPrompt: + """Internal function used to fetch prompt from ID. + + Args: + id (_type_): UUID of the prompt + + Returns: + ToolStudioPrompt: Instance of the model + """ + prompt_instance: ToolStudioPrompt = ToolStudioPrompt.objects.get(pk=id) + return prompt_instance + + @staticmethod + def fetch_prompt_from_tool(tool_id: str) -> list[ToolStudioPrompt]: + """Internal function used to fetch mapped prompts from ToolID. + + Args: + tool_id (_type_): UUID of the tool + + Returns: + List[ToolStudioPrompt]: List of instance of the model + """ + prompt_instances: list[ToolStudioPrompt] = ToolStudioPrompt.objects.filter( + tool_id=tool_id + ).order_by(TSPKeys.SEQUENCE_NUMBER) + return prompt_instances + + @staticmethod + def index_document( + tool_id: str, + file_name: str, + org_id: str, + user_id: str, + document_id: str, + is_summary: bool = False, + run_id: str = None, + ) -> Any: + """Method to index a document. + + Args: + tool_id (str): Id of the tool + file_name (str): File to parse + org_id (str): The ID of the organization to which the user belongs. + user_id (str): The ID of the user who uploaded the document. + is_summary (bool, optional): Whether the document is a summary + or not. Defaults to False. + + Raises: + ToolNotValid + IndexingError + """ + tool: CustomTool = CustomTool.objects.get(pk=tool_id) + if is_summary: + profile_manager: ProfileManager = ProfileManager.objects.get( + prompt_studio_tool=tool, is_summarize_llm=True + ) + default_profile = profile_manager + file_path = file_name + else: + default_profile = ProfileManager.get_default_llm_profile(tool) + file_path = FileManagerHelper.handle_sub_directory_for_tenants( + org_id, + is_create=False, + user_id=user_id, + tool_id=tool_id, + ) + file_path = str(Path(file_path) / file_name) + + if not tool: + logger.error(f"No tool instance found for the ID {tool_id}") + raise ToolNotValid() + + logger.info(f"[{tool_id}] Indexing started for doc: {file_name}") + PromptStudioHelper._publish_log( + {"tool_id": tool_id, "run_id": run_id, "doc_name": file_name}, + LogLevels.INFO, + LogLevels.RUN, + "Indexing started", + ) + + # Validate the status of adapter in profile manager + PromptStudioHelper.validate_adapter_status(default_profile) + # Need to check the user who created profile manager + # has access to adapters configured in profile manager + PromptStudioHelper.validate_profile_manager_owner_access(default_profile) + + doc_id = PromptStudioHelper.dynamic_indexer( + profile_manager=default_profile, + tool_id=tool_id, + file_path=file_path, + org_id=org_id, + document_id=document_id, + is_summary=is_summary, + reindex=True, + run_id=run_id, + user_id=user_id, + ) + + logger.info(f"[{tool_id}] Indexing successful for doc: {file_name}") + PromptStudioHelper._publish_log( + {"tool_id": tool_id, "run_id": run_id, "doc_name": file_name}, + LogLevels.INFO, + LogLevels.RUN, + "Indexing successful", + ) + + return doc_id.get("output") + + @staticmethod + def prompt_responder( + tool_id: str, + org_id: str, + user_id: str, + document_id: str, + id: Optional[str] = None, + run_id: str = None, + profile_manager_id: Optional[str] = None, + ) -> Any: + """Execute chain/single run of the prompts. Makes a call to prompt + service and returns the dict of response. + + Args: + tool_id (str): ID of tool created in prompt studio + org_id (str): Organization ID + user_id (str): User's ID + document_id (str): UUID of the document uploaded + id (Optional[str]): ID of the prompt + profile_manager_id (Optional[str]): UUID of the profile manager + + Raises: + AnswerFetchError: Error from prompt-service + + Returns: + Any: Dictionary containing the response from prompt-service + """ + document: DocumentManager = DocumentManager.objects.get(pk=document_id) + doc_name: str = document.document_name + doc_path = PromptStudioHelper._get_document_path( + org_id, user_id, tool_id, doc_name + ) + + if id: + return PromptStudioHelper._execute_single_prompt( + id, + doc_path, + doc_name, + tool_id, + org_id, + user_id, + document_id, + run_id, + profile_manager_id, + ) + else: + return PromptStudioHelper._execute_prompts_in_single_pass( + doc_path, tool_id, org_id, user_id, document_id, run_id + ) + + @staticmethod + def _execute_single_prompt( + id, + doc_path, + doc_name, + tool_id, + org_id, + user_id, + document_id, + run_id, + profile_manager_id, + ): + prompt_instance = PromptStudioHelper._fetch_prompt_from_id(id) + prompt_name = prompt_instance.prompt_key + PromptStudioHelper._publish_log( + { + "tool_id": tool_id, + "run_id": run_id, + "prompt_key": prompt_name, + "doc_name": doc_name, + }, + LogLevels.INFO, + LogLevels.RUN, + "Executing single prompt", + ) + prompts = [prompt_instance] + tool = prompt_instance.tool_id + + if tool.summarize_as_source: + directory, filename = os.path.split(doc_path) + doc_path = os.path.join( + directory, TSPKeys.SUMMARIZE, os.path.splitext(filename)[0] + ".txt" + ) + + PromptStudioHelper._publish_log( + { + "tool_id": tool_id, + "run_id": run_id, + "prompt_key": prompt_name, + "doc_name": doc_name, + }, + LogLevels.DEBUG, + LogLevels.RUN, + "Invoking prompt service", + ) + + try: + response = PromptStudioHelper._fetch_response( + doc_path=doc_path, + doc_name=doc_name, + tool=tool, + prompt=prompt_instance, + org_id=org_id, + document_id=document_id, + run_id=run_id, + profile_manager_id=profile_manager_id, + user_id=user_id, + ) + return PromptStudioHelper._handle_response( + response, run_id, prompts, document_id, False, profile_manager_id + ) + except Exception as e: + logger.error( + f"[{tool.tool_id}] Error while fetching response for " + f"prompt {id} and doc {document_id}: {e}" + ) + msg = str(e) + PromptStudioHelper._publish_log( + { + "tool_id": tool_id, + "run_id": run_id, + "prompt_key": prompt_name, + "doc_name": doc_name, + }, + LogLevels.ERROR, + LogLevels.RUN, + msg, + ) + raise e + + @staticmethod + def _execute_prompts_in_single_pass( + doc_path, tool_id, org_id, user_id, document_id, run_id + ): + prompts = PromptStudioHelper.fetch_prompt_from_tool(tool_id) + prompts = [prompt for prompt in prompts if prompt.prompt_type != TSPKeys.NOTES] + if not prompts: + logger.error(f"[{tool_id or 'NA'}] No prompts found for id: {id}") + raise NoPromptsFound() + + PromptStudioHelper._publish_log( + {"tool_id": tool_id, "run_id": run_id, "prompt_id": str(id)}, + LogLevels.INFO, + LogLevels.RUN, + "Executing prompts in single pass", + ) + + try: + tool = prompts[0].tool_id + response = PromptStudioHelper._fetch_single_pass_response( + file_path=doc_path, + tool=tool, + prompts=prompts, + org_id=org_id, + document_id=document_id, + run_id=run_id, + user_id=user_id, + ) + return PromptStudioHelper._handle_response( + response, run_id, prompts, document_id, True + ) + except Exception as e: + logger.error( + f"[{tool.tool_id}] Error while fetching single pass response: {e}" + ) + PromptStudioHelper._publish_log( + { + "tool_id": tool_id, + "run_id": run_id, + "prompt_id": str(id), + }, + LogLevels.ERROR, + LogLevels.RUN, + f"Failed to fetch single pass response. {e}", + ) + raise e + + @staticmethod + def _get_document_path(org_id, user_id, tool_id, doc_name): + doc_path = FileManagerHelper.handle_sub_directory_for_tenants( + org_id=org_id, + user_id=user_id, + tool_id=tool_id, + is_create=False, + ) + return str(Path(doc_path) / doc_name) + + @staticmethod + def _handle_response( + response, run_id, prompts, document_id, is_single_pass, profile_manager_id=None + ): + if response.get("status") == IndexingStatus.PENDING_STATUS.value: + return { + "status": IndexingStatus.PENDING_STATUS.value, + "message": IndexingStatus.DOCUMENT_BEING_INDEXED.value, + } + + OutputManagerHelper.handle_prompt_output_update( + run_id=run_id, + prompts=prompts, + outputs=response["output"], + document_id=document_id, + is_single_pass_extract=is_single_pass, + profile_manager_id=profile_manager_id, + context=response["metadata"].get("context"), + ) + return response + + @staticmethod + def _fetch_response( + tool: CustomTool, + doc_path: str, + doc_name: str, + prompt: ToolStudioPrompt, + org_id: str, + document_id: str, + run_id: str, + user_id: str, + profile_manager_id: Optional[str] = None, + ) -> Any: + """Utility function to invoke prompt service. Used internally. + + Args: + tool (CustomTool): CustomTool instance (prompt studio project) + doc_path (str): Path to the document + doc_name (str): Name of the document + prompt (ToolStudioPrompt): ToolStudioPrompt instance to fetch response + org_id (str): UUID of the organization + document_id (str): UUID of the document + profile_manager_id (Optional[str]): UUID of the profile manager + user_id (str): The ID of the user who uploaded the document + + + Raises: + DefaultProfileError: If no default profile is selected + AnswerFetchError: Due to failures in prompt service + + Returns: + Any: Output from LLM + """ + + # Fetch the ProfileManager instance using the profile_manager_id if provided + profile_manager = prompt.profile_manager + if profile_manager_id: + profile_manager = ProfileManagerHelper.get_profile_manager( + profile_manager_id=profile_manager_id + ) + + monitor_llm_instance: Optional[AdapterInstance] = tool.monitor_llm + monitor_llm: Optional[str] = None + challenge_llm_instance: Optional[AdapterInstance] = tool.challenge_llm + challenge_llm: Optional[str] = None + + if monitor_llm_instance: + monitor_llm = str(monitor_llm_instance.id) + else: + # Using default profile manager llm if monitor_llm is None + default_profile = ProfileManager.get_default_llm_profile(tool) + monitor_llm = str(default_profile.llm.id) + + # Using default profile manager llm if challenge_llm is None + if challenge_llm_instance: + challenge_llm = str(challenge_llm_instance.id) + else: + default_profile = ProfileManager.get_default_llm_profile(tool) + challenge_llm = str(default_profile.llm.id) + + # Need to check the user who created profile manager + PromptStudioHelper.validate_adapter_status(profile_manager) + # Need to check the user who created profile manager + # has access to adapters + PromptStudioHelper.validate_profile_manager_owner_access(profile_manager) + # Not checking reindex here as there might be + # change in Profile Manager + vector_db = str(profile_manager.vector_store.id) + embedding_model = str(profile_manager.embedding_model.id) + llm = str(profile_manager.llm.id) + x2text = str(profile_manager.x2text.id) + if not profile_manager: + raise DefaultProfileError() + index_result = PromptStudioHelper.dynamic_indexer( + profile_manager=profile_manager, + file_path=doc_path, + tool_id=str(tool.tool_id), + org_id=org_id, + document_id=document_id, + is_summary=tool.summarize_as_source, + run_id=run_id, + user_id=user_id, + ) + if index_result.get("status") == IndexingStatus.PENDING_STATUS.value: + return { + "status": IndexingStatus.PENDING_STATUS.value, + "message": IndexingStatus.DOCUMENT_BEING_INDEXED.value, + } + + output: dict[str, Any] = {} + outputs: list[dict[str, Any]] = [] + grammer_dict = {} + grammar_list = [] + # Adding validations + prompt_grammer = tool.prompt_grammer + if prompt_grammer: + for word, synonyms in prompt_grammer.items(): + synonyms = prompt_grammer[word] + grammer_dict[TSPKeys.WORD] = word + grammer_dict[TSPKeys.SYNONYMS] = synonyms + grammar_list.append(grammer_dict) + grammer_dict = {} + + output[TSPKeys.PROMPT] = prompt.prompt + output[TSPKeys.ACTIVE] = prompt.active + output[TSPKeys.CHUNK_SIZE] = profile_manager.chunk_size + output[TSPKeys.VECTOR_DB] = vector_db + output[TSPKeys.EMBEDDING] = embedding_model + output[TSPKeys.CHUNK_OVERLAP] = profile_manager.chunk_overlap + output[TSPKeys.LLM] = llm + output[TSPKeys.TYPE] = prompt.enforce_type + output[TSPKeys.NAME] = prompt.prompt_key + output[TSPKeys.RETRIEVAL_STRATEGY] = profile_manager.retrieval_strategy + output[TSPKeys.SIMILARITY_TOP_K] = profile_manager.similarity_top_k + output[TSPKeys.SECTION] = profile_manager.section + output[TSPKeys.X2TEXT_ADAPTER] = x2text + # Eval settings for the prompt + output[TSPKeys.EVAL_SETTINGS] = {} + output[TSPKeys.EVAL_SETTINGS][TSPKeys.EVAL_SETTINGS_EVALUATE] = prompt.evaluate + output[TSPKeys.EVAL_SETTINGS][TSPKeys.EVAL_SETTINGS_MONITOR_LLM] = [monitor_llm] + output[TSPKeys.EVAL_SETTINGS][ + TSPKeys.EVAL_SETTINGS_EXCLUDE_FAILED + ] = tool.exclude_failed + for attr in dir(prompt): + if attr.startswith(TSPKeys.EVAL_METRIC_PREFIX): + attr_val = getattr(prompt, attr) + output[TSPKeys.EVAL_SETTINGS][attr] = attr_val + + outputs.append(output) + + tool_settings = {} + tool_settings[TSPKeys.ENABLE_CHALLENGE] = tool.enable_challenge + tool_settings[TSPKeys.CHALLENGE_LLM] = challenge_llm + tool_settings[TSPKeys.SINGLE_PASS_EXTRACTION_MODE] = ( + tool.single_pass_extraction_mode + ) + tool_settings[TSPKeys.PREAMBLE] = tool.preamble + tool_settings[TSPKeys.POSTAMBLE] = tool.postamble + tool_settings[TSPKeys.GRAMMAR] = grammar_list + + tool_id = str(tool.tool_id) + + file_hash = ToolUtils.get_hash_from_file(file_path=doc_path) + + payload = { + TSPKeys.TOOL_SETTINGS: tool_settings, + TSPKeys.OUTPUTS: outputs, + TSPKeys.TOOL_ID: tool_id, + TSPKeys.RUN_ID: run_id, + TSPKeys.FILE_NAME: doc_name, + TSPKeys.FILE_HASH: file_hash, + Common.LOG_EVENTS_ID: StateStore.get(Common.LOG_EVENTS_ID), + } + + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + + responder = PromptTool( + tool=util, + prompt_host=settings.PROMPT_HOST, + prompt_port=settings.PROMPT_PORT, + ) + + answer = responder.answer_prompt(payload) + # TODO: Make use of dataclasses + if answer["status"] == "ERROR": + # TODO: Publish to FE logs from here + error_message = answer.get("error", "") + raise AnswerFetchError( + "Error while fetching response for " + f"'{prompt.prompt_key}' with '{doc_name}'. {error_message}" + ) + output_response = json.loads(answer["structure_output"]) + return output_response + + @staticmethod + def dynamic_indexer( + profile_manager: ProfileManager, + tool_id: str, + file_path: str, + org_id: str, + document_id: str, + user_id: str, + is_summary: bool = False, + reindex: bool = False, + run_id: str = None, + ) -> Any: + """Used to index a file based on the passed arguments. + + This is useful when a file needs to be indexed dynamically as the + parameters meant for indexing changes. The file + + Args: + profile_manager (ProfileManager): Profile manager instance that hold + values such as chunk size, chunk overlap and adapter IDs + tool_id (str): UUID of the prompt studio tool + file_path (str): Path to the file that needs to be indexed + org_id (str): ID of the organization + is_summary (bool, optional): Flag to ensure if extracted contents + need to be persisted. Defaults to False. + user_id (str): The ID of the user who uploaded the document + + Returns: + str: Index key for the combination of arguments + """ + embedding_model = str(profile_manager.embedding_model.id) + vector_db = str(profile_manager.vector_store.id) + x2text_adapter = str(profile_manager.x2text.id) + extract_file_path: Optional[str] = None + + if not is_summary: + directory, filename = os.path.split(file_path) + extract_file_path = os.path.join( + directory, "extract", os.path.splitext(filename)[0] + ".txt" + ) + else: + profile_manager.chunk_size = 0 + + try: + + usage_kwargs = {"run_id": run_id} + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + tool_index = Index(tool=util) + doc_id_key = tool_index.generate_file_id( + tool_id=tool_id, + vector_db=vector_db, + embedding=embedding_model, + x2text=x2text_adapter, + chunk_size=str(profile_manager.chunk_size), + chunk_overlap=str(profile_manager.chunk_overlap), + file_path=file_path, + file_hash=None, + ) + if not reindex: + indexed_doc_id = DocumentIndexingService.get_indexed_document_id( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + if indexed_doc_id: + return { + "status": IndexingStatus.COMPLETED_STATUS.value, + "output": indexed_doc_id, + } + # Polling if document is already being indexed + if DocumentIndexingService.is_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ): + return { + "status": IndexingStatus.PENDING_STATUS.value, + "output": IndexingStatus.DOCUMENT_BEING_INDEXED.value, + } + + # Set the document as being indexed + DocumentIndexingService.set_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + doc_id: str = tool_index.index( + tool_id=tool_id, + embedding_instance_id=embedding_model, + vector_db_instance_id=vector_db, + x2text_instance_id=x2text_adapter, + file_path=file_path, + chunk_size=profile_manager.chunk_size, + chunk_overlap=profile_manager.chunk_overlap, + reindex=reindex, + output_file_path=extract_file_path, + usage_kwargs=usage_kwargs.copy(), + ) + + PromptStudioIndexHelper.handle_index_manager( + document_id=document_id, + is_summary=is_summary, + profile_manager=profile_manager, + doc_id=doc_id, + ) + DocumentIndexingService.mark_document_indexed( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key, doc_id=doc_id + ) + return {"status": IndexingStatus.COMPLETED_STATUS.value, "output": doc_id} + except (IndexingError, IndexingAPIError, SdkError) as e: + doc_name = os.path.split(file_path)[1] + PromptStudioHelper._publish_log( + {"tool_id": tool_id, "run_id": run_id, "doc_name": doc_name}, + LogLevels.ERROR, + LogLevels.RUN, + f"Indexing failed : {e}", + ) + raise IndexingAPIError( + f"Error while indexing '{doc_name}'. {str(e)}" + ) from e + + @staticmethod + def _fetch_single_pass_response( + tool: CustomTool, + file_path: str, + prompts: list[ToolStudioPrompt], + org_id: str, + user_id: str, + document_id: str, + run_id: str = None, + ) -> Any: + tool_id: str = str(tool.tool_id) + outputs: list[dict[str, Any]] = [] + grammar: list[dict[str, Any]] = [] + prompt_grammar = tool.prompt_grammer + default_profile = ProfileManager.get_default_llm_profile(tool) + challenge_llm_instance: Optional[AdapterInstance] = tool.challenge_llm + challenge_llm: Optional[str] = None + # Using default profile manager llm if challenge_llm is None + if challenge_llm_instance: + challenge_llm = str(challenge_llm_instance.id) + else: + challenge_llm = str(default_profile.llm.id) + # Need to check the user who created profile manager + PromptStudioHelper.validate_adapter_status(default_profile) + # has access to adapters configured in profile manager + PromptStudioHelper.validate_profile_manager_owner_access(default_profile) + default_profile.chunk_size = 0 # To retrive full context + + if prompt_grammar: + for word, synonyms in prompt_grammar.items(): + grammar.append({TSPKeys.WORD: word, TSPKeys.SYNONYMS: synonyms}) + + if not default_profile: + raise DefaultProfileError() + + index_result = PromptStudioHelper.dynamic_indexer( + profile_manager=default_profile, + file_path=file_path, + tool_id=tool_id, + org_id=org_id, + is_summary=tool.summarize_as_source, + document_id=document_id, + run_id=run_id, + user_id=user_id, + ) + if index_result.get("status") == IndexingStatus.PENDING_STATUS.value: + return { + "status": IndexingStatus.PENDING_STATUS.value, + "message": IndexingStatus.DOCUMENT_BEING_INDEXED.value, + } + + vector_db = str(default_profile.vector_store.id) + embedding_model = str(default_profile.embedding_model.id) + llm = str(default_profile.llm.id) + x2text = str(default_profile.x2text.id) + tool_settings = {} + tool_settings[TSPKeys.PREAMBLE] = tool.preamble + tool_settings[TSPKeys.POSTAMBLE] = tool.postamble + tool_settings[TSPKeys.GRAMMAR] = grammar + tool_settings[TSPKeys.LLM] = llm + tool_settings[TSPKeys.X2TEXT_ADAPTER] = x2text + tool_settings[TSPKeys.VECTOR_DB] = vector_db + tool_settings[TSPKeys.EMBEDDING] = embedding_model + tool_settings[TSPKeys.CHUNK_SIZE] = default_profile.chunk_size + tool_settings[TSPKeys.CHUNK_OVERLAP] = default_profile.chunk_overlap + tool_settings[TSPKeys.ENABLE_CHALLENGE] = tool.enable_challenge + tool_settings[TSPKeys.CHALLENGE_LLM] = challenge_llm + + for prompt in prompts: + if not prompt.prompt: + raise EmptyPromptError() + output: dict[str, Any] = {} + output[TSPKeys.PROMPT] = prompt.prompt + output[TSPKeys.ACTIVE] = prompt.active + output[TSPKeys.TYPE] = prompt.enforce_type + output[TSPKeys.NAME] = prompt.prompt_key + outputs.append(output) + + if tool.summarize_as_source: + path = Path(file_path) + file_path = str(path.parent / TSPKeys.SUMMARIZE / (path.stem + ".txt")) + file_hash = ToolUtils.get_hash_from_file(file_path=file_path) + + payload = { + TSPKeys.TOOL_SETTINGS: tool_settings, + TSPKeys.OUTPUTS: outputs, + TSPKeys.TOOL_ID: tool_id, + TSPKeys.RUN_ID: run_id, + TSPKeys.FILE_HASH: file_hash, + Common.LOG_EVENTS_ID: StateStore.get(Common.LOG_EVENTS_ID), + } + + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + + responder = PromptTool( + tool=util, + prompt_host=settings.PROMPT_HOST, + prompt_port=settings.PROMPT_PORT, + ) + + answer = responder.single_pass_extraction(payload) + # TODO: Make use of dataclasses + if answer["status"] == "ERROR": + error_message = answer.get("error", None) + raise AnswerFetchError( + f"Error while fetching response for prompt. {error_message}" + ) + output_response = json.loads(answer["structure_output"]) + return output_response diff --git a/backend/prompt_studio/prompt_studio_core_v2/serializers.py b/backend/prompt_studio/prompt_studio_core_v2/serializers.py new file mode 100644 index 000000000..2e6b53ba1 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/serializers.py @@ -0,0 +1,116 @@ +import logging +from typing import Any + +from account_v2.models import User +from account_v2.serializer import UserSerializer +from django.core.exceptions import ObjectDoesNotExist +from file_management.constants import FileInformationKey +from prompt_studio.prompt_profile_manager_v2.models import ProfileManager +from prompt_studio.prompt_studio_core_v2.constants import ToolStudioKeys as TSKeys +from prompt_studio.prompt_studio_core_v2.exceptions import DefaultProfileError +from prompt_studio.prompt_studio_v2.models import ToolStudioPrompt +from prompt_studio.prompt_studio_v2.serializers import ToolStudioPromptSerializer +from rest_framework import serializers +from utils.FileValidator import FileValidator + +from backend.serializers import AuditSerializer + +from .models import CustomTool + +logger = logging.getLogger(__name__) + + +class CustomToolSerializer(AuditSerializer): + shared_users = serializers.PrimaryKeyRelatedField( + queryset=User.objects.all(), required=False, allow_null=True, many=True + ) + + class Meta: + model = CustomTool + fields = "__all__" + + def to_representation(self, instance): # type: ignore + data = super().to_representation(instance) + try: + profile_manager = ProfileManager.objects.get( + prompt_studio_tool=instance, is_summarize_llm=True + ) + data[TSKeys.SUMMARIZE_LLM_PROFILE] = profile_manager.profile_id + except ObjectDoesNotExist: + logger.info( + "Summarize LLM profile doesnt exist for prompt tool %s", + str(instance.tool_id), + ) + try: + profile_manager = ProfileManager.get_default_llm_profile(instance) + data[TSKeys.DEFAULT_PROFILE] = profile_manager.profile_id + except DefaultProfileError: + logger.info( + "Default LLM profile doesnt exist for prompt tool %s", + str(instance.tool_id), + ) + try: + prompt_instance: ToolStudioPrompt = ToolStudioPrompt.objects.filter( + tool_id=data.get(TSKeys.TOOL_ID) + ).order_by("sequence_number") + data[TSKeys.PROMPTS] = [] + output: list[Any] = [] + # Appending prompt instances of the tool for FE Processing + if prompt_instance.count() != 0: + for prompt in prompt_instance: + prompt_serializer = ToolStudioPromptSerializer(prompt) + output.append(prompt_serializer.data) + data[TSKeys.PROMPTS] = output + except Exception as e: + logger.error(f"Error occured while appending prompts {e}") + return data + + data["created_by_email"] = instance.created_by.email + + return data + + +class PromptStudioIndexSerializer(serializers.Serializer): + document_id = serializers.CharField() + + +class PromptStudioResponseSerializer(serializers.Serializer): + file_name = serializers.CharField() + tool_id = serializers.CharField() + id = serializers.CharField() + + +class SharedUserListSerializer(serializers.ModelSerializer): + """Used for listing users of Custom tool.""" + + created_by = UserSerializer() + shared_users = UserSerializer(many=True) + + class Meta: + model = CustomTool + fields = ( + "tool_id", + "tool_name", + "created_by", + "shared_users", + ) + + +class FileInfoIdeSerializer(serializers.Serializer): + document_id = serializers.CharField() + view_type = serializers.CharField(required=False) + + +class FileUploadIdeSerializer(serializers.Serializer): + file = serializers.ListField( + child=serializers.FileField(), + required=True, + validators=[ + FileValidator( + allowed_extensions=FileInformationKey.FILE_UPLOAD_ALLOWED_EXT, + allowed_mimetypes=FileInformationKey.FILE_UPLOAD_ALLOWED_MIME, + min_size=0, + max_size=FileInformationKey.FILE_UPLOAD_MAX_SIZE, + ) + ], + ) diff --git a/backend/prompt_studio/prompt_studio_core_v2/static/select_choices.json b/backend/prompt_studio/prompt_studio_core_v2/static/select_choices.json new file mode 100644 index 000000000..1e9e22ef4 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/static/select_choices.json @@ -0,0 +1,33 @@ +{ + "combined_output": + { + "JSON":"JSON", + "YAML":"YAML" + }, + "choose_llm":{ + "AZURE":"Azure OpenAI" + }, + "output_type":{ + "string":"Text", + "number":"number", + "email":"email", + "date":"date", + "boolean":"boolean", + "json":"json" + }, + "output_processing":{ + "DEFAULT":"Default" + }, + "embedding":{ + "azure_openai_embedding":"azure_openai_embedding", + "openai_embedding":"openai_embedding" + }, + "retrieval_strategy":{ + "simple":"simple", + "subquestion":"subquestion" + }, + "vector_store":{ + "Postgres pg_vector":"Postgres pg_vector", + "qdrant":"qdrant" + } +} diff --git a/backend/prompt_studio/prompt_studio_core_v2/urls.py b/backend/prompt_studio/prompt_studio_core_v2/urls.py new file mode 100644 index 000000000..0abc1c436 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/urls.py @@ -0,0 +1,118 @@ +from django.db import transaction +from django.urls import path +from django.utils.decorators import method_decorator +from rest_framework.urlpatterns import format_suffix_patterns + +from .views import PromptStudioCoreView + +prompt_studio_list = PromptStudioCoreView.as_view({"get": "list", "post": "create"}) +prompt_studio_detail = PromptStudioCoreView.as_view( + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", + } +) +prompt_studio_choices = PromptStudioCoreView.as_view({"get": "get_select_choices"}) +prompt_studio_profiles = PromptStudioCoreView.as_view( + {"get": "list_profiles", "patch": "make_profile_default"} +) + +prompt_studio_prompts = PromptStudioCoreView.as_view({"post": "create_prompt"}) + +prompt_studio_profilemanager = PromptStudioCoreView.as_view( + {"post": "create_profile_manager"} +) + +prompt_studio_prompt_index = PromptStudioCoreView.as_view({"post": "index_document"}) +prompt_studio_prompt_response = PromptStudioCoreView.as_view({"post": "fetch_response"}) +prompt_studio_adapter_choices = PromptStudioCoreView.as_view( + {"get": "get_adapter_choices"} +) +prompt_studio_single_pass_extraction = PromptStudioCoreView.as_view( + {"post": "single_pass_extraction"} +) +prompt_studio_users = PromptStudioCoreView.as_view({"get": "list_of_shared_users"}) + + +prompt_studio_file = PromptStudioCoreView.as_view( + { + "post": "upload_for_ide", + "get": "fetch_contents_ide", + "delete": "delete_for_ide", + } +) + +prompt_studio_export = PromptStudioCoreView.as_view( + {"post": "export_tool", "get": "export_tool_info"} +) + + +urlpatterns = format_suffix_patterns( + [ + path("prompt-studio/", prompt_studio_list, name="prompt-studio-list"), + path( + "prompt-studio//", + prompt_studio_detail, + name="tool-studio-detail", + ), + path( + "prompt-studio/select_choices/", + prompt_studio_choices, + name="prompt-studio-choices", + ), + path( + "prompt-studio/prompt-studio-profile//", + prompt_studio_profiles, + name="prompt-studio-profiles", + ), + path( + "prompt-studio/prompt-studio-prompt//", + prompt_studio_prompts, + name="prompt-studio-prompts", + ), + path( + "prompt-studio/profilemanager/", + prompt_studio_profilemanager, + name="prompt-studio-profilemanager", + ), + path( + "prompt-studio/index-document/", + method_decorator(transaction.non_atomic_requests)( + prompt_studio_prompt_index + ), + name="prompt-studio-prompt-index", + ), + path( + "prompt-studio/fetch_response/", + prompt_studio_prompt_response, + name="prompt-studio-prompt-response", + ), + path( + "prompt-studio/adapter-choices/", + prompt_studio_adapter_choices, + name="prompt-studio-adapter-choices", + ), + path( + "prompt-studio/single-pass-extraction/", + prompt_studio_single_pass_extraction, + name="prompt-studio-single-pass-extraction", + ), + path( + "prompt-studio/users/", + prompt_studio_users, + name="prompt-studio-users", + ), + path( + "prompt-studio/file/", + prompt_studio_file, + name="prompt_studio_file", + ), + path( + "prompt-studio/export/", + prompt_studio_export, + name="prompt_studio_export", + ), + ] +) diff --git a/backend/prompt_studio/prompt_studio_core_v2/views.py b/backend/prompt_studio/prompt_studio_core_v2/views.py new file mode 100644 index 000000000..e2b1e1d2a --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/views.py @@ -0,0 +1,525 @@ +import logging +import uuid +from typing import Any, Optional + +from account_v2.custom_exceptions import DuplicateData +from django.db import IntegrityError +from django.db.models import QuerySet +from django.http import HttpRequest +from file_management.exceptions import FileNotFound +from file_management.file_management_helper import FileManagerHelper +from permissions.permission import IsOwner, IsOwnerOrSharedUser +from prompt_studio.processor_loader import ProcessorConfig, load_plugins +from prompt_studio.prompt_profile_manager_v2.constants import ProfileManagerErrors +from prompt_studio.prompt_profile_manager_v2.models import ProfileManager +from prompt_studio.prompt_profile_manager_v2.serializers import ProfileManagerSerializer +from prompt_studio.prompt_studio_core_v2.constants import ( + FileViewTypes, + ToolStudioErrors, + ToolStudioKeys, + ToolStudioPromptKeys, +) +from prompt_studio.prompt_studio_core_v2.document_indexing_service import ( + DocumentIndexingService, +) +from prompt_studio.prompt_studio_core_v2.exceptions import ( + IndexingAPIError, + ToolDeleteError, +) +from prompt_studio.prompt_studio_core_v2.prompt_studio_helper import PromptStudioHelper +from prompt_studio.prompt_studio_document_manager_v2.models import DocumentManager +from prompt_studio.prompt_studio_document_manager_v2.prompt_studio_document_helper import ( # noqa: E501 + PromptStudioDocumentHelper, +) +from prompt_studio.prompt_studio_index_manager_v2.models import IndexManager +from prompt_studio.prompt_studio_registry_v2.prompt_studio_registry_helper import ( + PromptStudioRegistryHelper, +) +from prompt_studio.prompt_studio_registry_v2.serializers import ( + ExportToolRequestSerializer, + PromptStudioRegistryInfoSerializer, +) +from prompt_studio.prompt_studio_v2.constants import ToolStudioPromptErrors +from prompt_studio.prompt_studio_v2.serializers import ToolStudioPromptSerializer +from rest_framework import status, viewsets +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from tool_instance.models import ToolInstance +from unstract.sdk.utils.common_utils import CommonUtils +from utils.user_session import UserSessionUtils + +from unstract.connectors.filesystems.local_storage.local_storage import LocalStorageFS + +from .models import CustomTool +from .serializers import ( + CustomToolSerializer, + FileInfoIdeSerializer, + FileUploadIdeSerializer, + PromptStudioIndexSerializer, + SharedUserListSerializer, +) + +logger = logging.getLogger(__name__) + + +class PromptStudioCoreView(viewsets.ModelViewSet): + """Viewset to handle all Custom tool related operations.""" + + versioning_class = URLPathVersioning + + serializer_class = CustomToolSerializer + + processor_plugins = load_plugins() + + def get_permissions(self) -> list[Any]: + if self.action == "destroy": + return [IsOwner()] + + return [IsOwnerOrSharedUser()] + + def get_queryset(self) -> Optional[QuerySet]: + return CustomTool.objects.for_user(self.request.user) + + def create(self, request: HttpRequest) -> Response: + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + try: + self.perform_create(serializer) + except IntegrityError: + raise DuplicateData( + f"{ToolStudioErrors.TOOL_NAME_EXISTS}, \ + {ToolStudioErrors.DUPLICATE_API}" + ) + PromptStudioHelper.create_default_profile_manager( + request.user, serializer.data["tool_id"] + ) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + def perform_destroy(self, instance: CustomTool) -> None: + organization_id = UserSessionUtils.get_organization_id(self.request) + instance.delete(organization_id) + + def destroy( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> Response: + instance: CustomTool = self.get_object() + # Checks if tool is exported + if hasattr(instance, "prompt_studio_registry"): + exported_tool_instances_in_use = ToolInstance.objects.filter( + tool_id__exact=instance.prompt_studio_registry.pk + ) + dependent_wfs = set() + for tool_instance in exported_tool_instances_in_use: + dependent_wfs.add(tool_instance.workflow_id) + if len(dependent_wfs) > 0: + logger.info( + f"Cannot destroy custom tool {instance.tool_id}," + f" depended by workflows {dependent_wfs}" + ) + raise ToolDeleteError( + "Failed to delete tool, its used in other workflows. " + "Delete its usages first" + ) + return super().destroy(request, *args, **kwargs) + + def partial_update( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> Response: + summarize_llm_profile_id = request.data.get( + ToolStudioKeys.SUMMARIZE_LLM_PROFILE, None + ) + if summarize_llm_profile_id: + prompt_tool = self.get_object() + + ProfileManager.objects.filter(prompt_studio_tool=prompt_tool).update( + is_summarize_llm=False + ) + profile_manager = ProfileManager.objects.get(pk=summarize_llm_profile_id) + profile_manager.is_summarize_llm = True + profile_manager.save() + + return super().partial_update(request, *args, **kwargs) + + @action(detail=True, methods=["get"]) + def get_select_choices(self, request: HttpRequest) -> Response: + """Method to return all static dropdown field values. + + The field values are retrieved from `./static/select_choices.json`. + + Returns: + Response: Reponse of dropdown dict + """ + try: + select_choices: dict[str, Any] = PromptStudioHelper.get_select_fields() + return Response(select_choices, status=status.HTTP_200_OK) + except Exception as e: + logger.error(f"Error occured while fetching select fields {e}") + return Response(select_choices, status=status.HTTP_204_NO_CONTENT) + + @action(detail=True, methods=["get"]) + def list_profiles(self, request: HttpRequest, pk: Any = None) -> Response: + prompt_tool = ( + self.get_object() + ) # Assuming you have a get_object method in your viewset + + profile_manager_instances = ProfileManager.objects.filter( + prompt_studio_tool=prompt_tool + ) + + serialized_instances = ProfileManagerSerializer( + profile_manager_instances, many=True + ).data + + return Response(serialized_instances) + + @action(detail=True, methods=["patch"]) + def make_profile_default(self, request: HttpRequest, pk: Any = None) -> Response: + prompt_tool = ( + self.get_object() + ) # Assuming you have a get_object method in your viewset + + ProfileManager.objects.filter(prompt_studio_tool=prompt_tool).update( + is_default=False + ) + + profile_manager = ProfileManager.objects.get(pk=request.data["default_profile"]) + profile_manager.is_default = True + profile_manager.save() + + return Response( + status=status.HTTP_200_OK, + data={"default_profile": profile_manager.profile_id}, + ) + + @action(detail=True, methods=["post"]) + def index_document(self, request: HttpRequest, pk: Any = None) -> Response: + """API Entry point method to index input file. + + Args: + request (HttpRequest) + + Raises: + IndexingError + ValidationError + + Returns: + Response + """ + tool = self.get_object() + serializer = PromptStudioIndexSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + document_id: str = serializer.validated_data.get( + ToolStudioPromptKeys.DOCUMENT_ID + ) + document: DocumentManager = DocumentManager.objects.get(pk=document_id) + file_name: str = document.document_name + # Generate a run_id + run_id = CommonUtils.generate_uuid() + unique_id = PromptStudioHelper.index_document( + tool_id=str(tool.tool_id), + file_name=file_name, + org_id=UserSessionUtils.get_organization_id(request), + user_id=tool.created_by.user_id, + document_id=document_id, + run_id=run_id, + ) + + usage_kwargs: dict[Any, Any] = dict() + usage_kwargs[ToolStudioPromptKeys.RUN_ID] = run_id + for processor_plugin in self.processor_plugins: + cls = processor_plugin[ProcessorConfig.METADATA][ + ProcessorConfig.METADATA_SERVICE_CLASS + ] + cls.process( + tool_id=str(tool.tool_id), + file_name=file_name, + org_id=UserSessionUtils.get_organization_id(request), + user_id=tool.created_by.user_id, + document_id=document_id, + usage_kwargs=usage_kwargs.copy(), + ) + + if unique_id: + return Response( + {"message": "Document indexed successfully."}, + status=status.HTTP_200_OK, + ) + else: + logger.error("Error occured while indexing. Unique ID is not valid.") + raise IndexingAPIError() + + @action(detail=True, methods=["post"]) + def fetch_response(self, request: HttpRequest, pk: Any = None) -> Response: + """API Entry point method to fetch response to prompt. + + Args: + request (HttpRequest): _description_ + + Raises: + FilenameMissingError: _description_ + + Returns: + Response + """ + custom_tool = self.get_object() + tool_id: str = str(custom_tool.tool_id) + document_id: str = request.data.get(ToolStudioPromptKeys.DOCUMENT_ID) + id: str = request.data.get(ToolStudioPromptKeys.ID) + run_id: str = request.data.get(ToolStudioPromptKeys.RUN_ID) + profile_manager: str = request.data.get(ToolStudioPromptKeys.PROFILE_MANAGER_ID) + if not run_id: + # Generate a run_id + run_id = CommonUtils.generate_uuid() + + response: dict[str, Any] = PromptStudioHelper.prompt_responder( + id=id, + tool_id=tool_id, + org_id=UserSessionUtils.get_organization_id(request), + user_id=custom_tool.created_by.user_id, + document_id=document_id, + run_id=run_id, + profile_manager_id=profile_manager, + ) + return Response(response, status=status.HTTP_200_OK) + + @action(detail=True, methods=["post"]) + def single_pass_extraction(self, request: HttpRequest, pk: uuid) -> Response: + """API Entry point method to fetch response to prompt. + + Args: + request (HttpRequest): _description_ + pk (Any): Primary key of the CustomTool + + Returns: + Response + """ + # TODO: Handle fetch_response and single_pass_ + # extraction using common function + custom_tool = self.get_object() + tool_id: str = str(custom_tool.tool_id) + document_id: str = request.data.get(ToolStudioPromptKeys.DOCUMENT_ID) + run_id: str = request.data.get(ToolStudioPromptKeys.RUN_ID) + if not run_id: + # Generate a run_id + run_id = CommonUtils.generate_uuid() + response: dict[str, Any] = PromptStudioHelper.prompt_responder( + tool_id=tool_id, + org_id=UserSessionUtils.get_organization_id(request), + user_id=custom_tool.created_by.user_id, + document_id=document_id, + run_id=run_id, + ) + return Response(response, status=status.HTTP_200_OK) + + @action(detail=True, methods=["get"]) + def list_of_shared_users(self, request: HttpRequest, pk: Any = None) -> Response: + + custom_tool = ( + self.get_object() + ) # Assuming you have a get_object method in your viewset + + serialized_instances = SharedUserListSerializer(custom_tool).data + + return Response(serialized_instances) + + @action(detail=True, methods=["post"]) + def create_prompt(self, request: HttpRequest, pk: Any = None) -> Response: + context = super().get_serializer_context() + serializer = ToolStudioPromptSerializer(data=request.data, context=context) + serializer.is_valid(raise_exception=True) + try: + # serializer.save() + self.perform_create(serializer) + except IntegrityError: + raise DuplicateData( + f"{ToolStudioPromptErrors.PROMPT_NAME_EXISTS}, \ + {ToolStudioPromptErrors.DUPLICATE_API}" + ) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + @action(detail=True, methods=["post"]) + def create_profile_manager(self, request: HttpRequest, pk: Any = None) -> Response: + context = super().get_serializer_context() + serializer = ProfileManagerSerializer(data=request.data, context=context) + + serializer.is_valid(raise_exception=True) + try: + self.perform_create(serializer) + except IntegrityError: + raise DuplicateData( + f"{ProfileManagerErrors.PROFILE_NAME_EXISTS}, \ + {ProfileManagerErrors.DUPLICATE_API}" + ) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + @action(detail=True, methods=["get"]) + def fetch_contents_ide(self, request: HttpRequest, pk: Any = None) -> Response: + custom_tool = self.get_object() + serializer = FileInfoIdeSerializer(data=request.GET) + serializer.is_valid(raise_exception=True) + document_id: str = serializer.validated_data.get("document_id") + document: DocumentManager = DocumentManager.objects.get(pk=document_id) + file_name: str = document.document_name + view_type: str = serializer.validated_data.get("view_type") + + filename_without_extension = file_name.rsplit(".", 1)[0] + if view_type == FileViewTypes.EXTRACT: + file_name = ( + f"{FileViewTypes.EXTRACT.lower()}/" f"{filename_without_extension}.txt" + ) + if view_type == FileViewTypes.SUMMARIZE: + file_name = ( + f"{FileViewTypes.SUMMARIZE.lower()}/" + f"{filename_without_extension}.txt" + ) + + file_path = file_path = FileManagerHelper.handle_sub_directory_for_tenants( + UserSessionUtils.get_organization_id(request), + is_create=True, + user_id=custom_tool.created_by.user_id, + tool_id=str(custom_tool.tool_id), + ) + file_system = LocalStorageFS(settings={"path": file_path}) + if not file_path.endswith("/"): + file_path += "/" + file_path += file_name + # Temporary Hack for frictionless onboarding as the user id will be empty + try: + contents = FileManagerHelper.fetch_file_contents(file_system, file_path) + except FileNotFound: + file_path = file_path = FileManagerHelper.handle_sub_directory_for_tenants( + UserSessionUtils.get_organization_id(request), + is_create=True, + user_id="", + tool_id=str(custom_tool.tool_id), + ) + if not file_path.endswith("/"): + file_path += "/" + file_path += file_name + contents = FileManagerHelper.fetch_file_contents(file_system, file_path) + + return Response({"data": contents}, status=status.HTTP_200_OK) + + @action(detail=True, methods=["post"]) + def upload_for_ide(self, request: HttpRequest, pk: Any = None) -> Response: + custom_tool = self.get_object() + serializer = FileUploadIdeSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + uploaded_files: Any = serializer.validated_data.get("file") + + file_path = FileManagerHelper.handle_sub_directory_for_tenants( + UserSessionUtils.get_organization_id(request), + is_create=True, + user_id=custom_tool.created_by.user_id, + tool_id=str(custom_tool.tool_id), + ) + file_system = LocalStorageFS(settings={"path": file_path}) + + documents = [] + for uploaded_file in uploaded_files: + file_name = uploaded_file.name + + # Create a record in the db for the file + document = PromptStudioDocumentHelper.create( + tool_id=str(custom_tool.tool_id), document_name=file_name + ) + # Create a dictionary to store document data + doc = { + "document_id": document.document_id, + "document_name": document.document_name, + "tool": document.tool.tool_id, + } + # Store file + logger.info( + f"Uploading file: {file_name}" if file_name else "Uploading file" + ) + FileManagerHelper.upload_file( + file_system, + file_path, + uploaded_file, + file_name, + ) + documents.append(doc) + return Response({"data": documents}) + + @action(detail=True, methods=["delete"]) + def delete_for_ide(self, request: HttpRequest, pk: uuid) -> Response: + custom_tool = self.get_object() + serializer = FileInfoIdeSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + document_id: str = serializer.validated_data.get( + ToolStudioPromptKeys.DOCUMENT_ID + ) + org_id = UserSessionUtils.get_organization_id(request) + user_id = custom_tool.created_by.user_id + document: DocumentManager = DocumentManager.objects.get(pk=document_id) + file_name: str = document.document_name + file_path = FileManagerHelper.handle_sub_directory_for_tenants( + org_id=org_id, + is_create=False, + user_id=user_id, + tool_id=str(custom_tool.tool_id), + ) + path = file_path + file_system = LocalStorageFS(settings={"path": path}) + try: + # Delete indexed flags in redis + index_managers = IndexManager.objects.filter(document_manager=document_id) + for index_manager in index_managers: + raw_index_id = index_manager.raw_index_id + DocumentIndexingService.remove_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=raw_index_id + ) + # Delete the document record + document.delete() + # Delete the files + FileManagerHelper.delete_file(file_system, path, file_name) + # Directories to delete the text files + directories = ["extract/", "summarize/"] + FileManagerHelper.delete_related_files( + file_system, path, file_name, directories + ) + return Response( + {"data": "File deleted succesfully."}, + status=status.HTTP_200_OK, + ) + except Exception as exc: + logger.error(f"Exception thrown from file deletion, error {exc}") + return Response( + {"data": "File deletion failed."}, + status=status.HTTP_400_BAD_REQUEST, + ) + + @action(detail=True, methods=["post"]) + def export_tool(self, request: Request, pk: Any = None) -> Response: + """API Endpoint for exporting required jsons for the custom tool.""" + custom_tool = self.get_object() + serializer = ExportToolRequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + is_shared_with_org: bool = serializer.validated_data.get("is_shared_with_org") + user_ids = set(serializer.validated_data.get("user_id")) + + PromptStudioRegistryHelper.update_or_create_psr_tool( + custom_tool=custom_tool, + shared_with_org=is_shared_with_org, + user_ids=user_ids, + ) + return Response( + {"message": "Custom tool exported sucessfully."}, + status=status.HTTP_200_OK, + ) + + @action(detail=True, methods=["get"]) + def export_tool_info(self, request: Request, pk: Any = None) -> Response: + custom_tool = self.get_object() + serialized_instances = None + if hasattr(custom_tool, "prompt_studio_registry"): + serialized_instances = PromptStudioRegistryInfoSerializer( + custom_tool.prompt_studio_registry + ).data + + return Response(serialized_instances) + else: + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/backend/prompt_studio/prompt_studio_registry_v2/__init__.py b/backend/prompt_studio/prompt_studio_registry_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/prompt_studio/prompt_studio_registry_v2/admin.py b/backend/prompt_studio/prompt_studio_registry_v2/admin.py new file mode 100644 index 000000000..9f6bdeb83 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_registry_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import PromptStudioRegistry + +admin.site.register(PromptStudioRegistry) diff --git a/backend/prompt_studio/prompt_studio_registry_v2/apps.py b/backend/prompt_studio/prompt_studio_registry_v2/apps.py new file mode 100644 index 000000000..e70823fb7 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_registry_v2/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class PromptStudioRegistry(AppConfig): + name = "prompt_studio.prompt_studio_registry_v2" diff --git a/backend/prompt_studio/prompt_studio_registry_v2/constants.py b/backend/prompt_studio/prompt_studio_registry_v2/constants.py new file mode 100644 index 000000000..6dbde1c09 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_registry_v2/constants.py @@ -0,0 +1,109 @@ +class PromptStudioRegistryKeys: + CREATED_BY = "created_by" + TOOL_ID = "tool_id" + NUMBER = "Number" + FLOAT = "Float" + PG_VECTOR = "Postgres pg_vector" + ANSWERS = "answers" + UNIQUE_FILE_ID = "unique_file_id" + PROMPT_REGISTRY_ID = "prompt_registry_id" + FILE_NAME = "file_name" + UNDEFINED = "undefined" + + +class PromptStudioRegistryErrors: + SERIALIZATION_FAILED = "Data Serialization Failed." + DUPLICATE_API = "It appears that a duplicate call may have been made." + CUSTOM_TOOL_EXISTS = "Custom tool with similiar configuration already exists" + + +class LogLevels: + INFO = "INFO" + ERROR = "ERROR" + DEBUG = "DEBUG" + RUN = "RUN" + + +# TODO: Update prompt studio constants to have a single source of truth +class JsonSchemaKey: + TYPE = "type" + TITLE = "title" + DEFAULT = "default" + ENUM = "enum" + DESCRIPTION = "description" + REQUIRED = "required" + STRING = "string" + PROCESSOR_TO_USE = "Processor to use" + AZURE_OPEN_AI = "Azure OpenAI" + PROPERTIES = "properties" + DISPLAY_NAME = "display_name" + FUNCTION_NAME = "function_name" + PARAMETERS = "parameters" + VERSIONS = "versions" + OUTPUT_TYPE = "output_type" + INPUT_TYPE = "input_type" + IS_CACHABLE = "is_cacheable" + REQUIRES = "requires" + DEFAULT_DESCRIPTION_PROCESSOR = "Use Unstract processor \ + if you do not want to use a cloud provider for privacy reasons" + NAME = "name" + ACTIVE = "active" + PROMPT = "prompt" + CHUNK_SIZE = "chunk-size" + PROMPTX = "promptx" + VECTOR_DB = "vector-db" + EMBEDDING = "embedding" + X2TEXT_ADAPTER = "x2text_adapter" + CHUNK_OVERLAP = "chunk-overlap" + LLM = "llm" + RETRIEVAL_STRATEGY = "retrieval-strategy" + SIMPLE = "simple" + TYPE = "type" + NUMBER = "number" + EMAIL = "email" + DATE = "date" + BOOLEAN = "boolean" + JSON = "json" + PREAMBLE = "preamble" + SIMILARITY_TOP_K = "similarity-top-k" + PROMPT_TOKENS = "prompt_tokens" + COMPLETION_TOKENS = "completion_tokens" + TOTAL_TOKENS = "total_tokens" + RESPONSE = "response" + POSTAMBLE = "postamble" + GRAMMAR = "grammar" + WORD = "word" + SYNONYMS = "synonyms" + OUTPUTS = "outputs" + SECTION = "section" + DEFAULT = "default" + AUTHOR = "author" + ICON = "icon" + REINDEX = "reindex" + TOOL_ID = "tool_id" + EMBEDDING_SUFFIX = "embedding_suffix" + FUNCTION_NAME = "function_name" + PROMPT_REGISTRY_ID = "prompt_registry_id" + NOTES = "NOTES" + TOOL_SETTINGS = "tool_settings" + ENABLE_CHALLENGE = "enable_challenge" + CHALLENGE_LLM = "challenge_llm" + ENABLE_SINGLE_PASS_EXTRACTION = "enable_single_pass_extraction" + IMAGE_URL = "image_url" + IMAGE_NAME = "image_name" + IMAGE_TAG = "image_tag" + SUMMARIZE_PROMPT = "summarize_prompt" + SUMMARIZE_AS_SOURCE = "summarize_as_source" + ENABLE_HIGHLIGHT = "enable_highlight" + + +class SpecKey: + PROCESSOR = "processor" + SPEC = "spec" + OUTPUT_FOLDER = "outputFolder" + CREATE_OUTPUT_DOCUMENT = "createOutputDocument" + USE_CACHE = "useCache" + EMBEDDING_TRANSFORMER = "embeddingTransformer" + VECTOR_STORE = "vectorstore" + OUTPUT_TYPE = "outputType" + OUTPUT_PROCESSING = "outputProcessing" diff --git a/backend/prompt_studio/prompt_studio_registry_v2/exceptions.py b/backend/prompt_studio/prompt_studio_registry_v2/exceptions.py new file mode 100644 index 000000000..12a182991 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_registry_v2/exceptions.py @@ -0,0 +1,32 @@ +from rest_framework.exceptions import APIException + + +class InternalError(APIException): + status_code = 500 + default_detail = "Internal service error." + + +class ToolDoesNotExist(APIException): + status_code = 500 + default_detail = "Tool does not exist." + + +class ToolSaveError(APIException): + status_code = 500 + default_detail = "Error while saving the tool." + + +class EmptyToolExportError(APIException): + status_code = 500 + default_detail = ( + "Empty Prompt Studio project without prompts cannot be exported. " + "Try adding a prompt and executing it." + ) + + +class InValidCustomToolError(APIException): + status_code = 500 + default_detail = ( + "This prompt studio project cannot be exported. It probably " + "has some empty or unexecuted prompts." + ) diff --git a/backend/prompt_studio/prompt_studio_registry_v2/fields.py b/backend/prompt_studio/prompt_studio_registry_v2/fields.py new file mode 100644 index 000000000..fcf8330b7 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_registry_v2/fields.py @@ -0,0 +1,31 @@ +import logging + +from django.db import models + +logger = logging.getLogger(__name__) + + +class ToolBaseJSONField(models.JSONField): + def from_db_value(self, value, expression, connection): # type: ignore + metadata = super().from_db_value(value, expression, connection) + return metadata + + +# TODO: Investigate if ToolBaseJSONField can replace the need for ToolPropertyJSONField, +# ToolSpecJSONField, ToolVariablesJSONField, and ToolMetadataJSONField classes. + + +class ToolPropertyJSONField(ToolBaseJSONField, models.JSONField): + pass + + +class ToolSpecJSONField(ToolBaseJSONField, models.JSONField): + pass + + +class ToolVariablesJSONField(ToolBaseJSONField, models.JSONField): + pass + + +class ToolMetadataJSONField(ToolBaseJSONField, models.JSONField): + pass diff --git a/backend/prompt_studio/prompt_studio_registry_v2/models.py b/backend/prompt_studio/prompt_studio_registry_v2/models.py new file mode 100644 index 000000000..ec768a854 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_registry_v2/models.py @@ -0,0 +1,110 @@ +import logging +import uuid +from typing import Any + +from account_v2.models import User +from django.db import models +from django.db.models import QuerySet +from prompt_studio.prompt_studio_registry_v2.fields import ( + ToolMetadataJSONField, + ToolPropertyJSONField, + ToolSpecJSONField, +) +from prompt_studio.prompt_studio_v2.models import CustomTool +from utils.models.base_model import BaseModel +from utils.models.organization_mixin import ( + DefaultOrganizationManagerMixin, + DefaultOrganizationMixin, +) + +logger = logging.getLogger(__name__) + + +class PromptStudioRegistryModelManager(DefaultOrganizationManagerMixin, models.Manager): + def get_queryset(self) -> QuerySet[Any]: + return super().get_queryset() + + def list_tools(self, user: User) -> QuerySet[Any]: + return ( + self.get_queryset() + .filter( + models.Q(created_by=user) + | models.Q(shared_users=user) + | models.Q(shared_to_org=True) + ) + .distinct("prompt_registry_id") + ) + + +class PromptStudioRegistry(DefaultOrganizationMixin, BaseModel): + """Data model to export JSON fields needed for registering the Custom tool + to the tool registry. + + By default the tools will be added to private tool hub. + """ + + prompt_registry_id = models.UUIDField( + primary_key=True, default=uuid.uuid4, editable=False + ) + name = models.CharField(editable=False, default="") + description = models.CharField(editable=False, default="") + tool_property = ToolPropertyJSONField( + db_column="tool_property", + db_comment="PROPERTIES of the tool", + null=False, + blank=False, + default=dict, + ) + tool_spec = ToolSpecJSONField( + db_column="tool_spec", + db_comment="SPEC of the tool", + null=False, + blank=False, + default=dict, + ) + tool_metadata = ToolMetadataJSONField( + db_column="tool_metadata", + db_comment="Metadata from Prompt Studio", + null=False, + blank=False, + default=dict, + ) + icon = models.CharField(db_comment="Tool icon in svg format", editable=False) + url = models.CharField(editable=False) + custom_tool = models.OneToOneField( + CustomTool, + on_delete=models.CASCADE, + related_name="prompt_studio_registries", + editable=False, + null=True, + ) + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="prompt_registries_created", + null=True, + blank=True, + editable=False, + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="prompt_registries_modified", + null=True, + blank=True, + editable=False, + ) + shared_to_org = models.BooleanField( + default=False, + db_comment="Is the exported tool shared with entire org", + ) + # Introduced field to establish M2M relation between users and tools. + # This will introduce intermediary table which relates both the models. + shared_users = models.ManyToManyField(User, related_name="prompt_registries") + + objects = PromptStudioRegistryModelManager() + + class Meta: + verbose_name = "Prompt Studio Registry" + verbose_name_plural = "Prompt Studio Registries" + db_table = "prompt_studio_registry_v2" diff --git a/backend/prompt_studio/prompt_studio_registry_v2/prompt_studio_registry_helper.py b/backend/prompt_studio/prompt_studio_registry_v2/prompt_studio_registry_helper.py new file mode 100644 index 000000000..e9d267e4d --- /dev/null +++ b/backend/prompt_studio/prompt_studio_registry_v2/prompt_studio_registry_helper.py @@ -0,0 +1,382 @@ +import logging +from typing import Any, Optional + +from account_v2.models import User +from adapter_processor_v2.models import AdapterInstance +from django.conf import settings +from django.db import IntegrityError +from prompt_studio.prompt_profile_manager_v2.models import ProfileManager +from prompt_studio.prompt_studio_core_v2.models import CustomTool +from prompt_studio.prompt_studio_core_v2.prompt_studio_helper import PromptStudioHelper +from prompt_studio.prompt_studio_output_manager_v2.models import ( + PromptStudioOutputManager, +) +from prompt_studio.prompt_studio_v2.models import ToolStudioPrompt +from unstract.tool_registry.dto import Properties, Spec, Tool + +from .constants import JsonSchemaKey +from .exceptions import ( + EmptyToolExportError, + InternalError, + InValidCustomToolError, + ToolSaveError, +) +from .models import PromptStudioRegistry +from .serializers import PromptStudioRegistrySerializer + +logger = logging.getLogger(__name__) + + +class PromptStudioRegistryHelper: + """Class to register custom tools to tool studio registry. + + By default the exported tools will be private and will be executed + with the help of a proto tool. + """ + + @staticmethod + def frame_spec(tool: CustomTool) -> Spec: + """Method to return spec of the Custom tool. + + Args: + tool (CustomTool): Saved tool data + + Returns: + dict: spec dict + """ + properties = { + "challenge_llm": { + "type": "string", + "title": "Challenge LLM", + "adapterType": "LLM", + "description": "LLM to use for challenge", + "adapterIdKey": "challenge_llm_adapter_id", + }, + "enable_challenge": { + "type": "boolean", + "title": "Enable challenge", + "default": False, + "description": "Enables Challenge", + }, + "summarize_as_source": { + "type": "boolean", + "title": "Summarize and use summary as source", + "default": False, + "description": "Enables summary and use summarized content as source", + }, + "single_pass_extraction_mode": { + "type": "boolean", + "title": "Enable Single pass extraction", + "default": False, + "description": "Enables single pass extraction", + }, + } + + spec = Spec( + title=str(tool.tool_id), + description=tool.description, + required=[JsonSchemaKey.CHALLENGE_LLM], + properties=properties, + ) + return spec + + @staticmethod + def frame_properties(tool: CustomTool) -> Properties: + """Method to return properties of the tool. + + Args: + tool (CustomTool): Saved custom tool data. + + Returns: + dict: Properties dict + """ + # TODO: Update for new architecture + tool_props = Properties( + display_name=tool.tool_name, + function_name=str(tool.tool_id), + description=tool.description, + ) + return tool_props + + @staticmethod + def get_tool_by_prompt_registry_id( + prompt_registry_id: str, + ) -> Optional[Tool]: + """Gets the `Tool` associated with a prompt registry ID if it exists. + + Args: + prompt_registry_id (str): Prompt registry ID to fetch for + + Returns: + Optional[Tool]: The `Tool` exported from Prompt Studio + """ + try: + prompt_registry_tool = PromptStudioRegistry.objects.get( + pk=prompt_registry_id + ) + # Suppress all exceptions to allow processing + except Exception as e: + logger.warning( + "Error while fetching for prompt registry " + f"ID {prompt_registry_id}: {e} " + ) + return None + # The below properties are introduced after 0.20.0 + # So defaulting to 0.20.0 if the properties are not found + image_url = prompt_registry_tool.tool_metadata.get( + JsonSchemaKey.IMAGE_URL, "docker:unstract/tool-structure:0.0.20" + ) + image_name = prompt_registry_tool.tool_metadata.get( + JsonSchemaKey.IMAGE_NAME, "unstract/tool-structure" + ) + image_tag = prompt_registry_tool.tool_metadata.get( + JsonSchemaKey.IMAGE_TAG, "0.0.20" + ) + return Tool( + tool_uid=prompt_registry_tool.prompt_registry_id, + properties=Properties.from_dict(prompt_registry_tool.tool_property), + spec=Spec.from_dict(prompt_registry_tool.tool_spec), + icon=prompt_registry_tool.icon, + image_url=image_url, + image_name=image_name, + image_tag=image_tag, + ) + + @staticmethod + def update_or_create_psr_tool( + custom_tool: CustomTool, shared_with_org: bool, user_ids: set[int] + ) -> PromptStudioRegistry: + """Updates or creates the PromptStudioRegistry record. + + This appears as a separate tool in the workflow and is mapped + 1:1 with the `CustomTool`. + + Args: + tool_id (str): ID of the custom tool. + + Raises: + ToolSaveError + InternalError + + Returns: + obj: PromptStudioRegistry instance that was updated or created + """ + try: + properties: Properties = PromptStudioRegistryHelper.frame_properties( + tool=custom_tool + ) + spec: Spec = PromptStudioRegistryHelper.frame_spec(tool=custom_tool) + prompts: list[ToolStudioPrompt] = PromptStudioHelper.fetch_prompt_from_tool( + tool_id=custom_tool.tool_id + ) + metadata = PromptStudioRegistryHelper.frame_export_json( + tool=custom_tool, prompts=prompts + ) + + obj: PromptStudioRegistry + created: bool + obj, created = PromptStudioRegistry.objects.update_or_create( + custom_tool=custom_tool, + created_by=custom_tool.created_by, + modified_by=custom_tool.modified_by, + defaults={ + "name": custom_tool.tool_name, + "tool_property": properties.to_dict(), + "tool_spec": spec.to_dict(), + "tool_metadata": metadata, + "icon": custom_tool.icon, + "description": custom_tool.description, + }, + ) + if created: + logger.info(f"PSR {obj.prompt_registry_id} was created") + else: + logger.info(f"PSR {obj.prompt_registry_id} was updated") + + obj.shared_to_org = shared_with_org + if not shared_with_org: + obj.shared_users.clear() + obj.shared_users.add(*user_ids) + # add prompt studio users + # for shared_user in custom_tool.shared_users: + obj.shared_users.add( + *custom_tool.shared_users.all().values_list("id", flat=True) + ) + # add prompt studio owner + obj.shared_users.add(custom_tool.created_by) + else: + obj.shared_users.clear() + obj.save() + return obj + except IntegrityError as error: + logger.error( + "Integrity Error - Error occurred while " + f"exporting custom tool : {error}" + ) + raise ToolSaveError + + @staticmethod + def frame_export_json( + tool: CustomTool, prompts: list[ToolStudioPrompt] + ) -> dict[str, Any]: + export_metadata = {} + + prompt_grammer = tool.prompt_grammer + grammar_list = [] + grammer_dict = {} + outputs: list[dict[str, Any]] = [] + output: dict[str, Any] = {} + invalidated_prompts: list[str] = [] + invalidated_outputs: list[str] = [] + + if not prompts: + raise EmptyToolExportError() + + if prompt_grammer: + for word, synonyms in prompt_grammer.items(): + synonyms = prompt_grammer[word] + grammer_dict[JsonSchemaKey.WORD] = word + grammer_dict[JsonSchemaKey.SYNONYMS] = synonyms + grammar_list.append(grammer_dict) + grammer_dict = {} + + export_metadata[JsonSchemaKey.NAME] = tool.tool_name + export_metadata[JsonSchemaKey.DESCRIPTION] = tool.description + export_metadata[JsonSchemaKey.AUTHOR] = tool.author + export_metadata[JsonSchemaKey.TOOL_ID] = str(tool.tool_id) + export_metadata[JsonSchemaKey.IMAGE_URL] = settings.STRUCTURE_TOOL_IMAGE_URL + export_metadata[JsonSchemaKey.IMAGE_NAME] = settings.STRUCTURE_TOOL_IMAGE_NAME + export_metadata[JsonSchemaKey.IMAGE_TAG] = settings.STRUCTURE_TOOL_IMAGE_TAG + + default_llm_profile = ProfileManager.get_default_llm_profile(tool) + challenge_llm_instance: Optional[AdapterInstance] = tool.challenge_llm + challenge_llm: Optional[str] = None + # Using default profile manager llm if challenge_llm is None + if challenge_llm_instance: + challenge_llm = str(challenge_llm_instance.id) + else: + challenge_llm = str(default_llm_profile.llm.id) + + embedding_suffix = "" + adapter_id = "" + vector_db = str(default_llm_profile.vector_store.id) + embedding_model = str(default_llm_profile.embedding_model.id) + llm = str(default_llm_profile.llm.id) + x2text = str(default_llm_profile.x2text.id) + + # Tool settings + tool_settings = {} + tool_settings[JsonSchemaKey.SUMMARIZE_PROMPT] = tool.summarize_prompt + tool_settings[JsonSchemaKey.SUMMARIZE_AS_SOURCE] = tool.summarize_as_source + tool_settings[JsonSchemaKey.PREAMBLE] = tool.preamble + tool_settings[JsonSchemaKey.POSTAMBLE] = tool.postamble + tool_settings[JsonSchemaKey.GRAMMAR] = grammar_list + tool_settings[JsonSchemaKey.LLM] = llm + tool_settings[JsonSchemaKey.X2TEXT_ADAPTER] = x2text + tool_settings[JsonSchemaKey.VECTOR_DB] = vector_db + tool_settings[JsonSchemaKey.EMBEDDING] = embedding_model + tool_settings[JsonSchemaKey.CHUNK_SIZE] = default_llm_profile.chunk_size + tool_settings[JsonSchemaKey.CHUNK_OVERLAP] = default_llm_profile.chunk_overlap + tool_settings[JsonSchemaKey.ENABLE_CHALLENGE] = tool.enable_challenge + tool_settings[JsonSchemaKey.CHALLENGE_LLM] = challenge_llm + tool_settings[JsonSchemaKey.ENABLE_SINGLE_PASS_EXTRACTION] = ( + tool.single_pass_extraction_mode + ) + + for prompt in prompts: + + if not prompt.prompt: + invalidated_prompts.append(prompt.prompt_key) + continue + + prompt_output = PromptStudioOutputManager.objects.filter( + tool_id=tool.tool_id, + prompt_id=prompt.prompt_id, + profile_manager=prompt.profile_manager, + ).all() + + if not prompt_output: + invalidated_outputs.append(prompt.prompt_key) + continue + + if prompt.prompt_type == JsonSchemaKey.NOTES: + continue + if not prompt.profile_manager: + prompt.profile_manager = default_llm_profile + + vector_db = str(prompt.profile_manager.vector_store.id) + embedding_model = str(prompt.profile_manager.embedding_model.id) + llm = str(prompt.profile_manager.llm.id) + x2text = str(prompt.profile_manager.x2text.id) + adapter_id = str(prompt.profile_manager.embedding_model.adapter_id) + embedding_suffix = adapter_id.split("|")[0] + + output[JsonSchemaKey.PROMPT] = prompt.prompt + output[JsonSchemaKey.ACTIVE] = prompt.active + output[JsonSchemaKey.CHUNK_SIZE] = prompt.profile_manager.chunk_size + output[JsonSchemaKey.VECTOR_DB] = vector_db + output[JsonSchemaKey.EMBEDDING] = embedding_model + output[JsonSchemaKey.X2TEXT_ADAPTER] = x2text + output[JsonSchemaKey.CHUNK_OVERLAP] = prompt.profile_manager.chunk_overlap + output[JsonSchemaKey.LLM] = llm + output[JsonSchemaKey.PREAMBLE] = tool.preamble + output[JsonSchemaKey.POSTAMBLE] = tool.postamble + output[JsonSchemaKey.GRAMMAR] = grammar_list + output[JsonSchemaKey.TYPE] = prompt.enforce_type + output[JsonSchemaKey.NAME] = prompt.prompt_key + output[JsonSchemaKey.RETRIEVAL_STRATEGY] = ( + prompt.profile_manager.retrieval_strategy + ) + output[JsonSchemaKey.SIMILARITY_TOP_K] = ( + prompt.profile_manager.similarity_top_k + ) + output[JsonSchemaKey.SECTION] = prompt.profile_manager.section + output[JsonSchemaKey.REINDEX] = prompt.profile_manager.reindex + output[JsonSchemaKey.EMBEDDING_SUFFIX] = embedding_suffix + outputs.append(output) + output = {} + vector_db = "" + embedding_suffix = "" + adapter_id = "" + llm = "" + embedding_model = "" + + if invalidated_prompts: + raise InValidCustomToolError( + f"Cannot export tool. Prompt(s): {', '.join(invalidated_prompts)} " + "are empty. Please enter a valid prompt." + ) + if invalidated_outputs: + raise InValidCustomToolError( + f"Cannot export tool. Prompt(s): {', '.join(invalidated_outputs)} " + "were not run. Please run them before exporting." + ) + export_metadata[JsonSchemaKey.TOOL_SETTINGS] = tool_settings + export_metadata[JsonSchemaKey.OUTPUTS] = outputs + return export_metadata + + @staticmethod + def fetch_json_for_registry(user: User) -> list[dict[str, Any]]: + try: + # filter the Prompt studio registry based on the users and org flag + prompt_studio_tools = PromptStudioRegistry.objects.list_tools(user) + pi_serializer = PromptStudioRegistrySerializer( + instance=prompt_studio_tools, many=True + ) + except Exception as error: + logger.error(f"Error occured while fetching tool for tool_id: {error}") + raise InternalError() + tool_metadata: dict[str, Any] = {} + tool_list = [] + for prompts in pi_serializer.data: + tool_metadata[JsonSchemaKey.NAME] = prompts.get(JsonSchemaKey.NAME) + tool_metadata[JsonSchemaKey.DESCRIPTION] = prompts.get( + JsonSchemaKey.DESCRIPTION + ) + tool_metadata[JsonSchemaKey.ICON] = prompts.get(JsonSchemaKey.ICON) + tool_metadata[JsonSchemaKey.FUNCTION_NAME] = prompts.get( + JsonSchemaKey.PROMPT_REGISTRY_ID + ) + tool_list.append(tool_metadata) + tool_metadata = {} + return tool_list diff --git a/backend/prompt_studio/prompt_studio_registry_v2/serializers.py b/backend/prompt_studio/prompt_studio_registry_v2/serializers.py new file mode 100644 index 000000000..e995bb266 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_registry_v2/serializers.py @@ -0,0 +1,38 @@ +from typing import Any + +from account_v2.serializer import UserSerializer +from rest_framework import serializers + +from backend.serializers import AuditSerializer + +from .models import PromptStudioRegistry + + +class PromptStudioRegistrySerializer(AuditSerializer): + class Meta: + model = PromptStudioRegistry + fields = "__all__" + + +class PromptStudioRegistryInfoSerializer(AuditSerializer): + shared_users = UserSerializer(many=True) + prompt_studio_users = serializers.SerializerMethodField() + + class Meta: + model = PromptStudioRegistry + fields = ( + "name", + "shared_users", + "shared_to_org", + "prompt_studio_users", + ) + + def get_prompt_studio_users(self, obj: PromptStudioRegistry) -> Any: + + prompt_studio_users = obj.custom_tool.shared_users + return UserSerializer(prompt_studio_users, many=True).data + + +class ExportToolRequestSerializer(serializers.Serializer): + is_shared_with_org = serializers.BooleanField(default=False) + user_id = serializers.ListField(child=serializers.IntegerField(), required=False) diff --git a/backend/prompt_studio/prompt_studio_registry_v2/urls.py b/backend/prompt_studio/prompt_studio_registry_v2/urls.py new file mode 100644 index 000000000..9cc93e05e --- /dev/null +++ b/backend/prompt_studio/prompt_studio_registry_v2/urls.py @@ -0,0 +1,3 @@ +from rest_framework.urlpatterns import format_suffix_patterns + +urlpatterns = format_suffix_patterns([]) diff --git a/backend/prompt_studio/prompt_studio_registry_v2/views.py b/backend/prompt_studio/prompt_studio_registry_v2/views.py new file mode 100644 index 000000000..6b8a5d7e7 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_registry_v2/views.py @@ -0,0 +1,35 @@ +import logging +from typing import Optional + +from django.db.models import QuerySet +from prompt_studio.prompt_studio_registry_v2.constants import PromptStudioRegistryKeys +from prompt_studio.prompt_studio_registry_v2.serializers import ( + PromptStudioRegistrySerializer, +) +from rest_framework import viewsets +from rest_framework.versioning import URLPathVersioning +from utils.filtering import FilterHelper + +from .models import PromptStudioRegistry + +logger = logging.getLogger(__name__) + + +class PromptStudioRegistryView(viewsets.ModelViewSet): + """Driver class to handle export and registering of custom tools to private + tool hub.""" + + versioning_class = URLPathVersioning + queryset = PromptStudioRegistry.objects.all() + serializer_class = PromptStudioRegistrySerializer + + def get_queryset(self) -> Optional[QuerySet]: + filterArgs = FilterHelper.build_filter_args( + self.request, + PromptStudioRegistryKeys.PROMPT_REGISTRY_ID, + ) + queryset = None + if filterArgs: + queryset = PromptStudioRegistry.objects.filter(**filterArgs) + + return queryset diff --git a/backend/prompt_studio/prompt_studio_v2/__init__.py b/backend/prompt_studio/prompt_studio_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/prompt_studio/prompt_studio_v2/admin.py b/backend/prompt_studio/prompt_studio_v2/admin.py new file mode 100644 index 000000000..bbcd76c6e --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import ToolStudioPrompt + +admin.site.register(ToolStudioPrompt) diff --git a/backend/prompt_studio/prompt_studio_v2/apps.py b/backend/prompt_studio/prompt_studio_v2/apps.py new file mode 100644 index 000000000..0a6f1e517 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class ToolStudioPrompt(AppConfig): + name = "prompt_studio.prompt_studio_v2" diff --git a/backend/prompt_studio/prompt_studio_v2/constants.py b/backend/prompt_studio/prompt_studio_v2/constants.py new file mode 100644 index 000000000..6554a9f8c --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/constants.py @@ -0,0 +1,32 @@ +class ToolStudioPromptKeys: + CREATED_BY = "created_by" + TOOL_ID = "tool_id" + NUMBER = "Number" + FLOAT = "Float" + PG_VECTOR = "Postgres pg_vector" + ANSWERS = "answers" + UNIQUE_FILE_ID = "unique_file_id" + ID = "id" + FILE_NAME = "file_name" + UNDEFINED = "undefined" + ACTIVE = "active" + PROMPT_KEY = "prompt_key" + EVAL_METRIC_PREFIX = "eval_" + EVAL_RESULT_DELIM = "__" + SEQUENCE_NUMBER = "sequence_number" + START_SEQUENCE_NUMBER = "start_sequence_number" + END_SEQUENCE_NUMBER = "end_sequence_number" + PROMPT_ID = "prompt_id" + + +class ToolStudioPromptErrors: + SERIALIZATION_FAILED = "Data Serialization Failed." + DUPLICATE_API = "It appears that a duplicate call may have been made." + PROMPT_NAME_EXISTS = "Prompt with the name already exists" + + +class LogLevels: + INFO = "INFO" + ERROR = "ERROR" + DEBUG = "DEBUG" + RUN = "RUN" diff --git a/backend/prompt_studio/prompt_studio_v2/controller.py b/backend/prompt_studio/prompt_studio_v2/controller.py new file mode 100644 index 000000000..30cacfe29 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/controller.py @@ -0,0 +1,58 @@ +import logging + +from prompt_studio.prompt_studio_v2.constants import ToolStudioPromptKeys +from prompt_studio.prompt_studio_v2.helper import PromptStudioHelper +from prompt_studio.prompt_studio_v2.models import ToolStudioPrompt +from prompt_studio.prompt_studio_v2.serializers import ReorderPromptsSerializer +from rest_framework import status +from rest_framework.request import Request +from rest_framework.response import Response + +logger = logging.getLogger(__name__) + + +class PromptStudioController: + def reorder_prompts(self, request: Request) -> Response: + """Reorder the sequence of prompts based on the start and end sequence + numbers. + + This action handles the reordering of prompts by updating their sequence + numbers. It increments or decrements the sequence numbers of the relevant + prompts to reflect the new order. If the start and end sequence numbers + are equal, it returns a bad request response. + + Args: + request (Request): The HTTP request object containing the data to + reorder prompts. + + Returns: + Response: A Response object with the status of the reordering operation. + """ + try: + # Validate request data + serializer = ReorderPromptsSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + + # Extract validated data from the serializer + start_sequence_number = serializer.validated_data.get( + ToolStudioPromptKeys.START_SEQUENCE_NUMBER + ) + end_sequence_number = serializer.validated_data.get( + ToolStudioPromptKeys.END_SEQUENCE_NUMBER + ) + prompt_id = serializer.validated_data.get(ToolStudioPromptKeys.PROMPT_ID) + + filtered_prompts_data = PromptStudioHelper.reorder_prompts_helper( + prompt_id=prompt_id, + start_sequence_number=start_sequence_number, + end_sequence_number=end_sequence_number, + ) + + logger.info("Re-ordering completed successfully.") + return Response(status=status.HTTP_200_OK, data=filtered_prompts_data) + + except ToolStudioPrompt.DoesNotExist: + logger.error(f"Prompt with ID {prompt_id} not found.") + return Response( + status=status.HTTP_404_NOT_FOUND, data={"detail": "Prompt not found."} + ) diff --git a/backend/prompt_studio/prompt_studio_v2/exceptions.py b/backend/prompt_studio/prompt_studio_v2/exceptions.py new file mode 100644 index 000000000..c78c3a740 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/exceptions.py @@ -0,0 +1,16 @@ +from rest_framework.exceptions import APIException + + +class IndexingError(APIException): + status_code = 400 + default_detail = "Error while indexing file" + + +class AnswerFetchError(APIException): + status_code = 400 + default_detail = "Error occured while fetching response for the prompt" + + +class ToolNotValid(APIException): + status_code = 400 + default_detail = "Custom tool is not valid." diff --git a/backend/prompt_studio/prompt_studio_v2/helper.py b/backend/prompt_studio/prompt_studio_v2/helper.py new file mode 100644 index 000000000..8491332a4 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/helper.py @@ -0,0 +1,111 @@ +import logging + +from prompt_studio.prompt_studio_v2.models import ToolStudioPrompt + +logger = logging.getLogger(__name__) + + +class PromptStudioHelper: + @staticmethod + def reorder_prompts_helper( + prompt_id: str, start_sequence_number: int, end_sequence_number: int + ) -> list[dict[str, int]]: + """Helper method to reorder prompts based on sequence numbers. + + Args: + prompt_id (str): The ID of the prompt to be reordered. + start_sequence_number (int): The initial sequence number of the prompt. + end_sequence_number (int): The new sequence number of the prompt. + + Returns: + list[dict[str, int]]: A list of updated prompt data with their IDs + and new sequence numbers. + """ + prompt_instance: ToolStudioPrompt = ToolStudioPrompt.objects.get(pk=prompt_id) + filtered_prompts_data = [] + tool_id = prompt_instance.tool_id + + # Determine the direction of sequence adjustment based on start + # and end sequence numbers + if start_sequence_number < end_sequence_number: + logger.info( + "Start sequence number is less than end sequence number. " + "Decrementing sequence numbers." + ) + filters = { + "sequence_number__gt": start_sequence_number, + "sequence_number__lte": end_sequence_number, + "tool_id": tool_id, + } + increment = False + + elif start_sequence_number > end_sequence_number: + logger.info( + "Start sequence number is greater than end sequence number. " + "Incrementing sequence numbers." + ) + filters = { + "sequence_number__lt": start_sequence_number, + "sequence_number__gte": end_sequence_number, + "tool_id": tool_id, + } + increment = True + + # Call helper method to update sequence numbers and get filtered prompt data + filtered_prompts_data = PromptStudioHelper.update_sequence_numbers( + filters, increment + ) + + # Update the sequence number of the moved prompt + prompt_instance.sequence_number = end_sequence_number + prompt_instance.save() + + # Append the updated prompt instance data to the response + filtered_prompts_data.append( + { + "id": prompt_instance.prompt_id, + "sequence_number": prompt_instance.sequence_number, + } + ) + + return filtered_prompts_data + + @staticmethod + def update_sequence_numbers(filters: dict, increment: bool) -> list[dict[str, int]]: + """Update the sequence numbers for prompts based on the provided + filters and increment flag. + + Args: + filters (dict): The filter criteria for selecting prompts. + increment (bool): Whether to increment (True) or decrement (False) + the sequence numbers. + + Returns: + list[dict[str, int]]: A list of updated prompt data with their IDs + and new sequence numbers. + """ + # Filter prompts based on the provided filters + filtered_prompts = ToolStudioPrompt.objects.filter(**filters) + + # List to hold updated prompt data + filtered_prompts_data = [] + + # Prepare updates and collect data + for prompt in filtered_prompts: + if increment: + prompt.sequence_number += 1 + else: + prompt.sequence_number -= 1 + + # Append prompt data to the list + filtered_prompts_data.append( + { + "id": prompt.prompt_id, + "sequence_number": prompt.sequence_number, + } + ) + + # Bulk update the sequence numbers + ToolStudioPrompt.objects.bulk_update(filtered_prompts, ["sequence_number"]) + + return filtered_prompts_data diff --git a/backend/prompt_studio/prompt_studio_v2/models.py b/backend/prompt_studio/prompt_studio_v2/models.py new file mode 100644 index 000000000..ba928f7c5 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/models.py @@ -0,0 +1,127 @@ +import uuid + +from account_v2.models import User +from django.db import models +from prompt_studio.prompt_profile_manager_v2.models import ProfileManager +from prompt_studio.prompt_studio_core_v2.models import CustomTool +from utils.models.base_model import BaseModel + + +class ToolStudioPrompt(BaseModel): + """Model class while store Prompt data for Custom Tool Studio. + + It has Many to one relation with CustomTool for ToolStudio. + """ + + class EnforceType(models.TextChoices): + TEXT = "Text", "Response sent as Text" + NUMBER = "number", "Response sent as number" + EMAIL = "email", "Response sent as email" + DATE = "date", "Response sent as date" + BOOLEAN = "boolean", "Response sent as boolean" + JSON = "json", "Response sent as json" + + class PromptType(models.TextChoices): + PROMPT = "PROMPT", "Response sent as Text" + NOTES = "NOTES", "Response sent as float" + + class Mode(models.TextChoices): + DEFAULT = "Default", "Default choice for output" + + prompt_id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + prompt_key = models.TextField( + blank=False, + db_comment="Field to store the prompt key", + ) + enforce_type = models.TextField( + blank=True, + db_comment="Field to store the type in \ + which the response to be returned.", + choices=EnforceType.choices, + default=EnforceType.TEXT, + ) + prompt = models.TextField( + blank=True, db_comment="Field to store the prompt", unique=False + ) + tool_id = models.ForeignKey( + CustomTool, + on_delete=models.CASCADE, + related_name="tool_studio_prompts", + null=True, + blank=True, + ) + sequence_number = models.IntegerField(null=True, blank=True) + prompt_type = models.TextField( + blank=True, + db_comment="Field to store the type of the input prompt", + choices=PromptType.choices, + ) + profile_manager = models.ForeignKey( + ProfileManager, + on_delete=models.SET_NULL, + related_name="tool_studio_prompts", + null=True, + blank=True, + ) + output = models.TextField(blank=True) + # TODO: Remove below 3 fields related to assertion + assert_prompt = models.TextField( + blank=True, + null=True, + db_comment="Field to store the asserted prompt", + unique=False, + ) + assertion_failure_prompt = models.TextField( + blank=True, + null=True, + db_comment="Field to store the prompt key", + unique=False, + ) + is_assert = models.BooleanField(default=False) + active = models.BooleanField(default=True, null=False, blank=False) + output_metadata = models.JSONField( + db_column="output_metadata", + null=False, + blank=False, + default=dict, + db_comment="JSON adapter metadata for the FE to load the pagination", + ) + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="tool_studio_prompts_created", + null=True, + blank=True, + editable=False, + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="tool_studio_prompts_modified", + null=True, + blank=True, + editable=False, + ) + # Eval settings for the prompt + # NOTE: + # - Field name format is eval__ + # - Metric name alone should be UNIQUE across all eval metrics + evaluate = models.BooleanField(default=True) + eval_quality_faithfulness = models.BooleanField(default=True) + eval_quality_correctness = models.BooleanField(default=True) + eval_quality_relevance = models.BooleanField(default=True) + eval_security_pii = models.BooleanField(default=True) + eval_guidance_toxicity = models.BooleanField(default=True) + eval_guidance_completeness = models.BooleanField(default=True) + # + + class Meta: + verbose_name = "Tool Studio Prompt" + verbose_name_plural = "Tool Studio Prompts" + db_table = "tool_studio_prompt_v2" + constraints = [ + models.UniqueConstraint( + fields=["prompt_key", "tool_id"], + name="unique_prompt_key_tool_id_index", + ), + ] diff --git a/backend/prompt_studio/prompt_studio_v2/serializers.py b/backend/prompt_studio/prompt_studio_v2/serializers.py new file mode 100644 index 000000000..e1adddc33 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/serializers.py @@ -0,0 +1,33 @@ +from rest_framework import serializers + +from backend.serializers import AuditSerializer + +from .models import ToolStudioPrompt + + +class ToolStudioPromptSerializer(AuditSerializer): + class Meta: + model = ToolStudioPrompt + fields = "__all__" + + +class ToolStudioIndexSerializer(serializers.Serializer): + file_name = serializers.CharField() + tool_id = serializers.CharField() + + +class ReorderPromptsSerializer(serializers.Serializer): + start_sequence_number = serializers.IntegerField(required=True) + end_sequence_number = serializers.IntegerField(required=True) + prompt_id = serializers.CharField(required=True) + + def validate(self, data): + start_sequence_number = data.get("start_sequence_number") + end_sequence_number = data.get("end_sequence_number") + + if start_sequence_number == end_sequence_number: + raise serializers.ValidationError( + "Start and end sequence numbers cannot be the same." + ) + + return data diff --git a/backend/prompt_studio/prompt_studio_v2/urls.py b/backend/prompt_studio/prompt_studio_v2/urls.py new file mode 100644 index 000000000..3d8c3d74c --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/urls.py @@ -0,0 +1,30 @@ +from django.urls import path +from rest_framework.urlpatterns import format_suffix_patterns + +from .views import ToolStudioPromptView + +prompt_studio_prompt_detail = ToolStudioPromptView.as_view( + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", + } +) + +reorder_prompts = ToolStudioPromptView.as_view({"post": "reorder_prompts"}) + +urlpatterns = format_suffix_patterns( + [ + path( + "prompt//", + prompt_studio_prompt_detail, + name="tool-studio-prompt-detail", + ), + path( + "prompt/reorder", + reorder_prompts, + name="reorder_prompts", + ), + ] +) diff --git a/backend/prompt_studio/prompt_studio_v2/views.py b/backend/prompt_studio/prompt_studio_v2/views.py new file mode 100644 index 000000000..ccf63aa0d --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/views.py @@ -0,0 +1,56 @@ +from typing import Optional + +from django.db.models import QuerySet +from prompt_studio.permission import PromptAcesssToUser +from prompt_studio.prompt_studio_v2.constants import ToolStudioPromptKeys +from prompt_studio.prompt_studio_v2.controller import PromptStudioController +from prompt_studio.prompt_studio_v2.models import ToolStudioPrompt +from prompt_studio.prompt_studio_v2.serializers import ToolStudioPromptSerializer +from rest_framework import viewsets +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from utils.filtering import FilterHelper + + +class ToolStudioPromptView(viewsets.ModelViewSet): + """Viewset to handle all Tool Studio prompt related API logics. + + Args: + viewsets (_type_) + + Raises: + DuplicateData + FilenameMissingError + IndexingError + ValidationError + """ + + versioning_class = URLPathVersioning + serializer_class = ToolStudioPromptSerializer + permission_classes: list[type[PromptAcesssToUser]] = [PromptAcesssToUser] + + def get_queryset(self) -> Optional[QuerySet]: + filter_args = FilterHelper.build_filter_args( + self.request, + ToolStudioPromptKeys.TOOL_ID, + ) + if filter_args: + queryset = ToolStudioPrompt.objects.filter(**filter_args) + else: + queryset = ToolStudioPrompt.objects.all() + return queryset + + @action(detail=True, methods=["post"]) + def reorder_prompts(self, request: Request) -> Response: + """Reorder the sequence of prompts based on the provided data. + + Args: + request (Request): The HTTP request containing the reorder data. + + Returns: + Response: The HTTP response indicating the status of the reorder operation. + """ + prompt_studio_controller = PromptStudioController() + return prompt_studio_controller.reorder_prompts(request) diff --git a/backend/scheduler/helper.py b/backend/scheduler/helper.py index 30c479b61..7cc30e07a 100644 --- a/backend/scheduler/helper.py +++ b/backend/scheduler/helper.py @@ -3,8 +3,6 @@ from typing import Any from django.db import connection -from pipeline.models import Pipeline -from pipeline.pipeline_processor import PipelineProcessor from rest_framework.serializers import ValidationError from scheduler.constants import SchedulerConstants as SC from scheduler.exceptions import JobDeletionError, JobSchedulingError @@ -15,9 +13,21 @@ disable_task, enable_task, ) -from workflow_manager.workflow.constants import WorkflowExecutionKey, WorkflowKey -from workflow_manager.workflow.serializers import ExecuteWorkflowSerializer +from backend.constants import FeatureFlag +from unstract.flags.feature_flag import check_feature_flag_status + +if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): + from pipeline_v2.models import Pipeline + from pipeline_v2.pipeline_processor import PipelineProcessor + from utils.user_context import UserContext + from workflow_manager.workflow_v2.constants import WorkflowExecutionKey, WorkflowKey + from workflow_manager.workflow_v2.serializers import ExecuteWorkflowSerializer +else: + from pipeline.models import Pipeline + from pipeline.pipeline_processor import PipelineProcessor + from workflow_manager.workflow.constants import WorkflowExecutionKey, WorkflowKey + from workflow_manager.workflow.serializers import ExecuteWorkflowSerializer logger = logging.getLogger(__name__) @@ -46,7 +56,10 @@ def _schedule_task_job(pipeline_id: str, job_data: Any) -> None: workflow_id = serializer.get_workflow_id(serializer.validated_data) # TODO: Remove unused argument in execute_pipeline_task execution_action = serializer.get_execution_action(serializer.validated_data) - org_schema = connection.tenant.schema_name + if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): + organization_id = UserContext.get_organization_identifier() + else: + organization_id = connection.tenant.schema_name create_periodic_task( cron_string=cron_string, @@ -54,7 +67,7 @@ def _schedule_task_job(pipeline_id: str, job_data: Any) -> None: task_path="scheduler.tasks.execute_pipeline_task", task_args=[ str(workflow_id), - org_schema, + organization_id, execution_action or "", execution_id, str(pipeline.pk), diff --git a/backend/scheduler/serializer.py b/backend/scheduler/serializer.py index ef19de11f..65c2c1910 100644 --- a/backend/scheduler/serializer.py +++ b/backend/scheduler/serializer.py @@ -1,11 +1,17 @@ import logging from typing import Any -from pipeline.manager import PipelineManager from rest_framework import serializers from scheduler.constants import SchedulerConstants as SC +from backend.constants import FeatureFlag from backend.constants import FieldLengthConstants as FieldLength +from unstract.flags.feature_flag import check_feature_flag_status + +if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): + from pipeline_v2.manager import PipelineManager +else: + from pipeline.manager import PipelineManager logger = logging.getLogger(__name__) diff --git a/backend/scheduler/tasks.py b/backend/scheduler/tasks.py index dbc6f6101..abbc7cf91 100644 --- a/backend/scheduler/tasks.py +++ b/backend/scheduler/tasks.py @@ -2,15 +2,28 @@ import logging from typing import Any -from account.models import Organization -from account.subscription_loader import load_plugins, validate_etl_run from celery import shared_task from django_celery_beat.models import CrontabSchedule, PeriodicTask from django_tenants.utils import get_tenant_model, tenant_context -from pipeline.models import Pipeline -from pipeline.pipeline_processor import PipelineProcessor -from workflow_manager.workflow.models.workflow import Workflow -from workflow_manager.workflow.workflow_helper import WorkflowHelper + +from backend.constants import FeatureFlag +from unstract.flags.feature_flag import check_feature_flag_status + +if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): + from account_v2.subscription_loader import load_plugins, validate_etl_run + from pipeline_v2.models import Pipeline + from pipeline_v2.pipeline_processor import PipelineProcessor + from utils.user_context import UserContext + from workflow_manager.workflow_v2.models.workflow import Workflow + from workflow_manager.workflow_v2.workflow_helper import WorkflowHelper +else: + from account.models import Organization + from account.subscription_loader import load_plugins, validate_etl_run + from pipeline.models import Pipeline + from pipeline.pipeline_processor import PipelineProcessor + from workflow_manager.workflow.models.workflow import Workflow + from workflow_manager.workflow.workflow_helper import WorkflowHelper + logger = logging.getLogger(__name__) subscription_loader = load_plugins() @@ -58,6 +71,15 @@ def execute_pipeline_task( with_logs: Any, name: Any, ) -> None: + if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): + execute_pipeline_task_v2( + workflow_id=workflow_id, + organization_id=org_schema, + execution_id=execution_id, + pipeline_id=pipepline_id, + pipeline_name=name, + ) + return logger.info(f"Executing pipeline name: {name}") try: logger.info(f"Executing workflow id: {workflow_id}") @@ -92,6 +114,61 @@ def execute_pipeline_task( logger.error(f"Failed to execute pipeline: {name}. Error: {e}") +def execute_pipeline_task_v2( + workflow_id: Any, + organization_id: Any, + execution_id: Any, + pipeline_id: Any, + pipeline_name: Any, +) -> None: + """V2 of execute_pipeline method. + + Args: + workflow_id (Any): UID of workflow entity + org_schema (Any): Organization Identifier + execution_id (Any): UID of execution entity + pipeline_id (Any): UID of pipeline entity + name (Any): pipeline name + """ + try: + logger.info( + f"Executing workflow id: {workflow_id} for pipeline {pipeline_name}" + ) + # Set organization in state store for execution + UserContext.set_organization_identifier(organization_id) + if ( + subscription_loader + and subscription_loader[0] + and not validate_etl_run(organization_id) + ): + try: + logger.info(f"Disabling ETL task: {pipeline_id}") + disable_task(pipeline_id) + except Exception as e: + logger.warning(f"Failed to disable task: {pipeline_id}. Error: {e}") + return + workflow = WorkflowHelper.get_workflow_by_id( + id=workflow_id, organization_id=organization_id + ) + logger.info(f"Executing workflow: {workflow}") + PipelineProcessor.update_pipeline( + pipeline_id, Pipeline.PipelineStatus.INPROGRESS + ) + execution_response = WorkflowHelper.complete_execution( + workflow, execution_id, pipeline_id + ) + logger.info( + f"Execution response for pipeline {pipeline_name} of organization " + f"{organization_id}: {execution_response}" + ) + logger.info( + f"Execution completed for pipeline {pipeline_name} of organization: " + f"{organization_id}" + ) + except Exception as e: + logger.error(f"Failed to execute pipeline: {pipeline_name}. Error: {e}") + + def delete_periodic_task(task_name: str) -> None: try: task = PeriodicTask.objects.get(name=task_name) diff --git a/backend/tenant_account_v2/__init__.py b/backend/tenant_account_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tenant_account_v2/admin.py b/backend/tenant_account_v2/admin.py new file mode 100644 index 000000000..846f6b406 --- /dev/null +++ b/backend/tenant_account_v2/admin.py @@ -0,0 +1 @@ +# Register your models here. diff --git a/backend/tenant_account_v2/apps.py b/backend/tenant_account_v2/apps.py new file mode 100644 index 000000000..cd128f028 --- /dev/null +++ b/backend/tenant_account_v2/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class TenantAccountV2Config(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "tenant_account_v2" diff --git a/backend/tenant_account_v2/constants.py b/backend/tenant_account_v2/constants.py new file mode 100644 index 000000000..29ba63684 --- /dev/null +++ b/backend/tenant_account_v2/constants.py @@ -0,0 +1,14 @@ +class PlatformServiceConstants: + IS_ACTIVE = "is_active" + KEY = "key" + ORGANIZATION = "organization" + ID = "id" + ACTIVATE = "ACTIVATE" + DEACTIVATE = "DEACTIVATE" + ACTION = "action" + KEY_NAME = "key_name" + + +class ErrorMessage: + KEY_EXIST = "Key name already exists" + DUPLICATE_API = "It appears that a duplicate call may have been made." diff --git a/backend/tenant_account_v2/dto.py b/backend/tenant_account_v2/dto.py new file mode 100644 index 000000000..16c37ea58 --- /dev/null +++ b/backend/tenant_account_v2/dto.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + + +@dataclass +class OrganizationLoginResponse: + name: str + display_name: str + organization_id: str + created_at: str + + +@dataclass +class ResetUserPasswordDto: + status: bool + message: str diff --git a/backend/tenant_account_v2/enums.py b/backend/tenant_account_v2/enums.py new file mode 100644 index 000000000..d8209ec2d --- /dev/null +++ b/backend/tenant_account_v2/enums.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class UserRole(Enum): + USER = "user" + ADMIN = "admin" diff --git a/backend/tenant_account_v2/invitation_urls.py b/backend/tenant_account_v2/invitation_urls.py new file mode 100644 index 000000000..aed1f37a1 --- /dev/null +++ b/backend/tenant_account_v2/invitation_urls.py @@ -0,0 +1,20 @@ +from django.urls import path +from tenant_account_v2.invitation_views import InvitationViewSet + +invitation_list = InvitationViewSet.as_view( + { + "get": InvitationViewSet.list_invitations.__name__, + } +) + +invitation_details = InvitationViewSet.as_view( + { + "delete": InvitationViewSet.delete_invitation.__name__, + } +) + + +urlpatterns = [ + path("", invitation_list, name="invitation_list"), + path("/", invitation_details, name="invitation_details"), +] diff --git a/backend/tenant_account_v2/invitation_views.py b/backend/tenant_account_v2/invitation_views.py new file mode 100644 index 000000000..adbd72695 --- /dev/null +++ b/backend/tenant_account_v2/invitation_views.py @@ -0,0 +1,46 @@ +import logging + +from account_v2.authentication_controller import AuthenticationController +from account_v2.dto import MemberInvitation +from rest_framework import status, viewsets +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.response import Response +from tenant_account_v2.serializer import ListInvitationsResponseSerializer +from utils.user_session import UserSessionUtils + +Logger = logging.getLogger(__name__) + + +class InvitationViewSet(viewsets.ViewSet): + @action(detail=False, methods=["GET"]) + def list_invitations(self, request: Request) -> Response: + auth_controller = AuthenticationController() + invitations: list[MemberInvitation] = auth_controller.get_user_invitations( + organization_id=UserSessionUtils.get_organization_id(request), + ) + serialized_members = ListInvitationsResponseSerializer( + invitations, many=True + ).data + return Response( + status=status.HTTP_200_OK, + data={"message": "success", "members": serialized_members}, + ) + + @action(detail=False, methods=["DELETE"]) + def delete_invitation(self, request: Request, id: str) -> Response: + auth_controller = AuthenticationController() + is_deleted: bool = auth_controller.delete_user_invitation( + organization_id=UserSessionUtils.get_organization_id(request), + invitation_id=id, + ) + if is_deleted: + return Response( + status=status.HTTP_204_NO_CONTENT, + data={"status": "success", "message": "success"}, + ) + else: + return Response( + status=status.HTTP_404_NOT_FOUND, + data={"status": "failed", "message": "failed"}, + ) diff --git a/backend/tenant_account_v2/models.py b/backend/tenant_account_v2/models.py new file mode 100644 index 000000000..c2c5b3cdc --- /dev/null +++ b/backend/tenant_account_v2/models.py @@ -0,0 +1,48 @@ +from account_v2.models import User +from django.db import models +from utils.models.organization_mixin import ( + DefaultOrganizationManagerMixin, + DefaultOrganizationMixin, +) + + +class OrganizationMemberModelManager(DefaultOrganizationManagerMixin, models.Manager): + pass + + +class OrganizationMember(DefaultOrganizationMixin): + member_id = models.BigAutoField(primary_key=True) + user = models.ForeignKey( + User, + on_delete=models.CASCADE, + default=None, + related_name="organization_user", + ) + role = models.CharField() + is_login_onboarding_msg = models.BooleanField( + default=True, + db_comment="Flag to indicate whether the onboarding messages are shown", + ) + is_prompt_studio_onboarding_msg = models.BooleanField( + default=True, + db_comment="Flag to indicate whether the prompt studio messages are shown", + ) + + def __str__(self): # type: ignore + return ( + f"OrganizationMember(" + f"{self.member_id}, role: {self.role}, userId: {self.user.user_id})" + ) + + objects = OrganizationMemberModelManager() + + class Meta: + db_table = "organization_member_v2" + verbose_name = "Organization Member" + verbose_name_plural = "Organization Members" + constraints = [ + models.UniqueConstraint( + fields=["organization", "user"], + name="unique_organization_member", + ), + ] diff --git a/backend/tenant_account_v2/organization_member_service.py b/backend/tenant_account_v2/organization_member_service.py new file mode 100644 index 000000000..ae9dd085c --- /dev/null +++ b/backend/tenant_account_v2/organization_member_service.py @@ -0,0 +1,144 @@ +from typing import Any, Optional + +from tenant_account_v2.models import OrganizationMember +from utils.cache_service import CacheService + + +class OrganizationMemberService: + + @staticmethod + def get_user_by_email(email: str) -> Optional[OrganizationMember]: + try: + return OrganizationMember.objects.get(user__email=email) # type: ignore + except OrganizationMember.DoesNotExist: + return None + + @staticmethod + def get_user_by_user_id(user_id: str) -> Optional[OrganizationMember]: + try: + return OrganizationMember.objects.get(user__user_id=user_id) # type: ignore + except OrganizationMember.DoesNotExist: + return None + + @staticmethod + def get_user_by_id(id: str) -> Optional[OrganizationMember]: + try: + return OrganizationMember.objects.get(user=id) # type: ignore + except OrganizationMember.DoesNotExist: + return None + + @staticmethod + def get_members() -> list[OrganizationMember]: + return OrganizationMember.objects.all() + + @staticmethod + def get_members_by_user_email( + user_emails: list[str], values_list_fields: list[str] + ) -> list[dict[str, Any]]: + """Get members by user emails. + + Parameters: + user_emails (list[str]): The emails of the users to get. + values_list_fields (list[str]): The fields to include in the result. + + Returns: + list[dict[str, Any]]: The members. + """ + if not user_emails: + return [] + queryset = OrganizationMember.objects.filter(user__email__in=user_emails) + if values_list_fields is None: + users = queryset.values() + else: + users = queryset.values_list(*values_list_fields) + + return list(users) + + @staticmethod + def delete_user(user: OrganizationMember) -> None: + """Delete a user from an organization. + + Parameters: + user (OrganizationMember): The user to delete. + """ + user.delete() + + @staticmethod + def remove_users_by_user_pks(user_pks: list[str]) -> None: + """Remove a users from an organization. + + Parameters: + user_pks (list[str]): The primary keys of the users to remove. + """ + OrganizationMember.objects.filter(user__in=user_pks).delete() + + @classmethod + def remove_user_by_user_id(cls, user_id: str) -> None: + """Remove a user from an organization. + + Parameters: + user_id (str): The user_id of the user to remove. + """ + user = cls.get_user_by_user_id(user_id) + if user: + cls.delete_user(user) + + @staticmethod + def get_organization_user_cache_key(user_id: str, organization_id: str) -> str: + """Get the cache key for a user in an organization. + + Parameters: + organization_id (str): The ID of the organization. + + Returns: + str: The cache key for a user in the organization. + """ + return f"user_organization:{user_id}:{organization_id}" + + @classmethod + def check_user_membership_in_organization_cache( + cls, user_id: str, organization_id: str + ) -> bool: + """Check if a user exists in an organization. + + Parameters: + user_id (str): The ID of the user to check. + organization_id (str): The ID of the organization to check. + + Returns: + bool: True if the user exists in the organization, False otherwise. + """ + user_organization_key = cls.get_organization_user_cache_key( + user_id, organization_id + ) + return CacheService.check_a_key_exist(user_organization_key) + + @classmethod + def set_user_membership_in_organization_cache( + cls, user_id: str, organization_id: str + ) -> None: + """Set a user's membership in an organization in the cache. + + Parameters: + user_id (str): The ID of the user. + organization_id (str): The ID of the organization. + """ + user_organization_key = cls.get_organization_user_cache_key( + user_id, organization_id + ) + CacheService.set_key(user_organization_key, {}) + + @classmethod + def remove_user_membership_in_organization_cache( + cls, user_id: str, organization_id: str + ) -> None: + """Remove a user's membership in an organization from the cache. + + Parameters: + user_id (str): The ID of the user. + organization_id (str): The ID of the organization. + """ + user_organization_key = cls.get_organization_user_cache_key( + user_id, organization_id + ) + CacheService.delete_a_key(user_organization_key) diff --git a/backend/tenant_account_v2/serializer.py b/backend/tenant_account_v2/serializer.py new file mode 100644 index 000000000..32165ef19 --- /dev/null +++ b/backend/tenant_account_v2/serializer.py @@ -0,0 +1,159 @@ +from collections import OrderedDict +from typing import Any, Optional, Union, cast + +from account_v2.constants import Common +from rest_framework import serializers +from rest_framework.exceptions import ValidationError +from rest_framework.serializers import ModelSerializer +from tenant_account_v2.models import OrganizationMember + + +class OrganizationCallbackSerializer(serializers.Serializer): + id = serializers.CharField(required=False) + + +class OrganizationLoginResponseSerializer(serializers.Serializer): + name = serializers.CharField() + display_name = serializers.CharField() + organization_id = serializers.CharField() + created_at = serializers.CharField() + + +class UserInviteResponseSerializer(serializers.Serializer): + email = serializers.CharField(required=True) + status = serializers.CharField(required=True) + message = serializers.CharField(required=False) + + +class OrganizationMemberSerializer(serializers.ModelSerializer): + email = serializers.CharField(source="user.email", read_only=True) + id = serializers.CharField(source="user.id", read_only=True) + + class Meta: + model = OrganizationMember + fields = ("id", "email", "role") + + +class LimitedUserEmailListSerializer(serializers.ListSerializer): + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.max_elements: int = kwargs.pop("max_elements", Common.MAX_EMAIL_IN_REQUEST) + super().__init__(*args, **kwargs) + + def validate(self, data: list[str]) -> Any: + if len(data) > self.max_elements: + raise ValidationError( + f"Exceeded maximum number of elements ({self.max_elements})" + ) + return data + + +class LimitedUserListSerializer(serializers.ListSerializer): + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.max_elements: int = kwargs.pop("max_elements", Common.MAX_EMAIL_IN_REQUEST) + super().__init__(*args, **kwargs) + + def validate( + self, data: list[dict[str, Union[str, None]]] + ) -> list[dict[str, Union[str, None]]]: + if len(data) > self.max_elements: + raise ValidationError( + f"Exceeded maximum number of elements ({self.max_elements})" + ) + + for item in data: + if not isinstance(item, dict): + raise ValidationError("Each item in the list must be a dictionary.") + if "email" not in item: + raise ValidationError("Each item in the list must have 'email' key.") + if "role" not in item: + item["role"] = None + + return data + + +class InviteUserSerializer(serializers.Serializer): + users = LimitedUserListSerializer( + required=True, + child=serializers.DictField( + child=serializers.CharField(max_length=255, required=True), + required=False, # Make 'role' field optional + ), + max_elements=Common.MAX_EMAIL_IN_REQUEST, + ) + + def get_users( + self, validated_data: dict[str, Any] + ) -> list[dict[str, Union[str, None]]]: + return validated_data.get("users", []) + + +class RemoveUserFromOrganizationSerializer(serializers.Serializer): + emails = LimitedUserEmailListSerializer( + required=True, + child=serializers.EmailField(required=True), + max_elements=Common.MAX_EMAIL_IN_REQUEST, + ) + + def get_user_emails( + self, validated_data: dict[str, Union[list[str], None]] + ) -> list[str]: + return cast(list[str], validated_data.get(Common.USER_EMAILS, [])) + + +class ChangeUserRoleRequestSerializer(serializers.Serializer): + email = serializers.EmailField(required=True) + role = serializers.CharField(required=True) + + def get_user_email( + self, validated_data: dict[str, Union[str, None]] + ) -> Optional[str]: + return validated_data.get(Common.USER_EMAIL) + + def get_user_role( + self, validated_data: dict[str, Union[str, None]] + ) -> Optional[str]: + return validated_data.get(Common.USER_ROLE) + + +class DeleteInvitationRequestSerializer(serializers.Serializer): + id = serializers.EmailField(required=True) + + def get_id(self, validated_data: dict[str, Union[str, None]]) -> Optional[str]: + return validated_data.get(Common.ID) + + +class UserInfoSerializer(serializers.Serializer): + id = serializers.CharField() + email = serializers.CharField() + name = serializers.CharField() + display_name = serializers.CharField() + family_name = serializers.CharField() + picture = serializers.CharField() + + +class GetRolesResponseSerializer(serializers.Serializer): + id = serializers.CharField() + name = serializers.CharField() + description = serializers.CharField() + + def to_representation(self, instance: Any) -> OrderedDict[str, Any]: + data: OrderedDict[str, Any] = super().to_representation(instance) + return data + + +class ListInvitationsResponseSerializer(serializers.Serializer): + id = serializers.CharField() + email = serializers.CharField() + created_at = serializers.CharField() + expires_at = serializers.CharField() + + def to_representation(self, instance: Any) -> OrderedDict[str, Any]: + data: OrderedDict[str, Any] = super().to_representation(instance) + return data + + +class UpdateFlagSerializer(ModelSerializer): + + class Meta: + model = OrganizationMember + fields = ("is_login_onboarding_msg", "is_prompt_studio_onboarding_msg") diff --git a/backend/tenant_account_v2/tests.py b/backend/tenant_account_v2/tests.py new file mode 100644 index 000000000..a39b155ac --- /dev/null +++ b/backend/tenant_account_v2/tests.py @@ -0,0 +1 @@ +# Create your tests here. diff --git a/backend/tenant_account_v2/urls.py b/backend/tenant_account_v2/urls.py new file mode 100644 index 000000000..50e46ab08 --- /dev/null +++ b/backend/tenant_account_v2/urls.py @@ -0,0 +1,11 @@ +from django.urls import include, path +from tenant_account_v2 import invitation_urls, users_urls +from tenant_account_v2.views import get_organization, get_roles, reset_password + +urlpatterns = [ + path("roles", get_roles, name="roles"), + path("users/", include(users_urls)), + path("invitation/", include(invitation_urls)), + path("organization", get_organization, name="get_organization"), + path("reset_password", reset_password, name="reset_password"), +] diff --git a/backend/tenant_account_v2/users_urls.py b/backend/tenant_account_v2/users_urls.py new file mode 100644 index 000000000..7b2765f3f --- /dev/null +++ b/backend/tenant_account_v2/users_urls.py @@ -0,0 +1,37 @@ +from django.urls import path +from tenant_account_v2.users_view import OrganizationUserViewSet + +organization_user_role = OrganizationUserViewSet.as_view( + { + "post": OrganizationUserViewSet.assign_organization_role_to_user.__name__, + "delete": OrganizationUserViewSet.remove_organization_role_from_user.__name__, + } +) + +user_profile = OrganizationUserViewSet.as_view( + { + "get": OrganizationUserViewSet.get_user_profile.__name__, + "put": OrganizationUserViewSet.update_flags.__name__, + } +) + +invite_user = OrganizationUserViewSet.as_view( + { + "post": OrganizationUserViewSet.invite_user.__name__, + } +) + +organization_users = OrganizationUserViewSet.as_view( + { + "get": OrganizationUserViewSet.get_organization_members.__name__, + "delete": OrganizationUserViewSet.remove_members_from_organization.__name__, + } +) + + +urlpatterns = [ + path("", organization_users, name="organization_user"), + path("profile/", user_profile, name="user_profile"), + path("role/", organization_user_role, name="organization_user_role"), + path("invite/", invite_user, name="invite_user"), +] diff --git a/backend/tenant_account_v2/users_view.py b/backend/tenant_account_v2/users_view.py new file mode 100644 index 000000000..b9deeea4f --- /dev/null +++ b/backend/tenant_account_v2/users_view.py @@ -0,0 +1,198 @@ +import logging + +from account_v2.authentication_controller import AuthenticationController +from account_v2.exceptions import BadRequestException +from rest_framework import status, viewsets +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.response import Response +from tenant_account_v2.models import OrganizationMember +from tenant_account_v2.organization_member_service import OrganizationMemberService +from tenant_account_v2.serializer import ( + ChangeUserRoleRequestSerializer, + InviteUserSerializer, + OrganizationMemberSerializer, + RemoveUserFromOrganizationSerializer, + UpdateFlagSerializer, + UserInfoSerializer, + UserInviteResponseSerializer, +) +from utils.user_session import UserSessionUtils + +Logger = logging.getLogger(__name__) + + +class OrganizationUserViewSet(viewsets.ViewSet): + @action(detail=False, methods=["POST"]) + def assign_organization_role_to_user(self, request: Request) -> Response: + serializer = ChangeUserRoleRequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + user_email = serializer.get_user_email(serializer.validated_data) + role = serializer.get_user_role(serializer.validated_data) + if not (user_email and role): + raise BadRequestException + org_id: str = UserSessionUtils.get_organization_id(request) + auth_controller = AuthenticationController() + + auth_controller = AuthenticationController() + update_status = auth_controller.add_user_role( + request.user, org_id, user_email, role + ) + if update_status: + return Response( + status=status.HTTP_200_OK, + data={"status": "success", "message": "success"}, + ) + else: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"status": "failed", "message": "failed"}, + ) + + @action(detail=False, methods=["DELETE"]) + def remove_organization_role_from_user(self, request: Request) -> Response: + serializer = ChangeUserRoleRequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + user_email = serializer.get_user_email(serializer.validated_data) + role = serializer.get_user_role(serializer.validated_data) + if not (user_email and role): + raise BadRequestException + org_id: str = UserSessionUtils.get_organization_id(request) + auth_controller = AuthenticationController() + + auth_controller = AuthenticationController() + update_status = auth_controller.remove_user_role( + request.user, org_id, user_email, role + ) + if update_status: + return Response( + status=status.HTTP_200_OK, + data={"status": "success", "message": "success"}, + ) + else: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"status": "failed", "message": "failed"}, + ) + + @action(detail=False, methods=["GET"]) + def get_user_profile(self, request: Request) -> Response: + auth_controller = AuthenticationController() + try: + user_info = auth_controller.get_user_info(request) + role = auth_controller.get_organization_members_by_user(request.user) + if not user_info: + return Response( + status=status.HTTP_404_NOT_FOUND, + data={"message": "User Not Found"}, + ) + serialized_user_info = UserInfoSerializer(user_info).data + # Temporary fix for getting user role along with user info. + # Proper implementation would be adding role field to UserInfo. + serialized_user_info["is_admin"] = auth_controller.is_admin_by_role( + role.role + ) + # changes for displying onboarding msgs + org_member = OrganizationMemberService.get_user_by_id(id=request.user.id) + serialized_user_info["login_onboarding_message_displayed"] = ( + org_member.is_login_onboarding_msg + ) + serialized_user_info["prompt_onboarding_message_displayed"] = ( + org_member.is_prompt_studio_onboarding_msg + ) + + return Response( + status=status.HTTP_200_OK, data={"user": serialized_user_info} + ) + except Exception as error: + Logger.error(f"Error while get User : {error}") + return Response( + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + data={"message": "Internal Error"}, + ) + + @action(detail=False, methods=["POST"]) + def invite_user(self, request: Request) -> Response: + serializer = InviteUserSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + user_list = serializer.get_users(serializer.validated_data) + auth_controller = AuthenticationController() + invite_response = auth_controller.invite_user( + admin=request.user, + org_id=UserSessionUtils.get_organization_id(request), + user_list=user_list, + ) + + response_serializer = UserInviteResponseSerializer(invite_response, many=True) + + if invite_response and len(invite_response) != 0: + response = Response( + status=status.HTTP_200_OK, + data={"message": response_serializer.data}, + ) + else: + response = Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": "failed"}, + ) + return response + + @action(detail=False, methods=["DELETE"]) + def remove_members_from_organization(self, request: Request) -> Response: + serializer = RemoveUserFromOrganizationSerializer(data=request.data) + + serializer.is_valid(raise_exception=True) + user_emails = serializer.get_user_emails(serializer.validated_data) + organization_id: str = UserSessionUtils.get_organization_id(request) + + auth_controller = AuthenticationController() + is_updated = auth_controller.remove_users_from_organization( + admin=request.user, + organization_id=organization_id, + user_emails=user_emails, + ) + if is_updated: + return Response( + status=status.HTTP_200_OK, + data={"status": "success", "message": "success"}, + ) + else: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"status": "failed", "message": "failed"}, + ) + + @action(detail=False, methods=["GET"]) + def get_organization_members(self, request: Request) -> Response: + auth_controller = AuthenticationController() + if UserSessionUtils.get_organization_id(request): + members: list[OrganizationMember] = ( + auth_controller.get_organization_members_by_org_id() + ) + serialized_members = OrganizationMemberSerializer(members, many=True).data + return Response( + status=status.HTTP_200_OK, + data={"message": "success", "members": serialized_members}, + ) + return Response( + status=status.HTTP_401_UNAUTHORIZED, + data={"message": "cookie not found"}, + ) + + @action(detail=False, methods=["PUT"]) + def update_flags(self, request: Request) -> Response: + serializer = UpdateFlagSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + org_member = OrganizationMemberService.get_user_by_id(id=request.user.id) + org_member.is_login_onboarding_msg = serializer.validated_data.get( + "is_login_onboarding_msg" + ) + + org_member.is_prompt_studio_onboarding_msg = serializer.validated_data.get( + "is_prompt_studio_onboarding_msg" + ) + org_member.save() + return Response( + status=status.HTTP_200_OK, + data={"status": "success", "message": "success"}, + ) diff --git a/backend/tenant_account_v2/views.py b/backend/tenant_account_v2/views.py new file mode 100644 index 000000000..e64cfd4c2 --- /dev/null +++ b/backend/tenant_account_v2/views.py @@ -0,0 +1,89 @@ +import logging +from typing import Any + +from account_v2.authentication_controller import AuthenticationController +from account_v2.dto import UserRoleData +from account_v2.models import Organization +from rest_framework import status +from rest_framework.decorators import api_view +from rest_framework.request import Request +from rest_framework.response import Response +from tenant_account_v2.dto import OrganizationLoginResponse, ResetUserPasswordDto +from tenant_account_v2.serializer import ( + GetRolesResponseSerializer, + OrganizationLoginResponseSerializer, +) +from utils.user_session import UserSessionUtils + +logger = logging.getLogger(__name__) + + +@api_view(["GET"]) +def logout(request: Request) -> Response: + auth_controller = AuthenticationController() + return auth_controller.user_logout(request) + + +@api_view(["GET"]) +def get_roles(request: Request) -> Response: + auth_controller = AuthenticationController() + roles: list[UserRoleData] = auth_controller.get_user_roles() + serialized_members = GetRolesResponseSerializer(roles, many=True).data + return Response( + status=status.HTTP_200_OK, + data={"message": "success", "members": serialized_members}, + ) + + +@api_view(["POST"]) +def reset_password(request: Request) -> Response: + auth_controller = AuthenticationController() + data: ResetUserPasswordDto = auth_controller.reset_user_password(request.user) + if data.status: + return Response( + status=status.HTTP_200_OK, + data={"status": "success", "message": data.message}, + ) + else: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"status": "failed", "message": data.message}, + ) + + +@api_view(["GET"]) +def get_organization(request: Request) -> Response: + auth_controller = AuthenticationController() + try: + organization_id = UserSessionUtils.get_organization_id(request) + org_data = auth_controller.get_organization_info(organization_id) + if not org_data: + return Response( + status=status.HTTP_404_NOT_FOUND, + data={"message": "Org Not Found"}, + ) + response = makeSignupResponse(org_data) + return Response( + status=status.HTTP_201_CREATED, + data={"message": "success", "organization": response}, + ) + + except Exception as error: + logger.error(f"Error while get User : {error}") + return Response( + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + data={"message": "Internal Error"}, + ) + + +def makeSignupResponse( + organization: Organization, +) -> Any: + return OrganizationLoginResponseSerializer( + OrganizationLoginResponse( + organization.name, + organization.display_name, + organization.organization_id, + organization.created_at, + ) + ).data diff --git a/backend/tool_instance_v2/__init__.py b/backend/tool_instance_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tool_instance_v2/admin.py b/backend/tool_instance_v2/admin.py new file mode 100644 index 000000000..1e1c975ca --- /dev/null +++ b/backend/tool_instance_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import ToolInstance + +admin.site.register(ToolInstance) diff --git a/backend/tool_instance_v2/apps.py b/backend/tool_instance_v2/apps.py new file mode 100644 index 000000000..b75d98958 --- /dev/null +++ b/backend/tool_instance_v2/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class ToolInstanceConfig(AppConfig): + name = "tool_instance_v2" diff --git a/backend/tool_instance_v2/constants.py b/backend/tool_instance_v2/constants.py new file mode 100644 index 000000000..17a3891c5 --- /dev/null +++ b/backend/tool_instance_v2/constants.py @@ -0,0 +1,45 @@ +class ToolInstanceKey: + """Dict keys for ToolInstance model.""" + + PK = "id" + TOOL_ID = "tool_id" + VERSION = "version" + METADATA = "metadata" + STEP = "step" + STATUS = "status" + WORKFLOW = "workflow" + INPUT = "input" + OUTPUT = "output" + TI_COUNT = "tool_instance_count" + + +class JsonSchemaKey: + """Dict Keys for Tool's Json schema.""" + + PROPERTIES = "properties" + THEN = "then" + INPUT_FILE_CONNECTOR = "inputFileConnector" + OUTPUT_FILE_CONNECTOR = "outputFileConnector" + OUTPUT_FOLDER = "outputFolder" + ROOT_FOLDER = "rootFolder" + TENANT_ID = "tenant_id" + INPUT_DB_CONNECTOR = "inputDBConnector" + OUTPUT_DB_CONNECTOR = "outputDBConnector" + ENUM = "enum" + PROJECT_DEFAULT = "Project Default" + + +class ToolInstanceErrors: + TOOL_EXISTS = "Tool with this configuration already exists." + DUPLICATE_API = "It appears that a duplicate call may have been made." + + +class ToolKey: + """Dict keys for a Tool.""" + + NAME = "name" + DESCRIPTION = "description" + ICON = "icon" + FUNCTION_NAME = "function_name" + OUTPUT_TYPE = "output_type" + INPUT_TYPE = "input_type" diff --git a/backend/tool_instance_v2/exceptions.py b/backend/tool_instance_v2/exceptions.py new file mode 100644 index 000000000..69c2c26a5 --- /dev/null +++ b/backend/tool_instance_v2/exceptions.py @@ -0,0 +1,46 @@ +from typing import Optional + +from rest_framework.exceptions import APIException + + +class ToolInstanceBaseException(APIException): + def __init__( + self, + detail: Optional[str] = None, + code: Optional[int] = None, + tool_name: Optional[str] = None, + ) -> None: + detail = detail or self.default_detail + if tool_name is not None: + detail = f"{detail} Tool: {tool_name}" + super().__init__(detail, code) + + +class ToolFunctionIsMandatory(ToolInstanceBaseException): + status_code = 400 + default_detail = "Tool function is mandatory." + + +class ToolDoesNotExist(ToolInstanceBaseException): + status_code = 400 + default_detail = "Tool doesn't exist." + + +class FetchToolListFailed(ToolInstanceBaseException): + status_code = 400 + default_detail = "Failed to fetch tool list." + + +class ToolInstantiationError(ToolInstanceBaseException): + status_code = 500 + default_detail = "Error instantiating tool." + + +class BadRequestException(ToolInstanceBaseException): + status_code = 400 + default_detail = "Invalid input." + + +class ToolSettingValidationError(APIException): + status_code = 400 + default_detail = "Error while validating tool's setting." diff --git a/backend/tool_instance_v2/models.py b/backend/tool_instance_v2/models.py new file mode 100644 index 000000000..a7609566d --- /dev/null +++ b/backend/tool_instance_v2/models.py @@ -0,0 +1,99 @@ +import uuid + +from account_v2.models import User +from connector_v2.models import ConnectorInstance +from django.db import models +from django.db.models import QuerySet +from utils.models.base_model import BaseModel +from workflow_manager.workflow_v2.models.workflow import Workflow + +TOOL_ID_LENGTH = 64 +TOOL_VERSION_LENGTH = 16 +TOOL_STATUS_LENGTH = 32 + + +class ToolInstanceManager(models.Manager): + def get_instances_for_workflow( + self, workflow: uuid.UUID + ) -> QuerySet["ToolInstance"]: + return self.filter(workflow=workflow) + + +class ToolInstance(BaseModel): + class Status(models.TextChoices): + PENDING = "PENDING", "Settings Not Configured" + READY = "READY", "Ready to Start" + INITIATED = "INITIATED", "Initialization in Progress" + COMPLETED = "COMPLETED", "Process Completed" + ERROR = "ERROR", "Error Encountered" + + workflow = models.ForeignKey( + Workflow, + on_delete=models.CASCADE, + related_name="tool_instances", + null=False, + blank=False, + ) + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + tool_id = models.CharField( + max_length=TOOL_ID_LENGTH, + db_comment="Function name of the tool being used", + ) + input = models.JSONField(null=True, db_comment="Provisional WF input to a tool") + output = models.JSONField(null=True, db_comment="Provisional WF output to a tool") + version = models.CharField(max_length=TOOL_VERSION_LENGTH) + metadata = models.JSONField(db_comment="Stores config for a tool") + step = models.IntegerField() + # TODO: Make as an enum supporting fixed values once we have clarity + status = models.CharField(max_length=TOOL_STATUS_LENGTH, default="Ready to start") + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="tool_instances_created", + null=True, + blank=True, + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="tool_instances_modified", + null=True, + blank=True, + ) + # Added these connectors separately + # for file and db for scalability + input_file_connector = models.ForeignKey( + ConnectorInstance, + on_delete=models.SET_NULL, + related_name="input_file_connectors", + null=True, + blank=True, + ) + output_file_connector = models.ForeignKey( + ConnectorInstance, + on_delete=models.SET_NULL, + related_name="output_file_connectors", + null=True, + blank=True, + ) + input_db_connector = models.ForeignKey( + ConnectorInstance, + on_delete=models.SET_NULL, + related_name="input_db_connectors", + null=True, + blank=True, + ) + output_db_connector = models.ForeignKey( + ConnectorInstance, + on_delete=models.SET_NULL, + related_name="output_db_connectors", + null=True, + blank=True, + ) + + objects = ToolInstanceManager() + + class Meta: + verbose_name = "Tool Instance" + verbose_name_plural = "Tool Instances" + db_table = "tool_instance_v2" diff --git a/backend/tool_instance_v2/serializers.py b/backend/tool_instance_v2/serializers.py new file mode 100644 index 000000000..86801545d --- /dev/null +++ b/backend/tool_instance_v2/serializers.py @@ -0,0 +1,129 @@ +import logging +import uuid +from typing import Any + +from prompt_studio.prompt_studio_registry_v2.constants import PromptStudioRegistryKeys +from rest_framework.serializers import ListField, Serializer, UUIDField, ValidationError +from tool_instance_v2.constants import ToolInstanceKey as TIKey +from tool_instance_v2.constants import ToolKey +from tool_instance_v2.exceptions import ToolDoesNotExist +from tool_instance_v2.models import ToolInstance +from tool_instance_v2.tool_instance_helper import ToolInstanceHelper +from tool_instance_v2.tool_processor import ToolProcessor +from unstract.tool_registry.dto import Tool +from workflow_manager.workflow_v2.constants import WorkflowKey +from workflow_manager.workflow_v2.models.workflow import Workflow + +from backend.constants import RequestKey +from backend.serializers import AuditSerializer + +logger = logging.getLogger(__name__) + + +class ToolInstanceSerializer(AuditSerializer): + workflow_id = UUIDField(write_only=True) + + class Meta: + model = ToolInstance + fields = "__all__" + extra_kwargs = { + TIKey.WORKFLOW: { + "required": False, + }, + TIKey.VERSION: { + "required": False, + }, + TIKey.METADATA: { + "required": False, + }, + TIKey.STEP: { + "required": False, + }, + } + + def to_representation(self, instance: ToolInstance) -> dict[str, str]: + rep: dict[str, Any] = super().to_representation(instance) + tool_function = rep.get(TIKey.TOOL_ID) + + if tool_function is None: + raise ToolDoesNotExist() + try: + tool: Tool = ToolProcessor.get_tool_by_uid(tool_function) + except ToolDoesNotExist: + return rep + rep[ToolKey.ICON] = tool.icon + rep[ToolKey.NAME] = tool.properties.display_name + # Need to Change it into better method + if self.context.get(RequestKey.REQUEST): + metadata = ToolInstanceHelper.get_altered_metadata(instance) + if metadata: + rep[TIKey.METADATA] = metadata + return rep + + def create(self, validated_data: dict[str, Any]) -> Any: + workflow_id = validated_data.pop(WorkflowKey.WF_ID) + try: + workflow = Workflow.objects.get(pk=workflow_id) + except Workflow.DoesNotExist: + raise ValidationError(f"Workflow with ID {workflow_id} does not exist.") + validated_data[TIKey.WORKFLOW] = workflow + + tool_uid = validated_data.get(TIKey.TOOL_ID) + if not tool_uid: + raise ToolDoesNotExist() + + tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid) + # TODO: Handle other fields once tools SDK is out + validated_data[TIKey.PK] = uuid.uuid4() + # TODO: Use version from tool props + validated_data[TIKey.VERSION] = "" + validated_data[TIKey.METADATA] = { + # TODO: Review and remove tool instance ID + WorkflowKey.WF_TOOL_INSTANCE_ID: str(validated_data[TIKey.PK]), + PromptStudioRegistryKeys.PROMPT_REGISTRY_ID: str(tool_uid), + **ToolProcessor.get_default_settings(tool), + } + if TIKey.STEP not in validated_data: + validated_data[TIKey.STEP] = workflow.tool_instances.count() + 1 + # Workflow will get activated on adding tools to workflow + if not workflow.is_active: + workflow.is_active = True + workflow.save() + return super().create(validated_data) + + +class ToolInstanceReorderSerializer(Serializer): + workflow_id = UUIDField() + tool_instances = ListField(child=UUIDField()) + + def validate(self, data: dict[str, Any]) -> dict[str, Any]: + workflow_id = data.get(WorkflowKey.WF_ID) + tool_instances = data.get(WorkflowKey.WF_TOOL_INSTANCES, []) + + # Check if the workflow exists + try: + workflow = Workflow.objects.get(pk=workflow_id) + except Workflow.DoesNotExist: + raise ValidationError(f"Workflow with ID {workflow_id} does not exist.") + + # Check if the number of tool instances matches the actual count + tool_instance_count = workflow.tool_instances.count() + if len(tool_instances) != tool_instance_count: + msg = ( + f"Incorrect number of tool instances passed: " + f"{len(tool_instances)}, expected: {tool_instance_count}" + ) + logger.error(msg) + raise ValidationError(detail=msg) + + # Check if each tool instance exists in the workflow + existing_tool_instance_ids = workflow.tool_instances.values_list( + "id", flat=True + ) + for tool_instance_id in tool_instances: + if tool_instance_id not in existing_tool_instance_ids: + raise ValidationError( + "One or more tool instances do not exist in the workflow." + ) + + return data diff --git a/backend/tool_instance_v2/tool_instance_helper.py b/backend/tool_instance_v2/tool_instance_helper.py new file mode 100644 index 000000000..a1ba2adda --- /dev/null +++ b/backend/tool_instance_v2/tool_instance_helper.py @@ -0,0 +1,458 @@ +import logging +import os +import uuid +from typing import Any, Optional + +from account_v2.models import User +from adapter_processor_v2.adapter_processor import AdapterProcessor +from adapter_processor_v2.models import AdapterInstance +from connector_v2.connector_instance_helper import ConnectorInstanceHelper +from django.core.exceptions import PermissionDenied +from django.core.exceptions import ValidationError as DjangoValidationError +from jsonschema.exceptions import ValidationError as JSONValidationError +from prompt_studio.prompt_studio_registry_v2.models import PromptStudioRegistry +from tool_instance_v2.constants import JsonSchemaKey +from tool_instance_v2.exceptions import ToolSettingValidationError +from tool_instance_v2.models import ToolInstance +from tool_instance_v2.tool_processor import ToolProcessor +from unstract.adapters.enums import AdapterTypes +from unstract.sdk.tool.validator import DefaultsGeneratingValidator +from unstract.tool_registry.constants import AdapterPropertyKey +from unstract.tool_registry.dto import Spec, Tool +from unstract.tool_registry.tool_utils import ToolUtils +from workflow_manager.workflow_v2.constants import WorkflowKey + +logger = logging.getLogger(__name__) + + +class ToolInstanceHelper: + @staticmethod + def get_tool_instances_by_workflow( + workflow_id: str, + order_by: str, + lookup: Optional[dict[str, Any]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, + ) -> list[ToolInstance]: + wf_filter = {} + if lookup: + wf_filter = lookup + wf_filter[WorkflowKey.WF_ID] = workflow_id + + if limit: + offset_value = 0 if not offset else offset + to = offset_value + limit + return list( + ToolInstance.objects.filter(**wf_filter)[offset_value:to].order_by( + order_by + ) + ) + return list(ToolInstance.objects.filter(**wf_filter).all().order_by(order_by)) + + @staticmethod + def update_instance_metadata( + org_id: str, tool_instance: ToolInstance, metadata: dict[str, Any] + ) -> None: + if ( + JsonSchemaKey.OUTPUT_FILE_CONNECTOR in metadata + and JsonSchemaKey.OUTPUT_FOLDER in metadata + ): + output_connector_name = metadata[JsonSchemaKey.OUTPUT_FILE_CONNECTOR] + output_connector = ConnectorInstanceHelper.get_output_connector_instance_by_name_for_workflow( # noqa + tool_instance.workflow_id, output_connector_name + ) + if output_connector and "path" in output_connector.metadata: + metadata[JsonSchemaKey.OUTPUT_FOLDER] = os.path.join( + output_connector.metadata["path"], + *(metadata[JsonSchemaKey.OUTPUT_FOLDER].split("/")), + ) + if ( + JsonSchemaKey.INPUT_FILE_CONNECTOR in metadata + and JsonSchemaKey.ROOT_FOLDER in metadata + ): + input_connector_name = metadata[JsonSchemaKey.INPUT_FILE_CONNECTOR] + input_connector = ConnectorInstanceHelper.get_input_connector_instance_by_name_for_workflow( # noqa + tool_instance.workflow_id, input_connector_name + ) + + if input_connector and "path" in input_connector.metadata: + metadata[JsonSchemaKey.ROOT_FOLDER] = os.path.join( + input_connector.metadata["path"], + *(metadata[JsonSchemaKey.ROOT_FOLDER].split("/")), + ) + ToolInstanceHelper.update_metadata_with_adapter_instances( + metadata, tool_instance.tool_id + ) + metadata[JsonSchemaKey.TENANT_ID] = org_id + tool_instance.metadata = metadata + tool_instance.save() + + @staticmethod + def update_metadata_with_adapter_properties( + metadata: dict[str, Any], + adapter_key: str, + adapter_property: dict[str, Any], + adapter_type: AdapterTypes, + ) -> None: + """Update the metadata dictionary with adapter properties. + + Parameters: + metadata (dict[str, Any]): + The metadata dictionary to be updated with adapter properties. + adapter_key (str): + The key in the metadata dictionary corresponding to the adapter. + adapter_property (dict[str, Any]): + The properties of the adapter. + adapter_type (AdapterTypes): + The type of the adapter. + + Returns: + None + """ + if adapter_key in metadata: + adapter_name = metadata[adapter_key] + adapter = AdapterProcessor.get_adapter_by_name_and_type( + adapter_type=adapter_type, adapter_name=adapter_name + ) + adapter_id = str(adapter.id) if adapter else None + metadata_key_for_id = adapter_property.get( + AdapterPropertyKey.ADAPTER_ID_KEY, AdapterPropertyKey.ADAPTER_ID + ) + metadata[metadata_key_for_id] = adapter_id + + @staticmethod + def update_metadata_with_adapter_instances( + metadata: dict[str, Any], tool_uid: str + ) -> None: + """ + Update the metadata dictionary with adapter instances. + Parameters: + metadata (dict[str, Any]): + The metadata dictionary to be updated with adapter instances. + + Returns: + None + """ + tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid) + schema: Spec = ToolUtils.get_json_schema_for_tool(tool) + llm_properties = schema.get_llm_adapter_properties() + embedding_properties = schema.get_embedding_adapter_properties() + vector_db_properties = schema.get_vector_db_adapter_properties() + x2text_properties = schema.get_text_extractor_adapter_properties() + ocr_properties = schema.get_ocr_adapter_properties() + + for adapter_key, adapter_property in llm_properties.items(): + ToolInstanceHelper.update_metadata_with_adapter_properties( + metadata=metadata, + adapter_key=adapter_key, + adapter_property=adapter_property, + adapter_type=AdapterTypes.LLM, + ) + + for adapter_key, adapter_property in embedding_properties.items(): + ToolInstanceHelper.update_metadata_with_adapter_properties( + metadata=metadata, + adapter_key=adapter_key, + adapter_property=adapter_property, + adapter_type=AdapterTypes.EMBEDDING, + ) + + for adapter_key, adapter_property in vector_db_properties.items(): + ToolInstanceHelper.update_metadata_with_adapter_properties( + metadata=metadata, + adapter_key=adapter_key, + adapter_property=adapter_property, + adapter_type=AdapterTypes.VECTOR_DB, + ) + + for adapter_key, adapter_property in x2text_properties.items(): + ToolInstanceHelper.update_metadata_with_adapter_properties( + metadata=metadata, + adapter_key=adapter_key, + adapter_property=adapter_property, + adapter_type=AdapterTypes.X2TEXT, + ) + + for adapter_key, adapter_property in ocr_properties.items(): + ToolInstanceHelper.update_metadata_with_adapter_properties( + metadata=metadata, + adapter_key=adapter_key, + adapter_property=adapter_property, + adapter_type=AdapterTypes.OCR, + ) + + # TODO: Review if adding this metadata is still required + @staticmethod + def get_altered_metadata( + tool_instance: ToolInstance, + ) -> Optional[dict[str, Any]]: + """Get altered metadata by resolving relative paths. + + This method retrieves the metadata from the given tool instance + and checks if there are output and input file connectors. + If output and input file connectors exist in the metadata, + it resolves the relative paths using connector instances. + + Args: + tool_instance (ToolInstance). + + Returns: + Optional[dict[str, Any]]: Altered metadata with resolved relative \ + paths. + """ + metadata: dict[str, Any] = tool_instance.metadata + if ( + JsonSchemaKey.OUTPUT_FILE_CONNECTOR in metadata + and JsonSchemaKey.OUTPUT_FOLDER in metadata + ): + output_connector_name = metadata[JsonSchemaKey.OUTPUT_FILE_CONNECTOR] + output_connector = ConnectorInstanceHelper.get_output_connector_instance_by_name_for_workflow( # noqa + tool_instance.workflow_id, output_connector_name + ) + if output_connector and "path" in output_connector.metadata: + relative_path = ToolInstanceHelper.get_relative_path( + metadata[JsonSchemaKey.OUTPUT_FOLDER], + output_connector.metadata["path"], + ) + metadata[JsonSchemaKey.OUTPUT_FOLDER] = relative_path + if ( + JsonSchemaKey.INPUT_FILE_CONNECTOR in metadata + and JsonSchemaKey.ROOT_FOLDER in metadata + ): + input_connector_name = metadata[JsonSchemaKey.INPUT_FILE_CONNECTOR] + input_connector = ConnectorInstanceHelper.get_input_connector_instance_by_name_for_workflow( # noqa + tool_instance.workflow_id, input_connector_name + ) + if input_connector and "path" in input_connector.metadata: + relative_path = ToolInstanceHelper.get_relative_path( + metadata[JsonSchemaKey.ROOT_FOLDER], + input_connector.metadata["path"], + ) + metadata[JsonSchemaKey.ROOT_FOLDER] = relative_path + return metadata + + @staticmethod + def update_metadata_with_default_adapter( + adapter_type: AdapterTypes, + schema_spec: Spec, + adapter: AdapterInstance, + metadata: dict[str, Any], + ) -> None: + """Update the metadata of a tool instance with default values for + enabled adapters. + + Parameters: + adapter_type (AdapterTypes): The type of adapter to update + the metadata for. + schema_spec (Spec): The schema specification for the tool. + adapter (AdapterInstance): The adapter instance to use for updating + the metadata. + metadata (dict[str, Any]): The metadata dictionary to update. + + Returns: + None + """ + properties = {} + if adapter_type == AdapterTypes.LLM: + properties = schema_spec.get_llm_adapter_properties() + if adapter_type == AdapterTypes.EMBEDDING: + properties = schema_spec.get_embedding_adapter_properties() + if adapter_type == AdapterTypes.VECTOR_DB: + properties = schema_spec.get_vector_db_adapter_properties() + if adapter_type == AdapterTypes.X2TEXT: + properties = schema_spec.get_text_extractor_adapter_properties() + if adapter_type == AdapterTypes.OCR: + properties = schema_spec.get_ocr_adapter_properties() + for adapter_key, adapter_property in properties.items(): + metadata_key_for_id = adapter_property.get( + AdapterPropertyKey.ADAPTER_ID_KEY, AdapterPropertyKey.ADAPTER_ID + ) + metadata[adapter_key] = adapter.adapter_name + metadata[metadata_key_for_id] = str(adapter.id) + + @staticmethod + def update_metadata_with_default_values( + tool_instance: ToolInstance, user: User + ) -> None: + """Update the metadata of a tool instance with default values for + enabled adapters. + + Parameters: + tool_instance (ToolInstance): The tool instance to update the + metadata. + + Returns: + None + """ + metadata: dict[str, Any] = tool_instance.metadata + tool_uuid = tool_instance.tool_id + + tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uuid) + schema: Spec = ToolUtils.get_json_schema_for_tool(tool) + + default_adapters = AdapterProcessor.get_default_adapters(user=user) + for adapter in default_adapters: + try: + adapter_type = AdapterTypes(adapter.adapter_type) + ToolInstanceHelper.update_metadata_with_default_adapter( + adapter_type=adapter_type, + schema_spec=schema, + adapter=adapter, + metadata=metadata, + ) + except ValueError: + logger.warning(f"Invalid AdapterType {adapter.adapter_type}") + tool_instance.metadata = metadata + tool_instance.save() + + @staticmethod + def get_relative_path(absolute_path: str, base_path: str) -> str: + if absolute_path.startswith(base_path): + relative_path = os.path.relpath(absolute_path, base_path) + else: + relative_path = absolute_path + if relative_path == ".": + relative_path = "" + return relative_path + + @staticmethod + def reorder_tool_instances(instances_to_reorder: list[uuid.UUID]) -> None: + """Reorders tool instances based on the list of tool UUIDs received. + Saves the instance in the DB. + + Args: + instances_to_reorder (list[uuid.UUID]): Desired order of tool UUIDs + """ + logger.info(f"Reordering instances: {instances_to_reorder}") + for step, tool_instance_id in enumerate(instances_to_reorder): + tool_instance = ToolInstance.objects.get(pk=tool_instance_id) + tool_instance.step = step + 1 + tool_instance.save() + + @staticmethod + def validate_tool_settings( + user: User, tool_uid: str, tool_meta: dict[str, Any] + ) -> bool: + """Function to validate Tools settings.""" + + # check if exported tool is valid for the user who created workflow + ToolInstanceHelper.validate_tool_access(user=user, tool_uid=tool_uid) + ToolInstanceHelper.validate_adapter_permissions( + user=user, tool_uid=tool_uid, tool_meta=tool_meta + ) + + tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid) + tool_name: str = ( + tool.properties.display_name if tool.properties.display_name else tool_uid + ) + schema_json: dict[str, Any] = ToolProcessor.get_json_schema_for_tool( + tool_uid=tool_uid, user=user + ) + try: + DefaultsGeneratingValidator(schema_json).validate(tool_meta) + except JSONValidationError as e: + logger.error(e, stack_info=True, exc_info=True) + err_msg = e.message + # TODO: Support other JSON validation errors or consider following + # https://github.com/networknt/json-schema-validator/blob/master/doc/cust-msg.md + if e.validator == "required": + for validator_val in e.validator_value: + required_prop = e.schema.get("properties").get(validator_val) + required_display_name = required_prop.get("title") + err_msg = err_msg.replace(validator_val, required_display_name) + else: + logger.warning(f"Unformatted exception sent to user: {err_msg}") + raise ToolSettingValidationError( + f"Error validating tool settings for '{tool_name}': {err_msg}" + ) + return True + + @staticmethod + def validate_adapter_permissions( + user: User, tool_uid: str, tool_meta: dict[str, Any] + ) -> None: + tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid) + adapter_ids: set[str] = set() + + for llm in tool.properties.adapter.language_models: + if llm.is_enabled and llm.adapter_id: + adapter_id = tool_meta[llm.adapter_id] + elif llm.is_enabled: + adapter_id = tool_meta[AdapterPropertyKey.DEFAULT_LLM_ADAPTER_ID] + + adapter_ids.add(adapter_id) + for vdb in tool.properties.adapter.vector_stores: + if vdb.is_enabled and vdb.adapter_id: + adapter_id = tool_meta[vdb.adapter_id] + elif vdb.is_enabled: + adapter_id = tool_meta[AdapterPropertyKey.DEFAULT_VECTOR_DB_ADAPTER_ID] + + adapter_ids.add(adapter_id) + for embedding in tool.properties.adapter.embedding_services: + if embedding.is_enabled and embedding.adapter_id: + adapter_id = tool_meta[embedding.adapter_id] + elif embedding.is_enabled: + adapter_id = tool_meta[AdapterPropertyKey.DEFAULT_EMBEDDING_ADAPTER_ID] + + adapter_ids.add(adapter_id) + for text_extractor in tool.properties.adapter.text_extractors: + if text_extractor.is_enabled and text_extractor.adapter_id: + adapter_id = tool_meta[text_extractor.adapter_id] + elif text_extractor.is_enabled: + adapter_id = tool_meta[AdapterPropertyKey.DEFAULT_X2TEXT_ADAPTER_ID] + + adapter_ids.add(adapter_id) + + ToolInstanceHelper.validate_adapter_access(user=user, adapter_ids=adapter_ids) + + @staticmethod + def validate_adapter_access( + user: User, + adapter_ids: set[str], + ) -> None: + adapter_instances = AdapterInstance.objects.filter(id__in=adapter_ids).all() + + for adapter_instance in adapter_instances: + if not adapter_instance.is_usable: + logger.error( + "Free usage for the configured sample adapter %s exhausted", + adapter_instance.id, + ) + error_msg = "Permission Error: Free usage for the configured trial adapter exhausted.Please connect your own service accounts to continue.Please see our documentation for more details:https://docs.unstract.com/unstract_platform/setup_accounts/whats_needed" # noqa: E501 + + raise PermissionDenied(error_msg) + + if not ( + adapter_instance.shared_to_org + or adapter_instance.created_by == user + or adapter_instance.shared_users.filter(pk=user.pk).exists() + ): + logger.error( + "User %s doesn't have access to adapter %s", + user.user_id, + adapter_instance.id, + ) + raise PermissionDenied( + "You don't have permission to perform this action." + ) + + @staticmethod + def validate_tool_access( + user: User, + tool_uid: str, + ) -> None: + # HACK: Assume tool_uid is a prompt studio exported tool and query it. + # We suppress ValidationError when tool_uid is of a static tool. + try: + prompt_registry_tool = PromptStudioRegistry.objects.get(pk=tool_uid) + except DjangoValidationError: + logger.info(f"Not validating tool access for tool: {tool_uid}") + return + + if ( + prompt_registry_tool.shared_to_org + or prompt_registry_tool.shared_users.filter(pk=user.pk).exists() + ): + return + else: + raise PermissionDenied("You don't have permission to perform this action.") diff --git a/backend/tool_instance_v2/tool_processor.py b/backend/tool_instance_v2/tool_processor.py new file mode 100644 index 000000000..20cb05953 --- /dev/null +++ b/backend/tool_instance_v2/tool_processor.py @@ -0,0 +1,137 @@ +import logging +from typing import Any, Optional + +from account_v2.models import User +from adapter_processor_v2.adapter_processor import AdapterProcessor +from prompt_studio.prompt_studio_registry_v2.prompt_studio_registry_helper import ( + PromptStudioRegistryHelper, +) +from tool_instance_v2.exceptions import ToolDoesNotExist +from unstract.adapters.enums import AdapterTypes +from unstract.tool_registry.dto import Spec, Tool +from unstract.tool_registry.tool_registry import ToolRegistry +from unstract.tool_registry.tool_utils import ToolUtils + +logger = logging.getLogger(__name__) + + +class ToolProcessor: + TOOL_NOT_IN_REGISTRY_MESSAGE = "Tool does not exist in registry" + tool_registry = ToolRegistry() + + @staticmethod + def get_tool_by_uid(tool_uid: str) -> Tool: + """Function to get and instantiate a tool for a given tool + settingsId.""" + tool_registry = ToolRegistry() + tool: Optional[Tool] = tool_registry.get_tool_by_uid(tool_uid) + # HACK: Assume tool_uid is prompt_registry_id for fetching a dynamic + # tool made with Prompt Studio. + if not tool: + tool = PromptStudioRegistryHelper.get_tool_by_prompt_registry_id( + prompt_registry_id=tool_uid + ) + if not tool: + raise ToolDoesNotExist( + f"{ToolProcessor.TOOL_NOT_IN_REGISTRY_MESSAGE}: {tool_uid}" + ) + return tool + + @staticmethod + def get_default_settings(tool: Tool) -> dict[str, str]: + """Function to make and fill settings with default values. + + Args: + tool (ToolSettings): tool + + Returns: + dict[str, str]: tool settings + """ + tool_metadata: dict[str, str] = ToolUtils.get_default_settings(tool) + return tool_metadata + + @staticmethod + def get_json_schema_for_tool(tool_uid: str, user: User) -> dict[str, str]: + """Function to Get JSON Schema for Tools.""" + tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid) + schema: Spec = ToolUtils.get_json_schema_for_tool(tool) + ToolProcessor.update_schema_with_adapter_configurations( + schema=schema, user=user + ) + schema_json: dict[str, Any] = schema.to_dict() + return schema_json + + @staticmethod + def update_schema_with_adapter_configurations(schema: Spec, user: User) -> None: + """Updates the JSON schema with the available adapter configurations + for the LLM, embedding, and vector DB adapters. + + Args: + schema (Spec): The JSON schema object to be updated. + + Returns: + None. The `schema` object is updated in-place. + """ + llm_keys = schema.get_llm_adapter_properties_keys() + embedding_keys = schema.get_embedding_adapter_properties_keys() + vector_db_keys = schema.get_vector_db_adapter_properties_keys() + x2text_keys = schema.get_text_extractor_adapter_properties_keys() + ocr_keys = schema.get_ocr_adapter_properties_keys() + + if llm_keys: + adapters = AdapterProcessor.get_adapters_by_type( + AdapterTypes.LLM, user=user + ) + for key in llm_keys: + adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters) + schema.properties[key]["enum"] = list(adapter_names) + + if embedding_keys: + adapters = AdapterProcessor.get_adapters_by_type( + AdapterTypes.EMBEDDING, user=user + ) + for key in embedding_keys: + adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters) + schema.properties[key]["enum"] = list(adapter_names) + + if vector_db_keys: + adapters = AdapterProcessor.get_adapters_by_type( + AdapterTypes.VECTOR_DB, user=user + ) + for key in vector_db_keys: + adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters) + schema.properties[key]["enum"] = list(adapter_names) + + if x2text_keys: + adapters = AdapterProcessor.get_adapters_by_type( + AdapterTypes.X2TEXT, user=user + ) + for key in x2text_keys: + adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters) + schema.properties[key]["enum"] = list(adapter_names) + + if ocr_keys: + adapters = AdapterProcessor.get_adapters_by_type( + AdapterTypes.OCR, user=user + ) + for key in ocr_keys: + adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters) + schema.properties[key]["enum"] = list(adapter_names) + + @staticmethod + def get_tool_list(user: User) -> list[dict[str, Any]]: + """Function to get a list of tools.""" + tool_registry = ToolRegistry() + prompt_studio_tools: list[dict[str, Any]] = ( + PromptStudioRegistryHelper.fetch_json_for_registry(user) + ) + tool_list: list[dict[str, Any]] = tool_registry.fetch_tools_descriptions() + tool_list = tool_list + prompt_studio_tools + return tool_list + + @staticmethod + def get_registry_tools() -> list[Tool]: + """Function to get a list of tools.""" + tool_registry = ToolRegistry() + tool_list: list[Tool] = tool_registry.fetch_all_tools() + return tool_list diff --git a/backend/tool_instance_v2/urls.py b/backend/tool_instance_v2/urls.py new file mode 100644 index 000000000..411725079 --- /dev/null +++ b/backend/tool_instance_v2/urls.py @@ -0,0 +1,46 @@ +from django.urls import path +from rest_framework.urlpatterns import format_suffix_patterns +from tool_instance_v2.views import ToolInstanceViewSet + +from . import views + +tool_instance_list = ToolInstanceViewSet.as_view( + { + "get": "list", + "post": "create", + } +) +tool_instance_detail = ToolInstanceViewSet.as_view( + # fmt: off + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy" + } + # fmt: on +) + +tool_instance_reorder = ToolInstanceViewSet.as_view({"post": "reorder"}) + +urlpatterns = format_suffix_patterns( + [ + path("tool_instance/", tool_instance_list, name="tool-instance-list"), + path( + "tool_instance//", + tool_instance_detail, + name="tool-instance-detail", + ), + path( + "tool_settings_schema/", + views.tool_settings_schema, + name="tool_settings_schema", + ), + path( + "tool_instance/reorder/", + tool_instance_reorder, + name="tool_instance_reorder", + ), + path("tool/", views.get_tool_list, name="tool_list"), + ] +) diff --git a/backend/tool_instance_v2/views.py b/backend/tool_instance_v2/views.py new file mode 100644 index 000000000..7d56c4ec9 --- /dev/null +++ b/backend/tool_instance_v2/views.py @@ -0,0 +1,167 @@ +import logging +import uuid +from typing import Any + +from account_v2.custom_exceptions import DuplicateData +from django.db import IntegrityError +from django.db.models.query import QuerySet +from rest_framework import serializers, status, viewsets +from rest_framework.decorators import api_view +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from tool_instance_v2.constants import ToolInstanceErrors +from tool_instance_v2.constants import ToolInstanceKey as TIKey +from tool_instance_v2.constants import ToolKey +from tool_instance_v2.exceptions import FetchToolListFailed, ToolFunctionIsMandatory +from tool_instance_v2.models import ToolInstance +from tool_instance_v2.serializers import ( + ToolInstanceReorderSerializer as TIReorderSerializer, +) +from tool_instance_v2.serializers import ToolInstanceSerializer +from tool_instance_v2.tool_instance_helper import ToolInstanceHelper +from tool_instance_v2.tool_processor import ToolProcessor +from utils.filtering import FilterHelper +from utils.user_session import UserSessionUtils +from workflow_manager.workflow_v2.constants import WorkflowKey + +from backend.constants import RequestKey + +logger = logging.getLogger(__name__) + + +@api_view(["GET"]) +def tool_settings_schema(request: Request) -> Response: + if request.method == "GET": + tool_function = request.GET.get(ToolKey.FUNCTION_NAME) + if tool_function is None or tool_function == "": + raise ToolFunctionIsMandatory() + + json_schema = ToolProcessor.get_json_schema_for_tool( + tool_uid=tool_function, user=request.user + ) + return Response(data=json_schema, status=status.HTTP_200_OK) + + +@api_view(("GET",)) +def get_tool_list(request: Request) -> Response: + """Get tool list. + + Fetches a list of tools available in the Tool registry + """ + if request.method == "GET": + try: + logger.info("Fetching tools from the tool registry...") + return Response( + data=ToolProcessor.get_tool_list(request.user), + status=status.HTTP_200_OK, + ) + except Exception as exc: + logger.error(f"Failed to fetch tools: {exc}") + raise FetchToolListFailed + + +class ToolInstanceViewSet(viewsets.ModelViewSet): + versioning_class = URLPathVersioning + queryset = ToolInstance.objects.all() + serializer_class = ToolInstanceSerializer + + def get_queryset(self) -> QuerySet: + filter_args = FilterHelper.build_filter_args( + self.request, + RequestKey.PROJECT, + RequestKey.CREATED_BY, + RequestKey.WORKFLOW, + ) + if filter_args: + queryset = ToolInstance.objects.filter( + created_by=self.request.user, **filter_args + ) + else: + queryset = ToolInstance.objects.filter(created_by=self.request.user) + return queryset + + def get_serializer_class(self) -> serializers.Serializer: + if self.action == "reorder": + return TIReorderSerializer + else: + return ToolInstanceSerializer + + def create(self, request: Any) -> Response: + """Create tool instance. + + Creates a tool instance, useful to add them directly to a + workflow. Its an alternative to creating tool instances through + the LLM response. + """ + + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + try: + self.perform_create(serializer) + except IntegrityError: + raise DuplicateData( + f"{ToolInstanceErrors.TOOL_EXISTS}, " + f"{ToolInstanceErrors.DUPLICATE_API}" + ) + instance: ToolInstance = serializer.instance + ToolInstanceHelper.update_metadata_with_default_values( + instance, user=request.user + ) + headers = self.get_success_headers(serializer.data) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) + + def perform_destroy(self, instance: ToolInstance) -> None: + """Deletes a tool instance and decrements successor instance's steps. + + Args: + instance (ToolInstance): Instance being deleted. + """ + lookup = {"step__gt": instance.step} + next_tool_instances: list[ToolInstance] = ( + ToolInstanceHelper.get_tool_instances_by_workflow( + instance.workflow.id, TIKey.STEP, lookup=lookup + ) + ) + super().perform_destroy(instance) + + for instance in next_tool_instances: + instance.step = instance.step - 1 + instance.save() + + def partial_update(self, request: Request, *args: Any, **kwargs: Any) -> Response: + """Allows partial updates on a tool instance.""" + instance: ToolInstance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=True) + serializer.is_valid(raise_exception=True) + if serializer.validated_data.get(TIKey.METADATA): + metadata: dict[str, Any] = serializer.validated_data.get(TIKey.METADATA) + + # TODO: Move update logic into serializer + organization_id = UserSessionUtils.get_organization_id(request) + ToolInstanceHelper.update_instance_metadata( + organization_id, + instance, + metadata, + ) + return Response(serializer.data) + return super().partial_update(request, *args, **kwargs) + + def reorder(self, request: Any, **kwargs: Any) -> Response: + """Reorder tool instances. + + Reorders the tool instances based on a list of UUIDs. + """ + serializer: TIReorderSerializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + wf_id = serializer.validated_data[WorkflowKey.WF_ID] + instances_to_reorder: list[uuid.UUID] = serializer.validated_data[ + WorkflowKey.WF_TOOL_INSTANCES + ] + + ToolInstanceHelper.reorder_tool_instances(instances_to_reorder) + tool_instances = ToolInstance.objects.get_instances_for_workflow(workflow=wf_id) + ti_serializer = ToolInstanceSerializer(instance=tool_instances, many=True) + return Response(ti_serializer.data, status=status.HTTP_200_OK) diff --git a/backend/usage_v2/__init__.py b/backend/usage_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/usage_v2/admin.py b/backend/usage_v2/admin.py new file mode 100644 index 000000000..c7469a4fb --- /dev/null +++ b/backend/usage_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import Usage + +admin.site.register(Usage) diff --git a/backend/usage_v2/apps.py b/backend/usage_v2/apps.py new file mode 100644 index 000000000..d2845981e --- /dev/null +++ b/backend/usage_v2/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class UsageConfig(AppConfig): + name = "usage_v2" diff --git a/backend/usage_v2/constants.py b/backend/usage_v2/constants.py new file mode 100644 index 000000000..8da54da05 --- /dev/null +++ b/backend/usage_v2/constants.py @@ -0,0 +1,7 @@ +class UsageKeys: + RUN_ID = "run_id" + EMBEDDING_TOKENS = "embedding_tokens" + PROMPT_TOKENS = "prompt_tokens" + COMPLETION_TOKENS = "completion_tokens" + TOTAL_TOKENS = "total_tokens" + COST_IN_DOLLARS = "cost_in_dollars" diff --git a/backend/usage_v2/helper.py b/backend/usage_v2/helper.py new file mode 100644 index 000000000..0bfab7556 --- /dev/null +++ b/backend/usage_v2/helper.py @@ -0,0 +1,64 @@ +import logging + +from django.db.models import Sum +from rest_framework.exceptions import APIException + +from .constants import UsageKeys +from .models import Usage + +logger = logging.getLogger(__name__) + + +class UsageHelper: + @staticmethod + def get_aggregated_token_count(run_id: str) -> dict: + """Retrieve aggregated token counts for the given run_id. + + Args: + run_id (str): The identifier for the token usage. + + Returns: + dict: A dictionary containing aggregated token counts + for different token types. + Keys: + - 'embedding_tokens': Total embedding tokens. + - 'prompt_tokens': Total prompt tokens. + - 'completion_tokens': Total completion tokens. + - 'total_tokens': Total tokens. + + Raises: + APIException: For unexpected errors during database operations. + """ + try: + # Aggregate the token counts for the given run_id + usage_summary = Usage.objects.filter(run_id=run_id).aggregate( + embedding_tokens=Sum(UsageKeys.EMBEDDING_TOKENS), + prompt_tokens=Sum(UsageKeys.PROMPT_TOKENS), + completion_tokens=Sum(UsageKeys.COMPLETION_TOKENS), + total_tokens=Sum(UsageKeys.TOTAL_TOKENS), + cost_in_dollars=Sum(UsageKeys.COST_IN_DOLLARS), + ) + + logger.info(f"Token counts aggregated successfully for run_id: {run_id}") + + # Prepare the result dictionary with None as the default value + result = { + UsageKeys.EMBEDDING_TOKENS: usage_summary.get( + UsageKeys.EMBEDDING_TOKENS + ), + UsageKeys.PROMPT_TOKENS: usage_summary.get(UsageKeys.PROMPT_TOKENS), + UsageKeys.COMPLETION_TOKENS: usage_summary.get( + UsageKeys.COMPLETION_TOKENS + ), + UsageKeys.TOTAL_TOKENS: usage_summary.get(UsageKeys.TOTAL_TOKENS), + UsageKeys.COST_IN_DOLLARS: usage_summary.get(UsageKeys.COST_IN_DOLLARS), + } + return result + except Usage.DoesNotExist: + # Handle the case where no usage data is found for the given run_id + logger.warning(f"Usage data not found for the specified run_id: {run_id}") + return {} + except Exception as e: + # Handle any other exceptions that might occur during the execution + logger.error(f"An unexpected error occurred for run_id {run_id}: {str(e)}") + raise APIException("An unexpected error occurred") diff --git a/backend/usage_v2/models.py b/backend/usage_v2/models.py new file mode 100644 index 000000000..8d4a35069 --- /dev/null +++ b/backend/usage_v2/models.py @@ -0,0 +1,82 @@ +import uuid + +from django.db import models +from utils.models.base_model import BaseModel +from utils.models.organization_mixin import ( + DefaultOrganizationManagerMixin, + DefaultOrganizationMixin, +) + + +class UsageType(models.TextChoices): + LLM = "llm", "LLM Usage" + EMBEDDING = "embedding", "Embedding Usage" + + +class LLMUsageReason(models.TextChoices): + EXTRACTION = "extraction", "Extraction" + CHALLENGE = "challenge", "Challenge" + SUMMARIZE = "summarize", "Summarize" + + +class UsageModelManager(DefaultOrganizationManagerMixin, models.Manager): + pass + + +class Usage(DefaultOrganizationMixin, BaseModel): + id = models.UUIDField( + primary_key=True, + default=uuid.uuid4, + editable=False, + db_comment="Primary key for the usage entry, automatically generated UUID", + ) + workflow_id = models.CharField( + max_length=255, null=True, blank=True, db_comment="Identifier for the workflow" + ) + execution_id = models.CharField( + max_length=255, + null=True, + blank=True, + db_comment="Identifier for the execution instance", + ) + adapter_instance_id = models.CharField( + max_length=255, db_comment="Identifier for the adapter instance" + ) + run_id = models.CharField( + max_length=255, null=True, blank=True, db_comment="Identifier for the run" + ) + usage_type = models.CharField( + max_length=255, + choices=UsageType.choices, + db_comment="Type of usage, either 'llm' or 'embedding'", + ) + llm_usage_reason = models.CharField( + max_length=255, + choices=LLMUsageReason.choices, + null=True, + blank=True, + db_comment="Reason for LLM usage. Empty if usage_type is 'embedding'. ", + ) + model_name = models.CharField(max_length=255, db_comment="Name of the model used") + embedding_tokens = models.IntegerField( + db_comment="Number of tokens used for embedding" + ) + prompt_tokens = models.IntegerField( + db_comment="Number of tokens used for the prompt" + ) + completion_tokens = models.IntegerField( + db_comment="Number of tokens used for the completion" + ) + total_tokens = models.IntegerField(db_comment="Total number of tokens used") + cost_in_dollars = models.FloatField(db_comment="Total number of tokens used") + # Manager + objects = UsageModelManager() + + def __str__(self): + return str(self.id) + + class Meta: + db_table = "token_usage_v2" + indexes = [ + models.Index(fields=["run_id"]), + ] diff --git a/backend/usage_v2/serializers.py b/backend/usage_v2/serializers.py new file mode 100644 index 000000000..eb1f2c326 --- /dev/null +++ b/backend/usage_v2/serializers.py @@ -0,0 +1,5 @@ +from rest_framework import serializers + + +class GetUsageSerializer(serializers.Serializer): + run_id = serializers.CharField(required=True) diff --git a/backend/usage_v2/urls.py b/backend/usage_v2/urls.py new file mode 100644 index 000000000..a0fa0bfa8 --- /dev/null +++ b/backend/usage_v2/urls.py @@ -0,0 +1,18 @@ +from django.urls import path +from rest_framework.urlpatterns import format_suffix_patterns + +from .views import UsageView + +get_token_usage = UsageView.as_view({"get": "get_token_usage"}) + +# TODO: Refactor URL to avoid using action-specific verbs like get. + +urlpatterns = format_suffix_patterns( + [ + path( + "get_token_usage/", + get_token_usage, + name="get-token-usage", + ), + ] +) diff --git a/backend/usage_v2/views.py b/backend/usage_v2/views.py new file mode 100644 index 000000000..1fcebfaa0 --- /dev/null +++ b/backend/usage_v2/views.py @@ -0,0 +1,63 @@ +import logging + +from django.http import HttpRequest +from rest_framework import status, viewsets +from rest_framework.decorators import action +from rest_framework.exceptions import APIException, ValidationError +from rest_framework.response import Response + +from .constants import UsageKeys +from .helper import UsageHelper +from .serializers import GetUsageSerializer + +logger = logging.getLogger(__name__) + + +class UsageView(viewsets.ModelViewSet): + """Viewset for managing Usage-related operations.""" + + @action(detail=True, methods=["get"]) + def get_token_usage(self, request: HttpRequest) -> Response: + """Retrieves the aggregated token usage for a given run_id. + + This method validates the 'run_id' query parameter, aggregates the token + usage statistics for the specified run_id, and returns the results. + + Args: + request (HttpRequest): The HTTP request object containing the + query parameters. + + Returns: + Response: A Response object containing the aggregated token usage data + with HTTP 200 OK status if successful, or an error message and + appropriate HTTP status if an error occurs. + + Raises: + ValidationError: If the 'run_id' query parameter is missing or invalid. + APIException: If an unexpected error occurs during the execution. + """ + + try: + # Validate the query parameters using the serializer + # This ensures that 'run_id' is present and valid + serializer = GetUsageSerializer(data=self.request.query_params) + serializer.is_valid(raise_exception=True) + run_id = serializer.validated_data.get(UsageKeys.RUN_ID) + + # Retrieve aggregated token count for the given run_id. + result: dict = UsageHelper.get_aggregated_token_count(run_id=run_id) + + # Log the successful completion of the operation + logger.info(f"Token usage retrieved successfully for run_id: {run_id}") + + # Return the result + return Response(status=status.HTTP_200_OK, data=result) + except ValidationError as e: + # Handle validation errors specifically + logger.error(f"Validation error: {e.detail}") + raise ValidationError(detail=f"Validation error: {str(e)}") + except Exception as e: + # Handle any other exceptions that might occur during the execution + error_msg = "An unexpected error occurred while fetching the token usage" + logger.error(f"{error_msg}: {e}") + raise APIException(detail=error_msg) diff --git a/backend/workflow_manager/endpoint_v2/__init__ .py b/backend/workflow_manager/endpoint_v2/__init__ .py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/workflow_manager/endpoint_v2/apps.py b/backend/workflow_manager/endpoint_v2/apps.py new file mode 100644 index 000000000..bf6010bf0 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class WorkflowEndpointConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "workflow_manager.endpoint_v2" diff --git a/backend/workflow_manager/endpoint_v2/base_connector.py b/backend/workflow_manager/endpoint_v2/base_connector.py new file mode 100644 index 000000000..3207537e6 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/base_connector.py @@ -0,0 +1,91 @@ +import json +from typing import Any + +from django.conf import settings +from fsspec import AbstractFileSystem +from unstract.workflow_execution.execution_file_handler import ExecutionFileHandler +from utils.constants import Common +from utils.user_context import UserContext + +from unstract.connectors.filesystems import connectors +from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem + + +class BaseConnector(ExecutionFileHandler): + """Base class for connectors providing common methods and utilities.""" + + def __init__( + self, workflow_id: str, execution_id: str, organization_id: str + ) -> None: + """Initialize the BaseConnector class. + + This class serves as a base for connectors and provides common + utilities. + """ + if not (settings.API_STORAGE_DIR and settings.WORKFLOW_DATA_DIR): + raise ValueError("Missed env API_STORAGE_DIR or WORKFLOW_DATA_DIR") + super().__init__(workflow_id, execution_id, organization_id) + # Directory path for storing execution-related files for API + self.api_storage_dir: str = self.create_execution_dir_path( + workflow_id, execution_id, organization_id, settings.API_STORAGE_DIR + ) + + def get_fsspec( + self, settings: dict[str, Any], connector_id: str + ) -> AbstractFileSystem: + """Get an fsspec file system based on the specified connector. + + Parameters: + - settings (dict): Connector-specific settings. + - connector_id (str): Identifier for the desired connector. + + Returns: + AbstractFileSystem: An fsspec file system instance. + + Raises: + KeyError: If the connector_id is not found in the connectors dictionary. + """ + if connector_id not in connectors: + raise ValueError(f"Invalid connector_id: {connector_id}") + connector = connectors[connector_id][Common.METADATA][Common.CONNECTOR] + connector_class: UnstractFileSystem = connector(settings) + return connector_class.get_fsspec_fs() + + @classmethod + def get_json_schema(cls, file_path: str) -> dict[str, Any]: + """Load and return a JSON schema from the specified file path. + + Parameters: + - file_path (str): The path to the JSON schema file. + + Returns: + dict: The loaded JSON schema. + + Raises: + json.JSONDecodeError: If there is an issue decoding the JSON file. + """ + try: + with open(file_path, encoding="utf-8") as file: + schema: dict[str, Any] = json.load(file) + except OSError: + schema = {} + return schema + + @classmethod + def get_api_storage_dir_path(cls, workflow_id: str, execution_id: str) -> str: + """Get the directory path for storing api files. + + Parameters: + - workflow_id (str): Identifier for the workflow. + - execution_id (str): Identifier for the execution. + - organization_id (Optional[str]): Identifier for the organization + (default: None). + + Returns: + str: The directory path for the execution. + """ + organization_id = UserContext.get_organization_identifier() + api_storage_dir: str = cls.create_execution_dir_path( + workflow_id, execution_id, organization_id, settings.API_STORAGE_DIR + ) + return api_storage_dir diff --git a/backend/workflow_manager/endpoint_v2/constants.py b/backend/workflow_manager/endpoint_v2/constants.py new file mode 100644 index 000000000..d9553245d --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/constants.py @@ -0,0 +1,105 @@ +class TableColumns: + CREATED_BY = "created_by" + CREATED_AT = "created_at" + PERMANENT_COLUMNS = ["created_by", "created_at"] + + +class DBConnectionClass: + SNOWFLAKE = "SnowflakeDB" + BIGQUERY = "BigQuery" + MSSQL = "MSSQL" + + +class Snowflake: + COLUMN_TYPES = [ + "VARCHAR", + "CHAR", + "CHARACTER", + "STRING", + "TEXT", + "BINARY", + "VARBINARY", + "DATE", + "DATETIME", + "TIME", + "TIMESTAMP", + "TIMESTAMP_LTZ", + "TIMESTAMP_NTZ", + "TIMESTAMP_TZ", + "BOOLEAN", + ] + + +class FileSystemConnector: + MAX_FILES = 100 + + +class WorkflowFileType: + SOURCE = "SOURCE" + INFILE = "INFILE" + METADATA_JSON = "METADATA.json" + + +class SourceKey: + FILE_EXTENSIONS = "fileExtensions" + PROCESS_SUB_DIRECTORIES = "processSubDirectories" + MAX_FILES = "maxFiles" + ROOT_FOLDER = "rootFolder" + + +class DestinationKey: + TABLE = "table" + INCLUDE_AGENT = "includeAgent" + INCLUDE_TIMESTAMP = "includeTimestamp" + AGENT_NAME = "agentName" + COLUMN_MODE = "columnMode" + SINGLE_COLUMN_NAME = "singleColumnName" + PATH = "path" + OUTPUT_FOLDER = "outputFolder" + OVERWRITE_OUTPUT_DOCUMENT = "overwriteOutput" + + +class OutputJsonKey: + JSON_RESULT_KEY = "result" + + +class FileType: + PDF_DOCUMENTS = "PDF documents" + TEXT_DOCUMENTS = "Text documents" + IMAGES = "Images" + + +class FilePattern: + PDF_DOCUMENTS = ["*.pdf"] + TEXT_DOCUMENTS = ["*.txt"] + IMAGES = ["*.jpg", "*.jpeg", "*.png", "*.gif", "*.bmp"] + + +class SourceConstant: + MAX_RECURSIVE_DEPTH = 10 + + +class ApiDeploymentResultStatus: + SUCCESS = "Success" + FAILED = "Failed" + + +class BigQuery: + """In big query, table name has to be in the format {db}.{schema}.{table} + Throws error if any of the params not set. + + When converted to list table size should be 3 + """ + + TABLE_NAME_SIZE = 3 + COLUMN_TYPES = [ + "DATE", + "DATETIME", + "TIME", + "TIMESTAMP", + ] + + +class QueueResultStatus: + SUCCESS = "Success" + FAILED = "Failed" diff --git a/backend/workflow_manager/endpoint_v2/database_utils.py b/backend/workflow_manager/endpoint_v2/database_utils.py new file mode 100644 index 000000000..00b9fbf73 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/database_utils.py @@ -0,0 +1,350 @@ +import datetime +import json +import logging +import uuid +from typing import Any, Optional + +from utils.constants import Common +from workflow_manager.endpoint_v2.constants import ( + BigQuery, + DBConnectionClass, + TableColumns, +) +from workflow_manager.endpoint_v2.db_connector_helper import DBConnectorQueryHelper +from workflow_manager.endpoint_v2.exceptions import ( + BigQueryTableNotFound, + UnstractDBException, +) +from workflow_manager.workflow_v2.enums import AgentName, ColumnModes + +from unstract.connectors.databases import connectors as db_connectors +from unstract.connectors.databases.exceptions import UnstractDBConnectorException +from unstract.connectors.databases.unstract_db import UnstractDB +from unstract.connectors.exceptions import ConnectorError + +logger = logging.getLogger(__name__) + + +class DatabaseUtils: + @staticmethod + def get_sql_values_for_query( + values: dict[str, Any], column_types: dict[str, str], cls_name: str + ) -> dict[str, str]: + """Making Sql Columns and Values for Query. + + Args: + values (dict[str, Any]): dictionary of columns and values + column_types (dict[str,str]): types of columns + cls (Any, optional): The database connection class (e.g., + DBConnectionClass.SNOWFLAKE) for handling database-specific + queries. + Defaults to None. + + Returns: + list[str]: _description_ + + Note: + - If `cls` is not provided or is None, the function assumes a + Default SQL database and makes values accordingly. + - If `cls` is provided and matches DBConnectionClass.SNOWFLAKE, + the function makes values using Snowflake-specific syntax. + + - Unstract creates id by default if table not exists. + If there is column 'id' in db table, it will insert + 'id' as uuid into the db table. + Else it will GET table details from INFORMATION SCHEMA and + insert into the table accordingly + """ + sql_values: dict[str, Any] = {} + for column in values: + if cls_name == DBConnectionClass.SNOWFLAKE: + col = column.lower() + type_x = column_types[col] + if type_x == "VARIANT": + values[column] = values[column].replace("'", "\\'") + sql_values[column] = f"parse_json($${values[column]}$$)" + else: + sql_values[column] = f"{values[column]}" + else: + # Default to Other SQL DBs + # TODO: Handle numeric types with no quotes + sql_values[column] = f"{values[column]}" + if column_types.get("id"): + uuid_id = str(uuid.uuid4()) + sql_values["id"] = f"{uuid_id}" + return sql_values + + @staticmethod + def get_column_types_util(columns_with_types: Any) -> dict[str, str]: + """Converts db results columns_with_types to dict. + + Args: + columns_with_types (Any): _description_ + + Returns: + dict[str, str]: _description_ + """ + column_types: dict[str, str] = {} + for column_name, data_type in columns_with_types: + column_types[column_name] = data_type + return column_types + + @staticmethod + def get_column_types( + cls_name: Any, + table_name: str, + connector_id: str, + connector_settings: dict[str, Any], + ) -> Any: + """Get db column name and types. + + Args: + cls (Any): _description_ + table_name (str): _description_ + connector_id (str): _description_ + connector_settings (dict[str, Any]): _description_ + + Raises: + ValueError: _description_ + e: _description_ + + Returns: + Any: _description_ + """ + column_types: dict[str, str] = {} + try: + if cls_name == DBConnectionClass.SNOWFLAKE: + query = f"describe table {table_name}" + results = DatabaseUtils.execute_and_fetch_data( + connector_id=connector_id, + connector_settings=connector_settings, + query=query, + ) + for column in results: + column_types[column[0].lower()] = column[1].split("(")[0] + elif cls_name == DBConnectionClass.BIGQUERY: + bigquery_table_name = str.lower(table_name).split(".") + if len(bigquery_table_name) != BigQuery.TABLE_NAME_SIZE: + raise BigQueryTableNotFound() + database = bigquery_table_name[0] + schema = bigquery_table_name[1] + table = bigquery_table_name[2] + query = ( + "SELECT column_name, data_type FROM " + f"{database}.{schema}.INFORMATION_SCHEMA.COLUMNS WHERE " + f"table_name = '{table}'" + ) + results = DatabaseUtils.execute_and_fetch_data( + connector_id=connector_id, + connector_settings=connector_settings, + query=query, + ) + column_types = DatabaseUtils.get_column_types_util(results) + else: + table_name = str.lower(table_name) + query = ( + "SELECT column_name, data_type FROM " + "information_schema.columns WHERE " + f"table_name = '{table_name}'" + ) + results = DatabaseUtils.execute_and_fetch_data( + connector_id=connector_id, + connector_settings=connector_settings, + query=query, + ) + column_types = DatabaseUtils.get_column_types_util(results) + except Exception as e: + logger.error(f"Error getting column types for {table_name}: {str(e)}") + raise e + return column_types + + @staticmethod + def get_columns_and_values( + column_mode_str: str, + data: Any, + include_timestamp: bool = False, + include_agent: bool = False, + agent_name: Optional[str] = AgentName.UNSTRACT_DBWRITER.value, + single_column_name: str = "data", + ) -> dict[str, Any]: + """Generate a dictionary of columns and values based on specified + parameters. + + Args: + column_mode_str (str): The string representation of the column mode, + which determines how data is stored in the dictionary. + data (Any): The data to be stored in the dictionary. + include_timestamp (bool, optional): Whether to include the + current timestamp in the dictionary. Defaults to False. + include_agent (bool, optional): Whether to include the agent's name + in the dictionary. Defaults to False. + agent_name (str, optional): The name of the agent when include_agent + is true. Defaults to AgentName.UNSTRACT_DBWRITER. + single_column_name (str, optional): The name of the single column + when using 'WRITE_JSON_TO_A_SINGLE_COLUMN' mode. + Defaults to "data". + + Returns: + dict: A dictionary containing columns and values based on + the specified parameters. + """ + + values: dict[str, Any] = {} + try: + column_mode = ColumnModes(column_mode_str) + except ValueError: + # Handle the case where the string is not a valid enum value + column_mode = ColumnModes.WRITE_JSON_TO_A_SINGLE_COLUMN + + if include_agent and agent_name: + values[TableColumns.CREATED_BY] = agent_name + + if include_timestamp: + values[TableColumns.CREATED_AT] = datetime.datetime.now() + + if column_mode == ColumnModes.WRITE_JSON_TO_A_SINGLE_COLUMN: + if isinstance(data, str): + values[single_column_name] = data + else: + values[single_column_name] = json.dumps(data) + if column_mode == ColumnModes.SPLIT_JSON_INTO_COLUMNS: + if isinstance(data, dict): + values.update(data) + elif isinstance(data, str): + values[single_column_name] = data + else: + values[single_column_name] = json.dumps(data) + + return values + + @staticmethod + def get_sql_query_data( + cls_name: str, + connector_id: str, + connector_settings: dict[str, Any], + table_name: str, + values: dict[str, Any], + ) -> dict[str, Any]: + """Generate SQL columns and values for an insert query based on the + provided values and table schema. + + Args: + connector_id: The connector id of the connector provided + connector_settings: Connector settings provided by user + table_name (str): The name of the target table for the insert query. + values (dict[str, Any]): A dictionary containing column-value pairs + for the insert query. + + Returns: + list[str]: A list of SQL values suitable for use in an insert query. + + Note: + - This function determines the database type based on the + `engine` parameter. + - If the database is Snowflake (DBConnectionClass.SNOWFLAKE), + it handles Snowflake-specific SQL generation. + - For other SQL databases, it uses default SQL generation + based on column types. + """ + column_types: dict[str, str] = DatabaseUtils.get_column_types( + cls_name=cls_name, + table_name=table_name, + connector_id=connector_id, + connector_settings=connector_settings, + ) + sql_columns_and_values = DatabaseUtils.get_sql_values_for_query( + values=values, + column_types=column_types, + cls_name=cls_name, + ) + return sql_columns_and_values + + @staticmethod + def execute_write_query( + db_class: UnstractDB, + engine: Any, + table_name: str, + sql_keys: list[str], + sql_values: Any, + ) -> None: + """Execute Insert Query. + + Args: + engine (Any): _description_ + table_name (str): table name + sql_keys (list[str]): columns + sql_values (list[str]): values + Notes: + - Snowflake does not support INSERT INTO ... VALUES ... + syntax when VARIANT columns are present (JSON). + So we need to use INSERT INTO ... SELECT ... syntax + - sql values can contain data with single quote. It needs to + """ + cls_name = db_class.__class__.__name__ + sql = DBConnectorQueryHelper.build_sql_insert_query( + cls_name=cls_name, table_name=table_name, sql_keys=sql_keys + ) + logger.debug(f"inserting into table {table_name} with: {sql} query") + + sql_values = DBConnectorQueryHelper.prepare_sql_values( + cls_name=cls_name, sql_values=sql_values, sql_keys=sql_keys + ) + logger.debug(f"sql_values: {sql_values}") + + try: + db_class.execute_query( + engine=engine, + sql_query=sql, + sql_values=sql_values, + table_name=table_name, + sql_keys=sql_keys, + ) + except UnstractDBConnectorException as e: + raise UnstractDBException(detail=e.detail) from e + logger.debug(f"sucessfully inserted into table {table_name} with: {sql} query") + + @staticmethod + def get_db_class( + connector_id: str, connector_settings: dict[str, Any] + ) -> UnstractDB: + connector = db_connectors[connector_id][Common.METADATA][Common.CONNECTOR] + connector_class: UnstractDB = connector(connector_settings) + return connector_class + + @staticmethod + def execute_and_fetch_data( + connector_id: str, connector_settings: dict[str, Any], query: str + ) -> Any: + connector = db_connectors[connector_id][Common.METADATA][Common.CONNECTOR] + connector_class: UnstractDB = connector(connector_settings) + try: + return connector_class.execute(query=query) + except ConnectorError as e: + raise UnstractDBException(detail=e.message) from e + + @staticmethod + def create_table_if_not_exists( + db_class: UnstractDB, + engine: Any, + table_name: str, + database_entry: dict[str, Any], + ) -> None: + """Creates table if not exists. + + Args: + class_name (UnstractDB): Type of Unstract DB connector + table_name (str): _description_ + database_entry (dict[str, Any]): _description_ + + Raises: + e: _description_ + """ + sql = DBConnectorQueryHelper.create_table_query( + conn_cls=db_class, table=table_name, database_entry=database_entry + ) + logger.debug(f"creating table {table_name} with: {sql} query") + try: + db_class.execute_query(engine=engine, sql_query=sql, sql_values=None) + except UnstractDBConnectorException as e: + raise UnstractDBException(detail=e.detail) from e + logger.debug(f"successfully created table {table_name} with: {sql} query") diff --git a/backend/workflow_manager/endpoint_v2/db_connector_helper.py b/backend/workflow_manager/endpoint_v2/db_connector_helper.py new file mode 100644 index 000000000..6c1e92de3 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/db_connector_helper.py @@ -0,0 +1,77 @@ +from typing import Any + +from google.cloud import bigquery +from workflow_manager.endpoint_v2.constants import DBConnectionClass, TableColumns + +from unstract.connectors.databases.unstract_db import UnstractDB + + +class DBConnectorQueryHelper: + """A class that helps to generate query for connector table operations.""" + + @staticmethod + def create_table_query( + conn_cls: UnstractDB, table: str, database_entry: dict[str, Any] + ) -> Any: + sql_query = "" + """Generate a SQL query to create a table, based on the provided + database entry. + + Args: + conn_cls (str): The database connector class. + Should be one of 'BIGQUERY', 'SNOWFLAKE', or other. + table (str): The name of the table to be created. + database_entry (dict[str, Any]): + A dictionary containing column names as keys + and their corresponding values. + + These values are used to determine the data types, + for the columns in the table. + + Returns: + str: A SQL query string to create a table with the specified name, + and column definitions. + + Note: + - Each conn_cls have it's implementation for SQL create table query + Based on the implementation, a base SQL create table query will be + created containing Permanent columns + - Each conn_cls also has a mapping to convert python datatype to + corresponding column type (string, VARCHAR etc) + - keys in database_entry will be converted to column type, and + column values will be the valus in database_entry + - base SQL create table will be appended based column type and + values, and generates a complete SQL create table query + """ + create_table_query = conn_cls.get_create_table_query(table=table) + sql_query += create_table_query + + for key, val in database_entry.items(): + if key not in TableColumns.PERMANENT_COLUMNS: + sql_type = conn_cls.sql_to_db_mapping(val) + sql_query += f"{key} {sql_type}, " + + return sql_query.rstrip(", ") + ");" + + @staticmethod + def build_sql_insert_query( + cls_name: str, table_name: str, sql_keys: list[str] + ) -> str: + keys_str = ",".join(sql_keys) + if cls_name == DBConnectionClass.BIGQUERY: + values_placeholder = ",".join(["@" + key for key in sql_keys]) + else: + values_placeholder = ",".join(["%s" for _ in sql_keys]) + return f"INSERT INTO {table_name} ({keys_str}) VALUES ({values_placeholder})" + + @staticmethod + def prepare_sql_values(cls_name: str, sql_values: Any, sql_keys: list[str]) -> Any: + if cls_name == DBConnectionClass.MSSQL: + return tuple(sql_values) + elif cls_name == DBConnectionClass.BIGQUERY: + query_parameters = [ + bigquery.ScalarQueryParameter(key, "STRING", value) + for key, value in zip(sql_keys, sql_values) + ] + return bigquery.QueryJobConfig(query_parameters=query_parameters) + return sql_values diff --git a/backend/workflow_manager/endpoint_v2/destination.py b/backend/workflow_manager/endpoint_v2/destination.py new file mode 100644 index 000000000..c882c1b0e --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/destination.py @@ -0,0 +1,532 @@ +import ast +import base64 +import json +import logging +import os +from typing import Any, Optional + +import fsspec +import magic +from connector_v2.models import ConnectorInstance +from fsspec.implementations.local import LocalFileSystem +from unstract.sdk.constants import ToolExecKey +from unstract.workflow_execution.constants import ToolOutputType +from utils.user_context import UserContext +from workflow_manager.endpoint_v2.base_connector import BaseConnector +from workflow_manager.endpoint_v2.constants import ( + ApiDeploymentResultStatus, + DestinationKey, + QueueResultStatus, + WorkflowFileType, +) +from workflow_manager.endpoint_v2.database_utils import DatabaseUtils +from workflow_manager.endpoint_v2.exceptions import ( + DestinationConnectorNotConfigured, + InvalidDestinationConnectionType, + InvalidToolOutputType, + MissingDestinationConnectionType, + ToolOutputTypeMismatch, +) +from workflow_manager.endpoint_v2.models import WorkflowEndpoint +from workflow_manager.endpoint_v2.queue_utils import QueueResult, QueueUtils +from workflow_manager.workflow_v2.enums import ExecutionStatus +from workflow_manager.workflow_v2.file_history_helper import FileHistoryHelper +from workflow_manager.workflow_v2.models.file_history import FileHistory +from workflow_manager.workflow_v2.models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class DestinationConnector(BaseConnector): + """A class representing a Destination connector for a workflow. + + This class extends the BaseConnector class and provides methods for + interacting with different types of destination connectors, + such as file system connectors and API connectors and DB connectors. + + Attributes: + workflow (Workflow): The workflow associated with + the destination connector. + """ + + def __init__(self, workflow: Workflow, execution_id: str) -> None: + """Initialize a DestinationConnector object. + + Args: + workflow (Workflow): _description_ + """ + organization_id = UserContext.get_organization_identifier() + super().__init__(workflow.id, execution_id, organization_id) + self.endpoint = self._get_endpoint_for_workflow(workflow=workflow) + self.source_endpoint = self._get_source_endpoint_for_workflow(workflow=workflow) + self.execution_id = execution_id + self.api_results: list[dict[str, Any]] = [] + self.queue_results: list[dict[str, Any]] = [] + + def _get_endpoint_for_workflow( + self, + workflow: Workflow, + ) -> WorkflowEndpoint: + """Get WorkflowEndpoint instance. + + Args: + workflow (Workflow): Workflow associated with the + destination connector. + + Returns: + WorkflowEndpoint: WorkflowEndpoint instance. + """ + endpoint: WorkflowEndpoint = WorkflowEndpoint.objects.get( + workflow=workflow, + endpoint_type=WorkflowEndpoint.EndpointType.DESTINATION, + ) + if endpoint.connector_instance: + endpoint.connector_instance.connector_metadata = ( + endpoint.connector_instance.metadata + ) + return endpoint + + def _get_source_endpoint_for_workflow( + self, + workflow: Workflow, + ) -> WorkflowEndpoint: + """Get WorkflowEndpoint instance. + + Args: + workflow (Workflow): Workflow associated with the + destination connector. + + Returns: + WorkflowEndpoint: WorkflowEndpoint instance. + """ + endpoint: WorkflowEndpoint = WorkflowEndpoint.objects.get( + workflow=workflow, + endpoint_type=WorkflowEndpoint.EndpointType.SOURCE, + ) + if endpoint.connector_instance: + endpoint.connector_instance.connector_metadata = ( + endpoint.connector_instance.metadata + ) + return endpoint + + def validate(self) -> None: + connection_type = self.endpoint.connection_type + connector: ConnectorInstance = self.endpoint.connector_instance + if connection_type is None: + raise MissingDestinationConnectionType() + if connection_type not in WorkflowEndpoint.ConnectionType.values: + raise InvalidDestinationConnectionType() + if ( + connection_type != WorkflowEndpoint.ConnectionType.API + and connection_type != WorkflowEndpoint.ConnectionType.MANUALREVIEW + and connector is None + ): + raise DestinationConnectorNotConfigured() + + def handle_output( + self, + file_name: str, + file_hash: str, + workflow: Workflow, + file_history: Optional[FileHistory] = None, + error: Optional[str] = None, + input_file_path: Optional[str] = None, + ) -> None: + """Handle the output based on the connection type.""" + connection_type = self.endpoint.connection_type + result: Optional[str] = None + meta_data: Optional[str] = None + if error: + if connection_type == WorkflowEndpoint.ConnectionType.API: + self._handle_api_result(file_name=file_name, error=error, result=result) + return + if connection_type == WorkflowEndpoint.ConnectionType.FILESYSTEM: + self.copy_output_to_output_directory() + elif connection_type == WorkflowEndpoint.ConnectionType.DATABASE: + self.insert_into_db(file_history) + elif connection_type == WorkflowEndpoint.ConnectionType.API: + result = self.get_result(file_history) + meta_data = self.get_metadata(file_history) + self._handle_api_result( + file_name=file_name, error=error, result=result, meta_data=meta_data + ) + elif connection_type == WorkflowEndpoint.ConnectionType.MANUALREVIEW: + result = self.get_result(file_history) + meta_data = self.get_metadata(file_history) + self._push_to_queue( + file_name=file_name, + workflow=workflow, + result=result, + input_file_path=input_file_path, + meta_data=meta_data, + ) + if not file_history: + FileHistoryHelper.create_file_history( + cache_key=file_hash, + workflow=workflow, + status=ExecutionStatus.COMPLETED, + result=result, + metadata=meta_data, + file_name=file_name, + ) + + def copy_output_to_output_directory(self) -> None: + """Copy output to the destination directory.""" + connector: ConnectorInstance = self.endpoint.connector_instance + connector_settings: dict[str, Any] = connector.connector_metadata + destination_configurations: dict[str, Any] = self.endpoint.configuration + root_path = str(connector_settings.get(DestinationKey.PATH)) + output_folder = str( + destination_configurations.get(DestinationKey.OUTPUT_FOLDER, "/") + ) + overwrite = bool( + destination_configurations.get( + DestinationKey.OVERWRITE_OUTPUT_DOCUMENT, False + ) + ) + output_directory = os.path.join(root_path, output_folder) + + destination_volume_path = os.path.join( + self.execution_dir, ToolExecKey.OUTPUT_DIR + ) + + connector_fs = self.get_fsspec( + settings=connector_settings, connector_id=connector.connector_id + ) + if not connector_fs.isdir(output_directory): + connector_fs.mkdir(output_directory) + + # Traverse local directory and create the same structure in the + # output_directory + for root, dirs, files in os.walk(destination_volume_path): + for dir_name in dirs: + connector_fs.mkdir( + os.path.join( + output_directory, + os.path.relpath(root, destination_volume_path), + dir_name, + ) + ) + + for file_name in files: + source_path = os.path.join(root, file_name) + destination_path = os.path.join( + output_directory, + os.path.relpath(root, destination_volume_path), + file_name, + ) + normalized_path = os.path.normpath(destination_path) + with open(source_path, "rb") as source_file: + connector_fs.write_bytes( + normalized_path, source_file.read(), overwrite=overwrite + ) + + def insert_into_db(self, file_history: Optional[FileHistory]) -> None: + """Insert data into the database.""" + connector_instance: ConnectorInstance = self.endpoint.connector_instance + connector_settings: dict[str, Any] = connector_instance.metadata + destination_configurations: dict[str, Any] = self.endpoint.configuration + table_name: str = str(destination_configurations.get(DestinationKey.TABLE)) + include_agent: bool = bool( + destination_configurations.get(DestinationKey.INCLUDE_AGENT, False) + ) + include_timestamp = bool( + destination_configurations.get(DestinationKey.INCLUDE_TIMESTAMP, False) + ) + agent_name = str(destination_configurations.get(DestinationKey.AGENT_NAME)) + column_mode = str(destination_configurations.get(DestinationKey.COLUMN_MODE)) + single_column_name = str( + destination_configurations.get(DestinationKey.SINGLE_COLUMN_NAME, "data") + ) + + data = self.get_result(file_history) + values = DatabaseUtils.get_columns_and_values( + column_mode_str=column_mode, + data=data, + include_timestamp=include_timestamp, + include_agent=include_agent, + agent_name=agent_name, + single_column_name=single_column_name, + ) + db_class = DatabaseUtils.get_db_class( + connector_id=connector_instance.connector_id, + connector_settings=connector_settings, + ) + engine = db_class.get_engine() + # If data is None, don't execute CREATE or INSERT query + if data is None: + return + DatabaseUtils.create_table_if_not_exists( + db_class=db_class, + engine=engine, + table_name=table_name, + database_entry=values, + ) + cls_name = db_class.__class__.__name__ + sql_columns_and_values = DatabaseUtils.get_sql_query_data( + cls_name=cls_name, + connector_id=connector_instance.connector_id, + connector_settings=connector_settings, + table_name=table_name, + values=values, + ) + DatabaseUtils.execute_write_query( + db_class=db_class, + engine=engine, + table_name=table_name, + sql_keys=list(sql_columns_and_values.keys()), + sql_values=list(sql_columns_and_values.values()), + ) + + def _handle_api_result( + self, + file_name: str, + error: Optional[str] = None, + result: Optional[str] = None, + meta_data: Optional[dict[str, Any]] = None, + ) -> None: + """Handle the API result. + + This method is responsible for handling the API result. + It appends the file name and result to the 'results' list for API resp. + + Args: + file_name (str): The name of the file. + result (Optional[str], optional): The result of the API call. + Defaults to None. + + Returns: + None + """ + api_result: dict[str, Any] = {"file": file_name} + if error: + api_result.update( + {"status": ApiDeploymentResultStatus.FAILED, "error": error} + ) + else: + if result: + api_result.update( + { + "status": ApiDeploymentResultStatus.SUCCESS, + "result": result, + "metadata": meta_data, + } + ) + else: + api_result.update( + {"status": ApiDeploymentResultStatus.SUCCESS, "result": ""} + ) + self.api_results.append(api_result) + + def parse_string(self, original_string: str) -> Any: + """Parse the given string, attempting to evaluate it as a Python + literal. + ex: a json string to dict method + Parameters: + - original_string (str): The input string to be parsed. + + Returns: + - Any: The parsed result. If the string can be evaluated as a Python + literal, the result of the evaluation is returned. + If not, the original string is returned unchanged. + + Note: + This function uses `ast.literal_eval` to attempt parsing the string as a + Python literal. If parsing fails due to a SyntaxError or ValueError, + the original string is returned. + + Example: + >>> parser.parse_string("42") + 42 + >>> parser.parse_string("[1, 2, 3]") + [1, 2, 3] + >>> parser.parse_string("Hello, World!") + 'Hello, World!' + """ + try: + # Try to evaluate as a Python literal + python_literal = ast.literal_eval(original_string) + return python_literal + except (SyntaxError, ValueError): + # If evaluating as a Python literal fails, + # assume it's a plain string + return original_string + + def get_result(self, file_history: Optional[FileHistory]) -> Optional[Any]: + """Get result data from the output file. + + Returns: + Union[dict[str, Any], str]: Result data. + """ + if file_history and file_history.result: + return self.parse_string(file_history.result) + output_file = os.path.join(self.execution_dir, WorkflowFileType.INFILE) + metadata: dict[str, Any] = self.get_workflow_metadata() + output_type = self.get_output_type(metadata) + result: Optional[Any] = None + try: + # TODO: SDK handles validation; consider removing here. + mime = magic.Magic() + file_type = mime.from_file(output_file) + if output_type == ToolOutputType.JSON: + if "JSON" not in file_type: + logger.error(f"Output type json mismatched {file_type}") + raise ToolOutputTypeMismatch() + with open(output_file) as file: + result = json.load(file) + elif output_type == ToolOutputType.TXT: + if "JSON" in file_type: + logger.error(f"Output type txt mismatched {file_type}") + raise ToolOutputTypeMismatch() + with open(output_file) as file: + result = file.read() + result = result.encode("utf-8").decode("unicode-escape") + else: + raise InvalidToolOutputType() + except (FileNotFoundError, json.JSONDecodeError) as err: + logger.error(f"Error while getting result {err}") + return result + + def get_metadata( + self, file_history: Optional[FileHistory] + ) -> Optional[dict[str, Any]]: + """Get meta_data from the output file. + + Returns: + Union[dict[str, Any], str]: Meta data. + """ + if file_history and file_history.meta_data: + return self.parse_string(file_history.meta_data) + metadata: dict[str, Any] = self.get_workflow_metadata() + + return metadata + + def delete_execution_directory(self) -> None: + """Delete the execution directory. + + Returns: + None + """ + fs: LocalFileSystem = fsspec.filesystem("file") + fs.rm(self.execution_dir, recursive=True) + self.delete_api_storage_dir(self.workflow_id, self.execution_id) + + @classmethod + def delete_api_storage_dir(cls, workflow_id: str, execution_id: str) -> None: + """Delete the api storage path. + + Returns: + None + """ + api_storage_dir = cls.get_api_storage_dir_path( + workflow_id=workflow_id, execution_id=execution_id + ) + fs: LocalFileSystem = fsspec.filesystem("file") + fs.rm(api_storage_dir, recursive=True) + + @classmethod + def create_endpoint_for_workflow( + cls, + workflow: Workflow, + ) -> None: + """Create a workflow endpoint for the destination. + + Args: + workflow (Workflow): Workflow for which the endpoint is created. + """ + endpoint = WorkflowEndpoint( + workflow=workflow, + endpoint_type=WorkflowEndpoint.EndpointType.DESTINATION, + ) + endpoint.save() + + @classmethod + def get_json_schema_for_database(cls) -> dict[str, Any]: + """Get JSON schema for the database. + + Returns: + dict[str, Any]: JSON schema for the database. + """ + schema_path = os.path.join( + os.path.dirname(__file__), "static", "dest", "db.json" + ) + return cls.get_json_schema(file_path=schema_path) + + @classmethod + def get_json_schema_for_file_system(cls) -> dict[str, Any]: + """Get JSON schema for the file system. + + Returns: + dict[str, Any]: JSON schema for the file system. + """ + schema_path = os.path.join( + os.path.dirname(__file__), "static", "dest", "file.json" + ) + return cls.get_json_schema(file_path=schema_path) + + @classmethod + def get_json_schema_for_api(cls) -> dict[str, Any]: + """Json schema for api. + + Returns: + dict[str, Any]: _description_ + """ + schema_path = os.path.join( + os.path.dirname(__file__), "static", "dest", "api.json" + ) + return cls.get_json_schema(file_path=schema_path) + + def _push_to_queue( + self, + file_name: str, + workflow: Workflow, + result: Optional[str] = None, + input_file_path: Optional[str] = None, + meta_data: Optional[dict[str, Any]] = None, + ) -> None: + """Handle the Manual Review QUEUE result. + + This method is responsible for pushing the input file and result to + review queue. + Args: + file_name (str): The name of the file. + workflow (Workflow): The workflow object containing + details about the workflow. + result (Optional[str], optional): The result of the API call. + Defaults to None. + input_file_path (Optional[str], optional): + The path to the input file. + Defaults to None. + meta_data (Optional[dict[str, Any]], optional): + A dictionary containing additional + metadata related to the file. Defaults to None. + + Returns: + None + """ + if not result: + return + connector: ConnectorInstance = self.source_endpoint.connector_instance + connector_settings: dict[str, Any] = connector.connector_metadata + + source_fs = self.get_fsspec( + settings=connector_settings, connector_id=connector.connector_id + ) + with source_fs.open(input_file_path, "rb") as remote_file: + file_content = remote_file.read() + # Convert file content to a base64 encoded string + file_content_base64 = base64.b64encode(file_content).decode("utf-8") + q_name = f"review_queue_{self.organization_id}_{workflow.workflow_name}" + queue_result = QueueResult( + file=file_name, + whisper_hash=meta_data["whisper-hash"], + status=QueueResultStatus.SUCCESS, + result=result, + workflow_id=str(self.workflow_id), + file_content=file_content_base64, + ) + # Convert the result dictionary to a JSON string + queue_result_json = json.dumps(queue_result) + conn = QueueUtils.get_queue_inst() + # Enqueue the JSON string + conn.enqueue(queue_name=q_name, message=queue_result_json) diff --git a/backend/workflow_manager/endpoint_v2/endpoint_utils.py b/backend/workflow_manager/endpoint_v2/endpoint_utils.py new file mode 100644 index 000000000..70207b3a9 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/endpoint_utils.py @@ -0,0 +1,30 @@ +from workflow_manager.endpoint_v2.destination import DestinationConnector +from workflow_manager.endpoint_v2.models import WorkflowEndpoint +from workflow_manager.endpoint_v2.source import SourceConnector +from workflow_manager.workflow_v2.models.workflow import Workflow +from workflow_manager.workflow_v2.workflow_helper import WorkflowHelper + + +class WorkflowEndpointUtils: + @staticmethod + def create_endpoints_for_workflow(workflow: Workflow) -> None: + """Create endpoints for a given workflow. This method creates both + source and destination endpoints for the specified workflow. + + Parameters: + workflow (Workflow): The workflow for which + the endpoints need to be created. + + Returns: + None + """ + SourceConnector.create_endpoint_for_workflow(workflow) + DestinationConnector.create_endpoint_for_workflow(workflow) + + @staticmethod + def get_endpoints_for_workflow(workflow_id: str) -> list[WorkflowEndpoint]: + workflow = WorkflowHelper.get_workflow_by_id(workflow_id) + endpoints: list[WorkflowEndpoint] = WorkflowEndpoint.objects.filter( + workflow=workflow + ) + return endpoints diff --git a/backend/workflow_manager/endpoint_v2/exceptions.py b/backend/workflow_manager/endpoint_v2/exceptions.py new file mode 100644 index 000000000..7028eec28 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/exceptions.py @@ -0,0 +1,87 @@ +from rest_framework.exceptions import APIException + + +class InvalidInputDirectory(APIException): + status_code = 400 + default_detail = "The provided directory is invalid." + + +class InvalidSourceConnectionType(APIException): + status_code = 400 + default_detail = "The provided source connection type is invalid." + + +class InvalidDestinationConnectionType(APIException): + status_code = 400 + default_detail = "The provided destination connection type is invalid." + + +class MissingSourceConnectionType(APIException): + status_code = 400 + default_detail = "The source connection type is missing." + + +class MissingDestinationConnectionType(APIException): + status_code = 400 + default_detail = "The destination connection type is missing." + + +class SourceConnectorNotConfigured(APIException): + status_code = 400 + default_detail = "The source connector is not configured" + + +class DestinationConnectorNotConfigured(APIException): + status_code = 400 + default_detail = "The destination connector is not configured" + + +class FileHashNotFound(APIException): + status_code = 500 + default_detail = "Internal server error: File hash not found." + + +class ToolMetadataNotFound(APIException): + status_code = 500 + default_detail = "Internal server error: Tool metadata not found." + + +class OrganizationIdNotFound(APIException): + status_code = 404 + default_detail = "The organization ID could not be found" + + +class InvalidToolOutputType(APIException): + status_code = 500 + default_detail = "Invalid output type is returned from tool" + + +class ToolOutputTypeMismatch(APIException): + status_code = 400 + default_detail = ( + "The data type of the tool's output does not match the expected type." + ) + + +class BigQueryTableNotFound(APIException): + status_code = 400 + default_detail = ( + "Please enter correct correct bigquery table in the form " + "{table}.{schema}.{database}." + ) + + +class UnstractDBException(APIException): + default_detail = "Error creating/inserting to database. " + + def __init__(self, detail: str = default_detail) -> None: + status_code = 500 + super().__init__(detail=detail, code=status_code) + + +class UnstractQueueException(APIException): + default_detail = "Error creating/inserting to Queue. " + + def __init__(self, detail: str = default_detail) -> None: + status_code = 500 + super().__init__(detail=detail, code=status_code) diff --git a/backend/workflow_manager/endpoint_v2/models.py b/backend/workflow_manager/endpoint_v2/models.py new file mode 100644 index 000000000..6b067160e --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/models.py @@ -0,0 +1,54 @@ +import uuid + +from connector_v2.models import ConnectorInstance +from django.db import models +from utils.models.base_model import BaseModel +from workflow_manager.workflow_v2.models.workflow import Workflow + + +class WorkflowEndpoint(BaseModel): + class EndpointType(models.TextChoices): + SOURCE = "SOURCE", "Source connector" + DESTINATION = "DESTINATION", "Destination Connector" + + class ConnectionType(models.TextChoices): + FILESYSTEM = "FILESYSTEM", "FileSystem connector" + DATABASE = "DATABASE", "Database Connector" + API = "API", "API Connector" + APPDEPLOYMENT = "APPDEPLOYMENT", "App Deployment" + MANUALREVIEW = "MANUALREVIEW", "Manual Review Queue Connector" + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + workflow = models.ForeignKey( + Workflow, + on_delete=models.CASCADE, + db_index=True, + editable=False, + db_comment="Foreign key from Workflow model", + ) + endpoint_type = models.CharField( + choices=EndpointType.choices, + editable=False, + db_comment="Endpoint type (source or destination)", + ) + connection_type = models.CharField( + choices=ConnectionType.choices, + blank=True, + db_comment="Connection type (Filesystem, Database, API or Manualreview)", + ) + configuration = models.JSONField( + blank=True, null=True, db_comment="Configuration in JSON format" + ) + connector_instance = models.ForeignKey( + ConnectorInstance, + on_delete=models.CASCADE, + db_index=True, + null=True, + db_comment="Foreign key from ConnectorInstance model", + related_name="workflow_endpoints", + ) + + class Meta: + verbose_name = "Workflow Endpoint" + verbose_name_plural = "Workflow Endpoints" + db_table = "workflow_endpoints_v2" diff --git a/backend/workflow_manager/endpoint_v2/queue_utils.py b/backend/workflow_manager/endpoint_v2/queue_utils.py new file mode 100644 index 000000000..fde790e42 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/queue_utils.py @@ -0,0 +1,41 @@ +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Any + +from utils.constants import Common +from workflow_manager.endpoint_v2.exceptions import UnstractQueueException + +from unstract.connectors.queues import connectors as queue_connectors +from unstract.connectors.queues.unstract_queue import UnstractQueue + +logger = logging.getLogger(__name__) + + +class QueueResultStatus(Enum): + SUCCESS = "success" + FAILURE = "failure" + # Add other statuses as needed + + +class QueueUtils: + @staticmethod + def get_queue_inst(connector_settings: dict[str, Any] = {}) -> UnstractQueue: + if not queue_connectors: + raise UnstractQueueException(detail="Queue connector not exists") + queue_connector_key = next(iter(queue_connectors)) + connector = queue_connectors[queue_connector_key][Common.METADATA][ + Common.CONNECTOR + ] + connector_class: UnstractQueue = connector(connector_settings) + return connector_class + + +@dataclass +class QueueResult: + file: str + whisper_hash: str + status: QueueResultStatus + result: Any + workflow_id: str + file_content: str diff --git a/backend/workflow_manager/endpoint_v2/serializers.py b/backend/workflow_manager/endpoint_v2/serializers.py new file mode 100644 index 000000000..2ce3d4a8f --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/serializers.py @@ -0,0 +1,12 @@ +import logging + +from rest_framework.serializers import ModelSerializer +from workflow_manager.endpoint_v2.models import WorkflowEndpoint + +logger = logging.getLogger(__name__) + + +class WorkflowEndpointSerializer(ModelSerializer): + class Meta: + model = WorkflowEndpoint + fields = "__all__" diff --git a/backend/workflow_manager/endpoint_v2/source.py b/backend/workflow_manager/endpoint_v2/source.py new file mode 100644 index 000000000..57ef8d3f5 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/source.py @@ -0,0 +1,461 @@ +import fnmatch +import logging +import os +import shutil +from hashlib import md5, sha256 +from io import BytesIO +from pathlib import Path +from typing import Any, Optional + +import fsspec +from connector_processor.constants import ConnectorKeys +from connector_v2.models import ConnectorInstance +from django.core.files.uploadedfile import UploadedFile +from unstract.workflow_execution.enums import LogState +from utils.user_context import UserContext +from workflow_manager.endpoint_v2.base_connector import BaseConnector +from workflow_manager.endpoint_v2.constants import ( + FilePattern, + FileSystemConnector, + FileType, + SourceConstant, + SourceKey, + WorkflowFileType, +) +from workflow_manager.endpoint_v2.exceptions import ( + FileHashNotFound, + InvalidInputDirectory, + InvalidSourceConnectionType, + MissingSourceConnectionType, + OrganizationIdNotFound, + SourceConnectorNotConfigured, +) +from workflow_manager.endpoint_v2.models import WorkflowEndpoint +from workflow_manager.workflow_v2.execution import WorkflowExecutionServiceHelper +from workflow_manager.workflow_v2.models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class SourceConnector(BaseConnector): + """A class representing a source connector for a workflow. + + This class extends the BaseConnector class and provides methods for + interacting with different types of source connectors, + such as file system connectors and API connectors. + It allows listing files from the source connector, + adding files to the execution volume, and retrieving JSON schemas for + different types of connectors. + + Attributes: + workflow (Workflow): The workflow associated with the source connector. + """ + + def __init__( + self, + workflow: Workflow, + execution_id: str, + organization_id: Optional[str] = None, + execution_service: Optional[WorkflowExecutionServiceHelper] = None, + ) -> None: + """Initialize a SourceConnector object. + + Args: + workflow (Workflow): _description_ + """ + organization_id = organization_id or UserContext.get_organization_identifier() + if not organization_id: + raise OrganizationIdNotFound() + super().__init__(workflow.id, execution_id, organization_id) + self.endpoint = self._get_endpoint_for_workflow(workflow=workflow) + self.workflow = workflow + self.execution_id = execution_id + self.organization_id = organization_id + self.hash_value_of_file_content: Optional[str] = None + self.execution_service = execution_service + + def _get_endpoint_for_workflow( + self, + workflow: Workflow, + ) -> WorkflowEndpoint: + """Get WorkflowEndpoint instance. + + Args: + workflow (Workflow): Workflow + + Returns: + WorkflowEndpoint: _description_ + """ + endpoint: WorkflowEndpoint = WorkflowEndpoint.objects.get( + workflow=workflow, + endpoint_type=WorkflowEndpoint.EndpointType.SOURCE, + ) + if endpoint.connector_instance: + endpoint.connector_instance.connector_metadata = ( + endpoint.connector_instance.metadata + ) + return endpoint + + def validate(self) -> None: + connection_type = self.endpoint.connection_type + connector: ConnectorInstance = self.endpoint.connector_instance + if connection_type is None: + raise MissingSourceConnectionType() + if connection_type not in WorkflowEndpoint.ConnectionType.values: + raise InvalidSourceConnectionType() + if connection_type != WorkflowEndpoint.ConnectionType.API and connector is None: + raise SourceConnectorNotConfigured() + + def valid_file_patterns(self, required_patterns: list[Any]) -> list[str]: + patterns = { + FileType.PDF_DOCUMENTS: FilePattern.PDF_DOCUMENTS, + FileType.TEXT_DOCUMENTS: FilePattern.TEXT_DOCUMENTS, + FileType.IMAGES: FilePattern.IMAGES, + } + wildcard = [] + if not required_patterns: + wildcard.append("*") + else: + for pattern in required_patterns: + wildcard.extend(patterns.get(pattern, [])) + return wildcard + + def list_file_from_api_storage(self) -> list[str]: + """List all files from the api_storage_dir directory.""" + files: list[str] = [] + if not self.api_storage_dir: + return files + for file in os.listdir(self.api_storage_dir): + file_path = os.path.join(self.api_storage_dir, file) + if os.path.isfile(file_path): + files.append(file_path) + return files + + def list_files_from_file_connector(self) -> list[str]: + """_summary_ + + Raises: + InvalidDirectory: _description_ + + Returns: + list[str]: _description_ + """ + connector: ConnectorInstance = self.endpoint.connector_instance + connector_settings: dict[str, Any] = connector.connector_metadata + source_configurations: dict[str, Any] = self.endpoint.configuration + required_patterns = source_configurations.get(SourceKey.FILE_EXTENSIONS, []) + recursive = bool( + source_configurations.get(SourceKey.PROCESS_SUB_DIRECTORIES, False) + ) + limit = int( + source_configurations.get( + SourceKey.MAX_FILES, FileSystemConnector.MAX_FILES + ) + ) + root_dir_path = connector_settings.get(ConnectorKeys.PATH, "") + input_directory = str(source_configurations.get(SourceKey.ROOT_FOLDER, "")) + if root_dir_path: # user needs to manually type the optional file path + input_directory = str(Path(root_dir_path, input_directory.lstrip("/"))) + if not isinstance(required_patterns, list): + required_patterns = [required_patterns] + + source_fs = self.get_fsspec( + settings=connector_settings, connector_id=connector.connector_id + ) + patterns = self.valid_file_patterns(required_patterns=required_patterns) + is_directory = source_fs.isdir(input_directory) + if not is_directory: + raise InvalidInputDirectory() + matched_files = self._get_matched_files( + source_fs, input_directory, patterns, recursive, limit + ) + self.publish_input_output_list_file_logs(input_directory, matched_files) + return matched_files + + def publish_input_output_list_file_logs( + self, input_directory: str, matched_files: list[str] + ) -> None: + if not self.execution_service: + return None + input_log = f"##Input folder:\n\n `{os.path.basename(input_directory)}`\n\n" + self.execution_service.publish_update_log( + state=LogState.INPUT_UPDATE, message=input_log + ) + output_log = self._matched_files_component_log(matched_files) + self.execution_service.publish_update_log( + state=LogState.OUTPUT_UPDATE, message=output_log + ) + + def publish_input_file_content(self, input_file_path: str, input_text: str) -> None: + if self.execution_service: + output_log_message = f"##Input text:\n\n```text\n{input_text}\n```\n\n" + input_log_message = ( + "##Input file:\n\n```text\n" + f"{os.path.basename(input_file_path)}\n```\n\n" + ) + self.execution_service.publish_update_log( + state=LogState.INPUT_UPDATE, message=input_log_message + ) + self.execution_service.publish_update_log( + state=LogState.OUTPUT_UPDATE, message=output_log_message + ) + + def _matched_files_component_log(self, matched_files: list[str]) -> str: + output_log = "### Matched files \n```text\n\n\n" + for file in matched_files[:20]: + output_log += f"{file}\n" + output_log += "```\n\n" + output_log += f"""Total matched files: {len(matched_files)} + \n\nPlease note that only the first 20 files are shown.\n\n""" + return output_log + + def _get_matched_files( + self, + source_fs: Any, + input_directory: str, + patterns: list[str], + recursive: bool, + limit: int, + ) -> list[str]: + """Get a list of matched files based on patterns in a directory. + + This method searches for files in the specified `input_directory` that + match any of the given `patterns`. + The search can be performed recursively if `recursive` is set to True. + The number of matched files returned is limited by `limit`. + + Args: + source_fs (Any): The file system object used for searching. + input_directory (str): The directory to search for files. + patterns (list[str]): The patterns to match against file names. + recursive (bool): Whether to perform a recursive search. + limit (int): The maximum number of matched files to return. + + Returns: + list[str]: A list of matched file paths. + """ + matched_files = [] + count = 0 + max_depth = int(SourceConstant.MAX_RECURSIVE_DEPTH) if recursive else 1 + + for root, dirs, files in source_fs.walk(input_directory, maxdepth=max_depth): + if count >= limit: + break + for file in files: + if not file: + continue + if count >= limit: + break + if any(fnmatch.fnmatch(file, pattern) for pattern in patterns): + file_path = os.path.join(root, file) + file_path = f"{file_path}" + matched_files.append(file_path) + count += 1 + + return matched_files + + def list_files_from_source(self) -> list[str]: + """List files from source connector. + + Args: + api_storage_dir (Optional[str], optional): API storage directory + Returns: + list[str]: list of files + """ + connection_type = self.endpoint.connection_type + if connection_type == WorkflowEndpoint.ConnectionType.FILESYSTEM: + return self.list_files_from_file_connector() + elif connection_type == WorkflowEndpoint.ConnectionType.API: + return self.list_file_from_api_storage() + raise InvalidSourceConnectionType() + + @classmethod + def hash_str(cls, string_to_hash: Any, hash_method: str = "sha256") -> str: + """Computes the hash for a given input string. + + Useful to hash strings needed for caching and other purposes. + Hash method defaults to "md5" + + Args: + string_to_hash (str): String to be hashed + hash_method (str): Hash hash_method to use, supported ones + - "md5" + + Returns: + str: Hashed string + """ + if hash_method == "md5": + if isinstance(string_to_hash, bytes): + return str(md5(string_to_hash).hexdigest()) + return str(md5(string_to_hash.encode()).hexdigest()) + elif hash_method == "sha256": + if isinstance(string_to_hash, (bytes, bytearray)): + return str(sha256(string_to_hash).hexdigest()) + return str(sha256(string_to_hash.encode()).hexdigest()) + else: + raise ValueError(f"Unsupported hash_method: {hash_method}") + + def add_input_from_connector_to_volume(self, input_file_path: str) -> str: + """Add input file to execution directory. + + Args: + input_file_path (str): The path of the input file. + Returns: + str: The hash value of the file content. + Raises: + FileHashNotFound: If the hash value of the file content + is not found. + """ + connector: ConnectorInstance = self.endpoint.connector_instance + connector_settings: dict[str, Any] = connector.connector_metadata + source_file_path = os.path.join(self.execution_dir, WorkflowFileType.SOURCE) + infile_path = os.path.join(self.execution_dir, WorkflowFileType.INFILE) + source_file = f"file://{source_file_path}" + + source_fs = self.get_fsspec( + settings=connector_settings, connector_id=connector.connector_id + ) + with ( + source_fs.open(input_file_path, "rb") as remote_file, + fsspec.open(source_file, "wb") as local_file, + ): + file_content = remote_file.read() + hash_value_of_file_content = self.hash_str(file_content) + logger.info( + f"hash_value_of_file {source_file} is " + f": {hash_value_of_file_content}" + ) + input_log = ( + file_content[:500].decode("utf-8", errors="replace") + "...(truncated)" + ) + self.publish_input_file_content(input_file_path, input_log) + + local_file.write(file_content) + shutil.copyfile(source_file_path, infile_path) + logger.info(f"{input_file_path} is added in to execution directory") + return hash_value_of_file_content + + def add_input_from_api_storage_to_volume(self, input_file_path: str) -> None: + """Add input file to execution directory from api storage.""" + infile_path = os.path.join(self.execution_dir, WorkflowFileType.INFILE) + source_path = os.path.join(self.execution_dir, WorkflowFileType.SOURCE) + shutil.copyfile(input_file_path, infile_path) + shutil.copyfile(input_file_path, source_path) + + def add_file_to_volume( + self, input_file_path: str, hash_values_of_files: dict[str, str] = {} + ) -> tuple[str, str]: + """Add input file to execution directory. + + Args: + input_file_path (str): source file + + Raises: + InvalidSource: _description_ + + Returns: + str: file_name + """ + connection_type = self.endpoint.connection_type + file_name = os.path.basename(input_file_path) + if connection_type == WorkflowEndpoint.ConnectionType.FILESYSTEM: + file_content_hash = self.add_input_from_connector_to_volume( + input_file_path=input_file_path, + ) + elif connection_type == WorkflowEndpoint.ConnectionType.API: + self.add_input_from_api_storage_to_volume(input_file_path=input_file_path) + if file_name not in hash_values_of_files: + raise FileHashNotFound() + file_content_hash = hash_values_of_files[file_name] + else: + raise InvalidSourceConnectionType() + + self.add_metadata_to_volume( + input_file_path=input_file_path, source_hash=file_content_hash + ) + return file_name, file_content_hash + + def handle_final_result( + self, + results: list[dict[str, Any]], + file_name: str, + result: Optional[str], + ) -> None: + connection_type = self.endpoint.connection_type + if connection_type == WorkflowEndpoint.ConnectionType.API: + results.append({"file": file_name, "result": result}) + + def load_file(self, input_file_path: str) -> tuple[str, BytesIO]: + connector: ConnectorInstance = self.endpoint.connector_instance + connector_settings: dict[str, Any] = connector.connector_metadata + source_fs: fsspec.AbstractFileSystem = self.get_fsspec( + settings=connector_settings, connector_id=connector.connector_id + ) + with source_fs.open(input_file_path, "rb") as remote_file: + file_content = remote_file.read() + file_stream = BytesIO(file_content) + + return remote_file.key, file_stream + + @classmethod + def add_input_file_to_api_storage( + cls, workflow_id: str, execution_id: str, file_objs: list[UploadedFile] + ) -> dict[str, str]: + """Add input file to api storage. + + Args: + workflow_id (str): workflow id + execution_id (str): execution_id + file_objs (list[UploadedFile]): api file objects + """ + api_storage_dir = cls.get_api_storage_dir_path( + workflow_id=workflow_id, execution_id=execution_id + ) + file_hashes: dict[str, str] = {} + for file in file_objs: + file_name = file.name + destination_path = os.path.join(api_storage_dir, file_name) + os.makedirs(os.path.dirname(destination_path), exist_ok=True) + with open(destination_path, "wb") as f: + buffer = bytearray() + for chunk in file.chunks(): + buffer.extend(chunk) + f.write(buffer) + file_hashes.update({file_name: cls.hash_str(buffer)}) + return file_hashes + + @classmethod + def create_endpoint_for_workflow( + cls, + workflow: Workflow, + ) -> None: + """Creating WorkflowEndpoint entity.""" + endpoint = WorkflowEndpoint( + workflow=workflow, + endpoint_type=WorkflowEndpoint.EndpointType.SOURCE, + ) + endpoint.save() + + @classmethod + def get_json_schema_for_api(cls) -> dict[str, Any]: + """Json schema for api. + + Returns: + dict[str, Any]: _description_ + """ + schema_path = os.path.join( + os.path.dirname(__file__), "static", "src", "api.json" + ) + return cls.get_json_schema(file_path=schema_path) + + @classmethod + def get_json_schema_for_file_system(cls) -> dict[str, Any]: + """Json schema for Filesystem. + + Returns: + dict[str, Any]: _description_ + """ + schema_path = os.path.join( + os.path.dirname(__file__), "static", "src", "file.json" + ) + return cls.get_json_schema(file_path=schema_path) diff --git a/backend/workflow_manager/endpoint_v2/static/dest/db.json b/backend/workflow_manager/endpoint_v2/static/dest/db.json new file mode 100644 index 000000000..70286c704 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/static/dest/db.json @@ -0,0 +1,54 @@ +{ + "title": "Workflow DB Destination", + "description": "Settings for DB Destination", + "type": "object", + "required": [ + "table", + "includeAgent", + "includeTimestamp", + "columnMode" + ], + "properties": { + "table": { + "type": "string", + "title": "Table", + "default": "", + "description": "Table to store the output. If your database supports schemas, use the format schema.table" + }, + "includeAgent": { + "type": "boolean", + "title": "Include 'created_by' column", + "default": false, + "description": "Include the 'created_by' in the output row" + }, + "agentName": { + "type": "string", + "title": "Agent Name", + "enum": [ + "Unstract/DBWriter" + ], + "default": "Unstract/DBWriter", + "description": "Name of the agent to use as the 'created_by' value" + }, + "includeTimestamp": { + "type": "boolean", + "title": "Include 'created_at' column", + "default": false, + "description": "Include the 'created_at' in the output row" + }, + "columnMode": { + "type": "string", + "title": "Select how you want to write the output", + "enum": [ + "Write JSON to a single column" + ], + "default": "Write JSON to a single column" + }, + "singleColumnName": { + "type": "string", + "title": "Single Column Name", + "default": "data", + "description": "Name of the column to write the JSON to" + } + } +} diff --git a/backend/workflow_manager/endpoint_v2/static/dest/file.json b/backend/workflow_manager/endpoint_v2/static/dest/file.json new file mode 100644 index 000000000..aa3a1263e --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/static/dest/file.json @@ -0,0 +1,25 @@ +{ + "title": "Workflow File Destination", + "description": "Settings for File Destination", + "type": "object", + "required": [ + "outputFolder" + ], + "properties": { + "outputFolder": { + "type": "string", + "title": "Output folder", + "default": "output", + "description": "Folder to store the output", + "minLength": 1, + "maxLength": 100, + "format": "file-path" + }, + "overwriteOutput": { + "type": "boolean", + "title": "Overwrite existing output", + "default": true, + "description": "Used to overwrite output document" + } + } +} diff --git a/backend/workflow_manager/endpoint_v2/static/src/api.json b/backend/workflow_manager/endpoint_v2/static/src/api.json new file mode 100644 index 000000000..8d4a9d022 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/static/src/api.json @@ -0,0 +1,21 @@ +{ + "title": "Workflow API Source", + "description": "Settings for API Source", + "type": "object", + "required": [], + "properties": { + "fileExtensions": { + "type": "array", + "title": "File types to process", + "description": "Limit the file types to process. Leave it empty to process all files", + "items": { + "type": "string", + "enum": [ + "PDF documents", + "Text documents", + "Images" + ] + } + } + } +} diff --git a/backend/workflow_manager/endpoint_v2/static/src/file.json b/backend/workflow_manager/endpoint_v2/static/src/file.json new file mode 100644 index 000000000..d603b7e53 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/static/src/file.json @@ -0,0 +1,39 @@ +{ + "title": "Workflow File Source", + "description": "Settings for File Source", + "type": "object", + "required": [], + "properties": { + "rootFolder": { + "type": "string", + "title": "Folder to process", + "default": "", + "description": "The root folder to start processing files from. Leave it empty to use the root folder" + }, + "processSubDirectories": { + "type": "boolean", + "title": "Process sub directories", + "default": true, + "description": "Process sub directories recursively" + }, + "fileExtensions": { + "type": "array", + "title": "File types to process", + "description": "Limit the file types to process. Leave it empty to process all files", + "items": { + "type": "string", + "enum": [ + "PDF documents", + "Text documents", + "Images" + ] + } + }, + "maxFiles": { + "type": "number", + "title": "Max files to process", + "default": 100, + "description": "The maximum number of files to process" + } + } +} diff --git a/backend/workflow_manager/endpoint_v2/tests/__init__.py b/backend/workflow_manager/endpoint_v2/tests/__init__.py new file mode 100644 index 000000000..fca2b2401 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/tests/__init__.py @@ -0,0 +1,3 @@ +from backend.celery import app as celery_app + +__all__ = ["celery_app"] diff --git a/backend/workflow_manager/endpoint_v2/tests/test_database_utils/__init__.py b/backend/workflow_manager/endpoint_v2/tests/test_database_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/workflow_manager/endpoint_v2/tests/test_database_utils/base_test_db.py b/backend/workflow_manager/endpoint_v2/tests/test_database_utils/base_test_db.py new file mode 100644 index 000000000..4040eb624 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/tests/test_database_utils/base_test_db.py @@ -0,0 +1,160 @@ +import datetime +import json +import os +from typing import Any + +import pytest # type: ignore +from dotenv import load_dotenv + +from unstract.connectors.databases.bigquery import BigQuery +from unstract.connectors.databases.mariadb import MariaDB +from unstract.connectors.databases.mssql import MSSQL +from unstract.connectors.databases.mysql import MySQL +from unstract.connectors.databases.postgresql import PostgreSQL +from unstract.connectors.databases.redshift import Redshift +from unstract.connectors.databases.snowflake import SnowflakeDB +from unstract.connectors.databases.unstract_db import UnstractDB + +load_dotenv("test.env") + + +class BaseTestDB: + @pytest.fixture(autouse=True) + def base_setup(self) -> None: + self.postgres_creds = { + "user": os.getenv("DB_USER"), + "password": os.getenv("DB_PASSWORD"), + "host": os.getenv("DB_HOST"), + "port": os.getenv("DB_PORT"), + "database": os.getenv("DB_NAME"), + } + self.redshift_creds = { + "user": os.getenv("REDSHIFT_USER"), + "password": os.getenv("REDSHIFT_PASSWORD"), + "host": os.getenv("REDSHIFT_HOST"), + "port": os.getenv("REDSHIFT_PORT"), + "database": os.getenv("REDSHIFT_DB"), + } + self.snowflake_creds = { + "user": os.getenv("SNOWFLAKE_USER"), + "password": os.getenv("SNOWFLAKE_PASSWORD"), + "account": os.getenv("SNOWFLAKE_ACCOUNT"), + "role": os.getenv("SNOWFLAKE_ROLE"), + "database": os.getenv("SNOWFLAKE_DB"), + "schema": os.getenv("SNOWFLAKE_SCHEMA"), + "warehouse": os.getenv("SNOWFLAKE_WAREHOUSE"), + } + self.mssql_creds = { + "user": os.getenv("MSSQL_USER"), + "password": os.getenv("MSSQL_PASSWORD"), + "server": os.getenv("MSSQL_SERVER"), + "port": os.getenv("MSSQL_PORT"), + "database": os.getenv("MSSQL_DB"), + } + self.mysql_creds = { + "user": os.getenv("MYSQL_USER"), + "password": os.getenv("MYSQL_PASSWORD"), + "host": os.getenv("MYSQL_SERVER"), + "port": os.getenv("MYSQL_PORT"), + "database": os.getenv("MYSQL_DB"), + } + self.mariadb_creds = { + "user": os.getenv("MARIADB_USER"), + "password": os.getenv("MARIADB_PASSWORD"), + "host": os.getenv("MARIADB_SERVER"), + "port": os.getenv("MARIADB_PORT"), + "database": os.getenv("MARIADB_DB"), + } + self.database_entry = { + "created_by": "Unstract/DBWriter", + "created_at": datetime.datetime(2024, 5, 20, 7, 46, 57, 307998), + "data": '{"input_file": "simple.pdf", "result": "report"}', + } + valid_schema_name = "public" + invalid_schema_name = "public_1" + self.valid_postgres_creds = {**self.postgres_creds, "schema": valid_schema_name} + self.invalid_postgres_creds = { + **self.postgres_creds, + "schema": invalid_schema_name, + } + self.valid_redshift_creds = {**self.redshift_creds, "schema": valid_schema_name} + self.invalid_redshift_creds = { + **self.redshift_creds, + "schema": invalid_schema_name, + } + self.invalid_syntax_table_name = "invalid-syntax.name.test_output" + self.invalid_wrong_table_name = "database.schema.test_output" + self.valid_table_name = "test_output" + bigquery_json_str = os.getenv("BIGQUERY_CREDS", "{}") + self.bigquery_settings = json.loads(bigquery_json_str) + self.bigquery_settings["json_credentials"] = bigquery_json_str + self.valid_bigquery_table_name = "pandoras-tamer.bigquery_test.bigquery_output" + self.invalid_snowflake_db = {**self.snowflake_creds, "database": "invalid"} + self.invalid_snowflake_schema = {**self.snowflake_creds, "schema": "invalid"} + self.invalid_snowflake_warehouse = { + **self.snowflake_creds, + "warehouse": "invalid", + } + + # Gets all valid db instances except + # Bigquery (table name needs to be writted separately for bigquery) + @pytest.fixture( + params=[ + ("valid_postgres_creds", PostgreSQL), + ("snowflake_creds", SnowflakeDB), + ("mssql_creds", MSSQL), + ("mysql_creds", MySQL), + ("mariadb_creds", MariaDB), + ("valid_redshift_creds", Redshift), + ] + ) + def valid_dbs_instance(self, request: Any) -> Any: + return self.get_db_instance(request=request) + + # Gets all valid db instances except: + # Bigquery (table name needs to be writted separately for bigquery) + # Redshift (can't process more than 64KB character type) + @pytest.fixture( + params=[ + ("valid_postgres_creds", PostgreSQL), + ("snowflake_creds", SnowflakeDB), + ("mssql_creds", MSSQL), + ("mysql_creds", MySQL), + ("mariadb_creds", MariaDB), + ] + ) + def valid_dbs_instance_to_handle_large_doc(self, request: Any) -> Any: + return self.get_db_instance(request=request) + + def get_db_instance(self, request: Any) -> UnstractDB: + creds_name, db_class = request.param + creds = getattr(self, creds_name) + if not creds: + pytest.fail(f"Unknown credentials: {creds_name}") + db_instance = db_class(settings=creds) + return db_instance + + # Gets all invalid-db instances for postgres, redshift: + @pytest.fixture( + params=[ + ("invalid_postgres_creds", PostgreSQL), + ("invalid_redshift_creds", Redshift), + ] + ) + def invalid_dbs_instance(self, request: Any) -> Any: + return self.get_db_instance(request=request) + + @pytest.fixture + def valid_bigquery_db_instance(self) -> Any: + return BigQuery(settings=self.bigquery_settings) + + # Gets all invalid-db instances for snowflake: + @pytest.fixture( + params=[ + ("invalid_snowflake_db", SnowflakeDB), + ("invalid_snowflake_schema", SnowflakeDB), + ("invalid_snowflake_warehouse", SnowflakeDB), + ] + ) + def invalid_snowflake_db_instance(self, request: Any) -> Any: + return self.get_db_instance(request=request) diff --git a/backend/workflow_manager/endpoint_v2/tests/test_database_utils/static/large_doc.txt b/backend/workflow_manager/endpoint_v2/tests/test_database_utils/static/large_doc.txt new file mode 100644 index 000000000..3a3b67a00 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/tests/test_database_utils/static/large_doc.txt @@ -0,0 +1 @@ +"\n\n UNITED STATES \n SECURITIES AND EXCHANGE COMMISSION \n Washington, D.C. 20549 \n\n FORM 10-Q \n\n(Mark One) \n [X] X QUARTERLY REPORT PURSUANT TO SECTION 13 OR 15(d) OF THE SECURITIES EXCHANGE ACT OF 1934 \n For the quarterly period ended December 30, 2023 \n or \n [ ] TRANSITION REPORT PURSUANT TO SECTION 13 OR 15(d) OF THE SECURITIES EXCHANGE ACT OF 1934 \n For the transition period from to \n Commission File Number: 001-36743 \n\n Apple Inc. \n (Exact name of Registrant as specified in its charter) \n\n California 94-2404110 \n (State or other jurisdiction (I.R.S. Employer Identification No.) \n of incorporation or organization) \n\n One Apple Park Way \n Cupertino, California 95014 \n (Address of principal executive offices) (Zip Code) \n (408) 996-1010 \n (Registrant\'s telephone number, including area code) \n\n Securities registered pursuant to Section 12(b) of the Act \n\n Title of each class symbol(s) Trading Name of each exchange on which registered \n Common Stock, $0.00001 par value per share AAPL The Nasdaq Stock Market LLC \n 0.000% Notes due 2025 The Nasdaq Stock Market LLC \n 0.875% Notes due 2025 The Nasdaq Stock Market LLC \n 1.625% Notes due 2026 The Nasdaq Stock Market LLC \n 2.000% Notes due 2027 The Nasdaq Stock Market LLC \n 1.375% Notes due 2029 The Nasdaq Stock Market LLC \n 3.050% Notes due 2029 The Nasdaq Stock Market LLC \n 0.500% Notes due 2031 The Nasdaq Stock Market LLC \n 3.600% Notes due 2042 The Nasdaq Stock Market LLC \n\nIndicate by check mark whether the Registrant (1) has filed all reports required to be filed by Section 13 or 15(d) of the Securities Exchange Act \nof 1934 during the preceding 12 months (or for such shorter period that the Registrant was required to file such reports), and (2) has been \nsubject to such filing requirements for the past 90 days. \n Yes [X] No [ ] \n\nIndicate by check mark whether the Registrant has submitted electronically every Interactive Data File required to be submitted pursuant to Rule \n405 of Regulation S-T (§232.405 of this chapter) during the preceding 12 months (or for such shorter period that the Registrant was required to \nsubmit such files). \n Yes [X] No [ ] \n<<<\n\n\nIndicate by check mark whether the Registrant is a large accelerated filer, an accelerated filer, a non-accelerated filer, a smaller reporting \ncompany, or an emerging growth company. See the definitions of "large accelerated filer," "accelerated filer," "smaller reporting company," and \n"emerging growth company" in Rule 12b-2 of the Exchange Act. \n\n Large accelerated filer [X] Accelerated filer [ ] \n Non-accelerated filer [ ] Smaller reporting company [ ] \n Emerging growth company [ ] \n\nIf an emerging growth company, indicate by check mark if the Registrant has elected not to use the extended transition period for complying with \nany new or revised financial accounting standards provided pursuant to Section 13(a) of the Exchange Act. [ ] \n\nIndicate by check mark whether the Registrant is a shell company (as defined in Rule 12b-2 of the Exchange Act). \n Yes [ ] No [X] \n\n 15,441,881,000 shares of common stock were issued and outstanding as of January 19, 2024. \n<<<\n\n\n Apple Inc. \n\n Form 10-Q \n\n For the Fiscal Quarter Ended December 30, 2023 \n TABLE OF CONTENTS \n\n Page \n Part I \nItem 1. Financial Statements 1 \nItem 2. Management\'s Discussion and Analysis of Financial Condition and Results of Operations 13 \nItem 3. Quantitative and Qualitative Disclosures About Market Risk 18 \nItem 4. Controls and Procedures 18 \n Part II \nItem 1. Legal Proceedings 19 \nItem 1A. Risk Factors 19 \nItem 2. Unregistered Sales of Equity Securities and Use of Proceeds 20 \nItem 3. Defaults Upon Senior Securities 21 \nItem 4. Mine Safety Disclosures 21 \nItem 5. Other Information 21 \nItem 6. Exhibits 21 \n<<<\n\n\nPARTI - FINANCIAL INFORMATION \n\nItem 1. Financial Statements \n\n Apple Inc. \n\n CONDENSED CONSOLIDATED STATEMENTS OF OPERATIONS (Unaudited) \n (In millions, except number of shares, which are reflected in thousands, and per-share amounts) \n\n Three Months Ended \n December 2023 30, December 2022 31, \n\n Net sales: \n Products $ 96,458 $ 96,388 \n Services 23,117 20,766 \n Total net sales 119,575 117,154 \n\n Cost of sales: \n Products 58,440 60,765 \n Services 6,280 6,057 \n Total cost of sales 64,720 66,822 \n Gross margin 54,855 50,332 \n\n Operating expenses: \n Research and development 7,696 7,709 \n Selling, general and administrative 6,786 6,607 \n Total operating expenses 14,482 14,316 \n\n Operating income 40,373 36,016 \n Other income/(expense), net (50) (393) \n Income before provision for income taxes 40,323 35,623 \n Provision for income taxes 6,407 5,625 \n Net income $ 33,916 $ 29,998 \n\n Earnings per share: \n Basic $ 2.19 $ 1.89 \n Diluted $ 2.18 $ 1.88 \n\n Shares used in computing earnings per share: \n Basic 15,509,763 15,892,723 \n Diluted 15,576,641 15,955,718 \n\n See accompanying Notes to Condensed Consolidated Financial Statements. \n\n Apple Inc. IQ1 2024 Form 10-Q | 1 \n<<<\n\n\n Apple Inc. \n\n CONDENSED CONSOLIDATED STATEMENTS OF COMPREHENSIVE INCOME (Unaudited) \n (In millions) \n\n Three Months Ended \n December 2023 30, December 2022 31, \n\nNet income $ 33,916 $ 29,998 \nOther comprehensive income/(loss): \n Change in foreign currency translation, net of tax 308 (14) \n\n Change in unrealized gains/losses on derivative instruments, net of tax: \n Change in fair value of derivative instruments (531) (988) \n Adjustment for net (gains)/losses realized and included in net income (823) (1,766) \n Total change in unrealized gains/losses on derivative instruments (1,354) (2,754) \n\n Change in unrealized gains/losses on marketable debt securities, net of tax: \n Change in fair value of marketable debt securities 3,045 900 \n Adjustment for net (gains)/losses realized and included in net income 75 65 \n Total change in unrealized gains/losses on marketable debt securities 3,120 965 \n\nTotal other comprehensive income/(loss) 2,074 (1,803) \nTotal comprehensive income $ 35,990 $ 28,195 \n\n See accompanying Notes to Condensed Consolidated Financial Statements. \n\n Apple Inc. I Q1 2024 Form 10-Q 12 \n<<<\n\n\n Apple Inc. \n\n CONDENSED CONSOLIDATED BALANCE SHEETS (Unaudited) \n (In millions, except number of shares, which are reflected in thousands, and par value) \n\n December 2023 30, September 2023 30, \n\n ASSETS: \nCurrent assets: \n Cash and cash equivalents $ 40,760 $ 29,965 \n Marketable securities 32,340 31,590 \n Accounts receivable, net 23,194 29,508 \n Vendor non-trade receivables 26,908 31,477 \n Inventories 6,511 6,331 \n Other current assets 13,979 14,695 \n Total current assets 143,692 143,566 \n\nNon-current assets: \n Marketable securities 99,475 100,544 \n Property, plant and equipment, net 43,666 43,715 \n Other non-current assets 66,681 64,758 \n Total non-current assets 209,822 209,017 \n Total assets $ 353,514 $ 352,583 \n\n LIABILITIES AND SHAREHOLDERS\' EQUITY: \nCurrent liabilities: \n Accounts payable $ 58,146 $ 62,611 \n Other current liabilities 54,611 58,829 \n Deferred revenue 8,264 8,061 \n Commercial paper 1,998 5,985 \n Term debt 10,954 9,822 \n Total current liabilities 133,973 145,308 \n\nNon-current liabilities : \n Term debt 95,088 95,281 \n Other non-current liabilities 50,353 49,848 \n Total non-current liabilities 145,441 145,129 \n Total liabilities 279,414 290,437 \n\nCommitments and contingencies \n\nShareholders\' equity: \n Common stock and additional paid-in capital, $0.00001 par value: 50,400,000 shares \n authorized; 15,460,223 and 15,550,061 shares issued and outstanding, respectively 75,236 73,812 \n Retained earnings/(Accumulated deficit) 8,242 (214) \n Accumulated other comprehensive loss (9,378) (11,452) \n Total shareholders\' equity 74,100 62,146 \n Total liabilities and shareholders\' equity $ 353,514 $ 352,583 \n\n See accompanying Notes to Condensed Consolidated Financial Statements. \n\n Apple Inc. IQ1 2024 Form 10-Q 13 \n<<<\n\n\n Apple Inc. \n\n CONDENSED CONSOLIDATED STATEMENTS OF SHAREHOLDERS\' EQUITY (Unaudited) \n (In millions, except per-share amounts) \n\n Three Months Ended \n December 2023 30, December 2022 31, \n\nTotal shareholders\' equity, beginning balances $ 62,146 $ 50,672 \n\nCommon stock and additional paid-in capital: \n Beginning balances 73,812 64,849 \n Common stock withheld related to net share settlement of equity awards (1,660) (1,434) \n Share-based compensation 3,084 2,984 \n Ending balances 75,236 66,399 \n\nRetained earnings/(Accumulated deficit): \n Beginning balances (214) (3,068) \n Net income 33,916 29,998 \n Dividends and dividend equivalents declared (3,774) (3,712) \n Common stock withheld related to net share settlement of equity awards (1,018) (978) \n Common stock repurchased (20,668) (19,000) \n Ending balances 8,242 3,240 \n\nAccumulated other comprehensive income/(loss): \n Beginning balances (11,452) (11,109) \n Other comprehensive income/(loss) 2,074 (1,803) \n Ending balances (9,378) (12,912) \n\nTotal shareholders\' equity, ending balances $ 74,100 $ 56,727 \n\nDividends and dividend equivalents declared per share or RSU $ 0.24 $ 0.23 \n\n See accompanying Notes to Condensed Consolidated Financial Statements. \n\n Apple Inc. 2024 Form 10-Q 14 \n<<<\n\n\n Apple Inc. \n\n CONDENSED CONSOLIDATED STATEMENTS OF CASH FLOWS (Unaudited) \n (In millions) \n\n Three Months Ended \n December 2023 30, December 2022 31, \n\nCash, cash equivalents and restricted cash, beginning balances $ 30,737 $ 24,977 \n\nOperating activities: \n Net income 33,916 29,998 \n Adjustments to reconcile net income to cash generated by operating activities: \n Depreciation and amortization 2,848 2,916 \n Share-based compensation expense 2,997 2,905 \n Other (989) (317) \n Changes in operating assets and liabilities: \n Accounts receivable, net 6,555 4,275 \n Vendor non-trade receivables 4,569 2,320 \n Inventories (137) (1,807) \n Other current and non-current assets (1,457) (4,099) \n Accounts payable (4,542) (6,075) \n Other current and non-current liabilities (3,865) 3,889 \n Cash generated by operating activities 39,895 34,005 \n\nInvesting activities: \n Purchases of marketable securities (9,780) (5,153) \n Proceeds from maturities of marketable securities 13,046 7,127 \n Proceeds from sales of marketable securities 1,337 509 \n Payments for acquisition of property, plant and equipment (2,392) (3,787) \n Other (284) (141) \n Cash generated by/(used in) investing activities 1,927 (1,445) \n\nFinancing activities: \n Payments for taxes related to net share settlement of equity awards (2,591) (2,316) \n Payments for dividends and dividend equivalents (3,825) (3,768) \n Repurchases of common stock (20,139) (19,475) \n Repayments of term debt (1,401) \n Repayments of commercial paper, net (3,984) (8,214) \n Other (46) (389) \n Cash used in financing activities (30,585) (35,563) \n\nIncrease/(Decrease) in cash, cash equivalents and restricted cash 11,237 (3,003) \nCash, cash equivalents and restricted cash, ending balances $ 41,974 $ 21,974 \n\nSupplemental cash flow disclosure: \n Cash paid for income taxes, net $ 7,255 $ 828 \n\n See accompanying Notes to Condensed Consolidated Financial Statements. \n\n Apple Inc. IQ1 2024 Form 10-Q 1 5 \n<<<\n\n\n Apple Inc. \n\n Notes to Condensed Consolidated Financial Statements (Unaudited) \n\nNote 1 - Summary of Significant Accounting Policies \n\nBasis of Presentation and Preparation \nThe condensed consolidated financial statements include the accounts of Apple Inc. and its wholly owned subsidiaries \n(collectively "Apple" or the "Company"). In the opinion of the Company\'s management, the condensed consolidated financial \nstatements reflect all adjustments, which are normal and recurring in nature, necessary for fair financial statement presentation. \nThe preparation of these condensed consolidated financial statements and accompanying notes in conformity with U.S. generally \naccepted accounting principles ("GAAP") requires the use of management estimates. Certain prior period amounts in the \ncondensed consolidated financial statements and accompanying notes have been reclassified to conform to the current period\'s \npresentation. These condensed consolidated financial statements and accompanying notes should be read in conjunction with \nthe Company\'s annual consolidated financial statements and accompanying notes included in its Annual Report on Form 10-K \nfor the fiscal year ended September 30, 2023 (the "2023 Form 10-K"). \n\nThe Company\'s fiscal year is the 52- or 53-week period that ends on the last Saturday of September. An additional week is \nincluded in the first fiscal quarter every five or six years to realign the Company\'s fiscal quarters with calendar quarters, which \noccurred in the first fiscal quarter of 2023. The Company\'s fiscal years 2024 and 2023 span 52 and 53 weeks, respectively. \nUnless otherwise stated, references to particular years, quarters, months and periods refer to the Company\'s fiscal years ended \nin September and the associated quarters, months and periods of those fiscal years. \n\nNote 2 - Revenue \nNet sales disaggregated by significant products and services for the three months ended December 30, 2023 and December 31, \n2022 were as follows (in millions): \n Three Months Ended \n December 2023 30, December 2022 31, \n iPhone® $ 69,702 $ 65,775 \n Mac® 7,780 7,735 \n iPad® 7,023 9,396 \n Wearables, Home and Accessories 11,953 13,482 \n Services 23,117 20,766 \n Total net sales $ 119,575 $ 117,154 \n\nTotal net sales include $3.5 billion of revenue recognized in the three months ended December 30, 2023 that was included in \ndeferred revenue as of September 30, 2023 and $3.4 billion of revenue recognized in the three months ended December 31, \n2022 that was included in deferred revenue as of September 24, 2022. \n\nThe Company\'s proportion of net sales by disaggregated revenue source was generally consistent for each reportable segment \nin Note 10, "Segment Information and Geographic Data" for the three months ended December 30, 2023 and December 31, \n2022, except in Greater China, where iPhone revenue represented a moderately higher proportion of net sales. \n\nAs of December 30, 2023 and September 30, 2023, the Company had total deferred revenue of $12.5 billion and $12.1 billion, \nrespectively. As of December 30, 2023, the Company expects 66% of total deferred revenue to be realized in less than a year, \n26% within one-to-two years, 7% within two-to-three years and 1% in greater than three years. \n\n Apple Inc. I Q1 2024 Form 10-Q | 6 \n<<<\n\n\nNote 3 - Earnings Per Share \nThe following table shows the computation of basic and diluted earnings per share for the three months ended December 30, \n2023 and December 31, 2022 (net income in millions and shares in thousands): \n Three Months Ended \n December 2023 30, December 2022 31, \n\nNumerator: \n Net income $ 33,916 $ 29,998 \n\nDenominator: \n Weighted-average basic shares outstanding 15,509,763 15,892,723 \n Effect of dilutive share-based awards 66,878 62,995 \n Weighted-average diluted shares 15,576,641 15,955,718 \n\nBasic earnings per share $ 2.19 $ 1.89 \nDiluted earnings per share $ 2.18 $ 1.88 \n\nApproximately 89 million restricted stock units ("RSUs") were excluded from the computation of diluted earnings per share for the \nthree months ended December 31, 2022 because their effect would have been antidilutive. \n\nNote 4 - Financial Instruments \n\nCash, Cash Equivalents and Marketable Securities \nThe following tables show the Company\'s cash, cash equivalents and marketable securities by significant investment category \nas of December 30, 2023 and September 30, 2023 (in millions): \n December 30, 2023 \n Cash and Current Non-Current \n Adjusted Cost Unrealized Unrealized Fair Cash Marketable Marketable \n Gains Losses Value Equivalents Securities Securities \n Cash $ 29,542 $ $ - $ 29,542 $ 29,542 $ $ \n\nLevel 1: \n Money market funds 2,000 2,000 2,000 \n Mutual funds 448 35 (11) 472 472 \n Subtotal 2,448 35 (11) 2,472 2,000 472 \n\nLevel 2 (1): \n U.S. Treasury securities 24,041 12 (920) 23,133 7,303 4,858 10,972 \n U.S. agency securities 5,791 (448) 5,343 243 98 5,002 \n Non-U.S. government securities 17,326 54 (675) 16,705 11,175 5,530 \n Certificates of deposit and time deposits 1,448 - - 1,448 1,119 329 \n Commercial paper 1,361 1,361 472 889 \n Corporate debt securities 75,360 112 (3,964) 71,508 81 13,909 57,518 \n Municipal securities 562 (14) 548 185 363 \n Mortgage- and asset-backed securities 22,369 53 (1,907) 20,515 425 20,090 \n Subtotal 148,258 231 (7,928) 140,561 9,218 31,868 99,475 \n Total (2) $ 180,248 $ 266 $ (7,939) $ 172,575 $ 40,760 $ 32,340 $ 99,475 \n\n Apple Inc. IQ1 2024 Form 10-Q 1 7 \n<<<\n\n\n September 30, 2023 \n Cash and Current Non-Current \n Adjusted Cost Unrealized Gains Unrealized Value Fair Cash Marketable Marketable Securities \n Losses Equivalents Securities \n Cash $ 28,359 $ $ $ 28,359 $ 28,359 $ $ \n Level 1: \n Money market funds 481 481 481 \n Mutual funds and equity securities 442 12 (26) 428 428 \n Subtotal 923 12 (26) 909 481 428 \n Level 2 (1): \n U.S. Treasury securities 19,406 (1,292) 18,114 35 5,468 12,611 \n U.S. agency securities 5,736 (600) 5,136 36 271 4,829 \n Non-U.S. government securities 17,533 6 (1,048) 16,491 11,332 5,159 \n Certificates of deposit and time deposits 1,354 - 1,354 1,034 320 \n Commercial paper 608 608 608 \n Corporate debt securities 76,840 6 (5,956) 70,890 20 12,627 58,243 \n Municipal securities 628 (26) 602 192 410 \n Mortgage- and asset-backed securities 22,365 6 (2,735) 19,636 344 19,292 \n Subtotal 144,470 18 (11,657) 132,831 1,125 31,162 100,544 \n Total (2) $ 173,752 $ 30 $ (11,683) $ 162,099 $ 29,965 $ 31,590 $ 100,544 \n\n (1) The valuation techniques used to measure the fair values of the Company\'s Level 2 financial instruments, which generally \n have counterparties with high credit ratings, are based on quoted market prices or model-driven valuations using significant \n inputs derived from or corroborated by observable market data. \n (2) As of December 30, 2023 and September 30, 2023, total marketable securities included $13.9 billion and $13.8 billion, \n respectively, that were restricted from general use, related to the European Commission decision finding that Ireland granted \n state aid to the Company, and other agreements. \n\nThe following table shows the fair value of the Company\'s non-current marketable debt securities, by contractual maturity, as of \nDecember 30, 2023 (in millions): \n\n Due after 1 year through 5 years $ 72,994 \nDue after 5 years through 10 years 9,368 \nDue after 10 years 17,113 \n Total fair value $ 99,475 \n\nDerivative Instruments and Hedging \nThe Company may use derivative instruments to partially offset its business exposure to foreign exchange and interest rate risk. \nHowever, the Company may choose not to hedge certain exposures for a variety of reasons, including accounting considerations \nor the prohibitive economic cost of hedging particular exposures. There can be no assurance the hedges will offset more than a \nportion of the financial impact resulting from movements in foreign exchange or interest rates. \n\nForeign Exchange Rate Risk \nTo protect gross margins from fluctuations in foreign exchange rates, the Company may use forwards, options or other \ninstruments, and may designate these instruments as cash flow hedges. The Company generally hedges portions of its \nforecasted foreign currency exposure associated with revenue and inventory purchases, typically for up to 12 months. \n\nTo protect the Company\'s foreign currency-denominated term debt or marketable securities from fluctuations in foreign \nexchange rates, the Company may use forwards, cross-currency swaps or other instruments. The Company designates these \ninstruments as either cash flow or fair value hedges. As of December 30, 2023, the maximum length of time over which the \nCompany is hedging its exposure to the variability in future cash flows for term debt-related foreign currency transactions is 19 \nyears. \n\nThe Company may also use derivative instruments that are not designated as accounting hedges to protect gross margins from \ncertain fluctuations in foreign exchange rates, as well as to offset a portion of the foreign currency gains and losses generated by \nthe remeasurement of certain assets and liabilities denominated in non-functional currencies. \n\n Apple Inc. IQ1 2024 Form 10-Q 18 \n<<<\n\n\nInterest Rate Risk \nTo protect the Company\'s term debt or marketable securities from fluctuations in interest rates, the Company may use interest \nrate swaps, options or other instruments. The Company designates these instruments as either cash flow or fair value hedges. \n\nThe notional amounts of the Company\'s outstanding derivative instruments as of December 30, 2023 and September 30, 2023 \nwere as follows (in millions): \n December 2023 30, September 2023 30, \n\n Derivative instruments designated as accounting hedges: \n Foreign exchange contracts $ 66,735 $ 74,730 \n Interest rate contracts $ 19,375 $ 19,375 \n\n Derivative instruments not designated as accounting hedges : \n Foreign exchange contracts $ 102,108 $ 104,777 \n\n The carrying amounts of the Company\'s hedged items in fair value hedges as of December 30, 2023 and September 30, 2023 \nwere as follows (in millions): \n December 2023 30, September 2023 30, \n\n Hedged assets/(liabilities): \n Current and non-current marketable securities $ 15,102 $ 14,433 \n Current and non-current term debt $ (18,661) $ (18,247) \n\nAccounts Receivable \n\n Trade Receivables \n The Company\'s third-party cellular network carriers accounted for 34% and 41% of total trade receivables as of December 30, \n2023 and September 30, 2023, respectively. The Company requires third-party credit support or collateral from certain \ncustomers to limit credit risk. \n\n Vendor Non-Trade Receivables \nThe Company has non-trade receivables from certain of its manufacturing vendors resulting from the sale of components to \nthese vendors who manufacture subassemblies or assemble final products for the Company. The Company purchases these \ncomponents directly from suppliers. The Company does not reflect the sale of these components in products net sales. Rather, \nthe Company recognizes any gain on these sales as a reduction of products cost of sales when the related final products are \nsold by the Company. As of December 30, 2023, the Company had two vendors that individually represented 10% or more of \ntotal vendor non-trade receivables, which accounted for 50% and 20%. As of September 30, 2023, the Company had two \nvendors that individually represented 10% or more of total vendor non-trade receivables, which accounted for 48% and 23%. \n\nNote 5 - Condensed Consolidated Financial Statement Details \nThe following table shows the Company\'s condensed consolidated financial statement details as of December 30, 2023 and \nSeptember 30, 2023 (in millions): \n\nProperty, Plant and Equipment, Net \n December 2023 30, September 2023 30, \n\n Gross property, plant and equipment $ 116,176 $ 114,599 \n Accumulated depreciation (72,510) (70,884) \n Total property, plant and equipment, net $ 43,666 $ 43,715 \n\n Apple Inc. IQ1 2024 Form 10-Q 19 \n<<<\n\n\nNote 6 - Debt \n\nCommercial Paper \nThe Company issues unsecured short-term promissory notes pursuant to a commercial paper program. The Company uses net \nproceeds from the commercial paper program for general corporate purposes, including dividends and share repurchases. As of \nDecember 30, 2023 and September 30, 2023, the Company had $2.0 billion and $6.0 billion of commercial paper outstanding, \nrespectively. The following table provides a summary of cash flows associated with the issuance and maturities of commercial \npaper for the three months ended December 30, 2023 and December 31, 2022 (in millions): \n Three Months Ended \n December 2023 30, December 2022 31, \n\n Maturities 90 days or less: \n Repayments of commercial paper, net $ (3,984) $ (5,569) \n\n Maturities greater than 90 days: \n Repayments of commercial paper - (2,645) \n\n Total repayments of commercial paper, net $ (3,984) $ (8,214) \n\nTerm Debt \nAs of December 30, 2023 and September 30, 2023, the Company had outstanding fixed-rate notes with varying maturities for an \naggregate carrying amount of $106.0 billion and $105.1 billion, respectively (collectively the "Notes"). As of December 30, 2023 \nand September 30, 2023, the fair value of the Company\'s Notes, based on Level 2 inputs, was $96.7 billion and $90.8 billion, \nrespectively. \n\nNote 7 - Shareholders\' Equity \n\nShare Repurchase Program \nDuring the three months ended December 30, 2023, the Company repurchased 118 million shares of its common stock for $20.5 \nbillion. The Company\'s share repurchase program does not obligate the Company to acquire a minimum amount of shares. \nUnder the program, shares may be repurchased in privately negotiated or open market transactions, including under plans \ncomplying with Rule 10b5-1 under the Securities Exchange Act of 1934, as amended (the "Exchange Act"). \n\nNote 8 - Share-Based Compensation \n\nRestricted Stock Units \nA summary of the Company\'s RSU activity and related information for the three months ended December 30, 2023 is as follows: \n Number RSUs of Weighted-Average Grant Date Fair Aggregate Fair Value \n (in thousands) Value Per RSU (in millions) \n Balance as of September 30, 2023 180,247 $ 135.91 \n RSUs granted 74,241 $ 171.58 \n RSUs vested (42,490) $ 110.75 \n RSUs canceled (3,026) $ 109.05 \n Balance as of December 30, 2023 208,972 $ 154.09 $ 40,233 \n\nThe fair value as of the respective vesting dates of RSUs was $7.7 billion and $6.8 billion for the three months ended December \n30, 2023 and December 31, 2022, respectively. \n\n Apple Inc. I Q1 2024 Form 10-Q | 10 \n<<<\n\n\nShare-Based Compensation \nThe following table shows share-based compensation expense and the related income tax benefit included in the Condensed \nConsolidated Statements of Operations for the three months ended December 30, 2023 and December 31, 2022 (in millions): \n Three Months Ended \n December 2023 30, December 2022 31, \n\nShare-based compensation expense $ 2,997 $ 2,905 \nIncome tax benefit related to share-based compensation expense $ (1,235) $ (1,178) \n\nAs of December 30, 2023, the total unrecognized compensation cost related to outstanding RSUs was $27.4 billion, which the \nCompany expects to recognize over a weighted-average period of 2.9 years. \n\nNote 9 - Contingencies \nThe Company is subject to various legal proceedings and claims that have arisen in the ordinary course of business and that \nhave not been fully resolved. The outcome of litigation is inherently uncertain. In the opinion of management, there was not at \nleast a reasonable possibility the Company may have incurred a material loss, or a material loss greater than a recorded accrual, \nconcerning loss contingencies for asserted legal and other claims. \n\nNote 10 - Segment Information and Geographic Data \nThe following table shows information by reportable segment for the three months ended December 30, 2023 and December 31, \n2022 (in millions): \n Three Months Ended \n December 2023 30, December 2022 31, \n\nAmericas: \n Net sales $ 50,430 $ 49,278 \n Operating income $ 20,357 $ 17,864 \n\n Europe: \n Net sales $ 30,397 $ 27,681 \n Operating income $ 12,711 $ 10,017 \n\nGreater China: \n Net sales $ 20,819 $ 23,905 \n Operating income $ 8,622 $ 10,437 \n\nJapan: \n Net sales $ 7,767 $ 6,755 \n Operating income $ 3,819 $ 3,236 \n\nRest of Asia Pacific: \n Net sales $ 10,162 $ 9,535 \n Operating income $ 4,579 $ 3,851 \n\n Apple Inc. I Q1 2024 Form 10-Q | 11 \n<<<\n\n\nA reconciliation of the Company\'s segment operating income to the Condensed Consolidated Statements of Operations for the \nthree months ended December 30, 2023 and December 31, 2022 is as follows (in millions): \n Three Months Ended \n December 2023 30, December 2022 31, \n\n Segment operating income $ 50,088 $ 45,405 \n Research and development expense (7,696) (7,709) \n Other corporate expenses, net (2,019) (1,680) \n Total operating income $ 40,373 $ 36,016 \n\n Apple Inc. I Q1 2024 Form 10-Q | 12 \n<<<\n\n\nItem 2. Management\'s Discussion and Analysis of Financial Condition and Results of Operations \n\nThis Item and other sections of this Quarterly Report on Form 10-Q ("Form 10-Q") contain forward-looking statements, within \nthe meaning of the Private Securities Litigation Reform Act of 1995, that involve risks and uncertainties. Forward-looking \nstatements provide current expectations of future events based on certain assumptions and include any statement that does \nnot directly relate to any historical or current fact. For example, statements in this Form 10-Q regarding the potential future \nimpact of macroeconomic conditions on the Company\'s business and results of operations are forward-looking statements. \nForward-looking statements can also be identified by words such as "future," "anticipates," "believes," "estimates," "expects," \n"intends," "plans," "predicts," "will," "would," "could," "can," "may," and similar terms. Forward-looking statements are not \nguarantees of future performance and the Company\'s actual results may differ significantly from the results discussed in the \nforward-looking statements. Factors that might cause such differences include, but are not limited to, those discussed in Part I, \nItem 1A of the 2023 Form 10-K under the heading "Risk Factors." The Company assumes no obligation to revise or update any \nforward-looking statements for any reason, except as required by law. \n\nUnless otherwise stated, all information presented herein is based on the Company\'s fiscal calendar, and references to \nparticular years, quarters, months or periods refer to the Company\'s fiscal years ended in September and the associated \nquarters, months and periods of those fiscal years. \n\nThe following discussion should be read in conjunction with the 2023 Form 10-K filed with the U.S. Securities and Exchange \nCommission (the "SEC") and the condensed consolidated financial statements and accompanying notes included in Part I, \nItem 1 of this Form 10-Q. \n\nAvailable Information \nThe Company periodically provides certain information for investors on its corporate website, www.apple.com, and its investor \nrelations website, investor.apple.com. This includes press releases and other information about financial performance, \ninformation on environmental, social and governance matters, and details related to the Company\'s annual meeting of \nshareholders. The information contained on the websites referenced in this Form 10-Q is not incorporated by reference into this \nfiling. Further, the Company\'s references to website URLs are intended to be inactive textual references only. \n\nBusiness Seasonality and Product Introductions \nThe Company has historically experienced higher net sales in its first quarter compared to other quarters in its fiscal year due in \npart to seasonal holiday demand. Additionally, new product and service introductions can significantly impact net sales, cost of \nsales and operating expenses. The timing of product introductions can also impact the Company\'s net sales to its indirect \ndistribution channels as these channels are filled with new inventory following a product launch, and channel inventory of an \nolder product often declines as the launch of a newer product approaches. Net sales can also be affected when consumers and \ndistributors anticipate a product introduction. \n\nFiscal Period \nThe Company\'s fiscal year is the 52- or 53-week period that ends on the last Saturday of September. An additional week is \nincluded in the first fiscal quarter every five or six years to realign the Company\'s fiscal quarters with calendar quarters, which \noccurred in the first quarter of 2023. The Company\'s fiscal years 2024 and 2023 span 52 and 53 weeks, respectively. \n\nQuarterly Highlights \nThe Company\'s first quarter of 2024 included 13 weeks, compared to 14 weeks during the first quarter of 2023. \n\nThe Company\'s total net sales increased 2% or $2.4 billion during the first quarter of 2024 compared to the same quarter in \n2023, driven primarily by higher net sales of iPhone and Services, partially offset by lower net sales of iPad and Wearables, \nHome and Accessories. \n\nDuring the first quarter of 2024, the Company announced an updated MacBook Pro® 14-in., MacBook Pro 16-in. and iMac®. \n\nThe Company repurchased $20.5 billion of its common stock and paid dividends and dividend equivalents of $3.8 billion during \nthe first quarter of 2024. \n\nMacroeconomic Conditions \nMacroeconomic conditions, including inflation, changes in interest rates, and currency fluctuations, have directly and indirectly \nimpacted, and could in the future materially impact, the Company\'s results of operations and financial condition. \n\n Apple Inc. I Q1 2024 Form 10-Q | 13 \n<<<\n\n\nSegment Operating Performance \nThe following table shows net sales by reportable segment for the three months ended December 30, 2023 and December 31, \n2022 (dollars in millions): \n Three Months Ended \n December 2023 30, December 2022 31, Change \n\n Net sales by reportable segment: \n Americas $ 50,430 $ 49,278 2 % \n Europe 30,397 27,681 10 % \n Greater China 20,819 23,905 (13)% \n Japan 7,767 6,755 15 % \n Rest of Asia Pacific 10,162 9,535 7 % \n Total net sales $ 119,575 $ 117,154 2 % \n\nAmericas \nAmericas net sales increased 2% or $1.2 billion during the first quarter of 2024 compared to the same quarter in 2023 due \nprimarily to higher net sales of Services and iPhone, partially offset by lower net sales of iPad. The strength in foreign currencies \nrelative to the U.S. dollar had a net favorable year-over-year impact on Americas net sales during the first quarter of 2024. \n\nEurope \nEurope net sales increased 10% or $2.7 billion during the first quarter of 2024 compared to the same quarter in 2023 due \nprimarily to higher net sales of iPhone. The strength in foreign currencies relative to the U.S. dollar had a net favorable year- \nover-year impact on Europe net sales during the first quarter of 2024. \n\nGreater China \nGreater China net sales decreased 13% or $3.1 billion during the first quarter of 2024 compared to the same quarter in 2023 due \nprimarily to lower net sales of iPhone, iPad and Wearables, Home and Accessories. The weakness in the renminbi relative to the \nU.S. dollar had an unfavorable year-over-year impact on Greater China net sales during the first quarter of 2024. \n\nJapan \nJapan net sales increased 15% or $1.0 billion during the first quarter of 2024 compared to the same quarter in 2023 due primarily \nto higher net sales of iPhone. The weakness in the yen relative to the U.S. dollar had an unfavorable year-over-year impact on \nJapan net sales during the first quarter of 2024. \n\nRest of Asia Pacific \nRest of Asia Pacific net sales increased 7% or $627 million during the first quarter of 2024 compared to the same quarter in 2023 \ndue primarily to higher net sales of iPhone, partially offset by lower net sales of Wearables, Home and Accessories. \n\n Apple Inc. I Q1 2024 Form 10-Q | 14 \n<<<\n\n\nProducts and Services Performance \nThe following table shows net sales by category for the three months ended December 30, 2023 and December 31, 2022 \n(dollars in millions): \n Three Months Ended \n December 2023 30, December 2022 31, Change \n\nNet sales by category: \n iPhone $ 69,702 $ 65,775 6 % \n Mac 7,780 7,735 1 % \n iPad 7,023 9,396 (25)% \n Wearables, Home and Accessories 11,953 13,482 (11)% \n Services 23,117 20,766 11 % \n Total net sales $ 119,575 $ 117,154 2 % \n\niPhone \niPhone net sales increased 6% or $3.9 billion during the first quarter of 2024 compared to the same quarter in 2023 due primarily \nto higher net sales of Pro models, partially offset by lower net sales of other models. \n\nMac \nMac net sales were relatively flat during the first quarter of 2024 compared to the same quarter in 2023. \n\niPad \niPad net sales decreased 25% or $2.4 billion during the first quarter of 2024 compared to the same quarter in 2023 due primarily \nto lower net sales of iPad Pro, iPad 9th generation and iPad Air. \n\nWearables, Home and Accessories \nWearables, Home and Accessories net sales decreased 11% or $1.5 billion during the first quarter of 2024 compared to the \nsame quarter in 2023 due primarily to lower net sales of Wearables and Accessories. \n\nServices \nServices net sales increased 11% or $2.4 billion during the first quarter of 2024 compared to the same quarter in 2023 due \nprimarily to higher net sales from advertising, video and cloud services. \n\n Apple Inc. I Q1 2024 Form 10-Q | 15 \n<<<\n\n\nGross Margin \nProducts and Services gross margin and gross margin percentage for the three months ended December 30, 2023 and \nDecember 31, 2022 were as follows (dollars in millions): \n Three Months Ended \n December 2023 30, December 2022 31, \n\nGross margin: \n Products $ 38,018 $ 35,623 \n Services 16,837 14,709 \n Total gross margin $ 54,855 $ 50,332 \n\n Gross margin percentage: \n Products 39.4% 37.0% \n Services 72.8% 70.8% \n Total gross margin percentage 45.9% 43.0% \n\nProducts Gross Margin \nProducts gross margin increased during the first quarter of 2024 compared to the same quarter in 2023 due primarily to cost \nsavings and a different Products mix, partially offset by the weakness in foreign currencies relative to the U.S. dollar and lower \nProducts volume. \n\nProducts gross margin percentage increased during the first quarter of 2024 compared to the same quarter in 2023 due primarily \nto cost savings and a different Products mix, partially offset by the weakness in foreign currencies relative to the U.S. dollar. \n\nServices Gross Margin \nServices gross margin increased during the first quarter of 2024 compared to the same quarter in 2023 due primarily to higher \nServices net sales and a different Services mix. \n\nServices gross margin percentage increased during the first quarter of 2024 compared to the same quarter in 2023 due primarily \nto a different Services mix. \n\nThe Company\'s future gross margins can be impacted by a variety of factors, as discussed in Part I, Item 1A of the 2023 Form \n10-K under the heading "Risk Factors." As a result, the Company believes, in general, gross margins will be subject to volatility \nand downward pressure. \n\n Apple Inc. I Q1 2024 Form 10-Q | 16 \n<<<\n\n\nOperating Expenses \nOperating expenses for the three months ended December 30, 2023 and December 31, 2022 were as follows (dollars in \nmillions): \n Three Months Ended \n December 2023 30, December 2022 31, \n\n Research and development $ 7,696 $ 7,709 \n Percentage of total net sales 6% 7% \n\n Selling, general and administrative $ 6,786 $ 6,607 \n Percentage of total net sales 6% 6% \n Total operating expenses $ 14,482 $ 14,316 \n Percentage of total net sales 12% 12% \n\nResearch and Development \nResearch and development ("R&D") expense was relatively flat during the first quarter of 2024 compared to the same quarter in \n2023. \n\nSelling, General and Administrative \nSelling, general and administrative expense increased 3% or $179 million during the first quarter of 2024 compared to the same \nquarter in 2023. \n\nProvision for Income Taxes \nProvision for income taxes, effective tax rate and statutory federal income tax rate for the three months ended December 30, \n2023 and December 31, 2022 were as follows (dollars in millions): \n Three Months Ended \n December 2023 30, December 2022 31, \n\n Provision for income taxes $ 6,407 $ 5,625 \n Effective tax rate 15.9% 15.8% \n Statutory federal income tax rate 21% 21% \n\nThe Company\'s effective tax rate for the first quarter of 2024 was lower than the statutory federal income tax rate due primarily to \na lower effective tax rate on foreign earnings, tax benefits from share-based compensation, and the impact of the U.S. federal \nR&D credit, partially offset by state income taxes. \n\nThe Company\'s effective tax rate for the first quarter of 2024 was relatively flat compared to the same quarter in 2023. \n\nLiquidity and Capital Resources \nThe Company believes its balances of cash, cash equivalents and unrestricted marketable securities, along with cash generated \nby ongoing operations and continued access to debt markets, will be sufficient to satisfy its cash requirements and capital return \nprogram over the next 12 months and beyond. \n\nThe Company\'s contractual cash requirements have not changed materially since the 2023 Form 10-K, except for manufacturing \npurchase obligations. \n\nManufacturing Purchase Obligations \nThe Company utilizes several outsourcing partners to manufacture subassemblies for the Company\'s products and to perform \nfinal assembly and testing of finished products. The Company also obtains individual components for its products from a wide \nvariety of individual suppliers. As of December 30, 2023, the Company had manufacturing purchase obligations of $38.0 billion, \nwith $37.9 billion payable within 12 months. \n\n Apple Inc. I Q1 2024 Form 10-Q | 17 \n<<<\n\n\nCapital Return Program \nIn addition to its contractual cash requirements, the Company has an authorized share repurchase program. The program does \nnot obligate the Company to acquire a minimum amount of shares. As of December 30, 2023, the Company\'s quarterly cash \ndividend was $0.24 per share. The Company intends to increase its dividend on an annual basis, subject to declaration by the \nBoard of Directors. \n\nRecent Accounting Pronouncements \n\nIncome Taxes \nIn December 2023, the Financial Accounting Standards Board (the "FASB") issued Accounting Standards Update ("ASU") No. \n2023-09, Income Taxes (Topic 740): Improvements to Income Tax Disclosures ("ASU 2023-09"), which will require the Company \nto disclose specified additional information in its income tax rate reconciliation and provide additional information for reconciling \nitems that meet a quantitative threshold. ASU 2023-09 will also require the Company to disaggregate its income taxes paid \ndisclosure by federal, state and foreign taxes, with further disaggregation required for significant individual jurisdictions. The \nCompany will adopt ASU 2023-09 in its fourth quarter of 2026. ASU 2023-09 allows for adoption using either a prospective or \nretrospective transition method. \n\nSegment Reporting \nIn November 2023, the FASB issued ASU No. 2023-07, Segment Reporting (Topic 280): Improvements to Reportable Segment \nDisclosures ("ASU 2023-07\'), which will require the Company to disclose segment expenses that are significant and regularly \nprovided to the Company\'s chief operating decision maker ("CODM"). In addition, ASU 2023-07 will require the Company to \ndisclose the title and position of its CODM and how the CODM uses segment profit or loss information in assessing segment \nperformance and deciding how to allocate resources. The Company will adopt ASU 2023-07 in its fourth quarter of 2025 using a \nretrospective transition method. \n\nCritical Accounting Estimates \nThe preparation of financial statements and related disclosures in conformity with GAAP and the Company\'s discussion and \nanalysis of its financial condition and operating results require the Company\'s management to make judgments, assumptions \nand estimates that affect the amounts reported. Note 1, "Summary of Significant Accounting Policies" of the Notes to Condensed \nConsolidated Financial Statements in Part I, Item 1 of this Form 10-Q and in the Notes to Consolidated Financial Statements in \nPart II, Item 8 of the 2023 Form 10-K describe the significant accounting policies and methods used in the preparation of the \nCompany\'s condensed consolidated financial statements. There have been no material changes to the Company\'s critical \naccounting estimates since the 2023 Form 10-K. \n\nItem 3. Quantitative and Qualitative Disclosures About Market Risk \n\nThere have been no material changes to the Company\'s market risk during the first three months of 2024. For a discussion of the \nCompany\'s exposure to market risk, refer to the Company\'s market risk disclosures set forth in Part II, Item 7A, "Quantitative and \nQualitative Disclosures About Market Risk" of the 2023 Form 10-K. \n\nItem 4. Controls and Procedures \n\nEvaluation of Disclosure Controls and Procedures \nBased on an evaluation under the supervision and with the participation of the Company\'s management, the Company\'s principal \nexecutive officer and principal financial officer have concluded that the Company\'s disclosure controls and procedures as defined \nin Rules 13a-15(e) and 15d-15(e) under the Exchange Act were effective as of December 30, 2023 to provide reasonable \nassurance that information required to be disclosed by the Company in reports that it files or submits under the Exchange Act is \n(i) recorded, processed, summarized and reported within the time periods specified in the SEC rules and forms and \n(ii) accumulated and communicated to the Company\'s management, including its principal executive officer and principal \nfinancial officer, as appropriate to allow timely decisions regarding required disclosure. \n\nChanges in Internal Control over Financial Reporting \nThere were no changes in the Company\'s internal control over financial reporting during the first quarter of 2024, which were \nidentified in connection with management\'s evaluation required by paragraph (d) of Rules 13a-15 and 15d-15 under the \nExchange Act, that have materially affected, or are reasonably likely to materially affect, the Company\'s internal control over \nfinancial reporting. \n\n Apple Inc. I Q1 2024 Form 10-Q | 18 \n<<<\n\n\nPART II - OTHER INFORMATION \n\nItem 1. Legal Proceedings \n\nEpic Games \nEpic Games, Inc. ("Epic") filed a lawsuit in the U.S. District Court for the Northern District of California (the "District Court") \nagainst the Company alleging violations of federal and state antitrust laws and California\'s unfair competition law based upon the \nCompany\'s operation of its App Store®. On September 10, 2021, the District Court ruled in favor of the Company with respect to \nnine out of the ten counts included in Epic\'s claim. The District Court found that certain provisions of the Company\'s App Store \nReview Guidelines violate California\'s unfair competition law and issued an injunction enjoining the Company from prohibiting \ndevelopers from including in their apps external links that direct customers to purchasing mechanisms other than Apple in-app \npurchasing. The injunction applies to apps on the U.S. storefront of the iOS and iPadOS® App Store. On April 24, 2023, the U.S. \nCourt of Appeals for the Ninth Circuit (the "Circuit Court") affirmed the District Court\'s ruling. On June 7, 2023, the Company and \nEpic filed petitions with the Circuit Court requesting further review of the decision. On June 30, 2023, the Circuit Court denied \nboth petitions. On July 17, 2023, the Circuit Court granted Apple\'s motion to stay enforcement of the injunction pending appeal to \nthe U.S. Supreme Court (the "Supreme Court"). On January 16, 2024, the Supreme Court denied both the Company\'s and Epic\'s \npetitions and the stay terminated. The Supreme Court\'s denial of Epic\'s petition confirms the District Court\'s ruling in favor of the \nCompany with respect to all of the antitrust claims. Following termination of the stay, the Company implemented a plan to comply \nwith the injunction and filed a statement of compliance with the District Court. On January 31, 2024, Epic filed a notice with the \nDistrict Court indicating its intent to dispute the Company\'s compliance plan. \n\nMasimo \nMasimo Corporation and Cercacor Laboratories, Inc. (together, "Masimo") filed a complaint before the U.S. International Trade \nCommission (the "ITC") alleging infringement by the Company of five patents relating to the functionality of the blood oxygen \nfeature in Apple Watch® Series 6 and 7. In its complaint, Masimo sought a permanent exclusion order prohibiting importation to \nthe U.S. of certain Apple Watch models that include blood oxygen sensing functionality. On October 26, 2023, the ITC entered a \nlimited exclusion order (the "Order") prohibiting importation and sales in the U.S. of Apple Watch models with blood oxygen \nsensing functionality, which includes Apple Watch Series 9 and Apple Watch Ultra™ 2. The Company subsequently proposed a \nredesign of Apple Watch Series 9 and Apple Watch Ultra 2 to the U.S. Customs and Border Protection (the "CBP") and appealed \nthe Order. On January 12, 2024, the CBP found that the Company\'s proposed redesign of Apple Watch Series 9 and Apple \nWatch Ultra 2 falls outside the scope of the Order, permitting the Company to import and sell the models in the U.S. \n\nOther Legal Proceedings \nThe Company is subject to other legal proceedings and claims that have not been fully resolved and that have arisen in the \nordinary course of business. The Company settled certain matters during the first quarter of 2024 that did not individually or in \nthe aggregate have a material impact on the Company\'s financial condition or operating results. The outcome of litigation is \ninherently uncertain. If one or more legal matters were resolved against the Company in a reporting period for amounts above \nmanagement\'s expectations, the Company\'s financial condition and operating results for that reporting period could be materially \nadversely affected. \n\nItem 1A. Risk Factors \n\nThe Company\'s business, reputation, results of operations, financial condition and stock price can be affected by a number of \nfactors, whether currently known or unknown, including those described in Part I, Item 1A of the 2023 Form 10-K under the \nheading "Risk Factors." When any one or more of these risks materialize from time to time, the Company\'s business, reputation, \nresults of operations, financial condition and stock price can be materially and adversely affected. Except as set forth below, \nthere have been no material changes to the Company\'s risk factors since the 2023 Form 10-K. \n\nThe technology industry, including, in some instances, the Company, is subject to intense media, political and regulatory \nscrutiny, which exposes the Company to increasing regulation, government investigations, legal actions and penalties. \nFrom time to time, the Company has made changes to its App Store, including actions taken in response to litigation, \ncompetition, market conditions and legal and regulatory requirements. The Company expects to make further business changes \nin the future. For example, in the U.S. the Company has implemented changes to how developers communicate with consumers \nwithin apps on the U.S. storefront of the iOS and iPadOS App Store regarding alternative purchasing mechanisms. \n\n Apple Inc. I Q1 2024 Form 10-Q | 19 \n<<<\n\n\n In January 2024, the Company announced changes to iOS, the App Store and Safari® in the European Union to comply with the \n Digital Markets Act (the "DMA"), including new business terms and alternative fee structures for iOS apps, alternative methods of \n distribution for iOS apps, alternative payment processing for apps across the Company\'s operating systems, and additional tools \n and application programming interfaces ("APIs") for developers. Although the Company\'s compliance plan is intended to address \n the DMA\'s obligations, it is still subject to potential challenge by the European Commission or private litigants. In addition, other \n jurisdictions may seek to require the Company to make changes to its business. While the changes introduced by the Company \n in the European Union are intended to reduce new privacy and security risks the DMA poses to European Union users, many \n risks will remain. \n\n The Company is also currently subject to antitrust investigations in various jurisdictions around the world, which can result in \n legal proceedings and claims against the Company that could, individually or in the aggregate, have a materially adverse impact \n on the Company\'s business, results of operations and financial condition. For example, the Company is the subject of \n investigations in Europe and other jurisdictions relating to App Store terms and conditions. If such investigations result in adverse \nfindings against the Company, the Company could be exposed to significant fines and may be required to make further changes \n to its App Store business, all of which could materially adversely affect the Company\'s business, results of operations and \n financial condition. \n\n Further, the Company has commercial relationships with other companies in the technology industry that are or may become \n subject to investigations and litigation that, if resolved against those other companies, could materially adversely affect the \n Company\'s commercial relationships with those business partners and materially adversely affect the Company\'s business, \n results of operations and financial condition. For example, the Company earns revenue from licensing arrangements with other \n companies to offer their search services on the Company\'s platforms and applications, and certain of these arrangements are \n currently subject to government investigations and legal proceedings. \n\n There can be no assurance the Company\'s business will not be materially adversely affected, individually or in the aggregate, by \nthe outcomes of such investigations, litigation or changes to laws and regulations in the future. Changes to the Company\'s \n business practices to comply with new laws and regulations or in connection with other legal proceedings can negatively impact \n the reputation of the Company\'s products for privacy and security and otherwise adversely affect the experience for users of the \n Company\'s products and services, and result in harm to the Company\'s reputation, loss of competitive advantage, poor market \n acceptance, reduced demand for products and services, and lost sales. \n\n Item 2. Unregistered Sales of Equity Securities and Use of Proceeds \n\n Purchases of Equity Securities by the Issuer and Affiliated Purchasers \n Share repurchase activity during the three months ended December 30, 2023 was as follows (in millions, except number of \n shares, which are reflected in thousands, and per-share amounts): \n Total of Shares Number \n Purchased as Dollar Approximate Value of \n Total Number Average Price Part Announced of Publicly Yet Shares Be Purchased That May \n of Shares Paid Per Plans or Under the Plans \n Periods Purchased Share Programs or Programs (1) \n October 1, 2023 to November 4, 2023: \n August 2023 ASRs 6,498 (2) 6,498 \n Open market and privately negotiated purchases 45,970 $ 174.03 45,970 \n\n November 5, 2023 to December 2, 2023: \n Open market and privately negotiated purchases 33,797 $ 187.14 33,797 \n\n December 3, 2023 to December 30, 2023: \n Open market and privately negotiated purchases 31,782 $ 194.29 31,782 \n Total 118,047 $ 53,569 \n\n (1) As of December 30, 2023, the Company was authorized by the Board of Directors to purchase up to $90 billion of the \n Company\'s common stock under a share repurchase program announced on May 4, 2023, of which $36.4 billion had been \n utilized. The program does not obligate the Company to acquire a minimum amount of shares. Under the program, shares \n may be repurchased in privately negotiated or open market transactions, including under plans complying with Rule 10b5-1 \n under the Exchange Act. \n (2) In August 2023, the Company entered into accelerated share repurchase agreements ("ASRs") to purchase up to a total of \n $5.0 billion of the Company\'s common stock. In October 2023, the purchase periods for these ASRs ended and an additional \n 6 million shares were delivered and retired. In total, 29 million shares were delivered under these ASRs at an average \n repurchase price of $174.93 per share. \n\n Apple Inc. I Q1 2024 Form 10-Q | 20 \n<<<\n\n\nItem 3. Defaults Upon Senior Securities \n\nNone. \n\nItem 4. Mine Safety Disclosures \n\nNot applicable. \n\nItem 5. Other Information \n\nInsider Trading Arrangements \nOn November 11, 2023 and November 27, 2023, respectively, Luca Maestri, the Company\'s Senior Vice President and Chief \nFinancial Officer, and Katherine L. Adams, the Company\'s Senior Vice President and General Counsel, each entered into a \ntrading plan intended to satisfy the affirmative defense conditions of Rule 10b5-1(c) under the Exchange Act. The plans provide \nfor the sale of all shares vested during the duration of the plans pursuant to certain equity awards granted to Mr. Maestri and Ms. \nAdams, respectively, excluding any shares withheld by the Company to satisfy income tax withholding and remittance \nobligations. Mr. Maestri\'s plan will expire on December 31, 2024, and Ms. Adams\'s plan will expire on November 1, 2024, subject \nto early termination for certain specified events set forth in the plans. \n\nItem 6. Exhibits \n Incorporated by Reference \n\n Number Exhibit Form Filing Period Date/ End \n Exhibit Description Exhibit Date \n 31.1* Rule 13a-14(a) / 15d-14(a) Certification of Chief Executive Officer. \n 31.2* Rule 13a-14(a) / 15d-14(a) Certification of Chief Financial Officer. \n 32.1 ** Section 1350 Certifications of Chief Executive Officer and Chief Financial Officer. \n 101* Inline XBRL Document Set for the condensed consolidated financial statements \n and accompanying notes in Part I, Item 1, "Financial Statements" of this \n Quarterly Report on Form 10-Q. \n 104* Inline the Exhibit XBRL for 101 the Inline cover XBRL page Document of this Quarterly Set. Report on Form 10-Q, included in \n\n * Filed herewith. \n ** Furnished herewith. \n\n Apple Inc. I Q1 2024 Form 10-Q | 21 \n<<<\n\n\n SIGNATURE \n\n Pursuant to the requirements of the Securities Exchange Act of 1934, the Registrant has duly caused this report to be signed on \nits behalf by the undersigned thereunto duly authorized. \n\n Date: February 1, 2024 Apple Inc. \n\n By: /s/ Luca Maestri \n Luca Maestri \n Senior Vice President, \n Chief Financial Officer \n\n Apple Inc. I Q1 2024 Form 10-Q | 22 \n<<<\n\n\n Exhibit 31.1 \n\n CERTIFICATION \n\nI, Timothy D. Cook, certify that: \n\n1. I have reviewed this quarterly report on Form 10-Q of Apple Inc .; \n\n2. Based on my knowledge, this report does not contain any untrue statement of a material fact or omit to state a material fact \n necessary to make the statements made, in light of the circumstances under which such statements were made, not \n misleading with respect to the period covered by this report; \n\n3. Based on my knowledge, the financial statements, and other financial information included in this report, fairly present in all \n material respects the financial condition, results of operations and cash flows of the Registrant as of, and for, the periods \n presented in this report; \n\n4. The Registrant\'s other certifying officer(s) and I are responsible for establishing and maintaining disclosure controls and \n procedures (as defined in Exchange Act Rules 13a-15(e) and 15d-15(e)) and internal control over financial reporting (as \n defined in Exchange Act Rules 13a-15(f) and 15d-15(f)) for the Registrant and have: \n\n (a) Designed such disclosure controls and procedures, or caused such disclosure controls and procedures to be \n designed under our supervision, to ensure that material information relating to the Registrant, including its \n consolidated subsidiaries, is made known to us by others within those entities, particularly during the period in \n which this report is being prepared; \n\n (b) Designed such internal control over financial reporting, or caused such internal control over financial reporting \n to be designed under our supervision, to provide reasonable assurance regarding the reliability of financial \n reporting and the preparation of financial statements for external purposes in accordance with generally \n accepted accounting principles; \n\n (c) Evaluated the effectiveness of the Registrant\'s disclosure controls and procedures and presented in this report \n our conclusions about the effectiveness of the disclosure controls and procedures, as of the end of the period \n covered by this report based on such evaluation; and \n\n (d) Disclosed in this report any change in the Registrant\'s internal control over financial reporting that occurred \n during the Registrant\'s most recent fiscal quarter (the Registrant\'s fourth fiscal quarter in the case of an annual \n report) that has materially affected, or is reasonably likely to materially affect, the Registrant\'s internal control \n over financial reporting; and \n\n5. The Registrant\'s other certifying officer(s) and I have disclosed, based on our most recent evaluation of internal control over \n financial reporting, to the Registrant\'s auditors and the audit committee of the Registrant\'s board of directors (or persons \n performing the equivalent functions): \n\n (a) All significant deficiencies and material weaknesses in the design or operation of internal control over financial \n reporting which are reasonably likely to adversely affect the Registrant\'s ability to record, process, summarize \n and report financial information; and \n\n (b) Any fraud, whether or not material, that involves management or other employees who have a significant role \n in the Registrant\'s internal control over financial reporting. \n\nDate: February 1, 2024 \n\n By: /s/ Timothy D. Cook \n Timothy D. Cook \n Chief Executive Officer \n<<<\n\n\n Exhibit 31.2 \n\n CERTIFICATION \n\nI, Luca Maestri, certify that: \n\n1. I have reviewed this quarterly report on Form 10-Q of Apple Inc .; \n\n2. Based on my knowledge, this report does not contain any untrue statement of a material fact or omit to state a material fact \n necessary to make the statements made, in light of the circumstances under which such statements were made, not \n misleading with respect to the period covered by this report; \n\n3. Based on my knowledge, the financial statements, and other financial information included in this report, fairly present in all \n material respects the financial condition, results of operations and cash flows of the Registrant as of, and for, the periods \n presented in this report; \n\n4. The Registrant\'s other certifying officer(s) and I are responsible for establishing and maintaining disclosure controls and \n procedures (as defined in Exchange Act Rules 13a-15(e) and 15d-15(e)) and internal control over financial reporting (as \n defined in Exchange Act Rules 13a-15(f) and 15d-15(f)) for the Registrant and have: \n\n (a) Designed such disclosure controls and procedures, or caused such disclosure controls and procedures to be \n designed under our supervision, to ensure that material information relating to the Registrant, including its \n consolidated subsidiaries, is made known to us by others within those entities, particularly during the period in \n which this report is being prepared; \n\n (b) Designed such internal control over financial reporting, or caused such internal control over financial reporting \n to be designed under our supervision, to provide reasonable assurance regarding the reliability of financial \n reporting and the preparation of financial statements for external purposes in accordance with generally \n accepted accounting principles; \n\n (c) Evaluated the effectiveness of the Registrant\'s disclosure controls and procedures and presented in this report \n our conclusions about the effectiveness of the disclosure controls and procedures, as of the end of the period \n covered by this report based on such evaluation; and \n\n (d) Disclosed in this report any change in the Registrant\'s internal control over financial reporting that occurred \n during the Registrant\'s most recent fiscal quarter (the Registrant\'s fourth fiscal quarter in the case of an annual \n report) that has materially affected, or is reasonably likely to materially affect, the Registrant\'s internal control \n over financial reporting; and \n\n5. The Registrant\'s other certifying officer(s) and I have disclosed, based on our most recent evaluation of internal control over \n financial reporting, to the Registrant\'s auditors and the audit committee of the Registrant\'s board of directors (or persons \n performing the equivalent functions): \n\n (a) All significant deficiencies and material weaknesses in the design or operation of internal control over financial \n reporting which are reasonably likely to adversely affect the Registrant\'s ability to record, process, summarize \n and report financial information; and \n\n (b) Any fraud, whether or not material, that involves management or other employees who have a significant role \n in the Registrant\'s internal control over financial reporting. \n\nDate: February 1, 2024 \n\n By: /s/ Luca Maestri \n Luca Maestri \n Senior Vice President, \n Chief Financial Officer \n<<<\n\n\n Exhibit 32.1 \n\n CERTIFICATIONS OF CHIEF EXECUTIVE OFFICER AND CHIEF FINANCIAL OFFICER \n PURSUANT TO \n 18 U.S.C. SECTION 1350, \n AS ADOPTED PURSUANT TO \n SECTION 906 OF THE SARBANES-OXLEY ACT OF 2002 \n\nI, Timothy D. Cook, certify, as of the date hereof, pursuant to 18 U.S.C. Section 1350, as adopted pursuant to Section 906 of the \nSarbanes-Oxley Act of 2002, that the Quarterly Report of Apple Inc. on Form 10-Q for the period ended December 30, 2023 fully \ncomplies with the requirements of Section 13(a) or 15(d) of the Securities Exchange Act of 1934 and that information contained \nin such Form 10-Q fairly presents in all material respects the financial condition and results of operations of Apple Inc. at the \ndates and for the periods indicated. \n\nDate: February 1, 2024 \n\n By: /s/ Timothy D. Cook \n Timothy D. Cook \n Chief Executive Officer \n\nI, Luca Maestri, certify, as of the date hereof, pursuant to 18 U.S.C. Section 1350, as adopted pursuant to Section 906 of the \nSarbanes-Oxley Act of 2002, that the Quarterly Report of Apple Inc. on Form 10-Q for the period ended December 30, 2023 fully \ncomplies with the requirements of Section 13(a) or 15(d) of the Securities Exchange Act of 1934 and that information contained \nin such Form 10-Q fairly presents in all material respects the financial condition and results of operations of Apple Inc. at the \ndates and for the periods indicated. \n\nDate: February 1, 2024 \n\n By: /s/ Luca Maestri \n Luca Maestri \n Senior Vice President, \n Chief Financial Officer \n\nA signed original of this written statement required by Section 906 has been provided to Apple Inc. and will be retained by Apple \nInc. and furnished to the Securities and Exchange Commission or its staff upon request. \n<<<\n" diff --git a/backend/workflow_manager/endpoint_v2/tests/test_database_utils/test_create_table_if_not_exists.py b/backend/workflow_manager/endpoint_v2/tests/test_database_utils/test_create_table_if_not_exists.py new file mode 100644 index 000000000..cbf4f1619 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/tests/test_database_utils/test_create_table_if_not_exists.py @@ -0,0 +1,97 @@ +import pytest # type: ignore +from workflow_manager.endpoint_v2.database_utils import DatabaseUtils +from workflow_manager.endpoint_v2.exceptions import UnstractDBException + +from unstract.connectors.databases.unstract_db import UnstractDB + +from .base_test_db import BaseTestDB + + +class TestCreateTableIfNotExists(BaseTestDB): + def test_create_table_if_not_exists_valid( + self, valid_dbs_instance: UnstractDB + ) -> None: + engine = valid_dbs_instance.get_engine() + result = DatabaseUtils.create_table_if_not_exists( + db_class=valid_dbs_instance, + engine=engine, + table_name=self.valid_table_name, + database_entry=self.database_entry, + ) + assert result is None + + def test_create_table_if_not_exists_bigquery_valid( + self, valid_bigquery_db_instance: UnstractDB + ) -> None: + engine = valid_bigquery_db_instance.get_engine() + result = DatabaseUtils.create_table_if_not_exists( + db_class=valid_bigquery_db_instance, + engine=engine, + table_name=self.valid_bigquery_table_name, + database_entry=self.database_entry, + ) + assert result is None + + def test_create_table_if_not_exists_invalid_schema( + self, invalid_dbs_instance: UnstractDB + ) -> None: + engine = invalid_dbs_instance.get_engine() + with pytest.raises(UnstractDBException): + DatabaseUtils.create_table_if_not_exists( + db_class=invalid_dbs_instance, + engine=engine, + table_name=self.valid_table_name, + database_entry=self.database_entry, + ) + + def test_create_table_if_not_exists_invalid_syntax( + self, valid_dbs_instance: UnstractDB + ) -> None: + engine = valid_dbs_instance.get_engine() + with pytest.raises(UnstractDBException): + DatabaseUtils.create_table_if_not_exists( + db_class=valid_dbs_instance, + engine=engine, + table_name=self.invalid_syntax_table_name, + database_entry=self.database_entry, + ) + + def test_create_table_if_not_exists_wrong_table_name( + self, valid_dbs_instance: UnstractDB + ) -> None: + engine = valid_dbs_instance.get_engine() + with pytest.raises(UnstractDBException): + DatabaseUtils.create_table_if_not_exists( + db_class=valid_dbs_instance, + engine=engine, + table_name=self.invalid_wrong_table_name, + database_entry=self.database_entry, + ) + + def test_create_table_if_not_exists_feature_not_supported( + self, invalid_dbs_instance: UnstractDB + ) -> None: + engine = invalid_dbs_instance.get_engine() + with pytest.raises(UnstractDBException): + DatabaseUtils.create_table_if_not_exists( + db_class=invalid_dbs_instance, + engine=engine, + table_name=self.invalid_wrong_table_name, + database_entry=self.database_entry, + ) + + def test_create_table_if_not_exists_invalid_snowflake_db( + self, invalid_snowflake_db_instance: UnstractDB + ) -> None: + engine = invalid_snowflake_db_instance.get_engine() + with pytest.raises(UnstractDBException): + DatabaseUtils.create_table_if_not_exists( + db_class=invalid_snowflake_db_instance, + engine=engine, + table_name=self.invalid_wrong_table_name, + database_entry=self.database_entry, + ) + + +if __name__ == "__main__": + pytest.main() diff --git a/backend/workflow_manager/endpoint_v2/tests/test_database_utils/test_execute_write_query.py b/backend/workflow_manager/endpoint_v2/tests/test_database_utils/test_execute_write_query.py new file mode 100644 index 000000000..22c26dcfd --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/tests/test_database_utils/test_execute_write_query.py @@ -0,0 +1,169 @@ +import os +import uuid +from typing import Any + +import pytest # type: ignore +from workflow_manager.endpoint_v2.database_utils import DatabaseUtils +from workflow_manager.endpoint_v2.exceptions import UnstractDBException + +from unstract.connectors.databases.redshift import Redshift +from unstract.connectors.databases.unstract_db import UnstractDB + +from .base_test_db import BaseTestDB + + +class TestExecuteWriteQuery(BaseTestDB): + @pytest.fixture(autouse=True) + def setup(self, base_setup: Any) -> None: + self.sql_columns_and_values = { + "created_by": "Unstract/DBWriter", + "created_at": "2024-05-20 10:36:25.362609", + "data": '{"input_file": "simple.pdf", "result": "report"}', + "id": str(uuid.uuid4()), + } + + def test_execute_write_query_valid(self, valid_dbs_instance: Any) -> None: + engine = valid_dbs_instance.get_engine() + result = DatabaseUtils.execute_write_query( + db_class=valid_dbs_instance, + engine=engine, + table_name=self.valid_table_name, + sql_keys=list(self.sql_columns_and_values.keys()), + sql_values=list(self.sql_columns_and_values.values()), + ) + assert result is None + + def test_execute_write_query_invalid_schema( + self, invalid_dbs_instance: Any + ) -> None: + engine = invalid_dbs_instance.get_engine() + with pytest.raises(UnstractDBException): + DatabaseUtils.execute_write_query( + db_class=invalid_dbs_instance, + engine=engine, + table_name=self.valid_table_name, + sql_keys=list(self.sql_columns_and_values.keys()), + sql_values=list(self.sql_columns_and_values.values()), + ) + + def test_execute_write_query_invalid_syntax(self, valid_dbs_instance: Any) -> None: + engine = valid_dbs_instance.get_engine() + with pytest.raises(UnstractDBException): + DatabaseUtils.execute_write_query( + db_class=valid_dbs_instance, + engine=engine, + table_name=self.invalid_syntax_table_name, + sql_keys=list(self.sql_columns_and_values.keys()), + sql_values=list(self.sql_columns_and_values.values()), + ) + + def test_execute_write_query_feature_not_supported( + self, invalid_dbs_instance: Any + ) -> None: + engine = invalid_dbs_instance.get_engine() + with pytest.raises(UnstractDBException): + DatabaseUtils.execute_write_query( + db_class=invalid_dbs_instance, + engine=engine, + table_name=self.invalid_wrong_table_name, + sql_keys=list(self.sql_columns_and_values.keys()), + sql_values=list(self.sql_columns_and_values.values()), + ) + + def load_text_to_sql_values(self) -> dict[str, Any]: + file_path = os.path.join(os.path.dirname(__file__), "static", "large_doc.txt") + with open(file_path, encoding="utf-8") as file: + content = file.read() + sql_columns_and_values = self.sql_columns_and_values.copy() + sql_columns_and_values["data"] = content + return sql_columns_and_values + + @pytest.fixture + def valid_redshift_db_instance(self) -> Any: + return Redshift(self.valid_redshift_creds) + + def test_execute_write_query_datatype_too_large_redshift( + self, valid_redshift_db_instance: Any + ) -> None: + engine = valid_redshift_db_instance.get_engine() + sql_columns_and_values = self.load_text_to_sql_values() + with pytest.raises(UnstractDBException): + DatabaseUtils.execute_write_query( + db_class=valid_redshift_db_instance, + engine=engine, + table_name=self.valid_table_name, + sql_keys=list(sql_columns_and_values.keys()), + sql_values=list(sql_columns_and_values.values()), + ) + + def test_execute_write_query_bigquery_valid( + self, valid_bigquery_db_instance: Any + ) -> None: + engine = valid_bigquery_db_instance.get_engine() + result = DatabaseUtils.execute_write_query( + db_class=valid_bigquery_db_instance, + engine=engine, + table_name=self.valid_bigquery_table_name, + sql_keys=list(self.sql_columns_and_values.keys()), + sql_values=list(self.sql_columns_and_values.values()), + ) + assert result is None + + def test_execute_write_query_wrong_table_name( + self, valid_dbs_instance: UnstractDB + ) -> None: + engine = valid_dbs_instance.get_engine() + with pytest.raises(UnstractDBException): + DatabaseUtils.execute_write_query( + db_class=valid_dbs_instance, + engine=engine, + table_name=self.invalid_wrong_table_name, + sql_keys=list(self.sql_columns_and_values.keys()), + sql_values=list(self.sql_columns_and_values.values()), + ) + + def test_execute_write_query_bigquery_large_doc( + self, valid_bigquery_db_instance: Any + ) -> None: + engine = valid_bigquery_db_instance.get_engine() + sql_columns_and_values = self.load_text_to_sql_values() + result = DatabaseUtils.execute_write_query( + db_class=valid_bigquery_db_instance, + engine=engine, + table_name=self.valid_bigquery_table_name, + sql_keys=list(sql_columns_and_values.keys()), + sql_values=list(sql_columns_and_values.values()), + ) + assert result is None + + def test_execute_write_query_invalid_snowflake_db( + self, invalid_snowflake_db_instance: UnstractDB + ) -> None: + engine = invalid_snowflake_db_instance.get_engine() + with pytest.raises(UnstractDBException): + DatabaseUtils.execute_write_query( + db_class=invalid_snowflake_db_instance, + engine=engine, + table_name=self.invalid_wrong_table_name, + sql_keys=list(self.sql_columns_and_values.keys()), + sql_values=list(self.sql_columns_and_values.values()), + ) + + # Make this function at last to cover all large doc + def test_execute_write_query_large_doc( + self, valid_dbs_instance_to_handle_large_doc: Any + ) -> None: + engine = valid_dbs_instance_to_handle_large_doc.get_engine() + sql_columns_and_values = self.load_text_to_sql_values() + result = DatabaseUtils.execute_write_query( + db_class=valid_dbs_instance_to_handle_large_doc, + engine=engine, + table_name=self.valid_table_name, + sql_keys=list(sql_columns_and_values.keys()), + sql_values=list(sql_columns_and_values.values()), + ) + assert result is None + + +if __name__ == "__main__": + pytest.main() diff --git a/backend/workflow_manager/endpoint_v2/urls.py b/backend/workflow_manager/endpoint_v2/urls.py new file mode 100644 index 000000000..522859e23 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/urls.py @@ -0,0 +1,23 @@ +from django.urls import path +from workflow_manager.endpoint_v2.views import WorkflowEndpointViewSet + +workflow_endpoint_list = WorkflowEndpointViewSet.as_view( + {"get": "workflow_endpoint_list"} +) +endpoint_list = WorkflowEndpointViewSet.as_view({"get": "list"}) +workflow_endpoint_detail = WorkflowEndpointViewSet.as_view( + {"get": "retrieve", "put": "update"} +) +endpoint_settings_detail = WorkflowEndpointViewSet.as_view( + {"get": WorkflowEndpointViewSet.get_settings.__name__} +) + +urlpatterns = [ + path("", endpoint_list, name="endpoint-list"), + path("/", workflow_endpoint_detail, name="workflow-endpoint-detail"), + path( + "/settings/", + endpoint_settings_detail, + name="workflow-endpoint-detail", + ), +] diff --git a/backend/workflow_manager/endpoint_v2/views.py b/backend/workflow_manager/endpoint_v2/views.py new file mode 100644 index 000000000..6e82a0e36 --- /dev/null +++ b/backend/workflow_manager/endpoint_v2/views.py @@ -0,0 +1,82 @@ +from django.db.models import QuerySet +from rest_framework import status, viewsets +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.response import Response +from workflow_manager.endpoint_v2.destination import DestinationConnector +from workflow_manager.endpoint_v2.endpoint_utils import WorkflowEndpointUtils +from workflow_manager.endpoint_v2.models import WorkflowEndpoint +from workflow_manager.endpoint_v2.source import SourceConnector +from workflow_manager.workflow_v2.serializers import WorkflowEndpointSerializer + + +class WorkflowEndpointViewSet(viewsets.ModelViewSet): + queryset = WorkflowEndpoint.objects.all() + serializer_class = WorkflowEndpointSerializer + + def get_queryset(self) -> QuerySet: + + queryset = ( + WorkflowEndpoint.objects.all() + .select_related("workflow") + .filter(workflow__created_by=self.request.user) + ) + endpoint_type_filter = self.request.query_params.get("endpoint_type", None) + connection_type_filter = self.request.query_params.get("connection_type", None) + if endpoint_type_filter: + queryset = queryset.filter(endpoint_type=endpoint_type_filter) + if connection_type_filter: + queryset = queryset.filter(connection_type=connection_type_filter) + return queryset + + @action(detail=True, methods=["get"]) + def get_settings(self, request: Request, pk: str) -> Response: + """Retrieve the settings/schema for a specific workflow endpoint. + + Parameters: + request (Request): The HTTP request object. + pk (str): The primary key of the workflow endpoint. + + Returns: + Response: The HTTP response containing the settings/schema for + the endpoint. + """ + endpoint: WorkflowEndpoint = self.get_object() + connection_type = endpoint.connection_type + endpoint_type = endpoint.endpoint_type + schema = None + if endpoint_type == WorkflowEndpoint.EndpointType.SOURCE: + if connection_type == WorkflowEndpoint.ConnectionType.API: + schema = SourceConnector.get_json_schema_for_api() + if connection_type == WorkflowEndpoint.ConnectionType.FILESYSTEM: + schema = SourceConnector.get_json_schema_for_file_system() + if endpoint_type == WorkflowEndpoint.EndpointType.DESTINATION: + if connection_type == WorkflowEndpoint.ConnectionType.DATABASE: + schema = DestinationConnector.get_json_schema_for_database() + if connection_type == WorkflowEndpoint.ConnectionType.FILESYSTEM: + schema = DestinationConnector.get_json_schema_for_file_system() + if connection_type == WorkflowEndpoint.ConnectionType.API: + schema = DestinationConnector.get_json_schema_for_api() + + return Response( + { + "status": status.HTTP_200_OK, + "schema": schema, + } + ) + + @action(detail=True, methods=["get"]) + def workflow_endpoint_list(self, request: Request, pk: str) -> Response: + """Retrieve a list of endpoints for a specific workflow. + + Parameters: + request (Request): The HTTP request object. + pk (str): The primary key of the workflow. + + Returns: + Response: The HTTP response containing the serialized list of + endpoints. + """ + endpoints = WorkflowEndpointUtils.get_endpoints_for_workflow(pk) + serializer = WorkflowEndpointSerializer(endpoints, many=True) + return Response(serializer.data) diff --git a/backend/workflow_manager/workflow_v2/__init__.py b/backend/workflow_manager/workflow_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/workflow_manager/workflow_v2/admin.py b/backend/workflow_manager/workflow_v2/admin.py new file mode 100644 index 000000000..6ab58abaf --- /dev/null +++ b/backend/workflow_manager/workflow_v2/admin.py @@ -0,0 +1,4 @@ +from django.contrib import admin +from workflow_manager.workflow_v2.models.workflow import Workflow + +admin.site.register(Workflow) diff --git a/backend/workflow_manager/workflow_v2/apps.py b/backend/workflow_manager/workflow_v2/apps.py new file mode 100644 index 000000000..255090664 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/apps.py @@ -0,0 +1,13 @@ +from django.apps import AppConfig + + +class WorkflowConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "workflow_manager.workflow_v2" + + def ready(self): + from workflow_manager.workflow_v2.execution_log_utils import ( + create_log_consumer_scheduler_if_not_exists, + ) + + create_log_consumer_scheduler_if_not_exists() diff --git a/backend/workflow_manager/workflow_v2/constants.py b/backend/workflow_manager/workflow_v2/constants.py new file mode 100644 index 000000000..95aab13e3 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/constants.py @@ -0,0 +1,60 @@ +class WorkflowKey: + """Dict keys related to workflows.""" + + PROMPT_TEXT = "prompt_text" + LLM_RESPONSE = "llm_response" + WF_STEPS = "steps" + WF_TOOL = "tool" + WF_INSTANCE_SETTINGS = "instance_settings" + WF_TOOL_INSTANCE_ID = "tool_instance_id" + WF_CONNECTOR_CLASS = "connector_class" + WF_INPUT = "input" + WF_OUTPUT = "output" + WF_TOOL_UUID = "id" + WF_ID = "workflow_id" + WF_NAME = "workflow_name" + WF_OWNER = "workflow_owner" + WF_TOOL_INSTANCES = "tool_instances" + WF_IS_ACTIVE = "is_active" + EXECUTION_ACTION = "execution_action" + # Keys from provisional workflow + PWF_RESULT = "result" + PWF_OUTPUT = "output" + PWF_COST_TYPE = "cost_type" + PWF_COST = "cost" + PWF_TIME_TAKEN = "time_taken" + WF_CACHE_PATTERN = r"^cache:{?\w{8}-?\w{4}-?\w{4}-?\w{4}-?\w{12}}?$" + WF_PROJECT_GUID = "guid" + + +class WorkflowExecutionKey: + WORKFLOW_EXECUTION_ID_PREFIX = "workflow" + EXECUTION_ID = "execution_id" + LOG_GUID = "log_guid" + WITH_LOG = "with_log" + + +class WorkflowErrors: + WORKFLOW_EXISTS = "Workflow with this configuration already exists." + DUPLICATE_API = "It appears that a duplicate call may have been made." + INVALID_EXECUTION_ID = "Invalid execution_id" + + +class CeleryConfigurations: + INTERVAL = 2 + + +class Tool: + APIOPS = "apiops" + + +class WorkflowMessages: + CACHE_CLEAR_SUCCESS = "Cache cleared successfully." + CACHE_CLEAR_FAILED = "Failed to clear cache." + CACHE_EMPTY = "Cache is already empty." + CELERY_TIMEOUT_MESSAGE = ( + "Your request is being processed. Please wait." + "You can check the status using the status API." + ) + FILE_MARKER_CLEAR_SUCCESS = "File marker cleared successfully." + FILE_MARKER_CLEAR_FAILED = "Failed to clear file marker." diff --git a/backend/workflow_manager/workflow_v2/dto.py b/backend/workflow_manager/workflow_v2/dto.py new file mode 100644 index 000000000..5f2e0db43 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/dto.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass +from typing import Any, Optional + +from celery.result import AsyncResult +from workflow_manager.workflow_v2.constants import WorkflowKey + + +@dataclass +class ProvisionalWorkflow: + result: str + output: dict[str, str] + cost_type: str + cost: str + time_taken: float + + def __init__(self, input_dict: dict[str, Any]) -> None: + self.result = input_dict.get(WorkflowKey.PWF_RESULT, "") + self.output = input_dict.get(WorkflowKey.PWF_OUTPUT, {}) + self.cost_type = input_dict.get(WorkflowKey.PWF_COST_TYPE, "") + self.cost = input_dict.get(WorkflowKey.PWF_COST, "") + self.time_taken = input_dict.get(WorkflowKey.PWF_TIME_TAKEN, 0.0) + + +@dataclass +class ExecutionResponse: + workflow_id: str + execution_id: str + execution_status: str + log_id: Optional[str] = None + status_api: Optional[str] = None + error: Optional[str] = None + mode: Optional[str] = None + result: Optional[Any] = None + message: Optional[str] = None + + def __post_init__(self) -> None: + self.log_id = self.log_id or None + self.mode = self.mode or None + self.error = self.error or None + self.result = self.result or None + self.message = self.message or None + self.status_api = self.status_api or None + + +@dataclass +class AsyncResultData: + id: str + status: str + result: Any + is_ready: bool + is_failed: bool + info: Any + + def __init__(self, async_result: AsyncResult): + self.id = async_result.id + self.status = async_result.status + self.result = async_result.result + self.is_ready = async_result.ready() + self.is_failed = async_result.failed() + self.info = async_result.info + if isinstance(self.result, Exception): + self.result = str(self.result) + + def to_dict(self) -> dict[str, Any]: + return { + "status": self.status, + "result": self.result, + } diff --git a/backend/workflow_manager/workflow_v2/enums.py b/backend/workflow_manager/workflow_v2/enums.py new file mode 100644 index 000000000..cc13a1e35 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/enums.py @@ -0,0 +1,73 @@ +from enum import Enum + +from utils.common_utils import ModelEnum + + +class WorkflowExecutionMethod(Enum): + INSTANT = "INSTANT" + QUEUED = "QUEUED" + + +class ExecutionStatus(ModelEnum): + """An enumeration representing the various statuses of an execution + process. + + Statuses: + PENDING: The execution's entry has been created in the database. + QUEUED: The execution task is queued for asynchronous execution + INITIATED: The execution has been initiated. + READY: The execution is ready for the build phase. + EXECUTING: The execution is currently in progress. + COMPLETED: The execution has been successfully completed. + STOPPED: The execution was stopped by the user + (applicable to step executions). + ERROR: An error occurred during the execution process. + + Note: + Intermediate statuses might not be experienced due to + Django's query triggering once all processes are completed. + """ + + PENDING = "PENDING" + INITIATED = "INITIATED" + QUEUED = "QUEUED" + READY = "READY" + EXECUTING = "EXECUTING" + COMPLETED = "COMPLETED" + STOPPED = "STOPPED" + ERROR = "ERROR" + + +class SchemaType(Enum): + """Possible types for workflow module's JSON schema. + + Values: + src: Refers to the source module's schema + dest: Refers to the destination module's schema + """ + + SRC = "src" + DEST = "dest" + + +class SchemaEntity(Enum): + """Possible entities for workflow module's JSON schema. + + Values: + file: Refers to schema for file based sources + api: Refers to schema for API based sources + db: Refers to schema for DB based destinations + """ + + FILE = "file" + API = "api" + DB = "db" + + +class ColumnModes(Enum): + WRITE_JSON_TO_A_SINGLE_COLUMN = "Write JSON to a single column" + SPLIT_JSON_INTO_COLUMNS = "Split JSON into columns" + + +class AgentName(Enum): + UNSTRACT_DBWRITER = "Unstract/DBWriter" diff --git a/backend/workflow_manager/workflow_v2/exceptions.py b/backend/workflow_manager/workflow_v2/exceptions.py new file mode 100644 index 000000000..39087ab8e --- /dev/null +++ b/backend/workflow_manager/workflow_v2/exceptions.py @@ -0,0 +1,56 @@ +from rest_framework.exceptions import APIException + + +class WorkflowGenerationError(APIException): + status_code = 500 + default_detail = "Error generating workflow." + + +class WorkflowRegenerationError(APIException): + status_code = 500 + default_detail = "Error regenerating workflow." + + +class WorkflowExecutionError(APIException): + status_code = 500 + default_detail = "Error executing workflow." + + +class WorkflowDoesNotExistError(APIException): + status_code = 404 + default_detail = "Workflow does not exist" + + +class TaskDoesNotExistError(APIException): + status_code = 404 + default_detail = "Task does not exist" + + +class DuplicateActionError(APIException): + status_code = 400 + default_detail = "Action is running" + + +class InvalidRequest(APIException): + status_code = 400 + default_detail = "Invalid Request" + + +class MissingEnvException(APIException): + status_code = 500 + default_detail = "At least one active platform key should be available." + + +class InternalException(APIException): + """Internal Error. + + Args: + APIException (_type_): _description_ + """ + + status_code = 500 + + +class WorkflowExecutionNotExist(APIException): + status_code = 404 + default_detail = "Workflow execution does not exist" diff --git a/backend/workflow_manager/workflow_v2/execution.py b/backend/workflow_manager/workflow_v2/execution.py new file mode 100644 index 000000000..fd64cba50 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/execution.py @@ -0,0 +1,420 @@ +import logging +import time +from typing import Optional + +from account_v2.constants import Common +from api_v2.exceptions import InvalidAPIRequest +from platform_settings_v2.platform_auth_service import PlatformAuthenticationService +from tool_instance_v2.constants import JsonSchemaKey +from tool_instance_v2.models import ToolInstance +from tool_instance_v2.tool_processor import ToolProcessor +from unstract.tool_registry.dto import Tool +from unstract.workflow_execution import WorkflowExecutionService +from unstract.workflow_execution.dto import ToolInstance as ToolInstanceDataClass +from unstract.workflow_execution.dto import WorkflowDto +from unstract.workflow_execution.enums import ExecutionType, LogComponent, LogState +from unstract.workflow_execution.exceptions import StopExecution +from utils.local_context import StateStore +from utils.user_context import UserContext +from workflow_manager.workflow_v2.constants import WorkflowKey +from workflow_manager.workflow_v2.enums import ExecutionStatus +from workflow_manager.workflow_v2.exceptions import WorkflowExecutionError +from workflow_manager.workflow_v2.models import Workflow, WorkflowExecution +from workflow_manager.workflow_v2.models.execution import EXECUTION_ERROR_LENGTH +from workflow_manager.workflow_v2.models.file_history import FileHistory + +logger = logging.getLogger(__name__) + + +class WorkflowExecutionServiceHelper(WorkflowExecutionService): + def __init__( + self, + workflow: Workflow, + tool_instances: list[ToolInstance], + organization_id: Optional[str] = None, + pipeline_id: Optional[str] = None, + single_step: bool = False, + scheduled: bool = False, + mode: tuple[str, str] = WorkflowExecution.Mode.INSTANT, + workflow_execution: Optional[WorkflowExecution] = None, + include_metadata: bool = False, + ) -> None: + tool_instances_as_dto = [] + for tool_instance in tool_instances: + tool_instances_as_dto.append( + self.convert_tool_instance_model_to_data_class(tool_instance) + ) + workflow_as_dto: WorkflowDto = self.convert_workflow_model_to_data_class( + workflow=workflow + ) + organization_id = organization_id or UserContext.get_organization_identifier() + if not organization_id: + raise WorkflowExecutionError(detail="invalid Organization ID") + + platform_key = PlatformAuthenticationService.get_active_platform_key() + super().__init__( + organization_id=organization_id, + workflow_id=workflow.id, + workflow=workflow_as_dto, + tool_instances=tool_instances_as_dto, + platform_service_api_key=str(platform_key.key), + ignore_processed_entities=False, + include_metadata=include_metadata, + ) + if not workflow_execution: + # Use pipline_id for pipelines / API deployment + # since session might not be present. + log_events_id = StateStore.get(Common.LOG_EVENTS_ID) + self.execution_log_id = log_events_id if log_events_id else pipeline_id + self.execution_mode = mode + self.execution_method: tuple[str, str] = ( + WorkflowExecution.Method.SCHEDULED + if scheduled + else WorkflowExecution.Method.DIRECT + ) + self.execution_type: tuple[str, str] = ( + WorkflowExecution.Type.STEP + if single_step + else WorkflowExecution.Type.COMPLETE + ) + workflow_execution = WorkflowExecution( + pipeline_id=pipeline_id, + workflow_id=workflow.id, + execution_mode=mode, + execution_method=self.execution_method, + execution_type=self.execution_type, + status=ExecutionStatus.INITIATED.value, + execution_log_id=self.execution_log_id, + ) + workflow_execution.save() + else: + self.execution_mode = workflow_execution.execution_mode + self.execution_method = workflow_execution.execution_method + self.execution_type = workflow_execution.execution_type + self.execution_log_id = workflow_execution.execution_log_id + + self.set_messaging_channel(str(self.execution_log_id)) + project_settings = {} + project_settings[WorkflowKey.WF_PROJECT_GUID] = str(self.execution_log_id) + self.workflow_id = workflow.id + self.project_settings = project_settings + self.pipeline_id = pipeline_id + self.execution_id = str(workflow_execution.id) + logger.info( + f"Executing for Pipeline ID: {pipeline_id}, " + f"workflow ID: {self.workflow_id}, execution ID: {self.execution_id}, " + f"web socket messaging channel ID: {self.execution_log_id}" + ) + + self.compilation_result = self.compile_workflow(execution_id=self.execution_id) + + def _initiate_api_execution( + self, tool_instance: ToolInstance, execution_path: Optional[str] + ) -> None: + if not execution_path: + raise InvalidAPIRequest("File shouldn't be empty") + tool_instance.metadata[JsonSchemaKey.ROOT_FOLDER] = execution_path + + @staticmethod + def create_workflow_execution( + workflow_id: str, + pipeline_id: Optional[str] = None, + single_step: bool = False, + scheduled: bool = False, + log_events_id: Optional[str] = None, + execution_id: Optional[str] = None, + mode: tuple[str, str] = WorkflowExecution.Mode.INSTANT, + ) -> WorkflowExecution: + execution_method: tuple[str, str] = ( + WorkflowExecution.Method.SCHEDULED + if scheduled + else WorkflowExecution.Method.DIRECT + ) + execution_type: tuple[str, str] = ( + WorkflowExecution.Type.STEP + if single_step + else WorkflowExecution.Type.COMPLETE + ) + execution_log_id = log_events_id if log_events_id else pipeline_id + # TODO: Using objects.create() instead + workflow_execution = WorkflowExecution( + pipeline_id=pipeline_id, + workflow_id=workflow_id, + execution_mode=mode, + execution_method=execution_method, + execution_type=execution_type, + status=ExecutionStatus.PENDING.value, + execution_log_id=execution_log_id, + ) + if execution_id: + workflow_execution.id = execution_id + workflow_execution.save() + return workflow_execution + + def update_execution( + self, + status: Optional[ExecutionStatus] = None, + execution_time: Optional[float] = None, + error: Optional[str] = None, + increment_attempt: bool = False, + ) -> None: + execution = WorkflowExecution.objects.get(pk=self.execution_id) + + if status is not None: + execution.status = status.value + if execution_time is not None: + execution.execution_time = execution_time + if error: + execution.error_message = error + if increment_attempt: + execution.attempts += 1 + + execution.save() + + def has_successful_compilation(self) -> bool: + return self.compilation_result["success"] is True + + def get_execution_instance(self) -> WorkflowExecution: + execution: WorkflowExecution = WorkflowExecution.objects.get( + pk=self.execution_id + ) + return execution + + def build(self) -> None: + if self.compilation_result["success"] is True: + self.build_workflow() + self.update_execution(status=ExecutionStatus.READY) + else: + logger.error( + "Errors while compiling workflow " + f"{self.compilation_result['problems']}" + ) + self.update_execution( + status=ExecutionStatus.ERROR, + error=self.compilation_result["problems"][0], + ) + raise WorkflowExecutionError(self.compilation_result["problems"][0]) + + def execute(self, single_step: bool = False) -> None: + execution_type = ExecutionType.COMPLETE + if single_step: + execution_type = ExecutionType.STEP + + if self.compilation_result["success"] is False: + error_message = ( + f"Errors while compiling workflow " + f"{self.compilation_result['problems'][0]}" + ) + raise WorkflowExecutionError(error_message) + + if self.execution_mode not in ( + WorkflowExecution.Mode.INSTANT, + WorkflowExecution.Mode.QUEUE, + ): + error_message = f"Unknown Execution Method {self.execution_mode}" + raise WorkflowExecutionError(error_message) + + start_time = time.time() + try: + self.execute_workflow(execution_type=execution_type) + end_time = time.time() + execution_time = end_time - start_time + except StopExecution as exception: + end_time = time.time() + execution_time = end_time - start_time + logger.info(f"Execution {self.execution_id} stopped") + raise exception + except Exception as exception: + end_time = time.time() + execution_time = end_time - start_time + message = str(exception)[:EXECUTION_ERROR_LENGTH] + logger.info( + f"Execution {self.execution_id} ran for {execution_time:.4f}s, " + f" Error {exception}" + ) + raise WorkflowExecutionError(message) from exception + + def publish_initial_workflow_logs(self, total_files: int) -> None: + """Publishes the initial logs for the workflow. + + Args: + total_files (int): The total number of matched files. + + Returns: + None + """ + self.publish_log(f"Total matched files: {total_files}") + self.publish_update_log(LogState.BEGIN_WORKFLOW, "1", LogComponent.STATUS_BAR) + self.publish_update_log( + LogState.RUNNING, "Ready for execution", LogComponent.WORKFLOW + ) + + def publish_final_workflow_logs( + self, total_files: int, processed_files: int + ) -> None: + """Publishes the final logs for the workflow. + + Returns: + None + """ + self.publish_update_log(LogState.END_WORKFLOW, "1", LogComponent.STATUS_BAR) + self.publish_update_log( + LogState.SUCCESS, "Executed successfully", LogComponent.WORKFLOW + ) + self.publish_log( + f"Execution completed for {processed_files} files out of {total_files}" + ) + + def publish_initial_tool_execution_logs( + self, current_file_idx: int, total_files: int, file_name: str + ) -> None: + """Publishes the initial logs for tool execution. + + Args: + current_file_idx (int): 1-based index for the current file being processed + total_files (int): The total number of files to process + file_name (str): The name of the file being processed. + + Returns: + None + """ + self.publish_update_log( + component=LogComponent.STATUS_BAR, + state=LogState.MESSAGE, + message=f"Processing file {file_name} {current_file_idx}/{total_files}", + ) + self.publish_log(f"Processing file {file_name}") + + def execute_input_file( + self, + file_name: str, + single_step: bool, + file_history: Optional[FileHistory] = None, + ) -> bool: + """Executes the input file. + + Args: + file_name (str): The name of the file to be executed. + single_step (bool): Flag indicating whether to execute in + single step mode. + file_history (Optional[FileHistory], optional): + The file history object. Defaults to None. + Returns: + bool: Flag indicating whether the file was executed. + """ + execution_type = ExecutionType.COMPLETE + is_executed = False + if single_step: + execution_type = ExecutionType.STEP + if not (file_history and file_history.is_completed()): + self.execute_uncached_input(file_name=file_name, single_step=single_step) + self.publish_log(f"Tool executed successfully for {file_name}") + is_executed = True + else: + self.publish_log( + f"Skipping file {file_name} as it is already processed." + "Clear the cache to process it again" + ) + self._handle_execution_type(execution_type) + return is_executed + + def execute_uncached_input(self, file_name: str, single_step: bool) -> None: + """Executes the uncached input file. + + Args: + file_name (str): The name of the file to be executed. + single_step (bool): Flag indicating whether to execute in + single step mode. + + Returns: + None + """ + self.publish_log("No entries found in cache, executing the tools") + self.publish_update_log( + state=LogState.SUCCESS, + message=f"{file_name} Sent for execution", + component=LogComponent.SOURCE, + ) + self.execute(single_step) + + def initiate_tool_execution( + self, + current_file_idx: int, + total_files: int, + file_name: str, + single_step: bool, + ) -> None: + """Initiates the execution of a tool for a specific file in the + workflow. + + Args: + current_file_idx (int): 1-based index for the current file being processed + total_step (int): The total number of files to process in the workflow + file_name (str): The name of the file being processed + single_step (bool): Flag indicating whether the execution is in + single-step mode + + Returns: + None + + Raises: + None + """ + execution_type = ExecutionType.COMPLETE + if single_step: + execution_type = ExecutionType.STEP + self.publish_initial_tool_execution_logs( + current_file_idx, total_files, file_name + ) + self._handle_execution_type(execution_type) + + source_status_message = ( + f"({current_file_idx}/{total_files})Processing file {file_name}" + ) + self.publish_update_log( + state=LogState.RUNNING, + message=source_status_message, + component=LogComponent.SOURCE, + ) + self.publish_log("Trying to fetch results from cache") + + @staticmethod + def update_execution_status(execution_id: str, status: ExecutionStatus) -> None: + try: + execution = WorkflowExecution.objects.get(pk=execution_id) + execution.status = status.value + execution.save() + except WorkflowExecution.DoesNotExist: + logger.error(f"execution doesn't exist {execution_id}") + + @staticmethod + def update_execution_task(execution_id: str, task_id: str) -> None: + try: + execution = WorkflowExecution.objects.get(pk=execution_id) + execution.task_id = task_id + execution.save() + except WorkflowExecution.DoesNotExist: + logger.error(f"execution doesn't exist {execution_id}") + + @staticmethod + def convert_tool_instance_model_to_data_class( + tool_instance: ToolInstance, + ) -> ToolInstanceDataClass: + tool: Tool = ToolProcessor.get_tool_by_uid(tool_instance.tool_id) + tool_dto = ToolInstanceDataClass( + id=tool_instance.id, + tool_id=tool_instance.tool_id, + workflow=tool_instance.workflow.id, + metadata=tool_instance.metadata, + step=tool_instance.step, + properties=tool.properties, + image_name=tool.image_name, + image_tag=tool.image_tag, + ) + return tool_dto + + @staticmethod + def convert_workflow_model_to_data_class( + workflow: Workflow, + ) -> WorkflowDto: + return WorkflowDto(id=workflow.id) diff --git a/backend/workflow_manager/workflow_v2/execution_log_utils.py b/backend/workflow_manager/workflow_v2/execution_log_utils.py new file mode 100644 index 000000000..5b1ef6079 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/execution_log_utils.py @@ -0,0 +1,107 @@ +import logging +import sys +from collections import defaultdict + +from celery import shared_task +from django.db import IntegrityError +from django.db.utils import ProgrammingError +from django_celery_beat.models import IntervalSchedule, PeriodicTask +from utils.cache_service import CacheService +from utils.constants import ExecutionLogConstants +from utils.dto import LogDataDTO +from workflow_manager.workflow_v2.models.execution_log import ExecutionLog + +logger = logging.getLogger(__name__) + + +@shared_task(bind=True, name=ExecutionLogConstants.TASK_V2) +def consume_log_history(self): + organization_logs = defaultdict(list) + logs_count = 0 + + while logs_count < ExecutionLogConstants.LOGS_BATCH_LIMIT: + log = CacheService.lpop(ExecutionLogConstants.LOG_QUEUE_NAME) + if not log: + break + + log_data = LogDataDTO.from_json(log) + if not log_data: + continue + + organization_id = log_data.organization_id + organization_logs[organization_id].append( + ExecutionLog( + execution_id=log_data.execution_id, + data=log_data.data, + event_time=log_data.event_time, + ) + ) + logs_count += 1 + logger.info(f"Logs count: {logs_count}") + for organization_id, logs in organization_logs.items(): + store_to_db(organization_id, logs) + + +def create_log_consumer_scheduler_if_not_exists() -> None: + try: + interval, _ = IntervalSchedule.objects.get_or_create( + every=ExecutionLogConstants.CONSUMER_INTERVAL, + period=IntervalSchedule.SECONDS, + ) + except ProgrammingError as error: + logger.warning( + "ProgrammingError occurred while creating " + "log consumer scheduler. If you are currently running " + "migrations for new environment, you can ignore this warning" + ) + if all(arg not in sys.argv for arg in ("migrate", "makemigrations")): + logger.warning(f"ProgrammingError details: {error}") + return + except IntervalSchedule.MultipleObjectsReturned as error: + logger.error(f"Error occurred while getting interval schedule: {error}") + interval = IntervalSchedule.objects.filter( + every=ExecutionLogConstants.CONSUMER_INTERVAL, + period=IntervalSchedule.SECONDS, + ).first() + try: + # Create the scheduler + task, created = PeriodicTask.objects.get_or_create( + name=ExecutionLogConstants.PERIODIC_TASK_NAME_V2, + task=ExecutionLogConstants.TASK_V2, + defaults={ + "interval": interval, + "queue": ExecutionLogConstants.CELERY_QUEUE_NAME, + "enabled": ExecutionLogConstants.IS_ENABLED, + }, + ) + if not created: + task.enabled = ExecutionLogConstants.IS_ENABLED + task.interval = interval + task.queue = ExecutionLogConstants.CELERY_QUEUE_NAME + task.save() + logger.info("Log consumer scheduler updated successfully.") + else: + logger.info("Log consumer scheduler created successfully.") + except IntegrityError as error: + logger.error(f"Error occurred while creating log consumer scheduler: {error}") + + +def store_to_db(organization_id: str, execution_logs: list[ExecutionLog]) -> None: + + # Store the log data in the database within tenant context + ExecutionLog.objects.bulk_create(objs=execution_logs, ignore_conflicts=True) + + +class ExecutionLogUtils: + + @staticmethod + def get_execution_logs_by_execution_id(execution_id) -> list[ExecutionLog]: + """Get all execution logs for a given execution ID. + + Args: + execution_id (int): The ID of the execution. + + Returns: + list[ExecutionLog]: A list of ExecutionLog objects. + """ + return ExecutionLog.objects.filter(execution_id=execution_id) diff --git a/backend/workflow_manager/workflow_v2/execution_log_view.py b/backend/workflow_manager/workflow_v2/execution_log_view.py new file mode 100644 index 000000000..96f557b77 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/execution_log_view.py @@ -0,0 +1,27 @@ +import logging + +from permissions.permission import IsOwner +from rest_framework import viewsets +from rest_framework.versioning import URLPathVersioning +from utils.pagination import CustomPagination +from workflow_manager.workflow_v2.models.execution_log import ExecutionLog +from workflow_manager.workflow_v2.serializers import WorkflowExecutionLogSerializer + +logger = logging.getLogger(__name__) + + +class WorkflowExecutionLogViewSet(viewsets.ModelViewSet): + versioning_class = URLPathVersioning + permission_classes = [IsOwner] + serializer_class = WorkflowExecutionLogSerializer + pagination_class = CustomPagination + + EVENT_TIME_FELID_ASC = "event_time" + + def get_queryset(self): + # Get the execution_id:pk from the URL path + execution_id = self.kwargs.get("pk") + queryset = ExecutionLog.objects.filter(execution_id=execution_id).order_by( + self.EVENT_TIME_FELID_ASC + ) + return queryset diff --git a/backend/workflow_manager/workflow_v2/execution_view.py b/backend/workflow_manager/workflow_v2/execution_view.py new file mode 100644 index 000000000..190a72617 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/execution_view.py @@ -0,0 +1,25 @@ +import logging + +from permissions.permission import IsOwner +from rest_framework import viewsets +from rest_framework.versioning import URLPathVersioning +from workflow_manager.workflow_v2.models.execution import WorkflowExecution +from workflow_manager.workflow_v2.serializers import WorkflowExecutionSerializer + +logger = logging.getLogger(__name__) + + +class WorkflowExecutionViewSet(viewsets.ModelViewSet): + versioning_class = URLPathVersioning + permission_classes = [IsOwner] + serializer_class = WorkflowExecutionSerializer + + CREATED_AT_FIELD_DESC = "-created_at" + + def get_queryset(self): + # Get the uuid:pk from the URL path + workflow_id = self.kwargs.get("pk") + queryset = WorkflowExecution.objects.filter(workflow_id=workflow_id).order_by( + self.CREATED_AT_FIELD_DESC + ) + return queryset diff --git a/backend/workflow_manager/workflow_v2/file_history_helper.py b/backend/workflow_manager/workflow_v2/file_history_helper.py new file mode 100644 index 000000000..7d84cf0da --- /dev/null +++ b/backend/workflow_manager/workflow_v2/file_history_helper.py @@ -0,0 +1,89 @@ +import logging +from typing import Any, Optional + +from django.db.utils import IntegrityError +from workflow_manager.workflow_v2.enums import ExecutionStatus +from workflow_manager.workflow_v2.models.file_history import FileHistory +from workflow_manager.workflow_v2.models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class FileHistoryHelper: + """A helper class for managing file history related operations.""" + + @staticmethod + def get_file_history( + workflow: Workflow, cache_key: Optional[str] = None + ) -> Optional[FileHistory]: + """Retrieve a file history record based on the cache key. + + Args: + cache_key (Optional[str]): The cache key to search for. + + Returns: + Optional[FileHistory]: The matching file history record, if found. + """ + if not cache_key: + return None + try: + file_history: FileHistory = FileHistory.objects.get( + cache_key=cache_key, workflow=workflow + ) + except FileHistory.DoesNotExist: + return None + return file_history + + @staticmethod + def create_file_history( + cache_key: str, + workflow: Workflow, + status: ExecutionStatus, + result: Any, + metadata: Any, + error: Optional[str] = None, + file_name: Optional[str] = None, + ) -> FileHistory: + """Create a new file history record. + + Args: + cache_key (str): The cache key for the file. + workflow (Workflow): The associated workflow. + status (ExecutionStatus): The execution status. + result (Any): The result from the execution. + + Returns: + FileHistory: The newly created file history record. + """ + try: + file_history: FileHistory = FileHistory.objects.create( + workflow=workflow, + cache_key=cache_key, + status=status.value, + result=str(result), + meta_data=str(metadata), + error=str(error) if error else "", + ) + except IntegrityError: + # TODO: Need to find why duplicate insert is coming + logger.warning( + "Trying to insert duplication data for filename %s for workflow %s", + file_name, + workflow, + ) + file_history = FileHistoryHelper.get_file_history( + workflow=workflow, cache_key=cache_key + ) + + return file_history + + @staticmethod + def clear_history_for_workflow( + workflow: Workflow, + ) -> None: + """Clear all file history records associated with a workflow. + + Args: + workflow (Workflow): The workflow to clear the history for. + """ + FileHistory.objects.filter(workflow=workflow).delete() diff --git a/backend/workflow_manager/workflow_v2/generator.py b/backend/workflow_manager/workflow_v2/generator.py new file mode 100644 index 000000000..6543aae7f --- /dev/null +++ b/backend/workflow_manager/workflow_v2/generator.py @@ -0,0 +1,113 @@ +import logging +import uuid +from typing import Any + +from rest_framework.request import Request +from tool_instance_v2.constants import ToolInstanceKey as TIKey +from tool_instance_v2.exceptions import ToolInstantiationError +from tool_instance_v2.tool_processor import ToolProcessor +from unstract.tool_registry.dto import Tool +from workflow_manager.workflow_v2.constants import WorkflowKey +from workflow_manager.workflow_v2.dto import ProvisionalWorkflow +from workflow_manager.workflow_v2.exceptions import WorkflowGenerationError +from workflow_manager.workflow_v2.models.workflow import Workflow as WorkflowModel + +from unstract.core.llm_workflow_generator.llm_interface import LLMInterface + +logger = logging.getLogger(__name__) + + +class WorkflowGenerator: + """Helps with generating a workflow using the LLM.""" + + def __init__(self, workflow_id: str = str(uuid.uuid4())) -> None: + self._request: Request = {} + self._llm_response = "" + self._workflow_id = workflow_id + self._provisional_wf: ProvisionalWorkflow + + @property + def llm_response(self) -> dict[str, Any]: + output: dict[str, str] = self._provisional_wf.output + return output + + @property + def provisional_wf(self) -> ProvisionalWorkflow: + return self._provisional_wf + + def _get_provisional_workflow(self, tools: list[Tool]) -> ProvisionalWorkflow: + """Helper to generate the provisional workflow Gets stored as + `workflow.Workflow.llm_response` eventually.""" + provisional_wf: ProvisionalWorkflow + try: + if not self._request: + raise WorkflowGenerationError( + "Unable to generate a workflow: missing request" + ) + llm_interface = LLMInterface() + + provisional_wf_dict = llm_interface.get_provisional_workflow_from_llm( + workflow_id=self._workflow_id, + tools=tools, + user_prompt=self._request.data.get(WorkflowKey.PROMPT_TEXT), + use_cache=True, + ) + provisional_wf = ProvisionalWorkflow(provisional_wf_dict) + if provisional_wf.result != "OK": + raise WorkflowGenerationError( + f"Unable to generate a workflow: {provisional_wf.output}" + ) + except Exception as e: + logger.error(f"{e}") + raise WorkflowGenerationError + return provisional_wf + + def set_request(self, request: Request) -> None: + self._request = request + + def generate_workflow(self, tools: list[Tool]) -> None: + """Used to talk to the GPT model through core and obtain a provisional + workflow for the user to work with.""" + self._provisional_wf = self._get_provisional_workflow(tools) + + @staticmethod + def get_tool_instance_data_from_llm( + workflow: WorkflowModel, + ) -> list[dict[str, str]]: + """Used to generate the dict of tool instances for a given workflow. + + Call with ToolInstanceSerializer(data=tool_instance_data_list,many=True) + """ + tool_instance_data_list = [] + for step, tool_step in enumerate( + workflow.llm_response.get(WorkflowKey.WF_STEPS, []) + ): + step = step + 1 + logger.info(f"Building tool instance data for step: {step}") + tool_function: str = tool_step[WorkflowKey.WF_TOOL] + wf_input: str = tool_step[WorkflowKey.WF_INPUT] + wf_output: str = tool_step[WorkflowKey.WF_OUTPUT] + try: + tool: Tool = ToolProcessor.get_tool_by_uid(tool_function) + # TODO: Mark optional fields in model and handle in ToolInstance serializer # noqa + tool_instance_data = { + TIKey.PK: tool_step[WorkflowKey.WF_TOOL_UUID], + TIKey.WORKFLOW: workflow.id, + # Added to support changes for UN-154 + WorkflowKey.WF_ID: workflow.id, + TIKey.TOOL_ID: tool_function, + TIKey.METADATA: { + WorkflowKey.WF_TOOL_INSTANCE_ID: tool_step[ + WorkflowKey.WF_TOOL_UUID + ], + **ToolProcessor.get_default_settings(tool), + }, + TIKey.STEP: str(step), + TIKey.INPUT: wf_input, + TIKey.OUTPUT: wf_output, + } + tool_instance_data_list.append(tool_instance_data) + except Exception as e: + logger.error(f"Error while getting data for {tool_function}: {e}") + raise ToolInstantiationError(tool_name=tool_function) + return tool_instance_data_list diff --git a/backend/workflow_manager/workflow_v2/models/__init__.py b/backend/workflow_manager/workflow_v2/models/__init__.py new file mode 100644 index 000000000..938907762 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/models/__init__.py @@ -0,0 +1,4 @@ +from .execution import WorkflowExecution # noqa: F401 +from .execution_log import ExecutionLog # noqa: F401 +from .file_history import FileHistory # noqa: F401 +from .workflow import Workflow # noqa: F401 diff --git a/backend/workflow_manager/workflow_v2/models/execution.py b/backend/workflow_manager/workflow_v2/models/execution.py new file mode 100644 index 000000000..23ceb2bbd --- /dev/null +++ b/backend/workflow_manager/workflow_v2/models/execution.py @@ -0,0 +1,81 @@ +import uuid + +from django.db import models +from utils.models.base_model import BaseModel + +EXECUTION_ERROR_LENGTH = 256 + + +class WorkflowExecution(BaseModel): + class Mode(models.TextChoices): + INSTANT = "INSTANT", "will be executed immediately" + QUEUE = "QUEUE", "will be placed in a queue" + + class Method(models.TextChoices): + DIRECT = "DIRECT", " Execution triggered manually" + SCHEDULED = "SCHEDULED", "Scheduled execution" + + class Type(models.TextChoices): + COMPLETE = "COMPLETE", "For complete execution" + STEP = "STEP", "For step-by-step execution " + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + # TODO: Make as foreign key to access the instance directly + pipeline_id = models.UUIDField( + editable=False, + null=True, + db_comment="ID of the associated pipeline, if applicable", + ) + task_id = models.UUIDField( + editable=False, + null=True, + db_comment="task id of asynchronous execution", + ) + # We can remove workflow_id if it not required + workflow_id = models.UUIDField( + editable=False, db_comment="Id of workflow to be executed" + ) + project_settings_id = models.UUIDField( + editable=False, + default=uuid.uuid4, + db_comment="Id of project settings used while execution", + ) + execution_mode = models.CharField( + choices=Mode.choices, db_comment="Mode of execution" + ) + execution_method = models.CharField( + choices=Method.choices, db_comment="Method of execution" + ) + execution_type = models.CharField( + choices=Type.choices, db_comment="Type of execution" + ) + execution_log_id = models.CharField( + default="", editable=False, db_comment="Execution log events Id" + ) + # TODO: Restrict with an enum + status = models.CharField(default="", db_comment="Current status of execution") + error_message = models.CharField( + max_length=EXECUTION_ERROR_LENGTH, + blank=True, + default="", + db_comment="Details of encountered errors", + ) + attempts = models.IntegerField(default=0, db_comment="number of attempts taken") + execution_time = models.FloatField( + default=0, db_comment="execution time in seconds" + ) + + def __str__(self) -> str: + return ( + f"Workflow execution: {self.id} (" + f"pipeline ID: {self.pipeline_id}, " + f"workflow iD: {self.workflow_id}, " + f"execution method: {self.execution_method}, " + f"status: {self.status}, " + f"error message: {self.error_message})" + ) + + class Meta: + verbose_name = "Workflow Execution" + verbose_name_plural = "Workflow Executions" + db_table = "workflow_execution_v2" diff --git a/backend/workflow_manager/workflow_v2/models/execution_log.py b/backend/workflow_manager/workflow_v2/models/execution_log.py new file mode 100644 index 000000000..6bec96009 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/models/execution_log.py @@ -0,0 +1,22 @@ +import uuid + +from django.db import models +from utils.models.base_model import BaseModel + + +class ExecutionLog(BaseModel): + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + execution_id = models.UUIDField( + editable=False, + db_comment="Execution ID", + ) + data = models.JSONField(db_comment="Execution log data") + event_time = models.DateTimeField(db_comment="Execution log event time") + + def __str__(self): + return f"Execution ID: {self.execution_id}, Message: {self.data}" + + class Meta: + verbose_name = "Execution Log" + verbose_name_plural = "Execution Logs" + db_table = "execution_log_v2" diff --git a/backend/workflow_manager/workflow_v2/models/file_history.py b/backend/workflow_manager/workflow_v2/models/file_history.py new file mode 100644 index 000000000..5ad521cd0 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/models/file_history.py @@ -0,0 +1,52 @@ +import uuid + +from django.db import models +from utils.models.base_model import BaseModel +from workflow_manager.workflow_v2.enums import ExecutionStatus +from workflow_manager.workflow_v2.models.workflow import Workflow + +HASH_LENGTH = 64 + + +class FileHistory(BaseModel): + def is_completed(self) -> bool: + """Check if the execution status is completed. + + Returns: + bool: True if the execution status is completed, False otherwise. + """ + return ( + self.status is not None and self.status == ExecutionStatus.COMPLETED.value + ) + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + cache_key = models.CharField( + max_length=HASH_LENGTH, + db_comment="Hash value of file contents, WF and tool modified times", + ) + workflow = models.ForeignKey( + Workflow, + on_delete=models.CASCADE, + related_name="file_histories", + ) + status = models.TextField( + choices=ExecutionStatus.choices(), + db_comment="Latest status of execution", + ) + error = models.TextField( + blank=True, + default="", + db_comment="Error message", + ) + result = models.TextField(blank=True, db_comment="Result from execution") + + class Meta: + verbose_name = "File History" + verbose_name_plural = "File Histories" + db_table = "file_history_v2" + constraints = [ + models.UniqueConstraint( + fields=["workflow", "cache_key"], + name="unique_workflow_cacheKey", + ), + ] diff --git a/backend/workflow_manager/workflow_v2/models/workflow.py b/backend/workflow_manager/workflow_v2/models/workflow.py new file mode 100644 index 000000000..dcdaa5dae --- /dev/null +++ b/backend/workflow_manager/workflow_v2/models/workflow.py @@ -0,0 +1,95 @@ +import uuid + +from account_v2.models import User +from django.db import models +from utils.models.base_model import BaseModel +from utils.models.organization_mixin import ( + DefaultOrganizationManagerMixin, + DefaultOrganizationMixin, +) + +PROMPT_NAME_LENGTH = 32 +WORKFLOW_STATUS_LENGTH = 16 +EXECUTION_ERROR_LENGTH = 256 +DESCRIPTION_FIELD_LENGTH = 490 +WORKFLOW_NAME_SIZE = 128 + + +class WorkflowModelManager(DefaultOrganizationManagerMixin, models.Manager): + pass + + +class Workflow(DefaultOrganizationMixin, BaseModel): + class WorkflowType(models.TextChoices): + DEFAULT = "DEFAULT", "Not ready yet" + ETL = "ETL", "ETL pipeline" + TASK = "TASK", "TASK pipeline" + API = "API", "API deployment" + APP = "APP", "App deployment" + + class ExecutionAction(models.TextChoices): + START = "START", "Start the Execution" + NEXT = "NEXT", "Execute next tool" + STOP = "STOP", "Stop the execution" + CONTINUE = "CONTINUE", "Continue to full execution" + + # TODO Make this guid as primaryId instaed of current id bigint + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + # TODO: Move prompt fields as a One-One relationship/into Prompt instead + prompt_name = models.CharField(max_length=PROMPT_NAME_LENGTH, default="") + description = models.TextField(max_length=DESCRIPTION_FIELD_LENGTH, default="") + workflow_name = models.CharField(max_length=WORKFLOW_NAME_SIZE) + prompt_text = models.TextField(default="") + is_active = models.BooleanField(default=False) + status = models.CharField(max_length=WORKFLOW_STATUS_LENGTH, default="") + llm_response = models.TextField() + workflow_owner = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="workflows_owned", + null=True, + blank=True, + ) + deployment_type = models.CharField( + choices=WorkflowType.choices, + db_comment="Type of workflow deployment", + default=WorkflowType.DEFAULT, + ) + source_settings = models.JSONField( + null=True, db_comment="Settings for the Source module" + ) + destination_settings = models.JSONField( + null=True, db_comment="Settings for the Destination module" + ) + + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="workflows_created", + null=True, + blank=True, + ) + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="workflows_modified", + null=True, + blank=True, + ) + + # Manager + objects = WorkflowModelManager() + + def __str__(self) -> str: + return f"{self.id}, name: {self.workflow_name}" + + class Meta: + verbose_name = "Workflow" + verbose_name_plural = "Workflows" + db_table = "workflow_v2" + constraints = [ + models.UniqueConstraint( + fields=["workflow_name", "organization"], + name="unique_workflow_name", + ), + ] diff --git a/backend/workflow_manager/workflow_v2/serializers.py b/backend/workflow_manager/workflow_v2/serializers.py new file mode 100644 index 000000000..0f4c8c3d6 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/serializers.py @@ -0,0 +1,165 @@ +import logging +from typing import Any, Optional, Union + +from project.constants import ProjectKey +from rest_framework.serializers import ( + CharField, + ChoiceField, + JSONField, + ModelSerializer, + Serializer, + UUIDField, + ValidationError, +) +from tool_instance_v2.serializers import ToolInstanceSerializer +from tool_instance_v2.tool_instance_helper import ToolInstanceHelper +from workflow_manager.endpoint_v2.models import WorkflowEndpoint +from workflow_manager.workflow_v2.constants import WorkflowExecutionKey, WorkflowKey +from workflow_manager.workflow_v2.exceptions import WorkflowGenerationError +from workflow_manager.workflow_v2.generator import WorkflowGenerator +from workflow_manager.workflow_v2.models.execution import WorkflowExecution +from workflow_manager.workflow_v2.models.execution_log import ExecutionLog +from workflow_manager.workflow_v2.models.workflow import Workflow + +from backend.constants import RequestKey +from backend.serializers import AuditSerializer + +logger = logging.getLogger(__name__) + + +class WorkflowSerializer(AuditSerializer): + tool_instances = ToolInstanceSerializer(many=True, read_only=True) + + class Meta: + model = Workflow + fields = "__all__" + extra_kwargs = { + WorkflowKey.LLM_RESPONSE: { + "required": False, + }, + } + + def to_representation(self, instance: Workflow) -> dict[str, str]: + representation: dict[str, str] = super().to_representation(instance) + representation[WorkflowKey.WF_NAME] = instance.workflow_name + representation[WorkflowKey.WF_TOOL_INSTANCES] = ToolInstanceSerializer( + ToolInstanceHelper.get_tool_instances_by_workflow( + workflow_id=instance.id, order_by="step" + ), + many=True, + context=self.context, + ).data + representation["created_by_email"] = instance.created_by.email + return representation + + def create(self, validated_data: dict[str, Any]) -> Any: + if self.context.get(RequestKey.REQUEST): + validated_data[WorkflowKey.WF_OWNER] = self.context.get( + RequestKey.REQUEST + ).user + return super().create(validated_data) + + def update(self, instance: Any, validated_data: dict[str, Any]) -> Any: + if validated_data.get(WorkflowKey.PROMPT_TEXT): + instance.workflow_tool.all().delete() + return super().update(instance, validated_data) + + def save(self, **kwargs: Any) -> Workflow: + workflow: Workflow = super().save(**kwargs) + if self.validated_data.get(WorkflowKey.PROMPT_TEXT): + try: + tool_serializer = ToolInstanceSerializer( + data=WorkflowGenerator.get_tool_instance_data_from_llm( + workflow=workflow + ), + many=True, + context=self.context, + ) + tool_serializer.is_valid(raise_exception=True) + tool_serializer.save() + except Exception as exc: + logger.error(f"Error while generating tool instances: {exc}") + raise WorkflowGenerationError + + request = self.context.get("request") + if not request: + return workflow + return workflow + + +class ExecuteWorkflowSerializer(Serializer): + workflow_id = UUIDField(required=False) + project_id = UUIDField(required=False) + execution_action = ChoiceField( + choices=Workflow.ExecutionAction.choices, required=False + ) + execution_id = UUIDField(required=False) + log_guid = UUIDField(required=False) + # TODO: Add other fields to handle WFExecution method, mode .etc. + + def get_workflow_id( + self, validated_data: dict[str, Union[str, None]] + ) -> Optional[str]: + return validated_data.get(WorkflowKey.WF_ID) + + def get_project_id( + self, validated_data: dict[str, Union[str, None]] + ) -> Optional[str]: + return validated_data.get(ProjectKey.PROJECT_ID) + + def get_execution_id( + self, validated_data: dict[str, Union[str, None]] + ) -> Optional[str]: + return validated_data.get(WorkflowExecutionKey.EXECUTION_ID) + + def get_log_guid( + self, validated_data: dict[str, Union[str, None]] + ) -> Optional[str]: + return validated_data.get(WorkflowExecutionKey.LOG_GUID) + + def get_execution_action( + self, validated_data: dict[str, Union[str, None]] + ) -> Optional[str]: + return validated_data.get(WorkflowKey.EXECUTION_ACTION) + + def validate( + self, data: dict[str, Union[str, None]] + ) -> dict[str, Union[str, None]]: + workflow_id = data.get(WorkflowKey.WF_ID) + project_id = data.get(ProjectKey.PROJECT_ID) + + if not workflow_id and not project_id: + raise ValidationError( + "At least one of 'workflow_id' or 'project_id' is required." + ) + + return data + + +class ExecuteWorkflowResponseSerializer(Serializer): + workflow_id = UUIDField() + execution_id = UUIDField() + execution_status = CharField() + log_id = CharField() + error = CharField() + result = JSONField() + + +class WorkflowEndpointSerializer(ModelSerializer): + workflow_name = CharField(source="workflow.workflow_name", read_only=True) + + class Meta: + model = WorkflowEndpoint + fields = "__all__" + + +class WorkflowExecutionSerializer(ModelSerializer): + class Meta: + model = WorkflowExecution + fields = "__all__" + + +class WorkflowExecutionLogSerializer(ModelSerializer): + class Meta: + model = ExecutionLog + fields = "__all__" diff --git a/backend/workflow_manager/workflow_v2/urls.py b/backend/workflow_manager/workflow_v2/urls.py new file mode 100644 index 000000000..78e1b9aff --- /dev/null +++ b/backend/workflow_manager/workflow_v2/urls.py @@ -0,0 +1,77 @@ +from django.urls import path +from rest_framework.urlpatterns import format_suffix_patterns +from workflow_manager.workflow_v2.execution_log_view import WorkflowExecutionLogViewSet +from workflow_manager.workflow_v2.execution_view import WorkflowExecutionViewSet +from workflow_manager.workflow_v2.views import WorkflowViewSet + +workflow_list = WorkflowViewSet.as_view( + { + "get": "list", + "post": "create", + } +) +workflow_detail = WorkflowViewSet.as_view( + # fmt: off + { + 'get': 'retrieve', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy' + } + # fmt: on +) +workflow_execute = WorkflowViewSet.as_view({"post": "execute", "put": "activate"}) +execution_entity = WorkflowExecutionViewSet.as_view({"get": "retrieve"}) +execution_list = WorkflowExecutionViewSet.as_view({"get": "list"}) +execution_log_list = WorkflowExecutionLogViewSet.as_view({"get": "list"}) +workflow_clear_cache = WorkflowViewSet.as_view({"get": "clear_cache"}) +workflow_clear_file_marker = WorkflowViewSet.as_view({"get": "clear_file_marker"}) +workflow_schema = WorkflowViewSet.as_view({"get": "get_schema"}) +can_update = WorkflowViewSet.as_view({"get": "can_update"}) +urlpatterns = format_suffix_patterns( + [ + path("", workflow_list, name="workflow-list"), + path("/", workflow_detail, name="workflow-detail"), + path( + "/clear-cache/", + workflow_clear_cache, + name="clear-cache", + ), + path( + "/clear-file-marker/", + workflow_clear_file_marker, + name="clear-file-marker", + ), + path( + "/can-update/", + can_update, + name="can-update", + ), + path("execute/", workflow_execute, name="execute-workflow"), + path( + "active//", + workflow_execute, + name="active-workflow", + ), + path( + "/execution/", + execution_list, + name="execution-list", + ), + path( + "execution//", + execution_entity, + name="workflow-detail", + ), + path( + "execution//logs/", + execution_log_list, + name="execution-log", + ), + path( + "schema/", + workflow_schema, + name="workflow-schema", + ), + ] +) diff --git a/backend/workflow_manager/workflow_v2/views.py b/backend/workflow_manager/workflow_v2/views.py new file mode 100644 index 000000000..f9b9c50f4 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/views.py @@ -0,0 +1,296 @@ +import logging +from typing import Any, Optional + +from connector_v2.connector_instance_helper import ConnectorInstanceHelper +from django.conf import settings +from django.db.models.query import QuerySet +from permissions.permission import IsOwner +from pipeline_v2.models import Pipeline +from pipeline_v2.pipeline_processor import PipelineProcessor +from rest_framework import serializers, status, viewsets +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from tool_instance_v2.tool_processor import ToolProcessor +from unstract.tool_registry.dto import Tool +from utils.filtering import FilterHelper +from workflow_manager.endpoint_v2.destination import DestinationConnector +from workflow_manager.endpoint_v2.endpoint_utils import WorkflowEndpointUtils +from workflow_manager.endpoint_v2.source import SourceConnector +from workflow_manager.workflow_v2.constants import WorkflowKey +from workflow_manager.workflow_v2.dto import ExecutionResponse +from workflow_manager.workflow_v2.enums import SchemaEntity, SchemaType +from workflow_manager.workflow_v2.exceptions import ( + WorkflowDoesNotExistError, + WorkflowGenerationError, + WorkflowRegenerationError, +) +from workflow_manager.workflow_v2.generator import WorkflowGenerator +from workflow_manager.workflow_v2.models.workflow import Workflow +from workflow_manager.workflow_v2.serializers import ( + ExecuteWorkflowResponseSerializer, + ExecuteWorkflowSerializer, + WorkflowSerializer, +) +from workflow_manager.workflow_v2.workflow_helper import ( + WorkflowHelper, + WorkflowSchemaHelper, +) + +from backend.constants import RequestKey + +logger = logging.getLogger(__name__) + + +def make_execution_response(response: ExecutionResponse) -> Any: + return ExecuteWorkflowResponseSerializer(response).data + + +class WorkflowViewSet(viewsets.ModelViewSet): + versioning_class = URLPathVersioning + permission_classes = [IsOwner] + queryset = Workflow.objects.all() + + def get_queryset(self) -> QuerySet: + filter_args = FilterHelper.build_filter_args( + self.request, + RequestKey.PROJECT, + WorkflowKey.WF_OWNER, + WorkflowKey.WF_IS_ACTIVE, + ) + queryset = ( + Workflow.objects.filter(created_by=self.request.user, **filter_args) + if filter_args + else Workflow.objects.filter(created_by=self.request.user) + ) + order_by = self.request.query_params.get("order_by") + if order_by == "desc": + queryset = queryset.order_by("-modified_at") + elif order_by == "asc": + queryset = queryset.order_by("modified_at") + + return queryset + + def get_serializer_class(self) -> serializers.Serializer: + if self.action == "execute": + return ExecuteWorkflowSerializer + else: + return WorkflowSerializer + + def _generate_workflow(self, workflow_id: str) -> WorkflowGenerator: + registry_tools: list[Tool] = ToolProcessor.get_registry_tools() + generator = WorkflowGenerator(workflow_id=workflow_id) + generator.set_request(self.request) + generator.generate_workflow(registry_tools) + return generator + + def perform_update(self, serializer: WorkflowSerializer) -> Workflow: + """To edit a workflow. Regenerates the tool instances for a new prompt. + + Raises: WorkflowGenerationError + """ + kwargs = {} + if serializer.validated_data.get(WorkflowKey.PROMPT_TEXT): + workflow: Workflow = self.get_object() + generator = self._generate_workflow(workflow_id=workflow.id) + kwargs = { + WorkflowKey.LLM_RESPONSE: generator.llm_response, + WorkflowKey.WF_IS_ACTIVE: True, + } + try: + workflow = serializer.save(**kwargs) + return workflow + except Exception as e: + logger.error(f"Error saving workflow to DB: {e}") + raise WorkflowRegenerationError + + def perform_create(self, serializer: WorkflowSerializer) -> Workflow: + """To create a new workflow. Creates the Workflow instance first and + uses it to generate the tool instances. + + Raises: WorkflowGenerationError + """ + try: + workflow = serializer.save( + is_active=True, + ) + WorkflowEndpointUtils.create_endpoints_for_workflow(workflow) + + # Enable GCS configurations to create GCS while creating a workflow + if ( + settings.GOOGLE_STORAGE_ACCESS_KEY_ID + and settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME + ): + ConnectorInstanceHelper.create_default_gcs_connector( + workflow, self.request.user + ) + + except Exception as e: + logger.error(f"Error saving workflow to DB: {e}") + raise WorkflowGenerationError + return workflow + + def get_execution(self, request: Request, pk: str) -> Response: + execution = WorkflowHelper.get_current_execution(pk) + return Response(make_execution_response(execution), status=status.HTTP_200_OK) + + def get_workflow_by_id_or_project_id( + self, + workflow_id: Optional[str] = None, + project_id: Optional[str] = None, + ) -> Workflow: + """Retrieve workflow by workflow id or project Id. + + Args: + workflow_id (Optional[str], optional): workflow Id. + project_id (Optional[str], optional): Project Id. + + Raises: + WorkflowDoesNotExistError: _description_ + + Returns: + Workflow: workflow + """ + if workflow_id: + workflow = WorkflowHelper.get_workflow_by_id(workflow_id) + elif project_id: + workflow = WorkflowHelper.get_active_workflow_by_project_id(project_id) + else: + raise WorkflowDoesNotExistError() + return workflow + + def execute( + self, + request: Request, + pipeline_guid: Optional[str] = None, + ) -> Response: + self.serializer_class = ExecuteWorkflowSerializer + serializer = ExecuteWorkflowSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + workflow_id = serializer.get_workflow_id(serializer.validated_data) + project_id = serializer.get_project_id(serializer.validated_data) + execution_id = serializer.get_execution_id(serializer.validated_data) + execution_action = serializer.get_execution_action(serializer.validated_data) + file_objs = request.FILES.getlist("files") + include_metadata = ( + request.data.get("include_metadata", "false").lower() == "true" + ) + hashes_of_files = {} + if file_objs and execution_id and workflow_id: + hashes_of_files = SourceConnector.add_input_file_to_api_storage( + workflow_id=workflow_id, + execution_id=execution_id, + file_objs=file_objs, + ) + + try: + workflow = self.get_workflow_by_id_or_project_id( + workflow_id=workflow_id, project_id=project_id + ) + execution_response = self.execute_workflow( + workflow=workflow, + execution_action=execution_action, + execution_id=execution_id, + pipeline_guid=pipeline_guid, + hash_values_of_files=hashes_of_files, + include_metadata=include_metadata, + ) + return Response( + make_execution_response(execution_response), + status=status.HTTP_200_OK, + ) + except Exception as exception: + logger.error(f"Error while executing workflow: {exception}") + if file_objs and execution_id and workflow_id: + DestinationConnector.delete_api_storage_dir( + workflow_id=workflow_id, execution_id=execution_id + ) + raise exception + + def execute_workflow( + self, + workflow: Workflow, + execution_action: Optional[str] = None, + execution_id: Optional[str] = None, + pipeline_guid: Optional[str] = None, + hash_values_of_files: dict[str, str] = {}, + include_metadata: bool = False, + ) -> ExecutionResponse: + if execution_action is not None: + # Step execution + execution_response = WorkflowHelper.step_execution( + workflow, + execution_action, + execution_id=execution_id, + hash_values_of_files=hash_values_of_files, + include_metadata=include_metadata, + ) + elif pipeline_guid: + # pipeline execution + PipelineProcessor.update_pipeline( + pipeline_guid, Pipeline.PipelineStatus.INPROGRESS + ) + execution_response = WorkflowHelper.complete_execution( + workflow=workflow, + execution_id=execution_id, + pipeline_id=pipeline_guid, + hash_values_of_files=hash_values_of_files, + ) + else: + execution_response = WorkflowHelper.complete_execution( + workflow=workflow, + execution_id=execution_id, + hash_values_of_files=hash_values_of_files, + include_metadata=include_metadata, + ) + return execution_response + + def activate(self, request: Request, pk: str) -> Response: + workflow = WorkflowHelper.active_project_workflow(pk) + serializer = WorkflowSerializer(workflow) + return Response(serializer.data, status=status.HTTP_200_OK) + + @action(detail=True, methods=["get"]) + def clear_cache(self, request: Request, *args: Any, **kwargs: Any) -> Response: + workflow = self.get_object() + response: dict[str, Any] = WorkflowHelper.clear_cache(workflow_id=workflow.id) + return Response(response.get("message"), status=response.get("status")) + + @action(detail=True, methods=["get"]) + def can_update(self, request: Request, pk: str) -> Response: + response: dict[str, Any] = WorkflowHelper.can_update_workflow(pk) + return Response(response, status=status.HTTP_200_OK) + + @action(detail=True, methods=["get"]) + def clear_file_marker( + self, request: Request, *args: Any, **kwargs: Any + ) -> Response: + workflow = self.get_object() + response: dict[str, Any] = WorkflowHelper.clear_file_marker( + workflow_id=workflow.id + ) + return Response(response.get("message"), status=response.get("status")) + + @action(detail=False, methods=["get"]) + def get_schema(self, request: Request, *args: Any, **kwargs: Any) -> Response: + """Retrieves the JSON schema for source/destination type modules for + entities file/API/DB. + + Takes query params `type` (defaults to "src") and + `entity` (defaults to "file"). + + Returns: + Response: JSON schema for the request made + """ + schema_type = request.query_params.get("type", SchemaType.SRC.value) + schema_entity = request.query_params.get("entity", SchemaEntity.FILE.value) + + WorkflowSchemaHelper.validate_request( + schema_type=SchemaType(schema_type), + schema_entity=SchemaEntity(schema_entity), + ) + json_schema = WorkflowSchemaHelper.get_json_schema( + schema_type=schema_type, schema_entity=schema_entity + ) + return Response(data=json_schema, status=status.HTTP_200_OK) diff --git a/backend/workflow_manager/workflow_v2/workflow_helper.py b/backend/workflow_manager/workflow_v2/workflow_helper.py new file mode 100644 index 000000000..89b4babf0 --- /dev/null +++ b/backend/workflow_manager/workflow_v2/workflow_helper.py @@ -0,0 +1,819 @@ +import json +import logging +import os +import traceback +import uuid +from typing import Any, Optional + +from account_v2.constants import Common +from api_v2.models import APIDeployment +from celery import current_task +from celery import exceptions as celery_exceptions +from celery import shared_task +from celery.result import AsyncResult +from django.db import IntegrityError +from pipeline_v2.models import Pipeline +from pipeline_v2.pipeline_processor import PipelineProcessor +from rest_framework import serializers +from tool_instance_v2.constants import ToolInstanceKey +from tool_instance_v2.models import ToolInstance +from tool_instance_v2.tool_instance_helper import ToolInstanceHelper +from unstract.workflow_execution.enums import LogComponent, LogLevel, LogState +from unstract.workflow_execution.exceptions import StopExecution +from utils.cache_service import CacheService +from utils.constants import Account +from utils.local_context import StateStore +from utils.user_context import UserContext +from workflow_manager.endpoint_v2.destination import DestinationConnector +from workflow_manager.endpoint_v2.source import SourceConnector +from workflow_manager.workflow_v2.constants import ( + CeleryConfigurations, + WorkflowErrors, + WorkflowExecutionKey, + WorkflowMessages, +) +from workflow_manager.workflow_v2.dto import AsyncResultData, ExecutionResponse +from workflow_manager.workflow_v2.enums import ExecutionStatus, SchemaEntity, SchemaType +from workflow_manager.workflow_v2.exceptions import ( + InvalidRequest, + TaskDoesNotExistError, + WorkflowDoesNotExistError, + WorkflowExecutionNotExist, +) +from workflow_manager.workflow_v2.execution import WorkflowExecutionServiceHelper +from workflow_manager.workflow_v2.file_history_helper import FileHistoryHelper +from workflow_manager.workflow_v2.models.execution import WorkflowExecution +from workflow_manager.workflow_v2.models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class WorkflowHelper: + @staticmethod + def get_workflow_by_id(id: str) -> Workflow: + try: + workflow: Workflow = Workflow.objects.get(pk=id) + if not workflow or workflow is None: + raise WorkflowDoesNotExistError() + return workflow + except Workflow.DoesNotExist: + logger.error(f"Error getting workflow: {id}") + raise WorkflowDoesNotExistError() + + @staticmethod + def get_active_workflow_by_project_id(project_id: str) -> Workflow: + try: + workflow: Workflow = Workflow.objects.filter( + project_id=project_id, is_active=True + ).first() + if not workflow or workflow is None: + raise WorkflowDoesNotExistError() + return workflow + except Workflow.DoesNotExist: + raise WorkflowDoesNotExistError() + + @staticmethod + def active_project_workflow(workflow_id: str) -> Workflow: + workflow: Workflow = WorkflowHelper.get_workflow_by_id(workflow_id) + workflow.is_active = True + workflow.save() + return workflow + + @staticmethod + def build_workflow_execution_service( + organization_id: Optional[str], + workflow: Workflow, + tool_instances: list[ToolInstance], + pipeline_id: Optional[str], + single_step: bool, + scheduled: bool, + execution_mode: tuple[str, str], + workflow_execution: Optional[WorkflowExecution], + include_metadata: bool = False, + ) -> WorkflowExecutionServiceHelper: + workflow_execution_service = WorkflowExecutionServiceHelper( + organization_id=organization_id, + workflow=workflow, + tool_instances=tool_instances, + pipeline_id=pipeline_id, + single_step=single_step, + scheduled=scheduled, + mode=execution_mode, + workflow_execution=workflow_execution, + include_metadata=include_metadata, + ) + workflow_execution_service.build() + return workflow_execution_service + + @staticmethod + def process_input_files( + workflow: Workflow, + source: SourceConnector, + destination: DestinationConnector, + execution_service: WorkflowExecutionServiceHelper, + single_step: bool, + hash_values_of_files: dict[str, str] = {}, + ) -> WorkflowExecution: + input_files = source.list_files_from_source() + total_files = len(input_files) + processed_files = 0 + error_raised = 0 + execution_service.publish_initial_workflow_logs(total_files) + execution_service.update_execution( + ExecutionStatus.EXECUTING, increment_attempt=True + ) + for index, input_file in enumerate(input_files): + file_number = index + 1 + try: + is_executed, error = WorkflowHelper.process_file( + current_file_idx=file_number, + total_files=total_files, + input_file=input_file, + workflow=workflow, + source=source, + destination=destination, + execution_service=execution_service, + single_step=single_step, + hash_values_of_files=hash_values_of_files, + ) + if is_executed: + processed_files += 1 + if error: + error_raised += 1 + except StopExecution as exception: + execution_service.update_execution( + ExecutionStatus.STOPPED, error=str(exception) + ) + break + if error_raised and error_raised == total_files: + execution_service.update_execution(ExecutionStatus.ERROR) + else: + execution_service.update_execution(ExecutionStatus.COMPLETED) + + execution_service.publish_final_workflow_logs( + total_files=total_files, processed_files=processed_files + ) + return execution_service.get_execution_instance() + + @staticmethod + def process_file( + current_file_idx: int, + total_files: int, + input_file: str, + workflow: Workflow, + source: SourceConnector, + destination: DestinationConnector, + execution_service: WorkflowExecutionServiceHelper, + single_step: bool, + hash_values_of_files: dict[str, str], + ) -> tuple[bool, Optional[str]]: + file_history = None + error = None + is_executed = False + file_name, file_hash = source.add_file_to_volume( + input_file_path=input_file, + hash_values_of_files=hash_values_of_files, + ) + try: + execution_service.initiate_tool_execution( + current_file_idx, total_files, file_name, single_step + ) + file_history = FileHistoryHelper.get_file_history( + workflow=workflow, cache_key=file_hash + ) + is_executed = execution_service.execute_input_file( + file_name=file_name, + single_step=single_step, + file_history=file_history, + ) + except StopExecution: + raise + except Exception as e: + execution_service.publish_log( + f"Error processing file {input_file}: {str(e)}", + level=LogLevel.ERROR, + ) + error = str(e) + execution_service.publish_update_log( + LogState.RUNNING, + f"Processing output for {file_name}", + LogComponent.DESTINATION, + ) + destination.handle_output( + file_name=file_name, + file_hash=file_hash, + workflow=workflow, + file_history=file_history, + error=error, + input_file_path=input_file, + ) + execution_service.publish_update_log( + LogState.SUCCESS, + f"{file_name}'s output is processed successfully", + LogComponent.DESTINATION, + ) + return is_executed, error + + @staticmethod + def validate_tool_instances_meta( + tool_instances: list[ToolInstance], + ) -> None: + for tool in tool_instances: + ToolInstanceHelper.validate_tool_settings( + user=tool.workflow.created_by, + tool_uid=tool.tool_id, + tool_meta=tool.metadata, + ) + + @staticmethod + def run_workflow( + workflow: Workflow, + hash_values_of_files: dict[str, str] = {}, + organization_id: Optional[str] = None, + pipeline_id: Optional[str] = None, + scheduled: bool = False, + single_step: bool = False, + workflow_execution: Optional[WorkflowExecution] = None, + execution_mode: Optional[tuple[str, str]] = None, + include_metadata: bool = False, + ) -> ExecutionResponse: + tool_instances: list[ToolInstance] = ( + ToolInstanceHelper.get_tool_instances_by_workflow( + workflow.id, ToolInstanceKey.STEP + ) + ) + + WorkflowHelper.validate_tool_instances_meta(tool_instances=tool_instances) + execution_mode = execution_mode or WorkflowExecution.Mode.INSTANT + execution_service = WorkflowHelper.build_workflow_execution_service( + organization_id=organization_id, + workflow=workflow, + tool_instances=tool_instances, + pipeline_id=pipeline_id, + single_step=single_step, + scheduled=scheduled, + execution_mode=execution_mode, + workflow_execution=workflow_execution, + include_metadata=include_metadata, + ) + execution_id = execution_service.execution_id + source = SourceConnector( + organization_id=organization_id, + workflow=workflow, + execution_id=execution_id, + execution_service=execution_service, + ) + destination = DestinationConnector(workflow=workflow, execution_id=execution_id) + # Validating endpoints + source.validate() + destination.validate() + # Execution Process + try: + workflow_execution = WorkflowHelper.process_input_files( + workflow, + source, + destination, + execution_service, + single_step=single_step, + hash_values_of_files=hash_values_of_files, + ) + # TODO: Update through signals + WorkflowHelper._update_pipeline_status( + pipeline_id=pipeline_id, workflow_execution=workflow_execution + ) + return ExecutionResponse( + str(workflow.id), + str(workflow_execution.id), + workflow_execution.status, + log_id=str(execution_service.execution_log_id), + error=workflow_execution.error_message, + mode=workflow_execution.execution_mode, + result=destination.api_results, + ) + finally: + destination.delete_execution_directory() + + @staticmethod + def _update_pipeline_status( + pipeline_id: Optional[str], workflow_execution: WorkflowExecution + ) -> None: + try: + if pipeline_id: + # Update pipeline status + if workflow_execution.status != ExecutionStatus.ERROR.value: + PipelineProcessor.update_pipeline( + pipeline_id, Pipeline.PipelineStatus.SUCCESS + ) + else: + PipelineProcessor.update_pipeline( + pipeline_id, Pipeline.PipelineStatus.FAILURE + ) + # Expected exception since API deployments are not tracked in Pipeline + except Pipeline.DoesNotExist: + pass + except Exception as e: + logger.warning( + f"Error updating pipeline {pipeline_id} status: {e}, " + f"with workflow execution: {workflow_execution}" + ) + + @staticmethod + def get_status_of_async_task( + execution_id: str, + ) -> ExecutionResponse: + """Get celery task status. + + Args: + execution_id (str): workflow execution id + + Raises: + TaskDoesNotExistError: Not found exception + + Returns: + ExecutionResponse: _description_ + """ + execution = WorkflowExecution.objects.get(id=execution_id) + + if not execution.task_id: + raise TaskDoesNotExistError() + + result = AsyncResult(str(execution.task_id)) + + task = AsyncResultData(async_result=result) + return ExecutionResponse( + execution.workflow_id, + execution_id, + task.status, + result=task.result, + ) + + @staticmethod + def execute_workflow_async( + workflow_id: str, + execution_id: str, + hash_values_of_files: dict[str, str], + timeout: int = -1, + pipeline_id: Optional[str] = None, + include_metadata: bool = False, + ) -> ExecutionResponse: + """Adding a workflow to the queue for execution. + + Args: + workflow_id (str): workflowId + execution_id (str): Execution ID + timeout (int): Celery timeout (timeout -1 : async execution) + pipeline_id (Optional[str], optional): Optional pipeline. Defaults to None. + include_metadata (bool): Whether to include metadata in the prompt output + + Returns: + ExecutionResponse: Existing status of execution + """ + try: + org_schema = UserContext.get_organization_identifier() + log_events_id = StateStore.get(Common.LOG_EVENTS_ID) + async_execution = WorkflowHelper.execute_bin.delay( + org_schema, + workflow_id, + hash_values_of_files=hash_values_of_files, + execution_id=execution_id, + pipeline_id=pipeline_id, + log_events_id=log_events_id, + include_metadata=include_metadata, + ) + if timeout > -1: + async_execution.wait( + timeout=timeout, + interval=CeleryConfigurations.INTERVAL, + ) + task = AsyncResultData(async_result=async_execution) + logger.info(f"Job {async_execution} enqueued.") + celery_result = task.to_dict() + task_result = celery_result.get("result") + return ExecutionResponse( + workflow_id, + execution_id, + task.status, + result=task_result, + ) + except celery_exceptions.TimeoutError: + return ExecutionResponse( + workflow_id, + execution_id, + async_execution.status, + message=WorkflowMessages.CELERY_TIMEOUT_MESSAGE, + ) + except Exception as error: + WorkflowExecutionServiceHelper.update_execution_status( + execution_id, ExecutionStatus.ERROR + ) + logger.error(f"Errors while job enqueueing {str(error)}") + logger.error(f"Error {traceback.format_exc()}") + return ExecutionResponse( + workflow_id, + execution_id, + ExecutionStatus.ERROR.value, + error=str(error), + ) + + @staticmethod + @shared_task( + name="async_execute_bin", + acks_late=True, + autoretry_for=(Exception,), + max_retries=1, + retry_backoff=True, + retry_backoff_max=500, + retry_jitter=True, + ) + def execute_bin( + schema_name: str, + workflow_id: str, + execution_id: str, + hash_values_of_files: dict[str, str], + scheduled: bool = False, + execution_mode: Optional[tuple[str, str]] = None, + pipeline_id: Optional[str] = None, + include_metadata: bool = False, + **kwargs: dict[str, Any], + ) -> Optional[list[Any]]: + """Asynchronous Execution By celery. + + Args: + schema_name (str): schema name to get Data + workflow_id (str): Workflow Id + execution_id (str): Id of the execution + scheduled (bool, optional): Represents if it is a scheduled execution + Defaults to False + execution_mode (Optional[WorkflowExecution.Mode]): WorkflowExecution Mode + Defaults to None + pipeline_id (Optional[str], optional): Id of pipeline. Defaults to None + include_metadata (bool): Whether to include metadata in the prompt output + + Kwargs: + log_events_id (str): Session ID of the user, + helps establish WS connection for streaming logs to the FE + + Returns: + dict[str, list[Any]]: Returns a dict with result from workflow execution + """ + task_id = current_task.request.id + # Set organization in state store for execution + StateStore.set(Account.ORGANIZATION_ID, schema_name) + return WorkflowHelper.execute_workflow( + organization_id=schema_name, + task_id=task_id, + workflow_id=workflow_id, + execution_id=execution_id, + hash_values_of_files=hash_values_of_files, + scheduled=scheduled, + execution_mode=execution_mode, + pipeline_id=pipeline_id, + include_metadata=include_metadata, + **kwargs, + ) + + @staticmethod + def execute_workflow( + organization_id: str, + task_id: str, + workflow_id: str, + execution_id: str, + hash_values_of_files: dict[str, str], + scheduled: bool = False, + execution_mode: Optional[tuple[str, str]] = None, + pipeline_id: Optional[str] = None, + include_metadata: bool = False, + **kwargs: dict[str, Any], + ) -> Optional[list[Any]]: + """Asynchronous Execution By celery. + + Args: + schema_name (str): schema name to get Data + workflow_id (str): Workflow Id + execution_id (Optional[str], optional): Id of the execution. + Defaults to None. + scheduled (bool, optional): Represents if it is a scheduled + execution. Defaults to False. + execution_mode (Optional[WorkflowExecution.Mode]): + WorkflowExecution Mode. Defaults to None. + pipeline_id (Optional[str], optional): Id of pipeline. + Defaults to None. + include_metadata (bool): Whether to include metadata in the prompt output + + Kwargs: + log_events_id (str): Session ID of the user, helps establish + WS connection for streaming logs to the FE + + Returns: + dict[str, list[Any]]: Returns a dict with result from + workflow execution + """ + workflow = Workflow.objects.get(id=workflow_id) + try: + workflow_execution = ( + WorkflowExecutionServiceHelper.create_workflow_execution( + workflow_id=workflow_id, + single_step=False, + pipeline_id=pipeline_id, + mode=WorkflowExecution.Mode.QUEUE, + execution_id=execution_id, + **kwargs, # type: ignore + ) + ) + except IntegrityError: + # Use existing instance on retry attempt + workflow_execution = WorkflowExecution.objects.get(pk=execution_id) + WorkflowExecutionServiceHelper.update_execution_task( + execution_id=execution_id, task_id=task_id + ) + result = WorkflowHelper.run_workflow( + workflow=workflow, + organization_id=organization_id, + pipeline_id=pipeline_id, + scheduled=scheduled, + workflow_execution=workflow_execution, + execution_mode=execution_mode, + hash_values_of_files=hash_values_of_files, + include_metadata=include_metadata, + ).result + return result + + @staticmethod + def complete_execution( + workflow: Workflow, + execution_id: Optional[str] = None, + pipeline_id: Optional[str] = None, + hash_values_of_files: dict[str, str] = {}, + include_metadata: bool = False, + ) -> ExecutionResponse: + if pipeline_id: + logger.info(f"Executing pipeline: {pipeline_id}") + response: ExecutionResponse = WorkflowHelper.execute_workflow_async( + workflow_id=workflow.id, + pipeline_id=pipeline_id, + execution_id=str(uuid.uuid4()), + hash_values_of_files=hash_values_of_files, + ) + return response + if execution_id is None: + # Creating execution entity and return + return WorkflowHelper.create_and_make_execution_response( + workflow_id=workflow.id, pipeline_id=pipeline_id + ) + try: + # Normal execution + workflow_execution = WorkflowExecution.objects.get(pk=execution_id) + if ( + workflow_execution.status != ExecutionStatus.PENDING.value + or workflow_execution.execution_type != WorkflowExecution.Type.COMPLETE + ): + raise InvalidRequest(WorkflowErrors.INVALID_EXECUTION_ID) + return WorkflowHelper.run_workflow( + workflow=workflow, + workflow_execution=workflow_execution, + hash_values_of_files=hash_values_of_files, + include_metadata=include_metadata, + ) + except WorkflowExecution.DoesNotExist: + return WorkflowHelper.create_and_make_execution_response( + workflow_id=workflow.id, pipeline_id=pipeline_id + ) + + @staticmethod + def get_current_execution(execution_id: str) -> ExecutionResponse: + try: + workflow_execution = WorkflowExecution.objects.get(pk=execution_id) + return ExecutionResponse( + workflow_execution.workflow_id, + workflow_execution.id, + workflow_execution.status, + log_id=workflow_execution.execution_log_id, + error=workflow_execution.error_message, + mode=workflow_execution.execution_mode, + ) + except WorkflowExecution.DoesNotExist: + raise WorkflowExecutionNotExist() + + @staticmethod + def step_execution( + workflow: Workflow, + execution_action: str, + execution_id: Optional[str] = None, + hash_values_of_files: dict[str, str] = {}, + include_metadata: bool = False, + ) -> ExecutionResponse: + if execution_action is Workflow.ExecutionAction.START.value: # type: ignore + if execution_id is None: + return WorkflowHelper.create_and_make_execution_response( + workflow_id=workflow.id, single_step=True + ) + try: + workflow_execution = WorkflowExecution.objects.get(pk=execution_id) + return WorkflowHelper.run_workflow( + workflow=workflow, + single_step=True, + workflow_execution=workflow_execution, + hash_values_of_files=hash_values_of_files, + include_metadata=include_metadata, + ) + except WorkflowExecution.DoesNotExist: + return WorkflowHelper.create_and_make_execution_response( + workflow_id=workflow.id, single_step=True + ) + + else: + if execution_id is None: + raise InvalidRequest("execution_id is missed") + try: + workflow_execution = WorkflowExecution.objects.get(pk=execution_id) + except WorkflowExecution.DoesNotExist: + raise WorkflowExecutionNotExist(WorkflowErrors.INVALID_EXECUTION_ID) + if ( + workflow_execution.status != ExecutionStatus.PENDING.value + or workflow_execution.execution_type != WorkflowExecution.Type.STEP + ): + raise InvalidRequest(WorkflowErrors.INVALID_EXECUTION_ID) + current_action: Optional[str] = CacheService.get_key(execution_id) + logger.info(f"workflow_execution.current_action {current_action}") + if current_action is None: + raise InvalidRequest(WorkflowErrors.INVALID_EXECUTION_ID) + CacheService.set_key(execution_id, execution_action) + workflow_execution = WorkflowExecution.objects.get(pk=execution_id) + + return ExecutionResponse( + workflow.id, + execution_id, + workflow_execution.status, + log_id=workflow_execution.execution_log_id, + error=workflow_execution.error_message, + mode=workflow_execution.execution_mode, + ) + + @staticmethod + def create_and_make_execution_response( + workflow_id: str, + pipeline_id: Optional[str] = None, + single_step: bool = False, + mode: tuple[str, str] = WorkflowExecution.Mode.INSTANT, + ) -> ExecutionResponse: + log_events_id = StateStore.get(Common.LOG_EVENTS_ID) + workflow_execution = WorkflowExecutionServiceHelper.create_workflow_execution( + workflow_id=workflow_id, + single_step=single_step, + pipeline_id=pipeline_id, + mode=mode, + log_events_id=log_events_id, + ) + return ExecutionResponse( + workflow_execution.workflow_id, + workflow_execution.id, + workflow_execution.status, + log_id=workflow_execution.execution_log_id, + error=workflow_execution.error_message, + mode=workflow_execution.execution_mode, + ) + + # TODO: Access cache through a manager + @staticmethod + def clear_cache(workflow_id: str) -> dict[str, Any]: + """Function to clear cache with a specific pattern.""" + response: dict[str, Any] = {} + try: + key_pattern = f"*:cache:{workflow_id}:*" + CacheService.clear_cache(key_pattern) + response["message"] = WorkflowMessages.CACHE_CLEAR_SUCCESS + response["status"] = 200 + return response + except Exception as exc: + logger.error(f"Error occurred while clearing cache : {exc}") + response["message"] = WorkflowMessages.CACHE_CLEAR_FAILED + response["status"] = 400 + return response + + @staticmethod + def clear_file_marker(workflow_id: str) -> dict[str, Any]: + """Function to clear file marker from the cache.""" + # Clear file history from the table + response: dict[str, Any] = {} + workflow = Workflow.objects.get(id=workflow_id) + try: + FileHistoryHelper.clear_history_for_workflow(workflow=workflow) + response["message"] = WorkflowMessages.FILE_MARKER_CLEAR_SUCCESS + response["status"] = 200 + return response + except Exception as exc: + logger.error(f"Error occurred while clearing file marker : {exc}") + response["message"] = WorkflowMessages.FILE_MARKER_CLEAR_FAILED + response["status"] = 400 + return response + + @staticmethod + def get_workflow_execution_id(execution_id: str) -> str: + wf_exec_prefix = WorkflowExecutionKey.WORKFLOW_EXECUTION_ID_PREFIX + workflow_execution_id = f"{wf_exec_prefix}-{execution_id}" + return workflow_execution_id + + @staticmethod + def get_execution_by_id(execution_id: str) -> WorkflowExecution: + try: + execution: WorkflowExecution = WorkflowExecution.objects.get( + id=execution_id + ) + return execution + except WorkflowExecution.DoesNotExist: + raise WorkflowDoesNotExistError() + + @staticmethod + def make_async_result(obj: AsyncResult) -> dict[str, Any]: + return { + "id": obj.id, + "status": obj.status, + "result": obj.result, + "is_ready": obj.ready(), + "is_failed": obj.failed(), + "info": obj.info, + } + + @staticmethod + def can_update_workflow(workflow_id: str) -> dict[str, Any]: + try: + workflow: Workflow = Workflow.objects.get(pk=workflow_id) + if not workflow or workflow is None: + raise WorkflowDoesNotExistError() + used_count = Pipeline.objects.filter(workflow=workflow).count() + if used_count == 0: + used_count = APIDeployment.objects.filter(workflow=workflow).count() + return {"can_update": used_count == 0} + except Workflow.DoesNotExist: + logger.error(f"Error getting workflow: {id}") + raise WorkflowDoesNotExistError() + + +class WorkflowSchemaHelper: + """Helper class for workflow schema related methods.""" + + @staticmethod + def validate_request(schema_type: SchemaType, schema_entity: SchemaEntity) -> bool: + """Validates the given args for reading the JSON schema. + + Schema type of `src`, allows entities `file` and `api` + Schema type of `dest`, allows entities `db` + + Args: + schema_type (SchemaType): Enum with values `src`, `dest` + schema_entity (SchemaEntity): Enum with values `file`, `api`, `db` + + Raises: + serializers.ValidationError: If invalid values/ + combination is requested + + Returns: + bool: _description_ + """ + possible_types = [e.value for e in SchemaType] + possible_entities = [e.value for e in SchemaEntity] + + if schema_type.value not in possible_types: + raise serializers.ValidationError( + f"Invalid value for 'type': {schema_type.value}, " + f"should be one of {possible_types}" + ) + + if schema_entity.value not in possible_entities: + raise serializers.ValidationError( + f"Invalid value for 'entity': {schema_entity.value}, " + f"should be one of {possible_entities}" + ) + + if (schema_type == SchemaType.SRC and schema_entity == SchemaEntity.DB) or ( + schema_type == SchemaType.DEST and schema_entity != SchemaEntity.DB + ): + raise serializers.ValidationError( + f"Invalid values for 'type': {schema_type.value}, " + f"'entity': {schema_entity.value}." + f"Param 'type': {SchemaType.SRC.value} allows " + f"{SchemaEntity.FILE.value} and {SchemaEntity.API.value}" + f"'type': {SchemaType.DEST.value} allows " + f"{SchemaEntity.DB.value}." + ) + return True + + @staticmethod + def get_json_schema( + schema_type: SchemaType, schema_entity: SchemaEntity + ) -> dict[str, Any]: + """Reads and returns the JSON schema for the given args. + + Args: + schema_type (SchemaType): Enum with values `src`, `dest` + schema_entity (SchemaEntity): Enum with values `file`, `api`, `db` + + Returns: + dict[str, Any]: JSON schema for the requested entity + """ + schema_path = ( + f"{os.path.dirname(__file__)}/static/" f"{schema_type}/{schema_entity}.json" + ) + with open(schema_path, encoding="utf-8") as file: + schema = json.load(file) + return schema # type: ignore