Skip to content

Commit

Permalink
Python: Multiple results per prompt (incl. streaming) (#1316)
Browse files Browse the repository at this point in the history
### Motivation and Context
To take advantage of APIs offered by most LLMs and to be in sync with
.NET SK, this PR introduces the ability to generate multiple text
completions or chat completions from a single prompt.


![MultiChatCompletionStreamMulti](https://github.com/microsoft/semantic-kernel/assets/54643756/7bec03ec-0be2-40b0-b938-6ff71beac209)


### Description
- Return type hint for `complete_async` and `complete_chat_async`
changed from `str -> Union[str, List[str]]`. The use of `Union` is the
proper way to indicate multiple return types prior to Python 3.10+. 3.10
supports the use of the `|` symbol, but since the Python SK is supported
on 3.8 and 3.9, I did not adopt the newer standard.
- `complete_async`, `complete_stream_async`, `complete_chat_async`, and
`complete_chat_stream_async` now support settings field
`number_of_responses` greater than 1. Previously only a value of 1 was
supported.
- **Note: hf_text_completion does not support streaming multiple
responses due to a limitation of TextIteratorStreamer. This feature
requires the ability to parse distinct responses from
TextIteratorStreamer.**
- Fixed a bug where `complete_async` was streaming single responses as
1D arrays, content now is simply a string.

### Example Usage
#### Setup
```
    kernel = sk.Kernel()

    # Configure OpenAI service
    api_key, org_id = sk.openai_settings_from_dot_env()
    oai_text_service = OpenAITextCompletion("text-davinci-003", api_key, org_id)
    oai_chat_service = OpenAIChatCompletion("gpt-3.5-turbo", api_key, org_id)

    # Configure Hugging Face service
    hf_text_service = HuggingFaceTextCompletion("gpt2", task="text-generation")
  
    # Configure Prompt
    prompt = "what is the purpose of a rubber duck?"

    # Configure Request Settings
    text_request_settings_multi = CompleteRequestSettings(
        max_tokens=100,
        temperature=0.7,
        top_p=1,
        frequency_penalty=0.5,
        presence_penalty=0.8,
        number_of_responses=4
    )

    chat_request_settings_multi = ChatRequestSettings(
        max_tokens=100,
        temperature=0.7,
        top_p=1,
        frequency_penalty=0.5,
        presence_penalty=0.8,
        number_of_responses=4
    )
```

#### Text Completion (Standard)
```
    texts = await oai_text_service.complete_async(prompt, text_request_settings_multi)
    i = 0
    for text in texts:
        print("Option " + str(i) + ": " + text)
        i += 1
```

#### Streaming Text Completion
```
    multi_stream = oai_text_service.complete_stream_async(prompt, text_request_settings_multi)
    texts = [''] * text_request_settings_multi.number_of_responses
    async for text in multi_stream:
        i = 0
        os.system('cls' if os.name == 'nt' else 'clear') # clear the screen for a better experience
        print("PROMPT: " + prompt)
        for option in text:
            texts[i] = texts[i] + option
            print("{0}: {1}".format(i, texts[i]))
            i += 1
```

#### Chat Completion (Standard)
```
    texts = await oai_chat_service.complete_chat_async([("user",prompt)], chat_request_settings_multi)
    i = 0
    for text in texts:
        print("Option " + str(i) + ": " + text)
        i += 1
```

#### Streaming Chat Completion
```
    multi_stream = oai_chat_service.complete_chat_stream_async([("user",prompt)], chat_request_settings_multi)
    texts = [''] * chat_request_settings_multi.number_of_responses
    async for text in multi_stream:
        i = 0
        os.system('cls' if os.name == 'nt' else 'clear') # clear the screen for a better experience
        print("PROMPT: " + prompt)
        for option in text:
            texts[i] = texts[i] + option
            print("{0}: {1}".format(i, texts[i]))
            i += 1
```

#### HuggingFace Standard Completion
```
    texts = await hf_text_service.complete_async(prompt, request_settings_multi)
    i = 0
    for text in texts:
        print("-----------------------------------")
        print("Option " + str(i) + ": " + text)
        i += 1
```
  • Loading branch information
awharrison-28 authored Jun 9, 2023
1 parent dae1c16 commit b2e1548
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, List, Tuple, Union

if TYPE_CHECKING:
from semantic_kernel.connectors.ai.chat_request_settings import ChatRequestSettings
Expand All @@ -15,7 +15,19 @@ async def complete_chat_async(
messages: List[Tuple[str, str]],
settings: "ChatRequestSettings",
logger: Logger,
) -> str:
) -> Union[str, List[str]]:
"""
This is the method that is called from the kernel to get a response from a chat-optimized LLM.
Arguments:
messages {List[Tuple[str, str]]} -- A list of tuples, where each tuple is
comprised of a speaker ID and a message.
settings {ChatRequestSettings} -- Settings for the request.
logger {Logger} -- A logger to use for logging.
Returns:
Union[str, List[str]] -- A string or list of strings representing the response(s) from the LLM.
"""
pass

@abstractmethod
Expand All @@ -25,4 +37,16 @@ async def complete_chat_stream_async(
settings: "ChatRequestSettings",
logger: Logger,
):
"""
This is the method that is called from the kernel to get a stream response from a chat-optimized LLM.
Arguments:
messages {List[Tuple[str, str]]} -- A list of tuples, where each tuple is
comprised of a speaker ID and a message.
settings {ChatRequestSettings} -- Settings for the request.
logger {Logger} -- A logger to use for logging.
Yields:
A stream representing the response(s) from the LLM.
"""
pass
2 changes: 2 additions & 0 deletions python/semantic_kernel/connectors/ai/chat_request_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class ChatRequestSettings:
top_p: float = 1.0
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
number_of_responses: int = 1
max_tokens: int = 256

def update_from_completion_config(
Expand All @@ -24,6 +25,7 @@ def update_from_completion_config(
self.top_p = completion_config.top_p
self.presence_penalty = completion_config.presence_penalty
self.frequency_penalty = completion_config.frequency_penalty
self.number_of_responses = completion_config.number_of_responses
self.max_tokens = completion_config.max_tokens

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def update_from_completion_config(
self.frequency_penalty = completion_config.frequency_penalty
self.max_tokens = completion_config.max_tokens
self.stop_sequences = completion_config.stop_sequences
self.number_of_responses = completion_config.number_of_responses

@staticmethod
def from_completion_config(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from logging import Logger
from threading import Thread
from typing import Optional
from typing import List, Optional, Union

from semantic_kernel.connectors.ai.ai_exception import AIException
from semantic_kernel.connectors.ai.complete_request_settings import (
Expand Down Expand Up @@ -64,17 +64,7 @@ def __init__(

async def complete_async(
self, prompt: str, request_settings: CompleteRequestSettings
) -> str:
"""
Completes a prompt using the Hugging Face model.
Arguments:
prompt {str} -- Prompt to complete.
request_settings {CompleteRequestSettings} -- Request settings.
Returns:
str -- Completion result.
"""
) -> Union[str, List[str]]:
try:
import transformers

Expand All @@ -84,15 +74,30 @@ async def complete_async(
max_new_tokens=request_settings.max_tokens,
pad_token_id=50256, # EOS token
)
result = self.generator(
prompt, num_return_sequences=1, generation_config=generation_config

results = self.generator(
prompt,
do_sample=True,
num_return_sequences=request_settings.number_of_responses,
generation_config=generation_config,
)

completions = list()
if self._task == "text-generation" or self._task == "text2text-generation":
return result[0]["generated_text"]
for response in results:
completions.append(response["generated_text"])
if len(completions) == 1:
return completions[0]
else:
return completions

elif self._task == "summarization":
return result[0]["summary_text"]
for response in results:
completions.append(response["summary_text"])
if len(completions) == 1:
return completions[0]
else:
return completions

else:
raise AIException(
Expand All @@ -107,6 +112,23 @@ async def complete_async(
async def complete_stream_async(
self, prompt: str, request_settings: CompleteRequestSettings
):
"""
Streams a text completion using a Hugging Face model.
Note that this method does not support multiple responses.
Arguments:
prompt {str} -- Prompt to complete.
request_settings {CompleteRequestSettings} -- Request settings.
Yields:
str -- Completion result.
"""
if request_settings.number_of_responses > 1:
raise AIException(
AIException.ErrorCodes.InvalidConfiguration,
"HuggingFace TextIteratorStreamer does not stream multiple responses in a parseable format. \
If you need multiple responses, please use the complete_async method.",
)
try:
import transformers

Expand All @@ -116,15 +138,18 @@ async def complete_stream_async(
max_new_tokens=request_settings.max_tokens,
pad_token_id=50256, # EOS token
)

tokenizer = transformers.AutoTokenizer.from_pretrained(self._model_id)
streamer = transformers.TextIteratorStreamer(tokenizer)
args = {"prompt": prompt}
args = {prompt}
kwargs = {
"num_return_sequences": 1,
"num_return_sequences": request_settings.number_of_responses,
"generation_config": generation_config,
"streamer": streamer,
"do_sample": True,
}

# See https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py#L159
thread = Thread(target=self.generator, args=args, kwargs=kwargs)
thread.start()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

from logging import Logger
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union

import openai

Expand Down Expand Up @@ -61,28 +61,37 @@ def __init__(

async def complete_chat_async(
self, messages: List[Tuple[str, str]], request_settings: ChatRequestSettings
) -> str:
) -> Union[str, List[str]]:
# TODO: tracking on token counts/etc.
response = await self._send_chat_request(messages, request_settings, False)

return response.choices[0].message.content
if len(response.choices) == 1:
return response.choices[0].message.content
else:
return [choice.message.content for choice in response.choices]

async def complete_chat_stream_async(
self, messages: List[Tuple[str, str]], request_settings: ChatRequestSettings
):
response = await self._send_chat_request(messages, request_settings, True)

# parse the completion text(s) and yield them
async for chunk in response:
if "role" in chunk.choices[0].delta:
yield chunk.choices[0].delta.role + ": "
if "content" in chunk.choices[0].delta:
yield chunk.choices[0].delta.content
text, index = _parse_choices(chunk)
# if multiple responses are requested, keep track of them
if request_settings.number_of_responses > 1:
completions = [""] * request_settings.number_of_responses
completions[index] = text
yield completions
# if only one response is requested, yield it
else:
yield text

async def complete_async(
self, prompt: str, request_settings: CompleteRequestSettings
) -> str:
) -> Union[str, List[str]]:
"""
Completes the given prompt. Returns a single string completion.
Cannot return multiple completions. Cannot return logprobs.
Completes the given prompt.
Arguments:
prompt {str} -- The prompt to complete.
Expand All @@ -98,12 +107,16 @@ async def complete_async(
presence_penalty=request_settings.presence_penalty,
frequency_penalty=request_settings.frequency_penalty,
max_tokens=request_settings.max_tokens,
number_of_responses=request_settings.number_of_responses,
)
response = await self._send_chat_request(
prompt_to_message, chat_settings, False
)

return response.choices[0].message.content
if len(response.choices) == 1:
return response.choices[0].message.content
else:
return [choice.message.content for choice in response.choices]

async def complete_stream_async(
self, prompt: str, request_settings: CompleteRequestSettings
Expand All @@ -115,12 +128,21 @@ async def complete_stream_async(
presence_penalty=request_settings.presence_penalty,
frequency_penalty=request_settings.frequency_penalty,
max_tokens=request_settings.max_tokens,
number_of_responses=request_settings.number_of_responses,
)
response = await self._send_chat_request(prompt_to_message, chat_settings, True)

# parse the completion text(s) and yield them
async for chunk in response:
if "content" in chunk.choices[0].delta:
yield chunk.choices[0].delta.content
text, index = _parse_choices(chunk)
# if multiple responses are requested, keep track of them
if request_settings.number_of_responses > 1:
completions = [""] * request_settings.number_of_responses
completions[index] = text
yield completions
# if only one response is requested, yield it
else:
yield text

async def _send_chat_request(
self,
Expand All @@ -129,7 +151,7 @@ async def _send_chat_request(
stream: bool,
):
"""
Completes the given user message. Returns a single string completion.
Completes the given user message with an asynchronous stream.
Arguments:
user_message {str} -- The message (from a user) to respond to.
Expand Down Expand Up @@ -184,6 +206,7 @@ async def _send_chat_request(
presence_penalty=request_settings.presence_penalty,
frequency_penalty=request_settings.frequency_penalty,
max_tokens=request_settings.max_tokens,
n=request_settings.number_of_responses,
stream=stream,
)
except Exception as ex:
Expand All @@ -196,3 +219,14 @@ async def _send_chat_request(
# TODO: tracking on token counts/etc.

return response


def _parse_choices(chunk):
message = ""
if "role" in chunk.choices[0].delta:
message += chunk.choices[0].delta.role + ": "
if "content" in chunk.choices[0].delta:
message += chunk.choices[0].delta.content

index = chunk.choices[0].index
return message, index
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

from logging import Logger
from typing import Any, Optional
from typing import Any, List, Optional, Union

import openai

Expand Down Expand Up @@ -56,19 +56,30 @@ def __init__(

async def complete_async(
self, prompt: str, request_settings: CompleteRequestSettings
) -> str:
) -> Union[str, List[str]]:
# TODO: tracking on token counts/etc.
response = await self._send_completion_request(prompt, request_settings, False)
return response.choices[0].text

if len(response.choices) == 1:
return response.choices[0].text
else:
return [choice.text for choice in response.choices]

# TODO: complete w/ multiple...

async def complete_stream_async(
self, prompt: str, request_settings: CompleteRequestSettings
):
response = await self._send_completion_request(prompt, request_settings, True)

async for chunk in response:
yield chunk.choices[0].text
if request_settings.number_of_responses > 1:
for choice in chunk.choices:
completions = [""] * request_settings.number_of_responses
completions[choice.index] = choice.text
yield completions
else:
yield chunk.choices[0].text

async def _send_completion_request(
self, prompt: str, request_settings: CompleteRequestSettings, stream: bool
Expand Down Expand Up @@ -96,13 +107,6 @@ async def _send_completion_request(
f"but was {request_settings.max_tokens}",
)

if request_settings.number_of_responses != 1:
raise AIException(
AIException.ErrorCodes.InvalidRequest,
"complete_async only supports a single completion, "
f"but {request_settings.number_of_responses} were requested",
)

if request_settings.logprobs != 0:
raise AIException(
AIException.ErrorCodes.InvalidRequest,
Expand Down Expand Up @@ -131,6 +135,7 @@ async def _send_completion_request(
frequency_penalty=request_settings.frequency_penalty,
max_tokens=request_settings.max_tokens,
stream=stream,
n=request_settings.number_of_responses,
stop=(
request_settings.stop_sequences
if request_settings.stop_sequences is not None
Expand Down
Loading

0 comments on commit b2e1548

Please sign in to comment.