diff --git a/src/unstract/sdk/index.py b/src/unstract/sdk/index.py index 005ddaae..c7f6e03c 100644 --- a/src/unstract/sdk/index.py +++ b/src/unstract/sdk/index.py @@ -1,5 +1,6 @@ import json -from typing import Any, Optional +import logging +from typing import Any, Callable, Optional from llama_index.core import Document from llama_index.core.node_parser import SimpleNodeParser @@ -25,6 +26,8 @@ from unstract.sdk.vector_db import VectorDB from unstract.sdk.x2txt import X2Text +logger = logging.getLogger(__name__) + class Constants: TOP_K = 5 @@ -101,27 +104,6 @@ def query_index( finally: vector_db.close() - def _cleanup_text(self, full_text): - # Remove text which is not required - full_text_lines = full_text.split("\n") - new_context_lines = [] - empty_line_count = 0 - for line in full_text_lines: - if line.strip() == "": - empty_line_count += 1 - else: - if empty_line_count >= 3: - empty_line_count = 3 - for i in range(empty_line_count): - new_context_lines.append("") - empty_line_count = 0 - new_context_lines.append(line.rstrip()) - self.tool.stream_log( - f"Old context length: {len(full_text_lines)}, " - f"New context length: {len(new_context_lines)}" - ) - return "\n".join(new_context_lines) - def index( self, tool_id: str, @@ -136,6 +118,7 @@ def index( output_file_path: Optional[str] = None, enable_highlight: bool = False, usage_kwargs: dict[Any, Any] = {}, + process_text: Optional[Callable[[str], str]] = None, ) -> str: """Indexes an individual file using the passed arguments. @@ -276,10 +259,17 @@ def index( except AdapterError as e: # Wrapping AdapterErrors with SdkError raise IndexingError(str(e)) from e + if process_text: + try: + result = process_text(extracted_text) + if isinstance(result, str): + extracted_text = result + except Exception as e: + logger.error(f"Error occured inside function 'process_text': {e}") full_text.append( { "section": "full", - "text_contents": self._cleanup_text(extracted_text), + "text_contents": extracted_text, } ) diff --git a/src/unstract/sdk/llm.py b/src/unstract/sdk/llm.py index eda7c919..ccfc35fb 100644 --- a/src/unstract/sdk/llm.py +++ b/src/unstract/sdk/llm.py @@ -1,6 +1,6 @@ import logging import re -from typing import Any, Optional +from typing import Any, Callable, Optional from llama_index.core.base.llms.types import CompletionResponseGen from llama_index.core.llms import LLM as LlamaIndexLLM @@ -69,15 +69,41 @@ def _initialise(self): def complete( self, prompt: str, - retries: int = 3, + process_text: Optional[Callable[[str], str]] = None, **kwargs: Any, ) -> Optional[dict[str, Any]]: + """Generates a completion response for the given prompt. + + Args: + prompt (str): The input text prompt for generating the completion. + process_text (Optional[Callable[[str], str]], optional): A callable that + processes the generated text and extracts specific information. + Defaults to None. + **kwargs (Any): Additional arguments passed to the completion function. + + Returns: + Optional[dict[str, Any]]: A dictionary containing the result of the + completion and processed output or None if the completion fails. + + Raises: + Any: If an error occurs during the completion process, it will be + raised after being processed by `parse_llm_err`. + """ try: response: CompletionResponse = self._llm_instance.complete(prompt, **kwargs) + process_text_output = {} + if process_text: + try: + process_text_output = process_text(response, LLM.json_regex) + if not isinstance(process_text_output, dict): + process_text_output = {} + except Exception as e: + logger.error(f"Error occured inside function 'process_text': {e}") + process_text_output = {} match = LLM.json_regex.search(response.text) if match: response.text = match.group(0) - return {LLM.RESPONSE: response} + return {LLM.RESPONSE: response, **process_text_output} except Exception as e: raise parse_llm_err(e) from e