Skip to content

Commit

Permalink
Update response_stream
Browse files Browse the repository at this point in the history
  • Loading branch information
ashpreetbedi committed Feb 7, 2025
1 parent 896ae72 commit acb3d2a
Show file tree
Hide file tree
Showing 18 changed files with 186 additions and 236 deletions.
5 changes: 2 additions & 3 deletions cookbook/models/groq/async_tool_use.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Please install dependencies using:
pip install openai duckduckgo-search newspaper4k lxml_html_clean agno
"""

import asyncio

from agno.agent import Agent
Expand All @@ -23,6 +24,4 @@
)

# -*- Print a response to the cli
asyncio.run(
agent.aprint_response("Simulation theory", stream=True)
)
asyncio.run(agent.aprint_response("Simulation theory", stream=True))
1 change: 1 addition & 0 deletions cookbook/models/ollama/async_basic_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from typing import Iterator # noqa

from agno.agent import Agent, RunResponse # noqa
from agno.models.ollama import Ollama

Expand Down
1 change: 1 addition & 0 deletions cookbook/models/openai/async_tool_use.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Run `pip install duckduckgo-search` to install dependencies."""

import asyncio

from agno.agent import Agent
from agno.models.openai import OpenAIChat
from agno.tools.duckduckgo import DuckDuckGoTools
Expand Down
2 changes: 0 additions & 2 deletions libs/agno/agno/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,10 +1344,8 @@ def get_tools(self) -> Optional[List[Union[Toolkit, Callable, Function]]]:
return agent_tools

def add_tools_to_model(self, model: Model) -> None:

# Skip if functions_for_model is not None
if self._functions_for_model is None or self._tools_for_model is None:

# Get Agent tools
agent_tools = self.get_tools()
if agent_tools is not None:
Expand Down
4 changes: 2 additions & 2 deletions libs/agno/agno/document/reader/csv_reader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import csv
import io
import os
from time import sleep
from pathlib import Path
from time import sleep
from typing import IO, Any, List, Union
from urllib.parse import urlparse

Expand Down Expand Up @@ -73,7 +73,7 @@ def read(self, url: str) -> List[Document]:
if attempt == 2: # Last attempt
logger.error(f"Failed to fetch CSV after 3 attempts: {e}")
raise
wait_time = 2 ** attempt # Exponential backoff: 1, 2, 4 seconds
wait_time = 2**attempt # Exponential backoff: 1, 2, 4 seconds
logger.warning(f"Request failed, retrying in {wait_time} seconds...")
sleep(wait_time)

Expand Down
2 changes: 1 addition & 1 deletion libs/agno/agno/document/reader/pdf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def read(self, url: str) -> List[Document]:
if attempt == 2: # Last attempt
logger.error(f"Failed to fetch PDF after 3 attempts: {e}")
raise
wait_time = 2 ** attempt # Exponential backoff: 1, 2, 4 seconds
wait_time = 2**attempt # Exponential backoff: 1, 2, 4 seconds
logger.warning(f"Request failed, retrying in {wait_time} seconds...")
sleep(wait_time)

Expand Down
4 changes: 2 additions & 2 deletions libs/agno/agno/document/reader/url_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from time import sleep
from typing import List
from urllib.parse import urlparse
from time import sleep

from agno.document.base import Document
from agno.document.reader.base import Reader
Expand Down Expand Up @@ -29,7 +29,7 @@ def read(self, url: str) -> List[Document]:
if attempt == 2: # Last attempt
logger.error(f"Failed to fetch PDF after 3 attempts: {e}")
raise
wait_time = 2 ** attempt # Exponential backoff: 1, 2, 4 seconds
wait_time = 2**attempt # Exponential backoff: 1, 2, 4 seconds
logger.warning(f"Request failed, retrying in {wait_time} seconds...")
sleep(wait_time)

Expand Down
3 changes: 2 additions & 1 deletion libs/agno/agno/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ def __init__(
exc, user_message=user_message, agent_message=agent_message, messages=messages, stop_execution=True
)


class ModelProviderError(Exception):
"""Exception raised when a model provider returns an error."""

def __init__(self, exc, model_name: str, model_id: str):
super().__init__(exc)
self.model_name = model_name
self.model_id = model_id
self.model_id = model_id
35 changes: 21 additions & 14 deletions libs/agno/agno/models/anthropic/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,16 @@

