Skip to content

Commit

Permalink
[Feature] SDK llama index 0.10.x migration (#19)
Browse files Browse the repository at this point in the history
* Llama index migration changes

* Service context deprecation

* Remove service_context file

* Fix for token usage

* Add comments

* Change signature

* Remove Optional/None type

* Avoid nesting

* Update llala-index version

---------

Signed-off-by: Gayathri <[email protected]>
  • Loading branch information
gaya3-zipstack authored Apr 12, 2024
1 parent 4fa116b commit f2c9bf1
Show file tree
Hide file tree
Showing 13 changed files with 1,435 additions and 1,369 deletions.
2,380 changes: 1,206 additions & 1,174 deletions pdm.lock

Large diffs are not rendered by default.

18 changes: 16 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ dependencies = [
"python-magic~=0.4.27",
"python-dotenv==1.0.0",
# LLM Triad
"unstract-adapters~=0.8.0",
"llama-index==0.9.28",
"unstract-adapters~=0.9.0",
"llama-index==0.10.28",
"tiktoken~=0.4.0",
"transformers==4.37.0",
]
Expand Down Expand Up @@ -52,10 +52,24 @@ lint = [
"yamllint>=1.35.1",
]

[tool.isort]
line_length = 80
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
profile = "black"

[tool.pdm.build]
includes = ["src"]
package-dir = "src"

[tool.pdm.version]
source = "file"
path = "src/unstract/sdk/__init__.py"

# Adding the following override to resolve dependency version
# for environs. Otherwise, it stays stuck while resolving pins
[tool.pdm.resolution.overrides]
grpcio = ">=1.62.1"
2 changes: 1 addition & 1 deletion src/unstract/sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.18.2"
__version__ = "0.19.0"


def get_sdk_version():
Expand Down
3 changes: 1 addition & 2 deletions src/unstract/sdk/audit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import requests
from llama_index.callbacks import TokenCountingHandler
from llama_index.callbacks.schema import CBEventType
from llama_index.core.callbacks import CBEventType, TokenCountingHandler

from unstract.sdk.constants import LogLevel, ToolEnv
from unstract.sdk.helper import SdkHelper
Expand Down
60 changes: 31 additions & 29 deletions src/unstract/sdk/embedding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional

from llama_index.embeddings.base import BaseEmbedding
from llama_index.core.embeddings import BaseEmbedding
from unstract.adapters.constants import Common
from unstract.adapters.embedding import adapters

from unstract.sdk.adapters import ToolAdapter
from unstract.sdk.constants import LogLevel, ToolSettingsKey
from unstract.sdk.exceptions import SdkError
from unstract.sdk.tool.base import BaseTool


Expand All @@ -23,41 +24,42 @@ def __init__(self, tool: BaseTool, tool_settings: dict[str, str] = {}):

def get_embedding(
self, adapter_instance_id: Optional[str] = None
) -> Optional[BaseEmbedding]:
) -> BaseEmbedding:
adapter_instance_id = (
adapter_instance_id
if adapter_instance_id
else self.embedding_adapter_instance_id
)
if adapter_instance_id is not None:
try:
embedding_config_data = ToolAdapter.get_adapter_config(
self.tool, adapter_instance_id
if not adapter_instance_id:
raise SdkError(
f"Adapter_instance_id does not have "
f"a valid value: {adapter_instance_id}"
)
try:
embedding_config_data = ToolAdapter.get_adapter_config(
self.tool, adapter_instance_id
)
embedding_adapter_id = embedding_config_data.get(Common.ADAPTER_ID)
self.embedding_adapter_id = embedding_adapter_id
if embedding_adapter_id in self.embedding_adapters:
embedding_adapter = self.embedding_adapters[
embedding_adapter_id
][Common.METADATA][Common.ADAPTER]
embedding_metadata = embedding_config_data.get(
Common.ADAPTER_METADATA
)
embedding_adapter_id = embedding_config_data.get(
Common.ADAPTER_ID
embedding_adapter_class = embedding_adapter(embedding_metadata)
return embedding_adapter_class.get_embedding_instance()
else:
raise SdkError(
f"Embedding adapter not supported : "
f"{embedding_adapter_id}"
)
self.embedding_adapter_id = embedding_adapter_id
if embedding_adapter_id in self.embedding_adapters:
embedding_adapter = self.embedding_adapters[
embedding_adapter_id
][Common.METADATA][Common.ADAPTER]
embedding_metadata = embedding_config_data.get(
Common.ADAPTER_METADATA
)
embedding_adapter_class = embedding_adapter(
embedding_metadata
)
return embedding_adapter_class.get_embedding_instance()
else:
return None
except Exception as e:
self.tool.stream_log(
log=f"Error getting embedding: {e}", level=LogLevel.ERROR
)
return None
else:
return None
except Exception as e:
self.tool.stream_log(
log=f"Error getting embedding: {e}", level=LogLevel.ERROR
)
raise SdkError(f"Error getting embedding instance: {e}")

def get_embedding_length(self, embedding: BaseEmbedding) -> int:
embedding_list = embedding._get_text_embedding(self.__TEST_SNIPPET)
Expand Down
22 changes: 14 additions & 8 deletions src/unstract/sdk/index.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Optional

from llama_index import Document, StorageContext, VectorStoreIndex
from llama_index.node_parser import SimpleNodeParser
from llama_index.vector_stores import (
from llama_index.core import Document
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.storage import StorageContext
from llama_index.core.vector_stores import (
FilterOperator,
MetadataFilter,
MetadataFilters,
Expand All @@ -17,7 +19,9 @@
from unstract.sdk.exceptions import IndexingError, SdkError
from unstract.sdk.tool.base import BaseTool
from unstract.sdk.utils import ToolUtils
from unstract.sdk.utils.service_context import ServiceContext
from unstract.sdk.utils.callback_manager import (
CallbackManager as UNCallbackManager,
)
from unstract.sdk.vector_db import ToolVectorDB
from unstract.sdk.x2txt import X2Text

Expand Down Expand Up @@ -281,12 +285,12 @@ def index_file(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)

service_context = ServiceContext.get_service_context(
# Set callback_manager to collect Usage stats
callback_manager = UNCallbackManager.set_callback_manager(
platform_api_key=self.tool.get_env_or_die(
ToolEnv.PLATFORM_API_KEY
),
embed_model=embedding_li,
node_parser=parser,
embedding=embedding_li,
)

self.tool.stream_log("Adding nodes to vector db...")
Expand All @@ -295,7 +299,9 @@ def index_file(
documents,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
embed_model=embedding_li,
node_parser=parser,
callback_manager=callback_manager,
)
except Exception as e:
self.tool.stream_log(
Expand Down
33 changes: 18 additions & 15 deletions src/unstract/sdk/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
import time
from typing import Any, Optional

from llama_index.llms import LLM
from llama_index.llms.base import CompletionResponse
from llama_index.core.llms import LLM, CompletionResponse
from unstract.adapters.constants import Common
from unstract.adapters.llm import adapters
from unstract.adapters.llm.llm_adapter import LLMAdapter

from unstract.sdk.adapters import ToolAdapter
from unstract.sdk.constants import LogLevel, ToolSettingsKey
from unstract.sdk.exceptions import SdkError
from unstract.sdk.tool.base import BaseTool
from unstract.sdk.utils.service_context import ServiceContext
from unstract.sdk.utils.callback_manager import (
CallbackManager as UNCallbackManager,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,7 +55,8 @@ def run_completion(
retries: int = 3,
**kwargs: Any,
) -> Optional[dict[str, Any]]:
ServiceContext.get_service_context(
# Setup callback manager to collect Usage stats
UNCallbackManager.set_callback_manager(
platform_api_key=platform_api_key, llm=llm
)
for i in range(retries):
Expand Down Expand Up @@ -87,13 +90,11 @@ def run_completion(
time.sleep(5)
return None

def get_llm(
self, adapter_instance_id: Optional[str] = None
) -> Optional[LLM]:
def get_llm(self, adapter_instance_id: Optional[str] = None) -> LLM:
"""Returns the LLM object for the tool.
Returns:
Optional[LLM]: The LLM object for the tool.
LLM: The LLM object for the tool.
(llama_index.llms.base.LLM)
"""
adapter_instance_id = (
Expand All @@ -114,20 +115,22 @@ def get_llm(
][Common.ADAPTER]
llm_metadata = llm_config_data.get(Common.ADAPTER_METADATA)
llm_adapter_class: LLMAdapter = llm_adapter(llm_metadata)
llm_instance: Optional[
LLM
] = llm_adapter_class.get_llm_instance()
llm_instance: LLM = llm_adapter_class.get_llm_instance()
return llm_instance
else:
return None
raise SdkError(
f"LLM adapter not supported : " f"{llm_adapter_id}"
)
except Exception as e:
self.tool.stream_log(
log=f"Unable to get llm instance: {e}", level=LogLevel.ERROR
)
return None
raise SdkError(f"Error getting llm instance: {e}")
else:
logger.error("The adapter_instance_id parameter is None")
return None
raise SdkError(
f"Adapter_instance_id does not have "
f"a valid value: {adapter_instance_id}"
)

def get_max_tokens(self, reserved_for_output: int = 0) -> int:
"""Returns the maximum number of tokens that can be used for the LLM.
Expand Down
123 changes: 123 additions & 0 deletions src/unstract/sdk/utils/callback_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import logging
from typing import Callable, Optional, Union

import tiktoken
from llama_index.core.callbacks import (
CallbackManager as LlamaIndexCallbackManager,
)
from llama_index.core.callbacks import TokenCountingHandler
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.llms import LLM
from transformers import AutoTokenizer

from unstract.sdk.utils.usage_handler import UsageHandler

logger = logging.getLogger(__name__)


class CallbackManager:
"""Class representing the CallbackManager to manage callbacks.
Use this over the default service context of llama index
This class supports a tokenizer, token counter,
usage handler, and callback manager.
Attributes:
None
Methods:
set_callback_manager: Returns a standard callback manager
Example:
callback_manager = CallbackManager.
set_callback_manager(
llm="default",
embedding="default")
"""

@staticmethod
def set_callback_manager(
platform_api_key: str,
llm: Optional[LLM] = None,
embedding: Optional[BaseEmbedding] = None,
workflow_id: str = "",
execution_id: str = "",
) -> LlamaIndexCallbackManager:
"""Sets the standard callback manager for the llm. This is to be called
explicitly whenever there is a need for the callback handling defined
here as handlers is to be invoked.
Parameters:
llm (LLM): The LLM type
Returns:
CallbackManager tyoe of llama index
Example:
UNCallbackManager.set_callback_manager(
platform_api_key: "abc",
llm=llm,
embedding=embedding
)
"""

if llm:
tokenizer = CallbackManager.get_tokenizer(llm)
elif embedding:
tokenizer = CallbackManager.get_tokenizer(embedding)

token_counter = TokenCountingHandler(tokenizer=tokenizer, verbose=True)
usage_handler = UsageHandler(
token_counter=token_counter,
platform_api_key=platform_api_key,
llm_model=llm,
embed_model=embedding,
workflow_id=workflow_id,
execution_id=execution_id,
)

callback_manager: LlamaIndexCallbackManager = LlamaIndexCallbackManager(
handlers=[token_counter, usage_handler]
)

if llm is not None:
llm.callback_manager = callback_manager
if embedding is not None:
embedding.callback_manager = callback_manager

return callback_manager

@staticmethod
def get_tokenizer(
model: Optional[Union[LLM, BaseEmbedding, None]],
fallback_tokenizer: Callable[[str], list] = tiktoken.encoding_for_model(
"gpt-3.5-turbo"
).encode,
) -> Callable[[str], list]:
"""Returns a tokenizer function based on the provided model.
Args:
model (Optional[Union[LLM, BaseEmbedding]]): The model to use for
tokenization.
Returns:
Callable[[str], List]: The tokenizer function.
Raises:
OSError: If an error occurs while loading the tokenizer.
"""

try:
if isinstance(model, LLM):
model_name: str = model.metadata.model_name
elif isinstance(model, BaseEmbedding):
model_name = model.model_name

tokenizer: Callable[[str], list] = AutoTokenizer.from_pretrained(
model_name
).encode
return tokenizer
except OSError as e:
logger.warning(str(e))
return fallback_tokenizer
Loading

0 comments on commit f2c9bf1

Please sign in to comment.