From 72cdd4af2d126ef243ea2057feac78bb230617c6 Mon Sep 17 00:00:00 2001 From: Rahul Johny <116638720+johnyrahul@users.noreply.github.com> Date: Tue, 9 Apr 2024 10:48:50 +0530 Subject: [PATCH] Authorization/Permissions Prompt studio (#112) * intial commit for permission in prompt studio * Create prompts to be part of prompt studio core * added share to prompt project * Prettier fix * File upload changes * Profile manger creation from prompt studio * Removing relation from custom tool on user deletion * Refractored the code * Refractored thr code to avoid exposing unused endpoints * added permission around prompt * Support for sharing the exported tool * Support for sharing the exported tool and access validation * Conflict resolution models * Conflict resolution jsx files * Check prompt registry is associated with custom tool * implemented UI for tool sharing * code clean up * Updated migrations * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removed commented code * Fixed code climate issues * Fix code duplication * Removed redudant checks * Update backend/prompt_studio/permission.py Signed-off-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> --------- Signed-off-by: Rahul Johny <116638720+johnyrahul@users.noreply.github.com> Signed-off-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> Co-authored-by: jagadeeswaran-zipstack Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: siddhiq Co-authored-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> Co-authored-by: Neha <115609453+nehabagdia@users.noreply.github.com> --- backend/account/authentication_controller.py | 4 +- backend/account/serializer.py | 2 +- backend/adapter_processor/serializers.py | 10 +- backend/file_management/urls.py | 15 - backend/file_management/views.py | 111 ------- backend/prompt_studio/permission.py | 21 ++ .../prompt_profile_manager/urls.py | 6 - .../prompt_profile_manager/views.py | 8 +- backend/prompt_studio/prompt_studio/urls.py | 8 - backend/prompt_studio/prompt_studio/views.py | 38 +-- .../prompt_studio_core/constants.py | 6 + .../0012_customtool_shared_users.py | 24 ++ .../prompt_studio_core/models.py | 23 ++ .../prompt_studio_helper.py | 5 +- .../prompt_studio_core/serializers.py | 45 ++- .../prompt_studio/prompt_studio_core/urls.py | 54 +++- .../prompt_studio/prompt_studio_core/views.py | 275 ++++++++++++++++-- .../prompt_studio_document_manager/urls.py | 8 - .../prompt_studio_document_manager/views.py | 4 +- .../prompt_studio_index_manager/urls.py | 8 - .../prompt_studio_index_manager/views.py | 3 +- ...ptstudioregistry_shared_to_org_and_more.py | 30 ++ .../prompt_studio_registry/models.py | 25 ++ .../prompt_studio_registry_helper.py | 9 +- .../prompt_studio_registry/serializers.py | 14 +- .../prompt_studio_registry/urls.py | 14 +- .../prompt_studio_registry/views.py | 40 +-- backend/tool_instance/tool_instance_helper.py | 21 ++ .../add-llm-profile/AddLlmProfile.jsx | 5 +- .../document-manager/DocumentManager.jsx | 14 +- .../components/custom-tools/header/Header.jsx | 101 ++++++- .../list-of-tools/ListOfTools.jsx | 85 +++++- .../manage-docs-modal/ManageDocsModal.jsx | 2 +- .../manage-llm-profiles/ManageLlmProfiles.jsx | 8 +- .../OutputAnalyzerCard.jsx | 10 +- .../custom-tools/prompt-card/PromptCard.jsx | 8 +- .../custom-tools/tool-ide/ToolIde.jsx | 11 +- .../custom-tools/tools-main/ToolsMain.jsx | 9 +- .../helpers/custom-tools/CustomToolsHelper.js | 4 +- .../share-permission/SharePermission.css | 4 + .../share-permission/SharePermission.jsx | 196 +++++++------ 41 files changed, 852 insertions(+), 436 deletions(-) create mode 100644 backend/prompt_studio/permission.py create mode 100644 backend/prompt_studio/prompt_studio_core/migrations/0012_customtool_shared_users.py create mode 100644 backend/prompt_studio/prompt_studio_registry/migrations/0006_promptstudioregistry_shared_to_org_and_more.py diff --git a/backend/account/authentication_controller.py b/backend/account/authentication_controller.py index 275c5b838..72c1c2de2 100644 --- a/backend/account/authentication_controller.py +++ b/backend/account/authentication_controller.py @@ -402,8 +402,10 @@ def remove_users_from_organization( is_removed = False if is_removed: OrganizationMember.objects.filter(user__in=ids_list).delete() - # removing adapter relations on user removal + # removing user m2m relations , while removing user for user_id in ids_list: + User.objects.get(pk=user_id).shared_exported_tools.clear() + User.objects.get(pk=user_id).shared_custom_tool.clear() User.objects.get(pk=user_id).shared_adapters.clear() return is_removed diff --git a/backend/account/serializer.py b/backend/account/serializer.py index 77cb89b2a..15d165d3a 100644 --- a/backend/account/serializer.py +++ b/backend/account/serializer.py @@ -78,7 +78,7 @@ class Meta: class UserSerializer(serializers.ModelSerializer): class Meta: model = User - fields = ("id", "email") + fields = ("id", "username") class OrganizationSignupResponseSerializer(serializers.Serializer): diff --git a/backend/adapter_processor/serializers.py b/backend/adapter_processor/serializers.py index 162119815..7d1f75022 100644 --- a/backend/adapter_processor/serializers.py +++ b/backend/adapter_processor/serializers.py @@ -1,7 +1,7 @@ import json from typing import Any -from account.models import User +from account.serializer import UserSerializer from adapter_processor.adapter_processor import AdapterProcessor from adapter_processor.constants import AdapterKeys from cryptography.fernet import Fernet @@ -120,16 +120,10 @@ def to_representation(self, instance: AdapterInstance) -> dict[str, str]: return rep -class UserSerializer(serializers.ModelSerializer): - class Meta: - model = User - fields = ("id", "username") - - class SharedUserListSerializer(BaseAdapterSerializer): """Inherits BaseAdapterSerializer. - Used for listing adapters + Used for listing adapter users """ shared_users = UserSerializer(many=True) diff --git a/backend/file_management/urls.py b/backend/file_management/urls.py index 709b3795a..8b0ae2dcf 100644 --- a/backend/file_management/urls.py +++ b/backend/file_management/urls.py @@ -45,20 +45,5 @@ path("file/download", file_downlaod, name="download"), path("file/upload", file_upload, name="upload"), path("file/delete", file_delete, name="delete"), - path( - "prompt-studio/file/upload", - prompt_studio_file_upload, - name="prompt_studio_upload", - ), - path( - "prompt-studio/file/fetch_contents", - prompt_studio_fetch_content, - name="tool_studio_fetch", - ), - path( - "prompt-studio/file", - prompt_studio_file_list, - name="prompt_studio_list", - ), ] ) diff --git a/backend/file_management/views.py b/backend/file_management/views.py index db2a5429d..879ead194 100644 --- a/backend/file_management/views.py +++ b/backend/file_management/views.py @@ -1,10 +1,8 @@ import logging -import os from typing import Any from connector.models import ConnectorInstance from django.http import HttpRequest -from file_management.constants import FileViewTypes from file_management.exceptions import ( ConnectorInstanceNotFound, ConnectorOAuthError, @@ -15,16 +13,11 @@ from file_management.serializer import ( FileInfoIdeSerializer, FileInfoSerializer, - FileListRequestIdeSerializer, FileListRequestSerializer, - FileUploadIdeSerializer, FileUploadSerializer, ) from oauth2client.client import HttpAccessTokenRefreshError from prompt_studio.prompt_studio_document_manager.models import DocumentManager -from prompt_studio.prompt_studio_document_manager.prompt_studio_document_helper import ( - PromptStudioDocumentHelper, -) from rest_framework import serializers, status, viewsets from rest_framework.decorators import action from rest_framework.response import Response @@ -125,110 +118,6 @@ def upload(self, request: HttpRequest) -> Response: ) return Response({"message": "Files are uploaded successfully!"}) - @action(detail=True, methods=["post"]) - def upload_for_ide(self, request: HttpRequest) -> Response: - serializer = FileUploadIdeSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - uploaded_files: Any = serializer.validated_data.get("file") - tool_id: str = request.query_params.get("tool_id") - file_path = FileManagerHelper.handle_sub_directory_for_tenants( - request.org_id, - is_create=True, - user_id=request.user.user_id, - tool_id=tool_id, - ) - file_system = LocalStorageFS(settings={"path": file_path}) - - documents = [] - for uploaded_file in uploaded_files: - file_name = uploaded_file.name - - # Create a record in the db for the file - document = PromptStudioDocumentHelper.create( - tool_id=tool_id, document_name=file_name - ) - # Create a dictionary to store document data - doc = { - "document_id": document.document_id, - "document_name": document.document_name, - "tool": document.tool.tool_id, - } - # Store file - logger.info( - f"Uploading file: {file_name}" - if file_name - else "Uploading file" - ) - FileManagerHelper.upload_file( - file_system, - file_path, - uploaded_file, - file_name, - ) - documents.append(doc) - return Response({"data": documents}) - - @action(detail=True, methods=["get"]) - def fetch_contents_ide(self, request: HttpRequest) -> Response: - serializer = FileInfoIdeSerializer(data=request.GET) - serializer.is_valid(raise_exception=True) - document_id: str = serializer.validated_data.get("document_id") - document: DocumentManager = DocumentManager.objects.get(pk=document_id) - file_name: str = document.document_name - tool_id: str = serializer.validated_data.get("tool_id") - view_type: str = serializer.validated_data.get("view_type") - - filename_without_extension = file_name.rsplit(".", 1)[0] - if view_type == FileViewTypes.EXTRACT: - file_name = ( - f"{FileViewTypes.EXTRACT.lower()}/" - f"{filename_without_extension}.txt" - ) - if view_type == FileViewTypes.SUMMARIZE: - file_name = ( - f"{FileViewTypes.SUMMARIZE.lower()}/" - f"{filename_without_extension}.txt" - ) - - file_path = file_path = ( - FileManagerHelper.handle_sub_directory_for_tenants( - request.org_id, - is_create=True, - user_id=request.user.user_id, - tool_id=tool_id, - ) - ) - file_system = LocalStorageFS(settings={"path": file_path}) - if not file_path.endswith("/"): - file_path += "/" - file_path += file_name - contents = FileManagerHelper.fetch_file_contents(file_system, file_path) - return Response({"data": contents}, status=status.HTTP_200_OK) - - @action(detail=True, methods=["get"]) - def list_ide(self, request: HttpRequest) -> Response: - serializer = FileListRequestIdeSerializer(data=request.GET) - serializer.is_valid(raise_exception=True) - tool_id: str = serializer.validated_data.get("tool_id") - file_path = FileManagerHelper.handle_sub_directory_for_tenants( - request.org_id, - is_create=True, - user_id=request.user.user_id, - tool_id=tool_id, - ) - file_system = LocalStorageFS(settings={"path": file_path}) - try: - files = FileManagerHelper.list_files(file_system, file_path) - serializer = FileInfoSerializer(files, many=True) - # fetching only the name from path - for file in serializer.data: - file_name = os.path.basename(file.get("name")) - file["name"] = file_name - return Response(serializer.data) - except Exception as error: - logger.error(f"Exception thrown from file list, error {error}") - raise InternalServerError() - @action(detail=True, methods=["get"]) def delete(self, request: HttpRequest) -> Response: serializer = FileInfoIdeSerializer(data=request.GET) diff --git a/backend/prompt_studio/permission.py b/backend/prompt_studio/permission.py new file mode 100644 index 000000000..ba036172a --- /dev/null +++ b/backend/prompt_studio/permission.py @@ -0,0 +1,21 @@ +from typing import Any + +from rest_framework import permissions +from rest_framework.request import Request +from rest_framework.views import APIView + + +class PromptAcesssToUser(permissions.BasePermission): + """Is the crud to Prompt/Notes allowed to user.""" + + def has_object_permission( + self, request: Request, view: APIView, obj: Any + ) -> bool: + return ( + True + if ( + obj.tool_id.created_by == request.user + or obj.tool_id.shared_users.filter(pk=request.user.pk).exists() + ) + else False + ) diff --git a/backend/prompt_studio/prompt_profile_manager/urls.py b/backend/prompt_studio/prompt_profile_manager/urls.py index c846b259d..ae95f1fb9 100644 --- a/backend/prompt_studio/prompt_profile_manager/urls.py +++ b/backend/prompt_studio/prompt_profile_manager/urls.py @@ -3,7 +3,6 @@ from .views import ProfileManagerView -profile_manager_list = ProfileManagerView.as_view({"post": "create"}) profile_manager_detail = ProfileManagerView.as_view( { "get": "retrieve", @@ -16,11 +15,6 @@ urlpatterns = format_suffix_patterns( [ - path( - "profile-manager/", - profile_manager_list, - name="profile-manager-list", - ), path( "profile-manager//", profile_manager_detail, diff --git a/backend/prompt_studio/prompt_profile_manager/views.py b/backend/prompt_studio/prompt_profile_manager/views.py index 3ffbec51d..16d66ef1b 100644 --- a/backend/prompt_studio/prompt_profile_manager/views.py +++ b/backend/prompt_studio/prompt_profile_manager/views.py @@ -33,13 +33,9 @@ def get_queryset(self) -> Optional[QuerySet]: ProfileManagerKeys.CREATED_BY, ) if filter_args: - queryset = ProfileManager.objects.filter( - created_by=self.request.user, **filter_args - ) + queryset = ProfileManager.objects.filter(**filter_args) else: - queryset = ProfileManager.objects.filter( - created_by=self.request.user - ) + queryset = ProfileManager.objects.all() return queryset def create( diff --git a/backend/prompt_studio/prompt_studio/urls.py b/backend/prompt_studio/prompt_studio/urls.py index a0cb10ce4..e2f06d824 100644 --- a/backend/prompt_studio/prompt_studio/urls.py +++ b/backend/prompt_studio/prompt_studio/urls.py @@ -3,9 +3,6 @@ from .views import ToolStudioPromptView -prompt_studio_prompt_list = ToolStudioPromptView.as_view( - {"get": "list", "post": "create"} -) prompt_studio_prompt_detail = ToolStudioPromptView.as_view( { "get": "retrieve", @@ -17,11 +14,6 @@ urlpatterns = format_suffix_patterns( [ - path( - "prompt/", - prompt_studio_prompt_list, - name="prompt-studio-prompt-list", - ), path( "prompt//", prompt_studio_prompt_detail, diff --git a/backend/prompt_studio/prompt_studio/views.py b/backend/prompt_studio/prompt_studio/views.py index c226d3e36..aa1604bc7 100644 --- a/backend/prompt_studio/prompt_studio/views.py +++ b/backend/prompt_studio/prompt_studio/views.py @@ -1,16 +1,10 @@ import logging -from typing import Any, Optional +from typing import Optional -from account.custom_exceptions import DuplicateData -from django.db import IntegrityError from django.db.models import QuerySet -from django.http import HttpRequest -from prompt_studio.prompt_studio.constants import ( - ToolStudioPromptErrors, - ToolStudioPromptKeys, -) -from rest_framework import status, viewsets -from rest_framework.response import Response +from prompt_studio.permission import PromptAcesssToUser +from prompt_studio.prompt_studio.constants import ToolStudioPromptKeys +from rest_framework import viewsets from rest_framework.versioning import URLPathVersioning from utils.filtering import FilterHelper @@ -35,6 +29,7 @@ class ToolStudioPromptView(viewsets.ModelViewSet): versioning_class = URLPathVersioning serializer_class = ToolStudioPromptSerializer + permission_classes: list[type[PromptAcesssToUser]] = [PromptAcesssToUser] def get_queryset(self) -> Optional[QuerySet]: filter_args = FilterHelper.build_filter_args( @@ -42,26 +37,7 @@ def get_queryset(self) -> Optional[QuerySet]: ToolStudioPromptKeys.TOOL_ID, ) if filter_args: - queryset = ToolStudioPrompt.objects.filter( - created_by=self.request.user, **filter_args - ) + queryset = ToolStudioPrompt.objects.filter(**filter_args) else: - queryset = ToolStudioPrompt.objects.filter( - created_by=self.request.user, - ) + queryset = ToolStudioPrompt.objects.all() return queryset - - def create( - self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> Response: - serializer = self.get_serializer(data=request.data) - # TODO : Handle model related exceptions. - serializer.is_valid(raise_exception=True) - try: - self.perform_create(serializer) - except IntegrityError: - raise DuplicateData( - f"{ToolStudioPromptErrors.PROMPT_NAME_EXISTS}, \ - {ToolStudioPromptErrors.DUPLICATE_API}" - ) - return Response(serializer.data, status=status.HTTP_201_CREATED) diff --git a/backend/prompt_studio/prompt_studio_core/constants.py b/backend/prompt_studio/prompt_studio_core/constants.py index ac998aa92..806afd3e7 100644 --- a/backend/prompt_studio/prompt_studio_core/constants.py +++ b/backend/prompt_studio/prompt_studio_core/constants.py @@ -85,6 +85,12 @@ class ToolStudioPromptKeys: NOTES = "NOTES" +class FileViewTypes: + ORIGINAL = "ORIGINAL" + EXTRACT = "EXTRACT" + SUMMARIZE = "SUMMARIZE" + + class LogLevels: INFO = "INFO" ERROR = "ERROR" diff --git a/backend/prompt_studio/prompt_studio_core/migrations/0012_customtool_shared_users.py b/backend/prompt_studio/prompt_studio_core/migrations/0012_customtool_shared_users.py new file mode 100644 index 000000000..fa7a16bbe --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core/migrations/0012_customtool_shared_users.py @@ -0,0 +1,24 @@ +# Generated by Django 4.2.1 on 2024-03-26 04:26 + +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ( + "prompt_studio_core", + "0011_alter_customtool_postamble_alter_customtool_preamble", + ), + ] + + operations = [ + migrations.AddField( + model_name="customtool", + name="shared_users", + field=models.ManyToManyField( + related_name="shared_custom_tool", to=settings.AUTH_USER_MODEL + ), + ), + ] diff --git a/backend/prompt_studio/prompt_studio_core/models.py b/backend/prompt_studio/prompt_studio_core/models.py index 4878f0d9d..cac0a0607 100644 --- a/backend/prompt_studio/prompt_studio_core/models.py +++ b/backend/prompt_studio/prompt_studio_core/models.py @@ -1,12 +1,26 @@ import uuid +from typing import Any from account.models import User from adapter_processor.models import AdapterInstance from django.db import models +from django.db.models import QuerySet from prompt_studio.prompt_studio_core.constants import DefaultPrompts from utils.models.base_model import BaseModel +class CustomToolModelManager(models.Manager): + def get_queryset(self) -> QuerySet[Any]: + return super().get_queryset() + + def for_user(self, user: User) -> QuerySet[Any]: + return ( + self.get_queryset() + .filter(models.Q(created_by=user) | models.Q(shared_users=user)) + .distinct("tool_id") + ) + + class CustomTool(BaseModel): """Model to store the custom tools designed in the tool studio.""" @@ -82,6 +96,7 @@ class CustomTool(BaseModel): blank=True, editable=False, ) + exclude_failed = models.BooleanField( db_comment="Flag to make the answer null if it is incorrect", default=True, @@ -101,3 +116,11 @@ class CustomTool(BaseModel): enable_challenge = models.BooleanField( db_comment="Flag to enable or disable challenge", default=False ) + + # Introduced field to establish M2M relation between users and custom_tool. + # This will introduce intermediary table which relates both the models. + shared_users = models.ManyToManyField( + User, related_name="shared_custom_tool" + ) + + objects = CustomToolModelManager() diff --git a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py index 148017e6d..1043d5f23 100644 --- a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py +++ b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py @@ -223,7 +223,10 @@ def index_document( else: default_profile = ProfileManager.get_default_llm_profile(tool) file_path = FileManagerHelper.handle_sub_directory_for_tenants( - org_id, is_create=False, user_id=user_id, tool_id=tool_id + org_id, + is_create=False, + user_id=user_id, + tool_id=tool_id, ) file_path = str(Path(file_path) / file_name) diff --git a/backend/prompt_studio/prompt_studio_core/serializers.py b/backend/prompt_studio/prompt_studio_core/serializers.py index 088db4b17..595a45f2a 100644 --- a/backend/prompt_studio/prompt_studio_core/serializers.py +++ b/backend/prompt_studio/prompt_studio_core/serializers.py @@ -1,13 +1,17 @@ import logging from typing import Any +from account.models import User +from account.serializer import UserSerializer from django.core.exceptions import ObjectDoesNotExist +from file_management.constants import FileInformationKey from prompt_studio.prompt_profile_manager.models import ProfileManager from prompt_studio.prompt_studio.models import ToolStudioPrompt from prompt_studio.prompt_studio.serializers import ToolStudioPromptSerializer from prompt_studio.prompt_studio_core.constants import ToolStudioKeys as TSKeys from prompt_studio.prompt_studio_core.exceptions import DefaultProfileError from rest_framework import serializers +from utils.FileValidator import FileValidator from backend.serializers import AuditSerializer @@ -17,6 +21,10 @@ class CustomToolSerializer(AuditSerializer): + shared_users = serializers.PrimaryKeyRelatedField( + queryset=User.objects.all(), required=False, allow_null=True, many=True + ) + class Meta: model = CustomTool fields = "__all__" @@ -64,10 +72,45 @@ def to_representation(self, instance): # type: ignore class PromptStudioIndexSerializer(serializers.Serializer): document_id = serializers.CharField() - tool_id = serializers.CharField() class PromptStudioResponseSerializer(serializers.Serializer): file_name = serializers.CharField() tool_id = serializers.CharField() id = serializers.CharField() + + +class SharedUserListSerializer(serializers.ModelSerializer): + """Used for listing users of Custom tool.""" + + created_by = UserSerializer() + shared_users = UserSerializer(many=True) + + class Meta: + model = CustomTool + fields = ( + "tool_id", + "tool_name", + "created_by", + "shared_users", + ) + + +class FileInfoIdeSerializer(serializers.Serializer): + document_id = serializers.CharField() + view_type = serializers.CharField(required=False) + + +class FileUploadIdeSerializer(serializers.Serializer): + file = serializers.ListField( + child=serializers.FileField(), + required=True, + validators=[ + FileValidator( + allowed_extensions=FileInformationKey.FILE_UPLOAD_ALLOWED_EXT, + allowed_mimetypes=FileInformationKey.FILE_UPLOAD_ALLOWED_MIME, + min_size=0, + max_size=FileInformationKey.FILE_UPLOAD_MAX_SIZE, + ) + ], + ) diff --git a/backend/prompt_studio/prompt_studio_core/urls.py b/backend/prompt_studio/prompt_studio_core/urls.py index 089ac4ace..4d00fab5d 100644 --- a/backend/prompt_studio/prompt_studio_core/urls.py +++ b/backend/prompt_studio/prompt_studio_core/urls.py @@ -21,6 +21,12 @@ {"get": "list_profiles", "patch": "make_profile_default"} ) +prompt_studio_prompts = PromptStudioCoreView.as_view({"post": "create_prompt"}) + +prompt_studio_profilemanager = PromptStudioCoreView.as_view( + {"post": "create_profile_manager"} +) + prompt_studio_prompt_index = PromptStudioCoreView.as_view( {"post": "index_document"} ) @@ -33,6 +39,23 @@ prompt_studio_single_pass_extraction = PromptStudioCoreView.as_view( {"post": "single_pass_extraction"} ) +prompt_studio_users = PromptStudioCoreView.as_view( + {"get": "list_of_shared_users"} +) + + +prompt_studio_file = PromptStudioCoreView.as_view( + { + "post": "upload_for_ide", + "get": "fetch_contents_ide", + "delete": "delete_for_ide", + } +) + +prompt_studio_export = PromptStudioCoreView.as_view( + {"post": "export_tool", "get": "export_tool_info"} +) + urlpatterns = format_suffix_patterns( [ @@ -48,17 +71,27 @@ name="prompt-studio-choices", ), path( - "prompt-studio/profiles//", + "prompt-studio/prompt-studio-profile//", prompt_studio_profiles, name="prompt-studio-profiles", ), path( - "prompt-studio/index-document/", + "prompt-studio/prompt-studio-prompt//", + prompt_studio_prompts, + name="prompt-studio-prompts", + ), + path( + "prompt-studio/profilemanager/", + prompt_studio_profilemanager, + name="prompt-studio-profilemanager", + ), + path( + "prompt-studio/index-document/", prompt_studio_prompt_index, name="prompt-studio-prompt-index", ), path( - "prompt-studio/fetch_response/", + "prompt-studio/fetch_response/", prompt_studio_prompt_response, name="prompt-studio-prompt-response", ), @@ -72,5 +105,20 @@ prompt_studio_single_pass_extraction, name="prompt-studio-single-pass-extraction", ), + path( + "prompt-studio/users/", + prompt_studio_users, + name="prompt-studio-users", + ), + path( + "prompt-studio/file/", + prompt_studio_file, + name="prompt_studio_file", + ), + path( + "prompt-studio/export/", + prompt_studio_export, + name="prompt_studio_export", + ), ] ) diff --git a/backend/prompt_studio/prompt_studio_core/views.py b/backend/prompt_studio/prompt_studio_core/views.py index b5b575bf8..015bcbd23 100644 --- a/backend/prompt_studio/prompt_studio_core/views.py +++ b/backend/prompt_studio/prompt_studio_core/views.py @@ -5,14 +5,19 @@ from django.db import IntegrityError from django.db.models import QuerySet from django.http import HttpRequest -from permissions.permission import IsOwner +from file_management.file_management_helper import FileManagerHelper +from permissions.permission import IsOwner, IsOwnerOrSharedUser from prompt_studio.processor_loader import ProcessorConfig, load_plugins +from prompt_studio.prompt_profile_manager.constants import ProfileManagerErrors from prompt_studio.prompt_profile_manager.models import ProfileManager from prompt_studio.prompt_profile_manager.serializers import ( ProfileManagerSerializer, ) +from prompt_studio.prompt_studio.constants import ToolStudioPromptErrors from prompt_studio.prompt_studio.exceptions import FilenameMissingError +from prompt_studio.prompt_studio.serializers import ToolStudioPromptSerializer from prompt_studio.prompt_studio_core.constants import ( + FileViewTypes, ToolStudioErrors, ToolStudioKeys, ToolStudioPromptKeys, @@ -25,16 +30,35 @@ PromptStudioHelper, ) from prompt_studio.prompt_studio_document_manager.models import DocumentManager +from prompt_studio.prompt_studio_document_manager.prompt_studio_document_helper import ( # noqa: E501 + PromptStudioDocumentHelper, +) +from prompt_studio.prompt_studio_registry.prompt_studio_registry_helper import ( + PromptStudioRegistryHelper, +) +from prompt_studio.prompt_studio_registry.serializers import ( + ExportToolRequestSerializer, + PromptStudioRegistryInfoSerializer, +) from rest_framework import status, viewsets from rest_framework.decorators import action from rest_framework.request import Request from rest_framework.response import Response from rest_framework.versioning import URLPathVersioning from tool_instance.models import ToolInstance -from utils.filtering import FilterHelper + +from unstract.connectors.filesystems.local_storage.local_storage import ( + LocalStorageFS, +) from .models import CustomTool -from .serializers import CustomToolSerializer, PromptStudioIndexSerializer +from .serializers import ( + CustomToolSerializer, + FileInfoIdeSerializer, + FileUploadIdeSerializer, + PromptStudioIndexSerializer, + SharedUserListSerializer, +) logger = logging.getLogger(__name__) @@ -44,29 +68,20 @@ class PromptStudioCoreView(viewsets.ModelViewSet): versioning_class = URLPathVersioning - permission_classes = [IsOwner] serializer_class = CustomToolSerializer processor_plugins = load_plugins() + def get_permissions(self) -> list[Any]: + if self.action == "destroy": + return [IsOwner()] + + return [IsOwnerOrSharedUser()] + def get_queryset(self) -> Optional[QuerySet]: - filter_args = FilterHelper.build_filter_args( - self.request, - ToolStudioKeys.CREATED_BY, - ) - if filter_args: - queryset = CustomTool.objects.filter( - created_by=self.request.user, **filter_args - ) - else: - queryset = CustomTool.objects.filter( - created_by=self.request.user, - ) - return queryset + return CustomTool.objects.for_user(self.request.user) - def create( - self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> Response: + def create(self, request: HttpRequest) -> Response: serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) try: @@ -179,7 +194,7 @@ def make_profile_default( ) @action(detail=True, methods=["get"]) - def index_document(self, request: HttpRequest) -> Response: + def index_document(self, request: HttpRequest, pk: Any = None) -> Response: """API Entry point method to index input file. Args: @@ -192,21 +207,19 @@ def index_document(self, request: HttpRequest) -> Response: Returns: Response """ + tool = self.get_object() serializer = PromptStudioIndexSerializer(data=request.data) serializer.is_valid(raise_exception=True) - tool_id: str = serializer.validated_data.get( - ToolStudioPromptKeys.TOOL_ID - ) document_id: str = serializer.validated_data.get( ToolStudioPromptKeys.DOCUMENT_ID ) document: DocumentManager = DocumentManager.objects.get(pk=document_id) file_name: str = document.document_name unique_id = PromptStudioHelper.index_document( - tool_id=tool_id, + tool_id=str(tool.tool_id), file_name=file_name, org_id=request.org_id, - user_id=request.user.user_id, + user_id=tool.created_by.user_id, document_id=document_id, ) @@ -215,10 +228,10 @@ def index_document(self, request: HttpRequest) -> Response: ProcessorConfig.METADATA_SERVICE_CLASS ] cls.process( - tool_id=tool_id, + tool_id=str(tool.tool_id), file_name=file_name, org_id=request.org_id, - user_id=request.user.user_id, + user_id=tool.created_by.user_id, document_id=document_id, ) @@ -234,7 +247,7 @@ def index_document(self, request: HttpRequest) -> Response: raise IndexingAPIError() @action(detail=True, methods=["post"]) - def fetch_response(self, request: HttpRequest) -> Response: + def fetch_response(self, request: HttpRequest, pk: Any = None) -> Response: """API Entry point method to fetch response to prompt. Args: @@ -246,7 +259,8 @@ def fetch_response(self, request: HttpRequest) -> Response: Returns: Response """ - tool_id: str = request.data.get(ToolStudioPromptKeys.TOOL_ID) + custom_tool = self.get_object() + tool_id: str = str(custom_tool.tool_id) document_id: str = request.data.get(ToolStudioPromptKeys.DOCUMENT_ID) document: DocumentManager = DocumentManager.objects.get(pk=document_id) file_name: str = document.document_name @@ -260,7 +274,7 @@ def fetch_response(self, request: HttpRequest) -> Response: tool_id=tool_id, file_name=file_name, org_id=request.org_id, - user_id=request.user.user_id, + user_id=custom_tool.created_by.user_id, document_id=document_id, ) return Response(response, status=status.HTTP_200_OK) @@ -296,3 +310,202 @@ def single_pass_extraction(self, request: HttpRequest) -> Response: document_id=document_id, ) return Response(response, status=status.HTTP_200_OK) + + @action(detail=True, methods=["get"]) + def list_of_shared_users( + self, request: HttpRequest, pk: Any = None + ) -> Response: + + custom_tool = ( + self.get_object() + ) # Assuming you have a get_object method in your viewset + + serialized_instances = SharedUserListSerializer(custom_tool).data + + return Response(serialized_instances) + + @action(detail=True, methods=["post"]) + def create_prompt(self, request: HttpRequest, pk: Any = None) -> Response: + context = super().get_serializer_context() + serializer = ToolStudioPromptSerializer( + data=request.data, context=context + ) + serializer.is_valid(raise_exception=True) + try: + # serializer.save() + self.perform_create(serializer) + except IntegrityError: + raise DuplicateData( + f"{ToolStudioPromptErrors.PROMPT_NAME_EXISTS}, \ + {ToolStudioPromptErrors.DUPLICATE_API}" + ) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + @action(detail=True, methods=["post"]) + def create_profile_manager( + self, request: HttpRequest, pk: Any = None + ) -> Response: + context = super().get_serializer_context() + serializer = ProfileManagerSerializer( + data=request.data, context=context + ) + + serializer.is_valid(raise_exception=True) + try: + self.perform_create(serializer) + except IntegrityError: + raise DuplicateData( + f"{ProfileManagerErrors.PROFILE_NAME_EXISTS}, \ + {ProfileManagerErrors.DUPLICATE_API}" + ) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + @action(detail=True, methods=["get"]) + def fetch_contents_ide( + self, request: HttpRequest, pk: Any = None + ) -> Response: + custom_tool = self.get_object() + serializer = FileInfoIdeSerializer(data=request.GET) + serializer.is_valid(raise_exception=True) + document_id: str = serializer.validated_data.get("document_id") + document: DocumentManager = DocumentManager.objects.get(pk=document_id) + file_name: str = document.document_name + view_type: str = serializer.validated_data.get("view_type") + + filename_without_extension = file_name.rsplit(".", 1)[0] + if view_type == FileViewTypes.EXTRACT: + file_name = ( + f"{FileViewTypes.EXTRACT.lower()}/" + f"{filename_without_extension}.txt" + ) + if view_type == FileViewTypes.SUMMARIZE: + file_name = ( + f"{FileViewTypes.SUMMARIZE.lower()}/" + f"{filename_without_extension}.txt" + ) + + file_path = file_path = ( + FileManagerHelper.handle_sub_directory_for_tenants( + request.org_id, + is_create=True, + user_id=custom_tool.created_by.user_id, + tool_id=str(custom_tool.tool_id), + ) + ) + file_system = LocalStorageFS(settings={"path": file_path}) + if not file_path.endswith("/"): + file_path += "/" + file_path += file_name + contents = FileManagerHelper.fetch_file_contents(file_system, file_path) + return Response({"data": contents}, status=status.HTTP_200_OK) + + @action(detail=True, methods=["post"]) + def upload_for_ide(self, request: HttpRequest, pk: Any = None) -> Response: + custom_tool = self.get_object() + serializer = FileUploadIdeSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + uploaded_files: Any = serializer.validated_data.get("file") + + file_path = FileManagerHelper.handle_sub_directory_for_tenants( + request.org_id, + is_create=True, + user_id=custom_tool.created_by.user_id, + tool_id=str(custom_tool.tool_id), + ) + file_system = LocalStorageFS(settings={"path": file_path}) + + documents = [] + for uploaded_file in uploaded_files: + file_name = uploaded_file.name + + # Create a record in the db for the file + document = PromptStudioDocumentHelper.create( + tool_id=str(custom_tool.tool_id), document_name=file_name + ) + # Create a dictionary to store document data + doc = { + "document_id": document.document_id, + "document_name": document.document_name, + "tool": document.tool.tool_id, + } + # Store file + logger.info( + f"Uploading file: {file_name}" + if file_name + else "Uploading file" + ) + FileManagerHelper.upload_file( + file_system, + file_path, + uploaded_file, + file_name, + ) + documents.append(doc) + return Response({"data": documents}) + + @action(detail=True, methods=["delete"]) + def delete_for_ide(self, request: HttpRequest, pk: Any = None) -> Response: + custom_tool = self.get_object() + serializer = FileInfoIdeSerializer(data=request.GET) + serializer.is_valid(raise_exception=True) + document_id: str = serializer.validated_data.get("document_id") + document: DocumentManager = DocumentManager.objects.get(pk=document_id) + file_name: str = document.document_name + file_path = FileManagerHelper.handle_sub_directory_for_tenants( + request.org_id, + is_create=False, + user_id=custom_tool.created_by.user_id, + tool_id=str(custom_tool.tool_id), + ) + path = file_path + file_system = LocalStorageFS(settings={"path": path}) + try: + # Delete the document record + document.delete() + + # Delete the file + FileManagerHelper.delete_file(file_system, path, file_name) + return Response( + {"data": "File deleted succesfully."}, + status=status.HTTP_200_OK, + ) + except Exception as exc: + logger.error(f"Exception thrown from file deletion, error {exc}") + return Response( + {"data": "File deletion failed."}, + status=status.HTTP_400_BAD_REQUEST, + ) + + @action(detail=True, methods=["post"]) + def export_tool(self, request: Request, pk: Any = None) -> Response: + """API Endpoint for exporting required jsons for the custom tool.""" + custom_tool = self.get_object() + serializer = ExportToolRequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + is_shared_with_org: bool = serializer.validated_data.get( + "is_shared_with_org" + ) + user_ids = set(serializer.validated_data.get("user_id")) + + PromptStudioRegistryHelper.update_or_create_psr_tool( + custom_tool=custom_tool, + shared_with_org=is_shared_with_org, + user_ids=user_ids, + ) + return Response( + {"message": "Custom tool exported sucessfully."}, + status=status.HTTP_200_OK, + ) + + @action(detail=True, methods=["get"]) + def export_tool_info(self, request: Request, pk: Any = None) -> Response: + custom_tool = self.get_object() + serialized_instances = None + if hasattr(custom_tool, "prompt_studio_registry"): + serialized_instances = PromptStudioRegistryInfoSerializer( + custom_tool.prompt_studio_registry + ).data + + return Response(serialized_instances) + else: + return Response(status=status.HTTP_404_NOT_FOUND) diff --git a/backend/prompt_studio/prompt_studio_document_manager/urls.py b/backend/prompt_studio/prompt_studio_document_manager/urls.py index 208070544..f9fb9bcd3 100644 --- a/backend/prompt_studio/prompt_studio_document_manager/urls.py +++ b/backend/prompt_studio/prompt_studio_document_manager/urls.py @@ -10,9 +10,6 @@ prompt_studio_documents_detail = PromptStudioDocumentManagerView.as_view( { "get": "retrieve", - "put": "update", - "patch": "partial_update", - "delete": "destroy", } ) @@ -23,10 +20,5 @@ prompt_studio_documents_list, name="prompt-studio-documents-list", ), - path( - "prompt-document//", - prompt_studio_documents_detail, - name="tool-studio-documents-detail", - ), ] ) diff --git a/backend/prompt_studio/prompt_studio_document_manager/views.py b/backend/prompt_studio/prompt_studio_document_manager/views.py index 71596532a..49b2f0227 100644 --- a/backend/prompt_studio/prompt_studio_document_manager/views.py +++ b/backend/prompt_studio/prompt_studio_document_manager/views.py @@ -24,8 +24,8 @@ def get_queryset(self) -> Optional[QuerySet]: self.request, PromptStudioOutputManagerKeys.TOOL_ID, ) + queryset = None if filter_args: queryset = DocumentManager.objects.filter(**filter_args) - else: - queryset = DocumentManager.objects.all() + return queryset diff --git a/backend/prompt_studio/prompt_studio_index_manager/urls.py b/backend/prompt_studio/prompt_studio_index_manager/urls.py index 63858612b..cc7d45448 100644 --- a/backend/prompt_studio/prompt_studio_index_manager/urls.py +++ b/backend/prompt_studio/prompt_studio_index_manager/urls.py @@ -10,9 +10,6 @@ prompt_studio_index_detail = IndexManagerView.as_view( { "get": "retrieve", - "put": "update", - "patch": "partial_update", - "delete": "destroy", } ) @@ -23,10 +20,5 @@ prompt_studio_index_list, name="prompt-studio-documents-list", ), - path( - "document-index//", - prompt_studio_index_detail, - name="tool-studio-documents-detail", - ), ] ) diff --git a/backend/prompt_studio/prompt_studio_index_manager/views.py b/backend/prompt_studio/prompt_studio_index_manager/views.py index 9a870f029..a07e6f6c8 100644 --- a/backend/prompt_studio/prompt_studio_index_manager/views.py +++ b/backend/prompt_studio/prompt_studio_index_manager/views.py @@ -23,8 +23,7 @@ def get_queryset(self) -> Optional[QuerySet]: IndexManagerKeys.PROFILE_MANAGER, IndexManagerKeys.DOCUMENT_MANAGER, ) + queryset = None if filter_args: queryset = IndexManager.objects.filter(**filter_args) - else: - queryset = IndexManager.objects.all() return queryset diff --git a/backend/prompt_studio/prompt_studio_registry/migrations/0006_promptstudioregistry_shared_to_org_and_more.py b/backend/prompt_studio/prompt_studio_registry/migrations/0006_promptstudioregistry_shared_to_org_and_more.py new file mode 100644 index 000000000..df8b7e6be --- /dev/null +++ b/backend/prompt_studio/prompt_studio_registry/migrations/0006_promptstudioregistry_shared_to_org_and_more.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.1 on 2024-03-21 11:36 + +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("prompt_studio_registry", "0005_delete_corrupt_tool_instance"), + ] + + operations = [ + migrations.AddField( + model_name="promptstudioregistry", + name="shared_to_org", + field=models.BooleanField( + db_comment="Is the exported tool shared with entire org", + default=False, + ), + ), + migrations.AddField( + model_name="promptstudioregistry", + name="shared_users", + field=models.ManyToManyField( + related_name="shared_exported_tools", + to=settings.AUTH_USER_MODEL, + ), + ), + ] diff --git a/backend/prompt_studio/prompt_studio_registry/models.py b/backend/prompt_studio/prompt_studio_registry/models.py index df52c6856..bfff30d36 100644 --- a/backend/prompt_studio/prompt_studio_registry/models.py +++ b/backend/prompt_studio/prompt_studio_registry/models.py @@ -1,7 +1,9 @@ import uuid +from typing import Any from account.models import User from django.db import models +from django.db.models import QuerySet from prompt_studio.prompt_studio.models import CustomTool from utils.models.base_model import BaseModel @@ -12,6 +14,18 @@ ) +class PromptStudioRegistryModelManager(models.Manager): + def get_queryset(self) -> QuerySet[Any]: + return super().get_queryset() + + def list_tools(self, user: User) -> QuerySet[Any]: + return ( + self.get_queryset() + .filter(models.Q(shared_users=user) | models.Q(shared_to_org=True)) + .distinct("prompt_registry_id") + ) + + class PromptStudioRegistry(BaseModel): """Data model to export JSON fields needed for registering the Custom tool to the tool registry. @@ -72,3 +86,14 @@ class PromptStudioRegistry(BaseModel): blank=True, editable=False, ) + shared_to_org = models.BooleanField( + default=False, + db_comment="Is the exported tool shared with entire org", + ) + # Introduced field to establish M2M relation between users and tools. + # This will introduce intermediary table which relates both the models. + shared_users = models.ManyToManyField( + User, related_name="shared_exported_tools" + ) + + objects = PromptStudioRegistryModelManager() diff --git a/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py b/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py index 991bf0304..ee0b3b028 100644 --- a/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py +++ b/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py @@ -93,7 +93,7 @@ def get_tool_by_prompt_registry_id( @staticmethod def update_or_create_psr_tool( - custom_tool: CustomTool, + custom_tool: CustomTool, shared_with_org: bool, user_ids: set[int] ) -> PromptStudioRegistry: """Updates or creates the PromptStudioRegistry record. @@ -143,6 +143,10 @@ def update_or_create_psr_tool( logger.info(f"PSR {obj.prompt_registry_id} was created") else: logger.info(f"PSR {obj.prompt_registry_id} was updated") + obj.shared_to_org = shared_with_org + obj.shared_users.clear() + obj.shared_users.add(*user_ids) + obj.save() return obj except IntegrityError as error: logger.error( @@ -238,7 +242,8 @@ def frame_export_json( @staticmethod def fetch_json_for_registry(user: User) -> list[dict[str, Any]]: try: - prompt_studio_tools = PromptStudioRegistry.objects.all() + # filter the Prompt studio registry based on the users and org flag + prompt_studio_tools = PromptStudioRegistry.objects.list_tools(user) pi_serializer = PromptStudioRegistrySerializer( instance=prompt_studio_tools, many=True ) diff --git a/backend/prompt_studio/prompt_studio_registry/serializers.py b/backend/prompt_studio/prompt_studio_registry/serializers.py index fe74848f7..dd1d6ea20 100644 --- a/backend/prompt_studio/prompt_studio_registry/serializers.py +++ b/backend/prompt_studio/prompt_studio_registry/serializers.py @@ -1,3 +1,4 @@ +from account.serializer import UserSerializer from rest_framework import serializers from backend.serializers import AuditSerializer @@ -11,5 +12,16 @@ class Meta: fields = "__all__" +class PromptStudioRegistryInfoSerializer(AuditSerializer): + shared_users = UserSerializer(many=True) + + class Meta: + model = PromptStudioRegistry + fields = ("name", "shared_users", "shared_to_org") + + class ExportToolRequestSerializer(serializers.Serializer): - prompt_registry_id = serializers.UUIDField(required=True) + is_shared_with_org = serializers.BooleanField(default=False) + user_id = serializers.ListField( + child=serializers.IntegerField(), required=False + ) diff --git a/backend/prompt_studio/prompt_studio_registry/urls.py b/backend/prompt_studio/prompt_studio_registry/urls.py index 68aed0856..9cc93e05e 100644 --- a/backend/prompt_studio/prompt_studio_registry/urls.py +++ b/backend/prompt_studio/prompt_studio_registry/urls.py @@ -1,15 +1,3 @@ -from django.urls import path from rest_framework.urlpatterns import format_suffix_patterns -from .views import PromptStudioRegistryView - -tool_studio_export = PromptStudioRegistryView.as_view({"get": "export_tool"}) -urlpatterns = format_suffix_patterns( - [ - path( - "export/", - tool_studio_export, - name="prompt_studio_export", - ), - ] -) +urlpatterns = format_suffix_patterns([]) diff --git a/backend/prompt_studio/prompt_studio_registry/views.py b/backend/prompt_studio/prompt_studio_registry/views.py index 330721b81..5b0bd74cc 100644 --- a/backend/prompt_studio/prompt_studio_registry/views.py +++ b/backend/prompt_studio/prompt_studio_registry/views.py @@ -2,22 +2,13 @@ from typing import Optional from django.db.models import QuerySet -from prompt_studio.prompt_studio_core.models import CustomTool from prompt_studio.prompt_studio_registry.constants import ( PromptStudioRegistryKeys, ) -from prompt_studio.prompt_studio_registry.exceptions import ToolDoesNotExist -from prompt_studio.prompt_studio_registry.prompt_studio_registry_helper import ( - PromptStudioRegistryHelper, -) from prompt_studio.prompt_studio_registry.serializers import ( - ExportToolRequestSerializer, PromptStudioRegistrySerializer, ) -from rest_framework import status, viewsets -from rest_framework.decorators import action -from rest_framework.request import Request -from rest_framework.response import Response +from rest_framework import viewsets from rest_framework.versioning import URLPathVersioning from utils.filtering import FilterHelper @@ -39,33 +30,8 @@ def get_queryset(self) -> Optional[QuerySet]: self.request, PromptStudioRegistryKeys.PROMPT_REGISTRY_ID, ) + queryset = None if filterArgs: queryset = PromptStudioRegistry.objects.filter(**filterArgs) - else: - queryset = PromptStudioRegistry.objects.all() - return queryset - @action(detail=True, methods=["get"]) - def export_tool(self, request: Request) -> Response: - """API Endpoint for exporting required jsons for the custom tool.""" - serializer = ExportToolRequestSerializer(data=request.query_params) - serializer.is_valid(raise_exception=True) - - custom_tool_id = serializer.validated_data.get( - PromptStudioRegistryKeys.PROMPT_REGISTRY_ID - ) - try: - custom_tool = CustomTool.objects.get(tool_id=custom_tool_id) - except CustomTool.DoesNotExist as error: - logger.error( - f"Error occured while fetching tool \ - for tool_id:{custom_tool_id} {error}" - ) - raise ToolDoesNotExist from error - PromptStudioRegistryHelper.update_or_create_psr_tool( - custom_tool=custom_tool - ) - return Response( - {"message": "Custom tool exported sucessfully."}, - status=status.HTTP_200_OK, - ) + return queryset diff --git a/backend/tool_instance/tool_instance_helper.py b/backend/tool_instance/tool_instance_helper.py index bfb804d19..1e0d9c5e2 100644 --- a/backend/tool_instance/tool_instance_helper.py +++ b/backend/tool_instance/tool_instance_helper.py @@ -10,6 +10,7 @@ from connector.connector_instance_helper import ConnectorInstanceHelper from django.core.exceptions import PermissionDenied from jsonschema.exceptions import UnknownType, ValidationError +from prompt_studio.prompt_studio_registry.models import PromptStudioRegistry from tool_instance.constants import JsonSchemaKey from tool_instance.models import ToolInstance from tool_instance.tool_processor import ToolProcessor @@ -338,6 +339,9 @@ def validate_tool_settings( user: User, tool_uid: str, tool_meta: dict[str, Any] ) -> tuple[bool, str]: """Function to validate Tools settings.""" + + # check if exported tool is valid for the user who created workflow + ToolInstanceHelper.validate_tool_access(user=user, tool_uid=tool_uid) ToolInstanceHelper.validate_adapter_permissions( user=user, tool_uid=tool_uid, tool_meta=tool_meta ) @@ -432,3 +436,20 @@ def validate_adapter_access( raise PermissionDenied( "You don't have permission to perform this action." ) + + @staticmethod + def validate_tool_access( + user: User, + tool_uid: str, + ) -> None: + prompt_regitry_tool = PromptStudioRegistry.objects.get(pk=tool_uid) + + if ( + prompt_regitry_tool.shared_to_org + or prompt_regitry_tool.shared_users.filter(pk=user.pk).exists() + ): + return + else: + raise PermissionDenied( + "You don't have permission to perform this action." + ) diff --git a/frontend/src/components/custom-tools/add-llm-profile/AddLlmProfile.jsx b/frontend/src/components/custom-tools/add-llm-profile/AddLlmProfile.jsx index 50f31f20b..7b0c28fdb 100644 --- a/frontend/src/components/custom-tools/add-llm-profile/AddLlmProfile.jsx +++ b/frontend/src/components/custom-tools/add-llm-profile/AddLlmProfile.jsx @@ -16,13 +16,13 @@ import { useEffect, useState } from "react"; import { getBackendErrorDetail } from "../../../helpers/GetStaticData"; import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate"; +import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; import { useAlertStore } from "../../../store/alert-store"; import { useCustomToolStore } from "../../../store/custom-tool-store"; import { useSessionStore } from "../../../store/session-store"; import { CustomButton } from "../../widgets/custom-button/CustomButton"; import SpaceWrapper from "../../widgets/space-wrapper/SpaceWrapper"; import "./AddLlmProfile.css"; -import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; function AddLlmProfile({ editLlmProfileId, @@ -290,10 +290,11 @@ function AddLlmProfile({ const handleSubmit = () => { setLoading(true); let method = "POST"; - let url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/profile-manager/`; + let url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/profilemanager/${details?.tool_id}`; if (editLlmProfileId?.length) { method = "PUT"; + url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/profile-manager/`; url += `${editLlmProfileId}/`; } diff --git a/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx b/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx index f039b9cf0..26da57682 100644 --- a/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx +++ b/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx @@ -1,19 +1,19 @@ -import { Button, Space, Tabs, Tooltip, Typography } from "antd"; +import { LeftOutlined, RightOutlined } from "@ant-design/icons"; import "@react-pdf-viewer/core/lib/styles/index.css"; import "@react-pdf-viewer/default-layout/lib/styles/index.css"; import "@react-pdf-viewer/page-navigation/lib/styles/index.css"; +import { Button, Space, Tabs, Tooltip, Typography } from "antd"; import PropTypes from "prop-types"; -import { LeftOutlined, RightOutlined } from "@ant-design/icons"; import { useEffect, useState } from "react"; import "./DocumentManager.css"; -import { useCustomToolStore } from "../../../store/custom-tool-store"; -import { PdfViewer } from "../pdf-viewer/PdfViewer"; -import { ManageDocsModal } from "../manage-docs-modal/ManageDocsModal"; +import { base64toBlob, docIndexStatus } from "../../../helpers/GetStaticData"; import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate"; +import { useCustomToolStore } from "../../../store/custom-tool-store"; import { useSessionStore } from "../../../store/session-store"; -import { base64toBlob, docIndexStatus } from "../../../helpers/GetStaticData"; import { DocumentViewer } from "../document-viewer/DocumentViewer"; +import { ManageDocsModal } from "../manage-docs-modal/ManageDocsModal"; +import { PdfViewer } from "../pdf-viewer/PdfViewer"; import { TextViewerPre } from "../text-viewer-pre/TextViewerPre"; const items = [ @@ -127,7 +127,7 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) { const requestOptions = { method: "GET", - url: `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/file/fetch_contents?document_id=${selectedDoc?.document_id}&view_type=${viewType}&tool_id=${details?.tool_id}`, + url: `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/file/${details?.tool_id}?document_id=${selectedDoc?.document_id}&view_type=${viewType}`, }; handleLoadingStateUpdate(viewType, true); diff --git a/frontend/src/components/custom-tools/header/Header.jsx b/frontend/src/components/custom-tools/header/Header.jsx index 4380cc47c..78372d333 100644 --- a/frontend/src/components/custom-tools/header/Header.jsx +++ b/frontend/src/components/custom-tools/header/Header.jsx @@ -9,13 +9,14 @@ import { useState } from "react"; import { useNavigate } from "react-router-dom"; import "./Header.css"; +import { ExportToolIcon } from "../../../assets"; import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate"; +import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; import { useAlertStore } from "../../../store/alert-store"; import { useCustomToolStore } from "../../../store/custom-tool-store"; import { useSessionStore } from "../../../store/session-store"; import { CustomButton } from "../../widgets/custom-button/CustomButton"; -import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; -import { ExportToolIcon } from "../../../assets"; +import { SharePermission } from "../../widgets/share-permission/SharePermission"; let SinglePassToggleSwitch; try { @@ -32,13 +33,26 @@ function Header({ setOpenSettings, handleUpdateTool }) { const axiosPrivate = useAxiosPrivate(); const navigate = useNavigate(); const handleException = useExceptionHandler(); + const [userList, setUserList] = useState([]); + const [openSharePermissionModal, setOpenSharePermissionModal] = + useState(false); - const handleExport = () => { + const [toolDetails, setToolDetails] = useState(null); + + const handleExport = (selectedUsers, toolDetail, isSharedWithEveryone) => { + const body = { + is_shared_with_org: isSharedWithEveryone, + user_id: isSharedWithEveryone ? [] : selectedUsers, + }; const requestOptions = { - method: "GET", - url: `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/export/?prompt_registry_id=${details?.tool_id}`, + method: "POST", + url: `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/export/${details?.tool_id}`, + headers: { + "X-CSRFToken": sessionDetails?.csrfToken, + "Content-Type": "application/json", + }, + data: body, }; - setIsExportLoading(true); axiosPrivate(requestOptions) .then(() => { @@ -50,9 +64,72 @@ function Header({ setOpenSettings, handleUpdateTool }) { .catch((err) => { setAlertDetails(handleException(err, "Failed to export")); }) + .finally(() => { + setIsExportLoading(false); + setOpenSharePermissionModal(false); + }); + }; + + const handleShare = (isEdit) => { + const requestOptions = { + method: "GET", + url: `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/export/${details?.tool_id}`, + headers: { + "X-CSRFToken": sessionDetails?.csrfToken, + }, + }; + setIsExportLoading(true); + getAllUsers().then((users) => { + if (users.length < 2) { + handleExport([details?.created_by], details, false); + } else { + axiosPrivate(requestOptions) + .then((res) => { + setOpenSharePermissionModal(true); + setToolDetails({ ...res?.data, created_by: details?.created_by }); + }) + .catch((err) => { + if (err?.response?.status === 404) { + setToolDetails(details); + setOpenSharePermissionModal(true); + setAlertDetails(handleException(err, "Tool not exported yet")); + } else { + setAlertDetails(handleException(err)); + } + }) + .finally(() => { + setIsExportLoading(false); + }); + } + }); + }; + + const getAllUsers = async () => { + setIsExportLoading(true); + const requestOptions = { + method: "GET", + url: `/api/v1/unstract/${sessionDetails?.orgId}/users/`, + }; + + const userList = axiosPrivate(requestOptions) + .then((response) => { + const users = response?.data?.members || []; + setUserList( + users.map((user) => ({ + id: user?.id, + email: user?.email, + })) + ); + return users; + }) + .catch((err) => { + setAlertDetails(handleException(err, "Failed to load")); + }) .finally(() => { setIsExportLoading(false); }); + + return userList; }; return ( @@ -91,13 +168,23 @@ function Header({ setOpenSettings, handleUpdateTool }) { handleShare(true)} loading={isExportLoading} > + ); diff --git a/frontend/src/components/custom-tools/list-of-tools/ListOfTools.jsx b/frontend/src/components/custom-tools/list-of-tools/ListOfTools.jsx index 5293bdc49..703f6ebb6 100644 --- a/frontend/src/components/custom-tools/list-of-tools/ListOfTools.jsx +++ b/frontend/src/components/custom-tools/list-of-tools/ListOfTools.jsx @@ -10,6 +10,7 @@ import { ViewTools } from "../view-tools/ViewTools"; import "./ListOfTools.css"; import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; import { ToolNavBar } from "../../navigations/tool-nav-bar/ToolNavBar"; +import { SharePermission } from "../../widgets/share-permission/SharePermission"; function ListOfTools() { const [isListLoading, setIsListLoading] = useState(false); @@ -23,7 +24,12 @@ function ListOfTools() { const [filteredListOfTools, setFilteredListOfTools] = useState([]); const handleException = useExceptionHandler(); const [isEdit, setIsEdit] = useState(false); - + const [promptDetails, setPromptDetails] = useState(null); + const [openSharePermissionModal, setOpenSharePermissionModal] = + useState(false); + const [isPermissionEdit, setIsPermissionEdit] = useState(false); + const [isShareLoading, setIsShareLoading] = useState(false); + const [allUserList, setAllUserList] = useState([]); useEffect(() => { getListOfTools(); }, []); @@ -171,6 +177,73 @@ function ListOfTools() { ); }; + const handleShare = (_event, promptProject, isEdit) => { + const requestOptions = { + method: "GET", + url: `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/users/${promptProject?.tool_id}`, + headers: { + "X-CSRFToken": sessionDetails?.csrfToken, + }, + }; + setIsShareLoading(true); + getAllUsers(); + axiosPrivate(requestOptions) + .then((res) => { + setOpenSharePermissionModal(true); + setPromptDetails(res?.data); + setIsPermissionEdit(isEdit); + }) + .catch((err) => { + setAlertDetails(handleException(err)); + }) + .finally(() => { + setIsShareLoading(false); + }); + }; + + const getAllUsers = () => { + setIsShareLoading(true); + const requestOptions = { + method: "GET", + url: `/api/v1/unstract/${sessionDetails?.orgId}/users/`, + }; + + axiosPrivate(requestOptions) + .then((response) => { + const users = response?.data?.members || []; + setAllUserList( + users.map((user) => ({ + id: user?.id, + email: user?.email, + })) + ); + }) + .catch((err) => { + setAlertDetails(handleException(err, "Failed to load")); + }) + .finally(() => { + setIsShareLoading(false); + }); + }; + + const onShare = (userIds, adapter) => { + const requestOptions = { + method: "PATCH", + url: `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/${adapter?.tool_id}`, + headers: { + "X-CSRFToken": sessionDetails?.csrfToken, + }, + data: { shared_users: userIds }, + }; + axiosPrivate(requestOptions) + .then((response) => { + setOpenSharePermissionModal(false); + }) + .catch((err) => { + setAlertDetails(handleException(err, "Failed to load")); + }); + }; + return ( <> @@ -209,6 +283,15 @@ function ListOfTools() { handleAddNewTool={handleAddNewTool} /> )} + ); } diff --git a/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx b/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx index a2929d9a6..89783c7e0 100644 --- a/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx +++ b/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx @@ -520,7 +520,7 @@ function ManageDocsModal({
{ if (permissionEdit && adapter && adapter?.shared_users) { @@ -27,29 +38,30 @@ function SharePermission({ // set the selectedUsers to the IDs of shared users const users = allUsers.filter((user) => { if (adapter?.created_by?.id !== undefined) { - return ( - user?.id !== adapter?.created_by?.id?.toString() && - !selectedUsers.includes(user.id.toString()) - ); + return isSharableToOrg + ? !selectedUsers.includes(user?.id?.toString()) + : user?.id !== adapter?.created_by?.id?.toString() && + !selectedUsers.includes(user?.id?.toString()); } else { - return ( - user?.id !== adapter?.created_by?.toString() && - !selectedUsers.includes(user.id.toString()) - ); + return isSharableToOrg + ? !selectedUsers.includes(user?.id?.toString()) + : user?.id !== adapter?.created_by?.toString() && + !selectedUsers.includes(user?.id?.toString()); } }); setFilteredUsers(users); + setShareWithEveryone(adapter?.shared_to_org || false); } }, [permissionEdit, adapter, allUsers, selectedUsers]); useEffect(() => { - if (adapter && adapter.shared_users) { + if (adapter?.shared_users) { setSelectedUsers( adapter.shared_users.map((user) => { if (user?.id !== undefined) { return user.id.toString(); } else { - return user.toString(); + return user?.toString(); } }) ); @@ -64,6 +76,68 @@ function SharePermission({ const filterOption = (input, option) => (option?.label ?? "").toLowerCase().includes(input.toLowerCase()); + const handleShareWithEveryone = (checked) => { + setShareWithEveryone(checked); + }; + + let sharedWithContent; + if (shareWithEveryone) { + sharedWithContent = Shared with everyone; + } else if (selectedUsers.length > 0) { + sharedWithContent = ( + { + const user = allUsers.find( + (u) => u?.id.toString() === userId.toString() + ); + return { + id: user?.id, + email: user?.email, + }; + })} + renderItem={(item) => ( + event.stopPropagation()} role="none"> + } + onConfirm={(event) => handleDeleteUser(item?.id)} + > + + + + +
+ ) + } + > + + } + /> + + {item.email} + + + } + /> + + )} + /> + ); + } else { + sharedWithContent = Not shared with anyone yet; + } + return ( adapter && ( onApply(selectedUsers, adapter)} + onOk={() => onApply(selectedUsers, adapter, shareWithEveryone)} cancelButtonProps={!permissionEdit && { style: { display: "none" } }} okButtonProps={!permissionEdit && { style: { display: "none" } }} className="share-permission-modal" @@ -83,7 +157,16 @@ function SharePermission({ ) : ( <> - {permissionEdit ? ( + {isSharableToOrg && allUsers.length > 1 && ( + handleShareWithEveryone(e.target.checked)} + className="share-per-checkbox" + > + Share with everyone + + )} + {permissionEdit && !shareWithEveryone && (