Skip to content

Commit

Permalink
Special handling for Azure OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
TaoChenOSU committed Sep 25, 2024
1 parent 3398b7a commit 3a5f8be
Showing 1 changed file with 51 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# Copyright (c) Microsoft. All rights reserved.

import json
import logging
from collections.abc import Mapping
import sys
from collections.abc import AsyncGenerator, Mapping
from copy import deepcopy
from typing import Any, TypeVar
from uuid import uuid4

from openai import AsyncAzureOpenAI
if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
else:
from typing_extensions import override # pragma: no cover

from openai import AsyncAzureOpenAI, AsyncStream
from openai.lib.azure import AsyncAzureADTokenProvider
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
Expand All @@ -16,19 +23,24 @@
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.azure_chat_prompt_execution_settings import (
AzureChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
OpenAIChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.open_ai.services.azure_config_base import AzureOpenAIConfigBase
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion_base import OpenAIChatCompletionBase
from semantic_kernel.connectors.ai.open_ai.services.open_ai_handler import OpenAIModelTypes
from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_completion_base import OpenAITextCompletionBase
from semantic_kernel.connectors.ai.open_ai.settings.azure_open_ai_settings import AzureOpenAISettings
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.finish_reason import FinishReason
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidResponseError
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import trace_streaming_chat_completion

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -126,6 +138,42 @@ def __init__(
client=async_client,
)

@override
@trace_streaming_chat_completion(OpenAIChatCompletionBase.MODEL_PROVIDER_NAME)
async def _inner_get_streaming_chat_message_contents(
self,
chat_history: "ChatHistory",
settings: "PromptExecutionSettings",
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
"""Override the base method.
This is because the latest Azure OpenAI API GA version doesn't support `stream_option`
yet and it will potentially result in errors if the option is included.
This method will be called instead of the base method.
TODO: Remove this method when the `stream_option` is supported by the Azure OpenAI API.
GitHub Issue: https://github.com/microsoft/semantic-kernel/issues/8996
"""
if not isinstance(settings, OpenAIChatPromptExecutionSettings):
settings = self.get_prompt_execution_settings_from_settings(settings)
assert isinstance(settings, OpenAIChatPromptExecutionSettings) # nosec

settings.stream = True
settings.messages = self._prepare_chat_history_for_request(chat_history)
settings.ai_model_id = settings.ai_model_id or self.ai_model_id

response = await self._send_request(request_settings=settings)
if not isinstance(response, AsyncStream):
raise ServiceInvalidResponseError("Expected an AsyncStream[ChatCompletionChunk] response.")
async for chunk in response:
if len(chunk.choices) == 0:
continue

assert isinstance(chunk, ChatCompletionChunk) # nosec
chunk_metadata = self._get_metadata_from_streaming_chat_response(chunk)
yield [
self._create_streaming_chat_message_content(chunk, choice, chunk_metadata) for choice in chunk.choices
]

@classmethod
def from_dict(cls, settings: dict[str, Any]) -> "AzureChatCompletion":
"""Initialize an Azure OpenAI service from a dictionary of settings.
Expand Down

0 comments on commit 3a5f8be

Please sign in to comment.