Skip to content

Commit

Permalink
SDK backward compatibility changes (#55)
Browse files Browse the repository at this point in the history
* Backward comaptibility

* Make SDK backward compatible

* Fix for run_id population

* Address review comments

* roll sdk version

* Make param optiona;

* Update src/unstract/sdk/vector_db.py

Co-authored-by: Chandrasekharan M <[email protected]>
Signed-off-by: Gayathri <[email protected]>

---------

Signed-off-by: Gayathri <[email protected]>
Co-authored-by: Chandrasekharan M <[email protected]>
  • Loading branch information
1 parent 0a69375 commit e8979f5
Show file tree
Hide file tree
Showing 10 changed files with 427 additions and 190 deletions.
2 changes: 1 addition & 1 deletion src/unstract/sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.30.0"
__version__ = "0.31.0"


def get_sdk_version():
Expand Down
57 changes: 42 additions & 15 deletions src/unstract/sdk/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

from llama_index.core.base.embeddings.base import Embedding
from llama_index.core.embeddings import BaseEmbedding
Expand All @@ -21,22 +21,27 @@ class Embedding:
def __init__(
self,
tool: BaseTool,
adapter_instance_id: str,
adapter_instance_id: Optional[str] = None,
usage_kwargs: dict[Any, Any] = {},
):
self._tool = tool
self._adapter_instance_id = adapter_instance_id
self._embedding_instance: BaseEmbedding = self._get_embedding()
self._length: int = self._get_embedding_length()

self._usage_kwargs = usage_kwargs.copy()
self._usage_kwargs["adapter_instance_id"] = adapter_instance_id
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
CallbackManager.set_callback_manager(
platform_api_key=platform_api_key,
model=self._embedding_instance,
kwargs=self._usage_kwargs,
)
self._embedding_instance: BaseEmbedding = None
self._length: int = None
self._usage_kwargs = usage_kwargs
self._initialise()

def _initialise(self):
if self._adapter_instance_id:
self._embedding_instance = self._get_embedding()
self._length: int = self._get_embedding_length()
self._usage_kwargs["adapter_instance_id"] = self._adapter_instance_id
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
CallbackManager.set_callback(
platform_api_key=platform_api_key,
model=self._embedding_instance,
kwargs=self._usage_kwargs,
)

def _get_embedding(self) -> BaseEmbedding:
"""Gets an instance of LlamaIndex's embedding object.
Expand All @@ -48,6 +53,10 @@ def _get_embedding(self) -> BaseEmbedding:
BaseEmbedding: Embedding instance
"""
try:
if not self._adapter_instance_id:
raise EmbeddingError(
"Adapter instance ID not set. " "Initialisation failed"
)
embedding_config_data = ToolAdapter.get_adapter_config(
self._tool, self._adapter_instance_id
)
Expand Down Expand Up @@ -79,9 +88,27 @@ def _get_embedding_length(self) -> int:
embedding_dimension = len(embedding_list)
return embedding_dimension

@deprecated("Use the new class Embedding")
def get_class_name(self) -> str:
"""Gets the class name of the Llama Index Embedding.
Args:
NA
Returns:
Class name
"""
return self._embedding_instance.class_name()

@deprecated("Use Embedding instead of ToolEmbedding")
def get_embedding_length(self, embedding: BaseEmbedding) -> int:
return self._get_embedding_length(embedding)
return self._get_embedding_length()

@deprecated("Use Embedding instead of ToolEmbedding")
def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding:
if not self._embedding_instance:
self._adapter_instance_id = adapter_instance_id
self._initialise()
return self._embedding_instance


# Legacy
Expand Down
4 changes: 4 additions & 0 deletions src/unstract/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,9 @@ class X2TextError(SdkError):
DEFAULT_MESSAGE = "Error ocurred related to text extractor"


class OCRError(SdkError):
DEFAULT_MESSAGE = "Error ocurred related to OCR"


class RateLimitError(SdkError):
DEFAULT_MESSAGE = "Running into rate limit errors, please try again later"
Loading

0 comments on commit e8979f5

Please sign in to comment.