Skip to content

Commit

Permalink
.Net - Expose Agent Thread Messages (#5486)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Retrieving a thread based on ID is supported, but there's not a good way
to retrieve messages.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

Don't want to store messages in thread since this could impact memory
pressure but they _do_ need to be exposed when re-approaching an
existing thread. Turns out there was some dead-code from the early POC
to clean-up also.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
crickman authored Mar 14, 2024
1 parent b794c4a commit 74ff46f
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,32 @@ public async Task VerifyThreadLifecycleAsync()

await Assert.ThrowsAsync<HttpOperationException>(() => context.GetThreadModelAsync(thread.Id)).ConfigureAwait(true);
}

/// <summary>
/// Verify retrieval of thread messages
/// </summary>
[Fact(Skip = SkipReason)]
public async Task GetThreadAsync()
{
var threadId = "<your thread-id>";

var context = new OpenAIRestContext(AgentBuilder.OpenAIBaseUrl, TestConfig.OpenAIApiKey);
var thread = await ChatThread.GetAsync(context, threadId);

int index = 0;
string? messageId = null;
while (messageId != null || index == 0)
{
var messages = await thread.GetMessagesAsync(count: 100, lastMessageId: messageId).ConfigureAwait(true);
foreach (var message in messages)
{
++index;
this._output.WriteLine($"#{index:000} [{message.Id}] {message.Role} [{message.AgentId ?? "n/a"}]");

this._output.WriteLine(message.Content);
}

messageId = messages.Count > 0 ? messages[messages.Count - 1].Id : null;
}
}
}
3 changes: 1 addition & 2 deletions dotnet/src/Experimental/Agents.UnitTests/MockExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Net.Http;
using System.Threading;
using Moq;
Expand All @@ -15,7 +14,7 @@ public static void VerifyMock(this Mock<HttpMessageHandler> mockHandler, HttpMet
mockHandler.Protected().Verify(
"SendAsync",
Times.Exactly(times),
ItExpr.Is<HttpRequestMessage>(req => req.Method == method && (uri == null || req.RequestUri == new Uri(uri))),
ItExpr.Is<HttpRequestMessage>(req => req.Method == method && (uri == null || req.RequestUri!.AbsoluteUri.StartsWith(uri))),
ItExpr.IsAny<CancellationToken>());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,21 @@ public static Task<ThreadMessageModel> GetMessageAsync(
/// </summary>
/// <param name="context">A context for accessing OpenAI REST endpoint</param>
/// <param name="threadId">The thread identifier</param>
/// <param name="lastId">The identifier of the last message retrieved</param>
/// <param name="count">The maximum number of messages requested (up to 100 / default: 25)</param>
/// <param name="cancellationToken">A cancellation token</param>
/// <returns>A message list definition</returns>
public static Task<ThreadMessageListModel> GetMessagesAsync(
this OpenAIRestContext context,
string threadId,
string? lastId = null,
int? count = null,
CancellationToken cancellationToken = default)
{
return
context.ExecuteGetAsync<ThreadMessageListModel>(
context.GetMessagesUrl(threadId),
$"limit={count ?? 25}&after={lastId ?? string.Empty}",
cancellationToken);
}

Expand Down
9 changes: 9 additions & 0 deletions dotnet/src/Experimental/Agents/IAgentThread.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ public interface IAgentThread
/// <returns></returns>
Task<IChatMessage> AddUserMessageAsync(string message, IEnumerable<string>? fileIds = null, CancellationToken cancellationToken = default);

/// <summary>
/// Retrieve thread messages in descending order (most recent first).
/// </summary>
/// <param name="count">The maximum number of messages requested</param>
/// <param name="lastMessageId">The identifier of the last message retrieved</param>
/// <param name="cancellationToken">A cancellation token</param>
/// <returns>An list of <see cref="IChatMessage"/>.</returns>
Task<IReadOnlyList<IChatMessage>> GetMessagesAsync(int? count = null, string? lastMessageId = null, CancellationToken cancellationToken = default);

/// <summary>
/// Advance the thread with the specified agent.
/// </summary>
Expand Down
15 changes: 11 additions & 4 deletions dotnet/src/Experimental/Agents/Internal/ChatThread.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -31,7 +32,7 @@ public static async Task<IAgentThread> CreateAsync(OpenAIRestContext restContext
// Common case is for failure exception to be raised by REST invocation. Null result is a logical possibility, but unlikely edge case.
var threadModel = await restContext.CreateThreadModelAsync(cancellationToken).ConfigureAwait(false);

return new ChatThread(threadModel, messageListModel: null, restContext);
return new ChatThread(threadModel, restContext);
}

/// <summary>
Expand All @@ -44,9 +45,8 @@ public static async Task<IAgentThread> CreateAsync(OpenAIRestContext restContext
public static async Task<IAgentThread> GetAsync(OpenAIRestContext restContext, string threadId, CancellationToken cancellationToken = default)
{
var threadModel = await restContext.GetThreadModelAsync(threadId, cancellationToken).ConfigureAwait(false);
var messageListModel = await restContext.GetMessagesAsync(threadId, cancellationToken).ConfigureAwait(false);

return new ChatThread(threadModel, messageListModel, restContext);
return new ChatThread(threadModel, restContext);
}

/// <inheritdoc/>
Expand All @@ -59,6 +59,14 @@ public async Task<IChatMessage> AddUserMessageAsync(string message, IEnumerable<
return new ChatMessage(messageModel);
}

/// <inheritdoc/>
public async Task<IReadOnlyList<IChatMessage>> GetMessagesAsync(int? count = null, string? lastMessageId = null, CancellationToken cancellationToken = default)
{
var messageModel = await this._restContext.GetMessagesAsync(this.Id, lastMessageId, count, cancellationToken).ConfigureAwait(false);

return messageModel.Data.Select(m => new ChatMessage(m)).ToArray();
}

/// <inheritdoc/>
public IAsyncEnumerable<IChatMessage> InvokeAsync(IAgent agent, KernelArguments? arguments = null, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -109,7 +117,6 @@ public async Task DeleteAsync(CancellationToken cancellationToken)
/// </summary>
private ChatThread(
ThreadModel threadModel,
ThreadMessageListModel? messageListModel,
OpenAIRestContext restContext)
{
this.Id = threadModel.Id;
Expand Down

0 comments on commit 74ff46f

Please sign in to comment.