Skip to content

Commit

Permalink
.Net Agents - Support AdditionalMessages for OpenAIAssistantAgent (#…
Browse files Browse the repository at this point in the history
…9737)

### Motivation and Context
<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Add support for `AdditionalMessages` option when invoking a run.

Fixes: #9685

### Description
<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

Allow the addition of multiple messages to a thread when invoking a
`OpenAIAssistantAgent`.

### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
crickman authored Nov 18, 2024
1 parent fa24473 commit cde12d3
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public static IEnumerable<MessageContent> GetMessageContents(ChatMessageContent
{
yield return MessageContent.FromImageUri(imageContent.Uri);
}
else if (string.IsNullOrWhiteSpace(imageContent.DataUri))
else if (!string.IsNullOrWhiteSpace(imageContent.DataUri))
{
yield return MessageContent.FromImageUri(new(imageContent.DataUri!));
}
Expand Down
13 changes: 13 additions & 0 deletions dotnet/src/Agents/OpenAI/Internal/AssistantRunOptionsFactory.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using Microsoft.SemanticKernel.ChatCompletion;
using OpenAI.Assistants;

namespace Microsoft.SemanticKernel.Agents.OpenAI.Internal;
Expand Down Expand Up @@ -45,6 +46,18 @@ public static RunCreationOptions GenerateOptions(OpenAIAssistantDefinition defin
}
}

if (invocationOptions?.AdditionalMessages != null)
{
foreach (ChatMessageContent message in invocationOptions.AdditionalMessages)
{
ThreadInitializationMessage threadMessage = new(
role: message.Role == AuthorRole.User ? MessageRole.User : MessageRole.Assistant,
content: AssistantMessageFactory.GetMessageContents(message));

options.AdditionalMessages.Add(threadMessage);
}
}

return options;
}

