Skip to content

Commit

Permalink
Added adapter_instance_id and run_id (#40)
Browse files Browse the repository at this point in the history
* Added adapter_instance_id and run_id

* Updated field model_type -> model_name

* Review comment fixes

* Version bump

---------

Co-authored-by: Rahul Johny <[email protected]>
  • Loading branch information
Deepak-Kesavan and johnyrahul authored Apr 26, 2024
1 parent 6aa69f6 commit b78f41b
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 57 deletions.
2 changes: 1 addition & 1 deletion src/unstract/sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.22.1"
__version__ = "0.23.0"


def get_sdk_version():
Expand Down
48 changes: 27 additions & 21 deletions src/unstract/sdk/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@ class Audit(StreamMixin):
Attributes:
None
Example usage:
audit = Audit()
audit.push_usage_data(
token_counter,
workflow_id,
execution_id,
external_service,
event_type)
"""

def __init__(self, log_level: LogLevel = LogLevel.INFO) -> None:
Expand All @@ -33,23 +24,28 @@ def push_usage_data(
self,
platform_api_key: str,
token_counter: TokenCountingHandler = None,
workflow_id: str = "",
execution_id: str = "",
external_service: str = "",
model_name: str = "",
event_type: CBEventType = None,
**kwargs,
) -> None:
"""Pushes the usage data to the platform service.
Args:
platform_api_key (str): The platform API key.
token_counter (TokenCountingHandler, optional): The token counter
object. Defaults to None.
workflow_id (str, optional): The ID of the workflow. Defaults to "".
execution_id (str, optional): The ID of the execution. Defaults
to "".
external_service (str, optional): The name of the external service.
Defaults to "".
object. Defaults to None.
model_name (str, optional): The name of the model.
Defaults to "".
event_type (CBEventType, optional): The type of the event. Defaults
to None.
to None.
**kwargs: Optional keyword arguments.
workflow_id (str, optional): The ID of the workflow.
Defaults to "".
execution_id (str, optional): The ID of the execution. Defaults
to "".
adapter_instance_id (str, optional): The adapter instance ID.
Defaults to "".
run_id (str, optional): The run ID. Defaults to "".
Returns:
None
Expand All @@ -66,11 +62,18 @@ def push_usage_data(
)
bearer_token = platform_api_key

workflow_id = kwargs.get("workflow_id", "")
execution_id = kwargs.get("execution_id", "")
adapter_instance_id = kwargs.get("adapter_instance_id", "")
run_id = kwargs.get("run_id", "")

data = {
"usage_type": event_type,
"external_service": external_service,
"workflow_id": workflow_id,
"execution_id": execution_id,
"adapter_instance_id": adapter_instance_id,
"run_id": run_id,
"usage_type": event_type,
"model_name": model_name,
"embedding_tokens": token_counter.total_embedding_token_count,
"prompt_tokens": token_counter.prompt_llm_token_count,
"completion_tokens": token_counter.completion_llm_token_count,
Expand Down Expand Up @@ -100,3 +103,6 @@ def push_usage_data(
log=f"Error while pushing usage details: {e}",
level=LogLevel.ERROR,
)

finally:
token_counter.reset_counts()
34 changes: 14 additions & 20 deletions src/unstract/sdk/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,26 @@ def run_completion(
) -> Optional[dict[str, Any]]:
# Setup callback manager to collect Usage stats
UNCallbackManager.set_callback_manager(
platform_api_key=platform_api_key, llm=llm
platform_api_key=platform_api_key, llm=llm, **kwargs
)
# Removing specific keys from kwargs
new_kwargs = kwargs.copy()
for key in [
"workflow_id",
"execution_id",
"adapter_instance_id",
"run_id",
]:
new_kwargs.pop(key, None)
for i in range(retries):
try:
response: CompletionResponse = llm.complete(prompt, **kwargs)
response: CompletionResponse = llm.complete(
prompt, **new_kwargs
)
match = cls.json_regex.search(response.text)
if match:
response.text = match.group(0)

usage = {}
llm_token_counts = llm.callback_manager.handlers[
0
].llm_token_counts
if llm_token_counts:
llm_token_count = llm_token_counts[0]
usage[
"prompt_token_count"
] = llm_token_count.prompt_token_count
usage[
"completion_token_count"
] = llm_token_count.completion_token_count
usage[
"total_token_count"
] = llm_token_count.total_token_count

return {"response": response, "usage": usage}
return {"response": response}

except Exception as e:
if i == retries - 1:
Expand Down
8 changes: 3 additions & 5 deletions src/unstract/sdk/utils/callback_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def set_callback_manager(
platform_api_key: str,
llm: Optional[LLM] = None,
embedding: Optional[BaseEmbedding] = None,
workflow_id: str = "",
execution_id: str = "",
**kwargs,
) -> LlamaIndexCallbackManager:
"""Sets the standard callback manager for the llm. This is to be called
explicitly whenever there is a need for the callback handling defined
Expand All @@ -52,7 +51,7 @@ def set_callback_manager(
llm (LLM): The LLM type
Returns:
CallbackManager tyoe of llama index
CallbackManager type of llama index
Example:
UNCallbackManager.set_callback_manager(
Expand All @@ -73,8 +72,7 @@ def set_callback_manager(
platform_api_key=platform_api_key,
llm_model=llm,
embed_model=embedding,
workflow_id=workflow_id,
execution_id=execution_id,
**kwargs,
)

callback_manager: LlamaIndexCallbackManager = LlamaIndexCallbackManager(
Expand Down
16 changes: 6 additions & 10 deletions src/unstract/sdk/utils/usage_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,14 @@ def __init__(
platform_api_key: str,
llm_model: LLM = None,
embed_model: BaseEmbedding = None,
workflow_id: str = "",
execution_id: str = "",
event_starts_to_ignore: Optional[list[CBEventType]] = None,
event_ends_to_ignore: Optional[list[CBEventType]] = None,
verbose: bool = False,
log_level: LogLevel = LogLevel.INFO,
**kwargs,
) -> None:
self.kwargs = kwargs
self._verbose = verbose
self.workflow_id = workflow_id
self.execution_id = execution_id
self.token_counter = token_counter
self.llm_model = llm_model
self.embed_model = embed_model
Expand Down Expand Up @@ -96,9 +94,8 @@ def on_event_end(
platform_api_key=self.platform_api_key,
token_counter=self.token_counter,
event_type=event_type,
external_service=self.llm_model.metadata.model_name,
workflow_id=self.workflow_id,
execution_id=self.execution_id,
model_name=self.llm_model.metadata.model_name,
**self.kwargs,
)

elif (
Expand All @@ -113,7 +110,6 @@ def on_event_end(
platform_api_key=self.platform_api_key,
token_counter=self.token_counter,
event_type=event_type,
external_service=self.embed_model.model_name,
workflow_id=self.workflow_id,
execution_id=self.execution_id,
model_name=self.embed_model.model_name,
**self.kwargs,
)

0 comments on commit b78f41b

Please sign in to comment.