Skip to content

Commit

Permalink
[FIX] Adding data type support for variable replacement (#619)
Browse files Browse the repository at this point in the history
Adding data type support  for variable replacement
  • Loading branch information
harini-venkataraman authored Aug 23, 2024
1 parent beeba6a commit 9d430a7
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
28 changes: 26 additions & 2 deletions prompt-service/src/unstract/prompt_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,39 @@ def prompt_processor() -> Any:
)
try:
variable_map = output[PSKeys.VARIABLE_MAP]
VariableExtractor.execute_variable_replacement(
promptx = VariableExtractor.execute_variable_replacement(
prompt=promptx, variable_map=variable_map
)
app.logger.info(f"[{tool_id}] Prompt after variable replacement: {promptx}")
_publish_log(
log_events_id,
{
"tool_id": tool_id,
"prompt_key": prompt_name,
"doc_name": doc_name,
},
LogLevel.DEBUG,
RunLevel.RUN,
f"Prompt after variable replacement:{promptx} ",
)
except KeyError:
# Executed incase of structured tool and
# APIs where we do not set the variable map
VariableExtractor.execute_variable_replacement(
promptx = VariableExtractor.execute_variable_replacement(
prompt=promptx, variable_map=structured_output
)
app.logger.info(f"[{tool_id}] Prompt after variable replacement: {promptx}")
_publish_log(
log_events_id,
{
"tool_id": tool_id,
"prompt_key": prompt_name,
"doc_name": doc_name,
},
LogLevel.DEBUG,
RunLevel.RUN,
f"Prompt after variable replacement:{promptx} ",
)
except APIError as api_error:
raise api_error

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import re
from typing import Any
Expand Down Expand Up @@ -44,10 +45,21 @@ def check_static_variable_run_status(
return output

@staticmethod
def replace_generic_string_value(prompt: str, variable: str, value: str) -> str:
replaced_prompt = prompt.replace(variable, value)
def replace_generic_string_value(prompt: str, variable: str, value: Any) -> str:
formatted_value: str = value
if not isinstance(value, str):
formatted_value = VariableService.handle_json_and_str_types(value)
replaced_prompt = prompt.replace(variable, formatted_value)
return replaced_prompt

@staticmethod
def handle_json_and_str_types(value: Any) -> str:
try:
formatted_value = json.dumps(value)
except ValueError:
formatted_value = str(value)
return formatted_value

@staticmethod
def identify_variable_type(variable: str) -> VariableType:
variable_type: VariableType
Expand All @@ -69,12 +81,17 @@ def replace_dynamic_variable(
)
if not output_value:
return prompt
api_response = VariableService.fetch_dynamic_variable_value(
api_response: Any = VariableService.fetch_dynamic_variable_value(
url=url, data=output_value
)
formatted_api_response: str = VariableService.handle_json_and_str_types(
api_response
)
static_variable_marker_string = "".join(["{{", variable, "}}"])
replaced_prompt: str = VariableService.replace_generic_string_value(
prompt=prompt, variable=static_variable_marker_string, value=api_response
prompt=prompt,
variable=static_variable_marker_string,
value=formatted_api_response,
)
return replaced_prompt

Expand Down

0 comments on commit 9d430a7

Please sign in to comment.