try:
from anthropic import Anthropic as AnthropicClient
from anthropic import APIConnectionError, APIStatusError, RateLimitError
from anthropic import AsyncAnthropic as AsyncAnthropicClient
from anthropic import APIConnectionError, RateLimitError, APIStatusError

from anthropic.types import (
ContentBlockDeltaEvent,
MessageStopEvent,
ContentBlockStopEvent,
MessageDeltaEvent,
MessageStopEvent,
TextBlock,
TextDelta,
ToolUseBlock,
ContentBlockStopEvent,
)
from anthropic.types import Message as AnthropicMessage
except (ModuleNotFoundError, ImportError):
Expand Down Expand Up @@ -335,11 +334,15 @@ def invoke_stream(self, messages: List[Message]) -> Any:
request_kwargs = self._prepare_request_kwargs(system_message)

try:
return self.get_client().messages.stream(
model=self.id,
messages=chat_messages, # type: ignore
**request_kwargs,
).__enter__()
return (
self.get_client()
.messages.stream(
model=self.id,
messages=chat_messages, # type: ignore
**request_kwargs,
)
.__enter__()
)
except APIConnectionError as e:
logger.error(f"Connection error while calling Claude API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
Expand Down Expand Up @@ -404,11 +407,15 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any:
chat_messages, system_message = _format_messages(messages)
request_kwargs = self._prepare_request_kwargs(system_message)

return await self.get_async_client().messages.stream(
model=self.id,
messages=chat_messages, # type: ignore
**request_kwargs,
).__aenter__()
return (
await self.get_async_client()
.messages.stream(
model=self.id,
messages=chat_messages, # type: ignore
**request_kwargs,
)
.__aenter__()
)
except APIConnectionError as e:
logger.error(f"Connection error while calling Claude API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
Expand Down
4 changes: 1 addition & 3 deletions libs/agno/agno/models/aws/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,5 @@ async def ainvoke(self, *args, **kwargs) -> Any:
async def ainvoke_stream(self, *args, **kwargs) -> Any:
raise NotImplementedError(f"Async not supported on {self.name}.")

def parse_provider_response_delta(
self, response: Any
) -> Iterator[ProviderResponse]:
def parse_provider_response_delta(self, response: Any) -> Iterator[ProviderResponse]:
pass
18 changes: 11 additions & 7 deletions libs/agno/agno/models/aws/claude.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from dataclasses import dataclass
import json
from typing import Any, Dict, List, Optional, Iterator
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional

from agno.models.aws.bedrock import AwsBedrock
from agno.models.base import MessageData
from agno.models.message import Message
from agno.models.response import ModelResponse
from agno.utils.log import logger


@dataclass
class BedrockResponseUsage:
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0


@dataclass
class Claude(AwsBedrock):
"""
Expand Down Expand Up @@ -230,7 +232,9 @@ def parse_provider_response(self, response: Dict[str, Any]) -> ModelResponse:
return model_response

# Override the base class method
def format_function_call_results(self, messages: List[Message], function_call_results: List[Message], tool_ids: List[str]) -> None:
def format_function_call_results(
self, messages: List[Message], function_call_results: List[Message], tool_ids: List[str]
) -> None:
"""
Format function call results.
"""
Expand All @@ -248,9 +252,10 @@ def format_function_call_results(self, messages: List[Message], function_call_re
logger.debug(f"Tool call responses: {fc_responses}")
messages.append(Message(role="user", content=json.dumps(fc_responses)))


# Override the base class method
def process_response_stream(self, messages: List[Message], assistant_message: Message, stream_data: MessageData) -> Iterator[ModelResponse]:
def process_response_stream(
self, messages: List[Message], assistant_message: Message, stream_data: MessageData
) -> Iterator[ModelResponse]:
"""
Process the streaming response from the Bedrock API.
"""
Expand Down Expand Up @@ -321,8 +326,7 @@ def process_response_stream(self, messages: List[Message], assistant_message: Me

# Update metrics
self.add_usage_metrics_to_assistant_message(
assistant_message=assistant_message,
response_usage=response_usage
assistant_message=assistant_message, response_usage=response_usage
)

if tool_ids:
Expand Down
Loading

0 comments on commit acb3d2a

Please sign in to comment.