Skip to content

Commit

Permalink
[FIX] Changes to display chunk data properly (#821)
Browse files Browse the repository at this point in the history
* Converted context from string to array

Signed-off-by: Deepak <[email protected]>

* v2 changes

Signed-off-by: Deepak <[email protected]>

* Removed unwanted comment

* Update backend/prompt_studio/prompt_studio_output_manager_v2/serializers.py

Co-authored-by: Chandrasekharan M <[email protected]>
Signed-off-by: Deepak K <[email protected]>

* Minor fix

---------

Signed-off-by: Deepak <[email protected]>
Signed-off-by: Deepak K <[email protected]>
Co-authored-by: Chandrasekharan M <[email protected]>
Co-authored-by: Gayathri <[email protected]>
Co-authored-by: Hari John Kuriakose <[email protected]>
  • Loading branch information
4 people authored Nov 15, 2024
1 parent cf9c7d6 commit 2003289
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def update_or_create_prompt_output(
output=output,
eval_metrics=eval_metrics,
tool=tool,
context=context,
context=json.dumps(context),
challenge_data=challenge_data,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging

from usage.helper import UsageHelper
Expand Down Expand Up @@ -25,4 +26,10 @@ def to_representation(self, instance):
)
token_usage = {}
data["token_usage"] = token_usage
# Convert string to list
try:
data["context"] = json.loads(data["context"])
except json.JSONDecodeError:
# Convert the old value of data["context"] to a list
data["context"] = [data["context"]]
return data
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def update_or_create_prompt_output(
output=output,
eval_metrics=eval_metrics,
tool=tool,
context=context,
context=json.dumps(context),
challenge_data=challenge_data,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging

from usage_v2.helper import UsageHelper
Expand Down Expand Up @@ -25,4 +26,10 @@ def to_representation(self, instance):
)
token_usage = {}
data["token_usage"] = token_usage
# Convert string to list
try:
data["context"] = json.loads(data["context"])
except json.JSONDecodeError:
# Convert the old value of data["context"] to a list
data["context"] = [data["context"]]
return data
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,7 @@ function OutputForIndex({ chunkData, setIsIndexOpen, isIndexOpen }) {
const activeRef = useRef(null);

useEffect(() => {
if (!chunkData) {
setChunks([]);
}
// Split chunkData into chunks using \f\n delimiter
const tempChunks = chunkData?.split("\f\n");
// To remove " at the end
if (tempChunks?.length > 0) {
const lastChunk = tempChunks[tempChunks?.length - 1].trim();
if (lastChunk === '\\n"' || lastChunk === "") {
tempChunks.pop();
}
}
setChunks(tempChunks);
setChunks(chunkData || []);
}, [chunkData]);

// Debounced search handler
Expand Down
4 changes: 2 additions & 2 deletions prompt-service/src/unstract/prompt_service/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ def plugin_loader(app: Flask) -> None:
initialize_plugin_endpoints(app=app)


def get_cleaned_context(context: str) -> str:
def get_cleaned_context(context: set[str]) -> list[str]:
clean_context_plugin: dict[str, Any] = plugins.get(PSKeys.CLEAN_CONTEXT, {})
if clean_context_plugin:
return clean_context_plugin["entrypoint_cls"].run(context=context)
return context
return list(context)


def initialize_plugin_endpoints(app: Flask) -> None:
Expand Down
43 changes: 15 additions & 28 deletions prompt-service/src/unstract/prompt_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,10 @@ def prompt_processor() -> Any:
raise api_error

try:
context = ""
context: set[str] = set()
if output[PSKeys.CHUNK_SIZE] == 0:
# We can do this only for chunkless indexes
context: Optional[str] = index.query_index(
retrieved_context: Optional[str] = index.query_index(
embedding_instance_id=output[PSKeys.EMBEDDING],
vector_db_instance_id=output[PSKeys.VECTOR_DB],
doc_id=doc_id,
Expand All @@ -270,13 +270,13 @@ def prompt_processor() -> Any:
# inconsistent, and not reproducible easily,
# this is just a safety net.
time.sleep(2)
context: Optional[str] = index.query_index(
retrieved_context: Optional[str] = index.query_index(
embedding_instance_id=output[PSKeys.EMBEDDING],
vector_db_instance_id=output[PSKeys.VECTOR_DB],
doc_id=doc_id,
usage_kwargs=usage_kwargs,
)
if context is None:
if retrieved_context is None:
# TODO: Obtain user set name for vector DB
msg = NO_CONTEXT_ERROR
app.logger.error(
Expand All @@ -294,6 +294,7 @@ def prompt_processor() -> Any:
msg,
)
raise APIError(message=msg)
context.add(retrieved_context)
# TODO: Use vectorDB name when available
publish_log(
log_events_id,
Expand Down Expand Up @@ -323,7 +324,7 @@ def prompt_processor() -> Any:
tool_settings=tool_settings,
output=output,
llm=llm,
context=context,
context="\n".join(context),
prompt="promptx",
metadata=metadata,
)
Expand Down Expand Up @@ -537,7 +538,7 @@ def prompt_processor() -> Any:
llm=llm,
challenge_llm=challenge_llm,
run_id=run_id,
context=context,
context="\n".join(context),
tool_settings=tool_settings,
output=output,
structured_output=structured_output,
Expand Down Expand Up @@ -593,7 +594,7 @@ def prompt_processor() -> Any:
try:
evaluator = eval_plugin["entrypoint_cls"](
"",
context,
"\n".join(context),
"",
"",
output,
Expand Down Expand Up @@ -680,7 +681,7 @@ def run_retrieval( # type:ignore
retrieval_type: str,
metadata: dict[str, Any],
) -> tuple[str, str]:
context: str = ""
context: set[str] = set()
prompt = output[PSKeys.PROMPTX]
if retrieval_type == PSKeys.SUBQUESTION:
subq_prompt: str = (
Expand Down Expand Up @@ -713,19 +714,11 @@ def run_retrieval( # type:ignore
prompt=subq_prompt,
)
subquestion_list = subquestions.split(",")
raw_retrieved_context = ""
for each_subq in subquestion_list:
retrieved_context = _retrieve_context(
output, doc_id, vector_index, each_subq
)
# Not adding the potential for pinecode serverless
# inconsistency issue owing to risk of infinte loop
# and inablity to diffrentiate genuine cases of
# empty context.
raw_retrieved_context = "\f\n".join(
[raw_retrieved_context, retrieved_context]
)
context = _remove_duplicate_nodes(raw_retrieved_context)
context.update(retrieved_context)

if retrieval_type == PSKeys.SIMPLE:

Expand All @@ -746,21 +739,15 @@ def run_retrieval( # type:ignore
tool_settings=tool_settings,
output=output,
llm=llm,
context=context,
context="\n".join(context),
prompt="promptx",
metadata=metadata,
)

return (answer, context)


def _remove_duplicate_nodes(retrieved_context: str) -> str:
context_set: set[str] = set(retrieved_context.split("\f\n"))
fomatted_context = "\f\n".join(context_set)
return fomatted_context


def _retrieve_context(output, doc_id, vector_index, answer) -> str:
def _retrieve_context(output, doc_id, vector_index, answer) -> set[str]:
retriever = vector_index.as_retriever(
similarity_top_k=output[PSKeys.SIMILARITY_TOP_K],
filters=MetadataFilters(
Expand All @@ -773,18 +760,18 @@ def _retrieve_context(output, doc_id, vector_index, answer) -> str:
),
)
nodes = retriever.retrieve(answer)
text = ""
context: set[str] = set()
for node in nodes:
# ToDo: May have to fine-tune this value for node score or keep it
# configurable at the adapter level
if node.score > 0:
text += node.get_content() + "\f\n"
context.add(node.get_content())
else:
app.logger.info(
"Node score is less than 0. "
f"Ignored: {node.node_id} with score {node.score}"
)
return text
return context


def log_exceptions(e: HTTPException):
Expand Down

0 comments on commit 2003289

Please sign in to comment.