diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index deb24de93..3a8eddfa7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -140,7 +140,7 @@ repos: - id: markdownlint-fix args: [--disable, MD013] - repo: https://github.com/pdm-project/pdm - rev: 2.15.4 + rev: 2.16.1 hooks: - id: pdm-lock-check # - repo: local diff --git a/CONTRIBUTE.md b/CONTRIBUTE.md index fda4aa395..762700414 100644 --- a/CONTRIBUTE.md +++ b/CONTRIBUTE.md @@ -15,7 +15,7 @@ Use LLMs to eliminate manual processes involving unstructured data. Just run the `run-platform.sh` launch script to get started in few minutes. -The launch script does env setup with default values, pulls public docker images or builds them locally and finally runs them in containers. +The launch script configures the env with sane defaults, pulls public docker images or builds them locally and finally runs them in containers. ```bash # Pull and run entire Unstract platform with default env config. @@ -45,6 +45,7 @@ The launch script does env setup with default values, pulls public docker images Now visit [http://frontend.unstract.localhost](http://frontend.unstract.localhost) in your browser. +NOTE: Modify the `.env` files present in each service folder to update its runtime behaviour. Run docker compose up again for the changes to take effect.``` That's all. Enjoy! ## Authentication diff --git a/backend/account_v2/ReadMe.md b/backend/account_v2/ReadMe.md new file mode 100644 index 000000000..35695cfb5 --- /dev/null +++ b/backend/account_v2/ReadMe.md @@ -0,0 +1,26 @@ +# Basic WorkFlow + +`We can Add Workflows Here` + +## Login + +### Step + +1. Login +2. Get Organizations +3. Set Organization +4. Use organizational APIs /unstract// + +## Switch organization + +1. Get Organizations +2. Set Organization +3. Use organizational APIs /unstract// + +## Get current user and Organization data + +- Use Get User Profile and Get Organization Info APIs + +## Signout + +1.signout APi diff --git a/backend/account_v2/__init__.py b/backend/account_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/account_v2/admin.py b/backend/account_v2/admin.py new file mode 100644 index 000000000..e0b96cce8 --- /dev/null +++ b/backend/account_v2/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from .models import Organization, User + +admin.site.register([Organization, User]) diff --git a/backend/account_v2/apps.py b/backend/account_v2/apps.py new file mode 100644 index 000000000..65343132c --- /dev/null +++ b/backend/account_v2/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class AccountConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "account_v2" diff --git a/backend/account_v2/authentication_controller.py b/backend/account_v2/authentication_controller.py new file mode 100644 index 000000000..0aec22294 --- /dev/null +++ b/backend/account_v2/authentication_controller.py @@ -0,0 +1,489 @@ +import logging +from typing import Any, Optional, Union + +from account_v2.authentication_helper import AuthenticationHelper +from account_v2.authentication_plugin_registry import AuthenticationPluginRegistry +from account_v2.authentication_service import AuthenticationService +from account_v2.constants import ( + AuthorizationErrorCode, + Common, + Cookie, + ErrorMessage, + OrganizationMemberModel, +) +from account_v2.custom_exceptions import ( + DuplicateData, + Forbidden, + MethodNotImplemented, + UserNotExistError, +) +from account_v2.dto import ( + MemberInvitation, + OrganizationData, + UserInfo, + UserInviteResponse, + UserRoleData, +) +from account_v2.exceptions import OrganizationNotExist +from account_v2.models import Organization, User +from account_v2.organization import OrganizationService +from account_v2.serializer import ( + GetOrganizationsResponseSerializer, + OrganizationSerializer, + SetOrganizationsResponseSerializer, +) +from account_v2.user import UserService +from django.conf import settings +from django.contrib.auth import logout as django_logout +from django.db.utils import IntegrityError +from django.middleware import csrf +from django.shortcuts import redirect +from rest_framework import status +from rest_framework.request import Request +from rest_framework.response import Response +from tenant_account_v2.models import OrganizationMember as OrganizationMember +from tenant_account_v2.organization_member_service import OrganizationMemberService +from utils.cache_service import CacheService +from utils.local_context import StateStore +from utils.user_context import UserContext +from utils.user_session import UserSessionUtils + +logger = logging.getLogger(__name__) + + +class AuthenticationController: + """Authentication Controller This controller class manages user + authentication processes.""" + + def __init__(self) -> None: + """This method initializes the controller by selecting the appropriate + authentication plugin based on availability.""" + self.authentication_helper = AuthenticationHelper() + if AuthenticationPluginRegistry.is_plugin_available(): + self.auth_service: AuthenticationService = ( + AuthenticationPluginRegistry.get_plugin() + ) + else: + self.auth_service = AuthenticationService() + + def user_login( + self, + request: Request, + ) -> Any: + return self.auth_service.user_login(request) + + def user_signup(self, request: Request) -> Any: + return self.auth_service.user_signup(request) + + def authorization_callback( + self, request: Request, backend: str = settings.DEFAULT_MODEL_BACKEND + ) -> Any: + """Handle authorization callback. + + This function processes the authorization callback from + an external service. + + Args: + request (Request): Request instance + backend (str, optional): backend used to use login. + Defaults: settings.DEFAULT_MODEL_BACKEND. + + Returns: + Any: Redirect response + """ + try: + return self.auth_service.handle_authorization_callback( + request=request, backend=backend + ) + except Exception as ex: + logger.error(f"Error while handling authorization callback: {ex}") + return redirect(f"{settings.ERROR_URL}") + + def user_organizations(self, request: Request) -> Any: + """List a user's organizations. + + Args: + user (User): User instance + z_code (str): _description_ + + Returns: + list[OrganizationData]: _description_ + """ + + try: + organizations = self.auth_service.user_organizations(request) + except Exception as ex: + # + self.user_logout(request) + + response = Response( + status=status.HTTP_412_PRECONDITION_FAILED, + ) + if hasattr(ex, "code") and ex.code in { + AuthorizationErrorCode.USF, + AuthorizationErrorCode.USR, + AuthorizationErrorCode.INE001, + AuthorizationErrorCode.INE002, + }: # type: ignore + response.data = ({"domain": ex.data.get("domain"), "code": ex.code},) + return response + # Return in case even if missed unknown exception in + # self.auth_service.user_organizations(request) + return response + + user: User = request.user + org_ids = {org.id for org in organizations} + + CacheService.set_user_organizations(user.user_id, list(org_ids)) + + serialized_organizations = GetOrganizationsResponseSerializer( + organizations, many=True + ).data + response = Response( + status=status.HTTP_200_OK, + data={ + "message": "success", + "organizations": serialized_organizations, + }, + ) + if Cookie.CSRFTOKEN not in request.COOKIES: + csrf_token = csrf.get_token(request) + response.set_cookie(Cookie.CSRFTOKEN, csrf_token) + + return response + + def set_user_organization(self, request: Request, organization_id: str) -> Response: + user: User = request.user + new_organization = False + organization_ids = CacheService.get_user_organizations(user.user_id) + if not organization_ids: + z_organizations: list[OrganizationData] = ( + self.auth_service.get_organizations_by_user_id(user.user_id) + ) + organization_ids = {org.id for org in z_organizations} + if organization_id and organization_id in organization_ids: + # Set organization in user context + UserContext.set_organization_identifier(organization_id) + organization = OrganizationService.get_organization_by_org_id( + organization_id + ) + if not organization: + try: + organization_data: OrganizationData = ( + self.auth_service.get_organization_by_org_id(organization_id) + ) + except ValueError: + raise OrganizationNotExist() + try: + organization = OrganizationService.create_organization( + organization_data.name, + organization_data.display_name, + organization_data.id, + ) + new_organization = True + except IntegrityError: + raise DuplicateData( + f"{ErrorMessage.ORGANIZATION_EXIST}, \ + {ErrorMessage.DUPLICATE_API}" + ) + self.create_tenant_user(organization=organization, user=user) + + if new_organization: + try: + self.auth_service.hubspot_signup_api(request=request) + except MethodNotImplemented: + logger.info("hubspot_signup_api not implemented") + + try: + self.auth_service.frictionless_onboarding( + organization=organization, user=user + ) + except MethodNotImplemented: + logger.info("frictionless_onboarding not implemented") + + self.authentication_helper.create_initial_platform_key( + user=user, organization=organization + ) + logger.info( + f"New organization created with Id {organization_id}", + ) + + user_info: Optional[UserInfo] = self.get_user_info(request) + serialized_user_info = SetOrganizationsResponseSerializer(user_info).data + organization_info = OrganizationSerializer(organization).data + response: Response = Response( + status=status.HTTP_200_OK, + data={ + "user": serialized_user_info, + "organization": organization_info, + f"{Common.LOG_EVENTS_ID}": StateStore.get(Common.LOG_EVENTS_ID), + }, + ) + current_organization_id = UserSessionUtils.get_organization_id(request) + if current_organization_id: + OrganizationMemberService.remove_user_membership_in_organization_cache( + user_id=user.user_id, + organization_id=current_organization_id, + ) + UserSessionUtils.set_organization_id(request, organization_id) + OrganizationMemberService.set_user_membership_in_organization_cache( + user_id=user.user_id, organization_id=organization_id + ) + return response + return Response(status=status.HTTP_403_FORBIDDEN) + + def get_user_info(self, request: Request) -> Optional[UserInfo]: + return self.auth_service.get_user_info(request) + + def is_admin_by_role(self, role: str) -> bool: + """Check the role is act as admin in the context of authentication + plugin. + + Args: + role (str): role + + Returns: + bool: _description_ + """ + return self.auth_service.is_admin_by_role(role=role) + + def get_organization_info(self, org_id: str) -> Optional[Organization]: + organization = OrganizationService.get_organization_by_org_id(org_id=org_id) + return organization + + def make_organization_and_add_member( + self, + user_id: str, + user_name: str, + organization_name: Optional[str] = None, + display_name: Optional[str] = None, + ) -> Optional[OrganizationData]: + return self.auth_service.make_organization_and_add_member( + user_id, user_name, organization_name, display_name + ) + + def make_user_organization_name(self) -> str: + return self.auth_service.make_user_organization_name() + + def make_user_organization_display_name(self, user_name: str) -> str: + return self.auth_service.make_user_organization_display_name(user_name) + + def user_logout(self, request: Request) -> Response: + response = self.auth_service.user_logout(request=request) + organization_id = UserSessionUtils.get_organization_id(request) + user_id = UserSessionUtils.get_user_id(request) + if organization_id: + OrganizationMemberService.remove_user_membership_in_organization_cache( + user_id=user_id, organization_id=organization_id + ) + django_logout(request) + return response + + def get_organization_members_by_org_id( + self, organization_id: Optional[str] = None + ) -> list[OrganizationMember]: + members: list[OrganizationMember] = OrganizationMemberService.get_members() + return members + + def get_organization_members_by_user(self, user: User) -> OrganizationMember: + """Get organization member by user. This method will return + organization member object for given user. + + Args: + user (User): UserEntity + + Returns: + OrganizationMember: OrganizationMemberEntity + """ + member: OrganizationMember = OrganizationMemberService.get_user_by_id( + id=user.id + ) + return member + + def get_user_roles(self) -> list[UserRoleData]: + return self.auth_service.get_roles() + + def get_user_invitations(self, organization_id: str) -> list[MemberInvitation]: + return self.auth_service.get_invitations(organization_id=organization_id) + + def delete_user_invitation(self, organization_id: str, invitation_id: str) -> bool: + return self.auth_service.delete_invitation( + organization_id=organization_id, invitation_id=invitation_id + ) + + def reset_user_password(self, user: User) -> Response: + return self.auth_service.reset_user_password(user) + + def invite_user( + self, + admin: User, + org_id: str, + user_list: list[dict[str, Union[str, None]]], + ) -> list[UserInviteResponse]: + """Invites users to join an organization. + + Args: + admin (User): Admin user initiating the invitation. + org_id (str): ID of the organization to which users are invited. + user_list (list[dict[str, Union[str, None]]]): + List of user details for invitation. + Returns: + list[UserInviteResponse]: List of responses for each + user invitation. + """ + admin_user = OrganizationMemberService.get_user_by_id(id=admin.id) + if not self.auth_service.is_organization_admin(admin_user): + raise Forbidden() + response = [] + for user_item in user_list: + email = user_item.get("email") + role = user_item.get("role") + if email: + user = OrganizationMemberService.get_user_by_email(email=email) + user_response = {} + user_response["email"] = email + status = False + message = "User is already part of current organization" + # Check if user is already part of current organization + if not user: + status = self.auth_service.invite_user( + admin_user, org_id, email, role=role + ) + message = "User invitation successful." + + response.append( + UserInviteResponse( + email=email, + status="success" if status else "failed", + message=message, + ) + ) + return response + + def remove_users_from_organization( + self, admin: User, organization_id: str, user_emails: list[str] + ) -> bool: + admin_user = OrganizationMemberService.get_user_by_id(id=admin.id) + user_ids = OrganizationMemberService.get_members_by_user_email( + user_emails=user_emails, + values_list_fields=[ + OrganizationMemberModel.USER_ID, + OrganizationMemberModel.ID, + ], + ) + user_ids_list: list[str] = [] + pk_list: list[str] = [] + for user in user_ids: + user_ids_list.append(user[0]) + pk_list.append(user[1]) + if len(user_ids_list) > 0: + is_removed = self.auth_service.remove_users_from_organization( + admin=admin_user, + organization_id=organization_id, + user_ids=user_ids_list, + ) + else: + is_removed = False + if is_removed: + AuthenticationHelper.remove_users_from_organization_by_pks(pk_list) + for user_id in user_ids_list: + OrganizationMemberService.remove_user_membership_in_organization_cache( + user_id, organization_id + ) + + return is_removed + + def add_user_role( + self, admin: User, org_id: str, email: str, role: str + ) -> Optional[str]: + admin_user = OrganizationMemberService.get_user_by_id(id=admin.id) + user = OrganizationMemberService.get_user_by_email(email=email) + if user: + current_roles = self.auth_service.add_organization_user_role( + admin_user, org_id, user.user.user_id, [role] + ) + if current_roles: + self.save_organization_user_role( + user_uid=user.user.user.id, role=current_roles[0] + ) + return current_roles[0] + else: + return None + + def remove_user_role( + self, admin: User, org_id: str, email: str, role: str + ) -> Optional[str]: + admin_user = OrganizationMemberService.get_user_by_id(id=admin.id) + organization_member = OrganizationMemberService.get_user_by_email(email=email) + if organization_member: + current_roles = self.auth_service.remove_organization_user_role( + admin_user, org_id, organization_member.user.user_id, [role] + ) + if current_roles: + self.save_organization_user_role( + user_uid=organization_member.user.id, + role=current_roles[0], + ) + return current_roles[0] + else: + return None + + def save_organization_user_role(self, user_uid: str, role: str) -> None: + organization_user = OrganizationMemberService.get_user_by_id(id=user_uid) + if organization_user: + # consider single role + organization_user.role = role + organization_user.save() + + def create_tenant_user(self, organization: Organization, user: User) -> None: + existing_tenant_user = OrganizationMemberService.get_user_by_id(id=user.id) + + if existing_tenant_user: + return None + + account_user = self.get_or_create_user(user=user) + if not account_user: + raise UserNotExistError() + + logger.info(f"Creating account for {user.email}") + user_roles = self.auth_service.get_organization_role_of_user( + user_id=account_user.user_id, + organization_id=organization.organization_id, + ) + user_role = user_roles[0] + try: + tenant_user: OrganizationMember = OrganizationMember( + user=user, + role=user_role, + is_login_onboarding_msg=False, + is_prompt_studio_onboarding_msg=False, + ) + tenant_user.save() + logger.info( + f"{tenant_user.user.email} added in to the organization " + f"{organization.organization_id}" + ) + except IntegrityError: + logger.warning(f"Account already exists for {user.email}") + + def get_or_create_user( + self, user: User + ) -> Optional[Union[User, OrganizationMember]]: + user_service = UserService() + if user.id: + account_user: Optional[User] = user_service.get_user_by_id(user.id) + if account_user: + return account_user + elif user.email: + account_user = user_service.get_user_by_email(email=user.email) + if account_user: + return account_user + if user.user_id: + user.save() + return user + elif user.email and user.user_id: + account_user = user_service.create_user( + email=user.email, user_id=user.user_id + ) + return account_user + return None diff --git a/backend/account_v2/authentication_helper.py b/backend/account_v2/authentication_helper.py new file mode 100644 index 000000000..e303445fc --- /dev/null +++ b/backend/account_v2/authentication_helper.py @@ -0,0 +1,121 @@ +import logging +from typing import Any + +from account_v2.dto import MemberData +from account_v2.models import Organization, User +from account_v2.user import UserService +from platform_settings_v2.platform_auth_service import PlatformAuthenticationService +from tenant_account_v2.organization_member_service import OrganizationMemberService + +logger = logging.getLogger(__name__) + + +class AuthenticationHelper: + def __init__(self) -> None: + pass + + def list_of_members_from_user_model( + self, model_data: list[Any] + ) -> list[MemberData]: + members: list[MemberData] = [] + for data in model_data: + user_id = data.user_id + email = data.email + name = data.username + + members.append(MemberData(user_id=user_id, email=email, name=name)) + + return members + + @staticmethod + def get_or_create_user_by_email(user_id: str, email: str) -> User: + """Get or create a user with the given email. + + If a user with the given email already exists, return that user. + Otherwise, create a new user with the given email and return it. + + Parameters: + user_id (str): The ID of the user. + email (str): The email of the user. + + Returns: + User: The user with the given email. + """ + user_service = UserService() + user = user_service.get_user_by_email(email) + if user and not user.user_id: + user = user_service.update_user(user, user_id) + if not user: + user = user_service.create_user(email, user_id) + return user + + def create_initial_platform_key( + self, user: User, organization: Organization + ) -> None: + """Create an initial platform key for the given user and organization. + + This method generates a new platform key with the specified parameters + and saves it to the database. The generated key is set as active and + assigned the name "Key #1". The key is associated with the provided + user and organization. + + Parameters: + user (User): The user for whom the platform key is being created. + organization (Organization): + The organization to which the platform key belongs. + + Raises: + Exception: If an error occurs while generating the platform key. + + Returns: + None + """ + try: + PlatformAuthenticationService.generate_platform_key( + is_active=True, + key_name="Key #1", + user=user, + organization=organization, + ) + except Exception: + logger.error( + "Failed to create default platform key for " + f"organization {organization.organization_id}" + ) + + @staticmethod + def remove_users_from_organization_by_pks( + user_pks: list[str], + ) -> None: + """Remove users from an organization by their primary keys. + + Parameters: + user_pks (list[str]): The primary keys of the users to remove. + """ + # removing user from organization + OrganizationMemberService.remove_users_by_user_pks(user_pks) + # removing user m2m relations , while removing user + for user_pk in user_pks: + User.objects.get(pk=user_pk).prompt_registries.clear() + User.objects.get(pk=user_pk).shared_custom_tools.clear() + User.objects.get(pk=user_pk).shared_adapters_instance.clear() + + @staticmethod + def remove_user_from_organization_by_user_id( + user_id: str, organization_id: str + ) -> None: + """Remove users from an organization by their user_id. + + Parameters: + user_id (str): The user_id of the users to remove. + """ + # removing user from organization + OrganizationMemberService.remove_user_by_user_id(user_id) + # removing user m2m relations , while removing user + User.objects.get(user_id=user_id).prompt_registries.clear() + User.objects.get(user_id=user_id).shared_custom_tools.clear() + User.objects.get(user_id=user_id).shared_adapters_instance.clear() + # removing user from organization cache + OrganizationMemberService.remove_user_membership_in_organization_cache( + user_id=user_id, organization_id=organization_id + ) diff --git a/backend/account_v2/authentication_plugin_registry.py b/backend/account_v2/authentication_plugin_registry.py new file mode 100644 index 000000000..4521af1cb --- /dev/null +++ b/backend/account_v2/authentication_plugin_registry.py @@ -0,0 +1,96 @@ +import logging +import os +from importlib import import_module +from typing import Any + +from account_v2.constants import PluginConfig +from django.apps import apps + +logger = logging.getLogger(__name__) + + +def _load_plugins() -> dict[str, dict[str, Any]]: + """Iterating through the Authentication plugins and register their + metadata.""" + auth_app = apps.get_app_config(PluginConfig.PLUGINS_APP) + auth_package_path = auth_app.module.__package__ + auth_dir = os.path.join(auth_app.path, PluginConfig.AUTH_PLUGIN_DIR) + auth_package_path = f"{auth_package_path}.{PluginConfig.AUTH_PLUGIN_DIR}" + auth_modules = {} + + for item in os.listdir(auth_dir): + # Loads a plugin only if name starts with `auth`. + if not item.startswith(PluginConfig.AUTH_MODULE_PREFIX): + continue + # Loads a plugin if it is in a directory. + if os.path.isdir(os.path.join(auth_dir, item)): + auth_module_name = item + # Loads a plugin if it is a shared library. + # Module name is extracted from shared library name. + # `auth.platform_architecture.so` will be file name and + # `auth` will be the module name. + elif item.endswith(".so"): + auth_module_name = item.split(".")[0] + else: + continue + try: + full_module_path = f"{auth_package_path}.{auth_module_name}" + module = import_module(full_module_path) + metadata = getattr(module, PluginConfig.AUTH_METADATA, {}) + if metadata.get(PluginConfig.METADATA_IS_ACTIVE, False): + auth_modules[auth_module_name] = { + PluginConfig.AUTH_MODULE: module, + PluginConfig.AUTH_METADATA: module.metadata, + } + logger.info( + "Loaded auth plugin: %s, is_active: %s", + module.metadata["name"], + module.metadata["is_active"], + ) + else: + logger.warning( + "Metadata is not active for %s authentication module.", + auth_module_name, + ) + except ModuleNotFoundError as exception: + logger.error( + "Error while importing authentication module : %s", + exception, + ) + + if len(auth_modules) > 1: + raise ValueError( + "Multiple authentication modules found." + "Only one authentication method is allowed." + ) + elif len(auth_modules) == 0: + logger.warning( + "No authentication modules found." + "Application will start without authentication module" + ) + return auth_modules + + +class AuthenticationPluginRegistry: + auth_modules: dict[str, dict[str, Any]] = _load_plugins() + + @classmethod + def is_plugin_available(cls) -> bool: + """Check if any authentication plugin is available. + + Returns: + bool: True if a plugin is available, False otherwise. + """ + return len(cls.auth_modules) > 0 + + @classmethod + def get_plugin(cls) -> Any: + """Get the selected authentication plugin. + + Returns: + AuthenticationService: Selected authentication plugin instance. + """ + chosen_auth_module = next(iter(cls.auth_modules.values())) + chosen_metadata = chosen_auth_module[PluginConfig.AUTH_METADATA] + service_class_name = chosen_metadata[PluginConfig.METADATA_SERVICE_CLASS] + return service_class_name() diff --git a/backend/account_v2/authentication_service.py b/backend/account_v2/authentication_service.py new file mode 100644 index 000000000..273ce4fdb --- /dev/null +++ b/backend/account_v2/authentication_service.py @@ -0,0 +1,395 @@ +import logging +import uuid +from typing import Any, Optional + +from account_v2.authentication_helper import AuthenticationHelper +from account_v2.constants import DefaultOrg, ErrorMessage, UserLoginTemplate +from account_v2.custom_exceptions import Forbidden, MethodNotImplemented +from account_v2.dto import ( + CallbackData, + MemberData, + MemberInvitation, + OrganizationData, + ResetUserPasswordDto, + UserInfo, + UserRoleData, +) +from account_v2.enums import UserRole +from account_v2.models import Organization, User +from account_v2.organization import OrganizationService +from account_v2.serializer import LoginRequestSerializer +from django.conf import settings +from django.contrib.auth import authenticate, login, logout +from django.contrib.auth.hashers import make_password +from django.http import HttpRequest +from django.shortcuts import redirect, render +from rest_framework.request import Request +from rest_framework.response import Response +from tenant_account_v2.models import OrganizationMember as OrganizationMember +from tenant_account_v2.organization_member_service import OrganizationMemberService + +Logger = logging.getLogger(__name__) + + +class AuthenticationService: + def __init__(self) -> None: + self.authentication_helper = AuthenticationHelper() + self.default_organization: Organization = self.user_organization() + + def user_login(self, request: Request) -> Any: + """Authenticate and log in a user. + + Args: + request (Request): The HTTP request object. + + Returns: + Any: The response object. + + Raises: + ValueError: If there is an error in the login credentials. + """ + if request.method == "GET": + return self.render_login_page(request) + try: + validated_data = self.validate_login_credentials(request) + username = validated_data.get("username") + password = validated_data.get("password") + except ValueError as e: + return render( + request, + UserLoginTemplate.TEMPLATE, + {UserLoginTemplate.ERROR_PLACE_HOLDER: str(e)}, + ) + if self.authenticate_and_login(request, username, password): + return redirect(settings.WEB_APP_ORIGIN_URL) + + return self.render_login_page_with_error(request, ErrorMessage.USER_LOGIN_ERROR) + + def is_authenticated(self, request: HttpRequest) -> bool: + """Check if the user is authenticated. + + Args: + request (Request): The HTTP request object. + + Returns: + bool: True if the user is authenticated, False otherwise. + """ + return request.user.is_authenticated + + def authenticate_and_login( + self, request: Request, username: str, password: str + ) -> bool: + """Authenticate and log in a user. + + Args: + request (Request): The HTTP request object. + username (str): The username of the user. + password (str): The password of the user. + + Returns: + bool: True if the user is successfully authenticated and logged in, + False otherwise. + """ + user = authenticate(request, username=username, password=password) + if user: + # To avoid conflicts with django superuser + if user.is_superuser: + return False + login(request, user) + return True + # Attempt to initiate default user and authenticate again + if self.set_default_user(username, password): + user = authenticate(request, username=username, password=password) + if user: + login(request, user) + return True + return False + + def render_login_page(self, request: Request) -> Any: + return render(request, UserLoginTemplate.TEMPLATE) + + def render_login_page_with_error(self, request: Request, error_message: str) -> Any: + return render( + request, + UserLoginTemplate.TEMPLATE, + {UserLoginTemplate.ERROR_PLACE_HOLDER: error_message}, + ) + + def validate_login_credentials(self, request: Request) -> Any: + """Validate the login credentials. + + Args: + request (Request): The HTTP request object. + + Returns: + dict: The validated login credentials. + + Raises: + ValueError: If the login credentials are invalid. + """ + serializer = LoginRequestSerializer(data=request.POST) + if not serializer.is_valid(): + error_messages = { + field: errors[0] for field, errors in serializer.errors.items() + } + first_error_message = list(error_messages.values())[0] + raise ValueError(first_error_message) + return serializer.validated_data + + def user_signup(self, request: HttpRequest) -> Any: + raise MethodNotImplemented() + + def is_admin_by_role(self, role: str) -> bool: + """Check the role with actual admin Role. + + Args: + role (str): input string + + Returns: + bool: _description_ + """ + try: + return UserRole(role.lower()) == UserRole.ADMIN + except ValueError: + return False + + def get_callback_data(self, request: Request) -> CallbackData: + return CallbackData( + user_id=request.user.user_id, + email=request.user.email, + token="", + ) + + def user_organization(self) -> Organization: + return Organization( + name=DefaultOrg.ORGANIZATION_NAME, + display_name=DefaultOrg.ORGANIZATION_NAME, + organization_id=DefaultOrg.ORGANIZATION_NAME, + schema_name=DefaultOrg.ORGANIZATION_NAME, + ) + + def handle_invited_user_while_callback( + self, request: Request, user: User + ) -> MemberData: + member_data: MemberData = MemberData( + user_id=user.user_id, + organization_id=self.default_organization.organization_id, + role=[UserRole.ADMIN.value], + ) + + return member_data + + def handle_authorization_callback(self, request: Request, backend: str) -> Response: + raise MethodNotImplemented() + + def add_to_organization( + self, + request: Request, + user: User, + data: Optional[dict[str, Any]] = None, + ) -> MemberData: + member_data: MemberData = MemberData( + user_id=user.user_id, + organization_id=self.default_organization.organization_id, + ) + + return member_data + + def remove_users_from_organization( + self, + admin: OrganizationMember, + organization_id: str, + user_ids: list[str], + ) -> bool: + raise MethodNotImplemented() + + def user_organizations(self, request: Request) -> list[OrganizationData]: + organizationData: OrganizationData = OrganizationData( + id=self.default_organization.organization_id, + display_name=self.default_organization.display_name, + name=self.default_organization.name, + ) + return [organizationData] + + def get_organizations_by_user_id(self, id: str) -> list[OrganizationData]: + organizationData: OrganizationData = OrganizationData( + id=self.default_organization.organization_id, + display_name=self.default_organization.display_name, + name=self.default_organization.name, + ) + return [organizationData] + + def get_organization_role_of_user( + self, user_id: str, organization_id: str + ) -> list[str]: + return [UserRole.ADMIN.value] + + def is_organization_admin(self, member: OrganizationMember) -> bool: + """Check if the organization member has administrative privileges. + + Args: + member (OrganizationMember): The organization member to check. + + Returns: + bool: True if the user has administrative privileges, + False otherwise. + """ + try: + return UserRole(member.role) == UserRole.ADMIN + except ValueError: + return False + + def check_user_organization_association(self, user_email: str) -> None: + """Check if the user is already associated with any organizations. + + Raises: + - UserAlreadyAssociatedException: + If the user is already associated with organizations. + """ + return None + + def get_roles(self) -> list[UserRoleData]: + return [ + UserRoleData(name=UserRole.ADMIN.value), + UserRoleData(name=UserRole.USER.value), + ] + + def get_invitations(self, organization_id: str) -> list[MemberInvitation]: + raise MethodNotImplemented() + + def frictionless_onboarding(self, organization: Organization, user: User) -> None: + raise MethodNotImplemented() + + def hubspot_signup_api(self, request: Request) -> None: + raise MethodNotImplemented() + + def delete_invitation(self, organization_id: str, invitation_id: str) -> bool: + raise MethodNotImplemented() + + def add_organization_user_role( + self, + admin: User, + organization_id: str, + user_id: str, + role_ids: list[str], + ) -> list[str]: + if admin.role == UserRole.ADMIN.value: + return role_ids + raise Forbidden + + def remove_organization_user_role( + self, + admin: User, + organization_id: str, + user_id: str, + role_ids: list[str], + ) -> list[str]: + if admin.role == UserRole.ADMIN.value: + return role_ids + raise Forbidden + + def get_organization_by_org_id(self, id: str) -> OrganizationData: + organizationData: OrganizationData = OrganizationData( + id=DefaultOrg.ORGANIZATION_NAME, + display_name=DefaultOrg.ORGANIZATION_NAME, + name=DefaultOrg.ORGANIZATION_NAME, + ) + return organizationData + + def set_default_user(self, username: str, password: str) -> bool: + """Set the default user for authentication. + + This method creates a default user with the provided username and + password if the username and password match the default values defined + in the 'DefaultOrg' class. The default user is saved in the database. + + Args: + username (str): The username of the default user. + password (str): The password of the default user. + + Returns: + bool: True if the default user is successfully created and saved, + False otherwise. + """ + if ( + username != DefaultOrg.MOCK_USER + or password != DefaultOrg.MOCK_USER_PASSWORD + ): + return False + + user, created = User.objects.get_or_create(username=DefaultOrg.MOCK_USER) + if created: + user.password = make_password(DefaultOrg.MOCK_USER_PASSWORD) + else: + user.user_id = DefaultOrg.MOCK_USER_ID + user.email = DefaultOrg.MOCK_USER_EMAIL + user.password = make_password(DefaultOrg.MOCK_USER_PASSWORD) + user.save() + return True + + def get_user_info(self, request: Request) -> Optional[UserInfo]: + user: User = request.user + if user: + return UserInfo( + id=user.id, + user_id=user.user_id, + name=user.username, + display_name=user.username, + email=user.email, + ) + else: + return None + + def get_organization_info(self, org_id: str) -> Optional[Organization]: + return OrganizationService.get_organization_by_org_id(org_id=org_id) + + def make_organization_and_add_member( + self, + user_id: str, + user_name: str, + organization_name: Optional[str] = None, + display_name: Optional[str] = None, + ) -> Optional[OrganizationData]: + organization: OrganizationData = OrganizationData( + id=str(uuid.uuid4()), + display_name=DefaultOrg.MOCK_ORG, + name=DefaultOrg.MOCK_ORG, + ) + return organization + + def make_user_organization_name(self) -> str: + return str(uuid.uuid4()) + + def make_user_organization_display_name(self, user_name: str) -> str: + name = f"{user_name}'s" if user_name else "Your" + return f"{name} organization" + + def user_logout(self, request: HttpRequest) -> Response: + """Log out the user. + + Args: + request (HttpRequest): The HTTP request object. + + Returns: + Response: The redirect response to the web app origin URL. + """ + logout(request) + return redirect(settings.WEB_APP_ORIGIN_URL) + + def get_organization_members_by_org_id( + self, organization_id: str + ) -> list[MemberData]: + users: list[OrganizationMember] = OrganizationMemberService.get_members() + return self.authentication_helper.list_of_members_from_user_model(users) + + def reset_user_password(self, user: User) -> ResetUserPasswordDto: + raise MethodNotImplemented() + + def invite_user( + self, + admin: OrganizationMember, + org_id: str, + email: str, + role: Optional[str] = None, + ) -> bool: + raise MethodNotImplemented() diff --git a/backend/account_v2/constants.py b/backend/account_v2/constants.py new file mode 100644 index 000000000..15c113790 --- /dev/null +++ b/backend/account_v2/constants.py @@ -0,0 +1,89 @@ +from django.conf import settings + + +class LoginConstant: + INVITATION = "invitation" + ORGANIZATION = "organization" + ORGANIZATION_NAME = "organization_name" + + +class Common: + NEXT_URL_VARIABLE = "next" + PUBLIC_SCHEMA_NAME = "public" + ID = "id" + USER_ID = "user_id" + USER_EMAIL = "email" + USER_EMAILS = "emails" + USER_IDS = "user_ids" + USER_ROLE = "role" + MAX_EMAIL_IN_REQUEST = 10 + LOG_EVENTS_ID = "log_events_id" + + +class UserModel: + USER_ID = "user_id" + ID = "id" + + +class OrganizationMemberModel: + USER_ID = "user__user_id" + ID = "user__id" + + +class Cookie: + ORG_ID = "org_id" + Z_CODE = "z_code" + CSRFTOKEN = "csrftoken" + + +class ErrorMessage: + ORGANIZATION_EXIST = "Organization already exists" + DUPLICATE_API = "It appears that a duplicate call may have been made." + USER_LOGIN_ERROR = "Invalid username or password. Please try again." + + +class DefaultOrg: + ORGANIZATION_NAME = "mock_org" + MOCK_ORG = "mock_org" + MOCK_USER = settings.DEFAULT_AUTH_USERNAME + MOCK_USER_ID = "mock_user_id" + MOCK_USER_EMAIL = "email@mock.com" + MOCK_USER_PASSWORD = settings.DEFAULT_AUTH_PASSWORD + + +class UserLoginTemplate: + TEMPLATE = "login.html" + ERROR_PLACE_HOLDER = "error_message" + + +class PluginConfig: + PLUGINS_APP = "plugins" + AUTH_MODULE_PREFIX = "auth" + AUTH_PLUGIN_DIR = "authentication" + AUTH_MODULE = "module" + AUTH_METADATA = "metadata" + METADATA_SERVICE_CLASS = "service_class" + METADATA_IS_ACTIVE = "is_active" + + +class AuthorizationErrorCode: + """Error codes + IDM: INVITATION DENIED MESSAGE (Unauthorized invitation) + INF: INVITATION NOT FOUND (Invitation is either invalid or has expired) + UMM: USER MEMBERSHIP MISCONDUCT + USF: USER FOUND (User Account Already Exists for Organization) + INE001: INVALID EMAIL Exception code when an invalid email address is used + like disposable. + INE002: INVALID EMAIL Exception code when an invalid email address format. + + Error code reference : + frontend/src/components/error/GenericError/GenericError.jsx. + """ + + IDM = "IDM" + UMM = "UMM" + INF = "INF" + USF = "USF" + USR = "USR" + INE001 = "INE001" + INE002 = "INE002" diff --git a/backend/account_v2/custom_auth_middleware.py b/backend/account_v2/custom_auth_middleware.py new file mode 100644 index 000000000..db49d7352 --- /dev/null +++ b/backend/account_v2/custom_auth_middleware.py @@ -0,0 +1,51 @@ +from account_v2.authentication_plugin_registry import AuthenticationPluginRegistry +from account_v2.authentication_service import AuthenticationService +from account_v2.constants import Common +from django.conf import settings +from django.http import HttpRequest, HttpResponse, JsonResponse +from utils.constants import Account +from utils.local_context import StateStore +from utils.user_session import UserSessionUtils + +from backend.constants import RequestHeader + + +class CustomAuthMiddleware: + def __init__(self, get_response: HttpResponse): + self.get_response = get_response + # One-time configuration and initialization. + + def __call__(self, request: HttpRequest) -> HttpResponse: + # Returns result without authenticated if added in whitelisted paths + if any(request.path.startswith(path) for path in settings.WHITELISTED_PATHS): + return self.get_response(request) + + # Authenticating With API_KEY + x_api_key = request.headers.get(RequestHeader.X_API_KEY) + if ( + settings.INTERNAL_SERVICE_API_KEY + and x_api_key == settings.INTERNAL_SERVICE_API_KEY + ): # Should API Key be in settings or just env alone? + return self.get_response(request) + + if AuthenticationPluginRegistry.is_plugin_available(): + auth_service: AuthenticationService = ( + AuthenticationPluginRegistry.get_plugin() + ) + else: + auth_service = AuthenticationService() + + is_authenticated = auth_service.is_authenticated(request) + + if is_authenticated: + StateStore.set(Common.LOG_EVENTS_ID, request.session.session_key) + StateStore.set( + Account.ORGANIZATION_ID, + UserSessionUtils.get_organization_id(request=request), + ) + response = self.get_response(request) + StateStore.clear(Account.ORGANIZATION_ID) + StateStore.clear(Common.LOG_EVENTS_ID) + + return response + return JsonResponse({"message": "Unauthorized"}, status=401) diff --git a/backend/account_v2/custom_authentication.py b/backend/account_v2/custom_authentication.py new file mode 100644 index 000000000..1f7cdcb6e --- /dev/null +++ b/backend/account_v2/custom_authentication.py @@ -0,0 +1,13 @@ +from typing import Any + +from django.http import HttpRequest +from rest_framework.exceptions import AuthenticationFailed + + +def api_login_required(view_func: Any) -> Any: + def wrapper(request: HttpRequest, *args: Any, **kwargs: Any) -> Any: + if request.user and request.session and "user" in request.session: + return view_func(request, *args, **kwargs) + raise AuthenticationFailed("Unauthorized") + + return wrapper diff --git a/backend/account_v2/custom_cache.py b/backend/account_v2/custom_cache.py new file mode 100644 index 000000000..182f980f5 --- /dev/null +++ b/backend/account_v2/custom_cache.py @@ -0,0 +1,12 @@ +from django_redis import get_redis_connection + + +class CustomCache: + def __init__(self) -> None: + self.cache = get_redis_connection("default") + + def rpush(self, key: str, value: str) -> None: + self.cache.rpush(key, value) + + def lrem(self, key: str, value: str) -> None: + self.cache.lrem(key, value) diff --git a/backend/account_v2/custom_exceptions.py b/backend/account_v2/custom_exceptions.py new file mode 100644 index 000000000..bec24e16c --- /dev/null +++ b/backend/account_v2/custom_exceptions.py @@ -0,0 +1,60 @@ +from typing import Optional + +from rest_framework.exceptions import APIException + + +class ConflictError(Exception): + def __init__(self, message: str) -> None: + self.message = message + super().__init__(self.message) + + +class MethodNotImplemented(APIException): + status_code = 501 + default_detail = "Method Not Implemented" + + +class DuplicateData(APIException): + status_code = 400 + default_detail = "Duplicate Data" + + def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): + if detail is not None: + self.detail = detail + if code is not None: + self.code = code + super().__init__(detail, code) + + +class TableNotExistError(APIException): + status_code = 400 + default_detail = "Unknown Table" + + def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): + if detail is not None: + self.detail = detail + if code is not None: + self.code = code + super().__init__() + + +class UserNotExistError(APIException): + status_code = 400 + default_detail = "Unknown User" + + def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): + if detail is not None: + self.detail = detail + if code is not None: + self.code = code + super().__init__() + + +class Forbidden(APIException): + status_code = 403 + default_detail = "Do not have permission to perform this action." + + +class UserAlreadyAssociatedException(APIException): + status_code = 400 + default_detail = "User is already associated with one organization." diff --git a/backend/account_v2/dto.py b/backend/account_v2/dto.py new file mode 100644 index 000000000..1554901fb --- /dev/null +++ b/backend/account_v2/dto.py @@ -0,0 +1,131 @@ +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass +class MemberData: + user_id: str + email: Optional[str] = None + name: Optional[str] = None + picture: Optional[str] = None + role: Optional[list[str]] = None + organization_id: Optional[str] = None + + +@dataclass +class OrganizationData: + id: str + display_name: str + name: str + + +@dataclass +class CallbackData: + user_id: str + email: str + token: Any + + +@dataclass +class OrganizationSignupRequestBody: + name: str + display_name: str + organization_id: str + + +@dataclass +class OrganizationSignupResponse: + name: str + display_name: str + organization_id: str + created_at: str + + +@dataclass +class UserInfo: + email: str + user_id: str + id: Optional[str] = None + name: Optional[str] = None + display_name: Optional[str] = None + family_name: Optional[str] = None + picture: Optional[str] = None + + +@dataclass +class UserSessionInfo: + id: str + user_id: str + email: str + organization_id: str + user: UserInfo + + @staticmethod + def from_dict(data: dict[str, Any]) -> "UserSessionInfo": + return UserSessionInfo( + id=data["id"], + user_id=data["user_id"], + email=data["email"], + organization_id=data["organization_id"], + ) + + def to_dict(self) -> Any: + return { + "id": self.id, + "user_id": self.user_id, + "email": self.email, + "organization_id": self.organization_id, + } + + +@dataclass +class GetUserReposne: + user: UserInfo + organizations: list[OrganizationData] + + +@dataclass +class ResetUserPasswordDto: + status: bool + message: str + + +@dataclass +class UserInviteResponse: + email: str + status: str + message: Optional[str] = None + + +@dataclass +class UserRoleData: + name: str + id: Optional[str] = None + description: Optional[str] = None + + +@dataclass +class MemberInvitation: + """Represents an invitation to join an organization. + + Attributes: + id (str): The unique identifier for the invitation. + email (str): The user email. + roles (List[str]): The roles assigned to the invitee. + created_at (Optional[str]): The timestamp when the invitation + was created. + expires_at (Optional[str]): The timestamp when the invitation expires. + """ + + id: str + email: str + roles: list[str] + created_at: Optional[str] = None + expires_at: Optional[str] = None + + +@dataclass +class UserOrganizationRole: + user_id: str + role: UserRoleData + organization_id: str diff --git a/backend/account_v2/enums.py b/backend/account_v2/enums.py new file mode 100644 index 000000000..d8209ec2d --- /dev/null +++ b/backend/account_v2/enums.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class UserRole(Enum): + USER = "user" + ADMIN = "admin" diff --git a/backend/account_v2/exceptions.py b/backend/account_v2/exceptions.py new file mode 100644 index 000000000..2f5d34a84 --- /dev/null +++ b/backend/account_v2/exceptions.py @@ -0,0 +1,31 @@ +from rest_framework.exceptions import APIException + + +class UserIdNotExist(APIException): + status_code = 404 + default_detail = "User ID does not exist" + + +class UserAlreadyExistInOrganization(APIException): + status_code = 403 + default_detail = "User allready exist in the organization" + + +class OrganizationNotExist(APIException): + status_code = 404 + default_detail = "Organization does not exist" + + +class UnknownException(APIException): + status_code = 500 + default_detail = "An unexpected error occurred" + + +class BadRequestException(APIException): + status_code = 400 + default_detail = "Bad Request" + + +class Unauthorized(APIException): + status_code = 401 + default_detail = "Unauthorized" diff --git a/backend/account_v2/migrations/__init__.py b/backend/account_v2/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/account_v2/models.py b/backend/account_v2/models.py new file mode 100644 index 000000000..61428bda0 --- /dev/null +++ b/backend/account_v2/models.py @@ -0,0 +1,144 @@ +import uuid + +from django.contrib.auth.models import AbstractUser, Group, Permission +from django.db import models + +from backend.constants import FieldLengthConstants as FieldLength + +NAME_SIZE = 64 +KEY_SIZE = 64 + + +class Organization(models.Model): + """Stores data related to an organization. + + The fields created_by and modified_by is updated after a + :model:`account.User` is created. + """ + + name = models.CharField(max_length=NAME_SIZE) + display_name = models.CharField(max_length=NAME_SIZE) + organization_id = models.CharField( + max_length=FieldLength.ORG_NAME_SIZE, unique=True + ) + created_by = models.ForeignKey( + "User", + on_delete=models.SET_NULL, + related_name="orgs_created", + null=True, + blank=True, + ) + modified_by = models.ForeignKey( + "User", + on_delete=models.SET_NULL, + related_name="orgs_modified", + null=True, + blank=True, + ) + modified_at = models.DateTimeField(auto_now=True) + created_at = models.DateTimeField(auto_now=True) + allowed_token_limit = models.IntegerField( + default=-1, + db_comment="token limit set in case of frition less onbaoarded org", + ) + + class Meta: + verbose_name = "Organization" + verbose_name_plural = "Organizations" + db_table = "organization_v2" + + +class User(AbstractUser): + """Stores data related to a user belonging to any organization. + + Every org, user is assumed to be unique. + """ + + # Third Party Authentication User ID + user_id = models.CharField() + project_storage_created = models.BooleanField(default=False) + created_by = models.ForeignKey( + "User", + on_delete=models.SET_NULL, + related_name="users_created", + null=True, + blank=True, + ) + modified_by = models.ForeignKey( + "User", + on_delete=models.SET_NULL, + related_name="users_modified", + null=True, + blank=True, + ) + modified_at = models.DateTimeField(auto_now=True) + created_at = models.DateTimeField(auto_now_add=True) + + # Specify a unique related_name for the groups field + groups = models.ManyToManyField( + Group, + related_name="users", + related_query_name="user", + blank=True, + ) + + # Specify a unique related_name for the user_permissions field + user_permissions = models.ManyToManyField( + Permission, + related_name="users", + related_query_name="user", + blank=True, + ) + + def __str__(self): # type: ignore + return f"User({self.id}, email: {self.email}, userId: {self.user_id})" + + class Meta: + verbose_name = "User" + verbose_name_plural = "Users" + db_table = "user_v2" + + +class PlatformKey(models.Model): + """Model to hold details of Platform keys. + + Only users with admin role are allowed to perform any operation + related keys. + """ + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + key = models.UUIDField(default=uuid.uuid4) + key_name = models.CharField(max_length=KEY_SIZE, null=False, blank=True, default="") + is_active = models.BooleanField(default=False) + organization = models.ForeignKey( + "Organization", + on_delete=models.SET_NULL, + related_name="platform_keys", + null=True, + blank=True, + ) + created_by = models.ForeignKey( + "User", + on_delete=models.SET_NULL, + related_name="platform_keys_created", + null=True, + blank=True, + ) + modified_by = models.ForeignKey( + "User", + on_delete=models.SET_NULL, + related_name="platform_keys_modified", + null=True, + blank=True, + ) + + class Meta: + verbose_name = "Platform Key" + verbose_name_plural = "Platform Keys" + db_table = "platform_key_v2" + constraints = [ + models.UniqueConstraint( + fields=["key_name", "organization"], + name="unique_key_name_organization", + ), + ] diff --git a/backend/account_v2/organization.py b/backend/account_v2/organization.py new file mode 100644 index 000000000..72ef69200 --- /dev/null +++ b/backend/account_v2/organization.py @@ -0,0 +1,45 @@ +import logging +from typing import Optional + +from account_v2.models import Organization +from account_v2.subscription_loader import SubscriptionConfig, load_plugins +from django.db import IntegrityError + +Logger = logging.getLogger(__name__) + +subscription_loader = load_plugins() + + +class OrganizationService: + def __init__(self): # type: ignore + pass + + @staticmethod + def get_organization_by_org_id(org_id: str) -> Optional[Organization]: + try: + return Organization.objects.get(organization_id=org_id) # type: ignore + except Organization.DoesNotExist: + return None + + @staticmethod + def create_organization( + name: str, display_name: str, organization_id: str + ) -> Organization: + try: + organization: Organization = Organization( + name=name, + display_name=display_name, + organization_id=organization_id, + ) + organization.save() + + for subscription_plugin in subscription_loader: + cls = subscription_plugin[SubscriptionConfig.METADATA][ + SubscriptionConfig.METADATA_SERVICE_CLASS + ] + cls.add(organization_id=organization_id) + + except IntegrityError as error: + Logger.info(f"[Duplicate Id] Failed to create Organization Error: {error}") + raise error + return organization diff --git a/backend/account_v2/serializer.py b/backend/account_v2/serializer.py new file mode 100644 index 000000000..87c531a4a --- /dev/null +++ b/backend/account_v2/serializer.py @@ -0,0 +1,117 @@ +import re +from typing import Optional + +from account_v2.models import Organization, User +from rest_framework import serializers + + +class OrganizationSignupSerializer(serializers.Serializer): + name = serializers.CharField(required=True, max_length=150) + display_name = serializers.CharField(required=True, max_length=150) + organization_id = serializers.CharField(required=True, max_length=30) + + def validate_organization_id(self, value): # type: ignore + if not re.match(r"^[a-z0-9_-]+$", value): + raise serializers.ValidationError( + "organization_code should only contain " + "alphanumeric characters,_ and -." + ) + return value + + +class OrganizationCallbackSerializer(serializers.Serializer): + id = serializers.CharField(required=False) + + +class GetOrganizationsResponseSerializer(serializers.Serializer): + id = serializers.CharField() + display_name = serializers.CharField() + name = serializers.CharField() + # Add more fields as needed + + def to_representation(self, instance): # type: ignore + data = super().to_representation(instance) + # Modify the representation if needed + return data + + +class GetOrganizationMembersResponseSerializer(serializers.Serializer): + user_id = serializers.CharField() + email = serializers.CharField() + name = serializers.CharField() + picture = serializers.CharField() + # Add more fields as needed + + def to_representation(self, instance): # type: ignore + data = super().to_representation(instance) + # Modify the representation if needed + return data + + +class OrganizationSerializer(serializers.Serializer): + name = serializers.CharField() + organization_id = serializers.CharField() + + +class SetOrganizationsResponseSerializer(serializers.Serializer): + id = serializers.CharField() + email = serializers.CharField() + name = serializers.CharField() + display_name = serializers.CharField() + family_name = serializers.CharField() + picture = serializers.CharField() + # Add more fields as needed + + def to_representation(self, instance): # type: ignore + data = super().to_representation(instance) + # Modify the representation if needed + return data + + +class ModelTenantSerializer(serializers.ModelSerializer): + class Meta: + model = Organization + fields = fields = ("name", "created_on") + + +class UserSerializer(serializers.ModelSerializer): + class Meta: + model = User + fields = ("id", "username") + + +class OrganizationSignupResponseSerializer(serializers.Serializer): + name = serializers.CharField() + display_name = serializers.CharField() + organization_id = serializers.CharField() + created_at = serializers.CharField() + + +class LoginRequestSerializer(serializers.Serializer): + username = serializers.CharField(required=True) + password = serializers.CharField(required=True) + + def validate_username(self, value: Optional[str]) -> str: + """Check that the username is not empty and has at least 3 + characters.""" + if not value or len(value) < 3: + raise serializers.ValidationError( + "Username must be at least 3 characters long." + ) + return value + + def validate_password(self, value: Optional[str]) -> str: + """Check that the password is not empty and has at least 3 + characters.""" + if not value or len(value) < 3: + raise serializers.ValidationError( + "Password must be at least 3 characters long." + ) + return value + + +class UserSessionResponseSerializer(serializers.Serializer): + id = serializers.IntegerField() + user_id = serializers.CharField() + email = serializers.CharField() + organization_id = serializers.CharField() diff --git a/backend/account_v2/subscription_loader.py b/backend/account_v2/subscription_loader.py new file mode 100644 index 000000000..d380ed19d --- /dev/null +++ b/backend/account_v2/subscription_loader.py @@ -0,0 +1,77 @@ +import logging +import os +from importlib import import_module +from typing import Any + +from django.apps import apps + +logger = logging.getLogger(__name__) + + +class SubscriptionConfig: + """Loader config for subscription plugins.""" + + PLUGINS_APP = "plugins" + PLUGIN_DIR = "subscription" + MODULE = "module" + METADATA = "metadata" + METADATA_NAME = "name" + METADATA_SERVICE_CLASS = "service_class" + METADATA_IS_ACTIVE = "is_active" + + +def load_plugins() -> list[Any]: + """Iterate through the subscription plugins and register them.""" + plugins_app = apps.get_app_config(SubscriptionConfig.PLUGINS_APP) + package_path = plugins_app.module.__package__ + subscription_dir = os.path.join(plugins_app.path, SubscriptionConfig.PLUGIN_DIR) + subscription_package_path = f"{package_path}.{SubscriptionConfig.PLUGIN_DIR}" + subscription_plugins: list[Any] = [] + + if not os.path.exists(subscription_dir): + return subscription_plugins + + for item in os.listdir(subscription_dir): + # Loads a plugin if it is in a directory. + if os.path.isdir(os.path.join(subscription_dir, item)): + subscription_module_name = item + # Loads a plugin if it is a shared library. + # Module name is extracted from shared library name. + # `subscription.platform_architecture.so` will be file name and + # `subscription` will be the module name. + elif item.endswith(".so"): + subscription_module_name = item.split(".")[0] + else: + continue + try: + full_module_path = f"{subscription_package_path}.{subscription_module_name}" + module = import_module(full_module_path) + metadata = getattr(module, SubscriptionConfig.METADATA, {}) + + if metadata.get(SubscriptionConfig.METADATA_IS_ACTIVE, False): + subscription_plugins.append( + { + SubscriptionConfig.MODULE: module, + SubscriptionConfig.METADATA: module.metadata, + } + ) + logger.info( + "Loaded subscription plugin: %s, is_active: %s", + module.metadata[SubscriptionConfig.METADATA_NAME], + module.metadata[SubscriptionConfig.METADATA_IS_ACTIVE], + ) + else: + logger.info( + "subscription plugin %s is not active.", + subscription_module_name, + ) + except ModuleNotFoundError as exception: + logger.error( + "Error while importing subscription plugin: %s", + exception, + ) + + if len(subscription_plugins) == 0: + logger.info("No subscription plugins found.") + + return subscription_plugins diff --git a/backend/account_v2/templates/index.html b/backend/account_v2/templates/index.html new file mode 100644 index 000000000..ffa0b6085 --- /dev/null +++ b/backend/account_v2/templates/index.html @@ -0,0 +1,11 @@ + + + + + ZipstackID Django App Example + + +

