diff --git a/backend/adapter_processor/serializers.py b/backend/adapter_processor/serializers.py
index 80637206e..a3f281dcb 100644
--- a/backend/adapter_processor/serializers.py
+++ b/backend/adapter_processor/serializers.py
@@ -124,6 +124,10 @@ def to_representation(self, instance: AdapterInstance) -> dict[str, str]:
rep[common.ICON] = AdapterProcessor.get_adapter_data_with_key(
instance.adapter_id, common.ICON
)
+ adapter_metadata = instance.get_adapter_meta_data()
+ model = adapter_metadata.get("model")
+ if model:
+ rep["model"] = model
if instance.is_friction_less:
rep["created_by_email"] = "Unstract"
diff --git a/backend/backend/settings/base.py b/backend/backend/settings/base.py
index ae25fe6b2..902a7d012 100644
--- a/backend/backend/settings/base.py
+++ b/backend/backend/settings/base.py
@@ -167,6 +167,7 @@ def get_required_setting(
"CELERY_BROKER_URL", f"redis://{REDIS_HOST}:{REDIS_PORT}"
)
+INDEXING_FLAG_TTL = int(get_required_setting("INDEXING_FLAG_TTL"))
# Flag to Enable django admin
ADMIN_ENABLED = False
diff --git a/backend/prompt_studio/prompt_profile_manager/constants.py b/backend/prompt_studio/prompt_profile_manager/constants.py
index 70cf34019..6540b58ee 100644
--- a/backend/prompt_studio/prompt_profile_manager/constants.py
+++ b/backend/prompt_studio/prompt_profile_manager/constants.py
@@ -7,6 +7,8 @@ class ProfileManagerKeys:
VECTOR_STORE = "vector_store"
EMBEDDING_MODEL = "embedding_model"
X2TEXT = "x2text"
+ PROMPT_STUDIO_TOOL = "prompt_studio_tool"
+ MAX_PROFILE_COUNT = 4
class ProfileManagerErrors:
diff --git a/backend/prompt_studio/prompt_profile_manager/profile_manager_helper.py b/backend/prompt_studio/prompt_profile_manager/profile_manager_helper.py
new file mode 100644
index 000000000..68783b551
--- /dev/null
+++ b/backend/prompt_studio/prompt_profile_manager/profile_manager_helper.py
@@ -0,0 +1,11 @@
+from prompt_studio.prompt_profile_manager.models import ProfileManager
+
+
+class ProfileManagerHelper:
+
+ @classmethod
+ def get_profile_manager(cls, profile_manager_id: str) -> ProfileManager:
+ try:
+ return ProfileManager.objects.get(profile_id=profile_manager_id)
+ except ProfileManager.DoesNotExist:
+ raise ValueError("ProfileManager does not exist.")
diff --git a/backend/prompt_studio/prompt_profile_manager/serializers.py b/backend/prompt_studio/prompt_profile_manager/serializers.py
index 4d4753561..fc83aaab4 100644
--- a/backend/prompt_studio/prompt_profile_manager/serializers.py
+++ b/backend/prompt_studio/prompt_profile_manager/serializers.py
@@ -2,6 +2,7 @@
from adapter_processor.adapter_processor import AdapterProcessor
from prompt_studio.prompt_profile_manager.constants import ProfileManagerKeys
+from prompt_studio.prompt_studio_core.exceptions import MaxProfilesReachedError
from backend.serializers import AuditSerializer
@@ -38,3 +39,15 @@ def to_representation(self, instance): # type: ignore
AdapterProcessor.get_adapter_instance_by_id(x2text)
)
return rep
+
+ def validate(self, data):
+ prompt_studio_tool = data.get(ProfileManagerKeys.PROMPT_STUDIO_TOOL)
+
+ profile_count = ProfileManager.objects.filter(
+ prompt_studio_tool=prompt_studio_tool
+ ).count()
+
+ if profile_count >= ProfileManagerKeys.MAX_PROFILE_COUNT:
+ raise MaxProfilesReachedError()
+
+ return data
diff --git a/backend/prompt_studio/prompt_studio_core/constants.py b/backend/prompt_studio/prompt_studio_core/constants.py
index 934d9b530..55d61e32e 100644
--- a/backend/prompt_studio/prompt_studio_core/constants.py
+++ b/backend/prompt_studio/prompt_studio_core/constants.py
@@ -85,6 +85,9 @@ class ToolStudioPromptKeys:
NOTES = "NOTES"
OUTPUT = "output"
SEQUENCE_NUMBER = "sequence_number"
+ PROFILE_MANAGER_ID = "profile_manager"
+ CONTEXT = "context"
+ METADATA = "metadata"
class FileViewTypes:
@@ -108,6 +111,13 @@ class LogLevel(Enum):
FATAL = "FATAL"
+class IndexingStatus(Enum):
+ PENDING_STATUS = "pending"
+ COMPLETED_STATUS = "completed"
+ STARTED_STATUS = "started"
+ DOCUMENT_BEING_INDEXED = "Document is being indexed"
+
+
class DefaultPrompts:
PREAMBLE = (
"Your ability to extract and summarize this context accurately "
diff --git a/backend/prompt_studio/prompt_studio_core/document_indexing_service.py b/backend/prompt_studio/prompt_studio_core/document_indexing_service.py
new file mode 100644
index 000000000..539c5a2dc
--- /dev/null
+++ b/backend/prompt_studio/prompt_studio_core/document_indexing_service.py
@@ -0,0 +1,53 @@
+from typing import Optional
+
+from django.conf import settings
+from prompt_studio.prompt_studio_core.constants import IndexingStatus
+from utils.cache_service import CacheService
+
+
+class DocumentIndexingService:
+ CACHE_PREFIX = "document_indexing:"
+
+ @classmethod
+ def set_document_indexing(cls, org_id: str, user_id: str, doc_id_key: str) -> None:
+ CacheService.set_key(
+ cls._cache_key(org_id, user_id, doc_id_key),
+ IndexingStatus.STARTED_STATUS.value,
+ expire=settings.INDEXING_FLAG_TTL,
+ )
+
+ @classmethod
+ def is_document_indexing(cls, org_id: str, user_id: str, doc_id_key: str) -> bool:
+ return (
+ CacheService.get_key(cls._cache_key(org_id, user_id, doc_id_key))
+ == IndexingStatus.STARTED_STATUS.value
+ )
+
+ @classmethod
+ def mark_document_indexed(
+ cls, org_id: str, user_id: str, doc_id_key: str, doc_id: str
+ ) -> None:
+ CacheService.set_key(
+ cls._cache_key(org_id, user_id, doc_id_key),
+ doc_id,
+ expire=settings.INDEXING_FLAG_TTL,
+ )
+
+ @classmethod
+ def get_indexed_document_id(
+ cls, org_id: str, user_id: str, doc_id_key: str
+ ) -> Optional[str]:
+ result = CacheService.get_key(cls._cache_key(org_id, user_id, doc_id_key))
+ if result and result != IndexingStatus.STARTED_STATUS.value:
+ return result
+ return None
+
+ @classmethod
+ def remove_document_indexing(
+ cls, org_id: str, user_id: str, doc_id_key: str
+ ) -> None:
+ CacheService.delete_a_key(cls._cache_key(org_id, user_id, doc_id_key))
+
+ @classmethod
+ def _cache_key(cls, org_id: str, user_id: str, doc_id_key: str) -> str:
+ return f"{cls.CACHE_PREFIX}{org_id}:{user_id}:{doc_id_key}"
diff --git a/backend/prompt_studio/prompt_studio_core/exceptions.py b/backend/prompt_studio/prompt_studio_core/exceptions.py
index 666d41241..241418060 100644
--- a/backend/prompt_studio/prompt_studio_core/exceptions.py
+++ b/backend/prompt_studio/prompt_studio_core/exceptions.py
@@ -1,3 +1,4 @@
+from prompt_studio.prompt_profile_manager.constants import ProfileManagerKeys
from prompt_studio.prompt_studio_core.constants import ToolStudioErrors
from rest_framework.exceptions import APIException
@@ -58,3 +59,11 @@ class PermissionError(APIException):
class EmptyPromptError(APIException):
status_code = 422
default_detail = "Prompt(s) cannot be empty"
+
+
+class MaxProfilesReachedError(APIException):
+ status_code = 403
+ default_detail = (
+ f"Maximum number of profiles (max {ProfileManagerKeys.MAX_PROFILE_COUNT})"
+ " per prompt studio project has been reached."
+ )
diff --git a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py
index caa4e377c..c0f7b73d7 100644
--- a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py
+++ b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py
@@ -13,9 +13,15 @@
from django.db.models.manager import BaseManager
from file_management.file_management_helper import FileManagerHelper
from prompt_studio.prompt_profile_manager.models import ProfileManager
+from prompt_studio.prompt_profile_manager.profile_manager_helper import (
+ ProfileManagerHelper,
+)
from prompt_studio.prompt_studio.models import ToolStudioPrompt
-from prompt_studio.prompt_studio_core.constants import LogLevels
+from prompt_studio.prompt_studio_core.constants import IndexingStatus, LogLevels
from prompt_studio.prompt_studio_core.constants import ToolStudioPromptKeys as TSPKeys
+from prompt_studio.prompt_studio_core.document_indexing_service import (
+ DocumentIndexingService,
+)
from prompt_studio.prompt_studio_core.exceptions import (
AnswerFetchError,
DefaultProfileError,
@@ -344,6 +350,7 @@ def index_document(
is_summary=is_summary,
reindex=True,
run_id=run_id,
+ user_id=user_id,
)
logger.info(f"[{tool_id}] Indexing successful for doc: {file_name}")
@@ -354,7 +361,7 @@ def index_document(
"Indexing successful",
)
- return doc_id
+ return doc_id.get("output")
@staticmethod
def prompt_responder(
@@ -364,6 +371,7 @@ def prompt_responder(
document_id: str,
id: Optional[str] = None,
run_id: str = None,
+ profile_manager_id: Optional[str] = None,
) -> Any:
"""Execute chain/single run of the prompts. Makes a call to prompt
service and returns the dict of response.
@@ -374,6 +382,7 @@ def prompt_responder(
user_id (str): User's ID
document_id (str): UUID of the document uploaded
id (Optional[str]): ID of the prompt
+ profile_manager_id (Optional[str]): UUID of the profile manager
Raises:
AnswerFetchError: Error from prompt-service
@@ -383,44 +392,94 @@ def prompt_responder(
"""
document: DocumentManager = DocumentManager.objects.get(pk=document_id)
doc_name: str = document.document_name
-
- doc_path = FileManagerHelper.handle_sub_directory_for_tenants(
- org_id=org_id,
- user_id=user_id,
- tool_id=tool_id,
- is_create=False,
+ doc_path = PromptStudioHelper._get_document_path(
+ org_id, user_id, tool_id, doc_name
)
- doc_path = str(Path(doc_path) / doc_name)
if id:
- prompt_instance = PromptStudioHelper._fetch_prompt_from_id(id)
- prompt_name = prompt_instance.prompt_key
- logger.info(f"[{tool_id}] Executing single prompt {id}")
- PromptStudioHelper._publish_log(
- {
- "tool_id": tool_id,
- "run_id": run_id,
- "prompt_key": prompt_name,
- "doc_name": doc_name,
- },
- LogLevels.INFO,
- LogLevels.RUN,
- "Executing single prompt",
+ return PromptStudioHelper._execute_single_prompt(
+ id,
+ doc_path,
+ doc_name,
+ tool_id,
+ org_id,
+ user_id,
+ document_id,
+ run_id,
+ profile_manager_id,
+ )
+ else:
+ return PromptStudioHelper._execute_prompts_in_single_pass(
+ doc_path, tool_id, org_id, user_id, document_id, run_id
)
- prompts: list[ToolStudioPrompt] = []
- prompts.append(prompt_instance)
- tool: CustomTool = prompt_instance.tool_id
+ @staticmethod
+ def _execute_single_prompt(
+ id,
+ doc_path,
+ doc_name,
+ tool_id,
+ org_id,
+ user_id,
+ document_id,
+ run_id,
+ profile_manager_id,
+ ):
+ prompt_instance = PromptStudioHelper._fetch_prompt_from_id(id)
+ prompt_name = prompt_instance.prompt_key
+ PromptStudioHelper._publish_log(
+ {
+ "tool_id": tool_id,
+ "run_id": run_id,
+ "prompt_key": prompt_name,
+ "doc_name": doc_name,
+ },
+ LogLevels.INFO,
+ LogLevels.RUN,
+ "Executing single prompt",
+ )
+ prompts = [prompt_instance]
+ tool = prompt_instance.tool_id
- if tool.summarize_as_source:
- directory, filename = os.path.split(doc_path)
- doc_path = os.path.join(
- directory,
- TSPKeys.SUMMARIZE,
- os.path.splitext(filename)[0] + ".txt",
- )
+ if tool.summarize_as_source:
+ directory, filename = os.path.split(doc_path)
+ doc_path = os.path.join(
+ directory, TSPKeys.SUMMARIZE, os.path.splitext(filename)[0] + ".txt"
+ )
- logger.info(f"[{tool.tool_id}] Invoking prompt service for prompt {id}")
+ PromptStudioHelper._publish_log(
+ {
+ "tool_id": tool_id,
+ "run_id": run_id,
+ "prompt_key": prompt_name,
+ "doc_name": doc_name,
+ },
+ LogLevels.DEBUG,
+ LogLevels.RUN,
+ "Invoking prompt service",
+ )
+
+ try:
+ response = PromptStudioHelper._fetch_response(
+ doc_path=doc_path,
+ doc_name=doc_name,
+ tool=tool,
+ prompt=prompt_instance,
+ org_id=org_id,
+ document_id=document_id,
+ run_id=run_id,
+ profile_manager_id=profile_manager_id,
+ user_id=user_id,
+ )
+ return PromptStudioHelper._handle_response(
+ response, run_id, prompts, document_id, False, profile_manager_id
+ )
+ except Exception as e:
+ logger.error(
+ f"[{tool.tool_id}] Error while fetching response for "
+ f"prompt {id} and doc {document_id}: {e}"
+ )
+ msg = str(e)
PromptStudioHelper._publish_log(
{
"tool_id": tool_id,
@@ -428,130 +487,89 @@ def prompt_responder(
"prompt_key": prompt_name,
"doc_name": doc_name,
},
- LogLevels.DEBUG,
+ LogLevels.ERROR,
LogLevels.RUN,
- "Invoking prompt service",
+ msg,
)
+ raise e
- try:
- response = PromptStudioHelper._fetch_response(
- doc_path=doc_path,
- doc_name=doc_name,
- tool=tool,
- prompt=prompt_instance,
- org_id=org_id,
- document_id=document_id,
- run_id=run_id,
- )
+ @staticmethod
+ def _execute_prompts_in_single_pass(
+ doc_path, tool_id, org_id, user_id, document_id, run_id
+ ):
+ prompts = PromptStudioHelper.fetch_prompt_from_tool(tool_id)
+ prompts = [prompt for prompt in prompts if prompt.prompt_type != TSPKeys.NOTES]
+ if not prompts:
+ logger.error(f"[{tool_id or 'NA'}] No prompts found for id: {id}")
+ raise NoPromptsFound()
- OutputManagerHelper.handle_prompt_output_update(
- run_id=run_id,
- prompts=prompts,
- outputs=response["output"],
- document_id=document_id,
- is_single_pass_extract=False,
- )
- # TODO: Review if this catch-all is required
- except Exception as e:
- logger.error(
- f"[{tool.tool_id}] Error while fetching response for "
- f"prompt {id} and doc {document_id}: {e}"
- )
- msg: str = (
- f"Error while fetching response for "
- f"'{prompt_name}' with '{doc_name}'. {e}"
- )
- if isinstance(e, AnswerFetchError):
- msg = str(e)
- PromptStudioHelper._publish_log(
- {
- "tool_id": tool_id,
- "run_id": run_id,
- "prompt_key": prompt_name,
- "doc_name": doc_name,
- },
- LogLevels.ERROR,
- LogLevels.RUN,
- msg,
- )
- raise e
+ PromptStudioHelper._publish_log(
+ {"tool_id": tool_id, "run_id": run_id, "prompt_id": str(id)},
+ LogLevels.INFO,
+ LogLevels.RUN,
+ "Executing prompts in single pass",
+ )
- logger.info(
- f"[{tool.tool_id}] Response fetched successfully for prompt {id}"
+ try:
+ tool = prompts[0].tool_id
+ response = PromptStudioHelper._fetch_single_pass_response(
+ file_path=doc_path,
+ tool=tool,
+ prompts=prompts,
+ org_id=org_id,
+ document_id=document_id,
+ run_id=run_id,
+ user_id=user_id,
+ )
+ return PromptStudioHelper._handle_response(
+ response, run_id, prompts, document_id, True
+ )
+ except Exception as e:
+ logger.error(
+ f"[{tool.tool_id}] Error while fetching single pass response: {e}"
)
PromptStudioHelper._publish_log(
{
"tool_id": tool_id,
"run_id": run_id,
- "prompt_key": prompt_name,
- "doc_name": doc_name,
+ "prompt_id": str(id),
},
- LogLevels.INFO,
- LogLevels.RUN,
- "Single prompt execution completed",
- )
-
- return response
- else:
- prompts = PromptStudioHelper.fetch_prompt_from_tool(tool_id)
- prompts = [
- prompt for prompt in prompts if prompt.prompt_type != TSPKeys.NOTES
- ]
- if not prompts:
- logger.error(f"[{tool_id or 'NA'}] No prompts found for id: {id}")
- raise NoPromptsFound()
-
- logger.info(f"[{tool_id}] Executing prompts in single pass")
- PromptStudioHelper._publish_log(
- {"tool_id": tool_id, "run_id": run_id, "prompt_id": str(id)},
- LogLevels.INFO,
+ LogLevels.ERROR,
LogLevels.RUN,
- "Executing prompts in single pass",
+ f"Failed to fetch single pass response. {e}",
)
+ raise e
- try:
- tool = prompts[0].tool_id
- response = PromptStudioHelper._fetch_single_pass_response(
- file_path=doc_path,
- tool=tool,
- prompts=prompts,
- org_id=org_id,
- document_id=document_id,
- run_id=run_id,
- )
-
- OutputManagerHelper.handle_prompt_output_update(
- run_id=run_id,
- prompts=prompts,
- outputs=response[TSPKeys.OUTPUT],
- document_id=document_id,
- is_single_pass_extract=True,
- )
- except Exception as e:
- logger.error(
- f"[{tool.tool_id}] Error while fetching single pass response: {e}" # noqa: E501
- )
- PromptStudioHelper._publish_log(
- {
- "tool_id": tool_id,
- "run_id": run_id,
- "prompt_id": str(id),
- },
- LogLevels.ERROR,
- LogLevels.RUN,
- f"Failed to fetch single pass response. {e}",
- )
- raise e
-
- logger.info(f"[{tool.tool_id}] Single pass response fetched successfully")
- PromptStudioHelper._publish_log(
- {"tool_id": tool_id, "run_id": run_id, "prompt_id": str(id)},
- LogLevels.INFO,
- LogLevels.RUN,
- "Single pass execution completed",
- )
+ @staticmethod
+ def _get_document_path(org_id, user_id, tool_id, doc_name):
+ doc_path = FileManagerHelper.handle_sub_directory_for_tenants(
+ org_id=org_id,
+ user_id=user_id,
+ tool_id=tool_id,
+ is_create=False,
+ )
+ return str(Path(doc_path) / doc_name)
- return response
+ @staticmethod
+ def _handle_response(
+ response, run_id, prompts, document_id, is_single_pass, profile_manager_id=None
+ ):
+ if response.get("status") == IndexingStatus.PENDING_STATUS.value:
+ return {
+ "status": IndexingStatus.PENDING_STATUS.value,
+ "message": IndexingStatus.DOCUMENT_BEING_INDEXED.value,
+ }
+
+ OutputManagerHelper.handle_prompt_output_update(
+ run_id=run_id,
+ prompts=prompts,
+ outputs=response["output"],
+ document_id=document_id,
+ is_single_pass_extract=is_single_pass,
+ profile_manager_id=profile_manager_id,
+ context=response["metadata"].get("context"),
+ )
+ return response
@staticmethod
def _fetch_response(
@@ -562,6 +580,8 @@ def _fetch_response(
org_id: str,
document_id: str,
run_id: str,
+ user_id: str,
+ profile_manager_id: Optional[str] = None,
) -> Any:
"""Utility function to invoke prompt service. Used internally.
@@ -572,6 +592,9 @@ def _fetch_response(
prompt (ToolStudioPrompt): ToolStudioPrompt instance to fetch response
org_id (str): UUID of the organization
document_id (str): UUID of the document
+ profile_manager_id (Optional[str]): UUID of the profile manager
+ user_id (str): The ID of the user who uploaded the document
+
Raises:
DefaultProfileError: If no default profile is selected
@@ -580,6 +603,14 @@ def _fetch_response(
Returns:
Any: Output from LLM
"""
+
+ # Fetch the ProfileManager instance using the profile_manager_id if provided
+ profile_manager = prompt.profile_manager
+ if profile_manager_id:
+ profile_manager = ProfileManagerHelper.get_profile_manager(
+ profile_manager_id=profile_manager_id
+ )
+
monitor_llm_instance: Optional[AdapterInstance] = tool.monitor_llm
monitor_llm: Optional[str] = None
challenge_llm_instance: Optional[AdapterInstance] = tool.challenge_llm
@@ -600,28 +631,33 @@ def _fetch_response(
challenge_llm = str(default_profile.llm.id)
# Need to check the user who created profile manager
- PromptStudioHelper.validate_adapter_status(prompt.profile_manager)
+ PromptStudioHelper.validate_adapter_status(profile_manager)
# Need to check the user who created profile manager
# has access to adapters
- PromptStudioHelper.validate_profile_manager_owner_access(prompt.profile_manager)
+ PromptStudioHelper.validate_profile_manager_owner_access(profile_manager)
# Not checking reindex here as there might be
# change in Profile Manager
- vector_db = str(prompt.profile_manager.vector_store.id)
- embedding_model = str(prompt.profile_manager.embedding_model.id)
- llm = str(prompt.profile_manager.llm.id)
- x2text = str(prompt.profile_manager.x2text.id)
- prompt_profile_manager: ProfileManager = prompt.profile_manager
- if not prompt_profile_manager:
+ vector_db = str(profile_manager.vector_store.id)
+ embedding_model = str(profile_manager.embedding_model.id)
+ llm = str(profile_manager.llm.id)
+ x2text = str(profile_manager.x2text.id)
+ if not profile_manager:
raise DefaultProfileError()
- PromptStudioHelper.dynamic_indexer(
- profile_manager=prompt_profile_manager,
+ index_result = PromptStudioHelper.dynamic_indexer(
+ profile_manager=profile_manager,
file_path=doc_path,
tool_id=str(tool.tool_id),
org_id=org_id,
document_id=document_id,
is_summary=tool.summarize_as_source,
run_id=run_id,
+ user_id=user_id,
)
+ if index_result.get("status") == IndexingStatus.PENDING_STATUS.value:
+ return {
+ "status": IndexingStatus.PENDING_STATUS.value,
+ "message": IndexingStatus.DOCUMENT_BEING_INDEXED.value,
+ }
output: dict[str, Any] = {}
outputs: list[dict[str, Any]] = []
@@ -639,16 +675,16 @@ def _fetch_response(
output[TSPKeys.PROMPT] = prompt.prompt
output[TSPKeys.ACTIVE] = prompt.active
- output[TSPKeys.CHUNK_SIZE] = prompt.profile_manager.chunk_size
+ output[TSPKeys.CHUNK_SIZE] = profile_manager.chunk_size
output[TSPKeys.VECTOR_DB] = vector_db
output[TSPKeys.EMBEDDING] = embedding_model
- output[TSPKeys.CHUNK_OVERLAP] = prompt.profile_manager.chunk_overlap
+ output[TSPKeys.CHUNK_OVERLAP] = profile_manager.chunk_overlap
output[TSPKeys.LLM] = llm
output[TSPKeys.TYPE] = prompt.enforce_type
output[TSPKeys.NAME] = prompt.prompt_key
- output[TSPKeys.RETRIEVAL_STRATEGY] = prompt.profile_manager.retrieval_strategy
- output[TSPKeys.SIMILARITY_TOP_K] = prompt.profile_manager.similarity_top_k
- output[TSPKeys.SECTION] = prompt.profile_manager.section
+ output[TSPKeys.RETRIEVAL_STRATEGY] = profile_manager.retrieval_strategy
+ output[TSPKeys.SIMILARITY_TOP_K] = profile_manager.similarity_top_k
+ output[TSPKeys.SECTION] = profile_manager.section
output[TSPKeys.X2TEXT_ADAPTER] = x2text
# Eval settings for the prompt
output[TSPKeys.EVAL_SETTINGS] = {}
@@ -715,10 +751,11 @@ def dynamic_indexer(
file_path: str,
org_id: str,
document_id: str,
+ user_id: str,
is_summary: bool = False,
reindex: bool = False,
run_id: str = None,
- ) -> str:
+ ) -> Any:
"""Used to index a file based on the passed arguments.
This is useful when a file needs to be indexed dynamically as the
@@ -732,6 +769,7 @@ def dynamic_indexer(
org_id (str): ID of the organization
is_summary (bool, optional): Flag to ensure if extracted contents
need to be persisted. Defaults to False.
+ user_id (str): The ID of the user who uploaded the document
Returns:
str: Index key for the combination of arguments
@@ -750,9 +788,42 @@ def dynamic_indexer(
profile_manager.chunk_size = 0
try:
+
usage_kwargs = {"run_id": run_id}
util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id)
tool_index = Index(tool=util)
+ doc_id_key = tool_index.generate_file_id(
+ tool_id=tool_id,
+ vector_db=vector_db,
+ embedding=embedding_model,
+ x2text=x2text_adapter,
+ chunk_size=str(profile_manager.chunk_size),
+ chunk_overlap=str(profile_manager.chunk_overlap),
+ file_path=file_path,
+ file_hash=None,
+ )
+ if not reindex:
+ indexed_doc_id = DocumentIndexingService.get_indexed_document_id(
+ org_id=org_id, user_id=user_id, doc_id_key=doc_id_key
+ )
+ if indexed_doc_id:
+ return {
+ "status": IndexingStatus.COMPLETED_STATUS.value,
+ "output": indexed_doc_id,
+ }
+ # Polling if document is already being indexed
+ if DocumentIndexingService.is_document_indexing(
+ org_id=org_id, user_id=user_id, doc_id_key=doc_id_key
+ ):
+ return {
+ "status": IndexingStatus.PENDING_STATUS.value,
+ "output": IndexingStatus.DOCUMENT_BEING_INDEXED.value,
+ }
+
+ # Set the document as being indexed
+ DocumentIndexingService.set_document_indexing(
+ org_id=org_id, user_id=user_id, doc_id_key=doc_id_key
+ )
doc_id: str = tool_index.index(
tool_id=tool_id,
embedding_instance_id=embedding_model,
@@ -772,7 +843,10 @@ def dynamic_indexer(
profile_manager=profile_manager,
doc_id=doc_id,
)
- return doc_id
+ DocumentIndexingService.mark_document_indexed(
+ org_id=org_id, user_id=user_id, doc_id_key=doc_id_key, doc_id=doc_id
+ )
+ return {"status": IndexingStatus.COMPLETED_STATUS.value, "output": doc_id}
except (IndexingError, IndexingAPIError, SdkError) as e:
doc_name = os.path.split(file_path)[1]
PromptStudioHelper._publish_log(
@@ -791,6 +865,7 @@ def _fetch_single_pass_response(
file_path: str,
prompts: list[ToolStudioPrompt],
org_id: str,
+ user_id: str,
document_id: str,
run_id: str = None,
) -> Any:
@@ -819,7 +894,7 @@ def _fetch_single_pass_response(
if not default_profile:
raise DefaultProfileError()
- PromptStudioHelper.dynamic_indexer(
+ index_result = PromptStudioHelper.dynamic_indexer(
profile_manager=default_profile,
file_path=file_path,
tool_id=tool_id,
@@ -827,7 +902,13 @@ def _fetch_single_pass_response(
is_summary=tool.summarize_as_source,
document_id=document_id,
run_id=run_id,
+ user_id=user_id,
)
+ if index_result.get("status") == IndexingStatus.PENDING_STATUS.value:
+ return {
+ "status": IndexingStatus.PENDING_STATUS.value,
+ "message": IndexingStatus.DOCUMENT_BEING_INDEXED.value,
+ }
vector_db = str(default_profile.vector_store.id)
embedding_model = str(default_profile.embedding_model.id)
diff --git a/backend/prompt_studio/prompt_studio_core/views.py b/backend/prompt_studio/prompt_studio_core/views.py
index 9093efd42..8db0a3ef5 100644
--- a/backend/prompt_studio/prompt_studio_core/views.py
+++ b/backend/prompt_studio/prompt_studio_core/views.py
@@ -21,6 +21,9 @@
ToolStudioKeys,
ToolStudioPromptKeys,
)
+from prompt_studio.prompt_studio_core.document_indexing_service import (
+ DocumentIndexingService,
+)
from prompt_studio.prompt_studio_core.exceptions import (
IndexingAPIError,
ToolDeleteError,
@@ -30,6 +33,7 @@
from prompt_studio.prompt_studio_document_manager.prompt_studio_document_helper import ( # noqa: E501
PromptStudioDocumentHelper,
)
+from prompt_studio.prompt_studio_index_manager.models import IndexManager
from prompt_studio.prompt_studio_registry.prompt_studio_registry_helper import (
PromptStudioRegistryHelper,
)
@@ -264,6 +268,7 @@ def fetch_response(self, request: HttpRequest, pk: Any = None) -> Response:
document_id: str = request.data.get(ToolStudioPromptKeys.DOCUMENT_ID)
id: str = request.data.get(ToolStudioPromptKeys.ID)
run_id: str = request.data.get(ToolStudioPromptKeys.RUN_ID)
+ profile_manager: str = request.data.get(ToolStudioPromptKeys.PROFILE_MANAGER_ID)
if not run_id:
# Generate a run_id
run_id = CommonUtils.generate_uuid()
@@ -275,6 +280,7 @@ def fetch_response(self, request: HttpRequest, pk: Any = None) -> Response:
user_id=custom_tool.created_by.user_id,
document_id=document_id,
run_id=run_id,
+ profile_manager_id=profile_manager,
)
return Response(response, status=status.HTTP_200_OK)
@@ -446,17 +452,26 @@ def delete_for_ide(self, request: HttpRequest, pk: uuid) -> Response:
document_id: str = serializer.validated_data.get(
ToolStudioPromptKeys.DOCUMENT_ID
)
+ org_id = UserSessionUtils.get_organization_id(request)
+ user_id = custom_tool.created_by.user_id
document: DocumentManager = DocumentManager.objects.get(pk=document_id)
file_name: str = document.document_name
file_path = FileManagerHelper.handle_sub_directory_for_tenants(
- UserSessionUtils.get_organization_id(request),
+ org_id=org_id,
is_create=False,
- user_id=custom_tool.created_by.user_id,
+ user_id=user_id,
tool_id=str(custom_tool.tool_id),
)
path = file_path
file_system = LocalStorageFS(settings={"path": path})
try:
+ # Delete indexed flags in redis
+ index_managers = IndexManager.objects.filter(document_manager=document_id)
+ for index_manager in index_managers:
+ raw_index_id = index_manager.raw_index_id
+ DocumentIndexingService.remove_document_indexing(
+ org_id=org_id, user_id=user_id, doc_id_key=raw_index_id
+ )
# Delete the document record
document.delete()
# Delete the files
diff --git a/backend/prompt_studio/prompt_studio_output_manager/migrations/0013_promptstudiooutputmanager_context.py b/backend/prompt_studio/prompt_studio_output_manager/migrations/0013_promptstudiooutputmanager_context.py
new file mode 100644
index 000000000..9d72dbd4d
--- /dev/null
+++ b/backend/prompt_studio/prompt_studio_output_manager/migrations/0013_promptstudiooutputmanager_context.py
@@ -0,0 +1,20 @@
+# Generated by Django 4.2.1 on 2024-06-27 18:27
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("prompt_studio_output_manager", "0012_promptstudiooutputmanager_run_id"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="promptstudiooutputmanager",
+ name="context",
+ field=models.CharField(
+ blank=True, db_comment="Field to store chucks used", null=True
+ ),
+ ),
+ ]
diff --git a/backend/prompt_studio/prompt_studio_output_manager/migrations/0014_alter_promptstudiooutputmanager_context.py b/backend/prompt_studio/prompt_studio_output_manager/migrations/0014_alter_promptstudiooutputmanager_context.py
new file mode 100644
index 000000000..9d7844eaa
--- /dev/null
+++ b/backend/prompt_studio/prompt_studio_output_manager/migrations/0014_alter_promptstudiooutputmanager_context.py
@@ -0,0 +1,20 @@
+# Generated by Django 4.2.1 on 2024-06-30 17:17
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("prompt_studio_output_manager", "0013_promptstudiooutputmanager_context"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="promptstudiooutputmanager",
+ name="context",
+ field=models.TextField(
+ blank=True, db_comment="Field to store chunks used", null=True
+ ),
+ ),
+ ]
diff --git a/backend/prompt_studio/prompt_studio_output_manager/models.py b/backend/prompt_studio/prompt_studio_output_manager/models.py
index 14febf634..e1f7f5b86 100644
--- a/backend/prompt_studio/prompt_studio_output_manager/models.py
+++ b/backend/prompt_studio/prompt_studio_output_manager/models.py
@@ -21,6 +21,9 @@ class PromptStudioOutputManager(BaseModel):
output = models.CharField(
db_comment="Field to store output", editable=True, null=True, blank=True
)
+ context = models.TextField(
+ db_comment="Field to store chunks used", editable=True, null=True, blank=True
+ )
eval_metrics = models.JSONField(
db_column="eval_metrics",
null=False,
diff --git a/backend/prompt_studio/prompt_studio_output_manager/output_manager_helper.py b/backend/prompt_studio/prompt_studio_output_manager/output_manager_helper.py
index 6f942e3b7..b88a25602 100644
--- a/backend/prompt_studio/prompt_studio_output_manager/output_manager_helper.py
+++ b/backend/prompt_studio/prompt_studio_output_manager/output_manager_helper.py
@@ -1,10 +1,14 @@
import json
import logging
-from typing import Any
+from typing import Any, Optional
from prompt_studio.prompt_profile_manager.models import ProfileManager
-from prompt_studio.prompt_studio.exceptions import AnswerFetchError
from prompt_studio.prompt_studio.models import ToolStudioPrompt
+from prompt_studio.prompt_studio_core.exceptions import (
+ AnswerFetchError,
+ DefaultProfileError,
+)
+from prompt_studio.prompt_studio_core.models import CustomTool
from prompt_studio.prompt_studio_document_manager.models import DocumentManager
from prompt_studio.prompt_studio_output_manager.constants import (
PromptStudioOutputManagerKeys as PSOMKeys,
@@ -20,42 +24,32 @@ def handle_prompt_output_update(
run_id: str,
prompts: list[ToolStudioPrompt],
outputs: Any,
+ context: Any,
document_id: str,
is_single_pass_extract: bool,
+ profile_manager_id: Optional[str] = None,
) -> None:
"""Handles updating prompt outputs in the database.
Args:
+ run_id (str): ID of the run.
prompts (list[ToolStudioPrompt]): List of prompts to update.
outputs (Any): Outputs corresponding to the prompts.
document_id (str): ID of the document.
+ profile_manager_id (Optional[str]): UUID of the profile manager.
is_single_pass_extract (bool):
Flag indicating if single pass extract is active.
"""
- # Check if prompts list is empty
- if not prompts:
- return # Return early if prompts list is empty
- tool = prompts[0].tool_id
- document_manager = DocumentManager.objects.get(pk=document_id)
- default_profile = ProfileManager.get_default_llm_profile(tool=tool)
- # Iterate through each prompt in the list
- for prompt in prompts:
- if prompt.prompt_type == PSOMKeys.NOTES:
- continue
- if is_single_pass_extract:
- profile_manager = default_profile
- else:
- profile_manager = prompt.profile_manager
- output = json.dumps(outputs.get(prompt.prompt_key))
- eval_metrics = outputs.get(f"{prompt.prompt_key}__evaluation", [])
-
- # Attempt to update an existing output manager,
- # for the given criteria,
- # or create a new one if it doesn't exist
+ def update_or_create_prompt_output(
+ prompt: ToolStudioPrompt,
+ profile_manager: ProfileManager,
+ output: str,
+ eval_metrics: list[Any],
+ tool: CustomTool,
+ context: str,
+ ):
try:
- # Create or get the existing record for this document, prompt and
- # profile combo
_, success = PromptStudioOutputManager.objects.get_or_create(
document_manager=document_manager,
tool_id=tool,
@@ -65,6 +59,7 @@ def handle_prompt_output_update(
defaults={
"output": output,
"eval_metrics": eval_metrics,
+ "context": context,
},
)
@@ -79,11 +74,12 @@ def handle_prompt_output_update(
f"profile {profile_manager.profile_id}"
)
- args: dict[str, str] = dict()
- args["run_id"] = run_id
- args["output"] = output
- args["eval_metrics"] = eval_metrics
- # Update the record with the run id and other params
+ args: dict[str, str] = {
+ "run_id": run_id,
+ "output": output,
+ "eval_metrics": eval_metrics,
+ "context": context,
+ }
PromptStudioOutputManager.objects.filter(
document_manager=document_manager,
tool_id=tool,
@@ -94,3 +90,57 @@ def handle_prompt_output_update(
except Exception as e:
raise AnswerFetchError(f"Error updating prompt output {e}") from e
+
+ if not prompts:
+ return # Return early if prompts list is empty
+
+ tool = prompts[0].tool_id
+ default_profile = OutputManagerHelper.get_default_profile(
+ profile_manager_id, tool
+ )
+ document_manager = DocumentManager.objects.get(pk=document_id)
+
+ for prompt in prompts:
+ if prompt.prompt_type == PSOMKeys.NOTES or not prompt.active:
+ continue
+
+ if not is_single_pass_extract:
+ context = json.dumps(context.get(prompt.prompt_key))
+
+ output = json.dumps(outputs.get(prompt.prompt_key))
+ profile_manager = default_profile
+ eval_metrics = outputs.get(f"{prompt.prompt_key}__evaluation", [])
+
+ update_or_create_prompt_output(
+ prompt=prompt,
+ profile_manager=profile_manager,
+ output=output,
+ eval_metrics=eval_metrics,
+ tool=tool,
+ context=context,
+ )
+
+ @staticmethod
+ def get_default_profile(
+ profile_manager_id: Optional[str], tool: CustomTool
+ ) -> ProfileManager:
+ if profile_manager_id:
+ return OutputManagerHelper.fetch_profile_manager(profile_manager_id)
+ else:
+ return OutputManagerHelper.fetch_default_llm_profile(tool)
+
+ @staticmethod
+ def fetch_profile_manager(profile_manager_id: str) -> ProfileManager:
+ try:
+ return ProfileManager.objects.get(profile_id=profile_manager_id)
+ except ProfileManager.DoesNotExist:
+ raise DefaultProfileError(
+ f"ProfileManager with ID {profile_manager_id} does not exist."
+ )
+
+ @staticmethod
+ def fetch_default_llm_profile(tool: CustomTool) -> ProfileManager:
+ try:
+ return ProfileManager.get_default_llm_profile(tool=tool)
+ except DefaultProfileError:
+ raise DefaultProfileError("Default ProfileManager does not exist.")
diff --git a/backend/sample.env b/backend/sample.env
index abbc1757c..f42a32068 100644
--- a/backend/sample.env
+++ b/backend/sample.env
@@ -139,3 +139,6 @@ LOGS_BATCH_LIMIT=30
# Celery Configuration
CELERY_BROKER_URL = "redis://unstract-redis:6379"
+
+# Indexing flag to prevent re-index
+INDEXING_FLAG_TTL=1800
diff --git a/backend/usage/constants.py b/backend/usage/constants.py
index d28074e1e..8da54da05 100644
--- a/backend/usage/constants.py
+++ b/backend/usage/constants.py
@@ -4,3 +4,4 @@ class UsageKeys:
PROMPT_TOKENS = "prompt_tokens"
COMPLETION_TOKENS = "completion_tokens"
TOTAL_TOKENS = "total_tokens"
+ COST_IN_DOLLARS = "cost_in_dollars"
diff --git a/backend/usage/helper.py b/backend/usage/helper.py
index b91fae556..0bfab7556 100644
--- a/backend/usage/helper.py
+++ b/backend/usage/helper.py
@@ -36,6 +36,7 @@ def get_aggregated_token_count(run_id: str) -> dict:
prompt_tokens=Sum(UsageKeys.PROMPT_TOKENS),
completion_tokens=Sum(UsageKeys.COMPLETION_TOKENS),
total_tokens=Sum(UsageKeys.TOTAL_TOKENS),
+ cost_in_dollars=Sum(UsageKeys.COST_IN_DOLLARS),
)
logger.info(f"Token counts aggregated successfully for run_id: {run_id}")
@@ -50,6 +51,7 @@ def get_aggregated_token_count(run_id: str) -> dict:
UsageKeys.COMPLETION_TOKENS
),
UsageKeys.TOTAL_TOKENS: usage_summary.get(UsageKeys.TOTAL_TOKENS),
+ UsageKeys.COST_IN_DOLLARS: usage_summary.get(UsageKeys.COST_IN_DOLLARS),
}
return result
except Usage.DoesNotExist:
diff --git a/frontend/package-lock.json b/frontend/package-lock.json
index aee9ec808..018771100 100644
--- a/frontend/package-lock.json
+++ b/frontend/package-lock.json
@@ -25,6 +25,7 @@
"cronstrue": "^2.48.0",
"emoji-picker-react": "^4.8.0",
"emoji-regex": "^10.3.0",
+ "framer-motion": "^11.2.10",
"handlebars": "^4.7.8",
"http-proxy-middleware": "^2.0.6",
"js-cookie": "^3.0.5",
@@ -9383,6 +9384,30 @@
"url": "https://www.patreon.com/infusion"
}
},
+ "node_modules/framer-motion": {
+ "version": "11.2.10",
+ "resolved": "https://registry.npmjs.org/framer-motion/-/framer-motion-11.2.10.tgz",
+ "integrity": "sha512-/gr3PLZUVFCc86a9MqCUboVrALscrdluzTb3yew+2/qKBU8CX6nzs918/SRBRCqaPbx0TZP10CB6yFgK2C5cYQ==",
+ "dependencies": {
+ "tslib": "^2.4.0"
+ },
+ "peerDependencies": {
+ "@emotion/is-prop-valid": "*",
+ "react": "^18.0.0",
+ "react-dom": "^18.0.0"
+ },
+ "peerDependenciesMeta": {
+ "@emotion/is-prop-valid": {
+ "optional": true
+ },
+ "react": {
+ "optional": true
+ },
+ "react-dom": {
+ "optional": true
+ }
+ }
+ },
"node_modules/fresh": {
"version": "0.5.2",
"resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz",
@@ -27235,6 +27260,14 @@
"resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-4.2.0.tgz",
"integrity": "sha512-MhLuK+2gUcnZe8ZHlaaINnQLl0xRIGRfcGk2yl8xoQAfHrSsL3rYu6FCmBdkdbhc9EPlwyGHewaRsvwRMJtAlA=="
},
+ "framer-motion": {
+ "version": "11.2.10",
+ "resolved": "https://registry.npmjs.org/framer-motion/-/framer-motion-11.2.10.tgz",
+ "integrity": "sha512-/gr3PLZUVFCc86a9MqCUboVrALscrdluzTb3yew+2/qKBU8CX6nzs918/SRBRCqaPbx0TZP10CB6yFgK2C5cYQ==",
+ "requires": {
+ "tslib": "^2.4.0"
+ }
+ },
"fresh": {
"version": "0.5.2",
"resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz",
diff --git a/frontend/package.json b/frontend/package.json
index a01f53738..674ef6689 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -20,6 +20,7 @@
"cronstrue": "^2.48.0",
"emoji-picker-react": "^4.8.0",
"emoji-regex": "^10.3.0",
+ "framer-motion": "^11.2.10",
"handlebars": "^4.7.8",
"http-proxy-middleware": "^2.0.6",
"js-cookie": "^3.0.5",
diff --git a/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx b/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx
index 4b0cb5281..980a80916 100644
--- a/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx
+++ b/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx
@@ -7,6 +7,7 @@ import PropTypes from "prop-types";
import {
displayPromptResult,
+ getLLMModelNamesForProfiles,
promptType,
} from "../../../helpers/GetStaticData";
import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate";
@@ -31,18 +32,25 @@ try {
function CombinedOutput({ docId, setFilledFields }) {
const [combinedOutput, setCombinedOutput] = useState({});
const [isOutputLoading, setIsOutputLoading] = useState(false);
+ const [adapterData, setAdapterData] = useState([]);
+ const [activeKey, setActiveKey] = useState("0");
const {
details,
defaultLlmProfile,
singlePassExtractMode,
isSinglePassExtractLoading,
+ llmProfiles,
isSimplePromptStudio,
} = useCustomToolStore();
const { sessionDetails } = useSessionStore();
const { setAlertDetails } = useAlertStore();
const axiosPrivate = useAxiosPrivate();
const handleException = useExceptionHandler();
+ const [selectedProfile, setSelectedProfile] = useState(defaultLlmProfile);
+ useEffect(() => {
+ getAdapterInfo();
+ }, []);
useEffect(() => {
if (!docId || isSinglePassExtractLoading) {
return;
@@ -62,7 +70,7 @@ function CombinedOutput({ docId, setFilledFields }) {
}
output[item?.prompt_key] = "";
- let profileManager = item?.profile_manager;
+ let profileManager = selectedProfile || item?.profile_manager;
if (singlePassExtractMode) {
profileManager = defaultLlmProfile;
}
@@ -100,12 +108,25 @@ function CombinedOutput({ docId, setFilledFields }) {
.finally(() => {
setIsOutputLoading(false);
});
- }, [docId, singlePassExtractMode, isSinglePassExtractLoading]);
+ }, [
+ docId,
+ singlePassExtractMode,
+ isSinglePassExtractLoading,
+ selectedProfile,
+ ]);
const handleOutputApiRequest = async () => {
- let url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/prompt-output/?tool_id=${details?.tool_id}&document_manager=${docId}&is_single_pass_extract=${singlePassExtractMode}`;
+ let url;
if (isSimplePromptStudio) {
url = promptOutputApiSps(details?.tool_id, null, docId);
+ } else {
+ url = `/api/v1/unstract/${
+ sessionDetails?.orgId
+ }/prompt-studio/prompt-output/?tool_id=${
+ details?.tool_id
+ }&document_manager=${docId}&is_single_pass_extract=${singlePassExtractMode}&profile_manager=${
+ selectedProfile || defaultLlmProfile
+ }`;
}
const requestOptions = {
method: "GET",
@@ -122,15 +143,43 @@ function CombinedOutput({ docId, setFilledFields }) {
});
};
+ const getAdapterInfo = () => {
+ axiosPrivate
+ .get(
+ `/api/v1/unstract/${sessionDetails?.orgId}/adapter/?adapter_type=LLM`
+ )
+ .then((res) => {
+ const adapterList = res?.data;
+ setAdapterData(getLLMModelNamesForProfiles(llmProfiles, adapterList));
+ });
+ };
+
if (isOutputLoading) {
return
@@ -29,6 +49,11 @@ function JsonView({ combinedOutput }) { JsonView.propTypes = { combinedOutput: PropTypes.object.isRequired, + handleTabChange: PropTypes.func, + adapterData: PropTypes.array, + selectedProfile: PropTypes.string, + llmProfiles: PropTypes.array, + activeKey: PropTypes.string, }; export { JsonView }; diff --git a/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx b/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx index cc4d6d6f8..23de585e2 100644 --- a/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx +++ b/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx @@ -1,4 +1,4 @@ -import { Button, Modal, Table, Typography } from "antd"; +import { Button, Modal, Table, Tabs, Typography } from "antd"; import PropTypes from "prop-types"; import { useEffect, useState } from "react"; import { @@ -12,12 +12,17 @@ import { useCustomToolStore } from "../../../store/custom-tool-store"; import { useSessionStore } from "../../../store/session-store"; import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate"; import "./OutputForDocModal.css"; -import { displayPromptResult } from "../../../helpers/GetStaticData"; +import { + displayPromptResult, + getLLMModelNamesForProfiles, +} from "../../../helpers/GetStaticData"; import { SpinnerLoader } from "../../widgets/spinner-loader/SpinnerLoader"; import { useAlertStore } from "../../../store/alert-store"; import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; import { TokenUsage } from "../token-usage/TokenUsage"; import { useTokenUsageStore } from "../../../store/token-usage-store"; +import TabPane from "antd/es/tabs/TabPane"; +import { ProfileInfoBar } from "../profile-info-bar/ProfileInfoBar"; const columns = [ { @@ -57,6 +62,7 @@ function OutputForDocModal({ }) { const [promptOutputs, setPromptOutputs] = useState([]); const [rows, setRows] = useState([]); + const [adapterData, setAdapterData] = useState([]); const [isLoading, setIsLoading] = useState(false); const { details, @@ -66,6 +72,7 @@ function OutputForDocModal({ disableLlmOrDocChange, singlePassExtractMode, isSinglePassExtractLoading, + llmProfiles, } = useCustomToolStore(); const { sessionDetails } = useSessionStore(); const axiosPrivate = useAxiosPrivate(); @@ -73,12 +80,14 @@ function OutputForDocModal({ const { setAlertDetails } = useAlertStore(); const { handleException } = useExceptionHandler(); const { tokenUsage } = useTokenUsageStore(); + const [selectedProfile, setSelectedProfile] = useState(defaultLlmProfile); useEffect(() => { if (!open) { return; } handleGetOutputForDocs(); + getAdapterInfo(); }, [open, singlePassExtractMode, isSinglePassExtractLoading]); useEffect(() => { @@ -89,6 +98,12 @@ function OutputForDocModal({ handleRowsGeneration(promptOutputs); }, [promptOutputs, tokenUsage]); + useEffect(() => { + if (selectedProfile) { + handleGetOutputForDocs(selectedProfile); + } + }, [selectedProfile]); + const moveSelectedDocToTop = () => { // Create a copy of the list of documents const docs = [...listOfDocs]; @@ -147,8 +162,16 @@ function OutputForDocModal({ }); }; - const handleGetOutputForDocs = () => { - let profile = profileManagerId; + const getAdapterInfo = () => { + axiosPrivate + .get(`/api/v1/unstract/${sessionDetails.orgId}/adapter/?adapter_type=LLM`) + .then((res) => { + const adapterList = res.data; + setAdapterData(getLLMModelNamesForProfiles(llmProfiles, adapterList)); + }); + }; + + const handleGetOutputForDocs = (profile = profileManagerId) => { if (singlePassExtractMode) { profile = defaultLlmProfile; } @@ -206,10 +229,14 @@ function OutputForDocModal({ } const result = { - key: item, + key: item?.document_id, document: item?.document_name, token_count: ( -+ ), value: ( <> @@ -239,6 +266,14 @@ function OutputForDocModal({ setRows(rowsData); }; + const handleTabChange = (key) => { + if (key === "0") { + setSelectedProfile(profileManagerId); + } else { + setSelectedProfile(adapterData[key - 1]?.profile_id); + } + }; + return (
Profile not found
; + } + + return ( +No chunks founds
; + } + return ( + <> + {chunk?.map((line) => ( +