Expand Down
4 changes: 4 additions & 0 deletions dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ public async Task<string> UploadFileAsync(Stream stream, string name, Cancellati
/// <param name="threadId">The thread identifier</param>
/// <param name="message">A non-system message with which to append to the conversation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <remarks>
/// Only supports messages with role = User or Assistant:
/// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-additional_messages
/// </remarks>
public Task AddChatMessageAsync(string threadId, ChatMessageContent message, CancellationToken cancellationToken = default)
{
this.ThrowIfDeleted();
Expand Down
10 changes: 10 additions & 0 deletions dotnet/src/Agents/OpenAI/OpenAIAssistantInvocationOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ public sealed class OpenAIAssistantInvocationOptions
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? AdditionalInstructions { get; init; }

/// <summary>
/// Additional messages to add to the thread.
/// </summary>
/// <remarks>
/// Only supports messages with role = User or Assistant:
/// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-additional_messages
/// </remarks>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IReadOnlyList<ChatMessageContent>? AdditionalMessages { get; init; }

/// <summary>
/// Set if code_interpreter tool is enabled.
/// </summary>
Expand Down
4 changes: 4 additions & 0 deletions dotnet/src/Agents/OpenAI/OpenAIThreadCreationOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ public sealed class OpenAIThreadCreationOptions
/// <summary>
/// Optional messages to initialize thread with..
/// </summary>
/// <remarks>
/// Only supports messages with role = User or Assistant:
/// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-additional_messages
/// </remarks>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IReadOnlyList<ChatMessageContent>? Messages { get; init; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ public void VerifyAssistantMessageAdapterGetMessageWithImageUrl()
/// <summary>
/// Verify options creation.
/// </summary>
[Fact(Skip = "API bug with data Uri construction")]
[Fact]
public void VerifyAssistantMessageAdapterGetMessageWithImageData()
{
// Arrange
ChatMessageContent message = new(AuthorRole.User, items: [new ImageContent(new byte[] { 1, 2, 3 }, "image/png")]);
ChatMessageContent message = new(AuthorRole.User, items: [new ImageContent(new byte[] { 1, 2, 3 }, "image/png") { DataUri = "data:image/png;base64,MTIz" }]);

// Act
MessageContent[] contents = AssistantMessageFactory.GetMessageContents(message).ToArray();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents.OpenAI;
using Microsoft.SemanticKernel.Agents.OpenAI.Internal;
using Microsoft.SemanticKernel.ChatCompletion;
using OpenAI.Assistants;
using Xunit;

Expand Down Expand Up @@ -35,6 +37,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsNullTest()

// Assert
Assert.NotNull(options);
Assert.Empty(options.AdditionalMessages);
Assert.Null(options.InstructionsOverride);
Assert.Null(options.Temperature);
Assert.Null(options.NucleusSamplingFactor);
Expand Down Expand Up @@ -147,4 +150,28 @@ public void AssistantRunOptionsFactoryExecutionOptionsMetadataTest()
Assert.Equal("value", options.Metadata["key1"]);
Assert.Equal(string.Empty, options.Metadata["key2"]);
}

/// <summary>
/// Verify run options generation with <see cref="OpenAIAssistantInvocationOptions"/> metadata.
/// </summary>
[Fact]
public void AssistantRunOptionsFactoryExecutionOptionsMessagesTest()
{
// Arrange
OpenAIAssistantDefinition definition = new("gpt-anything");

OpenAIAssistantInvocationOptions invocationOptions =
new()
{
AdditionalMessages = [
new ChatMessageContent(AuthorRole.User, "test message")
]
};

// Act
RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(definition, null, invocationOptions);

// Assert
Assert.Single(options.AdditionalMessages);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Text.Json;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents.OpenAI;
using Microsoft.SemanticKernel.ChatCompletion;
using Xunit;

namespace SemanticKernel.Agents.UnitTests.OpenAI;
Expand All @@ -23,6 +25,7 @@ public void OpenAIAssistantInvocationOptionsInitialState()
// Assert
Assert.Null(options.ModelName);
Assert.Null(options.AdditionalInstructions);
Assert.Null(options.AdditionalMessages);
Assert.Null(options.Metadata);
Assert.Null(options.Temperature);
Assert.Null(options.TopP);
Expand Down Expand Up @@ -50,6 +53,9 @@ public void OpenAIAssistantInvocationOptionsAssignment()
{
ModelName = "testmodel",
AdditionalInstructions = "test instructions",
AdditionalMessages = [
new ChatMessageContent(AuthorRole.User, "test message")
],
Metadata = new Dictionary<string, string>() { { "a", "1" } },
MaxCompletionTokens = 1000,
MaxPromptTokens = 1000,
Expand All @@ -65,6 +71,7 @@ public void OpenAIAssistantInvocationOptionsAssignment()
// Assert
Assert.Equal("testmodel", options.ModelName);
Assert.Equal("test instructions", options.AdditionalInstructions);
Assert.Single(options.AdditionalMessages);
Assert.Equal(2, options.Temperature);
Assert.Equal(0, options.TopP);
Assert.Equal(1000, options.MaxCompletionTokens);
Expand All @@ -89,6 +96,8 @@ private static void ValidateSerialization(OpenAIAssistantInvocationOptions sourc

// Assert
Assert.NotNull(target);
Assert.Equal(source.AdditionalInstructions, target.AdditionalInstructions);
Assert.Equivalent(source.AdditionalMessages, target.AdditionalMessages);
Assert.Equal(source.ModelName, target.ModelName);
Assert.Equal(source.Temperature, target.Temperature);
Assert.Equal(source.TopP, target.TopP);
Expand Down
55 changes: 53 additions & 2 deletions dotnet/src/IntegrationTests/Agents/OpenAIAssistantAgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ await this.ExecuteStreamingAgentAsync(
}

/// <summary>
/// Integration test for <see cref="OpenAIAssistantAgent"/> using function calling
/// and targeting Azure OpenAI services.
/// Integration test for <see cref="OpenAIAssistantAgent"/> adding a message with
/// function result contents.
/// </summary>
[RetryFact(typeof(HttpOperationException))]
public async Task AzureOpenAIAssistantAgentFunctionCallResultAsync()
Expand Down Expand Up @@ -130,6 +130,57 @@ await OpenAIAssistantAgent.CreateAsync(
}
}

/// <summary>
/// Integration test for <see cref="OpenAIAssistantAgent"/> adding additional message to a thread.
/// function result contents.
/// </summary>
[RetryFact(typeof(HttpOperationException))]
public async Task AzureOpenAIAssistantAgentAdditionalMessagesAsync()
{
var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get<AzureOpenAIConfiguration>();
Assert.NotNull(azureOpenAIConfiguration);

OpenAIAssistantAgent agent =
await OpenAIAssistantAgent.CreateAsync(
OpenAIClientProvider.ForAzureOpenAI(new AzureCliCredential(), new Uri(azureOpenAIConfiguration.Endpoint)),
new(azureOpenAIConfiguration.ChatDeploymentName!),
new Kernel());

OpenAIThreadCreationOptions threadOptions = new()
{
Messages = [
new ChatMessageContent(AuthorRole.User, "Hello"),
new ChatMessageContent(AuthorRole.Assistant, "How may I help you?"),
]
};
string threadId = await agent.CreateThreadAsync(threadOptions);
try
{
var messages = await agent.GetThreadMessagesAsync(threadId).ToArrayAsync();
Assert.Equal(2, messages.Length);

OpenAIAssistantInvocationOptions invocationOptions = new()
{
AdditionalMessages = [
new ChatMessageContent(AuthorRole.User, "This is my real question...in three parts:"),
new ChatMessageContent(AuthorRole.User, "Part 1"),
new ChatMessageContent(AuthorRole.User, "Part 2"),
new ChatMessageContent(AuthorRole.User, "Part 3"),
]
};

messages = await agent.InvokeAsync(threadId, invocationOptions).ToArrayAsync();
Assert.Single(messages);

messages = await agent.GetThreadMessagesAsync(threadId).ToArrayAsync();
Assert.Equal(7, messages.Length);
}
finally
{
await agent.DeleteThreadAsync(threadId);
}
}

private async Task ExecuteAgentAsync(
OpenAIClientProvider config,
string modelName,
Expand Down

0 comments on commit cde12d3

Please sign in to comment.