Skip to content

Commit

Permalink
updated parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Dec 9, 2024
1 parent ec5bb77 commit d86e000
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import sys
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand Down Expand Up @@ -47,6 +47,8 @@
if TYPE_CHECKING:
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings

CMC_TYPE = TypeVar("CMC_TYPE", bound=ChatMessageContent)

logger: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -163,17 +165,14 @@ async def _inner_get_chat_message_contents(
**settings.prepare_settings_dict(),
)

if not isinstance(response_object, Mapping):
raise ServiceInvalidResponseError(
f"Invalid response type from Ollama chat completion. Expected Mapping but got {type(response_object)}."
)

return [
self._create_chat_message_content(
response_object,
self._get_metadata_from_response(response_object),
)
]
if isinstance(response_object, ChatResponse):
return [self._create_chat_message_content_by_type(response_object, ChatMessageContent)]
if isinstance(response_object, Mapping):
return [self._create_chat_message_content(response_object)]
raise ServiceInvalidResponseError(
"Invalid response type from Ollama chat completion. "
f"Expected Mapping or ChatResponse but got {type(response_object)}."
)

@override
@trace_streaming_chat_completion(OllamaBase.MODEL_PROVIDER_NAME)
Expand All @@ -197,21 +196,66 @@ async def _inner_get_streaming_chat_message_contents(

if not isinstance(response_object, AsyncIterator):
raise ServiceInvalidResponseError(
"Invalid response type from Ollama chat completion. "
"Invalid response type from Ollama streaming chat completion. "
f"Expected AsyncIterator but got {type(response_object)}."
)

async for part in response_object:
yield [
self._create_streaming_chat_message_content(
part,
self._get_metadata_from_response(part),
)
]
if isinstance(part, ChatResponse):
yield [self._create_chat_message_content_by_type(part, StreamingChatMessageContent)]
continue
if isinstance(part, Mapping):
yield [self._create_streaming_chat_message_content(part)]
continue
raise ServiceInvalidResponseError(
"Invalid response type from Ollama streaming chat completion. "
f"Expected mapping or ChatResponse but got {type(part)}."
)

# endregion

def _create_chat_message_content(self, response: Mapping[str, Any], metadata: dict[str, Any]) -> ChatMessageContent:
def _create_chat_message_content_by_type(self, response: ChatResponse, type: type[CMC_TYPE]) -> CMC_TYPE:
"""Create a chat message content from the response."""
items: list[ITEM_TYPES] = []
if response.message is None:
raise ServiceInvalidResponseError("No message content found in response.")
if response.message.content:
items.append(
TextContent(
text=response.message.content,
inner_content=response.message,
)
)
if response.message.tool_calls:
for tool_call in response.message.tool_calls:
items.append(
FunctionCallContent(
inner_content=tool_call,
ai_model_id=self.ai_model_id,
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
)
metadata = self._get_metadata_from_chat_response(response)
if type is StreamingChatMessageContent:
return type(
choice_index=0,
role=AuthorRole.ASSISTANT,
items=items,
inner_content=response,
ai_model_id=self.ai_model_id,
metadata=metadata,
)

return type(
role=AuthorRole.ASSISTANT,
items=items,
inner_content=response,
ai_model_id=self.ai_model_id,
metadata=metadata,
)

def _create_chat_message_content(self, response: Mapping[str, Any]) -> ChatMessageContent:
"""Create a chat message content from the response."""
items: list[ITEM_TYPES] = []
if not (message := response.get("message", None)):
Expand Down Expand Up @@ -239,34 +283,12 @@ def _create_chat_message_content(self, response: Mapping[str, Any], metadata: di
role=AuthorRole.ASSISTANT,
items=items,
inner_content=response,
metadata=metadata,
metadata=self._get_metadata_from_response(response),
)

def _create_streaming_chat_message_content(
self, part: Mapping[str, Any] | ChatResponse, metadata: dict[str, Any]
) -> StreamingChatMessageContent:
def _create_streaming_chat_message_content(self, part: Mapping[str, Any]) -> StreamingChatMessageContent:
"""Create a streaming chat message content from the response part."""
items: list[STREAMING_ITEM_TYPES] = []
if isinstance(part, ChatResponse):
if part.message is None:
raise ServiceInvalidResponseError("No message content found in response part.")
if part.message.content:
items.append(
StreamingTextContent(
choice_index=0,
text=part.message.content,
inner_content=part.message,
)
)
return StreamingChatMessageContent(
role=AuthorRole.ASSISTANT,
choice_index=0,
items=items,
inner_content=part,
ai_model_id=self.ai_model_id,
metadata=metadata,
)

if not (message := part.get("message", None)):
raise ServiceInvalidResponseError("No message content found in response part.")

Expand All @@ -278,28 +300,28 @@ def _create_streaming_chat_message_content(
inner_content=message,
)
)
if tool_calls := message.get("tool_calls", None):
for tool_call in tool_calls:
items.append(
FunctionCallContent(
inner_content=tool_call,
ai_model_id=self.ai_model_id,
name=tool_call.get("function").get("name"),
arguments=tool_call.get("function").get("arguments"),
)
)

return StreamingChatMessageContent(
role=AuthorRole.ASSISTANT,
choice_index=0,
items=items,
inner_content=part,
ai_model_id=self.ai_model_id,
metadata=metadata,
metadata=self._get_metadata_from_response(part),
)

def _get_metadata_from_response(self, response: Mapping[str, Any] | ChatResponse) -> dict[str, Any]:
def _get_metadata_from_response(self, response: Mapping[str, Any]) -> dict[str, Any]:
"""Get metadata from the response."""
if isinstance(response, ChatResponse):
metadata: dict[str, Any] = {
"model": response.model,
}
if response.prompt_eval_count and response.eval_count:
metadata["usage"] = CompletionUsage(
prompt_tokens=response.prompt_eval_count,
completion_tokens=response.eval_count,
)
return metadata
metadata = {
"model": response.get("model"),
}
Expand All @@ -311,3 +333,15 @@ def _get_metadata_from_response(self, response: Mapping[str, Any] | ChatResponse
)

return metadata

def _get_metadata_from_chat_response(self, response: ChatResponse) -> dict[str, Any]:
"""Get metadata from the response."""
metadata: dict[str, Any] = {
"model": response.model,
}
if response.prompt_eval_count and response.eval_count:
metadata["usage"] = CompletionUsage(
prompt_tokens=response.prompt_eval_count,
completion_tokens=response.eval_count,
)
return metadata
14 changes: 4 additions & 10 deletions python/tests/integration/completions/chat_completion_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.utils.authentication.entra_id_authentication import get_entra_auth_token
from tests.integration.completions.completion_test_base import CompletionTestBase, ServiceType
from tests.utils import is_service_setup_for_testing, is_test_running_on_supported_platforms
from tests.utils import is_service_setup_for_testing

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand All @@ -56,15 +56,9 @@
# There is no single model in Ollama that supports both image and tool call in chat completion
# We are splitting the Ollama test into three services: chat, image, and tool call. The chat model
# can be any model that supports chat completion. Also, Ollama is only available on Linux runners in our pipeline.
ollama_setup: bool = is_service_setup_for_testing(["OLLAMA_CHAT_MODEL_ID"]) and is_test_running_on_supported_platforms([
"Linux"
])
ollama_image_setup: bool = is_service_setup_for_testing([
"OLLAMA_CHAT_MODEL_ID_IMAGE"
]) and is_test_running_on_supported_platforms(["Linux"])
ollama_tool_call_setup: bool = is_service_setup_for_testing([
"OLLAMA_CHAT_MODEL_ID_TOOL_CALL"
]) and is_test_running_on_supported_platforms(["Linux"])
ollama_setup: bool = is_service_setup_for_testing(["OLLAMA_CHAT_MODEL_ID"])
ollama_image_setup: bool = is_service_setup_for_testing(["OLLAMA_CHAT_MODEL_ID_IMAGE"])
ollama_tool_call_setup: bool = is_service_setup_for_testing(["OLLAMA_CHAT_MODEL_ID_TOOL_CALL"])
google_ai_setup: bool = is_service_setup_for_testing(["GOOGLE_AI_API_KEY", "GOOGLE_AI_GEMINI_MODEL_ID"])
vertex_ai_setup: bool = is_service_setup_for_testing(["VERTEX_AI_PROJECT_ID", "VERTEX_AI_GEMINI_MODEL_ID"])
onnx_setup: bool = is_service_setup_for_testing(
Expand Down

0 comments on commit d86e000

Please sign in to comment.