Welcome Guest

+

Login

+ + diff --git a/backend/account_v2/templates/login.html b/backend/account_v2/templates/login.html new file mode 100644 index 000000000..4edf8e3f1 --- /dev/null +++ b/backend/account_v2/templates/login.html @@ -0,0 +1,134 @@ + + + + + + Login + + + +
+
+ +
+
+ {% load static %} +
+ My image +
+ + {% if error_message %} +

{{ error_message }}

+ {% endif %} + {% csrf_token %} + + + +

+ +
+ + + diff --git a/backend/account_v2/tests.py b/backend/account_v2/tests.py new file mode 100644 index 000000000..a39b155ac --- /dev/null +++ b/backend/account_v2/tests.py @@ -0,0 +1 @@ +# Create your tests here. diff --git a/backend/account_v2/urls.py b/backend/account_v2/urls.py new file mode 100644 index 000000000..767be8aba --- /dev/null +++ b/backend/account_v2/urls.py @@ -0,0 +1,22 @@ +from account_v2.views import ( + callback, + create_organization, + get_organizations, + get_session_data, + login, + logout, + set_organization, + signup, +) +from django.urls import path + +urlpatterns = [ + path("login", login, name="login"), + path("signup", signup, name="signup"), + path("logout", logout, name="logout"), + path("callback", callback, name="callback"), + path("session", get_session_data, name="session"), + path("organizations", get_organizations, name="get_organizations"), + path("organization//set", set_organization, name="set_organization"), + path("organization/create", create_organization, name="create_organization"), +] diff --git a/backend/account_v2/user.py b/backend/account_v2/user.py new file mode 100644 index 000000000..7967521e9 --- /dev/null +++ b/backend/account_v2/user.py @@ -0,0 +1,55 @@ +import logging +from typing import Any, Optional + +from account_v2.models import User +from django.db import IntegrityError + +Logger = logging.getLogger(__name__) + + +class UserService: + def __init__( + self, + ) -> None: + pass + + def create_user(self, email: str, user_id: str) -> User: + try: + user: User = User(email=email, user_id=user_id, username=email) + user.save() + except IntegrityError as error: + Logger.info(f"[Duplicate Id] Failed to create User Error: {error}") + raise error + return user + + def update_user(self, user: User, user_id: str) -> User: + user.user_id = user_id + user.save() + return user + + def get_user_by_email(self, email: str) -> Optional[User]: + try: + user: User = User.objects.get(email=email) + return user + except User.DoesNotExist: + return None + + def get_user_by_user_id(self, user_id: str) -> Any: + try: + return User.objects.get(user_id=user_id) + except User.DoesNotExist: + return None + + def get_user_by_id(self, id: str) -> Any: + """Retrieve a user by their ID, taking into account the schema context. + + Args: + id (str): The ID of the user. + + Returns: + Any: The user object if found, or None if not found. + """ + try: + return User.objects.get(id=id) + except User.DoesNotExist: + return None diff --git a/backend/account_v2/views.py b/backend/account_v2/views.py new file mode 100644 index 000000000..b73add92a --- /dev/null +++ b/backend/account_v2/views.py @@ -0,0 +1,169 @@ +import logging +from typing import Any + +from account_v2.authentication_controller import AuthenticationController +from account_v2.dto import ( + OrganizationSignupRequestBody, + OrganizationSignupResponse, + UserSessionInfo, +) +from account_v2.models import Organization +from account_v2.organization import OrganizationService +from account_v2.serializer import ( + OrganizationSignupResponseSerializer, + OrganizationSignupSerializer, + UserSessionResponseSerializer, +) +from rest_framework import status +from rest_framework.decorators import api_view +from rest_framework.request import Request +from rest_framework.response import Response +from utils.user_session import UserSessionUtils + +Logger = logging.getLogger(__name__) + + +@api_view(["POST"]) +def create_organization(request: Request) -> Response: + serializer = OrganizationSignupSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + try: + requestBody: OrganizationSignupRequestBody = makeSignupRequestParams(serializer) + + organization: Organization = OrganizationService.create_organization( + requestBody.name, + requestBody.display_name, + requestBody.organization_id, + ) + response = makeSignupResponse(organization) + return Response( + status=status.HTTP_201_CREATED, + data={"message": "success", "tenant": response}, + ) + except Exception as error: + Logger.error(error) + return Response( + status=status.HTTP_500_INTERNAL_SERVER_ERROR, data="Unknown Error" + ) + + +@api_view(["GET"]) +def callback(request: Request) -> Response: + auth_controller = AuthenticationController() + return auth_controller.authorization_callback(request) + + +@api_view(["GET", "POST"]) +def login(request: Request) -> Response: + auth_controller = AuthenticationController() + return auth_controller.user_login(request) + + +@api_view(["GET"]) +def signup(request: Request) -> Response: + auth_controller = AuthenticationController() + return auth_controller.user_signup(request) + + +@api_view(["GET"]) +def logout(request: Request) -> Response: + auth_controller = AuthenticationController() + return auth_controller.user_logout(request) + + +@api_view(["GET"]) +def get_organizations(request: Request) -> Response: + """get_organizations. + + Retrieve the list of organizations to which the user belongs. + Args: + request (HttpRequest): _description_ + + Returns: + Response: A list of organizations with associated information. + """ + auth_controller = AuthenticationController() + return auth_controller.user_organizations(request) + + +@api_view(["POST"]) +def set_organization(request: Request, id: str) -> Response: + """set_organization. + + Set the current organization to use. + Args: + request (HttpRequest): _description_ + id (String): organization Id + + Returns: + Response: Contains the User and Current organization details. + """ + + auth_controller = AuthenticationController() + return auth_controller.set_user_organization(request, id) + + +@api_view(["GET"]) +def get_session_data(request: Request) -> Response: + """get_session_data. + + Retrieve the current session data. + Args: + request (HttpRequest): _description_ + + Returns: + Response: Contains the User and Current organization details. + """ + response = make_session_response(request) + + return Response( + status=status.HTTP_201_CREATED, + data=response, + ) + + +def make_session_response( + request: Request, +) -> Any: + """make_session_response. + + Make the current session data. + Args: + request (HttpRequest): _description_ + + Returns: + User and Current organization details. + """ + auth_controller = AuthenticationController() + return UserSessionResponseSerializer( + UserSessionInfo( + id=request.user.id, + user_id=request.user.user_id, + email=request.user.email, + user=auth_controller.get_user_info(request), + organization_id=UserSessionUtils.get_organization_id(request), + ) + ).data + + +def makeSignupRequestParams( + serializer: OrganizationSignupSerializer, +) -> OrganizationSignupRequestBody: + return OrganizationSignupRequestBody( + serializer.validated_data["name"], + serializer.validated_data["display_name"], + serializer.validated_data["organization_id"], + ) + + +def makeSignupResponse( + organization: Organization, +) -> Any: + return OrganizationSignupResponseSerializer( + OrganizationSignupResponse( + organization.name, + organization.display_name, + organization.organization_id, + organization.created_at, + ) + ).data diff --git a/backend/backend/public_urls.py b/backend/backend/public_urls.py index 27ece8786..017b79dce 100644 --- a/backend/backend/public_urls.py +++ b/backend/backend/public_urls.py @@ -47,10 +47,25 @@ try: - import pluggable_apps.platform_admin.urls # noqa: F401 + import pluggable_apps.platform_admin.urls # noqa # pylint: disable=unused-import urlpatterns += [ path(f"{path_prefix}/", include("pluggable_apps.platform_admin.urls")), ] except ImportError: pass + +try: + import pluggable_apps.public_shares.share_controller.urls # noqa # pylint: disable=unused-import + + share_path_prefix = settings.PUBLIC_PATH_PREFIX + + urlpatterns += [ + # Public Sharing + path( + f"{share_path_prefix}/", + include("pluggable_apps.public_shares.share_controller.urls"), + ), + ] +except ImportError: + pass diff --git a/backend/backend/urls.py b/backend/backend/urls.py index 8b299ecc6..a86a6c7b5 100644 --- a/backend/backend/urls.py +++ b/backend/backend/urls.py @@ -84,7 +84,7 @@ # Subscription urls try: - import pluggable_apps.subscription.urls # noqa # pylint: disable=unused-import + import pluggable_apps.subscription.urls # noqa # pylint: disable=unused-import urlpatterns += [ path("", include("pluggable_apps.subscription.urls")), @@ -93,10 +93,32 @@ pass try: - import pluggable_apps.manual_review.urls # noqa: F401 + import pluggable_apps.manual_review.urls # noqa # pylint: disable=unused-import urlpatterns += [ path("manual_review/", include("pluggable_apps.manual_review.urls")), ] except ImportError: pass + +# Public share urls +try: + + import pluggable_apps.public_shares.share_manager.urls # noqa # pylint: disable=unused-import + + urlpatterns += [ + path("", include("pluggable_apps.public_shares.share_manager.urls")), + ] +except ImportError: + pass + +# Clone urls +try: + + import pluggable_apps.clone.urls # noqa # pylint: disable=unused-import + + urlpatterns += [ + path("", include("pluggable_apps.clone.urls")), + ] +except ImportError: + pass diff --git a/backend/connector_processor/connector_processor.py b/backend/connector_processor/connector_processor.py index dcababce3..b2346a9d4 100644 --- a/backend/connector_processor/connector_processor.py +++ b/backend/connector_processor/connector_processor.py @@ -12,8 +12,7 @@ InValidConnectorId, InValidConnectorMode, OAuthTimeOut, - TestConnectorException, - TestConnectorInputException, + TestConnectorInputError, ) from unstract.connectors.base import UnstractConnector @@ -100,15 +99,15 @@ def get_all_supported_connectors( return supported_connectors @staticmethod - def test_connectors(connector_id: str, cred_string: dict[str, Any]) -> bool: + def test_connectors(connector_id: str, credentials: dict[str, Any]) -> bool: logger.info(f"Testing connector: {connector_id}") connector: dict[str, Any] = fetch_connectors_by_key_value( ConnectorKeys.ID, connector_id )[0] if connector.get(ConnectorKeys.OAUTH): try: - oauth_key = cred_string.get(ConnectorAuthKey.OAUTH_KEY) - cred_string = ConnectorAuthHelper.get_oauth_creds_from_cache( + oauth_key = credentials.get(ConnectorAuthKey.OAUTH_KEY) + credentials = ConnectorAuthHelper.get_oauth_creds_from_cache( cache_key=oauth_key, delete_key=False ) except Exception as exc: @@ -120,17 +119,13 @@ def test_connectors(connector_id: str, cred_string: dict[str, Any]) -> bool: try: connector_impl = Connectorkit().get_connector_by_id( - connector_id, cred_string + connector_id, credentials ) test_result = connector_impl.test_credentials() logger.info(f"{connector_id} test result: {test_result}") return test_result except ConnectorError as e: - logger.error(f"Error while testing {connector_id}: {e}") - raise TestConnectorInputException(core_err=e) - except Exception as e: - logger.error(f"Error while testing {connector_id}: {e}") - raise TestConnectorException + raise TestConnectorInputError(core_err=e) def get_connector_data_with_key(connector_id: str, key_value: str) -> Any: """Generic Function to get connector data with provided key.""" diff --git a/backend/connector_processor/exceptions.py b/backend/connector_processor/exceptions.py index 1df8079d8..22a1e2966 100644 --- a/backend/connector_processor/exceptions.py +++ b/backend/connector_processor/exceptions.py @@ -31,7 +31,7 @@ class JSONParseException(APIException): class OAuthTimeOut(APIException): status_code = 408 - default_detail = "Timed Out. Please re authenticate." + default_detail = "Timed out. Please re-authenticate." class InternalServiceError(APIException): @@ -44,7 +44,7 @@ class TestConnectorException(APIException): default_detail = "Error while testing connector." -class TestConnectorInputException(UnstractBaseException): +class TestConnectorInputError(UnstractBaseException): def __init__(self, core_err: ConnectorError) -> None: super().__init__(detail=core_err.message, core_err=core_err) self.default_detail = core_err.message diff --git a/backend/connector_processor/views.py b/backend/connector_processor/views.py index 55367c40d..edca86ba1 100644 --- a/backend/connector_processor/views.py +++ b/backend/connector_processor/views.py @@ -67,10 +67,10 @@ def test(self, request: Request) -> Response: """Tests the connector against the credentials passed.""" serializer: TestConnectorSerializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - connector_id = serializer.validated_data.get(ConnectorKeys.CONNECTOR_ID) + connector_id = serializer.validated_data.get(CIKey.CONNECTOR_ID) cred_string = serializer.validated_data.get(CIKey.CONNECTOR_METADATA) test_result = ConnectorProcessor.test_connectors( - connector_id=connector_id, cred_string=cred_string + connector_id=connector_id, credentials=cred_string ) return Response( {ConnectorKeys.IS_VALID: test_result}, diff --git a/backend/pdm.lock b/backend/pdm.lock index 230c6db84..4b896fa3e 100644 --- a/backend/pdm.lock +++ b/backend/pdm.lock @@ -542,13 +542,13 @@ files = [ [[package]] name = "cachetools" -version = "5.3.3" +version = "5.4.0" requires_python = ">=3.7" summary = "Extensible memoizing collections and decorators" groups = ["default", "dev"] files = [ - {file = "cachetools-5.3.3-py3-none-any.whl", hash = "sha256:0abad1021d3f8325b2fc1d2e9c8b9c9d57b04c3932657a72465447332c24d945"}, - {file = "cachetools-5.3.3.tar.gz", hash = "sha256:ba29e2dfa0b8b556606f097407ed1aa62080ee108ab0dc5ec9d6a723a007d105"}, + {file = "cachetools-5.4.0-py3-none-any.whl", hash = "sha256:3ae3b49a3d5e28a77a0be2b37dbcb89005058959cb2323858c2657c4a8cab474"}, + {file = "cachetools-5.4.0.tar.gz", hash = "sha256:b8adc2e7c07f105ced7bc56dbb6dfbe7c4a00acce20e2227b3f355be89bc6827"}, ] [[package]] @@ -1068,7 +1068,7 @@ files = [ [[package]] name = "dropboxdrivefs" -version = "1.3.1" +version = "1.4.1" requires_python = ">=3.5" summary = "Dropbox implementation for fsspec module" groups = ["default", "dev"] @@ -1078,7 +1078,7 @@ dependencies = [ "requests", ] files = [ - {file = "dropboxdrivefs-1.3.1.tar.gz", hash = "sha256:892ee9017c59648736d79c3989cadb9e129b469fcec0c68d12e42bd6826a962d"}, + {file = "dropboxdrivefs-1.4.1.tar.gz", hash = "sha256:6f3c6061d045813553ce91ed0e2b682f1d70bec74011943c92b3181faacefd34"}, ] [[package]] @@ -1098,14 +1098,14 @@ files = [ [[package]] name = "exceptiongroup" -version = "1.2.1" +version = "1.2.2" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" groups = ["default", "dev", "test"] marker = "python_version < \"3.11\"" files = [ - {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"}, - {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"}, + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, ] [[package]] @@ -1299,7 +1299,7 @@ files = [ [[package]] name = "google-api-python-client" -version = "2.136.0" +version = "2.137.0" requires_python = ">=3.7" summary = "Google API Client Library for Python" groups = ["default", "dev"] @@ -1311,8 +1311,8 @@ dependencies = [ "uritemplate<5,>=3.0.1", ] files = [ - {file = "google-api-python-client-2.136.0.tar.gz", hash = "sha256:161c722c8864e7ed39393e2b7eea76ef4e1c933a6a59f9d7c70409b6635f225d"}, - {file = "google_api_python_client-2.136.0-py2.py3-none-any.whl", hash = "sha256:5a554c8b5edf0a609b905d89d7ced82e8f6ac31da1e4d8d5684ef63dbc0e49f5"}, + {file = "google_api_python_client-2.137.0-py2.py3-none-any.whl", hash = "sha256:a8b5c5724885e5be9f5368739aa0ccf416627da4ebd914b410a090c18f84d692"}, + {file = "google_api_python_client-2.137.0.tar.gz", hash = "sha256:e739cb74aac8258b1886cb853b0722d47c81fe07ad649d7f2206f06530513c04"}, ] [[package]] @@ -1364,7 +1364,7 @@ files = [ [[package]] name = "google-cloud-aiplatform" -version = "1.58.0" +version = "1.59.0" requires_python = ">=3.8" summary = "Vertex AI API client library" groups = ["default", "dev"] @@ -1376,14 +1376,14 @@ dependencies = [ "google-cloud-resource-manager<3.0.0dev,>=1.3.3", "google-cloud-storage<3.0.0dev,>=1.32.0", "packaging>=14.3", - "proto-plus<2.0.0dev,>=1.22.0", + "proto-plus<2.0.0dev,>=1.22.3", "protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.19.5", "pydantic<3", "shapely<3.0.0dev", ] files = [ - {file = "google-cloud-aiplatform-1.58.0.tar.gz", hash = "sha256:7a05aceac4a6c7eaa26e684e9f202b829cc7e57f82bffe7281684275a553fcad"}, - {file = "google_cloud_aiplatform-1.58.0-py2.py3-none-any.whl", hash = "sha256:21f1320860f4916183ec939fdf2ff3fc1d7fdde97fe5795974257ab21f9458ec"}, + {file = "google-cloud-aiplatform-1.59.0.tar.gz", hash = "sha256:2bebb59c0ba3e3b4b568305418ca1b021977988adbee8691a5bed09b037e7e63"}, + {file = "google_cloud_aiplatform-1.59.0-py2.py3-none-any.whl", hash = "sha256:549e6eb1844b0f853043309138ebe2db00de4bbd8197b3bde26804ac163ef52a"}, ] [[package]] @@ -1862,7 +1862,7 @@ files = [ [[package]] name = "huggingface-hub" -version = "0.23.4" +version = "0.23.5" requires_python = ">=3.8.0" summary = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" groups = ["default", "dev"] @@ -1876,8 +1876,8 @@ dependencies = [ "typing-extensions>=3.7.4.3", ] files = [ - {file = "huggingface_hub-0.23.4-py3-none-any.whl", hash = "sha256:3a0b957aa87150addf0cc7bd71b4d954b78e749850e1e7fb29ebbd2db64ca037"}, - {file = "huggingface_hub-0.23.4.tar.gz", hash = "sha256:35d99016433900e44ae7efe1c209164a5a81dbbcd53a52f99c281dcd7ce22431"}, + {file = "huggingface_hub-0.23.5-py3-none-any.whl", hash = "sha256:d7a7d337615e11a45cc14a0ce5a605db6b038dc24af42866f731684825226e90"}, + {file = "huggingface_hub-0.23.5.tar.gz", hash = "sha256:67a9caba79b71235be3752852ca27da86bd54311d2424ca8afdb8dda056edf98"}, ] [[package]] @@ -2017,21 +2017,6 @@ files = [ {file = "kombu-5.3.7.tar.gz", hash = "sha256:011c4cd9a355c14a1de8d35d257314a1d2456d52b7140388561acac3cf1a97bf"}, ] -[[package]] -name = "llama-cloud" -version = "0.0.6" -requires_python = "<4,>=3.8" -summary = "" -groups = ["default", "dev"] -dependencies = [ - "httpx>=0.20.0", - "pydantic>=1.10", -] -files = [ - {file = "llama_cloud-0.0.6-py3-none-any.whl", hash = "sha256:0f07c8a865be632b543dec2bcad350a68a61f13413a7421b4b03de32c36f0194"}, - {file = "llama_cloud-0.0.6.tar.gz", hash = "sha256:33b94cd119133dcb2899c9b69e8e1c36aec7bc7e80062c55c65f15618722e091"}, -] - [[package]] name = "llama-index" version = "0.10.38" @@ -2091,7 +2076,7 @@ files = [ [[package]] name = "llama-index-core" -version = "0.10.53.post1" +version = "0.10.55" requires_python = "<4.0,>=3.8.1" summary = "Interface between LLMs and your data" groups = ["default", "dev"] @@ -2104,7 +2089,6 @@ dependencies = [ "dirtyjson<2.0.0,>=1.0.8", "fsspec>=2023.5.0", "httpx", - "llama-cloud<0.0.7,>=0.0.6", "nest-asyncio<2.0.0,>=1.5.8", "networkx>=3.0", "nltk<4.0.0,>=3.8.1", @@ -2121,8 +2105,8 @@ dependencies = [ "wrapt", ] files = [ - {file = "llama_index_core-0.10.53.post1-py3-none-any.whl", hash = "sha256:565d0967dd8f05456c66f5aca6ee6ee3dbc5645b6a55c81957f776ff029d6a99"}, - {file = "llama_index_core-0.10.53.post1.tar.gz", hash = "sha256:6219a737b66c887b406814b0d9db6e24addd35f3136ffb6a879e54ac3f133406"}, + {file = "llama_index_core-0.10.55-py3-none-any.whl", hash = "sha256:e2f7dbc9c992d4487dabad6a7b0f40ed145cce0ab99e52cc78e9caf0cd4c1c08"}, + {file = "llama_index_core-0.10.55.tar.gz", hash = "sha256:b02d46595c17805221a8f404c04a97609d1ce22e5be24ad7b7c4ac30e5181561"}, ] [[package]] @@ -2413,7 +2397,7 @@ files = [ [[package]] name = "llama-index-readers-file" -version = "0.1.29" +version = "0.1.30" requires_python = "<4.0,>=3.8.1" summary = "llama-index readers file integration" groups = ["default", "dev"] @@ -2424,8 +2408,8 @@ dependencies = [ "striprtf<0.0.27,>=0.0.26", ] files = [ - {file = "llama_index_readers_file-0.1.29-py3-none-any.whl", hash = "sha256:b25f3dbf7bf3e0635290e499e808db5ba955eab67f205a3ff1cea6a4eb93556a"}, - {file = "llama_index_readers_file-0.1.29.tar.gz", hash = "sha256:f9f696e738383e7d14078e75958fba5a7030f7994a20586e3140e1ca41395a54"}, + {file = "llama_index_readers_file-0.1.30-py3-none-any.whl", hash = "sha256:d5f6cdd4685ee73103c68b9bc0dfb0d05439033133fc6bd45ef31ff41519e723"}, + {file = "llama_index_readers_file-0.1.30.tar.gz", hash = "sha256:32f40465f2a8a65fa5773e03c9f4dd55164be934ae67fad62113680436787d91"}, ] [[package]] @@ -3045,7 +3029,7 @@ files = [ [[package]] name = "portalocker" -version = "2.10.0" +version = "2.10.1" requires_python = ">=3.8" summary = "Wraps the portalocker recipe for easy usage" groups = ["default", "dev"] @@ -3053,8 +3037,8 @@ dependencies = [ "pywin32>=226; platform_system == \"Windows\"", ] files = [ - {file = "portalocker-2.10.0-py3-none-any.whl", hash = "sha256:48944147b2cd42520549bc1bb8fe44e220296e56f7c3d551bc6ecce69d9b0de1"}, - {file = "portalocker-2.10.0.tar.gz", hash = "sha256:49de8bc0a2f68ca98bf9e219c81a3e6b27097c7bf505a87c5a112ce1aaeb9b81"}, + {file = "portalocker-2.10.1-py3-none-any.whl", hash = "sha256:53a5984ebc86a025552264b459b46a2086e269b21823cb572f8f28ee759e45bf"}, + {file = "portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f"}, ] [[package]] @@ -3457,7 +3441,7 @@ files = [ [[package]] name = "pypdf" -version = "4.2.0" +version = "4.3.0" requires_python = ">=3.6" summary = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files" groups = ["default", "dev"] @@ -3465,8 +3449,8 @@ dependencies = [ "typing-extensions>=4.0; python_version < \"3.11\"", ] files = [ - {file = "pypdf-4.2.0-py3-none-any.whl", hash = "sha256:dc035581664e0ad717e3492acebc1a5fc23dba759e788e3d4a9fc9b1a32e72c1"}, - {file = "pypdf-4.2.0.tar.gz", hash = "sha256:fe63f3f7d1dcda1c9374421a94c1bba6c6f8c4a62173a59b64ffd52058f846b1"}, + {file = "pypdf-4.3.0-py3-none-any.whl", hash = "sha256:eeea4d019b57c099d02a0e1692eaaab23341ae3f255c1dafa3c8566b4636496d"}, + {file = "pypdf-4.3.0.tar.gz", hash = "sha256:0d7a4c67fd03782f5a09d3f48c11c7a31e0bb9af78861a25229bb49259ed0504"}, ] [[package]] @@ -4022,18 +4006,18 @@ files = [ [[package]] name = "setuptools" -version = "70.2.0" +version = "70.3.0" requires_python = ">=3.8" summary = "Easily download, build, install, upgrade, and uninstall Python packages" groups = ["default", "dev"] files = [ - {file = "setuptools-70.2.0-py3-none-any.whl", hash = "sha256:b8b8060bb426838fbe942479c90296ce976249451118ef566a5a0b7d8b78fb05"}, - {file = "setuptools-70.2.0.tar.gz", hash = "sha256:bd63e505105011b25c3c11f753f7e3b8465ea739efddaccef8f0efac2137bac1"}, + {file = "setuptools-70.3.0-py3-none-any.whl", hash = "sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc"}, + {file = "setuptools-70.3.0.tar.gz", hash = "sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5"}, ] [[package]] name = "shapely" -version = "2.0.4" +version = "2.0.5" requires_python = ">=3.7" summary = "Manipulation and analysis of geometric objects" groups = ["default", "dev"] @@ -4041,28 +4025,25 @@ dependencies = [ "numpy<3,>=1.14", ] files = [ - {file = "shapely-2.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:011b77153906030b795791f2fdfa2d68f1a8d7e40bce78b029782ade3afe4f2f"}, - {file = "shapely-2.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9831816a5d34d5170aa9ed32a64982c3d6f4332e7ecfe62dc97767e163cb0b17"}, - {file = "shapely-2.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5c4849916f71dc44e19ed370421518c0d86cf73b26e8656192fcfcda08218fbd"}, - {file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:841f93a0e31e4c64d62ea570d81c35de0f6cea224568b2430d832967536308e6"}, - {file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b4431f522b277c79c34b65da128029a9955e4481462cbf7ebec23aab61fc58"}, - {file = "shapely-2.0.4-cp310-cp310-win32.whl", hash = "sha256:92a41d936f7d6743f343be265ace93b7c57f5b231e21b9605716f5a47c2879e7"}, - {file = "shapely-2.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:30982f79f21bb0ff7d7d4a4e531e3fcaa39b778584c2ce81a147f95be1cd58c9"}, - {file = "shapely-2.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de0205cb21ad5ddaef607cda9a3191eadd1e7a62a756ea3a356369675230ac35"}, - {file = "shapely-2.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7d56ce3e2a6a556b59a288771cf9d091470116867e578bebced8bfc4147fbfd7"}, - {file = "shapely-2.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:58b0ecc505bbe49a99551eea3f2e8a9b3b24b3edd2a4de1ac0dc17bc75c9ec07"}, - {file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:790a168a808bd00ee42786b8ba883307c0e3684ebb292e0e20009588c426da47"}, - {file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4310b5494271e18580d61022c0857eb85d30510d88606fa3b8314790df7f367d"}, - {file = "shapely-2.0.4-cp311-cp311-win32.whl", hash = "sha256:63f3a80daf4f867bd80f5c97fbe03314348ac1b3b70fb1c0ad255a69e3749879"}, - {file = "shapely-2.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:c52ed79f683f721b69a10fb9e3d940a468203f5054927215586c5d49a072de8d"}, - {file = "shapely-2.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3f9103abd1678cb1b5f7e8e1af565a652e036844166c91ec031eeb25c5ca8af0"}, - {file = "shapely-2.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:263bcf0c24d7a57c80991e64ab57cba7a3906e31d2e21b455f493d4aab534aaa"}, - {file = "shapely-2.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ddf4a9bfaac643e62702ed662afc36f6abed2a88a21270e891038f9a19bc08fc"}, - {file = "shapely-2.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:485246fcdb93336105c29a5cfbff8a226949db37b7473c89caa26c9bae52a242"}, - {file = "shapely-2.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8de4578e838a9409b5b134a18ee820730e507b2d21700c14b71a2b0757396acc"}, - {file = "shapely-2.0.4-cp39-cp39-win32.whl", hash = "sha256:9dab4c98acfb5fb85f5a20548b5c0abe9b163ad3525ee28822ffecb5c40e724c"}, - {file = "shapely-2.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:31c19a668b5a1eadab82ff070b5a260478ac6ddad3a5b62295095174a8d26398"}, - {file = "shapely-2.0.4.tar.gz", hash = "sha256:5dc736127fac70009b8d309a0eeb74f3e08979e530cf7017f2f507ef62e6cfb8"}, + {file = "shapely-2.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:89d34787c44f77a7d37d55ae821f3a784fa33592b9d217a45053a93ade899375"}, + {file = "shapely-2.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:798090b426142df2c5258779c1d8d5734ec6942f778dab6c6c30cfe7f3bf64ff"}, + {file = "shapely-2.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45211276900c4790d6bfc6105cbf1030742da67594ea4161a9ce6812a6721e68"}, + {file = "shapely-2.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e119444bc27ca33e786772b81760f2028d930ac55dafe9bc50ef538b794a8e1"}, + {file = "shapely-2.0.5-cp310-cp310-win32.whl", hash = "sha256:9a4492a2b2ccbeaebf181e7310d2dfff4fdd505aef59d6cb0f217607cb042fb3"}, + {file = "shapely-2.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:1e5cb5ee72f1bc7ace737c9ecd30dc174a5295fae412972d3879bac2e82c8fae"}, + {file = "shapely-2.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5bbfb048a74cf273db9091ff3155d373020852805a37dfc846ab71dde4be93ec"}, + {file = "shapely-2.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93be600cbe2fbaa86c8eb70656369f2f7104cd231f0d6585c7d0aa555d6878b8"}, + {file = "shapely-2.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8e71bb9a46814019f6644c4e2560a09d44b80100e46e371578f35eaaa9da1c"}, + {file = "shapely-2.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5251c28a29012e92de01d2e84f11637eb1d48184ee8f22e2df6c8c578d26760"}, + {file = "shapely-2.0.5-cp311-cp311-win32.whl", hash = "sha256:35110e80070d664781ec7955c7de557456b25727a0257b354830abb759bf8311"}, + {file = "shapely-2.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c6b78c0007a34ce7144f98b7418800e0a6a5d9a762f2244b00ea560525290c9"}, + {file = "shapely-2.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7545a39c55cad1562be302d74c74586f79e07b592df8ada56b79a209731c0219"}, + {file = "shapely-2.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4c83a36f12ec8dee2066946d98d4d841ab6512a6ed7eb742e026a64854019b5f"}, + {file = "shapely-2.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89e640c2cd37378480caf2eeda9a51be64201f01f786d127e78eaeff091ec897"}, + {file = "shapely-2.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06efe39beafde3a18a21dde169d32f315c57da962826a6d7d22630025200c5e6"}, + {file = "shapely-2.0.5-cp39-cp39-win32.whl", hash = "sha256:8203a8b2d44dcb366becbc8c3d553670320e4acf0616c39e218c9561dd738d92"}, + {file = "shapely-2.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:7fed9dbfbcfec2682d9a047b9699db8dcc890dfca857ecba872c42185fc9e64e"}, + {file = "shapely-2.0.5.tar.gz", hash = "sha256:bff2366bc786bfa6cb353d6b47d0443c570c32776612e527ee47b6df63fcfe32"}, ] [[package]] @@ -4322,13 +4303,13 @@ files = [ [[package]] name = "sqlparse" -version = "0.5.0" +version = "0.5.1" requires_python = ">=3.8" summary = "A non-validating SQL parser." groups = ["default"] files = [ - {file = "sqlparse-0.5.0-py3-none-any.whl", hash = "sha256:c204494cd97479d0e39f28c93d46c0b2d5959c7b9ab904762ea6c7af211c8663"}, - {file = "sqlparse-0.5.0.tar.gz", hash = "sha256:714d0a4932c059d16189f58ef5411ec2287a4360f17cdd0edd2d09d4c5087c93"}, + {file = "sqlparse-0.5.1-py3-none-any.whl", hash = "sha256:773dcbf9a5ab44a090f3441e2180efe2560220203dc2f8c0b0fa141e18b505e4"}, + {file = "sqlparse-0.5.1.tar.gz", hash = "sha256:bb6b4df465655ef332548e24f08e205afc81b9ab86cb1c45657a7ff173a3a00e"}, ] [[package]] @@ -4491,13 +4472,13 @@ files = [ [[package]] name = "tomlkit" -version = "0.12.5" -requires_python = ">=3.7" +version = "0.13.0" +requires_python = ">=3.8" summary = "Style preserving TOML library" groups = ["default", "dev"] files = [ - {file = "tomlkit-0.12.5-py3-none-any.whl", hash = "sha256:af914f5a9c59ed9d0762c7b64d3b5d5df007448eb9cd2edc8a46b1eafead172f"}, - {file = "tomlkit-0.12.5.tar.gz", hash = "sha256:eef34fba39834d4d6b73c9ba7f3e4d1c417a4e56f89a7e96e090dd0d24b8fb3c"}, + {file = "tomlkit-0.13.0-py3-none-any.whl", hash = "sha256:7075d3042d03b80f603482d69bf0c8f345c2b30e41699fd8883227f89972b264"}, + {file = "tomlkit-0.13.0.tar.gz", hash = "sha256:08ad192699734149f5b97b45f1f18dad7eb1b6d16bc72ad0c2335772650d7b72"}, ] [[package]] @@ -4698,7 +4679,7 @@ dependencies = [ "PyMySQL==1.1.0", "adlfs==2023.8.0", "boxfs==0.2.1", - "dropboxdrivefs==1.3.1", + "dropboxdrivefs==1.4.1", "gcsfs==2023.6.0", "google-auth==2.20.0", "google-cloud-bigquery==3.11.4", @@ -4829,13 +4810,13 @@ files = [ [[package]] name = "validators" -version = "0.31.0" +version = "0.33.0" requires_python = ">=3.8" summary = "Python Data Validation for Humans™" groups = ["default", "dev"] files = [ - {file = "validators-0.31.0-py3-none-any.whl", hash = "sha256:e15a600d81555a4cd409b17bf55946c5edec7748e776afc85ed0a19bdee54e56"}, - {file = "validators-0.31.0.tar.gz", hash = "sha256:de7574fc56a231c788162f3e7da15bc2053c5ff9e0281d9ff1afb3a7b69498df"}, + {file = "validators-0.33.0-py3-none-any.whl", hash = "sha256:134b586a98894f8139865953899fc2daeb3d0c35569552c5518f089ae43ed075"}, + {file = "validators-0.33.0.tar.gz", hash = "sha256:535867e9617f0100e676a1257ba1e206b9bfd847ddc171e4d44811f07ff0bfbf"}, ] [[package]] diff --git a/backend/platform_settings_v2/__init__.py b/backend/platform_settings_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/platform_settings_v2/admin.py b/backend/platform_settings_v2/admin.py new file mode 100644 index 000000000..846f6b406 --- /dev/null +++ b/backend/platform_settings_v2/admin.py @@ -0,0 +1 @@ +# Register your models here. diff --git a/backend/platform_settings_v2/apps.py b/backend/platform_settings_v2/apps.py new file mode 100644 index 000000000..7fd953c81 --- /dev/null +++ b/backend/platform_settings_v2/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class PlatformSettingsConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "platform_settings_v2" diff --git a/backend/platform_settings_v2/constants.py b/backend/platform_settings_v2/constants.py new file mode 100644 index 000000000..29ba63684 --- /dev/null +++ b/backend/platform_settings_v2/constants.py @@ -0,0 +1,14 @@ +class PlatformServiceConstants: + IS_ACTIVE = "is_active" + KEY = "key" + ORGANIZATION = "organization" + ID = "id" + ACTIVATE = "ACTIVATE" + DEACTIVATE = "DEACTIVATE" + ACTION = "action" + KEY_NAME = "key_name" + + +class ErrorMessage: + KEY_EXIST = "Key name already exists" + DUPLICATE_API = "It appears that a duplicate call may have been made." diff --git a/backend/platform_settings_v2/exceptions.py b/backend/platform_settings_v2/exceptions.py new file mode 100644 index 000000000..512c57d1b --- /dev/null +++ b/backend/platform_settings_v2/exceptions.py @@ -0,0 +1,49 @@ +from typing import Optional + +from rest_framework.exceptions import APIException + + +class InternalServiceError(APIException): + status_code = 500 + default_detail = "Internal error occurred while performing platform key operations." + + +class UserForbidden(APIException): + status_code = 403 + default_detail = ( + "User is forbidden from performing this action. Please contact admin." + ) + + +class KeyCountExceeded(APIException): + status_code = 403 + default_detail = ( + "Maximum key count is exceeded. Please delete one before generation." + ) + + +class FoundActiveKey(APIException): + status_code = 403 + default_detail = "Only one active key allowed at a time." + + +class ActiveKeyNotFound(APIException): + status_code = 404 + default_detail = "At least one active platform key should be available" + + +class InvalidRequest(APIException): + status_code = 401 + default_detail = "Invalid Request" + + +class DuplicateData(APIException): + status_code = 400 + default_detail = "Duplicate Data" + + def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): + if detail is not None: + self.detail = detail + if code is not None: + self.code = code + super().__init__(detail, code) diff --git a/backend/platform_settings_v2/migrations/__init__.py b/backend/platform_settings_v2/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/platform_settings_v2/models.py b/backend/platform_settings_v2/models.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/platform_settings_v2/platform_auth_helper.py b/backend/platform_settings_v2/platform_auth_helper.py new file mode 100644 index 000000000..144287e8e --- /dev/null +++ b/backend/platform_settings_v2/platform_auth_helper.py @@ -0,0 +1,51 @@ +import logging + +from account_v2.authentication_controller import AuthenticationController +from account_v2.models import Organization, PlatformKey, User +from platform_settings_v2.exceptions import KeyCountExceeded, UserForbidden +from tenant_account_v2.models import OrganizationMember + +PLATFORM_KEY_COUNT = 2 + +logger = logging.getLogger(__name__) + + +class PlatformAuthHelper: + """Class to hold helper functions for Platform settings authentication.""" + + @staticmethod + def validate_user_role(user: User) -> None: + """This method validates if the logged in user has admin role for + performing appropriate actions. + + Args: + user (User): Logged in user from context + """ + auth_controller = AuthenticationController() + member: OrganizationMember = auth_controller.get_organization_members_by_user( + user=user + ) + if not auth_controller.is_admin_by_role(member.role): + logger.error("User is not having right access to perform this operation.") + raise UserForbidden() + else: + pass + + @staticmethod + def validate_token_count(organization: Organization) -> None: + """This method validates if the organization has reached the maximum + platform key count. + + Args: + organization (Organization): + Organization for which the key is being created. + """ + key_count = PlatformKey.objects.filter(organization=organization).count() + if key_count >= PLATFORM_KEY_COUNT: + logger.error( + f"Key count exceeded: {key_count}/{PLATFORM_KEY_COUNT} keys for " + f"organization ID {organization.id}." + ) + raise KeyCountExceeded() + else: + pass diff --git a/backend/platform_settings_v2/platform_auth_service.py b/backend/platform_settings_v2/platform_auth_service.py new file mode 100644 index 000000000..e9099e04c --- /dev/null +++ b/backend/platform_settings_v2/platform_auth_service.py @@ -0,0 +1,242 @@ +import logging +import uuid +from typing import Any, Optional + +from account_v2.models import Organization, PlatformKey, User +from account_v2.organization import OrganizationService +from django.db import IntegrityError +from platform_settings_v2.exceptions import ( + ActiveKeyNotFound, + DuplicateData, + InternalServiceError, + InvalidRequest, +) +from tenant_account_v2.constants import ErrorMessage, PlatformServiceConstants +from utils.user_context import UserContext + +logger = logging.getLogger(__name__) + + +class PlatformAuthenticationService: + """Service class to hold Platform service authentication and validation. + + Supports generation, refresh, revoke and toggle of active keys. + """ + + @staticmethod + def generate_platform_key( + is_active: bool, + key_name: str, + user: User, + organization: Optional[Organization] = None, + ) -> dict[str, Any]: + """Method to support generation of new platform key. Throws error when + maximum count is exceeded. Forbids for user other than admin + permission. + + Args: + key_name (str): Value of the key + is_active (bool): By default the key is False + user (User): User object representing the user generating the key + organization (Optional[Organization], optional): + Org the key belongs to. Defaults to None. + + Returns: + dict[str, Any]: + A dictionary containing the generated platform key details, + including the id, key name, and key value. + Raises: + DuplicateData: If a platform key with the same key name + already exists for the organization. + InternalServiceError: If an internal error occurs while + generating the platform key. + """ + organization: Organization = organization or UserContext.get_organization() + if not organization: + raise InternalServiceError("No valid organization provided") + try: + # TODO : Add encryption to Platform keys + # id is added here to avoid passing of keys in transactions. + platform_key: PlatformKey = PlatformKey( + id=str(uuid.uuid4()), + key=str(uuid.uuid4()), + is_active=is_active, + organization=organization, + key_name=key_name, + created_by=user, + modified_by=user, + ) + platform_key.save() + result: dict[str, Any] = {} + result[PlatformServiceConstants.ID] = platform_key.id + result[PlatformServiceConstants.KEY_NAME] = platform_key.key_name + result[PlatformServiceConstants.KEY] = platform_key.key + + logger.info(f"platform_key is generated for {organization.id}") + return result + except IntegrityError as error: + logger.error( + "Failed to generate platform key for " + f"organization {organization}, Integrity error: {error}" + ) + raise DuplicateData( + f"{ErrorMessage.KEY_EXIST}, \ + {ErrorMessage.DUPLICATE_API}" + ) + + @staticmethod + def delete_platform_key(id: str) -> None: + """Method to delete a platform key by id. + + Args: + id (str): platform key primary id + + Raises: + error: IntegrityError + """ + try: + platform_key: PlatformKey = PlatformKey.objects.get(pk=id) + platform_key.delete() + # TODO: Add organization details in logs in possible places once v2 enabled + logger.info(f"platform_key {id} is deleted for {platform_key.organization}") + except IntegrityError as error: + logger.error(f"Failed to delete platform key : {error}") + raise DuplicateData( + f"{ErrorMessage.KEY_EXIST}, \ + {ErrorMessage.DUPLICATE_API}" + ) + + @staticmethod + def refresh_platform_key(id: str, user: User) -> dict[str, Any]: + """Method to refresh a platform key. + + Args: + id (str): Unique id of the key to be refreshed + new_key (str): Value to be updated. + + Raises: + error: IntegrityError + """ + try: + result: dict[str, Any] = {} + platform_key: PlatformKey = PlatformKey.objects.get(pk=id) + platform_key.key = str(uuid.uuid4()) + platform_key.modified_by = user + platform_key.save() + result[PlatformServiceConstants.ID] = platform_key.id + result[PlatformServiceConstants.KEY_NAME] = platform_key.key_name + result[PlatformServiceConstants.KEY] = platform_key.key + + logger.info(f"platform_key {id} is updated by user {user.id}") + return result + except IntegrityError as error: + logger.error( + f"Failed to refresh platform key {id} " + f"by user {user.id}, Integrity error: {error}" + ) + raise DuplicateData( + f"{ErrorMessage.KEY_EXIST}, \ + {ErrorMessage.DUPLICATE_API}" + ) + + @staticmethod + def toggle_platform_key_status( + platform_key: PlatformKey, action: str, user: User + ) -> None: + """Method to activate/deactivate a platform key. Only one active key is + allowed at a time. On change or setting, other keys are deactivated. + + Args: + platform_key (PlatformKey): The platform key to be toggled. + action (str): activate/deactivate + user (User): The user performing the action. + + Raises: + InvalidRequest: If no valid organization is found. + DuplicateData: If an IntegrityError occurs during the save operation. + """ + try: + organization: Organization = UserContext.get_organization() + if not organization: + logger.error( + f"No valid organization provided to toggle status of platform key " + f"{platform_key.id} for user {user.id}" + ) + raise InvalidRequest("Invalid organization") + platform_key.modified_by = user + if action == PlatformServiceConstants.ACTIVATE: + # Deactivate all active keys for the organization + PlatformKey.objects.filter( + is_active=True, organization=organization + ).update(is_active=False, modified_by=user) + # Activate the chosen key + platform_key.is_active = True + elif action == PlatformServiceConstants.DEACTIVATE: + platform_key.is_active = False + else: + logger.error( + f"Invalid action: {action} for platform key {platform_key.id} " + f"by user {user.id}" + ) + raise InvalidRequest(f"Invalid action: {action}") + platform_key.save() + except IntegrityError as error: + logger.error( + f"IntegrityError - Failed to {action} platform key {platform_key.id}" + f": {error}" + ) + raise DuplicateData( + f"{ErrorMessage.KEY_EXIST}, {ErrorMessage.DUPLICATE_API}" + ) + + @staticmethod + def list_platform_key_ids() -> list[PlatformKey]: + """Method to fetch list of platform keys unique ids for internal usage. + + Returns: + Any: List of platform keys. + """ + organization_id = UserContext.get_organization_identifier() + organization: Organization = OrganizationService.get_organization_by_org_id( + org_id=organization_id + ) + organization_pk = organization.id + + platform_keys: list[PlatformKey] = PlatformKey.objects.filter( + organization=organization_pk + ) + return platform_keys + + @staticmethod + def fetch_platform_key_id() -> Any: + """Method to fetch list of platform keys unique ids for internal usage. + + Returns: + Any: List of platform keys. + """ + platform_key: list[PlatformKey] = PlatformKey.objects.all() + return platform_key + + @staticmethod + def get_active_platform_key( + organization_id: Optional[str] = None, + ) -> PlatformKey: + """Method to fetch active key. + + Considering only one active key is allowed at a time + Returns: + Any: platformKey. + """ + try: + organization_id = ( + organization_id or UserContext.get_organization_identifier() + ) + organization: Organization = OrganizationService.get_organization_by_org_id( + org_id=organization_id + ) + platform_key: PlatformKey = PlatformKey.objects.get( + organization=organization, is_active=True + ) + return platform_key + except PlatformKey.DoesNotExist: + raise ActiveKeyNotFound() diff --git a/backend/platform_settings_v2/serializers.py b/backend/platform_settings_v2/serializers.py new file mode 100644 index 000000000..24bd93ec5 --- /dev/null +++ b/backend/platform_settings_v2/serializers.py @@ -0,0 +1,24 @@ +from account_v2.models import PlatformKey +from rest_framework import serializers + +from backend.serializers import AuditSerializer + + +class PlatformKeySerializer(AuditSerializer): + class Meta: + model = PlatformKey + fields = "__all__" + + +class PlatformKeyGenerateSerializer(serializers.Serializer): + # Adjust these fields based on your actual serializer + is_active = serializers.BooleanField() + + key_name = serializers.CharField() + + +class PlatformKeyIDSerializer(serializers.Serializer): + id = serializers.CharField() + key_name = serializers.CharField() + key = serializers.CharField() + is_active = serializers.BooleanField() diff --git a/backend/platform_settings_v2/tests.py b/backend/platform_settings_v2/tests.py new file mode 100644 index 000000000..a39b155ac --- /dev/null +++ b/backend/platform_settings_v2/tests.py @@ -0,0 +1 @@ +# Create your tests here. diff --git a/backend/platform_settings_v2/urls.py b/backend/platform_settings_v2/urls.py new file mode 100644 index 000000000..feb1f5cc1 --- /dev/null +++ b/backend/platform_settings_v2/urls.py @@ -0,0 +1,26 @@ +from django.urls import path +from rest_framework.urlpatterns import format_suffix_patterns + +from .views import PlatformKeyViewSet + +platform_key_list = PlatformKeyViewSet.as_view( + {"post": "create", "put": "refresh", "get": "list"} +) +platform_key_update = PlatformKeyViewSet.as_view( + {"put": "toggle_platform_key", "delete": "destroy"} +) + +urlpatterns = format_suffix_patterns( + [ + path( + "keys/", + platform_key_list, + name="generate_platform_key", + ), + path( + "keys//", + platform_key_update, + name="update_platform_key", + ), + ] +) diff --git a/backend/platform_settings_v2/views.py b/backend/platform_settings_v2/views.py new file mode 100644 index 000000000..137fd1712 --- /dev/null +++ b/backend/platform_settings_v2/views.py @@ -0,0 +1,122 @@ +# views.py + +import logging +from typing import Any + +from account_v2.models import Organization, PlatformKey +from platform_settings_v2.constants import PlatformServiceConstants +from platform_settings_v2.platform_auth_helper import PlatformAuthHelper +from platform_settings_v2.platform_auth_service import PlatformAuthenticationService +from platform_settings_v2.serializers import ( + PlatformKeyGenerateSerializer, + PlatformKeyIDSerializer, + PlatformKeySerializer, +) +from rest_framework import status, viewsets +from rest_framework.request import Request +from rest_framework.response import Response +from utils.user_context import UserContext + +logger = logging.getLogger(__name__) + + +class PlatformKeyViewSet(viewsets.ModelViewSet): + queryset = PlatformKey.objects.all() + serializer_class = PlatformKeySerializer + + def validate_user_role(func: Any) -> Any: + def wrapper( + self: Any, + request: Request, + *args: tuple[Any], + **kwargs: dict[str, Any], + ) -> Any: + PlatformAuthHelper.validate_user_role(request.user) + return func(self, request, *args, **kwargs) + + return wrapper + + @validate_user_role + def list( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> Response: + platform_key_ids = PlatformAuthenticationService.list_platform_key_ids() + serializer = PlatformKeyIDSerializer(platform_key_ids, many=True) + return Response( + status=status.HTTP_200_OK, + data=serializer.data, + ) + + @validate_user_role + def refresh( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> Response: + """API Endpoint for refreshing platform keys.""" + id = request.data.get(PlatformServiceConstants.ID) + if not id: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={ + "message": "validation error", + "errors": "Mandatory fields missing", + }, + ) + platform_key = PlatformAuthenticationService.refresh_platform_key( + id=id, user=request.user + ) + return Response( + status=status.HTTP_201_CREATED, + data=platform_key, + ) + + @validate_user_role + def destroy( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> Response: + instance = self.get_object() + instance.delete() + return Response( + status=status.HTTP_204_NO_CONTENT, + data={"message": "Platform key deleted successfully"}, + ) + + @validate_user_role + def toggle_platform_key( + self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] + ) -> Response: + instance = self.get_object() + action = request.data.get(PlatformServiceConstants.ACTION) + if not action: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={ + "message": "validation error", + "errors": "Mandatory fields missing", + }, + ) + PlatformAuthenticationService.toggle_platform_key_status( + platform_key=instance, action=action, user=request.user + ) + return Response( + status=status.HTTP_201_CREATED, + data={"message": "Platform key toggled successfully"}, + ) + + @validate_user_role + def create(self, request: Request) -> Response: + serializer = PlatformKeyGenerateSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + is_active = request.data.get(PlatformServiceConstants.IS_ACTIVE) + key_name = request.data.get(PlatformServiceConstants.KEY_NAME) + organization: Organization = UserContext.get_organization() + + PlatformAuthHelper.validate_token_count(organization=organization) + + platform_key = PlatformAuthenticationService.generate_platform_key( + is_active=is_active, key_name=key_name, user=request.user + ) + serialized_data = self.serializer_class(platform_key).data + return Response( + status=status.HTTP_201_CREATED, + data=serialized_data, + ) diff --git a/backend/prompt_studio/prompt_profile_manager/serializers.py b/backend/prompt_studio/prompt_profile_manager/serializers.py index fc83aaab4..4d4753561 100644 --- a/backend/prompt_studio/prompt_profile_manager/serializers.py +++ b/backend/prompt_studio/prompt_profile_manager/serializers.py @@ -2,7 +2,6 @@ from adapter_processor.adapter_processor import AdapterProcessor from prompt_studio.prompt_profile_manager.constants import ProfileManagerKeys -from prompt_studio.prompt_studio_core.exceptions import MaxProfilesReachedError from backend.serializers import AuditSerializer @@ -39,15 +38,3 @@ def to_representation(self, instance): # type: ignore AdapterProcessor.get_adapter_instance_by_id(x2text) ) return rep - - def validate(self, data): - prompt_studio_tool = data.get(ProfileManagerKeys.PROMPT_STUDIO_TOOL) - - profile_count = ProfileManager.objects.filter( - prompt_studio_tool=prompt_studio_tool - ).count() - - if profile_count >= ProfileManagerKeys.MAX_PROFILE_COUNT: - raise MaxProfilesReachedError() - - return data diff --git a/backend/prompt_studio/prompt_studio_core/constants.py b/backend/prompt_studio/prompt_studio_core/constants.py index 55d61e32e..213fc066f 100644 --- a/backend/prompt_studio/prompt_studio_core/constants.py +++ b/backend/prompt_studio/prompt_studio_core/constants.py @@ -88,6 +88,7 @@ class ToolStudioPromptKeys: PROFILE_MANAGER_ID = "profile_manager" CONTEXT = "context" METADATA = "metadata" + INCLUDE_METADATA = "include_metadata" class FileViewTypes: diff --git a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py index e1e8e76b5..c5513b7e6 100644 --- a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py +++ b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py @@ -498,7 +498,11 @@ def _execute_prompts_in_single_pass( doc_path, tool_id, org_id, user_id, document_id, run_id ): prompts = PromptStudioHelper.fetch_prompt_from_tool(tool_id) - prompts = [prompt for prompt in prompts if prompt.prompt_type != TSPKeys.NOTES] + prompts = [ + prompt + for prompt in prompts + if prompt.prompt_type != TSPKeys.NOTES and prompt.active + ] if not prompts: logger.error(f"[{tool_id or 'NA'}] No prompts found for id: {id}") raise NoPromptsFound() @@ -959,8 +963,9 @@ def _fetch_single_pass_response( prompt_host=settings.PROMPT_HOST, prompt_port=settings.PROMPT_PORT, ) + include_metadata = {TSPKeys.INCLUDE_METADATA: True} - answer = responder.single_pass_extraction(payload) + answer = responder.single_pass_extraction(payload, include_metadata) # TODO: Make use of dataclasses if answer["status"] == "ERROR": error_message = answer.get("error", None) @@ -969,3 +974,11 @@ def _fetch_single_pass_response( ) output_response = json.loads(answer["structure_output"]) return output_response + + @staticmethod + def get_tool_from_tool_id(tool_id: str) -> Optional[CustomTool]: + try: + tool: CustomTool = CustomTool.objects.get(tool_id=tool_id) + return tool + except CustomTool.DoesNotExist: + return None diff --git a/backend/prompt_studio/prompt_studio_core/views.py b/backend/prompt_studio/prompt_studio_core/views.py index 8db0a3ef5..18915ac24 100644 --- a/backend/prompt_studio/prompt_studio_core/views.py +++ b/backend/prompt_studio/prompt_studio_core/views.py @@ -10,7 +10,10 @@ from file_management.file_management_helper import FileManagerHelper from permissions.permission import IsOwner, IsOwnerOrSharedUser from prompt_studio.processor_loader import ProcessorConfig, load_plugins -from prompt_studio.prompt_profile_manager.constants import ProfileManagerErrors +from prompt_studio.prompt_profile_manager.constants import ( + ProfileManagerErrors, + ProfileManagerKeys, +) from prompt_studio.prompt_profile_manager.models import ProfileManager from prompt_studio.prompt_profile_manager.serializers import ProfileManagerSerializer from prompt_studio.prompt_studio.constants import ToolStudioPromptErrors @@ -26,6 +29,7 @@ ) from prompt_studio.prompt_studio_core.exceptions import ( IndexingAPIError, + MaxProfilesReachedError, ToolDeleteError, ) from prompt_studio.prompt_studio_core.prompt_studio_helper import PromptStudioHelper @@ -345,6 +349,16 @@ def create_profile_manager(self, request: HttpRequest, pk: Any = None) -> Respon serializer = ProfileManagerSerializer(data=request.data, context=context) serializer.is_valid(raise_exception=True) + # Check for the maximum number of profiles constraint + prompt_studio_tool = serializer.validated_data[ + ProfileManagerKeys.PROMPT_STUDIO_TOOL + ] + profile_count = ProfileManager.objects.filter( + prompt_studio_tool=prompt_studio_tool + ).count() + + if profile_count >= ProfileManagerKeys.MAX_PROFILE_COUNT: + raise MaxProfilesReachedError() try: self.perform_create(serializer) except IntegrityError: diff --git a/backend/prompt_studio/prompt_studio_output_manager/constants.py b/backend/prompt_studio/prompt_studio_output_manager/constants.py index a9c046aae..1cee0b394 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/constants.py +++ b/backend/prompt_studio/prompt_studio_output_manager/constants.py @@ -5,3 +5,8 @@ class PromptStudioOutputManagerKeys: DOCUMENT_MANAGER = "document_manager" IS_SINGLE_PASS_EXTRACT = "is_single_pass_extract" NOTES = "NOTES" + + +class PromptOutputManagerErrorMessage: + TOOL_VALIDATION = "tool_id parameter is required" + TOOL_NOT_FOUND = "Tool not found" diff --git a/backend/prompt_studio/prompt_studio_output_manager/urls.py b/backend/prompt_studio/prompt_studio_output_manager/urls.py index 45af0ccee..61ec8540f 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/urls.py +++ b/backend/prompt_studio/prompt_studio_output_manager/urls.py @@ -4,9 +4,17 @@ from .views import PromptStudioOutputView prompt_doc_list = PromptStudioOutputView.as_view({"get": "list"}) +get_output_for_tool_default = PromptStudioOutputView.as_view( + {"get": "get_output_for_tool_default"} +) urlpatterns = format_suffix_patterns( [ path("prompt-output/", prompt_doc_list, name="prompt-doc-list"), + path( + "prompt-output/prompt-default-profile/", + get_output_for_tool_default, + name="prompt-default-profile-outputs", + ), ] ) diff --git a/backend/prompt_studio/prompt_studio_output_manager/views.py b/backend/prompt_studio/prompt_studio_output_manager/views.py index 3be3c5888..e1ef58296 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/views.py +++ b/backend/prompt_studio/prompt_studio_output_manager/views.py @@ -1,14 +1,20 @@ import logging -from typing import Optional +from typing import Any, Optional +from django.core.exceptions import ObjectDoesNotExist from django.db.models import QuerySet +from django.http import HttpRequest +from prompt_studio.prompt_studio.models import ToolStudioPrompt from prompt_studio.prompt_studio_output_manager.constants import ( + PromptOutputManagerErrorMessage, PromptStudioOutputManagerKeys, ) from prompt_studio.prompt_studio_output_manager.serializers import ( PromptStudioOutputSerializer, ) -from rest_framework import viewsets +from rest_framework import status, viewsets +from rest_framework.exceptions import APIException +from rest_framework.response import Response from rest_framework.versioning import URLPathVersioning from utils.common_utils import CommonUtils from utils.filtering import FilterHelper @@ -49,3 +55,51 @@ def get_queryset(self) -> Optional[QuerySet]: queryset = PromptStudioOutputManager.objects.filter(**filter_args) return queryset + + def get_output_for_tool_default(self, request: HttpRequest) -> Response: + # Get the tool_id from request parameters + # Get the tool_id from request parameters + tool_id = request.GET.get("tool_id") + document_manager_id = request.GET.get("document_manager") + tool_validation_message = PromptOutputManagerErrorMessage.TOOL_VALIDATION + tool_not_found = PromptOutputManagerErrorMessage.TOOL_NOT_FOUND + if not tool_id: + raise APIException(detail=tool_validation_message, code=400) + + try: + # Fetch ToolStudioPrompt records based on tool_id + tool_studio_prompts = ToolStudioPrompt.objects.filter(tool_id=tool_id) + except ObjectDoesNotExist: + raise APIException(detail=tool_not_found, code=400) + + # Initialize the result dictionary + result: dict[str, Any] = {} + + # Iterate over ToolStudioPrompt records + for tool_prompt in tool_studio_prompts: + prompt_id = str(tool_prompt.prompt_id) + profile_manager_id = str(tool_prompt.profile_manager.profile_id) + + # If profile_manager is not set, skip this record + if not profile_manager_id: + result[tool_prompt.prompt_key] = "" + continue + + try: + queryset = PromptStudioOutputManager.objects.filter( + prompt_id=prompt_id, + profile_manager=profile_manager_id, + is_single_pass_extract=False, + document_manager_id=document_manager_id, + ) + + if not queryset.exists(): + result[tool_prompt.prompt_key] = "" + continue + + for output in queryset: + result[tool_prompt.prompt_key] = output.output + except ObjectDoesNotExist: + result[tool_prompt.prompt_key] = "" + + return Response(result, status=status.HTTP_200_OK) diff --git a/backend/prompt_studio/prompt_studio_registry/constants.py b/backend/prompt_studio/prompt_studio_registry/constants.py index 6dbde1c09..7fea2c73c 100644 --- a/backend/prompt_studio/prompt_studio_registry/constants.py +++ b/backend/prompt_studio/prompt_studio_registry/constants.py @@ -89,9 +89,6 @@ class JsonSchemaKey: ENABLE_CHALLENGE = "enable_challenge" CHALLENGE_LLM = "challenge_llm" ENABLE_SINGLE_PASS_EXTRACTION = "enable_single_pass_extraction" - IMAGE_URL = "image_url" - IMAGE_NAME = "image_name" - IMAGE_TAG = "image_tag" SUMMARIZE_PROMPT = "summarize_prompt" SUMMARIZE_AS_SOURCE = "summarize_as_source" ENABLE_HIGHLIGHT = "enable_highlight" diff --git a/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py b/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py index 67f84b450..33c253205 100644 --- a/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py +++ b/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py @@ -119,25 +119,14 @@ def get_tool_by_prompt_registry_id( f"ID {prompt_registry_id}: {e} " ) return None - # The below properties are introduced after 0.20.0 - # So defaulting to 0.20.0 if the properties are not found - image_url = prompt_registry_tool.tool_metadata.get( - JsonSchemaKey.IMAGE_URL, "docker:unstract/tool-structure:0.0.20" - ) - image_name = prompt_registry_tool.tool_metadata.get( - JsonSchemaKey.IMAGE_NAME, "unstract/tool-structure" - ) - image_tag = prompt_registry_tool.tool_metadata.get( - JsonSchemaKey.IMAGE_TAG, "0.0.20" - ) return Tool( tool_uid=prompt_registry_tool.prompt_registry_id, properties=Properties.from_dict(prompt_registry_tool.tool_property), spec=Spec.from_dict(prompt_registry_tool.tool_spec), icon=prompt_registry_tool.icon, - image_url=image_url, - image_name=image_name, - image_tag=image_tag, + image_url=settings.STRUCTURE_TOOL_IMAGE_URL, + image_name=settings.STRUCTURE_TOOL_IMAGE_NAME, + image_tag=settings.STRUCTURE_TOOL_IMAGE_TAG, ) @staticmethod @@ -176,7 +165,6 @@ def update_or_create_psr_tool( obj, created = PromptStudioRegistry.objects.update_or_create( custom_tool=custom_tool, created_by=custom_tool.created_by, - modified_by=custom_tool.modified_by, defaults={ "name": custom_tool.tool_name, "tool_property": properties.to_dict(), @@ -190,7 +178,7 @@ def update_or_create_psr_tool( logger.info(f"PSR {obj.prompt_registry_id} was created") else: logger.info(f"PSR {obj.prompt_registry_id} was updated") - + obj.modified_by = custom_tool.modified_by obj.shared_to_org = shared_with_org if not shared_with_org: obj.shared_users.clear() @@ -242,9 +230,6 @@ def frame_export_json( export_metadata[JsonSchemaKey.DESCRIPTION] = tool.description export_metadata[JsonSchemaKey.AUTHOR] = tool.author export_metadata[JsonSchemaKey.TOOL_ID] = str(tool.tool_id) - export_metadata[JsonSchemaKey.IMAGE_URL] = settings.STRUCTURE_TOOL_IMAGE_URL - export_metadata[JsonSchemaKey.IMAGE_NAME] = settings.STRUCTURE_TOOL_IMAGE_NAME - export_metadata[JsonSchemaKey.IMAGE_TAG] = settings.STRUCTURE_TOOL_IMAGE_TAG default_llm_profile = ProfileManager.get_default_llm_profile(tool) challenge_llm_instance: Optional[AdapterInstance] = tool.challenge_llm @@ -283,6 +268,8 @@ def frame_export_json( tool_settings[JsonSchemaKey.ENABLE_HIGHLIGHT] = tool.enable_highlight for prompt in prompts: + if prompt.prompt_type == JsonSchemaKey.NOTES or not prompt.active: + continue if not prompt.prompt: invalidated_prompts.append(prompt.prompt_key) @@ -298,8 +285,6 @@ def frame_export_json( invalidated_outputs.append(prompt.prompt_key) continue - if prompt.prompt_type == JsonSchemaKey.NOTES: - continue if not prompt.profile_manager: prompt.profile_manager = default_llm_profile diff --git a/backend/sample.env b/backend/sample.env index 5ff45c8a8..2cca7fc19 100644 --- a/backend/sample.env +++ b/backend/sample.env @@ -88,7 +88,8 @@ PROMPT_PORT=3003 #Prompt Studio PROMPT_STUDIO_FILE_PATH=/app/prompt-studio-data -# Structure Tool +# Structure Tool Image (Runs prompt studio exported tools) +# https://hub.docker.com/r/unstract/tool-structure STRUCTURE_TOOL_IMAGE_URL="docker:unstract/tool-structure:0.0.30" STRUCTURE_TOOL_IMAGE_NAME="unstract/tool-structure" STRUCTURE_TOOL_IMAGE_TAG="0.0.30" diff --git a/backend/tool_instance/serializers.py b/backend/tool_instance/serializers.py index 4ce323b49..f36b66f87 100644 --- a/backend/tool_instance/serializers.py +++ b/backend/tool_instance/serializers.py @@ -68,6 +68,11 @@ def create(self, validated_data: dict[str, Any]) -> Any: raise ValidationError(f"Workflow with ID {workflow_id} does not exist.") validated_data[TIKey.WORKFLOW] = workflow + if workflow.workflow_tool.count() > 0: + raise ValidationError( + f"Workflow with ID {workflow_id} can't have more than one tool." + ) + tool_uid = validated_data.get(TIKey.TOOL_ID) if not tool_uid: raise ToolDoesNotExist() diff --git a/backend/workflow_manager/endpoint/base_connector.py b/backend/workflow_manager/endpoint/base_connector.py index e90ae03c2..6c35910e5 100644 --- a/backend/workflow_manager/endpoint/base_connector.py +++ b/backend/workflow_manager/endpoint/base_connector.py @@ -45,11 +45,26 @@ def get_fsspec( Raises: KeyError: If the connector_id is not found in the connectors dictionary. """ + return self.get_fs_connector( + settings=settings, connector_id=connector_id + ).get_fsspec_fs() + + def get_fs_connector( + self, settings: dict[str, Any], connector_id: str + ) -> UnstractFileSystem: + """Get an fs connector based specified connector settings. + + Parameters: + - settings (dict): Connector-specific settings. + - connector_id (str): Identifier for the desired connector. + + Returns: + UnstractFileSystem: An unstract fs connector instance. + """ if connector_id not in connectors: - raise ValueError(f"Invalid connector_id: {connector_id}") + raise ValueError(f"Connector '{connector_id}' is not supported.") connector = connectors[connector_id][Common.METADATA][Common.CONNECTOR] - connector_class: UnstractFileSystem = connector(settings) - return connector_class.get_fsspec_fs() + return connector(settings) @classmethod def get_json_schema(cls, file_path: str) -> dict[str, Any]: diff --git a/backend/workflow_manager/endpoint/constants.py b/backend/workflow_manager/endpoint/constants.py index d9553245d..ca84b5d15 100644 --- a/backend/workflow_manager/endpoint/constants.py +++ b/backend/workflow_manager/endpoint/constants.py @@ -70,9 +70,9 @@ class FileType: class FilePattern: - PDF_DOCUMENTS = ["*.pdf"] + PDF_DOCUMENTS = ["*.pdf", "*.PDF"] TEXT_DOCUMENTS = ["*.txt"] - IMAGES = ["*.jpg", "*.jpeg", "*.png", "*.gif", "*.bmp"] + IMAGES = ["*.jpg", "*.jpeg", "*.png", "*.gif", "*.bmp", "*.tif", "*.tiff"] class SourceConstant: diff --git a/backend/workflow_manager/endpoint/destination.py b/backend/workflow_manager/endpoint/destination.py index 6b88dd86c..bdb37385d 100644 --- a/backend/workflow_manager/endpoint/destination.py +++ b/backend/workflow_manager/endpoint/destination.py @@ -175,32 +175,29 @@ def copy_output_to_output_directory(self) -> None: connector: ConnectorInstance = self.endpoint.connector_instance connector_settings: dict[str, Any] = connector.connector_metadata destination_configurations: dict[str, Any] = self.endpoint.configuration - root_path = str(connector_settings.get(DestinationKey.PATH)) - output_folder = str( + root_path = connector_settings.get(DestinationKey.PATH, "") + + output_directory = str( destination_configurations.get(DestinationKey.OUTPUT_FOLDER, "/") ) - overwrite = bool( - destination_configurations.get( - DestinationKey.OVERWRITE_OUTPUT_DOCUMENT, False - ) + destination_fs = self.get_fs_connector( + settings=connector_settings, connector_id=connector.connector_id ) - output_directory = os.path.join(root_path, output_folder) - + output_directory = destination_fs.get_connector_root_dir( + input_dir=output_directory, root_path=root_path + ) + logger.debug(f"destination output directory {output_directory}") destination_volume_path = os.path.join( self.execution_dir, ToolExecKey.OUTPUT_DIR ) - - connector_fs = self.get_fsspec( - settings=connector_settings, connector_id=connector.connector_id - ) - if not connector_fs.isdir(output_directory): - connector_fs.mkdir(output_directory) + destination_fs.create_dir_if_not_exists(input_dir=output_directory) + destination_fsspec = destination_fs.get_fsspec_fs() # Traverse local directory and create the same structure in the # output_directory for root, dirs, files in os.walk(destination_volume_path): for dir_name in dirs: - connector_fs.mkdir( + destination_fsspec.mkdir( os.path.join( output_directory, os.path.relpath(root, destination_volume_path), @@ -217,9 +214,7 @@ def copy_output_to_output_directory(self) -> None: ) normalized_path = os.path.normpath(destination_path) with open(source_path, "rb") as source_file: - connector_fs.write_bytes( - normalized_path, source_file.read(), overwrite=overwrite - ) + destination_fsspec.write_bytes(normalized_path, source_file.read()) def insert_into_db(self, file_history: Optional[FileHistory]) -> None: """Insert data into the database.""" diff --git a/backend/workflow_manager/endpoint/source.py b/backend/workflow_manager/endpoint/source.py index 4dd096044..33cb97484 100644 --- a/backend/workflow_manager/endpoint/source.py +++ b/backend/workflow_manager/endpoint/source.py @@ -4,7 +4,6 @@ import shutil from hashlib import md5, sha256 from io import BytesIO -from pathlib import Path from typing import Any, Optional import fsspec @@ -153,21 +152,27 @@ def list_files_from_file_connector(self) -> list[str]: ) ) root_dir_path = connector_settings.get(ConnectorKeys.PATH, "") + input_directory = str(source_configurations.get(SourceKey.ROOT_FOLDER, "")) - if root_dir_path: # user needs to manually type the optional file path - input_directory = str(Path(root_dir_path, input_directory.lstrip("/"))) - if not isinstance(required_patterns, list): - required_patterns = [required_patterns] - source_fs = self.get_fsspec( + source_fs = self.get_fs_connector( settings=connector_settings, connector_id=connector.connector_id ) + input_directory = source_fs.get_connector_root_dir( + input_dir=input_directory, root_path=root_dir_path + ) + logger.debug(f"source input directory {input_directory}") + if not isinstance(required_patterns, list): + required_patterns = [required_patterns] + + source_fs_fsspec = source_fs.get_fsspec_fs() + patterns = self.valid_file_patterns(required_patterns=required_patterns) - is_directory = source_fs.isdir(input_directory) + is_directory = source_fs_fsspec.isdir(input_directory) if not is_directory: raise InvalidInputDirectory() matched_files = self._get_matched_files( - source_fs, input_directory, patterns, recursive, limit + source_fs_fsspec, input_directory, patterns, recursive, limit ) self.publish_input_output_list_file_logs(input_directory, matched_files) return matched_files @@ -386,6 +391,17 @@ def handle_final_result( results.append({"file": file_name, "result": result}) def load_file(self, input_file_path: str) -> tuple[str, BytesIO]: + """Load file contnt and file name based on the file path. + + Args: + input_file_path (str): source file + + Raises: + InvalidSource: _description_ + + Returns: + tuple[str, BytesIO]: file_name , file content + """ connector: ConnectorInstance = self.endpoint.connector_instance connector_settings: dict[str, Any] = connector.connector_metadata source_fs: fsspec.AbstractFileSystem = self.get_fsspec( @@ -395,7 +411,7 @@ def load_file(self, input_file_path: str) -> tuple[str, BytesIO]: file_content = remote_file.read() file_stream = BytesIO(file_content) - return remote_file.key, file_stream + return os.path.basename(input_file_path), file_stream @classmethod def add_input_file_to_api_storage( diff --git a/backend/workflow_manager/endpoint/static/dest/file.json b/backend/workflow_manager/endpoint/static/dest/file.json index aa3a1263e..95111993f 100644 --- a/backend/workflow_manager/endpoint/static/dest/file.json +++ b/backend/workflow_manager/endpoint/static/dest/file.json @@ -14,12 +14,6 @@ "minLength": 1, "maxLength": 100, "format": "file-path" - }, - "overwriteOutput": { - "type": "boolean", - "title": "Overwrite existing output", - "default": true, - "description": "Used to overwrite output document" } } } diff --git a/backend/workflow_manager/workflow/generator.py b/backend/workflow_manager/workflow/generator.py index 7c901b4bf..31aabafb6 100644 --- a/backend/workflow_manager/workflow/generator.py +++ b/backend/workflow_manager/workflow/generator.py @@ -17,6 +17,7 @@ logger = logging.getLogger(__name__) +# TODO: Can be removed as not getting used with UX chnages. class WorkflowGenerator: """Helps with generating a workflow using the LLM.""" diff --git a/backend/workflow_manager/workflow/views.py b/backend/workflow_manager/workflow/views.py index c02569355..0a3c451cf 100644 --- a/backend/workflow_manager/workflow/views.py +++ b/backend/workflow_manager/workflow/views.py @@ -4,6 +4,7 @@ from connector.connector_instance_helper import ConnectorInstanceHelper from django.conf import settings from django.db.models.query import QuerySet +from numpy import deprecate_with_doc from permissions.permission import IsOwner from pipeline.models import Pipeline from pipeline.pipeline_processor import PipelineProcessor @@ -78,6 +79,7 @@ def get_serializer_class(self) -> serializers.Serializer: else: return WorkflowSerializer + @deprecate_with_doc("Not using with the latest UX chnages") def _generate_workflow(self, workflow_id: str) -> WorkflowGenerator: registry_tools: list[Tool] = ToolProcessor.get_registry_tools() generator = WorkflowGenerator(workflow_id=workflow_id) @@ -86,18 +88,12 @@ def _generate_workflow(self, workflow_id: str) -> WorkflowGenerator: return generator def perform_update(self, serializer: WorkflowSerializer) -> Workflow: - """To edit a workflow. Regenerates the tool instances for a new prompt. + """To edit a workflow. Raises: WorkflowGenerationError """ kwargs = {} - if serializer.validated_data.get(WorkflowKey.PROMPT_TEXT): - workflow: Workflow = self.get_object() - generator = self._generate_workflow(workflow_id=workflow.id) - kwargs = { - WorkflowKey.LLM_RESPONSE: generator.llm_response, - WorkflowKey.WF_IS_ACTIVE: True, - } + try: workflow = serializer.save(**kwargs) return workflow diff --git a/docker/scripts/merge_env.py b/docker/scripts/merge_env.py index 3099da501..9a328b00c 100644 --- a/docker/scripts/merge_env.py +++ b/docker/scripts/merge_env.py @@ -54,8 +54,8 @@ def _merge_to_env_file(base_env_file_path: str, target_env: dict[str, str] = {}) target env. Args: - base_env_path (string): Base env file path. - target_env (dict, optional): Target env to use for merge. + base_env_file_path (string): Base env file path e.g. `sample.env` + target_env (dict, optional): Target env to use for merge e.g. `.env` Returns: string: File contents after merge. diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 018771100..2c7eec1d6 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -25,11 +25,13 @@ "cronstrue": "^2.48.0", "emoji-picker-react": "^4.8.0", "emoji-regex": "^10.3.0", + "file-saver": "^2.0.5", "framer-motion": "^11.2.10", "handlebars": "^4.7.8", "http-proxy-middleware": "^2.0.6", "js-cookie": "^3.0.5", "js-yaml": "^4.1.0", + "json-2-csv": "^5.5.4", "markdown-to-jsx": "^7.2.1", "moment": "^2.29.4", "moment-timezone": "^0.5.45", @@ -7520,6 +7522,14 @@ "resolved": "https://registry.npmjs.org/dedent/-/dedent-0.7.0.tgz", "integrity": "sha512-Q6fKUPqnAHAyhiUgFU7BUzLiv0kd8saH9al7tnu5Q/okj6dnupxyTgFIBjVzJATdfIAm9NAsvXNzjaKa+bxVyA==" }, + "node_modules/deeks": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/deeks/-/deeks-3.1.0.tgz", + "integrity": "sha512-e7oWH1LzIdv/prMQ7pmlDlaVoL64glqzvNgkgQNgyec9ORPHrT2jaOqMtRyqJuwWjtfb6v+2rk9pmaHj+F137A==", + "engines": { + "node": ">= 16" + } + }, "node_modules/deep-equal": { "version": "2.2.1", "resolved": "https://registry.npmjs.org/deep-equal/-/deep-equal-2.2.1.tgz", @@ -7748,6 +7758,14 @@ "node": ">=6" } }, + "node_modules/doc-path": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/doc-path/-/doc-path-4.1.1.tgz", + "integrity": "sha512-h1ErTglQAVv2gCnOpD3sFS6uolDbOKHDU1BZq+Kl3npPqroU3dYL42lUgMfd5UimlwtRgp7C9dLGwqQ5D2HYgQ==", + "engines": { + "node": ">=16" + } + }, "node_modules/doctrine": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/doctrine/-/doctrine-3.0.0.tgz", @@ -9045,6 +9063,11 @@ "webpack": "^4.0.0 || ^5.0.0" } }, + "node_modules/file-saver": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/file-saver/-/file-saver-2.0.5.tgz", + "integrity": "sha512-P9bmyZ3h/PRG+Nzga+rbdI4OEpNDzAVyy74uVO9ATgzLK6VtAsYybF/+TOCvrc0MO793d6+42lLyZTw7/ArVzA==" + }, "node_modules/filelist": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/filelist/-/filelist-1.0.4.tgz", @@ -12749,6 +12772,18 @@ "node": ">=4" } }, + "node_modules/json-2-csv": { + "version": "5.5.4", + "resolved": "https://registry.npmjs.org/json-2-csv/-/json-2-csv-5.5.4.tgz", + "integrity": "sha512-gB24IF5SvZn7QhEh6kp9QwFhRnI3FVEEXAGyq0xtPxqOQ4odYU3PU9pFKRoR1SGABxunQlBP6VFv0c8EnLbsLQ==", + "dependencies": { + "deeks": "3.1.0", + "doc-path": "4.1.1" + }, + "engines": { + "node": ">= 16" + } + }, "node_modules/json-parse-even-better-errors": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz", @@ -25878,6 +25913,11 @@ "resolved": "https://registry.npmjs.org/dedent/-/dedent-0.7.0.tgz", "integrity": "sha512-Q6fKUPqnAHAyhiUgFU7BUzLiv0kd8saH9al7tnu5Q/okj6dnupxyTgFIBjVzJATdfIAm9NAsvXNzjaKa+bxVyA==" }, + "deeks": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/deeks/-/deeks-3.1.0.tgz", + "integrity": "sha512-e7oWH1LzIdv/prMQ7pmlDlaVoL64glqzvNgkgQNgyec9ORPHrT2jaOqMtRyqJuwWjtfb6v+2rk9pmaHj+F137A==" + }, "deep-equal": { "version": "2.2.1", "resolved": "https://registry.npmjs.org/deep-equal/-/deep-equal-2.2.1.tgz", @@ -26052,6 +26092,11 @@ "@leichtgewicht/ip-codec": "^2.0.1" } }, + "doc-path": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/doc-path/-/doc-path-4.1.1.tgz", + "integrity": "sha512-h1ErTglQAVv2gCnOpD3sFS6uolDbOKHDU1BZq+Kl3npPqroU3dYL42lUgMfd5UimlwtRgp7C9dLGwqQ5D2HYgQ==" + }, "doctrine": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/doctrine/-/doctrine-3.0.0.tgz", @@ -27018,6 +27063,11 @@ "schema-utils": "^3.0.0" } }, + "file-saver": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/file-saver/-/file-saver-2.0.5.tgz", + "integrity": "sha512-P9bmyZ3h/PRG+Nzga+rbdI4OEpNDzAVyy74uVO9ATgzLK6VtAsYybF/+TOCvrc0MO793d6+42lLyZTw7/ArVzA==" + }, "filelist": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/filelist/-/filelist-1.0.4.tgz", @@ -29794,6 +29844,15 @@ "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-2.5.2.tgz", "integrity": "sha512-OYu7XEzjkCQ3C5Ps3QIZsQfNpqoJyZZA99wd9aWd05NCtC5pWOkShK2mkL6HXQR6/Cy2lbNdPlZBpuQHXE63gA==" }, + "json-2-csv": { + "version": "5.5.4", + "resolved": "https://registry.npmjs.org/json-2-csv/-/json-2-csv-5.5.4.tgz", + "integrity": "sha512-gB24IF5SvZn7QhEh6kp9QwFhRnI3FVEEXAGyq0xtPxqOQ4odYU3PU9pFKRoR1SGABxunQlBP6VFv0c8EnLbsLQ==", + "requires": { + "deeks": "3.1.0", + "doc-path": "4.1.1" + } + }, "json-parse-even-better-errors": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz", diff --git a/frontend/package.json b/frontend/package.json index 674ef6689..9a9bd21ae 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -20,11 +20,13 @@ "cronstrue": "^2.48.0", "emoji-picker-react": "^4.8.0", "emoji-regex": "^10.3.0", + "file-saver": "^2.0.5", "framer-motion": "^11.2.10", "handlebars": "^4.7.8", "http-proxy-middleware": "^2.0.6", "js-cookie": "^3.0.5", "js-yaml": "^4.1.0", + "json-2-csv": "^5.5.4", "markdown-to-jsx": "^7.2.1", "moment": "^2.29.4", "moment-timezone": "^0.5.45", diff --git a/frontend/src/components/agency/actions/Actions.jsx b/frontend/src/components/agency/actions/Actions.jsx index faa973bf0..ba57e40fb 100644 --- a/frontend/src/components/agency/actions/Actions.jsx +++ b/frontend/src/components/agency/actions/Actions.jsx @@ -77,10 +77,10 @@ function Actions({ statusBarMsg, initializeWfComp, stepLoader }) { // Enable Deploy as ETL Pipeline only when // destination connection_type is DATABASE and Source & Destination are Configured setCanAddETLPipeline( - (destination?.connection_type === "DATABASE" || - destination.connection_type === "MANUALREVIEW") && - source?.connector_instance && - destination.connector_instance + source?.connector_instance && + ((destination?.connection_type === "DATABASE" && + destination.connector_instance) || + destination.connection_type === "MANUALREVIEW") ); }, [source, destination]); useEffect(() => { diff --git a/frontend/src/components/agency/configure-connector-modal/ConfigureConnectorModal.jsx b/frontend/src/components/agency/configure-connector-modal/ConfigureConnectorModal.jsx index e781869e4..00f925625 100644 --- a/frontend/src/components/agency/configure-connector-modal/ConfigureConnectorModal.jsx +++ b/frontend/src/components/agency/configure-connector-modal/ConfigureConnectorModal.jsx @@ -1,6 +1,6 @@ import { Col, Modal, Row, Tabs, Typography } from "antd"; import PropTypes from "prop-types"; -import { useState } from "react"; +import { useEffect, useState } from "react"; import { ListOfConnectors } from "../list-of-connectors/ListOfConnectors"; import "./ConfigureConnectorModal.css"; @@ -28,6 +28,13 @@ function ConfigureConnectorModal({ setSelectedItemName, }) { const [activeKey, setActiveKey] = useState("1"); + useEffect(() => { + if (connectorMetadata) { + setActiveKey("2"); // If connector is already configured + } else { + setActiveKey("1"); // default value + } + }, [open, connectorMetadata]); const { setPostHogCustomEvent, posthogConnectorEventText } = usePostHogEvents(); const tabItems = [ diff --git a/frontend/src/components/agency/configure-forms-layout/ConfigureFormsLayout.jsx b/frontend/src/components/agency/configure-forms-layout/ConfigureFormsLayout.jsx index 9f6a4ed9b..d423e39e1 100644 --- a/frontend/src/components/agency/configure-forms-layout/ConfigureFormsLayout.jsx +++ b/frontend/src/components/agency/configure-forms-layout/ConfigureFormsLayout.jsx @@ -45,6 +45,7 @@ function ConfigureFormsLayout({ connDetails={connDetails} connType={connType} selectedSourceName={selectedItemName} + formDataConfig={formDataConfig} /> )} diff --git a/frontend/src/components/agency/ds-settings-card/DsSettingsCard.jsx b/frontend/src/components/agency/ds-settings-card/DsSettingsCard.jsx index 9681a96c8..8d5b193e9 100644 --- a/frontend/src/components/agency/ds-settings-card/DsSettingsCard.jsx +++ b/frontend/src/components/agency/ds-settings-card/DsSettingsCard.jsx @@ -36,8 +36,6 @@ const tooltip = { const disabledIdsByType = { FILE_SYSTEM: [ "box|4d94d237-ce4b-45d8-8f34-ddeefc37c0bf", - "google_cloud_storage|109bbe7b-8861-45eb-8841-7244e833d97b", - "azure_cloud_storage|1476a54a-ed17-4a01-9f8f-cb7e4cf91c8a", "http|6fdea346-86e4-4383-9a21-132db7c9a576", ], }; @@ -89,21 +87,26 @@ function DsSettingsCard({ type, endpointDetails, message }) { input: , output: , }; + + const setUpdatedInputoptions = (inputOption) => { + setInputOptions((prevInputOptions) => { + // Check if inputOption already exists in prevInputOptions + if (prevInputOptions.some((opt) => opt.value === inputOption.value)) { + return prevInputOptions; // Return previous state unchanged + } else { + // Create a new array with the existing options and the new option + const updatedInputOptions = [...prevInputOptions, inputOption]; + return updatedInputOptions; + } + }); + }; + useEffect(() => { try { const inputOption = require("../../../plugins/dscard-input-options/DsSettingsCardInputOptions").inputOption; if (flags.manual_review && inputOption) { - setInputOptions((prevInputOptions) => { - // Check if inputOption already exists in prevInputOptions - if (prevInputOptions.some((opt) => opt.value === inputOption.value)) { - return prevInputOptions; // Return previous state unchanged - } else { - // Create a new array with the existing options and the new option - const updatedInputOptions = [...prevInputOptions, inputOption]; - return updatedInputOptions; - } - }); + setUpdatedInputoptions(inputOption); } } catch { // The component will remain null of it is not available @@ -114,9 +117,7 @@ function DsSettingsCard({ type, endpointDetails, message }) { const inputOption = require("../../../plugins/dscard-input-options/AppDeploymentCardInputOptions").appDeploymentInputOption; if (flags.app_deployment && inputOption) { - const updatedInputOptions = inputOptions; - updatedInputOptions.push(inputOption); - setInputOptions(updatedInputOptions); + setUpdatedInputoptions(inputOption); } } catch { // The component will remain null of it is not available diff --git a/frontend/src/components/agency/workflow-execution-layout/WorkflowExecutionMain.jsx b/frontend/src/components/agency/workflow-execution-layout/WorkflowExecutionMain.jsx index fb5acdaec..2d11f0eeb 100644 --- a/frontend/src/components/agency/workflow-execution-layout/WorkflowExecutionMain.jsx +++ b/frontend/src/components/agency/workflow-execution-layout/WorkflowExecutionMain.jsx @@ -1,7 +1,6 @@ import { Col, Row } from "antd"; import PropTypes from "prop-types"; -import { Prompt } from "../prompt/Prompt"; import { Steps } from "../steps/Steps"; import "./WorkflowExecutionMain.css"; import { InputOutput } from "../input-output/InputOutput"; @@ -19,9 +18,6 @@ function WorkflowExecutionMain({
-
- -
{ const data = res?.data || []; const prompts = details?.prompts; + if (activeKey === "0") { + const output = {}; + for (const key in data) { + if (Object.hasOwn(data, key)) { + output[key] = displayPromptResult(data[key], false); + } + } + setCombinedOutput(output); + return; + } const output = {}; prompts.forEach((item) => { if (item?.prompt_type === promptType.notes) { @@ -70,10 +93,7 @@ function CombinedOutput({ docId, setFilledFields }) { } output[item?.prompt_key] = ""; - let profileManager = selectedProfile || item?.profile_manager; - if (singlePassExtractMode) { - profileManager = defaultLlmProfile; - } + const profileManager = selectedProfile || item?.profile_manager; const outputDetails = data.find( (outputValue) => outputValue?.prompt_id === item?.prompt_id && @@ -119,6 +139,13 @@ function CombinedOutput({ docId, setFilledFields }) { let url; if (isSimplePromptStudio) { url = promptOutputApiSps(details?.tool_id, null, docId); + } else if (isPublicSource) { + url = publicOutputsDocApi( + id, + docId, + selectedProfile || defaultLlmProfile, + singlePassExtractMode + ); } else { url = `/api/v1/unstract/${ sessionDetails?.orgId @@ -127,6 +154,9 @@ function CombinedOutput({ docId, setFilledFields }) { }&document_manager=${docId}&is_single_pass_extract=${singlePassExtractMode}&profile_manager=${ selectedProfile || defaultLlmProfile }`; + if (activeKey === "0") { + url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/prompt-output/prompt-default-profile/?tool_id=${details?.tool_id}&document_manager=${docId}`; + } } const requestOptions = { method: "GET", @@ -135,7 +165,6 @@ function CombinedOutput({ docId, setFilledFields }) { "X-CSRFToken": sessionDetails?.csrfToken, }, }; - return axiosPrivate(requestOptions) .then((res) => res) .catch((err) => { @@ -144,14 +173,14 @@ function CombinedOutput({ docId, setFilledFields }) { }; const getAdapterInfo = () => { - axiosPrivate - .get( - `/api/v1/unstract/${sessionDetails?.orgId}/adapter/?adapter_type=LLM` - ) - .then((res) => { - const adapterList = res?.data; - setAdapterData(getLLMModelNamesForProfiles(llmProfiles, adapterList)); - }); + let url = `/api/v1/unstract/${sessionDetails?.orgId}/adapter/?adapter_type=LLM`; + if (isPublicSource) { + url = publicAdapterApi(id, "LLM"); + } + axiosPrivate.get(url).then((res) => { + const adapterList = res?.data; + setAdapterData(getLLMModelNamesForProfiles(llmProfiles, adapterList)); + }); }; if (isOutputLoading) { @@ -178,6 +207,8 @@ function CombinedOutput({ docId, setFilledFields }) { selectedProfile={selectedProfile} llmProfiles={llmProfiles} activeKey={activeKey} + adapterData={adapterData} + isSinglePass={singlePassExtractMode} /> ); } diff --git a/frontend/src/components/custom-tools/combined-output/JsonView.jsx b/frontend/src/components/custom-tools/combined-output/JsonView.jsx index 61bfc9f46..9b48af632 100644 --- a/frontend/src/components/custom-tools/combined-output/JsonView.jsx +++ b/frontend/src/components/custom-tools/combined-output/JsonView.jsx @@ -12,6 +12,7 @@ function JsonView({ activeKey, selectedProfile, llmProfiles, + isSinglePass, }) { useEffect(() => { Prism.highlightAll(); @@ -21,7 +22,9 @@ function JsonView({
}> - Default} key={"0"}> + {!isSinglePass && ( + Default} key={"0"}> + )} {adapterData.map((adapter, index) => ( {adapter.llm_model}} @@ -54,6 +57,7 @@ JsonView.propTypes = { selectedProfile: PropTypes.string, llmProfiles: PropTypes.array, activeKey: PropTypes.string, + isSinglePass: PropTypes.bool, }; export { JsonView }; diff --git a/frontend/src/components/custom-tools/custom-synonyms/CustomSynonyms.jsx b/frontend/src/components/custom-tools/custom-synonyms/CustomSynonyms.jsx index 2a85e4cb1..e6422218b 100644 --- a/frontend/src/components/custom-tools/custom-synonyms/CustomSynonyms.jsx +++ b/frontend/src/components/custom-tools/custom-synonyms/CustomSynonyms.jsx @@ -37,7 +37,7 @@ function CustomSynonyms() { const [rows, setRows] = useState([]); const [isLoading, setIsLoading] = useState(false); const { sessionDetails } = useSessionStore(); - const { details, updateCustomTool } = useCustomToolStore(); + const { details, updateCustomTool, isPublicSource } = useCustomToolStore(); const { setAlertDetails } = useAlertStore(); const axiosPrivate = useAxiosPrivate(); const handleException = useExceptionHandler(); @@ -96,7 +96,7 @@ function CustomSynonyms() { handleConfirm={() => handleDelete(index)} content="The word, along with its corresponding synonyms, will be permanently deleted." > - @@ -213,6 +213,7 @@ function CustomSynonyms() { type="primary" icon={} onClick={handleAddRow} + disabled={isPublicSource} > Rows @@ -223,6 +224,7 @@ function CustomSynonyms() { type="primary" onClick={handleSave} loading={isLoading} + disabled={isPublicSource} > Save diff --git a/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx b/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx index 249a34edd..0b2a27fa3 100644 --- a/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx +++ b/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx @@ -21,6 +21,7 @@ import { ManageDocsModal } from "../manage-docs-modal/ManageDocsModal"; import { PdfViewer } from "../pdf-viewer/PdfViewer"; import { TextViewerPre } from "../text-viewer-pre/TextViewerPre"; import usePostHogEvents from "../../../hooks/usePostHogEvents"; +import { useParams } from "react-router-dom"; const items = [ { @@ -63,7 +64,13 @@ try { } catch { // The component will remain null of it is not available } - +let publicDocumentApi; +try { + publicDocumentApi = + require("../../../plugins/prompt-studio-public-share/helpers/PublicShareAPIs").publicDocumentApi; +} catch { + // The component will remain null of it is not available +} function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) { const [openManageDocsModal, setOpenManageDocsModal] = useState(false); const [page, setPage] = useState(1); @@ -85,10 +92,12 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) { indexDocs, isSinglePassExtractLoading, isSimplePromptStudio, + isPublicSource, } = useCustomToolStore(); const { sessionDetails } = useSessionStore(); const axiosPrivate = useAxiosPrivate(); const { setPostHogCustomEvent } = usePostHogEvents(); + const { id } = useParams(); useEffect(() => { if (isSimplePromptStudio) { @@ -186,11 +195,14 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) { }; const getDocuments = async (viewType) => { + let url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/file/${details?.tool_id}?document_id=${selectedDoc?.document_id}&view_type=${viewType}`; + if (isPublicSource) { + url = publicDocumentApi(id, selectedDoc?.document_id, viewType); + } const requestOptions = { + url, method: "GET", - url: `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/file/${details?.tool_id}?document_id=${selectedDoc?.document_id}&view_type=${viewType}`, }; - return axiosPrivate(requestOptions) .then((res) => res) .catch((err) => { diff --git a/frontend/src/components/custom-tools/editable-text/EditableText.jsx b/frontend/src/components/custom-tools/editable-text/EditableText.jsx index 2344429e5..37c2c81bd 100644 --- a/frontend/src/components/custom-tools/editable-text/EditableText.jsx +++ b/frontend/src/components/custom-tools/editable-text/EditableText.jsx @@ -25,6 +25,7 @@ function EditableText({ indexDocs, selectedDoc, isSinglePassExtractLoading, + isPublicSource, } = useCustomToolStore(); useEffect(() => { @@ -90,7 +91,9 @@ function EditableText({ onBlur={handleBlur} onClick={() => setIsEditing(true)} disabled={ - disableLlmOrDocChange.includes(promptId) || isSinglePassExtractLoading + disableLlmOrDocChange.includes(promptId) || + isSinglePassExtractLoading || + isPublicSource } /> ); @@ -114,7 +117,8 @@ function EditableText({ disabled={ disableLlmOrDocChange.includes(promptId) || indexDocs.includes(selectedDoc?.document_id) || - isSinglePassExtractLoading + isSinglePassExtractLoading || + isPublicSource } /> ); diff --git a/frontend/src/components/custom-tools/header-title/HeaderTitle.css b/frontend/src/components/custom-tools/header-title/HeaderTitle.css new file mode 100644 index 000000000..919146393 --- /dev/null +++ b/frontend/src/components/custom-tools/header-title/HeaderTitle.css @@ -0,0 +1,12 @@ +.custom-tools-name { + padding: 0px 8px; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + margin-left:auto; +} +.custom-tools-header { + display: flex; + justify-content: space-between; + margin-right: auto; +} diff --git a/frontend/src/components/custom-tools/header-title/HeaderTitle.jsx b/frontend/src/components/custom-tools/header-title/HeaderTitle.jsx new file mode 100644 index 000000000..d5db9cab6 --- /dev/null +++ b/frontend/src/components/custom-tools/header-title/HeaderTitle.jsx @@ -0,0 +1,34 @@ +import { ArrowLeftOutlined, EditOutlined } from "@ant-design/icons"; +import { Button, Typography } from "antd"; +import { useNavigate } from "react-router-dom"; + +import { useCustomToolStore } from "../../../store/custom-tool-store"; +import { useSessionStore } from "../../../store/session-store"; +import "./HeaderTitle.css"; + +function HeaderTitle() { + const navigate = useNavigate(); + const { details } = useCustomToolStore(); + const { sessionDetails } = useSessionStore(); + + return ( +
+
+ +
+
+ {details?.tool_name} + +
+
+ ); +} +export { HeaderTitle }; diff --git a/frontend/src/components/custom-tools/header/Header.jsx b/frontend/src/components/custom-tools/header/Header.jsx index 8cdf6bc56..da24362d3 100644 --- a/frontend/src/components/custom-tools/header/Header.jsx +++ b/frontend/src/components/custom-tools/header/Header.jsx @@ -1,14 +1,9 @@ -import { - ArrowLeftOutlined, - EditOutlined, - SettingOutlined, -} from "@ant-design/icons"; +import { SettingOutlined } from "@ant-design/icons"; import { Button, Tooltip, Typography } from "antd"; import PropTypes from "prop-types"; import { useState } from "react"; -import { useNavigate } from "react-router-dom"; -import "./Header.css"; +import { HeaderTitle } from "../header-title/HeaderTitle.jsx"; import { ExportToolIcon } from "../../../assets"; import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate"; import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; @@ -18,21 +13,36 @@ import { useSessionStore } from "../../../store/session-store"; import { CustomButton } from "../../widgets/custom-button/CustomButton"; import { ExportTool } from "../export-tool/ExportTool"; import usePostHogEvents from "../../../hooks/usePostHogEvents"; +import "./Header.css"; let SinglePassToggleSwitch; +let CloneButton; +let PromptShareButton; try { SinglePassToggleSwitch = require("../../../plugins/single-pass-toggle-switch/SinglePassToggleSwitch").SinglePassToggleSwitch; } catch { // The variable will remain undefined if the component is not available. } -function Header({ setOpenSettings, handleUpdateTool }) { +try { + PromptShareButton = + require("../../../plugins/prompt-studio-public-share/public-share-btn/PromptShareButton.jsx").PromptShareButton; + CloneButton = + require("../../../plugins/prompt-studio-clone/clone-btn/CloneButton.jsx").CloneButton; +} catch { + // The variable will remain undefined if the component is not available. +} +function Header({ + setOpenSettings, + handleUpdateTool, + setOpenShareModal, + setOpenCloneModal, +}) { const [isExportLoading, setIsExportLoading] = useState(false); - const { details } = useCustomToolStore(); + const { details, isPublicSource } = useCustomToolStore(); const { sessionDetails } = useSessionStore(); const { setAlertDetails } = useAlertStore(); const axiosPrivate = useAxiosPrivate(); - const navigate = useNavigate(); const handleException = useExceptionHandler(); const [userList, setUserList] = useState([]); const [openExportToolModal, setOpenExportToolModal] = useState(false); @@ -138,23 +148,15 @@ function Header({ setOpenSettings, handleUpdateTool }) { return (
-
- -
-
- {details?.tool_name} -
-
- -
+ {isPublicSource ? ( +
+ + {details?.tool_name} + +
+ ) : ( + + )}
{SinglePassToggleSwitch && ( @@ -167,6 +169,10 @@ function Header({ setOpenSettings, handleUpdateTool }) { />
+ {CloneButton && } + {PromptShareButton && ( + + )}
@@ -174,6 +180,7 @@ function Header({ setOpenSettings, handleUpdateTool }) { type="primary" onClick={() => handleShare(true)} loading={isExportLoading} + disabled={isPublicSource} > @@ -195,6 +202,8 @@ function Header({ setOpenSettings, handleUpdateTool }) { Header.propTypes = { setOpenSettings: PropTypes.func.isRequired, handleUpdateTool: PropTypes.func.isRequired, + setOpenCloneModal: PropTypes.func.isRequired, + setOpenShareModal: PropTypes.func.isRequired, }; export { Header }; diff --git a/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx b/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx index e82457114..49ff09bbf 100644 --- a/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx +++ b/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx @@ -75,6 +75,7 @@ function ManageDocsModal({ rawIndexStatus, summarizeIndexStatus, isSinglePassExtractLoading, + isPublicSource, } = useCustomToolStore(); const { messages } = useSocketCustomToolStore(); const axiosPrivate = useAxiosPrivate(); @@ -137,10 +138,9 @@ function ManageDocsModal({ }, [defaultLlmProfile, details]); useEffect(() => { - if (!open) { + if (!open || isPublicSource) { return; } - handleGetIndexStatus(rawLlmProfile, indexTypes.raw); }, [indexDocs, rawLlmProfile, open]); @@ -384,7 +384,8 @@ function ManageDocsModal({ isSinglePassExtractLoading || indexDocs.includes(item?.document_id) || isUploading || - !defaultLlmProfile + !defaultLlmProfile || + isPublicSource } /> @@ -408,7 +409,8 @@ function ManageDocsModal({ disableLlmOrDocChange?.length > 0 || isSinglePassExtractLoading || indexDocs.includes(item?.document_id) || - isUploading + isUploading || + isPublicSource } > @@ -423,7 +425,8 @@ function ManageDocsModal({ disabled={ disableLlmOrDocChange?.length > 0 || isSinglePassExtractLoading || - indexDocs.includes(item?.document_id) + indexDocs.includes(item?.document_id) || + isPublicSource } /> ), diff --git a/frontend/src/components/custom-tools/manage-llm-profiles/ManageLlmProfiles.jsx b/frontend/src/components/custom-tools/manage-llm-profiles/ManageLlmProfiles.jsx index c49b40c80..31e9f5304 100644 --- a/frontend/src/components/custom-tools/manage-llm-profiles/ManageLlmProfiles.jsx +++ b/frontend/src/components/custom-tools/manage-llm-profiles/ManageLlmProfiles.jsx @@ -1,5 +1,5 @@ import { DeleteOutlined, EditOutlined } from "@ant-design/icons"; -import { Button, Radio, Table, Typography } from "antd"; +import { Button, Radio, Table, Tooltip, Typography } from "antd"; import { useEffect, useState } from "react"; import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate"; @@ -65,11 +65,18 @@ function ManageLlmProfiles() { const [editLlmProfileId, setEditLlmProfileId] = useState(null); const axiosPrivate = useAxiosPrivate(); const { sessionDetails } = useSessionStore(); - const { details, defaultLlmProfile, updateCustomTool, llmProfiles } = - useCustomToolStore(); + const { + details, + defaultLlmProfile, + updateCustomTool, + llmProfiles, + isPublicSource, + } = useCustomToolStore(); const { setAlertDetails } = useAlertStore(); const handleException = useExceptionHandler(); const { setPostHogCustomEvent } = usePostHogEvents(); + const MAX_PROFILE_COUNT = 4; + const isMaxProfile = llmProfiles.length >= MAX_PROFILE_COUNT; const handleDefaultLlm = (profileId) => { try { @@ -125,7 +132,11 @@ function ManageLlmProfiles() { handleConfirm={() => handleDelete(item?.profile_id)} content="The LLM profile will be permanently deleted." > - @@ -134,6 +145,7 @@ function ManageLlmProfiles() {
- - Add New LLM Profile - + + + Add New LLM Profile + +
); diff --git a/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx b/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx index 23de585e2..446a9ffb3 100644 --- a/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx +++ b/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx @@ -6,7 +6,7 @@ import { CloseCircleFilled, InfoCircleFilled, } from "@ant-design/icons"; -import { useNavigate } from "react-router-dom"; +import { useNavigate, useParams } from "react-router-dom"; import { useCustomToolStore } from "../../../store/custom-tool-store"; import { useSessionStore } from "../../../store/session-store"; @@ -24,25 +24,13 @@ import { useTokenUsageStore } from "../../../store/token-usage-store"; import TabPane from "antd/es/tabs/TabPane"; import { ProfileInfoBar } from "../profile-info-bar/ProfileInfoBar"; -const columns = [ - { - title: "Document", - dataIndex: "document", - key: "document", - }, - { - title: "Token Count", - dataIndex: "token_count", - key: "token_count", - width: 200, - }, - { - title: "Value", - dataIndex: "value", - key: "value", - width: 600, - }, -]; +let publicOutputsApi; +try { + publicOutputsApi = + require("../../../plugins/prompt-studio-public-share/helpers/PublicShareAPIs").publicOutputsApi; +} catch { + // The component will remain null of it is not available +} const outputStatus = { yet_to_process: "YET_TO_PROCESS", @@ -68,25 +56,27 @@ function OutputForDocModal({ details, listOfDocs, selectedDoc, - defaultLlmProfile, disableLlmOrDocChange, singlePassExtractMode, isSinglePassExtractLoading, + isPublicSource, llmProfiles, + defaultLlmProfile, } = useCustomToolStore(); + const [selectedProfile, setSelectedProfile] = useState(defaultLlmProfile); const { sessionDetails } = useSessionStore(); const axiosPrivate = useAxiosPrivate(); const navigate = useNavigate(); + const { id } = useParams(); const { setAlertDetails } = useAlertStore(); - const { handleException } = useExceptionHandler(); + const handleException = useExceptionHandler(); const { tokenUsage } = useTokenUsageStore(); - const [selectedProfile, setSelectedProfile] = useState(defaultLlmProfile); useEffect(() => { if (!open) { return; } - handleGetOutputForDocs(); + handleGetOutputForDocs(selectedProfile || profileManagerId); getAdapterInfo(); }, [open, singlePassExtractMode, isSinglePassExtractLoading]); @@ -130,34 +120,6 @@ function OutputForDocModal({ // If data is provided, use it; otherwise, create a copy of the previous state const updatedPromptOutput = data || [...prev]; - // Get the keys of docOutputs - const keys = Object.keys(docOutputs); - - keys.forEach((key) => { - // Find the index of the prompt output corresponding to the document manager key - const index = updatedPromptOutput.findIndex( - (promptOutput) => promptOutput?.document_manager === key - ); - - let promptOutputInstance = {}; - // If the prompt output for the current key doesn't exist, skip it - if (index > -1) { - promptOutputInstance = updatedPromptOutput[index]; - promptOutputInstance["output"] = docOutputs[key]?.output; - } - - // Update output and isLoading properties based on docOutputs - promptOutputInstance["document_manager"] = key; - promptOutputInstance["isLoading"] = docOutputs[key]?.isLoading || false; - - // Update the prompt output instance in the array - if (index > -1) { - updatedPromptOutput[index] = promptOutputInstance; - } else { - updatedPromptOutput.push(promptOutputInstance); - } - }); - return updatedPromptOutput; }); }; @@ -172,22 +134,21 @@ function OutputForDocModal({ }; const handleGetOutputForDocs = (profile = profileManagerId) => { - if (singlePassExtractMode) { - profile = defaultLlmProfile; - } - if (!profile) { setRows([]); return; } + let url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/prompt-output/?tool_id=${details?.tool_id}&prompt_id=${promptId}&profile_manager=${profile}&is_single_pass_extract=${singlePassExtractMode}`; + if (isPublicSource) { + url = publicOutputsApi(id, promptId, profile, singlePassExtractMode); + } const requestOptions = { method: "GET", - url: `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/prompt-output/?tool_id=${details?.tool_id}&prompt_id=${promptId}&profile_manager=${profile}&is_single_pass_extract=${singlePassExtractMode}`, + url, headers: { "X-CSRFToken": sessionDetails?.csrfToken, }, }; - setIsLoading(true); axiosPrivate(requestOptions) .then((res) => { @@ -231,10 +192,14 @@ function OutputForDocModal({ const result = { key: item?.document_id, document: item?.document_name, - token_count: ( + token_count: !singlePassExtractMode && ( ), @@ -274,6 +239,26 @@ function OutputForDocModal({ } }; + const columns = [ + { + title: "Document", + dataIndex: "document", + key: "document", + }, + !singlePassExtractMode && { + title: "Token Count", + dataIndex: "token_count", + key: "token_count", + width: 200, + }, + { + title: "Value", + dataIndex: "value", + key: "value", + width: 600, + }, + ].filter(Boolean); + return ( ))} {" "} - +
- + Save diff --git a/frontend/src/components/custom-tools/prompt-card/Header.jsx b/frontend/src/components/custom-tools/prompt-card/Header.jsx index 954b5706d..41826f5df 100644 --- a/frontend/src/components/custom-tools/prompt-card/Header.jsx +++ b/frontend/src/components/custom-tools/prompt-card/Header.jsx @@ -41,6 +41,7 @@ function Header({ singlePassExtractMode, isSinglePassExtractLoading, indexDocs, + isPublicSource, } = useCustomToolStore(); const [isDisablePrompt, setIsDisablePrompt] = useState(promptDetails?.active); @@ -130,7 +131,8 @@ function Header({ disabled={ disableLlmOrDocChange.includes(promptDetails?.prompt_id) || isSinglePassExtractLoading || - indexDocs.includes(selectedDoc?.document_id) + indexDocs.includes(selectedDoc?.document_id) || + isPublicSource } > @@ -151,7 +153,8 @@ function Header({ updateStatus?.status === promptStudioUpdateStatus?.isUpdating) || disableLlmOrDocChange?.includes(promptDetails?.prompt_id) || - indexDocs?.includes(selectedDoc?.document_id) + indexDocs?.includes(selectedDoc?.document_id) || + isPublicSource } > @@ -168,7 +171,8 @@ function Header({ updateStatus?.status === promptStudioUpdateStatus?.isUpdating) || disableLlmOrDocChange?.includes(promptDetails?.prompt_id) || - indexDocs?.includes(selectedDoc?.document_id) + indexDocs?.includes(selectedDoc?.document_id) || + isPublicSource } > @@ -176,11 +180,13 @@ function Header({ )} - + + + handleDelete(promptDetails?.prompt_id)} @@ -194,7 +200,8 @@ function Header({ disabled={ disableLlmOrDocChange?.includes(promptDetails?.prompt_id) || isSinglePassExtractLoading || - indexDocs?.includes(selectedDoc?.document_id) + indexDocs?.includes(selectedDoc?.document_id) || + isPublicSource } > diff --git a/frontend/src/components/custom-tools/prompt-card/PromptCard.css b/frontend/src/components/custom-tools/prompt-card/PromptCard.css index d190564f2..6331b11dd 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptCard.css +++ b/frontend/src/components/custom-tools/prompt-card/PromptCard.css @@ -203,3 +203,8 @@ border-color: #00000026 !important; color: #000 !important; } + +.prompt-not-ran { + color: #575859; + font-size: 13px; +} diff --git a/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx b/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx index 6a67d9939..c945faec7 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx +++ b/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx @@ -18,6 +18,7 @@ import useTokenUsage from "../../../hooks/useTokenUsage"; import { useTokenUsageStore } from "../../../store/token-usage-store"; import { PromptCardItems } from "./PromptCardItems"; import "./PromptCard.css"; +import { useParams } from "react-router-dom"; const EvalModal = null; const getEvalMetrics = (param1, param2) => { @@ -34,7 +35,13 @@ try { } catch { // The component will remain null of it is not available } - +let publicOutputsApi; +try { + publicOutputsApi = + require("../../../plugins/prompt-studio-public-share/helpers/PublicShareAPIs").publicOutputsApi; +} catch { + // The component will remain null of it is not available +} function PromptCard({ promptDetails, handleChange, @@ -54,7 +61,7 @@ function PromptCard({ const [isCoverageLoading, setIsCoverageLoading] = useState(false); const [openOutputForDoc, setOpenOutputForDoc] = useState(false); const [progressMsg, setProgressMsg] = useState({}); - const [docOutputs, setDocOutputs] = useState({}); + const [docOutputs, setDocOutputs] = useState([]); const [timers, setTimers] = useState({}); const { getDropdownItems, @@ -69,6 +76,7 @@ function PromptCard({ singlePassExtractMode, isSinglePassExtractLoading, isSimplePromptStudio, + isPublicSource, } = useCustomToolStore(); const { messages } = useSocketCustomToolStore(); const { sessionDetails } = useSessionStore(); @@ -78,6 +86,7 @@ function PromptCard({ const { setPostHogCustomEvent } = usePostHogEvents(); const { tokenUsage, setTokenUsage } = useTokenUsageStore(); const { getTokenUsage } = useTokenUsage(); + const { id } = useParams(); useEffect(() => { const outputTypeData = getDropdownItems("output_type") || {}; @@ -123,19 +132,13 @@ function PromptCard({ if (isSinglePassExtractLoading) { return; } - if (selectedLlmProfileId !== promptDetails?.profile_id) { - handleChange( - selectedLlmProfileId, - promptDetails?.prompt_id, - "profile_manager" - ); - } }, [ selectedLlmProfileId, selectedDoc, listOfDocs, singlePassExtractMode, isSinglePassExtractLoading, + defaultLlmProfile, ]); useEffect(() => { @@ -173,7 +176,7 @@ function PromptCard({ useEffect(() => { const isProfilePresent = llmProfiles?.some( - (profile) => profile?.profile_id === selectedLlmProfileId + (profile) => profile?.profile_id === defaultLlmProfile ); // If selectedLlmProfileId is not present, set it to null @@ -192,23 +195,35 @@ function PromptCard({ const handleSelectDefaultLLM = (llmProfileId) => { setSelectedLlmProfileId(llmProfileId); + handleChange(llmProfileId, promptDetails?.prompt_id, "profile_manager"); }; const handleTypeChange = (value) => { handleChange(value, promptDetails?.prompt_id, "enforce_type", true); }; - const handleDocOutputs = (docId, isLoading, output) => { + const handleDocOutputs = (docId, promptId, profileId, isLoading, output) => { if (isSimplePromptStudio) { return; } setDocOutputs((prev) => { - const updatedDocOutputs = { ...prev }; + const updatedDocOutputs = [...prev]; + const key = `${promptId}__${docId}__${profileId}`; // Update the entry for the provided docId with isLoading and output - updatedDocOutputs[docId] = { + const newData = { + key, isLoading, output, }; + const index = updatedDocOutputs.findIndex((item) => item.key === key); + + if (index !== -1) { + // Update the existing object + updatedDocOutputs[index] = newData; + } else { + // Append the new object + updatedDocOutputs.push(newData); + } return updatedDocOutputs; }); }; @@ -276,16 +291,9 @@ function PromptCard({ if (validateInputs(profileManagerId, selectedLlmProfiles, coverAllDoc)) { return; } - - handleIsRunLoading( - selectedDoc?.document_id, - profileManagerId || selectedLlmProfileId, - true - ); setIsCoverageLoading(true); setCoverage(0); setCoverageTotal(0); - setDocOutputs({}); resetInfoMsgs(); const docId = selectedDoc?.document_id; @@ -309,7 +317,6 @@ function PromptCard({ return; } - handleDocOutputs(docId, true, null); if (runAllLLM) { let selectedProfiles = llmProfiles; if (!coverAllDoc && selectedLlmProfiles?.length > 0) { @@ -318,6 +325,13 @@ function PromptCard({ ); } for (const profile of selectedProfiles) { + handleDocOutputs( + docId, + promptDetails?.prompt_id, + profile?.profile_id, + true, + null + ); setIsCoverageLoading(true); handleIsRunLoading(selectedDoc?.document_id, profile?.profile_id, true); @@ -328,22 +342,29 @@ function PromptCard({ if (value || value === 0) { setCoverage((prev) => prev + 1); } - handleDocOutputs(docId, false, value); + handleDocOutputs( + docId, + promptDetails?.prompt_id, + profile?.profile_id, + false, + value + ); handleGetOutput(profile?.profile_id); updateDocCoverage( - coverage, promptDetails?.prompt_id, profile?.profile_id, docId ); }) .catch((err) => { - handleIsRunLoading( - selectedDoc?.document_id, + handleIsRunLoading(docId, profile?.profile_id, false); + handleDocOutputs( + docId, + promptDetails?.prompt_id, profile?.profile_id, - false + false, + null ); - handleDocOutputs(docId, false, null); setAlertDetails( handleException(err, `Failed to generate output for ${docId}`) ); @@ -354,19 +375,32 @@ function PromptCard({ runCoverageForAllDoc(coverAllDoc, profile.profile_id); } } else { + handleIsRunLoading(selectedDoc?.document_id, profileManagerId, true); + handleDocOutputs( + docId, + promptDetails?.prompt_id, + profileManagerId, + true, + null + ); handleRunApiRequest(docId, profileManagerId) .then((res) => { const data = res?.data?.output; const value = data[promptDetails?.prompt_key]; if (value || value === 0) { updateDocCoverage( - coverage, promptDetails?.prompt_id, profileManagerId, docId ); } - handleDocOutputs(docId, false, value); + handleDocOutputs( + docId, + promptDetails?.prompt_id, + profileManagerId, + false, + value + ); handleGetOutput(); setCoverageTotal(1); }) @@ -376,7 +410,13 @@ function PromptCard({ selectedLlmProfileId, false ); - handleDocOutputs(docId, false, null); + handleDocOutputs( + docId, + promptDetails?.prompt_id, + profileManagerId, + false, + null + ); setAlertDetails( handleException(err, `Failed to generate output for ${docId}`) ); @@ -430,23 +470,40 @@ function PromptCard({ } setIsCoverageLoading(true); - handleDocOutputs(docId, true, null); + handleDocOutputs( + docId, + promptDetails?.prompt_id, + profileManagerId, + true, + null + ); handleRunApiRequest(docId, profileManagerId) .then((res) => { const data = res?.data?.output; const outputValue = data[promptDetails?.prompt_key]; if (outputValue || outputValue === 0) { updateDocCoverage( - coverage, promptDetails?.prompt_id, profileManagerId, docId ); } - handleDocOutputs(docId, false, outputValue); + handleDocOutputs( + docId, + promptDetails?.prompt_id, + profileManagerId, + false, + outputValue + ); }) .catch((err) => { - handleDocOutputs(docId, false, null); + handleDocOutputs( + docId, + promptDetails?.prompt_id, + profileManagerId, + false, + null + ); setAlertDetails( handleException(err, `Failed to generate output for ${docId}`) ); @@ -462,23 +519,33 @@ function PromptCard({ }); }; - const updateDocCoverage = (coverage, promptId, profileManagerId, docId) => { - const key = `${promptId}_${profileManagerId}`; - const counts = { ...coverage }; - // If the key exists in the counts object, increment the count - if (counts[key]) { - if (!counts[key]?.docs_covered?.includes(docId)) { - counts[key]?.docs_covered?.push(docId); + const updateDocCoverage = (promptId, profileManagerId, docId) => { + setCoverage((prevCoverage) => { + const keySuffix = `${promptId}_${profileManagerId}`; + const key = singlePassExtractMode ? `singlepass_${keySuffix}` : keySuffix; + + // Create a shallow copy of the previous coverage state + const updatedCoverage = { ...prevCoverage }; + + // If the key exists in the updated coverage object, update the docs_covered array + if (updatedCoverage[key]) { + if (!updatedCoverage[key].docs_covered.includes(docId)) { + updatedCoverage[key].docs_covered = [ + ...updatedCoverage[key].docs_covered, + docId, + ]; + } + } else { + // Otherwise, add the key to the updated coverage object with the new entry + updatedCoverage[key] = { + prompt_id: promptId, + profile_manager: profileManagerId, + docs_covered: [docId], + }; } - } else { - // Otherwise, add the key to the counts object with an initial count of 1 - counts[key] = { - prompt_id: promptId, - profile_manager: profileManagerId, - docs_covered: [docId], - }; - } - setCoverage(counts); + + return updatedCoverage; + }); }; const handleRunApiRequest = async (docId, profileManagerId) => { @@ -648,6 +715,14 @@ function PromptCard({ } url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/prompt-output/?tool_id=${details?.tool_id}&prompt_id=${promptDetails?.prompt_id}&is_single_pass_extract=${singlePassExtractMode}`; } + if (isPublicSource) { + url = publicOutputsApi( + id, + promptDetails?.prompt_id, + selectedLlmProfileId, + singlePassExtractMode + ); + } if (isOutput) { url += `&document_manager=${selectedDoc?.document_id}`; } @@ -662,13 +737,12 @@ function PromptCard({ "X-CSRFToken": sessionDetails?.csrfToken, }, }; - return axiosPrivate(requestOptions) .then((res) => { const data = res?.data || []; if (singlePassExtractMode) { - const tokenUsageId = `single_pass__${selectedDoc?.document_id}`; + const tokenUsageId = `single_pass__${defaultLlmProfile}__${selectedDoc?.document_id}`; const usage = data?.find((item) => item?.run_id !== undefined); if (!tokenUsage[tokenUsageId] && usage) { @@ -693,7 +767,6 @@ function PromptCard({ const handleGetCoverageData = (data) => { data?.forEach((item) => { updateDocCoverage( - coverage, item?.prompt_id, item?.profile_manager, item?.document_manager @@ -761,7 +834,7 @@ function PromptCard({ setOpen={setOpenOutputForDoc} promptId={promptDetails?.prompt_id} promptKey={promptDetails?.prompt_key} - profileManagerId={promptDetails?.profile_manager} + profileManagerId={selectedLlmProfileId} docOutputs={docOutputs} /> diff --git a/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx b/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx index 4bc10fd7a..5fd5ff0ec 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx +++ b/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx @@ -4,6 +4,7 @@ import { CheckCircleOutlined, DatabaseOutlined, ExclamationCircleFilled, + InfoCircleFilled, InfoCircleOutlined, PlayCircleFilled, PlayCircleOutlined, @@ -36,12 +37,8 @@ import { TokenUsage } from "../token-usage/TokenUsage"; import { useCustomToolStore } from "../../../store/custom-tool-store"; import { Header } from "./Header"; import CheckableTag from "antd/es/tag/CheckableTag"; -import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate"; -import { useSessionStore } from "../../../store/session-store"; import { motion, AnimatePresence } from "framer-motion"; import { OutputForIndex } from "./OutputForIndex"; -import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; -import { useAlertStore } from "../../../store/alert-store"; import { useWindowDimensions } from "../../../hooks/useWindowDimensions"; const EvalBtn = null; @@ -80,22 +77,22 @@ function PromptCardItems({ isSinglePassExtractLoading, indexDocs, isSimplePromptStudio, + isPublicSource, + adapters, + defaultLlmProfile, } = useCustomToolStore(); const [isEditingPrompt, setIsEditingPrompt] = useState(false); const [isEditingTitle, setIsEditingTitle] = useState(false); const [expandCard, setExpandCard] = useState(true); const [llmProfileDetails, setLlmProfileDetails] = useState([]); const [openIndexProfile, setOpenIndexProfile] = useState(null); + const [coverageCount, setCoverageCount] = useState(0); const [enabledProfiles, setEnabledProfiles] = useState( llmProfiles.map((profile) => profile.profile_id) ); const [expandedProfiles, setExpandedProfiles] = useState([]); // New state for expanded profiles const [isIndexOpen, setIsIndexOpen] = useState(false); - const privateAxios = useAxiosPrivate(); - const { sessionDetails } = useSessionStore(); const { width: windowWidth } = useWindowDimensions(); - const handleException = useExceptionHandler(); - const { setAlertDetails } = useAlertStore(); const componentWidth = windowWidth * 0.4; const divRef = useRef(null); @@ -123,35 +120,26 @@ function PromptCardItems({ return result; }; - const getAdapterInfo = async () => { - privateAxios - .get(`/api/v1/unstract/${sessionDetails?.orgId}/adapter/`) - .then((res) => { - const adapterData = res?.data; - - // Update llmProfiles with additional fields - const updatedProfiles = llmProfiles?.map((profile) => { - return { ...getModelOrAdapterId(profile, adapterData), ...profile }; - }); - setLlmProfileDetails( - updatedProfiles - .map((profile) => ({ - ...profile, - isDefault: profile?.profile_id === selectedLlmProfileId, - isEnabled: enabledProfiles.includes(profile?.profile_id), - })) - .sort((a, b) => { - if (a?.isDefault) return -1; // Default profile comes first - if (b?.isDefault) return 1; - if (a?.isEnabled && !b?.isEnabled) return -1; // Enabled profiles come before disabled - if (!a?.isEnabled && b?.isEnabled) return 1; - return 0; - }) - ); - }) - .catch((err) => { - setAlertDetails(handleException(err)); - }); + const getAdapterInfo = async (adapterData) => { + // Update llmProfiles with additional fields + const updatedProfiles = llmProfiles?.map((profile) => { + return { ...getModelOrAdapterId(profile, adapterData), ...profile }; + }); + setLlmProfileDetails( + updatedProfiles + .map((profile) => ({ + ...profile, + isDefault: profile?.profile_id === selectedLlmProfileId, + isEnabled: enabledProfiles.includes(profile?.profile_id), + })) + .sort((a, b) => { + if (a?.isDefault) return -1; // Default profile comes first + if (b?.isDefault) return 1; + if (a?.isEnabled && !b?.isEnabled) return -1; // Enabled profiles come before disabled + if (!a?.isEnabled && b?.isEnabled) return 1; + return 0; + }) + ); }; const tooltipContent = (adapterConf) => ( @@ -235,6 +223,14 @@ function PromptCardItems({ } return <>; }; + const getCoverageData = () => { + const profileId = singlePassExtractMode + ? defaultLlmProfile + : selectedLlmProfileId; + const keySuffix = `${promptDetails?.prompt_id}_${profileId}`; + const key = singlePassExtractMode ? `singlepass_${keySuffix}` : keySuffix; + return coverage[key]?.docs_covered?.length || 0; + }; useEffect(() => { setExpandCard(true); @@ -244,10 +240,11 @@ function PromptCardItems({ if (singlePassExtractMode) { setExpandedProfiles([]); } - }, [singlePassExtractMode]); + setCoverageCount(getCoverageData()); + }, [singlePassExtractMode, coverage]); useEffect(() => { - getAdapterInfo(); + getAdapterInfo(adapters); }, [llmProfiles, selectedLlmProfileId, enabledProfiles]); return ( @@ -317,6 +314,7 @@ function PromptCardItems({ type="link" className="display-flex-align-center prompt-card-action-button" onClick={() => setOpenOutputForDoc(true)} + disabled={isPublicSource} > {isCoverageLoading ? ( @@ -325,11 +323,8 @@ function PromptCardItems({ )} - Coverage:{" "} - {coverage[ - `${promptDetails?.prompt_id}_${selectedLlmProfileId}` - ]?.docs_covered?.length || 0}{" "} - of {listOfDocs?.length || 0} docs + Coverage: {coverageCount} of{" "} + {listOfDocs?.length || 0} docs @@ -347,7 +342,8 @@ function PromptCardItems({ promptDetails?.prompt_id ) || isSinglePassExtractLoading || - indexDocs.includes(selectedDoc?.document_id) + indexDocs.includes(selectedDoc?.document_id) || + isPublicSource } onChange={(value) => handleTypeChange(value)} /> @@ -428,6 +424,7 @@ function PromptCardItems({ onChange={(checked) => handleTagChange(checked, profileId) } + disabled={isPublicSource} className={isChecked ? "checked" : "unchecked"} > {isChecked ? ( @@ -472,6 +469,7 @@ function PromptCardItems({ onChange={() => handleSelectDefaultLLM(profileId) } + disabled={isPublicSource} > Default @@ -501,11 +499,24 @@ function PromptCardItems({ : "collapsed-output" } > - {displayPromptResult( - result.find( - (r) => r?.profileManager === profileId - )?.output, - true + {!result.find( + (r) => r?.profileManager === profileId + )?.output ? ( + + + + {" "} + Yet to run + + ) : ( + displayPromptResult( + result.find( + (r) => r?.profileManager === profileId + )?.output, + true + ) )}
@@ -520,7 +531,7 @@ function PromptCardItems({ disabled={ isRunLoading[ `${selectedDoc?.document_id}_${profileId}` - ] + ] || isPublicSource } > @@ -535,7 +546,7 @@ function PromptCardItems({ disabled={ isRunLoading[ `${selectedDoc?.document_id}_${profileId}` - ] + ] || isPublicSource } > diff --git a/frontend/src/components/custom-tools/tool-ide/ToolIde.css b/frontend/src/components/custom-tools/tool-ide/ToolIde.css index 41cb40103..e568d504b 100644 --- a/frontend/src/components/custom-tools/tool-ide/ToolIde.css +++ b/frontend/src/components/custom-tools/tool-ide/ToolIde.css @@ -1,103 +1,104 @@ /* Styles for ToolIde */ .tool-ide-layout { - background-color: var(--page-bg-2); - height: 100%; - display: flex; - flex-direction: column; + background-color: var(--page-bg-2); + height: 100%; + display: flex; + flex-direction: column; } .tool-ide-body { - flex: 1; - overflow-y: hidden; + flex: 1; + overflow-y: hidden; } .tool-ide-body-2 { - display: flex; - flex-direction: column; - height: 100%; + display: flex; + flex-direction: column; + height: 100%; } .tool-ide-main { - flex-grow: 1; - overflow-y: hidden; + flex-grow: 1; + overflow-y: hidden; } .tool-ide-col { - height: 100%; + height: 100%; } .tool-ide-main-row { - height: 100%; + height: 100%; } .tool-ide-footer { - padding: 0px 12px 12px 12px; + padding: 0px 12px 12px 12px; } .tool-ide-prompts { - padding: 12px 6px 1px 12px; - height: 100%; + padding: 12px 6px 1px 12px; + height: 100%; } .tool-ide-pdf { - padding: 12px 12px 1px 6px; - height: 100%; + padding: 12px 12px 1px 6px; + height: 100%; } -.tool-ide-prompts > div, .tool-ide-pdf > div { - background-color: var(--white); - height: 100%; +.tool-ide-prompts > div, +.tool-ide-pdf > div { + background-color: var(--white); + height: 100%; } .tool-ide-actions { - padding: 0px 12px; + padding: 0px 12px; } .tool-ide-logs { - height: 10vh; + height: 10vh; } .tool-ide-sider { - background-color: transparent !important; - height: 100%; + background-color: transparent !important; + height: 100%; } .tool-ide-sider-layout { - height: 100%; - background-color: transparent !important; + height: 100%; + background-color: transparent !important; } .tool-ide-sider-btn { - position: fixed; - transform: translate(-50%, 100%); - z-index: 1; - transition: "left 0.1s linear"; + position: fixed; + transform: translate(-50%, 100%); + z-index: 1; + transition: "left 0.1s linear"; } .tool-ide-collapse-panel { - background-color: var(--white); - border: none; - border-radius: 0px; + background-color: var(--white); + border: none; + border-radius: 0px; } /* Remove padding from modal content */ .custom-modal-wrapper .ant-modal-content { - padding: 0; - height: 90vh; - overflow-y: auto; + padding: 0; + height: 90vh; + overflow-y: auto; } .custom-modal-wrapper .tools-prompts-header-layout { - border-radius: 5px 5px 0px 0px; + border-radius: 5px 5px 0px 0px; } .custom-modal-gen-index .ant-modal { - top: 20px; - right: 20px; - position: absolute !important; + top: 20px; + right: 20px; + position: absolute !important; } .tool-ide-main-card .card-text { - font-size: 12px; + font-size: 12px; } diff --git a/frontend/src/components/custom-tools/tool-ide/ToolIde.jsx b/frontend/src/components/custom-tools/tool-ide/ToolIde.jsx index 7ade1571c..0dfaae6fd 100644 --- a/frontend/src/components/custom-tools/tool-ide/ToolIde.jsx +++ b/frontend/src/components/custom-tools/tool-ide/ToolIde.jsx @@ -1,6 +1,6 @@ import { FullscreenExitOutlined, FullscreenOutlined } from "@ant-design/icons"; import { Col, Collapse, Modal, Row } from "antd"; -import { useState } from "react"; +import { useState, useEffect } from "react"; import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate"; import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; @@ -16,8 +16,11 @@ import { SettingsModal } from "../settings-modal/SettingsModal"; import { ToolsMain } from "../tools-main/ToolsMain"; import "./ToolIde.css"; import usePostHogEvents from "../../../hooks/usePostHogEvents.js"; - let OnboardMessagesModal; +let PromptShareModal; +let PromptShareLink; +let CloneTitle; +let HeaderPublic; let slides; try { OnboardMessagesModal = @@ -28,6 +31,18 @@ try { OnboardMessagesModal = null; slides = []; } +try { + PromptShareModal = + require("../../../plugins/prompt-studio-public-share/public-share-modal/PromptShareModal.jsx").PromptShareModal; + PromptShareLink = + require("../../../plugins/prompt-studio-public-share/public-link-modal/PromptShareLink.jsx").PromptShareLink; + CloneTitle = + require("../../../plugins/prompt-studio-clone/clone-title-modal/CloneTitle.jsx").CloneTitle; + HeaderPublic = + require("../../../plugins/prompt-studio-public-share/header-public/HeaderPublic.jsx").HeaderPublic; +} catch (err) { + // Do nothing if plugins are not loaded. +} function ToolIde() { const [showLogsModal, setShowLogsModal] = useState(false); @@ -42,6 +57,8 @@ function ToolIde() { indexDocs, pushIndexDoc, deleteIndexDoc, + shareId, + isPublicSource, } = useCustomToolStore(); const { sessionDetails } = useSessionStore(); const { promptOnboardingMessage } = sessionDetails; @@ -50,6 +67,10 @@ function ToolIde() { const handleException = useExceptionHandler(); const [loginModalOpen, setLoginModalOpen] = useState(true); const { setPostHogCustomEvent } = usePostHogEvents(); + const [openShareLink, setOpenShareLink] = useState(false); + const [openShareConfirmation, setOpenShareConfirmation] = useState(false); + const [openShareModal, setOpenShareModal] = useState(false); + const [openCloneModal, setOpenCloneModal] = useState(false); const openLogsModal = () => { setShowLogsModal(true); @@ -58,6 +79,17 @@ function ToolIde() { const closeLogsModal = () => { setShowLogsModal(false); }; + useEffect(() => { + if (openShareModal) { + if (shareId) { + setOpenShareConfirmation(false); + setOpenShareLink(true); + } else { + setOpenShareConfirmation(true); + setOpenShareLink(false); + } + } + }, [shareId, openShareModal]); const genExtra = () => ( + {isPublicSource && HeaderPublic && }
@@ -244,6 +279,26 @@ function ToolIde() { setOpen={setOpenSettings} handleUpdateTool={handleUpdateTool} /> + {PromptShareModal && ( + + )} + {PromptShareLink && ( + + )} + {CloneTitle && ( + + )} {!promptOnboardingMessage && OnboardMessagesModal && ( { const data = res?.data; updatedCusTool["llmProfiles"] = data; + if (shareManagerToolSource) { + const reqOpsShare = { + method: "GET", + url: shareManagerToolSource(id, sessionDetails?.orgId), + }; + return handleApiRequest(reqOpsShare); + } + }) + .then((res) => { + const data = res?.data; + updatedCusTool["shareId"] = data?.share_id; + const reqOpsLlmProfiles = { + method: "GET", + url: `/api/v1/unstract/${sessionDetails?.orgId}/adapter/`, + }; + + return handleApiRequest(reqOpsLlmProfiles); + }) + .then((res) => { + const data = res?.data; + updatedCusTool["adapters"] = data; }) .catch((err) => { setAlertDetails(handleException(err, "Failed to load the custom tool")); diff --git a/frontend/src/components/input-output/add-source/AddSource.jsx b/frontend/src/components/input-output/add-source/AddSource.jsx index 417936abb..eb56d76a2 100644 --- a/frontend/src/components/input-output/add-source/AddSource.jsx +++ b/frontend/src/components/input-output/add-source/AddSource.jsx @@ -21,6 +21,7 @@ function AddSource({ handleUpdate, connDetails, connType, + formDataConfig, }) { const [spec, setSpec] = useState({}); const [formData, setFormData] = useState({}); @@ -54,6 +55,7 @@ function AddSource({ axiosPrivate(requestOptions) .then((res) => { const data = res?.data; + setFormData(metadata || {}); setSpec(data?.json_schema || {}); if (data?.oauth) { setOAuthProvider(data?.python_social_auth_backend); @@ -110,6 +112,7 @@ function AddSource({ metadata={metadata} selectedSourceName={selectedSourceName} connType={connType} + formDataConfig={formDataConfig} /> ); } @@ -125,6 +128,7 @@ AddSource.propTypes = { handleUpdate: PropTypes.func, connDetails: PropTypes.object, connType: PropTypes.string, + formDataConfig: PropTypes.object, }; export { AddSource }; diff --git a/frontend/src/components/input-output/configure-ds/ConfigureDs.jsx b/frontend/src/components/input-output/configure-ds/ConfigureDs.jsx index 92f403962..75ecb2e11 100644 --- a/frontend/src/components/input-output/configure-ds/ConfigureDs.jsx +++ b/frontend/src/components/input-output/configure-ds/ConfigureDs.jsx @@ -31,6 +31,7 @@ function ConfigureDs({ metadata, selectedSourceName, connType, + formDataConfig, }) { const formRef = createRef(null); const axiosPrivate = useAxiosPrivate(); @@ -252,7 +253,11 @@ function ConfigureDs({ .then((res) => { const data = res?.data; if (sourceTypes.connectors.includes(type)) { - handleUpdate({ connector_instance: data?.id }); + handleUpdate( + { connector_instance: data?.id, configuration: formDataConfig }, + true + ); + setIsTcSuccessful(false); return; } if (data) { @@ -351,6 +356,7 @@ ConfigureDs.propTypes = { metadata: PropTypes.object, selectedSourceName: PropTypes.string.isRequired, connType: PropTypes.string, + formDataConfig: PropTypes.object, }; export { ConfigureDs }; diff --git a/frontend/src/hooks/useExceptionHandler.jsx b/frontend/src/hooks/useExceptionHandler.jsx index d1268823d..4f7835e54 100644 --- a/frontend/src/hooks/useExceptionHandler.jsx +++ b/frontend/src/hooks/useExceptionHandler.jsx @@ -26,6 +26,13 @@ const useExceptionHandler = () => { // Handle validation errors if (setBackendErrors) { setBackendErrors(err?.response?.data); + } else { + return { + title: title, + type: "error", + content: errors?.[0]?.detail ? errors[0].detail : errMessage, + duration: duration, + }; } break; case "subscription_error": diff --git a/frontend/src/routes/Router.jsx b/frontend/src/routes/Router.jsx index d564714c0..291c39ec0 100644 --- a/frontend/src/routes/Router.jsx +++ b/frontend/src/routes/Router.jsx @@ -37,6 +37,8 @@ let ChatAppPage; let ChatAppLayout; let ManualReviewPage; let ReviewLayout; +let PublicPromptStudioHelper; + try { TrialRoutes = require("../plugins/subscription/trial-page/TrialEndPage.jsx").TrialEndPage; @@ -84,7 +86,12 @@ try { } catch (err) { // Do nothing, Not-found Page will be triggered. } - +try { + PublicPromptStudioHelper = + require("../plugins/prompt-studio-public-share/helpers/PublicPromptStudioHelper.js").PublicPromptStudioHelper; +} catch (err) { + // Do nothing, Not-found Page will be triggered. +} function Router() { return ( @@ -111,6 +118,18 @@ function Router() { } /> )} + {PublicPromptStudioHelper && ( + } + > + } /> + } + /> + + )} {/* protected routes */} diff --git a/frontend/src/store/custom-tool-store.js b/frontend/src/store/custom-tool-store.js index f7b94e1c3..452989545 100644 --- a/frontend/src/store/custom-tool-store.js +++ b/frontend/src/store/custom-tool-store.js @@ -16,6 +16,9 @@ const defaultState = { singlePassExtractMode: false, isSinglePassExtractLoading: false, isSimplePromptStudio: false, + shareId: null, + isPublicSource: false, + adapters: [], }; const defaultPromptInstance = { diff --git a/prompt-service/pdm.lock b/prompt-service/pdm.lock index b6369da27..10ca43356 100644 --- a/prompt-service/pdm.lock +++ b/prompt-service/pdm.lock @@ -773,7 +773,7 @@ files = [ [[package]] name = "google-cloud-aiplatform" -version = "1.58.0" +version = "1.59.0" requires_python = ">=3.8" summary = "Vertex AI API client library" groups = ["default"] @@ -785,14 +785,14 @@ dependencies = [ "google-cloud-resource-manager<3.0.0dev,>=1.3.3", "google-cloud-storage<3.0.0dev,>=1.32.0", "packaging>=14.3", - "proto-plus<2.0.0dev,>=1.22.0", + "proto-plus<2.0.0dev,>=1.22.3", "protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.19.5", "pydantic<3", "shapely<3.0.0dev", ] files = [ - {file = "google-cloud-aiplatform-1.58.0.tar.gz", hash = "sha256:7a05aceac4a6c7eaa26e684e9f202b829cc7e57f82bffe7281684275a553fcad"}, - {file = "google_cloud_aiplatform-1.58.0-py2.py3-none-any.whl", hash = "sha256:21f1320860f4916183ec939fdf2ff3fc1d7fdde97fe5795974257ab21f9458ec"}, + {file = "google-cloud-aiplatform-1.59.0.tar.gz", hash = "sha256:2bebb59c0ba3e3b4b568305418ca1b021977988adbee8691a5bed09b037e7e63"}, + {file = "google_cloud_aiplatform-1.59.0-py2.py3-none-any.whl", hash = "sha256:549e6eb1844b0f853043309138ebe2db00de4bbd8197b3bde26804ac163ef52a"}, ] [[package]] @@ -1782,7 +1782,7 @@ files = [ [[package]] name = "llama-index-readers-file" -version = "0.1.29" +version = "0.1.30" requires_python = "<4.0,>=3.8.1" summary = "llama-index readers file integration" groups = ["default"] @@ -1793,8 +1793,8 @@ dependencies = [ "striprtf<0.0.27,>=0.0.26", ] files = [ - {file = "llama_index_readers_file-0.1.29-py3-none-any.whl", hash = "sha256:b25f3dbf7bf3e0635290e499e808db5ba955eab67f205a3ff1cea6a4eb93556a"}, - {file = "llama_index_readers_file-0.1.29.tar.gz", hash = "sha256:f9f696e738383e7d14078e75958fba5a7030f7994a20586e3140e1ca41395a54"}, + {file = "llama_index_readers_file-0.1.30-py3-none-any.whl", hash = "sha256:d5f6cdd4685ee73103c68b9bc0dfb0d05439033133fc6bd45ef31ff41519e723"}, + {file = "llama_index_readers_file-0.1.30.tar.gz", hash = "sha256:32f40465f2a8a65fa5773e03c9f4dd55164be934ae67fad62113680436787d91"}, ] [[package]] @@ -3095,13 +3095,13 @@ files = [ [[package]] name = "setuptools" -version = "70.2.0" +version = "70.3.0" requires_python = ">=3.8" summary = "Easily download, build, install, upgrade, and uninstall Python packages" groups = ["default"] files = [ - {file = "setuptools-70.2.0-py3-none-any.whl", hash = "sha256:b8b8060bb426838fbe942479c90296ce976249451118ef566a5a0b7d8b78fb05"}, - {file = "setuptools-70.2.0.tar.gz", hash = "sha256:bd63e505105011b25c3c11f753f7e3b8465ea739efddaccef8f0efac2137bac1"}, + {file = "setuptools-70.3.0-py3-none-any.whl", hash = "sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc"}, + {file = "setuptools-70.3.0.tar.gz", hash = "sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5"}, ] [[package]] @@ -3521,6 +3521,41 @@ files = [ {file = "ujson-5.10.0.tar.gz", hash = "sha256:b3cd8f3c5d8c7738257f1018880444f7b7d9b66232c64649f562d7ba86ad4bc1"}, ] +[[package]] +name = "unstract-adapters" +version = "0.21.0" +requires_python = "<3.12,>=3.9" +summary = "Unstract interface for LLMs, Embeddings and VectorDBs" +groups = ["default"] +dependencies = [ + "filetype~=1.2.0", + "httpx>=0.25.2", + "llama-index-embeddings-azure-openai==0.1.6", + "llama-index-embeddings-azure-openai==0.1.6", + "llama-index-embeddings-google==0.1.5", + "llama-index-embeddings-ollama==0.1.2", + "llama-index-llms-anthropic==0.1.11", + "llama-index-llms-anyscale==0.1.3", + "llama-index-llms-azure-openai==0.1.5", + "llama-index-llms-mistralai==0.1.10", + "llama-index-llms-ollama==0.1.3", + "llama-index-llms-palm==0.1.5", + "llama-index-llms-replicate==0.1.3", + "llama-index-llms-vertex==0.1.8", + "llama-index-vector-stores-milvus==0.1.18", + "llama-index-vector-stores-pinecone==0.1.4", + "llama-index-vector-stores-postgres==0.1.3", + "llama-index-vector-stores-qdrant==0.2.8", + "llama-index-vector-stores-weaviate==0.1.4", + "llama-index==0.10.38", + "llama-parse==0.4.1", + "singleton-decorator~=1.0.0", +] +files = [ + {file = "unstract_adapters-0.21.0-py3-none-any.whl", hash = "sha256:6c4f597602f55b80ba176a29a930755abd3494ff1c085f406536e7463902d655"}, + {file = "unstract_adapters-0.21.0.tar.gz", hash = "sha256:ebb8f80b7f26f04874bb5466fe0be8c3e1b51d44ed85c4e62289a7751c996af6"}, +] + [[package]] name = "unstract-core" version = "0.0.1" @@ -3557,35 +3592,14 @@ requires_python = "<3.11.1,>=3.9" summary = "A framework for writing Unstract Tools/Apps" groups = ["default"] dependencies = [ - "filetype~=1.2.0", - "httpx>=0.25.2", "jsonschema~=4.18.2", - "llama-index-embeddings-azure-openai==0.1.6", - "llama-index-embeddings-azure-openai==0.1.6", - "llama-index-embeddings-google==0.1.5", - "llama-index-embeddings-ollama==0.1.2", - "llama-index-llms-anthropic==0.1.11", - "llama-index-llms-anyscale==0.1.3", - "llama-index-llms-azure-openai==0.1.5", - "llama-index-llms-mistralai==0.1.10", - "llama-index-llms-ollama==0.1.3", - "llama-index-llms-palm==0.1.5", - "llama-index-llms-replicate==0.1.3", - "llama-index-llms-vertex==0.1.8", - "llama-index-vector-stores-milvus==0.1.18", - "llama-index-vector-stores-pinecone==0.1.4", - "llama-index-vector-stores-postgres==0.1.3", - "llama-index-vector-stores-qdrant==0.2.8", - "llama-index-vector-stores-weaviate==0.1.4", "llama-index==0.10.38", - "llama-index==0.10.38", - "llama-parse==0.4.1", "openai~=1.21.2", "python-dotenv==1.0.0", "python-magic~=0.4.27", - "singleton-decorator~=1.0.0", "tiktoken~=0.4.0", "transformers==4.37.0", + "unstract-adapters~=0.21.0", ] files = [ {file = "unstract_sdk-0.37.0-py3-none-any.whl", hash = "sha256:3bd83dfb7a760e73f35cc9ed4bdadf645bd77f03f77d25578e842156645e6f1d"}, @@ -3605,13 +3619,13 @@ files = [ [[package]] name = "validators" -version = "0.31.0" +version = "0.32.0" requires_python = ">=3.8" summary = "Python Data Validation for Humans™" groups = ["default"] files = [ - {file = "validators-0.31.0-py3-none-any.whl", hash = "sha256:e15a600d81555a4cd409b17bf55946c5edec7748e776afc85ed0a19bdee54e56"}, - {file = "validators-0.31.0.tar.gz", hash = "sha256:de7574fc56a231c788162f3e7da15bc2053c5ff9e0281d9ff1afb3a7b69498df"}, + {file = "validators-0.32.0-py3-none-any.whl", hash = "sha256:e9ce1703afb0adf7724b0f98e4081d9d10e88fa5d37254d21e41f27774c020cd"}, + {file = "validators-0.32.0.tar.gz", hash = "sha256:9ee6e6d7ac9292b9b755a3155d7c361d76bb2dce23def4f0627662da1e300676"}, ] [[package]] diff --git a/run-platform.sh b/run-platform.sh index a11261fe4..6483b7dd5 100755 --- a/run-platform.sh +++ b/run-platform.sh @@ -8,6 +8,7 @@ blue_text='\033[94m' green_text='\033[32m' red_text='\033[31m' default_text='\033[39m' +yellow_text='\033[33m' # set -x/xtrace uses PS4 for more info PS4="$blue_text""${0}:${LINENO}: ""$default_text" @@ -68,7 +69,7 @@ display_help() { echo -e " -e, --only-env Only do env files setup" echo -e " -p, --only-pull Only do docker images pull" echo -e " -b, --build-local Build docker images locally" - echo -e " -u, --upgrade Upgrade services" + echo -e " -u, --update Update services version" echo -e " -x, --trace Enables trace mode" echo -e " -V, --verbose Print verbose logs" echo -e " -v, --version Docker images version tag (default \"latest\")" @@ -92,8 +93,8 @@ parse_args() { -b | --build-local) opt_build_local=true ;; - -u | --upgrade) - opt_upgrade=true + -u | --update) + opt_update=true ;; -x | --trace) set -o xtrace # display every line before execution; enables PS4 @@ -125,21 +126,40 @@ parse_args() { debug "OPTION only_env: $opt_only_env" debug "OPTION only_pull: $opt_only_pull" debug "OPTION build_local: $opt_build_local" - debug "OPTION upgrade: $opt_upgrade" + debug "OPTION upgrade: $opt_update" debug "OPTION verbose: $opt_verbose" debug "OPTION version: $opt_version" } do_git_pull() { - if [ "$opt_upgrade" = false ]; then + if [ "$opt_update" = false ]; then return fi - echo -e "Performing git switch to ""$blue_text""main branch""$default_text". - git switch main + echo "Fetching release tags." + git fetch --quiet --tags - echo -e "Performing ""$blue_text""git pull""$default_text"" on main branch." - git pull + if [[ "$opt_version" == "latest" ]]; then + branch=`git describe --tags --abbrev=0` + elif [[ "$opt_version" == "main" ]]; then + branch="main" + opt_build_local=true + echo -e "Choosing ""$blue_text""local build""$default_text"" of Docker images from ""$blue_text""main""$default_text"" branch." + elif [ -z $(git tag -l "$opt_version") ]; then + echo -e "$red_text""Version not found.""$default_text" + if [[ ! $opt_version == v* ]]; then + echo -e "$red_text""Version must be provided with a 'v' prefix (e.g. v0.47.0).""$default_text" + fi + exit 1 + else + branch="$opt_version" + fi + + echo -e "Performing ""$blue_text""git checkout""$default_text"" to ""$blue_text""$branch""$default_text""." + git checkout --quiet $branch + + echo -e "Performing ""$blue_text""git pull""$default_text"" on ""$blue_text""$branch""$default_text""." + git pull --quiet $(git remote) $branch } setup_env() { @@ -179,7 +199,7 @@ setup_env() { fi fi echo -e "Created env for ""$blue_text""$service""$default_text" at ""$blue_text""$env_path""$default_text"." - elif [ "$opt_upgrade" = true ]; then + elif [ "$opt_update" = true ]; then python3 $script_dir/docker/scripts/merge_env.py $sample_env_path $env_path if [ $? -ne 0 ]; then exit 1 @@ -191,7 +211,7 @@ setup_env() { if [ ! -e "$script_dir/docker/essentials.env" ]; then cp "$script_dir/docker/sample.essentials.env" "$script_dir/docker/essentials.env" echo -e "Created env for ""$blue_text""essential services""$default_text"" at ""$blue_text""$script_dir/docker/essentials.env""$default_text""." - elif [ "$opt_upgrade" = true ]; then + elif [ "$opt_update" = true ]; then python3 $script_dir/docker/scripts/merge_env.py "$script_dir/docker/sample.essentials.env" "$script_dir/docker/essentials.env" if [ $? -ne 0 ]; then exit 1 @@ -201,9 +221,9 @@ setup_env() { # Not part of an upgrade. if [ ! -e "$script_dir/docker/proxy_overrides.yaml" ]; then - echo -e "NOTE: Proxy behaviour can be overridden via ""$blue_text""$script_dir/docker/proxy_overrides.yaml""$default_text""." + echo -e "NOTE: Reverse proxy config can be overridden via ""$blue_text""$script_dir/docker/proxy_overrides.yaml""$default_text""." else - echo -e "Found ""$blue_text""$script_dir/docker/proxy_overrides.yaml""$default_text"". Proxy behaviour will be overridden." + echo -e "Found ""$blue_text""$script_dir/docker/proxy_overrides.yaml""$default_text"". ""$yellow_text""Reverse proxy config will be overridden.""$default_text" fi if [ "$opt_only_env" = true ]; then @@ -220,7 +240,7 @@ build_services() { echo -e "$red_text""Failed to build docker images.""$default_text" exit 1 } - elif [ "$first_setup" = true ] || [ "$opt_upgrade" = true ]; then + elif [ "$first_setup" = true ] || [ "$opt_update" = true ]; then echo -e "$blue_text""Pulling""$default_text"" docker images tag ""$blue_text""$opt_version""$default_text""." # Try again on a slow network. VERSION=$opt_version $docker_compose_cmd -f $script_dir/docker/docker-compose.yaml pull || @@ -245,13 +265,20 @@ run_services() { echo -e "$blue_text""Starting docker containers in detached mode""$default_text" VERSION=$opt_version $docker_compose_cmd up -d - if [ "$opt_upgrade" = true ]; then + if [ "$opt_update" = true ]; then echo "" - echo -e "$green_text""Upgraded platform to $opt_version version.""$default_text" + if [[ "$opt_version" == "main" ]]; then + echo -e "$green_text""Updated platform to latest main (unstable).""$default_text" + else + echo -e "$green_text""Updated platform to $opt_version version.""$default_text" + fi fi echo -e "\nOnce the services are up, visit ""$blue_text""http://frontend.unstract.localhost""$default_text"" in your browser." - echo "See logs with:" + echo -e "\nSee logs with:" echo -e " ""$blue_text""$docker_compose_cmd -f docker/docker-compose.yaml logs -f""$default_text" + echo -e "Configure services by updating corresponding ""$yellow_text""/.env""$default_text"" files." + echo -e "Make sure to ""$yellow_text""restart""$default_text"" the services with:" + echo -e " ""$blue_text""$docker_compose_cmd -f docker/docker-compose.yaml up -d""$default_text" popd 1>/dev/null } @@ -264,7 +291,7 @@ check_dependencies opt_only_env=false opt_only_pull=false opt_build_local=false -opt_upgrade=false +opt_update=false opt_verbose=false opt_version="latest" diff --git a/unstract/connectors/pyproject.toml b/unstract/connectors/pyproject.toml index fc292d93d..f25839f3e 100644 --- a/unstract/connectors/pyproject.toml +++ b/unstract/connectors/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "s3fs[boto3]==2023.6.0", # For Minio "PyDrive2[fsspec]==1.15.4", # For GDrive "oauth2client==4.1.3", # For GDrive - "dropboxdrivefs==1.3.1", # For Dropbox + "dropboxdrivefs==1.4.1", # For Dropbox "boxfs==0.2.1", # For Box "gcsfs==2023.6.0", # For GoogleCloudStorage "adlfs==2023.8.0", # For AzureCloudStorage diff --git a/unstract/connectors/src/unstract/connectors/databases/unstract_db.py b/unstract/connectors/src/unstract/connectors/databases/unstract_db.py index d426c26e6..7ec82abc2 100644 --- a/unstract/connectors/src/unstract/connectors/databases/unstract_db.py +++ b/unstract/connectors/src/unstract/connectors/databases/unstract_db.py @@ -68,7 +68,7 @@ def test_credentials(self) -> bool: try: self.get_engine() except Exception as e: - raise ConnectorError(str(e)) from e + raise ConnectorError(f"Error while connecting to DB: {str(e)}") from e return True def execute(self, query: str) -> Any: diff --git a/unstract/connectors/src/unstract/connectors/filesystems/__init__.py b/unstract/connectors/src/unstract/connectors/filesystems/__init__.py index 4826bdb25..7635b7ca1 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/__init__.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/__init__.py @@ -1,5 +1,7 @@ from unstract.connectors import ConnectorDict # type: ignore from unstract.connectors.filesystems.register import register_connectors +from .local_storage.local_storage import * # noqa: F401, F403 + connectors: ConnectorDict = {} register_connectors(connectors) diff --git a/unstract/connectors/src/unstract/connectors/filesystems/azure_cloud_storage/azure_cloud_storage.py b/unstract/connectors/src/unstract/connectors/filesystems/azure_cloud_storage/azure_cloud_storage.py index d83426e2c..fe61fd6f2 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/azure_cloud_storage/azure_cloud_storage.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/azure_cloud_storage/azure_cloud_storage.py @@ -65,7 +65,11 @@ def get_fsspec_fs(self) -> AzureBlobFileSystem: def test_credentials(self) -> bool: """To test credentials for Azure Cloud Storage.""" try: - self.get_fsspec_fs().ls(f"{self.bucket}") + is_dir = bool(self.get_fsspec_fs().isdir(self.bucket)) + if not is_dir: + raise RuntimeError(f"'{self.bucket}' is not a valid bucket.") except Exception as e: - raise ConnectorError(str(e)) + raise ConnectorError( + f"Error from Azure Cloud Storage while testing connection: {str(e)}" + ) from e return True diff --git a/unstract/connectors/src/unstract/connectors/filesystems/box/box.py b/unstract/connectors/src/unstract/connectors/filesystems/box/box.py index e30cd264d..2989e7916 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/box/box.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/box/box.py @@ -24,7 +24,7 @@ def __init__(self, settings: dict[str, Any]): settings_dict = json.loads(settings["box_app_settings"]) if not isinstance(settings_dict, dict): raise ConnectorError( - "Box app settings is expected to be a valid JSON", + "Box app settings should be a valid JSON.", treat_as_user_message=True, ) except JSONDecodeError as e: @@ -112,8 +112,15 @@ def get_fsspec_fs(self) -> BoxFileSystem: def test_credentials(self) -> bool: """To test credentials for the Box connector.""" + is_dir = False try: - self.get_fsspec_fs().isdir("/") + is_dir = bool(self.get_fsspec_fs().isdir("/")) except Exception as e: - raise ConnectorError(str(e)) + raise ConnectorError( + f"Error from Box while testing connection: {str(e)}" + ) from e + if not is_dir: + raise ConnectorError( + "Unable to connect to Box, please check the connection settings." + ) return True diff --git a/unstract/connectors/src/unstract/connectors/filesystems/google_cloud_storage/google_cloud_storage.py b/unstract/connectors/src/unstract/connectors/filesystems/google_cloud_storage/google_cloud_storage.py index 915cb2b2c..b79bb8742 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/google_cloud_storage/google_cloud_storage.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/google_cloud_storage/google_cloud_storage.py @@ -64,7 +64,11 @@ def get_fsspec_fs(self) -> GCSFileSystem: def test_credentials(self) -> bool: """To test credentials for Google Cloud Storage.""" try: - is_dir = bool(self.get_fsspec_fs().isdir(f"{self.bucket}")) - return is_dir + is_dir = bool(self.get_fsspec_fs().isdir(self.bucket)) + if not is_dir: + raise RuntimeError(f"'{self.bucket}' is not a valid bucket.") except Exception as e: - raise ConnectorError(str(e)) + raise ConnectorError( + f"Error from Google Cloud Storage while testing connection: {str(e)}" + ) from e + return True diff --git a/unstract/connectors/src/unstract/connectors/filesystems/google_drive/google_drive.py b/unstract/connectors/src/unstract/connectors/filesystems/google_drive/google_drive.py index 1a2241b78..01b574273 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/google_drive/google_drive.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/google_drive/google_drive.py @@ -1,6 +1,7 @@ import json import logging import os +from pathlib import Path from typing import Any from oauth2client.client import OAuth2Credentials @@ -90,8 +91,25 @@ def get_fsspec_fs(self) -> GDriveFileSystem: def test_credentials(self) -> bool: """To test credentials for Google Drive.""" + is_dir = False try: - self.get_fsspec_fs().isdir("root") + is_dir = bool(self.get_fsspec_fs().isdir("root")) except Exception as e: - raise ConnectorError(str(e)) + raise ConnectorError( + f"Error from Google Drive while testing connection: {str(e)}" + ) from e + if not is_dir: + raise ConnectorError( + "Unable to connect to Google Drive, " + "please check the connection settings." + ) return True + + @staticmethod + def get_connector_root_dir(input_dir: str, **kwargs: Any) -> str: + """Get roor dir of gdrive.""" + root_path = kwargs.get("root_path") + if root_path is None: + raise ValueError("root_path is required to get root_dir for Google Drive") + input_dir = str(Path(root_path, input_dir.lstrip("/"))) + return f"{input_dir.strip('/')}/" diff --git a/unstract/connectors/src/unstract/connectors/filesystems/http/http.py b/unstract/connectors/src/unstract/connectors/filesystems/http/http.py index c12236e15..bf0a29dd6 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/http/http.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/http/http.py @@ -71,8 +71,16 @@ def get_fsspec_fs(self) -> HTTPFileSystem: def test_credentials(self) -> bool: """To test credentials for HTTP(S).""" + is_dir = False try: - self.get_fsspec_fs().isdir("/") + is_dir = bool(self.get_fsspec_fs().isdir("/")) except Exception as e: - raise ConnectorError(str(e)) + raise ConnectorError( + f"Error while connecting to HTTP server: {str(e)}" + ) from e + if not is_dir: + raise ConnectorError( + "Unable to connect to HTTP server, " + "please check the connection settings." + ) return True diff --git a/unstract/connectors/src/unstract/connectors/filesystems/local_storage/local_storage.py b/unstract/connectors/src/unstract/connectors/filesystems/local_storage/local_storage.py index ec179e990..783803fe3 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/local_storage/local_storage.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/local_storage/local_storage.py @@ -4,6 +4,7 @@ from fsspec.implementations.local import LocalFileSystem +from unstract.connectors.exceptions import ConnectorError from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem logger = logging.getLogger(__name__) @@ -60,9 +61,16 @@ def get_fsspec_fs(self) -> Any: def test_credentials(self, *args, **kwargs) -> bool: # type:ignore """To test credentials for LocalStorage.""" + is_dir = False try: - self.get_fsspec_fs().isdir("/") + is_dir = bool(self.get_fsspec_fs().isdir("/")) except Exception as e: - logger.error(f"Test creds failed: {e}") - return False + raise ConnectorError( + f"Error while connecting to local storage: {str(e)}" + ) from e + if not is_dir: + raise ConnectorError( + "Unable to connect to local storage, " + "please check the connection settings." + ) return True diff --git a/unstract/connectors/src/unstract/connectors/filesystems/minio/exceptions.py b/unstract/connectors/src/unstract/connectors/filesystems/minio/exceptions.py new file mode 100644 index 000000000..2b9b6ea0a --- /dev/null +++ b/unstract/connectors/src/unstract/connectors/filesystems/minio/exceptions.py @@ -0,0 +1,26 @@ +from unstract.connectors.exceptions import ConnectorError + +S3FS_EXC_TO_UNSTRACT_EXC = { + "The AWS Access Key Id you provided does not exist in our records": ( + "Invalid Key (Access Key ID) provided, please provide a valid one." + ), + "The request signature we calculated does not match the signature you provided": ( + "Invalid Secret (Secret Access Key) provided, please provide a valid one." + ), + "[Errno 22] S3 API Requests must be made to API port": ( # Minio only + "Request made to invalid port, please check the port of the endpoint URL." + ), +} + + +def handle_s3fs_exception(e: Exception) -> ConnectorError: + original_exc = str(e) + user_msg = "Error from S3 / MinIO while testing connection: " + exc_to_append = "" + for s3fs_exc, user_friendly_msg in S3FS_EXC_TO_UNSTRACT_EXC.items(): + if s3fs_exc in original_exc: + exc_to_append = user_friendly_msg + break + + user_msg += exc_to_append if exc_to_append else str(e) + return ConnectorError(message=user_msg) diff --git a/unstract/connectors/src/unstract/connectors/filesystems/minio/minio.py b/unstract/connectors/src/unstract/connectors/filesystems/minio/minio.py index f0cf798b1..b676cda42 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/minio/minio.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/minio/minio.py @@ -4,9 +4,10 @@ from s3fs.core import S3FileSystem -from unstract.connectors.exceptions import ConnectorError from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem +from .exceptions import handle_s3fs_exception + logger = logging.getLogger(__name__) @@ -17,7 +18,6 @@ def __init__(self, settings: dict[str, Any]): secret = settings["secret"] endpoint_url = settings["endpoint_url"] self.bucket = settings["bucket"] - self.path = settings["path"] client_kwargs = {} if "region_name" in settings and settings["region_name"] != "": client_kwargs = {"region_name": settings["region_name"]} @@ -38,11 +38,11 @@ def get_id() -> str: @staticmethod def get_name() -> str: - return "MinioFS/S3" + return "S3/Minio" @staticmethod def get_description() -> str: - return "All MinioFS compatible, including AWS S3" + return "Connect to AWS S3 and other compatible storage such as Minio." @staticmethod def get_icon() -> str: @@ -77,7 +77,9 @@ def get_fsspec_fs(self) -> S3FileSystem: def test_credentials(self) -> bool: """To test credentials for Minio.""" try: - self.get_fsspec_fs().isdir(f"{self.bucket}") + is_dir = bool(self.get_fsspec_fs().isdir(self.bucket)) + if not is_dir: + raise RuntimeError(f"'{self.bucket}' is not a valid bucket.") except Exception as e: - raise ConnectorError(str(e)) + raise handle_s3fs_exception(e) from e return True diff --git a/unstract/connectors/src/unstract/connectors/filesystems/minio/static/json_schema.json b/unstract/connectors/src/unstract/connectors/filesystems/minio/static/json_schema.json index b248cad7f..7f3d3997c 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/minio/static/json_schema.json +++ b/unstract/connectors/src/unstract/connectors/filesystems/minio/static/json_schema.json @@ -32,12 +32,6 @@ "title": "Bucket Name", "description": "Name of the bucket to be restricted to." }, - "path": { - "type": "string", - "title": "Path", - "default": "", - "description": "Path to restrict to. (example /path/to/restrict/to)" - }, "endpoint_url": { "type": "string", "title": "Endpoint URL", diff --git a/unstract/connectors/src/unstract/connectors/filesystems/unstract_file_system.py b/unstract/connectors/src/unstract/connectors/filesystems/unstract_file_system.py index 304488e0d..34057315a 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/unstract_file_system.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/unstract_file_system.py @@ -1,5 +1,6 @@ import logging from abc import ABC, abstractmethod +from typing import Any from fsspec import AbstractFileSystem @@ -69,3 +70,15 @@ def get_fsspec_fs(self) -> AbstractFileSystem: def test_credentials(self) -> bool: """Override to test credentials for a connector.""" pass + + @staticmethod + def get_connector_root_dir(input_dir: str, **kwargs: Any) -> str: + """Override to get root dir of a connector.""" + return f"{input_dir.strip('/')}/" + + def create_dir_if_not_exists(self, input_dir: str) -> None: + """Override to create dir of a connector if not exists.""" + fs_fsspec = self.get_fsspec_fs() + is_dir = fs_fsspec.isdir(input_dir) + if not is_dir: + fs_fsspec.mkdir(input_dir) diff --git a/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/exceptions.py b/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/exceptions.py index 1453a8f2c..bab073573 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/exceptions.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/exceptions.py @@ -7,16 +7,22 @@ def handle_dropbox_exception(e: DropboxException) -> ConnectorError: - user_msg = "" + user_msg = "Error from Dropbox while testing connection: " if isinstance(e, ExcAuthError): if isinstance(e.error, AuthError): if e.error.is_expired_access_token(): - user_msg = "Expired access token" + user_msg += ( + "Expired access token, please regenerate it " + "through the Dropbox console." + ) elif e.error.is_invalid_access_token(): - user_msg = "Invalid access token" + user_msg += ( + "Invalid access token, please enter a valid token " + "from the Dropbox console." + ) else: - user_msg = e.error._tag + user_msg += e.error._tag elif isinstance(e, ApiError): if e.user_message_text is not None: - user_msg = e.user_message_text + user_msg += e.user_message_text return ConnectorError(message=user_msg, treat_as_user_message=True) diff --git a/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/zs_dropbox.py b/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/zs_dropbox.py index 31b743cbb..e7d5a7d37 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/zs_dropbox.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/zs_dropbox.py @@ -2,6 +2,7 @@ import os from typing import Any +from dropbox.exceptions import ApiError as DropBoxApiError from dropbox.exceptions import DropboxException from dropboxdrivefs import DropboxDriveFileSystem @@ -68,9 +69,23 @@ def test_credentials(self) -> bool: # self.get_fsspec_fs().connect() self.get_fsspec_fs().ls("") except DropboxException as e: - logger.error(f"Test creds failed: {e}") - raise handle_dropbox_exception(e) + raise handle_dropbox_exception(e) from e except Exception as e: - logger.error(f"Test creds failed: {e}") - raise ConnectorError(str(e)) + raise ConnectorError(f"Error while connecting to Dropbox: {str(e)}") from e return True + + @staticmethod + def get_connector_root_dir(input_dir: str, **kwargs: Any) -> str: + """Get roor dir of zs dropbox.""" + return f"/{input_dir.strip('/')}" + + def create_dir_if_not_exists(self, input_dir: str) -> None: + """Create roor dir of zs dropbox if not exists.""" + fs_fsspec = self.get_fsspec_fs() + try: + fs_fsspec.isdir(input_dir) + except ( + DropBoxApiError + ) as e: # Dropbox returns this exception when directory is not present + logger.debug(f"Path not found in dropbox {e.error}") + fs_fsspec.mkdir(input_dir)