Skip to content

Commit

Permalink
Added support to pass functions to index and llm complete functions (#79
Browse files Browse the repository at this point in the history
)

* Added support to pass functions to index and llm complete functions

Signed-off-by: Deepak <[email protected]>

* Added docstring

Signed-off-by: Deepak <[email protected]>

---------

Signed-off-by: Deepak <[email protected]>
Signed-off-by: Deepak K <[email protected]>
Co-authored-by: Gayathri <[email protected]>
  • Loading branch information
Deepak-Kesavan and gaya3-zipstack authored Aug 9, 2024
1 parent dea9bc3 commit 28434ca
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 26 deletions.
36 changes: 13 additions & 23 deletions src/unstract/sdk/index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Any, Optional
import logging
from typing import Any, Callable, Optional

from llama_index.core import Document
from llama_index.core.node_parser import SimpleNodeParser
Expand All @@ -25,6 +26,8 @@
from unstract.sdk.vector_db import VectorDB
from unstract.sdk.x2txt import X2Text

logger = logging.getLogger(__name__)


class Constants:
TOP_K = 5
Expand Down Expand Up @@ -101,27 +104,6 @@ def query_index(
finally:
vector_db.close()

def _cleanup_text(self, full_text):
# Remove text which is not required
full_text_lines = full_text.split("\n")
new_context_lines = []
empty_line_count = 0
for line in full_text_lines:
if line.strip() == "":
empty_line_count += 1
else:
if empty_line_count >= 3:
empty_line_count = 3
for i in range(empty_line_count):
new_context_lines.append("")
empty_line_count = 0
new_context_lines.append(line.rstrip())
self.tool.stream_log(
f"Old context length: {len(full_text_lines)}, "
f"New context length: {len(new_context_lines)}"
)
return "\n".join(new_context_lines)

def index(
self,
tool_id: str,
Expand All @@ -136,6 +118,7 @@ def index(
output_file_path: Optional[str] = None,
enable_highlight: bool = False,
usage_kwargs: dict[Any, Any] = {},
process_text: Optional[Callable[[str], str]] = None,
) -> str:
"""Indexes an individual file using the passed arguments.
Expand Down Expand Up @@ -276,10 +259,17 @@ def index(
except AdapterError as e:
# Wrapping AdapterErrors with SdkError
raise IndexingError(str(e)) from e
if process_text:
try:
result = process_text(extracted_text)
if isinstance(result, str):
extracted_text = result
except Exception as e:
logger.error(f"Error occured inside function 'process_text': {e}")
full_text.append(
{
"section": "full",
"text_contents": self._cleanup_text(extracted_text),
"text_contents": extracted_text,
}
)

Expand Down
32 changes: 29 additions & 3 deletions src/unstract/sdk/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from typing import Any, Optional
from typing import Any, Callable, Optional

from llama_index.core.base.llms.types import CompletionResponseGen
from llama_index.core.llms import LLM as LlamaIndexLLM
Expand Down Expand Up @@ -69,15 +69,41 @@ def _initialise(self):
def complete(
self,
prompt: str,
retries: int = 3,
process_text: Optional[Callable[[str], str]] = None,
**kwargs: Any,
) -> Optional[dict[str, Any]]:
"""Generates a completion response for the given prompt.
Args:
prompt (str): The input text prompt for generating the completion.
process_text (Optional[Callable[[str], str]], optional): A callable that
processes the generated text and extracts specific information.
Defaults to None.
**kwargs (Any): Additional arguments passed to the completion function.
Returns:
Optional[dict[str, Any]]: A dictionary containing the result of the
completion and processed output or None if the completion fails.
Raises:
Any: If an error occurs during the completion process, it will be
raised after being processed by `parse_llm_err`.
"""
try:
response: CompletionResponse = self._llm_instance.complete(prompt, **kwargs)
process_text_output = {}
if process_text:
try:
process_text_output = process_text(response, LLM.json_regex)
if not isinstance(process_text_output, dict):
process_text_output = {}
except Exception as e:
logger.error(f"Error occured inside function 'process_text': {e}")
process_text_output = {}
match = LLM.json_regex.search(response.text)
if match:
response.text = match.group(0)
return {LLM.RESPONSE: response}
return {LLM.RESPONSE: response, **process_text_output}
except Exception as e:
raise parse_llm_err(e) from e

Expand Down

0 comments on commit 28434ca

Please sign in to comment.