Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: enable ollama streaming tool calls #9890

Merged
merged 9 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ mistralai = [
"mistralai >= 1.2,< 2.0"
]
ollama = [
"ollama ~= 0.2"
"ollama ~= 0.4"
]
onnx = [
"onnxruntime-genai ~= 0.4; platform_system != 'Darwin'"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from collections.abc import Callable
from collections.abc import Callable, Mapping

from azure.ai.inference.models import (
AssistantMessage,
Expand Down Expand Up @@ -98,7 +98,7 @@ def _format_assistant_message(message: ChatMessageContent) -> AssistantMessage:
function=FunctionCall(
name=item.name or "",
arguments=json.dumps(item.arguments)
if isinstance(item.arguments, dict)
if isinstance(item.arguments, Mapping)
else item.arguments or "",
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import json
from collections.abc import Callable
from collections.abc import Callable, Mapping
from functools import partial
from typing import Any

Expand Down Expand Up @@ -108,7 +108,9 @@ def _format_assistant_message(message: ChatMessageContent) -> dict[str, Any]:
"toolUse": {
"toolUseId": item.id,
"name": item.custom_fully_qualified_name(BEDROCK_FUNCTION_NAME_SEPARATOR),
"input": item.arguments if isinstance(item.arguments, dict) else json.loads(item.arguments or "{}"),
"input": item.arguments
if isinstance(item.arguments, Mapping)
else json.loads(item.arguments or "{}"),
}
})
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ async def get_chat_message_contents(
"""
# Create a copy of the settings to avoid modifying the original settings
settings = copy.deepcopy(settings)
# Later on, we already use the tools or equivalent settings, we cast here.
if not isinstance(settings, self.get_prompt_execution_settings_class()):
settings = self.get_prompt_execution_settings_from_settings(settings)

if not self.SUPPORTS_FUNCTION_CALLING:
return await self._inner_get_chat_message_contents(chat_history, settings)
Expand Down Expand Up @@ -211,6 +214,9 @@ async def get_streaming_chat_message_contents(
"""
# Create a copy of the settings to avoid modifying the original settings
settings = copy.deepcopy(settings)
# Later on, we already use the tools or equivalent settings, we cast here.
if not isinstance(settings, self.get_prompt_execution_settings_class()):
settings = self.get_prompt_execution_settings_from_settings(settings)

if not self.SUPPORTS_FUNCTION_CALLING:
async for streaming_chat_message_contents in self._inner_get_streaming_chat_message_contents(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import logging
import sys
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping
from typing import TYPE_CHECKING, Any, ClassVar
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand All @@ -12,7 +12,7 @@

import httpx
from ollama import AsyncClient
from ollama._types import Message
from ollama._types import ChatResponse, Message
from pydantic import ValidationError

from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
Expand Down Expand Up @@ -47,6 +47,8 @@
if TYPE_CHECKING:
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings

CMC_TYPE = TypeVar("CMC_TYPE", bound=ChatMessageContent)

logger: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -163,17 +165,14 @@ async def _inner_get_chat_message_contents(
**settings.prepare_settings_dict(),
)

if not isinstance(response_object, Mapping):
raise ServiceInvalidResponseError(
f"Invalid response type from Ollama chat completion. Expected Mapping but got {type(response_object)}."
)

return [
self._create_chat_message_content(
response_object,
self._get_metadata_from_response(response_object),
)
]
if isinstance(response_object, ChatResponse):
return [self._create_chat_message_content_from_chat_response(response_object)]
if isinstance(response_object, Mapping):
return [self._create_chat_message_content(response_object)]
raise ServiceInvalidResponseError(
"Invalid response type from Ollama chat completion. "
f"Expected Mapping or ChatResponse but got {type(response_object)}."
)

@override
@trace_streaming_chat_completion(OllamaBase.MODEL_PROVIDER_NAME)
Expand All @@ -186,11 +185,6 @@ async def _inner_get_streaming_chat_message_contents(
settings = self.get_prompt_execution_settings_from_settings(settings)
TaoChenOSU marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(settings, OllamaChatPromptExecutionSettings) # nosec

if settings.tools:
raise ServiceInvalidExecutionSettingsError(
"Ollama does not support tool calling in streaming chat completion."
)

prepared_chat_history = self._prepare_chat_history_for_request(chat_history)

response_object = await self.client.chat(
Expand All @@ -202,21 +196,79 @@ async def _inner_get_streaming_chat_message_contents(

if not isinstance(response_object, AsyncIterator):
raise ServiceInvalidResponseError(
"Invalid response type from Ollama chat completion. "
"Invalid response type from Ollama streaming chat completion. "
f"Expected AsyncIterator but got {type(response_object)}."
)

async for part in response_object:
yield [
self._create_streaming_chat_message_content(
part,
self._get_metadata_from_response(part),
)
]
if isinstance(part, ChatResponse):
yield [self._create_streaming_chat_message_content_from_chat_response(part)]
continue
if isinstance(part, Mapping):
yield [self._create_streaming_chat_message_content(part)]
continue
raise ServiceInvalidResponseError(
"Invalid response type from Ollama streaming chat completion. "
f"Expected mapping or ChatResponse but got {type(part)}."
)

# endregion

def _create_chat_message_content(self, response: Mapping[str, Any], metadata: dict[str, Any]) -> ChatMessageContent:
def _create_streaming_chat_message_content_from_chat_response(
self, response: ChatResponse
) -> StreamingChatMessageContent:
"""Create a chat message content from the response."""
items: list[STREAMING_ITEM_TYPES] = []
if response.message.content:
items.append(
StreamingTextContent(
choice_index=0,
text=response.message.content,
inner_content=response.message,
)
)
self._parse_tool_calls(response.message.tool_calls, items)
return StreamingChatMessageContent(
choice_index=0,
role=AuthorRole.ASSISTANT,
items=items,
inner_content=response,
ai_model_id=self.ai_model_id,
metadata=self._get_metadata_from_chat_response(response),
)

def _parse_tool_calls(self, tool_calls: Sequence[Message.ToolCall] | None, items: list[Any]):
if tool_calls:
for tool_call in tool_calls:
items.append(
FunctionCallContent(
inner_content=tool_call,
ai_model_id=self.ai_model_id,
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
)

def _create_chat_message_content_from_chat_response(self, response: ChatResponse) -> ChatMessageContent:
"""Create a chat message content from the response."""
items: list[ITEM_TYPES] = []
if response.message.content:
items.append(
TextContent(
text=response.message.content,
inner_content=response.message,
)
)
self._parse_tool_calls(response.message.tool_calls, items)
return ChatMessageContent(
role=AuthorRole.ASSISTANT,
items=items,
inner_content=response,
ai_model_id=self.ai_model_id,
metadata=self._get_metadata_from_chat_response(response),
)

def _create_chat_message_content(self, response: Mapping[str, Any]) -> ChatMessageContent:
"""Create a chat message content from the response."""
items: list[ITEM_TYPES] = []
if not (message := response.get("message", None)):
Expand Down Expand Up @@ -244,12 +296,10 @@ def _create_chat_message_content(self, response: Mapping[str, Any], metadata: di
role=AuthorRole.ASSISTANT,
items=items,
inner_content=response,
metadata=metadata,
metadata=self._get_metadata_from_response(response),
)

def _create_streaming_chat_message_content(
self, part: Mapping[str, Any], metadata: dict[str, Any]
) -> StreamingChatMessageContent:
def _create_streaming_chat_message_content(self, part: Mapping[str, Any]) -> StreamingChatMessageContent:
"""Create a streaming chat message content from the response part."""
items: list[STREAMING_ITEM_TYPES] = []
if not (message := part.get("message", None)):
Expand All @@ -263,14 +313,24 @@ def _create_streaming_chat_message_content(
inner_content=message,
)
)
if tool_calls := message.get("tool_calls", None):
for tool_call in tool_calls:
items.append(
FunctionCallContent(
inner_content=tool_call,
ai_model_id=self.ai_model_id,
name=tool_call.get("function").get("name"),
arguments=tool_call.get("function").get("arguments"),
)
)

return StreamingChatMessageContent(
role=AuthorRole.ASSISTANT,
choice_index=0,
items=items,
inner_content=part,
ai_model_id=self.ai_model_id,
metadata=metadata,
metadata=self._get_metadata_from_response(part),
)

def _get_metadata_from_response(self, response: Mapping[str, Any]) -> dict[str, Any]:
Expand All @@ -286,3 +346,15 @@ def _get_metadata_from_response(self, response: Mapping[str, Any]) -> dict[str,
)

return metadata

def _get_metadata_from_chat_response(self, response: ChatResponse) -> dict[str, Any]:
"""Get metadata from the response."""
metadata: dict[str, Any] = {
"model": response.model,
}
if response.prompt_eval_count and response.eval_count:
metadata["usage"] = CompletionUsage(
prompt_tokens=response.prompt_eval_count,
completion_tokens=response.eval_count,
)
return metadata
17 changes: 11 additions & 6 deletions python/semantic_kernel/connectors/ai/ollama/services/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Copyright (c) Microsoft. All rights reserved.

import json
from collections.abc import Callable
from collections.abc import Callable, Mapping

from ollama._types import Message

from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration
from semantic_kernel.connectors.ai.function_calling_utils import kernel_function_metadata_to_function_call_format
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceType
from semantic_kernel.connectors.ai.ollama.ollama_prompt_execution_settings import OllamaChatPromptExecutionSettings
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
Expand Down Expand Up @@ -76,7 +75,7 @@ def _format_assistant_message(message: ChatMessageContent) -> Message:
"function": {
"name": tool_call.function_name,
"arguments": tool_call.arguments
if isinstance(tool_call.arguments, dict)
if isinstance(tool_call.arguments, Mapping)
else json.loads(tool_call.arguments or "{}"),
}
}
Expand Down Expand Up @@ -115,11 +114,17 @@ def update_settings_from_function_choice_configuration(
settings: PromptExecutionSettings,
type: FunctionChoiceType,
) -> None:
"""Update the settings from a FunctionChoiceConfiguration."""
assert isinstance(settings, OllamaChatPromptExecutionSettings) # nosec
"""Update the settings from a FunctionChoiceConfiguration.

Since this function might be called before the settings are cast to Ollama Settings
TaoChenOSU marked this conversation as resolved.
Show resolved Hide resolved
We need to try to use the tools attribute or fallback to the extension_data attribute.
"""
if function_choice_configuration.available_functions:
settings.tools = [
tools = [
kernel_function_metadata_to_function_call_format(f)
for f in function_choice_configuration.available_functions
]
try:
settings.tools = tools # type: ignore
except Exception:
settings.extension_data["tools"] = tools
Loading
Loading