Skip to content

Commit

Permalink
Add Mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkbrnd committed Feb 7, 2025
1 parent acb3d2a commit e6af03d
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 507 deletions.
6 changes: 3 additions & 3 deletions libs/agno/agno/models/aws/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from agno.exceptions import ModelProviderError
from agno.models.base import Model
from agno.models.message import Message
from agno.models.response import ProviderResponse
from agno.models.response import ModelResponse
from agno.utils.log import logger

try:
Expand Down Expand Up @@ -121,7 +121,7 @@ def format_messages(self, messages: List[Message]) -> Dict[str, Any]:
raise NotImplementedError("Please use a subclass of AwsBedrock")

@abstractmethod
def parse_provider_response(self, response: Dict[str, Any]) -> ProviderResponse:
def parse_provider_response(self, response: Dict[str, Any]) -> ModelResponse:
raise NotImplementedError("Please use a subclass of AwsBedrock")

async def ainvoke(self, *args, **kwargs) -> Any:
Expand All @@ -130,5 +130,5 @@ async def ainvoke(self, *args, **kwargs) -> Any:
async def ainvoke_stream(self, *args, **kwargs) -> Any:
raise NotImplementedError(f"Async not supported on {self.name}.")

def parse_provider_response_delta(self, response: Any) -> Iterator[ProviderResponse]:
def parse_provider_response_delta(self, response: Any) -> ModelResponse:
pass
18 changes: 12 additions & 6 deletions libs/agno/agno/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,26 +851,30 @@ async def aresponse(self, messages: List[Message]) -> ModelResponse:

def update_stream_data_and_assistant_message(
self, stream_data: MessageData, assistant_message: Message, model_response: ModelResponse
) -> ModelResponse:
) -> Iterator[ModelResponse]:
"""Update the stream data and assistant message with the model response."""

# Update metrics
assistant_message.metrics.completion_tokens += 1
if not assistant_message.metrics.time_to_first_token:
assistant_message.metrics.set_time_to_first_token()

should_yield = False
# Update stream_data content
if model_response.content is not None:
stream_data.response_content += model_response.content
should_yield = True

# Update stream_data tool calls
if model_response.tool_calls is not None:
if stream_data.response_tool_calls is None:
stream_data.response_tool_calls = []
stream_data.response_tool_calls.extend(model_response.tool_calls)
should_yield = True

if model_response.audio is not None:
stream_data.response_audio = model_response.audio
should_yield = True

if model_response.extra is not None:
stream_data.extra.update(model_response.extra)
Expand All @@ -880,6 +884,9 @@ def update_stream_data_and_assistant_message(
assistant_message=assistant_message, response_usage=model_response.response_usage
)

if should_yield:
yield model_response

def process_response_stream(
self, messages: List[Message], assistant_message: Message, stream_data: MessageData
) -> Iterator[ModelResponse]:
Expand All @@ -889,8 +896,7 @@ def process_response_stream(
for response_delta in self.invoke_stream(messages=messages):
model_response_delta = self.parse_provider_response_delta(response_delta)
if model_response_delta:
yield model_response_delta
self.update_stream_data_and_assistant_message(
yield from self.update_stream_data_and_assistant_message(
stream_data=stream_data, assistant_message=assistant_message, model_response=model_response_delta
)

Expand Down Expand Up @@ -975,10 +981,10 @@ async def aprocess_response_stream(
async for response_delta in await self.ainvoke_stream(messages=messages):
model_response_delta = self.parse_provider_response_delta(response_delta)
if model_response_delta:
yield model_response_delta
self.update_stream_data_and_assistant_message(
for model_response in self.update_stream_data_and_assistant_message(
stream_data=stream_data, assistant_message=assistant_message, model_response=model_response_delta
)
):
yield model_response

async def aresponse_stream(self, messages: List[Message]) -> AsyncIterator[ModelResponse]:
"""
Expand Down
Loading

0 comments on commit e6af03d

Please sign in to comment.