Skip to content

Commit

Permalink
Added support for single pass extraction (#104)
Browse files Browse the repository at this point in the history
* Added support single-pass-extraction plugin

* Removed unwanted code

* Implemented single pass extraction in the FE

* Minor UI fixes

* Updated function name

* UI improvements in single pass extraction

* Fixed issue in which incorrect data is displayed

* Fixed eslint

---------

Co-authored-by: Neha <[email protected]>
Co-authored-by: Tahier Hussain <[email protected]>
  • Loading branch information
3 people authored Mar 15, 2024
1 parent 8d31eb7 commit 81e12b6
Show file tree
Hide file tree
Showing 13 changed files with 329 additions and 54 deletions.
2 changes: 2 additions & 0 deletions backend/prompt_studio/prompt_studio_core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class ToolStudioPromptKeys:
SUMMARIZE = "summarize"
SUMMARIZED_RESULT = "summarized_result"
DOCUMENT_ID = "document_id"
EXTRACT = "extract"
LLM_PROFILE_MANAGER = "llm_profile_manager"


class LogLevels:
Expand Down
7 changes: 6 additions & 1 deletion backend/prompt_studio/prompt_studio_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PromptNotValid(APIException):


class IndexingError(APIException):
status_code = 500
status_code = 400
default_detail = "Error while indexing file"


Expand Down Expand Up @@ -49,3 +49,8 @@ class OutputSaveError(APIException):
class ToolDeleteError(APIException):
status_code = 500
default_detail = "Failed to delete the error"


class NoPromptsFound(APIException):
status_code = 404
default_detail = "No prompts available to process"
164 changes: 156 additions & 8 deletions backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AnswerFetchError,
DefaultProfileError,
IndexingError,
NoPromptsFound,
PromptNotValid,
ToolNotValid,
)
Expand Down Expand Up @@ -171,18 +172,18 @@ def index_document(

@staticmethod
def prompt_responder(
id: str,
tool_id: str,
file_name: str,
org_id: str,
user_id: str,
document_id: str,
id: Optional[str] = None,
) -> Any:
"""Execute chain/single run of the prompts. Makes a call to prompt
service and returns the dict of response.
Args:
id (str): ID of the prompt
id (Optional[str]): ID of the prompt
tool_id (str): ID of tool created in prompt studio
file_name (str): Name of the file uploaded
org_id (str): Organization ID
Expand All @@ -195,16 +196,16 @@ def prompt_responder(
Returns:
Any: Dictionary containing the response from prompt-service
"""
logger.info(f"Invoking prompt responser for prompt {id}")
file_path = FileManagerHelper.handle_sub_directory_for_tenants(
org_id=org_id,
user_id=user_id,
tool_id=tool_id,
is_create=False,
)
file_path = str(Path(file_path) / file_name)
if id and tool_id:
logger.info("Executing in single prompt mode")
if id:
message: str = f"Executing single prompt {id} of tool {tool_id}"
logger.info(message)
try:
prompt_instance = PromptStudioHelper._fetch_prompt_from_id(id)
prompts: list[ToolStudioPrompt] = []
Expand All @@ -224,8 +225,7 @@ def prompt_responder(
stream_log.log(
stage=LogLevels.RUN,
level=LogLevels.INFO,
message=f"Executing single prompt \
{id} of tool {tool.tool_id}",
message=message,
),
)
if not prompt_instance:
Expand Down Expand Up @@ -271,6 +271,66 @@ def prompt_responder(
)
logger.info(f"Response fetched succesfully for prompt {id}")
return response
else:
message = f"Executing in single pass prompt mode for tool {tool_id}"
logger.info(message)
try:
prompts = PromptStudioHelper.fetch_prompt_from_tool(tool_id)

if not prompts:
raise NoPromptsFound()

tool = prompts[0].tool_id
response = PromptStudioHelper._fetch_single_pass_response(
file_path=file_path,
tool=tool,
prompts=prompts,
org_id=org_id,
document_id=document_id,
)

stream_log.publish(
tool_id,
stream_log.log(
stage=LogLevels.RUN,
level=LogLevels.INFO,
message="Executing single pass "
f"prompt for tool {tool_id}",
),
)
except NoPromptsFound:
logger.error("No prompts found for tool %s", tool_id)
raise
except Exception as e:
logger.error("Error while fetching prompt %s", e)
raise AnswerFetchError()
stream_log.publish(
tool_id,
stream_log.log(
stage=LogLevels.RUN,
level=LogLevels.INFO,
message=f"Prompt instances fetched for tool {tool_id}",
),
)
stream_log.publish(
tool_id,
stream_log.log(
stage=LogLevels.RUN,
level=LogLevels.INFO,
message=f"Invoking prompt service for tool {tool_id}",
),
)
stream_log.publish(
tool_id,
stream_log.log(
stage=LogLevels.RUN,
level=LogLevels.INFO,
message=f"Response fetched sucessfully \
from platform for tool {tool_id}",
),
)
logger.info("Response fetched succesfully for tool %s", tool_id)
return response

@staticmethod
def _fetch_response(
Expand Down Expand Up @@ -449,7 +509,8 @@ def dynamic_indexer(
extract_file_path = os.path.join(
directory, "extract", os.path.splitext(filename)[0] + ".txt"
)

else:
profile_manager.chunk_size = 0
try:
doc_id: str = tool_index.index_file(
tool_id=tool_id,
Expand All @@ -473,3 +534,90 @@ def dynamic_indexer(
return doc_id
except SdkError as e:
raise IndexingError(str(e))

@staticmethod
def _fetch_single_pass_response(
tool: CustomTool,
file_path: str,
prompts: list[ToolStudioPrompt],
org_id: str,
document_id: str,
) -> Any:
tool_id: str = str(tool.tool_id)
outputs: list[dict[str, Any]] = []
grammar: list[dict[str, Any]] = []
prompt_grammar = tool.prompt_grammer
default_profile: ProfileManager = ProfileManager.objects.get(
prompt_studio_tool=tool, is_default=True
)
default_profile.chunk_size = 0 # To retrive full context

if prompt_grammar:
for word, synonyms in prompt_grammar.items():
grammar.append({TSPKeys.WORD: word, TSPKeys.SYNONYMS: synonyms})

if not default_profile:
raise DefaultProfileError()

PromptStudioHelper.dynamic_indexer(
profile_manager=default_profile,
file_path=file_path,
tool_id=tool_id,
org_id=org_id,
is_summary=tool.summarize_as_source,
document_id=document_id,
)

vector_db = str(default_profile.vector_store.id)
embedding_model = str(default_profile.embedding_model.id)
llm = str(default_profile.llm.id)
x2text = str(default_profile.x2text.id)
llm_profile_manager = {}
llm_profile_manager[TSPKeys.PREAMBLE] = tool.preamble
llm_profile_manager[TSPKeys.POSTAMBLE] = tool.postamble
llm_profile_manager[TSPKeys.GRAMMAR] = grammar
llm_profile_manager[TSPKeys.LLM] = llm
llm_profile_manager[TSPKeys.X2TEXT_ADAPTER] = x2text
llm_profile_manager[TSPKeys.VECTOR_DB] = vector_db
llm_profile_manager[TSPKeys.EMBEDDING] = embedding_model
llm_profile_manager[TSPKeys.CHUNK_SIZE] = default_profile.chunk_size
llm_profile_manager[
TSPKeys.CHUNK_OVERLAP
] = default_profile.chunk_overlap

for prompt in prompts:
output: dict[str, Any] = {}
output[TSPKeys.PROMPT] = prompt.prompt
output[TSPKeys.ACTIVE] = prompt.active
output[TSPKeys.TYPE] = prompt.enforce_type
output[TSPKeys.NAME] = prompt.prompt_key
outputs.append(output)

if tool.summarize_as_source:
path = Path(file_path)
file_path = str(
path.parent / TSPKeys.SUMMARIZE / (path.stem + ".txt")
)
file_hash = ToolUtils.get_hash_from_file(file_path=file_path)

payload = {
TSPKeys.LLM_PROFILE_MANAGER: llm_profile_manager,
TSPKeys.OUTPUTS: outputs,
TSPKeys.TOOL_ID: tool_id,
TSPKeys.FILE_HASH: file_hash,
}

util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id)

responder = PromptTool(
tool=util,
prompt_host=settings.PROMPT_HOST,
prompt_port=settings.PROMPT_PORT,
)

answer = responder.single_pass_extraction(payload)
# TODO: Make use of dataclasses
if answer["status"] == "ERROR":
raise AnswerFetchError()
output_response = json.loads(answer["structure_output"])
return output_response
8 changes: 8 additions & 0 deletions backend/prompt_studio/prompt_studio_core/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
prompt_studio_adapter_choices = PromptStudioCoreView.as_view(
{"get": "get_adapter_choices"}
)
prompt_studio_single_pass_extraction = PromptStudioCoreView.as_view(
{"post": "single_pass_extraction"}
)

urlpatterns = format_suffix_patterns(
[
Expand Down Expand Up @@ -64,5 +67,10 @@
prompt_studio_adapter_choices,
name="prompt-studio-adapter-choices",
),
path(
"prompt-studio/single-pass-extraction",
prompt_studio_single_pass_extraction,
name="prompt-studio-single-pass-extraction",
),
]
)
32 changes: 32 additions & 0 deletions backend/prompt_studio/prompt_studio_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,35 @@ def fetch_response(self, request: HttpRequest) -> Response:
document_id=document_id,
)
return Response(response, status=status.HTTP_200_OK)

@action(detail=True, methods=["post"])
def single_pass_extraction(self, request: HttpRequest) -> Response:
"""API Entry point method to fetch response to prompt.
Args:
request (HttpRequest): _description_
Raises:
FilenameMissingError: _description_
Returns:
Response
"""
# TODO: Handle fetch_response and single_pass_
# extraction using common function
tool_id: str = request.data.get(ToolStudioPromptKeys.TOOL_ID)
document_id: str = request.data.get(ToolStudioPromptKeys.DOCUMENT_ID)
document: DocumentManager = DocumentManager.objects.get(pk=document_id)
file_name: str = document.document_name

if not file_name or file_name == ToolStudioPromptKeys.UNDEFINED:
logger.error("Mandatory field file_name is missing")
raise FilenameMissingError()
response: dict[str, Any] = PromptStudioHelper.prompt_responder(
tool_id=tool_id,
file_name=file_name,
org_id=request.org_id,
user_id=request.user.user_id,
document_id=document_id,
)
return Response(response, status=status.HTTP_200_OK)
Loading

0 comments on commit 81e12b6

Please sign in to comment.