diff --git a/backend/adapter_processor/serializers.py b/backend/adapter_processor/serializers.py index 80637206e..a3f281dcb 100644 --- a/backend/adapter_processor/serializers.py +++ b/backend/adapter_processor/serializers.py @@ -124,6 +124,10 @@ def to_representation(self, instance: AdapterInstance) -> dict[str, str]: rep[common.ICON] = AdapterProcessor.get_adapter_data_with_key( instance.adapter_id, common.ICON ) + adapter_metadata = instance.get_adapter_meta_data() + model = adapter_metadata.get("model") + if model: + rep["model"] = model if instance.is_friction_less: rep["created_by_email"] = "Unstract" diff --git a/backend/backend/settings/base.py b/backend/backend/settings/base.py index ae25fe6b2..902a7d012 100644 --- a/backend/backend/settings/base.py +++ b/backend/backend/settings/base.py @@ -167,6 +167,7 @@ def get_required_setting( "CELERY_BROKER_URL", f"redis://{REDIS_HOST}:{REDIS_PORT}" ) +INDEXING_FLAG_TTL = int(get_required_setting("INDEXING_FLAG_TTL")) # Flag to Enable django admin ADMIN_ENABLED = False diff --git a/backend/prompt_studio/prompt_profile_manager/constants.py b/backend/prompt_studio/prompt_profile_manager/constants.py index 70cf34019..6540b58ee 100644 --- a/backend/prompt_studio/prompt_profile_manager/constants.py +++ b/backend/prompt_studio/prompt_profile_manager/constants.py @@ -7,6 +7,8 @@ class ProfileManagerKeys: VECTOR_STORE = "vector_store" EMBEDDING_MODEL = "embedding_model" X2TEXT = "x2text" + PROMPT_STUDIO_TOOL = "prompt_studio_tool" + MAX_PROFILE_COUNT = 4 class ProfileManagerErrors: diff --git a/backend/prompt_studio/prompt_profile_manager/profile_manager_helper.py b/backend/prompt_studio/prompt_profile_manager/profile_manager_helper.py new file mode 100644 index 000000000..68783b551 --- /dev/null +++ b/backend/prompt_studio/prompt_profile_manager/profile_manager_helper.py @@ -0,0 +1,11 @@ +from prompt_studio.prompt_profile_manager.models import ProfileManager + + +class ProfileManagerHelper: + + @classmethod + def get_profile_manager(cls, profile_manager_id: str) -> ProfileManager: + try: + return ProfileManager.objects.get(profile_id=profile_manager_id) + except ProfileManager.DoesNotExist: + raise ValueError("ProfileManager does not exist.") diff --git a/backend/prompt_studio/prompt_profile_manager/serializers.py b/backend/prompt_studio/prompt_profile_manager/serializers.py index 4d4753561..fc83aaab4 100644 --- a/backend/prompt_studio/prompt_profile_manager/serializers.py +++ b/backend/prompt_studio/prompt_profile_manager/serializers.py @@ -2,6 +2,7 @@ from adapter_processor.adapter_processor import AdapterProcessor from prompt_studio.prompt_profile_manager.constants import ProfileManagerKeys +from prompt_studio.prompt_studio_core.exceptions import MaxProfilesReachedError from backend.serializers import AuditSerializer @@ -38,3 +39,15 @@ def to_representation(self, instance): # type: ignore AdapterProcessor.get_adapter_instance_by_id(x2text) ) return rep + + def validate(self, data): + prompt_studio_tool = data.get(ProfileManagerKeys.PROMPT_STUDIO_TOOL) + + profile_count = ProfileManager.objects.filter( + prompt_studio_tool=prompt_studio_tool + ).count() + + if profile_count >= ProfileManagerKeys.MAX_PROFILE_COUNT: + raise MaxProfilesReachedError() + + return data diff --git a/backend/prompt_studio/prompt_studio_core/constants.py b/backend/prompt_studio/prompt_studio_core/constants.py index 934d9b530..55d61e32e 100644 --- a/backend/prompt_studio/prompt_studio_core/constants.py +++ b/backend/prompt_studio/prompt_studio_core/constants.py @@ -85,6 +85,9 @@ class ToolStudioPromptKeys: NOTES = "NOTES" OUTPUT = "output" SEQUENCE_NUMBER = "sequence_number" + PROFILE_MANAGER_ID = "profile_manager" + CONTEXT = "context" + METADATA = "metadata" class FileViewTypes: @@ -108,6 +111,13 @@ class LogLevel(Enum): FATAL = "FATAL" +class IndexingStatus(Enum): + PENDING_STATUS = "pending" + COMPLETED_STATUS = "completed" + STARTED_STATUS = "started" + DOCUMENT_BEING_INDEXED = "Document is being indexed" + + class DefaultPrompts: PREAMBLE = ( "Your ability to extract and summarize this context accurately " diff --git a/backend/prompt_studio/prompt_studio_core/document_indexing_service.py b/backend/prompt_studio/prompt_studio_core/document_indexing_service.py new file mode 100644 index 000000000..539c5a2dc --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core/document_indexing_service.py @@ -0,0 +1,53 @@ +from typing import Optional + +from django.conf import settings +from prompt_studio.prompt_studio_core.constants import IndexingStatus +from utils.cache_service import CacheService + + +class DocumentIndexingService: + CACHE_PREFIX = "document_indexing:" + + @classmethod + def set_document_indexing(cls, org_id: str, user_id: str, doc_id_key: str) -> None: + CacheService.set_key( + cls._cache_key(org_id, user_id, doc_id_key), + IndexingStatus.STARTED_STATUS.value, + expire=settings.INDEXING_FLAG_TTL, + ) + + @classmethod + def is_document_indexing(cls, org_id: str, user_id: str, doc_id_key: str) -> bool: + return ( + CacheService.get_key(cls._cache_key(org_id, user_id, doc_id_key)) + == IndexingStatus.STARTED_STATUS.value + ) + + @classmethod + def mark_document_indexed( + cls, org_id: str, user_id: str, doc_id_key: str, doc_id: str + ) -> None: + CacheService.set_key( + cls._cache_key(org_id, user_id, doc_id_key), + doc_id, + expire=settings.INDEXING_FLAG_TTL, + ) + + @classmethod + def get_indexed_document_id( + cls, org_id: str, user_id: str, doc_id_key: str + ) -> Optional[str]: + result = CacheService.get_key(cls._cache_key(org_id, user_id, doc_id_key)) + if result and result != IndexingStatus.STARTED_STATUS.value: + return result + return None + + @classmethod + def remove_document_indexing( + cls, org_id: str, user_id: str, doc_id_key: str + ) -> None: + CacheService.delete_a_key(cls._cache_key(org_id, user_id, doc_id_key)) + + @classmethod + def _cache_key(cls, org_id: str, user_id: str, doc_id_key: str) -> str: + return f"{cls.CACHE_PREFIX}{org_id}:{user_id}:{doc_id_key}" diff --git a/backend/prompt_studio/prompt_studio_core/exceptions.py b/backend/prompt_studio/prompt_studio_core/exceptions.py index 666d41241..241418060 100644 --- a/backend/prompt_studio/prompt_studio_core/exceptions.py +++ b/backend/prompt_studio/prompt_studio_core/exceptions.py @@ -1,3 +1,4 @@ +from prompt_studio.prompt_profile_manager.constants import ProfileManagerKeys from prompt_studio.prompt_studio_core.constants import ToolStudioErrors from rest_framework.exceptions import APIException @@ -58,3 +59,11 @@ class PermissionError(APIException): class EmptyPromptError(APIException): status_code = 422 default_detail = "Prompt(s) cannot be empty" + + +class MaxProfilesReachedError(APIException): + status_code = 403 + default_detail = ( + f"Maximum number of profiles (max {ProfileManagerKeys.MAX_PROFILE_COUNT})" + " per prompt studio project has been reached." + ) diff --git a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py index caa4e377c..c0f7b73d7 100644 --- a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py +++ b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py @@ -13,9 +13,15 @@ from django.db.models.manager import BaseManager from file_management.file_management_helper import FileManagerHelper from prompt_studio.prompt_profile_manager.models import ProfileManager +from prompt_studio.prompt_profile_manager.profile_manager_helper import ( + ProfileManagerHelper, +) from prompt_studio.prompt_studio.models import ToolStudioPrompt -from prompt_studio.prompt_studio_core.constants import LogLevels +from prompt_studio.prompt_studio_core.constants import IndexingStatus, LogLevels from prompt_studio.prompt_studio_core.constants import ToolStudioPromptKeys as TSPKeys +from prompt_studio.prompt_studio_core.document_indexing_service import ( + DocumentIndexingService, +) from prompt_studio.prompt_studio_core.exceptions import ( AnswerFetchError, DefaultProfileError, @@ -344,6 +350,7 @@ def index_document( is_summary=is_summary, reindex=True, run_id=run_id, + user_id=user_id, ) logger.info(f"[{tool_id}] Indexing successful for doc: {file_name}") @@ -354,7 +361,7 @@ def index_document( "Indexing successful", ) - return doc_id + return doc_id.get("output") @staticmethod def prompt_responder( @@ -364,6 +371,7 @@ def prompt_responder( document_id: str, id: Optional[str] = None, run_id: str = None, + profile_manager_id: Optional[str] = None, ) -> Any: """Execute chain/single run of the prompts. Makes a call to prompt service and returns the dict of response. @@ -374,6 +382,7 @@ def prompt_responder( user_id (str): User's ID document_id (str): UUID of the document uploaded id (Optional[str]): ID of the prompt + profile_manager_id (Optional[str]): UUID of the profile manager Raises: AnswerFetchError: Error from prompt-service @@ -383,44 +392,94 @@ def prompt_responder( """ document: DocumentManager = DocumentManager.objects.get(pk=document_id) doc_name: str = document.document_name - - doc_path = FileManagerHelper.handle_sub_directory_for_tenants( - org_id=org_id, - user_id=user_id, - tool_id=tool_id, - is_create=False, + doc_path = PromptStudioHelper._get_document_path( + org_id, user_id, tool_id, doc_name ) - doc_path = str(Path(doc_path) / doc_name) if id: - prompt_instance = PromptStudioHelper._fetch_prompt_from_id(id) - prompt_name = prompt_instance.prompt_key - logger.info(f"[{tool_id}] Executing single prompt {id}") - PromptStudioHelper._publish_log( - { - "tool_id": tool_id, - "run_id": run_id, - "prompt_key": prompt_name, - "doc_name": doc_name, - }, - LogLevels.INFO, - LogLevels.RUN, - "Executing single prompt", + return PromptStudioHelper._execute_single_prompt( + id, + doc_path, + doc_name, + tool_id, + org_id, + user_id, + document_id, + run_id, + profile_manager_id, + ) + else: + return PromptStudioHelper._execute_prompts_in_single_pass( + doc_path, tool_id, org_id, user_id, document_id, run_id ) - prompts: list[ToolStudioPrompt] = [] - prompts.append(prompt_instance) - tool: CustomTool = prompt_instance.tool_id + @staticmethod + def _execute_single_prompt( + id, + doc_path, + doc_name, + tool_id, + org_id, + user_id, + document_id, + run_id, + profile_manager_id, + ): + prompt_instance = PromptStudioHelper._fetch_prompt_from_id(id) + prompt_name = prompt_instance.prompt_key + PromptStudioHelper._publish_log( + { + "tool_id": tool_id, + "run_id": run_id, + "prompt_key": prompt_name, + "doc_name": doc_name, + }, + LogLevels.INFO, + LogLevels.RUN, + "Executing single prompt", + ) + prompts = [prompt_instance] + tool = prompt_instance.tool_id - if tool.summarize_as_source: - directory, filename = os.path.split(doc_path) - doc_path = os.path.join( - directory, - TSPKeys.SUMMARIZE, - os.path.splitext(filename)[0] + ".txt", - ) + if tool.summarize_as_source: + directory, filename = os.path.split(doc_path) + doc_path = os.path.join( + directory, TSPKeys.SUMMARIZE, os.path.splitext(filename)[0] + ".txt" + ) - logger.info(f"[{tool.tool_id}] Invoking prompt service for prompt {id}") + PromptStudioHelper._publish_log( + { + "tool_id": tool_id, + "run_id": run_id, + "prompt_key": prompt_name, + "doc_name": doc_name, + }, + LogLevels.DEBUG, + LogLevels.RUN, + "Invoking prompt service", + ) + + try: + response = PromptStudioHelper._fetch_response( + doc_path=doc_path, + doc_name=doc_name, + tool=tool, + prompt=prompt_instance, + org_id=org_id, + document_id=document_id, + run_id=run_id, + profile_manager_id=profile_manager_id, + user_id=user_id, + ) + return PromptStudioHelper._handle_response( + response, run_id, prompts, document_id, False, profile_manager_id + ) + except Exception as e: + logger.error( + f"[{tool.tool_id}] Error while fetching response for " + f"prompt {id} and doc {document_id}: {e}" + ) + msg = str(e) PromptStudioHelper._publish_log( { "tool_id": tool_id, @@ -428,130 +487,89 @@ def prompt_responder( "prompt_key": prompt_name, "doc_name": doc_name, }, - LogLevels.DEBUG, + LogLevels.ERROR, LogLevels.RUN, - "Invoking prompt service", + msg, ) + raise e - try: - response = PromptStudioHelper._fetch_response( - doc_path=doc_path, - doc_name=doc_name, - tool=tool, - prompt=prompt_instance, - org_id=org_id, - document_id=document_id, - run_id=run_id, - ) + @staticmethod + def _execute_prompts_in_single_pass( + doc_path, tool_id, org_id, user_id, document_id, run_id + ): + prompts = PromptStudioHelper.fetch_prompt_from_tool(tool_id) + prompts = [prompt for prompt in prompts if prompt.prompt_type != TSPKeys.NOTES] + if not prompts: + logger.error(f"[{tool_id or 'NA'}] No prompts found for id: {id}") + raise NoPromptsFound() - OutputManagerHelper.handle_prompt_output_update( - run_id=run_id, - prompts=prompts, - outputs=response["output"], - document_id=document_id, - is_single_pass_extract=False, - ) - # TODO: Review if this catch-all is required - except Exception as e: - logger.error( - f"[{tool.tool_id}] Error while fetching response for " - f"prompt {id} and doc {document_id}: {e}" - ) - msg: str = ( - f"Error while fetching response for " - f"'{prompt_name}' with '{doc_name}'. {e}" - ) - if isinstance(e, AnswerFetchError): - msg = str(e) - PromptStudioHelper._publish_log( - { - "tool_id": tool_id, - "run_id": run_id, - "prompt_key": prompt_name, - "doc_name": doc_name, - }, - LogLevels.ERROR, - LogLevels.RUN, - msg, - ) - raise e + PromptStudioHelper._publish_log( + {"tool_id": tool_id, "run_id": run_id, "prompt_id": str(id)}, + LogLevels.INFO, + LogLevels.RUN, + "Executing prompts in single pass", + ) - logger.info( - f"[{tool.tool_id}] Response fetched successfully for prompt {id}" + try: + tool = prompts[0].tool_id + response = PromptStudioHelper._fetch_single_pass_response( + file_path=doc_path, + tool=tool, + prompts=prompts, + org_id=org_id, + document_id=document_id, + run_id=run_id, + user_id=user_id, + ) + return PromptStudioHelper._handle_response( + response, run_id, prompts, document_id, True + ) + except Exception as e: + logger.error( + f"[{tool.tool_id}] Error while fetching single pass response: {e}" ) PromptStudioHelper._publish_log( { "tool_id": tool_id, "run_id": run_id, - "prompt_key": prompt_name, - "doc_name": doc_name, + "prompt_id": str(id), }, - LogLevels.INFO, - LogLevels.RUN, - "Single prompt execution completed", - ) - - return response - else: - prompts = PromptStudioHelper.fetch_prompt_from_tool(tool_id) - prompts = [ - prompt for prompt in prompts if prompt.prompt_type != TSPKeys.NOTES - ] - if not prompts: - logger.error(f"[{tool_id or 'NA'}] No prompts found for id: {id}") - raise NoPromptsFound() - - logger.info(f"[{tool_id}] Executing prompts in single pass") - PromptStudioHelper._publish_log( - {"tool_id": tool_id, "run_id": run_id, "prompt_id": str(id)}, - LogLevels.INFO, + LogLevels.ERROR, LogLevels.RUN, - "Executing prompts in single pass", + f"Failed to fetch single pass response. {e}", ) + raise e - try: - tool = prompts[0].tool_id - response = PromptStudioHelper._fetch_single_pass_response( - file_path=doc_path, - tool=tool, - prompts=prompts, - org_id=org_id, - document_id=document_id, - run_id=run_id, - ) - - OutputManagerHelper.handle_prompt_output_update( - run_id=run_id, - prompts=prompts, - outputs=response[TSPKeys.OUTPUT], - document_id=document_id, - is_single_pass_extract=True, - ) - except Exception as e: - logger.error( - f"[{tool.tool_id}] Error while fetching single pass response: {e}" # noqa: E501 - ) - PromptStudioHelper._publish_log( - { - "tool_id": tool_id, - "run_id": run_id, - "prompt_id": str(id), - }, - LogLevels.ERROR, - LogLevels.RUN, - f"Failed to fetch single pass response. {e}", - ) - raise e - - logger.info(f"[{tool.tool_id}] Single pass response fetched successfully") - PromptStudioHelper._publish_log( - {"tool_id": tool_id, "run_id": run_id, "prompt_id": str(id)}, - LogLevels.INFO, - LogLevels.RUN, - "Single pass execution completed", - ) + @staticmethod + def _get_document_path(org_id, user_id, tool_id, doc_name): + doc_path = FileManagerHelper.handle_sub_directory_for_tenants( + org_id=org_id, + user_id=user_id, + tool_id=tool_id, + is_create=False, + ) + return str(Path(doc_path) / doc_name) - return response + @staticmethod + def _handle_response( + response, run_id, prompts, document_id, is_single_pass, profile_manager_id=None + ): + if response.get("status") == IndexingStatus.PENDING_STATUS.value: + return { + "status": IndexingStatus.PENDING_STATUS.value, + "message": IndexingStatus.DOCUMENT_BEING_INDEXED.value, + } + + OutputManagerHelper.handle_prompt_output_update( + run_id=run_id, + prompts=prompts, + outputs=response["output"], + document_id=document_id, + is_single_pass_extract=is_single_pass, + profile_manager_id=profile_manager_id, + context=response["metadata"].get("context"), + ) + return response @staticmethod def _fetch_response( @@ -562,6 +580,8 @@ def _fetch_response( org_id: str, document_id: str, run_id: str, + user_id: str, + profile_manager_id: Optional[str] = None, ) -> Any: """Utility function to invoke prompt service. Used internally. @@ -572,6 +592,9 @@ def _fetch_response( prompt (ToolStudioPrompt): ToolStudioPrompt instance to fetch response org_id (str): UUID of the organization document_id (str): UUID of the document + profile_manager_id (Optional[str]): UUID of the profile manager + user_id (str): The ID of the user who uploaded the document + Raises: DefaultProfileError: If no default profile is selected @@ -580,6 +603,14 @@ def _fetch_response( Returns: Any: Output from LLM """ + + # Fetch the ProfileManager instance using the profile_manager_id if provided + profile_manager = prompt.profile_manager + if profile_manager_id: + profile_manager = ProfileManagerHelper.get_profile_manager( + profile_manager_id=profile_manager_id + ) + monitor_llm_instance: Optional[AdapterInstance] = tool.monitor_llm monitor_llm: Optional[str] = None challenge_llm_instance: Optional[AdapterInstance] = tool.challenge_llm @@ -600,28 +631,33 @@ def _fetch_response( challenge_llm = str(default_profile.llm.id) # Need to check the user who created profile manager - PromptStudioHelper.validate_adapter_status(prompt.profile_manager) + PromptStudioHelper.validate_adapter_status(profile_manager) # Need to check the user who created profile manager # has access to adapters - PromptStudioHelper.validate_profile_manager_owner_access(prompt.profile_manager) + PromptStudioHelper.validate_profile_manager_owner_access(profile_manager) # Not checking reindex here as there might be # change in Profile Manager - vector_db = str(prompt.profile_manager.vector_store.id) - embedding_model = str(prompt.profile_manager.embedding_model.id) - llm = str(prompt.profile_manager.llm.id) - x2text = str(prompt.profile_manager.x2text.id) - prompt_profile_manager: ProfileManager = prompt.profile_manager - if not prompt_profile_manager: + vector_db = str(profile_manager.vector_store.id) + embedding_model = str(profile_manager.embedding_model.id) + llm = str(profile_manager.llm.id) + x2text = str(profile_manager.x2text.id) + if not profile_manager: raise DefaultProfileError() - PromptStudioHelper.dynamic_indexer( - profile_manager=prompt_profile_manager, + index_result = PromptStudioHelper.dynamic_indexer( + profile_manager=profile_manager, file_path=doc_path, tool_id=str(tool.tool_id), org_id=org_id, document_id=document_id, is_summary=tool.summarize_as_source, run_id=run_id, + user_id=user_id, ) + if index_result.get("status") == IndexingStatus.PENDING_STATUS.value: + return { + "status": IndexingStatus.PENDING_STATUS.value, + "message": IndexingStatus.DOCUMENT_BEING_INDEXED.value, + } output: dict[str, Any] = {} outputs: list[dict[str, Any]] = [] @@ -639,16 +675,16 @@ def _fetch_response( output[TSPKeys.PROMPT] = prompt.prompt output[TSPKeys.ACTIVE] = prompt.active - output[TSPKeys.CHUNK_SIZE] = prompt.profile_manager.chunk_size + output[TSPKeys.CHUNK_SIZE] = profile_manager.chunk_size output[TSPKeys.VECTOR_DB] = vector_db output[TSPKeys.EMBEDDING] = embedding_model - output[TSPKeys.CHUNK_OVERLAP] = prompt.profile_manager.chunk_overlap + output[TSPKeys.CHUNK_OVERLAP] = profile_manager.chunk_overlap output[TSPKeys.LLM] = llm output[TSPKeys.TYPE] = prompt.enforce_type output[TSPKeys.NAME] = prompt.prompt_key - output[TSPKeys.RETRIEVAL_STRATEGY] = prompt.profile_manager.retrieval_strategy - output[TSPKeys.SIMILARITY_TOP_K] = prompt.profile_manager.similarity_top_k - output[TSPKeys.SECTION] = prompt.profile_manager.section + output[TSPKeys.RETRIEVAL_STRATEGY] = profile_manager.retrieval_strategy + output[TSPKeys.SIMILARITY_TOP_K] = profile_manager.similarity_top_k + output[TSPKeys.SECTION] = profile_manager.section output[TSPKeys.X2TEXT_ADAPTER] = x2text # Eval settings for the prompt output[TSPKeys.EVAL_SETTINGS] = {} @@ -715,10 +751,11 @@ def dynamic_indexer( file_path: str, org_id: str, document_id: str, + user_id: str, is_summary: bool = False, reindex: bool = False, run_id: str = None, - ) -> str: + ) -> Any: """Used to index a file based on the passed arguments. This is useful when a file needs to be indexed dynamically as the @@ -732,6 +769,7 @@ def dynamic_indexer( org_id (str): ID of the organization is_summary (bool, optional): Flag to ensure if extracted contents need to be persisted. Defaults to False. + user_id (str): The ID of the user who uploaded the document Returns: str: Index key for the combination of arguments @@ -750,9 +788,42 @@ def dynamic_indexer( profile_manager.chunk_size = 0 try: + usage_kwargs = {"run_id": run_id} util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) tool_index = Index(tool=util) + doc_id_key = tool_index.generate_file_id( + tool_id=tool_id, + vector_db=vector_db, + embedding=embedding_model, + x2text=x2text_adapter, + chunk_size=str(profile_manager.chunk_size), + chunk_overlap=str(profile_manager.chunk_overlap), + file_path=file_path, + file_hash=None, + ) + if not reindex: + indexed_doc_id = DocumentIndexingService.get_indexed_document_id( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + if indexed_doc_id: + return { + "status": IndexingStatus.COMPLETED_STATUS.value, + "output": indexed_doc_id, + } + # Polling if document is already being indexed + if DocumentIndexingService.is_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ): + return { + "status": IndexingStatus.PENDING_STATUS.value, + "output": IndexingStatus.DOCUMENT_BEING_INDEXED.value, + } + + # Set the document as being indexed + DocumentIndexingService.set_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) doc_id: str = tool_index.index( tool_id=tool_id, embedding_instance_id=embedding_model, @@ -772,7 +843,10 @@ def dynamic_indexer( profile_manager=profile_manager, doc_id=doc_id, ) - return doc_id + DocumentIndexingService.mark_document_indexed( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key, doc_id=doc_id + ) + return {"status": IndexingStatus.COMPLETED_STATUS.value, "output": doc_id} except (IndexingError, IndexingAPIError, SdkError) as e: doc_name = os.path.split(file_path)[1] PromptStudioHelper._publish_log( @@ -791,6 +865,7 @@ def _fetch_single_pass_response( file_path: str, prompts: list[ToolStudioPrompt], org_id: str, + user_id: str, document_id: str, run_id: str = None, ) -> Any: @@ -819,7 +894,7 @@ def _fetch_single_pass_response( if not default_profile: raise DefaultProfileError() - PromptStudioHelper.dynamic_indexer( + index_result = PromptStudioHelper.dynamic_indexer( profile_manager=default_profile, file_path=file_path, tool_id=tool_id, @@ -827,7 +902,13 @@ def _fetch_single_pass_response( is_summary=tool.summarize_as_source, document_id=document_id, run_id=run_id, + user_id=user_id, ) + if index_result.get("status") == IndexingStatus.PENDING_STATUS.value: + return { + "status": IndexingStatus.PENDING_STATUS.value, + "message": IndexingStatus.DOCUMENT_BEING_INDEXED.value, + } vector_db = str(default_profile.vector_store.id) embedding_model = str(default_profile.embedding_model.id) diff --git a/backend/prompt_studio/prompt_studio_core/views.py b/backend/prompt_studio/prompt_studio_core/views.py index 9093efd42..8db0a3ef5 100644 --- a/backend/prompt_studio/prompt_studio_core/views.py +++ b/backend/prompt_studio/prompt_studio_core/views.py @@ -21,6 +21,9 @@ ToolStudioKeys, ToolStudioPromptKeys, ) +from prompt_studio.prompt_studio_core.document_indexing_service import ( + DocumentIndexingService, +) from prompt_studio.prompt_studio_core.exceptions import ( IndexingAPIError, ToolDeleteError, @@ -30,6 +33,7 @@ from prompt_studio.prompt_studio_document_manager.prompt_studio_document_helper import ( # noqa: E501 PromptStudioDocumentHelper, ) +from prompt_studio.prompt_studio_index_manager.models import IndexManager from prompt_studio.prompt_studio_registry.prompt_studio_registry_helper import ( PromptStudioRegistryHelper, ) @@ -264,6 +268,7 @@ def fetch_response(self, request: HttpRequest, pk: Any = None) -> Response: document_id: str = request.data.get(ToolStudioPromptKeys.DOCUMENT_ID) id: str = request.data.get(ToolStudioPromptKeys.ID) run_id: str = request.data.get(ToolStudioPromptKeys.RUN_ID) + profile_manager: str = request.data.get(ToolStudioPromptKeys.PROFILE_MANAGER_ID) if not run_id: # Generate a run_id run_id = CommonUtils.generate_uuid() @@ -275,6 +280,7 @@ def fetch_response(self, request: HttpRequest, pk: Any = None) -> Response: user_id=custom_tool.created_by.user_id, document_id=document_id, run_id=run_id, + profile_manager_id=profile_manager, ) return Response(response, status=status.HTTP_200_OK) @@ -446,17 +452,26 @@ def delete_for_ide(self, request: HttpRequest, pk: uuid) -> Response: document_id: str = serializer.validated_data.get( ToolStudioPromptKeys.DOCUMENT_ID ) + org_id = UserSessionUtils.get_organization_id(request) + user_id = custom_tool.created_by.user_id document: DocumentManager = DocumentManager.objects.get(pk=document_id) file_name: str = document.document_name file_path = FileManagerHelper.handle_sub_directory_for_tenants( - UserSessionUtils.get_organization_id(request), + org_id=org_id, is_create=False, - user_id=custom_tool.created_by.user_id, + user_id=user_id, tool_id=str(custom_tool.tool_id), ) path = file_path file_system = LocalStorageFS(settings={"path": path}) try: + # Delete indexed flags in redis + index_managers = IndexManager.objects.filter(document_manager=document_id) + for index_manager in index_managers: + raw_index_id = index_manager.raw_index_id + DocumentIndexingService.remove_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=raw_index_id + ) # Delete the document record document.delete() # Delete the files diff --git a/backend/prompt_studio/prompt_studio_output_manager/migrations/0013_promptstudiooutputmanager_context.py b/backend/prompt_studio/prompt_studio_output_manager/migrations/0013_promptstudiooutputmanager_context.py new file mode 100644 index 000000000..9d72dbd4d --- /dev/null +++ b/backend/prompt_studio/prompt_studio_output_manager/migrations/0013_promptstudiooutputmanager_context.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.1 on 2024-06-27 18:27 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("prompt_studio_output_manager", "0012_promptstudiooutputmanager_run_id"), + ] + + operations = [ + migrations.AddField( + model_name="promptstudiooutputmanager", + name="context", + field=models.CharField( + blank=True, db_comment="Field to store chucks used", null=True + ), + ), + ] diff --git a/backend/prompt_studio/prompt_studio_output_manager/migrations/0014_alter_promptstudiooutputmanager_context.py b/backend/prompt_studio/prompt_studio_output_manager/migrations/0014_alter_promptstudiooutputmanager_context.py new file mode 100644 index 000000000..9d7844eaa --- /dev/null +++ b/backend/prompt_studio/prompt_studio_output_manager/migrations/0014_alter_promptstudiooutputmanager_context.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.1 on 2024-06-30 17:17 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("prompt_studio_output_manager", "0013_promptstudiooutputmanager_context"), + ] + + operations = [ + migrations.AlterField( + model_name="promptstudiooutputmanager", + name="context", + field=models.TextField( + blank=True, db_comment="Field to store chunks used", null=True + ), + ), + ] diff --git a/backend/prompt_studio/prompt_studio_output_manager/models.py b/backend/prompt_studio/prompt_studio_output_manager/models.py index 14febf634..e1f7f5b86 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/models.py +++ b/backend/prompt_studio/prompt_studio_output_manager/models.py @@ -21,6 +21,9 @@ class PromptStudioOutputManager(BaseModel): output = models.CharField( db_comment="Field to store output", editable=True, null=True, blank=True ) + context = models.TextField( + db_comment="Field to store chunks used", editable=True, null=True, blank=True + ) eval_metrics = models.JSONField( db_column="eval_metrics", null=False, diff --git a/backend/prompt_studio/prompt_studio_output_manager/output_manager_helper.py b/backend/prompt_studio/prompt_studio_output_manager/output_manager_helper.py index 6f942e3b7..b88a25602 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/output_manager_helper.py +++ b/backend/prompt_studio/prompt_studio_output_manager/output_manager_helper.py @@ -1,10 +1,14 @@ import json import logging -from typing import Any +from typing import Any, Optional from prompt_studio.prompt_profile_manager.models import ProfileManager -from prompt_studio.prompt_studio.exceptions import AnswerFetchError from prompt_studio.prompt_studio.models import ToolStudioPrompt +from prompt_studio.prompt_studio_core.exceptions import ( + AnswerFetchError, + DefaultProfileError, +) +from prompt_studio.prompt_studio_core.models import CustomTool from prompt_studio.prompt_studio_document_manager.models import DocumentManager from prompt_studio.prompt_studio_output_manager.constants import ( PromptStudioOutputManagerKeys as PSOMKeys, @@ -20,42 +24,32 @@ def handle_prompt_output_update( run_id: str, prompts: list[ToolStudioPrompt], outputs: Any, + context: Any, document_id: str, is_single_pass_extract: bool, + profile_manager_id: Optional[str] = None, ) -> None: """Handles updating prompt outputs in the database. Args: + run_id (str): ID of the run. prompts (list[ToolStudioPrompt]): List of prompts to update. outputs (Any): Outputs corresponding to the prompts. document_id (str): ID of the document. + profile_manager_id (Optional[str]): UUID of the profile manager. is_single_pass_extract (bool): Flag indicating if single pass extract is active. """ - # Check if prompts list is empty - if not prompts: - return # Return early if prompts list is empty - tool = prompts[0].tool_id - document_manager = DocumentManager.objects.get(pk=document_id) - default_profile = ProfileManager.get_default_llm_profile(tool=tool) - # Iterate through each prompt in the list - for prompt in prompts: - if prompt.prompt_type == PSOMKeys.NOTES: - continue - if is_single_pass_extract: - profile_manager = default_profile - else: - profile_manager = prompt.profile_manager - output = json.dumps(outputs.get(prompt.prompt_key)) - eval_metrics = outputs.get(f"{prompt.prompt_key}__evaluation", []) - - # Attempt to update an existing output manager, - # for the given criteria, - # or create a new one if it doesn't exist + def update_or_create_prompt_output( + prompt: ToolStudioPrompt, + profile_manager: ProfileManager, + output: str, + eval_metrics: list[Any], + tool: CustomTool, + context: str, + ): try: - # Create or get the existing record for this document, prompt and - # profile combo _, success = PromptStudioOutputManager.objects.get_or_create( document_manager=document_manager, tool_id=tool, @@ -65,6 +59,7 @@ def handle_prompt_output_update( defaults={ "output": output, "eval_metrics": eval_metrics, + "context": context, }, ) @@ -79,11 +74,12 @@ def handle_prompt_output_update( f"profile {profile_manager.profile_id}" ) - args: dict[str, str] = dict() - args["run_id"] = run_id - args["output"] = output - args["eval_metrics"] = eval_metrics - # Update the record with the run id and other params + args: dict[str, str] = { + "run_id": run_id, + "output": output, + "eval_metrics": eval_metrics, + "context": context, + } PromptStudioOutputManager.objects.filter( document_manager=document_manager, tool_id=tool, @@ -94,3 +90,57 @@ def handle_prompt_output_update( except Exception as e: raise AnswerFetchError(f"Error updating prompt output {e}") from e + + if not prompts: + return # Return early if prompts list is empty + + tool = prompts[0].tool_id + default_profile = OutputManagerHelper.get_default_profile( + profile_manager_id, tool + ) + document_manager = DocumentManager.objects.get(pk=document_id) + + for prompt in prompts: + if prompt.prompt_type == PSOMKeys.NOTES or not prompt.active: + continue + + if not is_single_pass_extract: + context = json.dumps(context.get(prompt.prompt_key)) + + output = json.dumps(outputs.get(prompt.prompt_key)) + profile_manager = default_profile + eval_metrics = outputs.get(f"{prompt.prompt_key}__evaluation", []) + + update_or_create_prompt_output( + prompt=prompt, + profile_manager=profile_manager, + output=output, + eval_metrics=eval_metrics, + tool=tool, + context=context, + ) + + @staticmethod + def get_default_profile( + profile_manager_id: Optional[str], tool: CustomTool + ) -> ProfileManager: + if profile_manager_id: + return OutputManagerHelper.fetch_profile_manager(profile_manager_id) + else: + return OutputManagerHelper.fetch_default_llm_profile(tool) + + @staticmethod + def fetch_profile_manager(profile_manager_id: str) -> ProfileManager: + try: + return ProfileManager.objects.get(profile_id=profile_manager_id) + except ProfileManager.DoesNotExist: + raise DefaultProfileError( + f"ProfileManager with ID {profile_manager_id} does not exist." + ) + + @staticmethod + def fetch_default_llm_profile(tool: CustomTool) -> ProfileManager: + try: + return ProfileManager.get_default_llm_profile(tool=tool) + except DefaultProfileError: + raise DefaultProfileError("Default ProfileManager does not exist.") diff --git a/backend/sample.env b/backend/sample.env index abbc1757c..f42a32068 100644 --- a/backend/sample.env +++ b/backend/sample.env @@ -139,3 +139,6 @@ LOGS_BATCH_LIMIT=30 # Celery Configuration CELERY_BROKER_URL = "redis://unstract-redis:6379" + +# Indexing flag to prevent re-index +INDEXING_FLAG_TTL=1800 diff --git a/backend/usage/constants.py b/backend/usage/constants.py index d28074e1e..8da54da05 100644 --- a/backend/usage/constants.py +++ b/backend/usage/constants.py @@ -4,3 +4,4 @@ class UsageKeys: PROMPT_TOKENS = "prompt_tokens" COMPLETION_TOKENS = "completion_tokens" TOTAL_TOKENS = "total_tokens" + COST_IN_DOLLARS = "cost_in_dollars" diff --git a/backend/usage/helper.py b/backend/usage/helper.py index b91fae556..0bfab7556 100644 --- a/backend/usage/helper.py +++ b/backend/usage/helper.py @@ -36,6 +36,7 @@ def get_aggregated_token_count(run_id: str) -> dict: prompt_tokens=Sum(UsageKeys.PROMPT_TOKENS), completion_tokens=Sum(UsageKeys.COMPLETION_TOKENS), total_tokens=Sum(UsageKeys.TOTAL_TOKENS), + cost_in_dollars=Sum(UsageKeys.COST_IN_DOLLARS), ) logger.info(f"Token counts aggregated successfully for run_id: {run_id}") @@ -50,6 +51,7 @@ def get_aggregated_token_count(run_id: str) -> dict: UsageKeys.COMPLETION_TOKENS ), UsageKeys.TOTAL_TOKENS: usage_summary.get(UsageKeys.TOTAL_TOKENS), + UsageKeys.COST_IN_DOLLARS: usage_summary.get(UsageKeys.COST_IN_DOLLARS), } return result except Usage.DoesNotExist: diff --git a/frontend/package-lock.json b/frontend/package-lock.json index aee9ec808..018771100 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -25,6 +25,7 @@ "cronstrue": "^2.48.0", "emoji-picker-react": "^4.8.0", "emoji-regex": "^10.3.0", + "framer-motion": "^11.2.10", "handlebars": "^4.7.8", "http-proxy-middleware": "^2.0.6", "js-cookie": "^3.0.5", @@ -9383,6 +9384,30 @@ "url": "https://www.patreon.com/infusion" } }, + "node_modules/framer-motion": { + "version": "11.2.10", + "resolved": "https://registry.npmjs.org/framer-motion/-/framer-motion-11.2.10.tgz", + "integrity": "sha512-/gr3PLZUVFCc86a9MqCUboVrALscrdluzTb3yew+2/qKBU8CX6nzs918/SRBRCqaPbx0TZP10CB6yFgK2C5cYQ==", + "dependencies": { + "tslib": "^2.4.0" + }, + "peerDependencies": { + "@emotion/is-prop-valid": "*", + "react": "^18.0.0", + "react-dom": "^18.0.0" + }, + "peerDependenciesMeta": { + "@emotion/is-prop-valid": { + "optional": true + }, + "react": { + "optional": true + }, + "react-dom": { + "optional": true + } + } + }, "node_modules/fresh": { "version": "0.5.2", "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz", @@ -27235,6 +27260,14 @@ "resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-4.2.0.tgz", "integrity": "sha512-MhLuK+2gUcnZe8ZHlaaINnQLl0xRIGRfcGk2yl8xoQAfHrSsL3rYu6FCmBdkdbhc9EPlwyGHewaRsvwRMJtAlA==" }, + "framer-motion": { + "version": "11.2.10", + "resolved": "https://registry.npmjs.org/framer-motion/-/framer-motion-11.2.10.tgz", + "integrity": "sha512-/gr3PLZUVFCc86a9MqCUboVrALscrdluzTb3yew+2/qKBU8CX6nzs918/SRBRCqaPbx0TZP10CB6yFgK2C5cYQ==", + "requires": { + "tslib": "^2.4.0" + } + }, "fresh": { "version": "0.5.2", "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz", diff --git a/frontend/package.json b/frontend/package.json index a01f53738..674ef6689 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -20,6 +20,7 @@ "cronstrue": "^2.48.0", "emoji-picker-react": "^4.8.0", "emoji-regex": "^10.3.0", + "framer-motion": "^11.2.10", "handlebars": "^4.7.8", "http-proxy-middleware": "^2.0.6", "js-cookie": "^3.0.5", diff --git a/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx b/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx index 4b0cb5281..980a80916 100644 --- a/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx +++ b/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx @@ -7,6 +7,7 @@ import PropTypes from "prop-types"; import { displayPromptResult, + getLLMModelNamesForProfiles, promptType, } from "../../../helpers/GetStaticData"; import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate"; @@ -31,18 +32,25 @@ try { function CombinedOutput({ docId, setFilledFields }) { const [combinedOutput, setCombinedOutput] = useState({}); const [isOutputLoading, setIsOutputLoading] = useState(false); + const [adapterData, setAdapterData] = useState([]); + const [activeKey, setActiveKey] = useState("0"); const { details, defaultLlmProfile, singlePassExtractMode, isSinglePassExtractLoading, + llmProfiles, isSimplePromptStudio, } = useCustomToolStore(); const { sessionDetails } = useSessionStore(); const { setAlertDetails } = useAlertStore(); const axiosPrivate = useAxiosPrivate(); const handleException = useExceptionHandler(); + const [selectedProfile, setSelectedProfile] = useState(defaultLlmProfile); + useEffect(() => { + getAdapterInfo(); + }, []); useEffect(() => { if (!docId || isSinglePassExtractLoading) { return; @@ -62,7 +70,7 @@ function CombinedOutput({ docId, setFilledFields }) { } output[item?.prompt_key] = ""; - let profileManager = item?.profile_manager; + let profileManager = selectedProfile || item?.profile_manager; if (singlePassExtractMode) { profileManager = defaultLlmProfile; } @@ -100,12 +108,25 @@ function CombinedOutput({ docId, setFilledFields }) { .finally(() => { setIsOutputLoading(false); }); - }, [docId, singlePassExtractMode, isSinglePassExtractLoading]); + }, [ + docId, + singlePassExtractMode, + isSinglePassExtractLoading, + selectedProfile, + ]); const handleOutputApiRequest = async () => { - let url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/prompt-output/?tool_id=${details?.tool_id}&document_manager=${docId}&is_single_pass_extract=${singlePassExtractMode}`; + let url; if (isSimplePromptStudio) { url = promptOutputApiSps(details?.tool_id, null, docId); + } else { + url = `/api/v1/unstract/${ + sessionDetails?.orgId + }/prompt-studio/prompt-output/?tool_id=${ + details?.tool_id + }&document_manager=${docId}&is_single_pass_extract=${singlePassExtractMode}&profile_manager=${ + selectedProfile || defaultLlmProfile + }`; } const requestOptions = { method: "GET", @@ -122,15 +143,43 @@ function CombinedOutput({ docId, setFilledFields }) { }); }; + const getAdapterInfo = () => { + axiosPrivate + .get( + `/api/v1/unstract/${sessionDetails?.orgId}/adapter/?adapter_type=LLM` + ) + .then((res) => { + const adapterList = res?.data; + setAdapterData(getLLMModelNamesForProfiles(llmProfiles, adapterList)); + }); + }; + if (isOutputLoading) { return ; } + const handleTabChange = (key) => { + if (key === "0") { + setSelectedProfile(defaultLlmProfile); + } else { + setSelectedProfile(adapterData[key - 1]?.profile_id); + } + setActiveKey(key); + }; + if (isSimplePromptStudio && TableView) { return ; } - return ; + return ( + + ); } CombinedOutput.propTypes = { diff --git a/frontend/src/components/custom-tools/combined-output/JsonView.jsx b/frontend/src/components/custom-tools/combined-output/JsonView.jsx index 4ad2fa7a1..61bfc9f46 100644 --- a/frontend/src/components/custom-tools/combined-output/JsonView.jsx +++ b/frontend/src/components/custom-tools/combined-output/JsonView.jsx @@ -1,8 +1,18 @@ import PropTypes from "prop-types"; import Prism from "prismjs"; import { useEffect } from "react"; +import { ProfileInfoBar } from "../profile-info-bar/ProfileInfoBar"; +import TabPane from "antd/es/tabs/TabPane"; +import { Tabs } from "antd"; -function JsonView({ combinedOutput }) { +function JsonView({ + combinedOutput, + handleTabChange, + adapterData, + activeKey, + selectedProfile, + llmProfiles, +}) { useEffect(() => { Prism.highlightAll(); }, [combinedOutput]); @@ -10,9 +20,19 @@ function JsonView({ combinedOutput }) { return (
+ }> + Default} key={"0"}> + {adapterData.map((adapter, index) => ( + {adapter.llm_model}} + key={(index + 1)?.toString()} + /> + ))} +
+
{combinedOutput && (
@@ -29,6 +49,11 @@ function JsonView({ combinedOutput }) {
 
 JsonView.propTypes = {
   combinedOutput: PropTypes.object.isRequired,
+  handleTabChange: PropTypes.func,
+  adapterData: PropTypes.array,
+  selectedProfile: PropTypes.string,
+  llmProfiles: PropTypes.array,
+  activeKey: PropTypes.string,
 };
 
 export { JsonView };
diff --git a/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx b/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx
index cc4d6d6f8..23de585e2 100644
--- a/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx
+++ b/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.jsx
@@ -1,4 +1,4 @@
-import { Button, Modal, Table, Typography } from "antd";
+import { Button, Modal, Table, Tabs, Typography } from "antd";
 import PropTypes from "prop-types";
 import { useEffect, useState } from "react";
 import {
@@ -12,12 +12,17 @@ import { useCustomToolStore } from "../../../store/custom-tool-store";
 import { useSessionStore } from "../../../store/session-store";
 import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate";
 import "./OutputForDocModal.css";
-import { displayPromptResult } from "../../../helpers/GetStaticData";
+import {
+  displayPromptResult,
+  getLLMModelNamesForProfiles,
+} from "../../../helpers/GetStaticData";
 import { SpinnerLoader } from "../../widgets/spinner-loader/SpinnerLoader";
 import { useAlertStore } from "../../../store/alert-store";
 import { useExceptionHandler } from "../../../hooks/useExceptionHandler";
 import { TokenUsage } from "../token-usage/TokenUsage";
 import { useTokenUsageStore } from "../../../store/token-usage-store";
+import TabPane from "antd/es/tabs/TabPane";
+import { ProfileInfoBar } from "../profile-info-bar/ProfileInfoBar";
 
 const columns = [
   {
@@ -57,6 +62,7 @@ function OutputForDocModal({
 }) {
   const [promptOutputs, setPromptOutputs] = useState([]);
   const [rows, setRows] = useState([]);
+  const [adapterData, setAdapterData] = useState([]);
   const [isLoading, setIsLoading] = useState(false);
   const {
     details,
@@ -66,6 +72,7 @@ function OutputForDocModal({
     disableLlmOrDocChange,
     singlePassExtractMode,
     isSinglePassExtractLoading,
+    llmProfiles,
   } = useCustomToolStore();
   const { sessionDetails } = useSessionStore();
   const axiosPrivate = useAxiosPrivate();
@@ -73,12 +80,14 @@ function OutputForDocModal({
   const { setAlertDetails } = useAlertStore();
   const { handleException } = useExceptionHandler();
   const { tokenUsage } = useTokenUsageStore();
+  const [selectedProfile, setSelectedProfile] = useState(defaultLlmProfile);
 
   useEffect(() => {
     if (!open) {
       return;
     }
     handleGetOutputForDocs();
+    getAdapterInfo();
   }, [open, singlePassExtractMode, isSinglePassExtractLoading]);
 
   useEffect(() => {
@@ -89,6 +98,12 @@ function OutputForDocModal({
     handleRowsGeneration(promptOutputs);
   }, [promptOutputs, tokenUsage]);
 
+  useEffect(() => {
+    if (selectedProfile) {
+      handleGetOutputForDocs(selectedProfile);
+    }
+  }, [selectedProfile]);
+
   const moveSelectedDocToTop = () => {
     // Create a copy of the list of documents
     const docs = [...listOfDocs];
@@ -147,8 +162,16 @@ function OutputForDocModal({
     });
   };
 
-  const handleGetOutputForDocs = () => {
-    let profile = profileManagerId;
+  const getAdapterInfo = () => {
+    axiosPrivate
+      .get(`/api/v1/unstract/${sessionDetails.orgId}/adapter/?adapter_type=LLM`)
+      .then((res) => {
+        const adapterList = res.data;
+        setAdapterData(getLLMModelNamesForProfiles(llmProfiles, adapterList));
+      });
+  };
+
+  const handleGetOutputForDocs = (profile = profileManagerId) => {
     if (singlePassExtractMode) {
       profile = defaultLlmProfile;
     }
@@ -206,10 +229,14 @@ function OutputForDocModal({
       }
 
       const result = {
-        key: item,
+        key: item?.document_id,
         document: item?.document_name,
         token_count: (
-          
+          
         ),
         value: (
           <>
@@ -239,6 +266,14 @@ function OutputForDocModal({
     setRows(rowsData);
   };
 
+  const handleTabChange = (key) => {
+    if (key === "0") {
+      setSelectedProfile(profileManagerId);
+    } else {
+      setSelectedProfile(adapterData[key - 1]?.profile_id);
+    }
+  };
+
   return (
     
         
+
+ + Default} key={"0"}> + {adapterData?.map((adapter, index) => ( + {adapter?.llm_model}} + key={(index + 1)?.toString()} + > + ))} + {" "} + +
+ ), + spinning: isLoading, + }} />
diff --git a/frontend/src/components/custom-tools/profile-info-bar/ProfileInfoBar.css b/frontend/src/components/custom-tools/profile-info-bar/ProfileInfoBar.css new file mode 100644 index 000000000..8346c26d7 --- /dev/null +++ b/frontend/src/components/custom-tools/profile-info-bar/ProfileInfoBar.css @@ -0,0 +1,3 @@ +.profile-info-bar { + margin-bottom: 10px; +} diff --git a/frontend/src/components/custom-tools/profile-info-bar/ProfileInfoBar.jsx b/frontend/src/components/custom-tools/profile-info-bar/ProfileInfoBar.jsx new file mode 100644 index 000000000..8f52a68a4 --- /dev/null +++ b/frontend/src/components/custom-tools/profile-info-bar/ProfileInfoBar.jsx @@ -0,0 +1,58 @@ +import { Col, Row, Tag } from "antd"; +import PropTypes from "prop-types"; +import "./ProfileInfoBar.css"; + +const ProfileInfoBar = ({ profiles, profileId }) => { + const profile = profiles?.find((p) => p?.profile_id === profileId); + + if (!profile) { + return

Profile not found

; + } + + return ( + + + + Profile Name: {profile?.profile_name} + + + + + Chunk Size: {profile?.chunk_size} + + + + + Vector Store: {profile?.vector_store} + + + + + Embedding Model: {profile?.embedding_model} + + + + + LLM: {profile?.llm} + + + + + X2Text: {profile?.x2text} + + + + + Reindex: {profile?.reindex ? "Yes" : "No"} + + + + ); +}; + +ProfileInfoBar.propTypes = { + profiles: PropTypes.array, + profileId: PropTypes.string, +}; + +export { ProfileInfoBar }; diff --git a/frontend/src/components/custom-tools/prompt-card/Header.jsx b/frontend/src/components/custom-tools/prompt-card/Header.jsx index 35ac9d5aa..954b5706d 100644 --- a/frontend/src/components/custom-tools/prompt-card/Header.jsx +++ b/frontend/src/components/custom-tools/prompt-card/Header.jsx @@ -3,10 +3,12 @@ import { DeleteOutlined, EditOutlined, LoadingOutlined, + PlayCircleFilled, PlayCircleOutlined, SyncOutlined, } from "@ant-design/icons"; -import { Button, Col, Row, Tag, Tooltip } from "antd"; +import { useState } from "react"; +import { Button, Checkbox, Col, Divider, Row, Tag, Tooltip } from "antd"; import PropTypes from "prop-types"; import { promptStudioUpdateStatus } from "../../../helpers/GetStaticData"; @@ -31,6 +33,7 @@ function Header({ enableEdit, expandCard, setExpandCard, + enabledProfiles, }) { const { selectedDoc, @@ -40,9 +43,21 @@ function Header({ indexDocs, } = useCustomToolStore(); - const handleRunBtnClick = () => { + const [isDisablePrompt, setIsDisablePrompt] = useState(promptDetails?.active); + + const handleRunBtnClick = (profileManager = null, coverAllDoc = true) => { setExpandCard(true); - handleRun(); + handleRun(profileManager, coverAllDoc, enabledProfiles, true); + }; + + const handleDisablePrompt = (event) => { + const check = event?.target?.checked; + setIsDisablePrompt(check); + handleChange(check, promptDetails?.prompt_id, "active", true, true).catch( + () => { + setIsDisablePrompt(!check); + } + ); }; return ( @@ -122,24 +137,51 @@ function Header({ {!singlePassExtractMode && ( - - - + <> + + + + + + + )} + + handleDelete(promptDetails?.prompt_id)} content="The prompt will be permanently deleted." @@ -150,9 +192,9 @@ function Header({ type="text" className="prompt-card-action-button" disabled={ - disableLlmOrDocChange.includes(promptDetails?.prompt_id) || + disableLlmOrDocChange?.includes(promptDetails?.prompt_id) || isSinglePassExtractLoading || - indexDocs.includes(selectedDoc?.document_id) + indexDocs?.includes(selectedDoc?.document_id) } > @@ -180,6 +222,7 @@ Header.propTypes = { enableEdit: PropTypes.func.isRequired, expandCard: PropTypes.bool.isRequired, setExpandCard: PropTypes.func.isRequired, + enabledProfiles: PropTypes.array.isRequired, }; export { Header }; diff --git a/frontend/src/components/custom-tools/prompt-card/OutputForIndex.jsx b/frontend/src/components/custom-tools/prompt-card/OutputForIndex.jsx new file mode 100644 index 000000000..1f29cc8a1 --- /dev/null +++ b/frontend/src/components/custom-tools/prompt-card/OutputForIndex.jsx @@ -0,0 +1,50 @@ +import PropTypes from "prop-types"; +import { Modal } from "antd"; +import "./PromptCard.css"; +import { uniqueId } from "lodash"; + +function OutputForIndex({ chunkData, setIsIndexOpen, isIndexOpen }) { + const handleClose = () => { + setIsIndexOpen(false); + }; + + const lines = chunkData?.split("\\n"); // Split text into lines and remove any empty lines + + const renderContent = (chunk) => { + if (!chunk) { + return

No chunks founds

; + } + return ( + <> + {chunk?.map((line) => ( +
+ {line} +
+
+ ))} + + ); + }; + + return ( + +
{renderContent(lines)}
+
+ ); +} + +OutputForIndex.propTypes = { + chunkData: PropTypes.string, + isIndexOpen: PropTypes.bool.isRequired, + setIsIndexOpen: PropTypes.func.isRequired, +}; + +export { OutputForIndex }; diff --git a/frontend/src/components/custom-tools/prompt-card/PromptCard.css b/frontend/src/components/custom-tools/prompt-card/PromptCard.css index a57d12fd8..d190564f2 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptCard.css +++ b/frontend/src/components/custom-tools/prompt-card/PromptCard.css @@ -2,6 +2,7 @@ .prompt-card { border: 1px solid #d9d9d9; + border-radius: 0; } .prompt-card .ant-card-body { @@ -20,10 +21,6 @@ background-color: #eceff3; } -.prompt-card-rad { - border-radius: 8px 8px 0px 0px; -} - .prompt-card-head-info-icon { color: #575859; } @@ -61,8 +58,11 @@ background-color: #f5f7f9; } -.prompt-card-comp-layout-border { - border-radius: 0px 0px 10px 10px; +.prompt-card-llm-layout { + width: 100%; + padding: 8px 12px; + background-color: #f5f7f9; + row-gap: 2; } .prompt-card-actions-dropdowns { @@ -76,7 +76,10 @@ .prompt-card-result { padding-top: 12px; background-color: #fff8e6; - border-radius: 0px 0px 8px 8px; + display: flex; + justify-content: space-between; + align-items: center; + width: -webkit-fill-available; } .prompt-card-result .ant-typography { @@ -84,7 +87,8 @@ } .prompt-card-res { - white-space: pre-wrap; + min-width: 0; + flex-basis: 60%; } .prompt-card-select-type { @@ -116,3 +120,86 @@ .prompt-card-collapse .ant-collapse-content-box { padding: 0px !important; } + +.llm-info { + display: flex; + align-items: center; +} + +.prompt-card-llm-title { + margin: 0 0 0 10px !important; +} + +.prompt-card-llm-icon { + display: flex; + justify-content: center; +} + +.prompt-cost-item { + font-size: 12px; + margin-right: 10px; +} + +.prompt-info { + display: flex; + justify-content: space-between; +} + +.llm-info-container > * { + margin-left: 10px; +} + +.prompt-card-llm-container { + border-right: 1px solid #0000000f; + width: fill-available !important; +} + +.prompt-profile-run-expanded { + flex-direction: column; + align-items: flex-start; +} + +.prompt-card-llm { + flex: 1; + min-width: 250px; +} + +.collapsed-output { + max-height: 20px; /* Adjust height as necessary */ + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.expanded-output { + max-height: 250px; /* Adjust height as necessary */ + overflow-y: auto; +} + +.prompt-profile-run { + align-self: flex-end; + margin-left: 10px; +} + +.index-output-tab { + overflow: scroll; + height: 60vh; +} + +.header-delete-divider { + margin: auto 2px auto 10px; + border: 1px solid rgba(5, 5, 5, 0.1); + height: 20px; +} + +.ant-tag-checkable.checked { + background-color: #f6ffed !important; + border-color: #b7eb8f !important; + color: #52c41a !important; +} + +.ant-tag-checkable.unchecked { + background-color: #00000005 !important; + border-color: #00000026 !important; + color: #000 !important; +} diff --git a/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx b/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx index a6efe92b1..6a67d9939 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx +++ b/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx @@ -4,6 +4,7 @@ import { useEffect, useState } from "react"; import { defaultTokenUsage, generateUUID, + pollForCompletion, } from "../../../helpers/GetStaticData"; import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate"; import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; @@ -42,22 +43,19 @@ function PromptCard({ updatePlaceHolder, }) { const [enforceTypeList, setEnforceTypeList] = useState([]); - const [page, setPage] = useState(0); - const [isRunLoading, setIsRunLoading] = useState(false); + const [isRunLoading, setIsRunLoading] = useState({}); const [promptKey, setPromptKey] = useState(""); const [promptText, setPromptText] = useState(""); const [selectedLlmProfileId, setSelectedLlmProfileId] = useState(null); const [openEval, setOpenEval] = useState(false); - const [result, setResult] = useState({ - promptOutputId: null, - output: "", - }); - const [coverage, setCoverage] = useState(0); + const [result, setResult] = useState([]); + const [coverage, setCoverage] = useState({}); const [coverageTotal, setCoverageTotal] = useState(0); const [isCoverageLoading, setIsCoverageLoading] = useState(false); const [openOutputForDoc, setOpenOutputForDoc] = useState(false); const [progressMsg, setProgressMsg] = useState({}); const [docOutputs, setDocOutputs] = useState({}); + const [timers, setTimers] = useState({}); const { getDropdownItems, llmProfiles, @@ -113,17 +111,25 @@ function PromptCard({ }, [messages]); useEffect(() => { - setSelectedLlmProfileId(promptDetails?.profile_manager || null); + setSelectedLlmProfileId( + promptDetails?.profile_manager || llmProfiles[0]?.profile_id + ); }, [promptDetails]); useEffect(() => { resetInfoMsgs(); + handleGetOutput(); + handleGetCoverage(); if (isSinglePassExtractLoading) { return; } - - handleGetOutput(); - handleGetCoverage(); + if (selectedLlmProfileId !== promptDetails?.profile_id) { + handleChange( + selectedLlmProfileId, + promptDetails?.prompt_id, + "profile_manager" + ); + } }, [ selectedLlmProfileId, selectedDoc, @@ -154,20 +160,6 @@ function PromptCard({ updateCustomTool({ disableLlmOrDocChange: listOfIds }); }, [isCoverageLoading]); - useEffect(() => { - if (page < 1) { - return; - } - const llmProfile = llmProfiles[page - 1]; - if (llmProfile?.profile_id !== promptDetails?.profile_id) { - handleChange( - llmProfile?.profile_id, - promptDetails?.prompt_id, - "profile_manager" - ); - } - }, [page]); - useEffect(() => { if (isCoverageLoading && coverageTotal === listOfDocs?.length) { setIsCoverageLoading(false); @@ -180,42 +172,26 @@ function PromptCard({ }; useEffect(() => { - const isProfilePresent = llmProfiles.some( - (profile) => profile.profile_id === selectedLlmProfileId + const isProfilePresent = llmProfiles?.some( + (profile) => profile?.profile_id === selectedLlmProfileId ); // If selectedLlmProfileId is not present, set it to null if (!isProfilePresent) { setSelectedLlmProfileId(null); } - - const llmProfileId = promptDetails?.profile_manager; - if (!llmProfileId) { - setPage(0); - return; - } - const index = llmProfiles.findIndex( - (item) => item?.profile_id === llmProfileId - ); - setPage(index + 1); }, [llmProfiles]); - const handlePageLeft = () => { - if (page <= 1) { - return; - } - - const newPage = page - 1; - setPage(newPage); + // Function to update loading state for a specific document and profile + const handleIsRunLoading = (docId, profileId, isLoading) => { + setIsRunLoading((prevLoadingProfiles) => ({ + ...prevLoadingProfiles, + [`${docId}_${profileId}`]: isLoading, + })); }; - const handlePageRight = () => { - if (page >= llmProfiles?.length) { - return; - } - - const newPage = page + 1; - setPage(newPage); + const handleSelectDefaultLLM = (llmProfileId) => { + setSelectedLlmProfileId(llmProfileId); }; const handleTypeChange = (value) => { @@ -238,7 +214,12 @@ function PromptCard({ }; // Generate the result for the currently selected document - const handleRun = () => { + const handleRun = ( + profileManagerId, + coverAllDoc = true, + selectedLlmProfiles = [], + runAllLLM = false + ) => { try { setPostHogCustomEvent("ps_prompt_run", { info: "Click on 'Run Prompt' button (Multi Pass)", @@ -247,39 +228,60 @@ function PromptCard({ // If an error occurs while setting custom posthog event, ignore it and continue } - if (!promptDetails?.profile_manager?.length && !isSimplePromptStudio) { - setAlertDetails({ - type: "error", - content: "LLM Profile is not selected", - }); - return; - } + const validateInputs = ( + profileManagerId, + selectedLlmProfiles, + coverAllDoc + ) => { + if ( + !profileManagerId && + !promptDetails?.profile_manager?.length && + !(!coverAllDoc && selectedLlmProfiles?.length > 0) && + !isSimplePromptStudio + ) { + setAlertDetails({ + type: "error", + content: "LLM Profile is not selected", + }); + return true; + } - if (!selectedDoc) { - setAlertDetails({ - type: "error", - content: "Document not selected", - }); - return; - } + if (!selectedDoc) { + setAlertDetails({ + type: "error", + content: "Document not selected", + }); + return true; + } - if (!promptKey) { - setAlertDetails({ - type: "error", - content: "Prompt key cannot be empty", - }); - return; - } + if (!promptKey) { + setAlertDetails({ + type: "error", + content: "Prompt key cannot be empty", + }); + return true; + } - if (!promptText) { - setAlertDetails({ - type: "error", - content: "Prompt cannot be empty", - }); + if (!promptText) { + setAlertDetails({ + type: "error", + content: "Prompt cannot be empty", + }); + return true; + } + + return false; + }; + + if (validateInputs(profileManagerId, selectedLlmProfiles, coverAllDoc)) { return; } - setIsRunLoading(true); + handleIsRunLoading( + selectedDoc?.document_id, + profileManagerId || selectedLlmProfileId, + true + ); setIsCoverageLoading(true); setCoverage(0); setCoverageTotal(0); @@ -297,8 +299,9 @@ function PromptCard({ details?.summarize_llm_profile ) { // Summary needs to be indexed before running the prompt - setIsRunLoading(false); - handleStepsAfterRunCompletion(); + handleIsRunLoading(selectedDoc?.document_id, selectedLlmProfileId, false); + setCoverageTotal(1); + handleCoverage(selectedLlmProfileId); setAlertDetails({ type: "error", content: `Summary needs to be indexed before running the prompt - ${selectedDoc?.document_name}.`, @@ -307,39 +310,93 @@ function PromptCard({ } handleDocOutputs(docId, true, null); - handleRunApiRequest(docId) - .then((res) => { - const data = res?.data?.output; - const value = data[promptDetails?.prompt_key]; - if (value || value === 0) { - setCoverage((prev) => prev + 1); - } - handleDocOutputs(docId, false, value); - handleGetOutput(); - }) - .catch((err) => { - setIsRunLoading(false); - handleDocOutputs(docId, false, null); - setAlertDetails( - handleException(err, `Failed to generate output for ${docId}`) + if (runAllLLM) { + let selectedProfiles = llmProfiles; + if (!coverAllDoc && selectedLlmProfiles?.length > 0) { + selectedProfiles = llmProfiles.filter((profile) => + selectedLlmProfiles.includes(profile?.profile_id) ); - }) - .finally(() => { - if (isSimplePromptStudio) { + } + for (const profile of selectedProfiles) { + setIsCoverageLoading(true); + + handleIsRunLoading(selectedDoc?.document_id, profile?.profile_id, true); + handleRunApiRequest(docId, profile?.profile_id) + .then((res) => { + const data = res?.data?.output; + const value = data[promptDetails?.prompt_key]; + if (value || value === 0) { + setCoverage((prev) => prev + 1); + } + handleDocOutputs(docId, false, value); + handleGetOutput(profile?.profile_id); + updateDocCoverage( + coverage, + promptDetails?.prompt_id, + profile?.profile_id, + docId + ); + }) + .catch((err) => { + handleIsRunLoading( + selectedDoc?.document_id, + profile?.profile_id, + false + ); + handleDocOutputs(docId, false, null); + setAlertDetails( + handleException(err, `Failed to generate output for ${docId}`) + ); + }) + .finally(() => { + setIsCoverageLoading(false); + }); + runCoverageForAllDoc(coverAllDoc, profile.profile_id); + } + } else { + handleRunApiRequest(docId, profileManagerId) + .then((res) => { + const data = res?.data?.output; + const value = data[promptDetails?.prompt_key]; + if (value || value === 0) { + updateDocCoverage( + coverage, + promptDetails?.prompt_id, + profileManagerId, + docId + ); + } + handleDocOutputs(docId, false, value); + handleGetOutput(); + setCoverageTotal(1); + }) + .catch((err) => { + handleIsRunLoading( + selectedDoc?.document_id, + selectedLlmProfileId, + false + ); + handleDocOutputs(docId, false, null); + setAlertDetails( + handleException(err, `Failed to generate output for ${docId}`) + ); + }) + .finally(() => { setIsCoverageLoading(false); - } else { - handleStepsAfterRunCompletion(); - } - }); + handleIsRunLoading(selectedDoc?.document_id, profileManagerId, false); + }); + runCoverageForAllDoc(coverAllDoc, profileManagerId); + } }; - const handleStepsAfterRunCompletion = () => { - setCoverageTotal(1); - handleCoverage(); + const runCoverageForAllDoc = (coverAllDoc, profileManagerId) => { + if (coverAllDoc) { + handleCoverage(profileManagerId); + } }; // Get the coverage for all the documents except the one that's currently selected - const handleCoverage = () => { + const handleCoverage = (profileManagerId) => { const listOfDocsToProcess = [...listOfDocs].filter( (item) => item?.document_id !== selectedDoc?.document_id ); @@ -372,13 +429,19 @@ function PromptCard({ return; } + setIsCoverageLoading(true); handleDocOutputs(docId, true, null); - handleRunApiRequest(docId) + handleRunApiRequest(docId, profileManagerId) .then((res) => { const data = res?.data?.output; const outputValue = data[promptDetails?.prompt_key]; if (outputValue || outputValue === 0) { - setCoverage((prev) => prev + 1); + updateDocCoverage( + coverage, + promptDetails?.prompt_id, + profileManagerId, + docId + ); } handleDocOutputs(docId, false, outputValue); }) @@ -390,102 +453,164 @@ function PromptCard({ }) .finally(() => { totalCoverageValue++; + if (listOfDocsToProcess?.length >= totalCoverageValue) { + setIsCoverageLoading(false); + return; + } setCoverageTotal(totalCoverageValue); }); }); }; - const handleRunApiRequest = async (docId) => { + const updateDocCoverage = (coverage, promptId, profileManagerId, docId) => { + const key = `${promptId}_${profileManagerId}`; + const counts = { ...coverage }; + // If the key exists in the counts object, increment the count + if (counts[key]) { + if (!counts[key]?.docs_covered?.includes(docId)) { + counts[key]?.docs_covered?.push(docId); + } + } else { + // Otherwise, add the key to the counts object with an initial count of 1 + counts[key] = { + prompt_id: promptId, + profile_manager: profileManagerId, + docs_covered: [docId], + }; + } + setCoverage(counts); + }; + + const handleRunApiRequest = async (docId, profileManagerId) => { const promptId = promptDetails?.prompt_id; const runId = generateUUID(); + const maxWaitTime = 30 * 1000; // 30 seconds + const pollingInterval = 5000; // 5 seconds + const tokenUsagepollingInterval = 5000; const body = { document_id: docId, id: promptId, }; - let intervalId; - let tokenUsageId; - let url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/fetch_response/${details?.tool_id}`; - if (!isSimplePromptStudio) { - body["run_id"] = runId; - // Update the token usage state with default token usage for a specific document ID - tokenUsageId = promptId + "__" + docId; - setTokenUsage(tokenUsageId, defaultTokenUsage); - - // Set up an interval to fetch token usage data at regular intervals - intervalId = setInterval( - () => getTokenUsage(runId, tokenUsageId), - 5000 // Fetch token usage data every 5000 milliseconds (5 seconds) - ); - } else { - body["sps_id"] = details?.tool_id; - url = promptRunApiSps; - } - - const requestOptions = { - method: "POST", - url, - headers: { - "X-CSRFToken": sessionDetails?.csrfToken, - "Content-Type": "application/json", - }, - data: body, - }; - - return axiosPrivate(requestOptions) - .then((res) => res) - .catch((err) => { - throw err; - }) - .finally(() => { - if (!isSimplePromptStudio) { - clearInterval(intervalId); - getTokenUsage(runId, tokenUsageId); + if (profileManagerId) { + body.profile_manager = profileManagerId; + let intervalId; + let tokenUsageId; + let url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/fetch_response/${details?.tool_id}`; + if (!isSimplePromptStudio) { + body["run_id"] = runId; + // Update the token usage state with default token usage for a specific document ID + tokenUsageId = promptId + "__" + docId + "__" + profileManagerId; + setTokenUsage(tokenUsageId, defaultTokenUsage); + + // Set up an interval to fetch token usage data at regular intervals + if ( + profileManagerId === selectedLlmProfileId && + docId === selectedDoc?.document_id + ) { + intervalId = setInterval( + () => getTokenUsage(runId, tokenUsageId), + tokenUsagepollingInterval // Fetch token usage data every 5000 milliseconds (5 seconds) + ); } - }); + setTimers((prev) => ({ + ...prev, + [tokenUsageId]: 0, + })); + } else { + body["sps_id"] = details?.tool_id; + url = promptRunApiSps; + } + const timerIntervalId = startTimer(tokenUsageId); + + const requestOptions = { + method: "POST", + url, + headers: { + "X-CSRFToken": sessionDetails?.csrfToken, + "Content-Type": "application/json", + }, + data: body, + }; + + const makeApiRequest = (requestOptions) => { + return axiosPrivate(requestOptions); + }; + const startTime = Date.now(); + return pollForCompletion( + startTime, + requestOptions, + maxWaitTime, + pollingInterval, + makeApiRequest + ) + .then((response) => { + return response; + }) + .catch((err) => { + throw err; + }) + .finally(() => { + if (!isSimplePromptStudio) { + clearInterval(intervalId); + getTokenUsage(runId, tokenUsageId); + stopTimer(tokenUsageId, timerIntervalId); + } + }); + } }; - const handleGetOutput = () => { - setIsRunLoading(true); - if ( - !selectedDoc || - (!singlePassExtractMode && !selectedLlmProfileId && !isSimplePromptStudio) - ) { - setResult({ - promptOutputId: null, - output: "", - }); - setIsRunLoading(false); + const handleGetOutput = (profileManager = undefined) => { + if (!selectedDoc) { + setResult([]); + return; + } + + if (!singlePassExtractMode && !selectedLlmProfileId) { + setResult([]); return; } + handleIsRunLoading( + selectedDoc?.document_id, + profileManager || selectedLlmProfileId, + true + ); + handleOutputApiRequest(true) .then((res) => { const data = res?.data; if (!data || data?.length === 0) { - setResult({ - promptOutputId: null, - output: "", - }); + setResult([]); return; } - const outputResult = data[0]; - setResult({ - promptOutputId: outputResult?.prompt_output_id, - output: outputResult?.output, - evalMetrics: getEvalMetrics( - promptDetails?.evaluate, - outputResult?.eval_metrics || [] - ), + const outputResults = data.map((outputResult) => { + return { + runId: outputResult?.run_id, + promptOutputId: outputResult?.prompt_output_id, + profileManager: outputResult?.profile_manager, + context: outputResult?.context, + output: outputResult?.output, + totalCost: outputResult?.token_usage?.cost_in_dollars, + evalMetrics: getEvalMetrics( + promptDetails?.evaluate, + outputResult?.eval_metrics || [] + ), + }; }); + setResult(outputResults); }) .catch((err) => { setAlertDetails(handleException(err, "Failed to generate the result")); }) .finally(() => { - setIsRunLoading(false); + handleIsRunLoading( + selectedDoc?.document_id, + profileManager || selectedLlmProfileId, + false + ); }); }; @@ -494,7 +619,7 @@ function PromptCard({ (singlePassExtractMode && !defaultLlmProfile) || (!singlePassExtractMode && !selectedLlmProfileId) ) { - setCoverage(0); + setCoverage({}); return; } @@ -510,6 +635,7 @@ function PromptCard({ const handleOutputApiRequest = async (isOutput) => { let url; + let profileManager = selectedLlmProfileId; if (isSimplePromptStudio) { url = promptOutputApiSps( details?.tool_id, @@ -517,16 +643,17 @@ function PromptCard({ null ); } else { - let profileManager = selectedLlmProfileId; if (singlePassExtractMode) { profileManager = defaultLlmProfile; } - url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/prompt-output/?tool_id=${details?.tool_id}&prompt_id=${promptDetails?.prompt_id}&profile_manager=${profileManager}&is_single_pass_extract=${singlePassExtractMode}`; + url = `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/prompt-output/?tool_id=${details?.tool_id}&prompt_id=${promptDetails?.prompt_id}&is_single_pass_extract=${singlePassExtractMode}`; } - if (isOutput) { url += `&document_manager=${selectedDoc?.document_id}`; } + if (singlePassExtractMode) { + url += `&profile_manager=${profileManager}`; + } const requestOptions = { method: "GET", @@ -542,14 +669,14 @@ function PromptCard({ if (singlePassExtractMode) { const tokenUsageId = `single_pass__${selectedDoc?.document_id}`; - const usage = data.find((item) => item?.run_id !== undefined); + const usage = data?.find((item) => item?.run_id !== undefined); if (!tokenUsage[tokenUsageId] && usage) { setTokenUsage(tokenUsageId, usage?.token_usage); } } else { - data.forEach((item) => { - const tokenUsageId = `${item?.prompt_id}__${item?.document_manager}`; + data?.forEach((item) => { + const tokenUsageId = `${item?.prompt_id}__${item?.document_manager}__${item?.profile_manager}`; if (tokenUsage[tokenUsageId] === undefined) { setTokenUsage(tokenUsageId, item?.token_usage); @@ -564,14 +691,35 @@ function PromptCard({ }; const handleGetCoverageData = (data) => { - const coverageValue = data.reduce((acc, item) => { - if (item?.output || item?.output === 0) { - return acc + 1; - } else { - return acc; - } - }, 0); - setCoverage(coverageValue); + data?.forEach((item) => { + updateDocCoverage( + coverage, + item?.prompt_id, + item?.profile_manager, + item?.document_manager + ); + }); + }; + + const startTimer = (profileId) => { + setTimers((prev) => ({ + ...prev, + [profileId]: (prev[profileId] || 0) + 1, + })); + return setInterval(() => { + setTimers((prev) => ({ + ...prev, + [profileId]: (prev[profileId] || 0) + 1, + })); + }, 1000); + }; + + const stopTimer = (profileId, intervalId) => { + clearInterval(intervalId); + setTimers((prev) => ({ + ...prev, + [profileId]: prev[profileId] || 0, + })); }; return ( @@ -589,8 +737,6 @@ function PromptCard({ progressMsg={progressMsg} handleRun={handleRun} handleChange={handleChange} - handlePageLeft={handlePageLeft} - handlePageRight={handlePageRight} handleTypeChange={handleTypeChange} handleDelete={handleDelete} updateStatus={updateStatus} @@ -599,7 +745,8 @@ function PromptCard({ setOpenEval={setOpenEval} setOpenOutputForDoc={setOpenOutputForDoc} selectedLlmProfileId={selectedLlmProfileId} - page={page} + handleSelectDefaultLLM={handleSelectDefaultLLM} + timers={timers} /> {EvalModal && !singlePassExtractMode && ( profile.profile_id) + ); + const [expandedProfiles, setExpandedProfiles] = useState([]); // New state for expanded profiles + const [isIndexOpen, setIsIndexOpen] = useState(false); + const privateAxios = useAxiosPrivate(); + const { sessionDetails } = useSessionStore(); + const { width: windowWidth } = useWindowDimensions(); + const handleException = useExceptionHandler(); + const { setAlertDetails } = useAlertStore(); + const componentWidth = windowWidth * 0.4; - useEffect(() => { - setExpandCard(true); - }, [isSinglePassExtractLoading]); + const divRef = useRef(null); const enableEdit = (event) => { event.stopPropagation(); @@ -73,7 +106,149 @@ function PromptCardItems({ setIsEditingTitle(true); setIsEditingPrompt(true); }; + const getModelOrAdapterId = (profile, adapters) => { + const result = { conf: {} }; + const keys = ["vector_store", "embedding_model", "llm", "x2text"]; + + keys.forEach((key) => { + const adapterName = profile[key]; + const adapter = adapters?.find( + (adapter) => adapter?.adapter_name === adapterName + ); + if (adapter) { + result.conf[key] = adapter?.model || adapter?.adapter_id?.split("|")[0]; + if (adapter?.adapter_type === "LLM") result.icon = adapter?.icon; + } + }); + return result; + }; + const getAdapterInfo = async () => { + privateAxios + .get(`/api/v1/unstract/${sessionDetails?.orgId}/adapter/`) + .then((res) => { + const adapterData = res?.data; + + // Update llmProfiles with additional fields + const updatedProfiles = llmProfiles?.map((profile) => { + return { ...getModelOrAdapterId(profile, adapterData), ...profile }; + }); + setLlmProfileDetails( + updatedProfiles + .map((profile) => ({ + ...profile, + isDefault: profile?.profile_id === selectedLlmProfileId, + isEnabled: enabledProfiles.includes(profile?.profile_id), + })) + .sort((a, b) => { + if (a?.isDefault) return -1; // Default profile comes first + if (b?.isDefault) return 1; + if (a?.isEnabled && !b?.isEnabled) return -1; // Enabled profiles come before disabled + if (!a?.isEnabled && b?.isEnabled) return 1; + return 0; + }) + ); + }) + .catch((err) => { + setAlertDetails(handleException(err)); + }); + }; + + const tooltipContent = (adapterConf) => ( +
+ {Object.entries(adapterConf)?.map(([key, value]) => ( +
+ {key}: {value} +
+ ))} +
+ ); + + const handleExpandClick = (profile) => { + const profileId = profile?.profile_id; + setExpandedProfiles((prevState) => + prevState.includes(profileId) + ? prevState.filter((id) => id !== profileId) + : [...prevState, profileId] + ); + }; + + const handleTagChange = (checked, profileId) => { + setEnabledProfiles((prevState) => + checked + ? [...prevState, profileId] + : prevState.filter((id) => id !== profileId) + ); + }; + + const getColSpan = () => (componentWidth < 1200 ? 24 : 6); + + const renderSinglePassResult = () => { + const [firstResult] = result || []; + if ( + promptDetails.active && + (firstResult?.output || firstResult?.output === 0) + ) { + return ( + <> + +
+ {isSinglePassExtractLoading ? ( + } /> + ) : ( + +
+ {displayPromptResult(firstResult.output, true)} +
+
+ )} +
+ + + +
+
+ + ); + } + return <>; + }; + + useEffect(() => { + setExpandCard(true); + }, [isSinglePassExtractLoading]); + + useEffect(() => { + if (singlePassExtractMode) { + setExpandedProfiles([]); + } + }, [singlePassExtractMode]); + + useEffect(() => { + getAdapterInfo(); + }, [llmProfiles, selectedLlmProfileId, enabledProfiles]); return (
@@ -94,6 +269,7 @@ function PromptCardItems({ enableEdit={enableEdit} expandCard={expandCard} setExpandCard={setExpandCard} + enabledProfiles={enabledProfiles} />
@@ -110,7 +286,7 @@ function PromptCardItems({ text={promptText} setText={setPromptText} promptId={promptDetails?.prompt_id} - defaultText={promptDetails.prompt} + defaultText={promptDetails?.prompt} handleChange={handleChange} isTextarea={true} placeHolder={updatePlaceHolder} @@ -120,7 +296,6 @@ function PromptCardItems({ {!isSimplePromptStudio && ( <> - )} @@ -150,22 +325,16 @@ function PromptCardItems({ )} - Coverage: {coverage} of {listOfDocs?.length || 0}{" "} - docs + Coverage:{" "} + {coverage[ + `${promptDetails?.prompt_id}_${selectedLlmProfileId}` + ]?.docs_covered?.length || 0}{" "} + of {listOfDocs?.length || 0} docs - {!singlePassExtractMode && ( - - )}