Skip to content

Commit

Permalink
.Net Agents - Fix Function Call Handling for Streaming (#9652)
Browse files Browse the repository at this point in the history
### Motivation and Context
<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Fixes: #9638

`System.ArgumentException: An item with the same key has already been
added. ` - Duplicate key added when processing function result.

### Description
<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

The processing loop for assistant streaming is selecting completed steps
which is resulting in over-processing that violates the state-tracking.
This was due to leveraging the existing utiility method
`GetRunStepsAsync`. I removed this method in favor of inline invocation
since the processing for Streaming and Non-Streaming have distinct
considerations. Also, as the SDK has evolved (paging removed), the
utility method isn't adding much value.

> Note: Was able to reproduce reported issue by setting
`ParallelToolCallsEnabled = false` on `OpenAIAssistant_Streaming` demo
and verify fix. Existing approach was able to handle parallel function
calls and function calls on different steps adequetly.

### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
crickman authored Nov 11, 2024
1 parent 4283cf2 commit 83a59d4
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public static async IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(Assist
throw new KernelException($"Agent Failure - Run terminated: {run.Status} [{run.Id}]: {run.LastError?.Message ?? "Unknown"}");
}

IReadOnlyList<RunStep> steps = await GetRunStepsAsync(client, run, cancellationToken).ConfigureAwait(false);
RunStep[] steps = await client.GetRunStepsAsync(run.ThreadId, run.Id, cancellationToken: cancellationToken).ToArrayAsync(cancellationToken).ConfigureAwait(false);

// Is tool action required?
if (run.Status == RunStatus.RequiresAction)
Expand Down Expand Up @@ -475,11 +475,14 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin

if (run.Status == RunStatus.RequiresAction)
{
IReadOnlyList<RunStep> steps = await GetRunStepsAsync(client, run, cancellationToken).ConfigureAwait(false);
RunStep[] activeSteps =
await client.GetRunStepsAsync(run.ThreadId, run.Id, cancellationToken: cancellationToken)
.Where(step => step.Status == RunStepStatus.InProgress)
.ToArrayAsync(cancellationToken).ConfigureAwait(false);

// Capture map between the tool call and its associated step
Dictionary<string, string> toolMap = [];
foreach (RunStep step in steps)
foreach (RunStep step in activeSteps)
{
foreach (RunStepToolCall stepDetails in step.Details.ToolCalls)
{
Expand All @@ -488,7 +491,7 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
}

// Execute functions in parallel and post results at once.
FunctionCallContent[] functionCalls = steps.SelectMany(step => ParseFunctionStep(agent, step)).ToArray();
FunctionCallContent[] functionCalls = activeSteps.SelectMany(step => ParseFunctionStep(agent, step)).ToArray();
if (functionCalls.Length > 0)
{
// Emit function-call content
Expand All @@ -504,7 +507,7 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
ToolOutput[] toolOutputs = GenerateToolOutputs(functionResults);
asyncUpdates = client.SubmitToolOutputsToRunStreamingAsync(run.ThreadId, run.Id, toolOutputs, cancellationToken);

foreach (RunStep step in steps)
foreach (RunStep step in activeSteps)
{
stepFunctionResults.Add(step.Id, functionResults.Where(result => step.Id == toolMap[result.CallId!]).ToArray());
}
Expand Down Expand Up @@ -560,18 +563,6 @@ await RetrieveMessageAsync(
logger.LogOpenAIAssistantCompletedRun(nameof(InvokeAsync), run?.Id ?? "Failed", threadId);
}

private static async Task<IReadOnlyList<RunStep>> GetRunStepsAsync(AssistantClient client, ThreadRun run, CancellationToken cancellationToken)
{
List<RunStep> steps = [];

await foreach (RunStep step in client.GetRunStepsAsync(run.ThreadId, run.Id, cancellationToken: cancellationToken).ConfigureAwait(false))
{
steps.Add(step);
}

return steps;
}

private static ChatMessageContent GenerateMessageContent(string? assistantName, ThreadMessage message, RunStep? completedStep = null)
{
AuthorRole role = new(message.Role.ToString());
Expand Down Expand Up @@ -788,7 +779,7 @@ private static ChatMessageContent GenerateFunctionCallContent(string agentName,
return functionCallContent;
}

private static ChatMessageContent GenerateFunctionResultContent(string agentName, FunctionResultContent[] functionResults, RunStep completedStep)
private static ChatMessageContent GenerateFunctionResultContent(string agentName, IEnumerable<FunctionResultContent> functionResults, RunStep completedStep)
{
ChatMessageContent functionResultContent = new(AuthorRole.Tool, content: null)
{
Expand Down

0 comments on commit 83a59d4

Please sign in to comment.