Skip to content

Commit

Permalink
FEAT: Add Support for Public Indexing and Prompt Run Functionality (#74)
Browse files Browse the repository at this point in the history
* Implemented SPS support for index method

* Changes to support public calls

* Changes to support public calls

* Fixed sonar issues

* Code optimization

* Reverted index.py file to it's previous state

* Fixed pre-commit issues

* Code quality improvements and minor bug fixes

* Fixed pre-commit issues

* Fixed pre-commit issues

* Added log message and optimized code

---------

Co-authored-by: Gayathri <[email protected]>
  • Loading branch information
tahierhussain and gaya3-zipstack authored Jul 25, 2024
1 parent 0c624b3 commit 623807c
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 22 deletions.
16 changes: 16 additions & 0 deletions src/unstract/sdk/adapters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
from typing import Any, Optional

import requests

from unstract.sdk.constants import AdapterKeys, LogLevel, ToolEnv
from unstract.sdk.helper import SdkHelper
from unstract.sdk.platform import PlatformBase
from unstract.sdk.tool.base import BaseTool

Expand Down Expand Up @@ -88,6 +90,11 @@ def get_adapter_config(
) -> Optional[dict[str, Any]]:
"""Get adapter spec by the help of unstract DB tool.
This method first checks if the adapter_instance_id matches
any of the public adapter keys. If it matches, the configuration
is fetched from environment variables. Otherwise, it connects to the
platform service to retrieve the configuration.
Args:
adapter_instance_id (str): ID of the adapter instance
tool (AbstractTool): Instance of AbstractTool
Expand All @@ -97,6 +104,15 @@ def get_adapter_config(
Returns:
Any: engine
"""
# Check if the adapter ID matches any public adapter keys
if SdkHelper.is_public_adapter(
adapter_id=adapter_instance_id
):
adapter_metadata_config = tool.get_env_or_die(
adapter_instance_id
)
adapter_metadata = json.loads(adapter_metadata_config)
return adapter_metadata
platform_host = tool.get_env_or_die(ToolEnv.PLATFORM_HOST)
platform_port = tool.get_env_or_die(ToolEnv.PLATFORM_PORT)

Expand Down
7 changes: 7 additions & 0 deletions src/unstract/sdk/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,10 @@ class ToolSettingsKey:
RUN_ID = "run_id"
WORKFLOW_ID = "workflow_id"
EXECUTION_ID = "execution_id"


class PublicAdapterKeys:
PUBLIC_LLM_CONFIG = "PUBLIC_LLM_CONFIG"
PUBLIC_EMBEDDING_CONFIG = "PUBLIC_EMBEDDING_CONFIG"
PUBLIC_VECTOR_DB_CONFIG = "PUBLIC_VECTOR_DB_CONFIG"
PUBLIC_X2TEXT_CONFIG = "PUBLIC_X2TEXT_CONFIG"
18 changes: 12 additions & 6 deletions src/unstract/sdk/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from unstract.sdk.adapters import ToolAdapter
from unstract.sdk.constants import LogLevel, ToolEnv
from unstract.sdk.exceptions import EmbeddingError, SdkError
from unstract.sdk.helper import SdkHelper
from unstract.sdk.tool.base import BaseTool
from unstract.sdk.utils.callback_manager import CallbackManager

Expand Down Expand Up @@ -36,12 +37,16 @@ def _initialise(self):
self._embedding_instance = self._get_embedding()
self._length: int = self._get_embedding_length()
self._usage_kwargs["adapter_instance_id"] = self._adapter_instance_id
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
CallbackManager.set_callback(
platform_api_key=platform_api_key,
model=self._embedding_instance,
kwargs=self._usage_kwargs,
)

if not SdkHelper.is_public_adapter(
adapter_id=self._adapter_instance_id
):
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
CallbackManager.set_callback(
platform_api_key=platform_api_key,
model=self._embedding_instance,
kwargs=self._usage_kwargs,
)

def _get_embedding(self) -> BaseEmbedding:
"""Gets an instance of LlamaIndex's embedding object.
Expand All @@ -57,6 +62,7 @@ def _get_embedding(self) -> BaseEmbedding:
raise EmbeddingError(
"Adapter instance ID not set. " "Initialisation failed"
)

embedding_config_data = ToolAdapter.get_adapter_config(
self._tool, self._adapter_instance_id
)
Expand Down
32 changes: 32 additions & 0 deletions src/unstract/sdk/helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import logging

from unstract.sdk.constants import PublicAdapterKeys

logger = logging.getLogger(__name__)

class SdkHelper:
def __init__(self) -> None:
pass
Expand All @@ -16,3 +22,29 @@ def get_platform_base_url(platform_host: str, platform_port: str) -> str:
if platform_host[-1] == "/":
return f"{platform_host[:-1]}:{platform_port}"
return f"{platform_host}:{platform_port}"

@staticmethod
def is_public_adapter(adapter_id: str) -> bool:
"""Check if the given adapter_id is one of the public adapter keys.
This method iterates over the attributes of the PublicAdapterKeys class
and checks if the provided adapter_id matches any of the attribute values.
Args:
adapter_id (str): The ID of the adapter to check.
Returns:
bool: True if the adapter_id matches any public adapter key,
False otherwise.
"""
try:
for attr in dir(PublicAdapterKeys):
if getattr(PublicAdapterKeys, attr) == adapter_id:
return True
return False
except Exception as e:
logger.warning(
f"Unable to determine if adapter_id: {adapter_id}"
f"is public or not: {str(e)}"
)
return False
19 changes: 13 additions & 6 deletions src/unstract/sdk/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from unstract.sdk.adapters import ToolAdapter
from unstract.sdk.constants import LogLevel, ToolEnv
from unstract.sdk.exceptions import LLMError, RateLimitError, SdkError
from unstract.sdk.helper import SdkHelper
from unstract.sdk.tool.base import BaseTool
from unstract.sdk.utils.callback_manager import CallbackManager

Expand Down Expand Up @@ -54,12 +55,16 @@ def _initialise(self):
if self._adapter_instance_id:
self._llm_instance = self._get_llm(self._adapter_instance_id)
self._usage_kwargs["adapter_instance_id"] = self._adapter_instance_id
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
CallbackManager.set_callback(
platform_api_key=platform_api_key,
model=self._llm_instance,
kwargs=self._usage_kwargs,
)

if not SdkHelper.is_public_adapter(
adapter_id=self._adapter_instance_id
):
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
CallbackManager.set_callback(
platform_api_key=platform_api_key,
model=self._llm_instance,
kwargs=self._usage_kwargs,
)

def complete(
self,
Expand Down Expand Up @@ -94,9 +99,11 @@ def _get_llm(self, adapter_instance_id: str) -> LlamaIndexLLM:
try:
if not self._adapter_instance_id:
raise LLMError("Adapter instance ID not set. " "Initialisation failed")

llm_config_data = ToolAdapter.get_adapter_config(
self._tool, self._adapter_instance_id
)

llm_adapter_id = llm_config_data.get(Common.ADAPTER_ID)
if llm_adapter_id not in self.llm_adapters:
raise SdkError(f"LLM adapter not supported : " f"{llm_adapter_id}")
Expand Down
16 changes: 12 additions & 4 deletions src/unstract/sdk/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
tool: BaseTool,
prompt_host: str,
prompt_port: str,
is_public_call: bool = False,
) -> None:
"""
Args:
Expand All @@ -28,13 +29,18 @@ def __init__(
"""
self.tool = tool
self.base_url = SdkHelper.get_platform_base_url(prompt_host, prompt_port)
self.bearer_token = tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
self.is_public_call = is_public_call
if not is_public_call:
self.bearer_token = tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)

def answer_prompt(
self, payload: dict[str, Any], params: Optional[dict[str, str]] = None
) -> dict[str, Any]:
url_path = "answer-prompt"
if self.is_public_call:
url_path = "answer-prompt-public"
return self._post_call(
url_path="answer-prompt",
url_path=url_path,
payload=payload,
params=params,
)
Expand Down Expand Up @@ -85,14 +91,16 @@ def _post_call(
"structure_output": "",
}
url: str = f"{self.base_url}/{url_path}"
headers: dict[str, str] = {"Authorization": f"Bearer {self.bearer_token}"}
headers: dict[str, str] = {}
if not self.is_public_call:
headers = {"Authorization": f"Bearer {self.bearer_token}"}
response: Response = Response()
try:
response = requests.post(
url=url,
json=payload,
headers=headers,
params=params,
headers=headers
)
response.raise_for_status()
result["status"] = "OK"
Expand Down
10 changes: 8 additions & 2 deletions src/unstract/sdk/vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unstract.sdk.constants import LogLevel, ToolEnv
from unstract.sdk.embedding import Embedding
from unstract.sdk.exceptions import SdkError, VectorDBError
from unstract.sdk.helper import SdkHelper
from unstract.sdk.platform import PlatformHelper
from unstract.sdk.tool.base import BaseTool

Expand Down Expand Up @@ -83,9 +84,11 @@ def _get_vector_db(self) -> Union[BasePydanticVectorStore, VectorStore]:
raise VectorDBError(
"Adapter instance ID not set. Initialisation failed"
)

vector_db_config = ToolAdapter.get_adapter_config(
self._tool, self._adapter_instance_id
)

vector_db_adapter_id = vector_db_config.get(Common.ADAPTER_ID)
if vector_db_adapter_id not in self.vector_db_adapters:
raise SdkError(
Expand All @@ -96,10 +99,13 @@ def _get_vector_db(self) -> Union[BasePydanticVectorStore, VectorStore]:
Common.METADATA
][Common.ADAPTER]
vector_db_metadata = vector_db_config.get(Common.ADAPTER_METADATA)
org = self._get_org_id()
# Adding the collection prefix and embedding type
# to the metadata
vector_db_metadata[VectorDbConstants.VECTOR_DB_NAME] = org

if not SdkHelper.is_public_adapter(adapter_id=self._adapter_instance_id):
org = self._get_org_id()
vector_db_metadata[VectorDbConstants.VECTOR_DB_NAME] = org

vector_db_metadata[
VectorDbConstants.EMBEDDING_DIMENSION
] = self._embedding_dimension
Expand Down
21 changes: 17 additions & 4 deletions src/unstract/sdk/x2txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
from unstract.sdk.adapters import ToolAdapter
from unstract.sdk.constants import LogLevel
from unstract.sdk.exceptions import X2TextError
from unstract.sdk.helper import SdkHelper
from unstract.sdk.tool.base import BaseTool


class X2Text(metaclass=ABCMeta):
def __init__(self, tool: BaseTool, adapter_instance_id: Optional[str] = None):
def __init__(
self,
tool: BaseTool,
adapter_instance_id: Optional[str] = None
):
self._tool = tool
self._x2text_adapters = adapters
self._adapter_instance_id = adapter_instance_id
Expand All @@ -32,9 +37,11 @@ def _get_x2text(self) -> X2TextAdapter:
raise X2TextError(
"Adapter instance ID not set. " "Initialisation failed"
)

x2text_config = ToolAdapter.get_adapter_config(
self._tool, self._adapter_instance_id
)

x2text_adapter_id = x2text_config.get(Common.ADAPTER_ID)
if x2text_adapter_id in self._x2text_adapters:
x2text_adapter = self._x2text_adapters[x2text_adapter_id][
Expand All @@ -48,9 +55,15 @@ def _get_x2text(self) -> X2TextAdapter:
x2text_metadata[
X2TextConstants.X2TEXT_PORT
] = self._tool.get_env_or_die(X2TextConstants.X2TEXT_PORT)
x2text_metadata[
X2TextConstants.PLATFORM_SERVICE_API_KEY
] = self._tool.get_env_or_die(X2TextConstants.PLATFORM_SERVICE_API_KEY)

if not SdkHelper.is_public_adapter(
adapter_id=self._adapter_instance_id
):
x2text_metadata[
X2TextConstants.PLATFORM_SERVICE_API_KEY
] = self._tool.get_env_or_die(
X2TextConstants.PLATFORM_SERVICE_API_KEY
)

self._x2text_instance = x2text_adapter(x2text_metadata)

Expand Down

0 comments on commit 623807c

Please sign in to comment.