Skip to content

Commit

Permalink
.Net - Fix Open AI Agent Run State Processing (#5488)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Eric pointed out that not all run states are handled correctly:

#5449

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

Fully support the run states:
https://platform.openai.com/docs/api-reference/runs/object#runs/object-status

Ran complex tool calling scenarios repeatedly to verify.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
crickman authored Mar 15, 2024
1 parent bf64bdd commit 88bdb11
Showing 1 changed file with 33 additions and 38 deletions.
71 changes: 33 additions & 38 deletions dotnet/src/Experimental/Agents/Internal/ChatRun.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ internal sealed class ChatRun
public string ThreadId => this._model.ThreadId;

private const string ActionState = "requires_action";
private const string FailedState = "failed";
private const string CompletedState = "completed";
private static readonly TimeSpan s_pollingInterval = TimeSpan.FromMilliseconds(500);
private static readonly TimeSpan s_pollingBackoff = TimeSpan.FromSeconds(1);
Expand All @@ -38,6 +37,15 @@ internal sealed class ChatRun
{
"queued",
"in_progress",
"cancelling",
};

private static readonly HashSet<string> s_terminalStates =
new(StringComparer.OrdinalIgnoreCase)
{
"expired",
"failed",
"cancelled",
};

private readonly OpenAIRestContext _restContext;
Expand All @@ -48,38 +56,32 @@ internal sealed class ChatRun
/// <inheritdoc/>
public async IAsyncEnumerable<string> GetResultAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Poll until actionable
await PollRunStatus().ConfigureAwait(false);

// Retrieve steps
var processedMessageIds = new HashSet<string>();
var steps = await this._restContext.GetRunStepsAsync(this.ThreadId, this.Id, cancellationToken).ConfigureAwait(false);

do
{
// Poll run and steps until actionable
var steps = await PollRunStatusAsync().ConfigureAwait(false);

// Is in terminal state?
if (s_terminalStates.Contains(this._model.Status))
{
throw new AgentException($"Run terminated - {this._model.Status} [{this.Id}]: {this._model.LastError?.Message ?? "Unknown"}");
}

// Is tool action required?
if (ActionState.Equals(this._model.Status, StringComparison.OrdinalIgnoreCase))
{
// Execute functions in parallel and post results at once.
var tasks = steps.Data.SelectMany(step => this.ExecuteStep(step, cancellationToken)).ToArray();
await Task.WhenAll(tasks).ConfigureAwait(false);

var results = tasks.Select(t => t.Result).ToArray();
await this._restContext.AddToolOutputsAsync(this.ThreadId, this.Id, results, cancellationToken).ConfigureAwait(false);

// Refresh run as it goes back into pending state after posting function results.
await PollRunStatus(force: true).ConfigureAwait(false);

// Refresh steps to retrieve additional messages.
steps = await this._restContext.GetRunStepsAsync(this.ThreadId, this.Id, cancellationToken).ConfigureAwait(false);
}

// Did fail?
if (FailedState.Equals(this._model.Status, StringComparison.OrdinalIgnoreCase))
{
throw new AgentException($"Unexpected failure processing run: {this.Id}: {this._model.LastError?.Message ?? "Unknown"}");
if (tasks.Length > 0)
{
var results = await Task.WhenAll(tasks).ConfigureAwait(false);
await this._restContext.AddToolOutputsAsync(this.ThreadId, this.Id, results, cancellationToken).ConfigureAwait(false);
}
}

// Enumerate completed messages
var newMessageIds =
steps.Data
.Where(s => s.StepDetails.MessageCreation != null)
Expand All @@ -96,21 +98,15 @@ public async IAsyncEnumerable<string> GetResultAsync([EnumeratorCancellation] Ca
}
while (!CompletedState.Equals(this._model.Status, StringComparison.OrdinalIgnoreCase));

async Task PollRunStatus(bool force = false)
async Task<ThreadRunStepListModel> PollRunStatusAsync()
{
int count = 0;

// Ignore model status when forced.
while (force || s_pollingStates.Contains(this._model.Status))
do
{
if (!force)
{
// Reduce polling frequency after a couple attempts
await Task.Delay(count >= 2 ? s_pollingInterval : s_pollingBackoff, cancellationToken).ConfigureAwait(false);
++count;
}

force = false;
// Reduce polling frequency after a couple attempts
await Task.Delay(count >= 2 ? s_pollingInterval : s_pollingBackoff, cancellationToken).ConfigureAwait(false);
++count;

try
{
Expand All @@ -121,6 +117,9 @@ async Task PollRunStatus(bool force = false)
// Retry anyway..
}
}
while (s_pollingStates.Contains(this._model.Status));

return await this._restContext.GetRunStepsAsync(this.ThreadId, this.Id, cancellationToken).ConfigureAwait(false);
}
}

Expand Down Expand Up @@ -153,11 +152,7 @@ private IEnumerable<Task<ToolResultModel>> ExecuteStep(ThreadRunStepModel step,
private async Task<ToolResultModel> ProcessFunctionStepAsync(string callId, ThreadRunStepModel.FunctionDetailsModel functionDetails, CancellationToken cancellationToken)
{
var result = await InvokeFunctionCallAsync().ConfigureAwait(false);
var toolResult = result as string;
if (toolResult == null)
{
toolResult = JsonSerializer.Serialize(result);
}
var toolResult = result as string ?? JsonSerializer.Serialize(result);

return
new ToolResultModel
Expand Down

0 comments on commit 88bdb11

Please sign in to comment.