Skip to content

Commit

Permalink
Feature/prompt studio document manager (#78)
Browse files Browse the repository at this point in the history
* Implemented re-index

* Fixed data persistence issue in the summarize modal

* Reverted to the default background color of the PDF viewer

* UI Improvements in Output Analyzer

* Fixed issue with spaces in the prompt/notes card

* Added loader to the submit button

* Revert "FIX: Prompt Studio Bug Fixed (#65)"

This reverts commit 35429b1.

* Revert "Revert "FIX: Prompt Studio Bug Fixed (#65)""

This reverts commit 1e2921d.

* Revert "FIX: Prompt Studio Bug Fixed (#65)"

This reverts commit 35429b1.

* Backend changes related to document manager

* UI changes related to document manager

* Index Manager and Document Manager changes

* Code efficiency improvement in the document manager BE

* UI changes to support document manager changes

* implemented new design

* Added API support for index manager

* FE changes for showing the indexing status in the Manage Documents table

* fixed prompt list not updated after adding new prompt

* UI bug fixes and improvements

* UI bug fixes

* Optimized migrations

* Modification in migrations

* Code quality improvement

---------

Co-authored-by: Neha <[email protected]>
Co-authored-by: Jaseem Jas <[email protected]>
Co-authored-by: jagadeeswaran-zipstack <[email protected]>
  • Loading branch information
4 people authored Mar 11, 2024
1 parent 80d4870 commit af6139f
Show file tree
Hide file tree
Showing 65 changed files with 1,642 additions and 720 deletions.
2 changes: 2 additions & 0 deletions backend/backend/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def get_required_setting(
"prompt_studio.prompt_studio_core",
"prompt_studio.prompt_studio_registry",
"prompt_studio.prompt_studio_output_manager",
"prompt_studio.prompt_studio_document_manager",
"prompt_studio.prompt_studio_index_manager",
)

INSTALLED_APPS = list(SHARED_APPS) + [
Expand Down
8 changes: 8 additions & 0 deletions backend/backend/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,12 @@
UrlPathConstants.PROMPT_STUDIO,
include("prompt_studio.prompt_studio_output_manager.urls"),
),
path(
UrlPathConstants.PROMPT_STUDIO,
include("prompt_studio.prompt_studio_document_manager.urls"),
),
path(
UrlPathConstants.PROMPT_STUDIO,
include("prompt_studio.prompt_studio_index_manager.urls"),
),
]
6 changes: 6 additions & 0 deletions backend/file_management/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,9 @@ class FileInformationKey:
FILE_UPLOAD_MAX_SIZE = 100 * 1024 * 1024
FILE_UPLOAD_ALLOWED_EXT = ["pdf"]
FILE_UPLOAD_ALLOWED_MIME = ["application/pdf"]

class FileViewTypes:
ORIGINAL = "ORIGINAL"
EXTRACT = "EXTRACT"
SUMMARIZE = "SUMMARIZE"

3 changes: 2 additions & 1 deletion backend/file_management/file_management_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ def fetch_file_contents(

elif file_content_type == "text/plain":
with fs.open(file_path, "r") as file:
FileManagerHelper.logger.info(f"Reading text file: {file_path}")
FileManagerHelper.logger.info(
f"Reading text file: {file_path}")
text_content = file.read()
return text_content
else:
Expand Down
3 changes: 2 additions & 1 deletion backend/file_management/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ class FileUploadIdeSerializer(serializers.Serializer):


class FileInfoIdeSerializer(serializers.Serializer):
file_name = serializers.CharField()
document_id = serializers.CharField()
tool_id = serializers.CharField()
view_type = serializers.CharField(required=False)


class FileListRequestIdeSerializer(serializers.Serializer):
Expand Down
56 changes: 47 additions & 9 deletions backend/file_management/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from connector.models import ConnectorInstance
from django.http import HttpRequest
from file_management.constants import FileViewTypes
from file_management.exceptions import (
ConnectorInstanceNotFound,
ConnectorOAuthError,
Expand All @@ -20,10 +21,15 @@
FileUploadSerializer,
)
from oauth2client.client import HttpAccessTokenRefreshError
from prompt_studio.prompt_studio_document_manager.models import DocumentManager
from prompt_studio.prompt_studio_document_manager.prompt_studio_document_helper import (
PromptStudioDocumentHelper,
)
from rest_framework import serializers, status, viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning

from unstract.connectors.exceptions import ConnectorError
from unstract.connectors.filesystems.local_storage.local_storage import (
LocalStorageFS,
Expand Down Expand Up @@ -132,8 +138,21 @@ def upload_for_ide(self, request: HttpRequest) -> Response:
tool_id=tool_id,
)
file_system = LocalStorageFS(settings={"path": file_path})

documents = []
for uploaded_file in uploaded_files:
file_name = uploaded_file.name

# Create a record in the db for the file
document = PromptStudioDocumentHelper.create(
tool_id=tool_id, document_name=file_name)
# Create a dictionary to store document data
doc = {
"document_id": document.document_id,
"document_name": document.document_name,
"tool": document.tool.tool_id
}
# Store file
logger.info(
f"Uploading file: {file_name}"
if file_name
Expand All @@ -145,14 +164,31 @@ def upload_for_ide(self, request: HttpRequest) -> Response:
uploaded_file,
file_name,
)
return Response({"message": "Files are uploaded successfully!"})
documents.append(doc)
return Response({"data": documents})

@action(detail=True, methods=["get"])
def fetch_contents_ide(self, request: HttpRequest) -> Response:
serializer = FileInfoIdeSerializer(data=request.GET)
serializer.is_valid(raise_exception=True)
file_name: str = serializer.validated_data.get("file_name")
document_id: str = serializer.validated_data.get("document_id")
document: DocumentManager = DocumentManager.objects.get(pk=document_id)
file_name: str = document.document_name
tool_id: str = serializer.validated_data.get("tool_id")
view_type: str = serializer.validated_data.get("view_type")

filename_without_extension = file_name.rsplit('.', 1)[0]
if view_type == FileViewTypes.EXTRACT:
file_name = (
f"{FileViewTypes.EXTRACT.lower()}/"
f"{filename_without_extension}.txt"
)
if view_type == FileViewTypes.SUMMARIZE:
file_name = (
f"{FileViewTypes.SUMMARIZE.lower()}/"
f"{filename_without_extension}.txt"
)

file_path = (
file_path
) = FileManagerHelper.handle_sub_directory_for_tenants(
Expand All @@ -165,7 +201,8 @@ def fetch_contents_ide(self, request: HttpRequest) -> Response:
if not file_path.endswith("/"):
file_path += "/"
file_path += file_name
contents = FileManagerHelper.fetch_file_contents(file_system, file_path)
contents = FileManagerHelper.fetch_file_contents(
file_system, file_path)
return Response({"data": contents}, status=status.HTTP_200_OK)

@action(detail=True, methods=["get"])
Expand Down Expand Up @@ -196,7 +233,9 @@ def list_ide(self, request: HttpRequest) -> Response:
def delete(self, request: HttpRequest) -> Response:
serializer = FileInfoIdeSerializer(data=request.GET)
serializer.is_valid(raise_exception=True)
file_name: str = serializer.validated_data.get("file_name")
document_id: str = serializer.validated_data.get("document_id")
document: DocumentManager = DocumentManager.objects.get(pk=document_id)
file_name: str = document.document_name
tool_id: str = serializer.validated_data.get("tool_id")
file_path = FileManagerHelper.handle_sub_directory_for_tenants(
request.org_id,
Expand All @@ -205,13 +244,12 @@ def delete(self, request: HttpRequest) -> Response:
tool_id=tool_id,
)
path = file_path
if not file_name:
return Response(
{"data": "File deletion failed. File name is mandatory"},
status=status.HTTP_400_BAD_REQUEST,
)
file_system = LocalStorageFS(settings={"path": path})
try:
# Delete the document record
document.delete()

# Delete the file
FileManagerHelper.delete_file(file_system, path, file_name)
return Response(
{"data": "File deleted succesfully."},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ class Migration(migrations.Migration):
"prompt_profile_manager",
"0007_profilemanager_is_default_and_more",
),
(
"prompt_studio",
"0006_alter_toolstudioprompt_prompt_key_and_more",
),
(
"prompt_studio_core",
"0007_remove_customtool_default_profile_and_more",
)
]

def MigrateProfileManager(apps: Any, schema_editor: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion backend/prompt_studio/prompt_profile_manager/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_queryset(self) -> Optional[QuerySet]:
def create(
self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any]
) -> Response:
serializer = self.get_serializer(data=request.data)
serializer: ProfileManagerSerializer = self.get_serializer(data=request.data)
# Overriding default exception behaviour
# TO DO : Handle model related exceptions.
serializer.is_valid(raise_exception=True)
Expand Down
1 change: 1 addition & 0 deletions backend/prompt_studio/prompt_studio_core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class ToolStudioPromptKeys:
EVAL_SETTINGS_EXCLUDE_FAILED = "exclude_failed"
SUMMARIZE = "summarize"
SUMMARIZED_RESULT = "summarized_result"
DOCUMENT_ID = "document_id"


class LogLevels:
Expand Down
35 changes: 31 additions & 4 deletions backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from prompt_studio.prompt_studio_core.prompt_ide_base_tool import (
PromptIdeBaseTool,
)
from prompt_studio.prompt_studio_index_manager.prompt_studio_index_helper import (
PromptStudioIndexHelper,
)
from unstract.sdk.constants import LogLevel
from unstract.sdk.index import ToolIndex
from unstract.sdk.prompt import PromptTool
Expand Down Expand Up @@ -85,6 +88,7 @@ def index_document(
file_name: str,
org_id: str,
user_id: str,
document_id: str,
is_summary: bool = False,
) -> Any:
"""Method to index a document.
Expand Down Expand Up @@ -149,6 +153,7 @@ def index_document(
tool_id=tool_id,
file_name=file_path,
org_id=org_id,
document_id=document_id,
is_summary=is_summary,
)
logger.info(f"Indexing done sucessfully for {file_name}")
Expand All @@ -164,7 +169,12 @@ def index_document(

@staticmethod
def prompt_responder(
id: str, tool_id: str, file_name: str, org_id: str, user_id: str
id: str,
tool_id: str,
file_name: str,
org_id: str,
user_id: str,
document_id: str
) -> Any:
"""Execute chain/single run of the prompts. Makes a call to prompt
service and returns the dict of response.
Expand Down Expand Up @@ -217,7 +227,8 @@ def prompt_responder(
),
)
if not prompt_instance:
logger.error(f"Prompt id {id} does not have any data in db")
logger.error(
f"Prompt id {id} does not have any data in db")
raise PromptNotValid()
except Exception as exc:
logger.error(f"Error while fetching prompt {exc}")
Expand All @@ -242,7 +253,11 @@ def prompt_responder(
)
logger.info(f"Invoking prompt service for prompt id {id}")
response = PromptStudioHelper._fetch_response(
path=file_path, tool=tool, prompts=prompts, org_id=org_id
path=file_path,
tool=tool,
prompts=prompts,
org_id=org_id,
document_id=document_id
)
stream_log.publish(
tool.tool_id,
Expand All @@ -262,6 +277,7 @@ def _fetch_response(
path: str,
prompts: list[ToolStudioPrompt],
org_id: str,
document_id: str
) -> Any:
"""Utility function to invoke prompt service. Used internally.
Expand Down Expand Up @@ -302,6 +318,7 @@ def _fetch_response(
file_name=path,
tool_id=str(tool.tool_id),
org_id=org_id,
document_id=document_id,
is_summary=tool.summarize_as_source,
)

Expand Down Expand Up @@ -382,6 +399,7 @@ def dynamic_indexer(
tool_id: str,
file_name: str,
org_id: str,
document_id: str,
is_summary: bool = False,
) -> str:
try:
Expand All @@ -400,7 +418,7 @@ def dynamic_indexer(
extract_file_path = os.path.join(
directory, "extract", os.path.splitext(filename)[0] + ".txt"
)
return str(
doc_id = str(
tool_index.index_file(
tool_id=tool_id,
embedding_type=embedding_model,
Expand All @@ -414,3 +432,12 @@ def dynamic_indexer(
output_file_path=extract_file_path,
)
)

PromptStudioIndexHelper.handle_index_manager(
document_id=document_id,
is_summary=is_summary,
profile_manager=profile_manager,
doc_id=doc_id,
)

return doc_id
2 changes: 1 addition & 1 deletion backend/prompt_studio/prompt_studio_core/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def to_representation(self, instance): # type: ignore


class PromptStudioIndexSerializer(serializers.Serializer):
file_name = serializers.CharField()
document_id = serializers.CharField()
tool_id = serializers.CharField()


Expand Down
15 changes: 11 additions & 4 deletions backend/prompt_studio/prompt_studio_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from prompt_studio.prompt_studio_core.prompt_studio_helper import (
PromptStudioHelper,
)
from prompt_studio.prompt_studio_document_manager.models import DocumentManager
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
Expand Down Expand Up @@ -196,15 +197,17 @@ def index_document(self, request: HttpRequest) -> Response:
tool_id: str = serializer.validated_data.get(
ToolStudioPromptKeys.TOOL_ID
)
file_name: str = serializer.validated_data.get(
ToolStudioPromptKeys.FILE_NAME
)
document_id: str = serializer.validated_data.get(
ToolStudioPromptKeys.DOCUMENT_ID)
document: DocumentManager = DocumentManager.objects.get(pk=document_id)
file_name: str = document.document_name
try:
unique_id = PromptStudioHelper.index_document(
tool_id=tool_id,
file_name=file_name,
org_id=request.org_id,
user_id=request.user.user_id,
document_id=document_id,
)

for processor_plugin in self.processor_plugins:
Expand All @@ -216,6 +219,7 @@ def index_document(self, request: HttpRequest) -> Response:
file_name=file_name,
org_id=request.org_id,
user_id=request.user.user_id,
document_id=document_id,
)

if unique_id:
Expand Down Expand Up @@ -246,7 +250,9 @@ def fetch_response(self, request: HttpRequest) -> Response:
Response
"""
tool_id: str = request.data.get(ToolStudioPromptKeys.TOOL_ID)
file_name: str = request.data.get(ToolStudioPromptKeys.FILE_NAME)
document_id: str = request.data.get(ToolStudioPromptKeys.DOCUMENT_ID)
document: DocumentManager = DocumentManager.objects.get(pk=document_id)
file_name: str = document.document_name
id: str = request.data.get(ToolStudioPromptKeys.ID)

if not file_name or file_name == ToolStudioPromptKeys.UNDEFINED:
Expand All @@ -258,5 +264,6 @@ def fetch_response(self, request: HttpRequest) -> Response:
file_name=file_name,
org_id=request.org_id,
user_id=request.user.user_id,
document_id=document_id,
)
return Response(response, status=status.HTTP_200_OK)
Empty file.
5 changes: 5 additions & 0 deletions backend/prompt_studio/prompt_studio_document_manager/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.contrib import admin

from .models import DocumentManager

admin.site.register(DocumentManager)
5 changes: 5 additions & 0 deletions backend/prompt_studio/prompt_studio_document_manager/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.apps import AppConfig


class PromptStudioDocumentManagerConfig(AppConfig):
name = 'prompt_studio.prompt_studio_document_manager'
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class PSDMKeys:
DOCUMENT_NAME = "document_name"
TOOL = "tool"
DOCUMENT_ID = "document_id"
Loading

0 comments on commit af6139f

Please sign in to comment.