Skip to content

Commit

Permalink
Merge branch 'main' into users/markwallace/issue_8946_1
Browse files Browse the repository at this point in the history
  • Loading branch information
markwallace-microsoft authored Oct 11, 2024
2 parents 6dc2321 + 718d5cc commit fb4c9bb
Show file tree
Hide file tree
Showing 26 changed files with 507 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace Memory;
public class VectorStore_VectorSearch_MultiStore_AzureAISearch(ITestOutputHelper output) : BaseTest(output)
{
[Fact]
public async Task ExampleWitDIAsync()
public async Task ExampleWithDIAsync()
{
// Use the kernel for DI purposes.
var kernelBuilder = Kernel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace Memory;
public class VectorStore_VectorSearch_MultiStore_InMemory(ITestOutputHelper output) : BaseTest(output)
{
[Fact]
public async Task ExampleWitDIAsync()
public async Task ExampleWithDIAsync()
{
// Use the kernel for DI purposes.
var kernelBuilder = Kernel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace Memory;
public class VectorStore_VectorSearch_MultiStore_Qdrant(ITestOutputHelper output, VectorStoreQdrantContainerFixture qdrantFixture) : BaseTest(output), IClassFixture<VectorStoreQdrantContainerFixture>
{
[Fact]
public async Task ExampleWitDIAsync()
public async Task ExampleWithDIAsync()
{
// Use the kernel for DI purposes.
var kernelBuilder = Kernel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class VectorStore_VectorSearch_MultiStore_Redis(ITestOutputHelper output,
[Theory]
[InlineData(RedisStorageType.Json)]
[InlineData(RedisStorageType.HashSet)]
public async Task ExampleWitDIAsync(RedisStorageType redisStorageType)
public async Task ExampleWithDIAsync(RedisStorageType redisStorageType)
{
// Use the kernel for DI purposes.
var kernelBuilder = Kernel
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using Azure.Identity;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using Microsoft.SemanticKernel.Connectors.InMemory;
Expand All @@ -25,7 +26,7 @@ public async Task ExampleAsync()
var textEmbeddingGenerationService = new AzureOpenAITextEmbeddingGenerationService(
TestConfiguration.AzureOpenAIEmbeddings.DeploymentName,
TestConfiguration.AzureOpenAIEmbeddings.Endpoint,
TestConfiguration.AzureOpenAIEmbeddings.ApiKey);
new AzureCliCredential());

// Construct an InMemory vector store.
var vectorStore = new InMemoryVectorStore();
Expand Down
4 changes: 4 additions & 0 deletions dotnet/src/Agents/Abstractions/AggregatorAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ public sealed class AggregatorAgent(Func<AgentChat> chatProvider) : Agent
public AggregatorMode Mode { get; init; } = AggregatorMode.Flat;

/// <inheritdoc/>
/// <remarks>
/// Different <see cref="AggregatorAgent"/> will never share the same channel.
/// </remarks>
protected internal override IEnumerable<string> GetChannelKeys()
{
yield return typeof(AggregatorChannel).FullName!;
yield return this.GetHashCode().ToString();
}

/// <inheritdoc/>
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Agents/Abstractions/KernelAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public abstract class KernelAgent : Agent
/// <summary>
/// A prompt-template based on the agent instructions.
/// </summary>
protected IPromptTemplate? Template { get; set; }
public IPromptTemplate? Template { get; protected set; }

/// <summary>
/// Format the system instructions for the agent.
Expand Down
136 changes: 90 additions & 46 deletions dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,14 @@ public static async IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(Assist
// Process code-interpreter content
if (toolCall.ToolKind == RunStepToolCallKind.CodeInterpreter)
{
content = GenerateCodeInterpreterContent(agent.GetName(), toolCall.CodeInterpreterInput);
content = GenerateCodeInterpreterContent(agent.GetName(), toolCall.CodeInterpreterInput, completedStep);
isVisible = true;
}
// Process function result content
else if (toolCall.ToolKind == RunStepToolCallKind.Function)
{
FunctionResultContent functionStep = functionSteps[toolCall.ToolCallId]; // Function step always captured on invocation
content = GenerateFunctionResultContent(agent.GetName(), [functionStep]);
content = GenerateFunctionResultContent(agent.GetName(), [functionStep], completedStep);
}

if (content is not null)
Expand Down Expand Up @@ -366,14 +366,14 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin

// Evaluate status and process steps and messages, as encountered.
HashSet<string> processedStepIds = [];
Dictionary<string, RunStep?> activeMessages = [];
Dictionary<string, FunctionResultContent[]> stepFunctionResults = [];
List<RunStep> stepsToProcess = [];
ThreadRun? run = null;
RunStep? currentStep = null;

IAsyncEnumerable<StreamingUpdate> asyncUpdates = client.CreateRunStreamingAsync(threadId, agent.Id, options, cancellationToken);
do
{
activeMessages.Clear();
stepsToProcess.Clear();

await foreach (StreamingUpdate update in asyncUpdates.ConfigureAwait(false))
{
Expand All @@ -397,15 +397,6 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
break;
}
}
else if (update is MessageStatusUpdate statusUpdate)
{
switch (statusUpdate.UpdateKind)
{
case StreamingUpdateReason.MessageCompleted:
activeMessages.Add(statusUpdate.Value.Id, currentStep);
break;
}
}
else if (update is RunStepDetailsUpdate detailsUpdate)
{
StreamingChatMessageContent? toolContent = GenerateStreamingCodeInterpreterContent(agent.GetName(), detailsUpdate);
Expand All @@ -418,11 +409,8 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
{
switch (stepUpdate.UpdateKind)
{
case StreamingUpdateReason.RunStepCreated:
currentStep = stepUpdate.Value;
break;
case StreamingUpdateReason.RunStepCompleted:
currentStep = null;
stepsToProcess.Add(stepUpdate.Value);
break;
default:
break;
Expand All @@ -445,6 +433,16 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
{
IReadOnlyList<RunStep> steps = await GetRunStepsAsync(client, run, cancellationToken).ConfigureAwait(false);

// Capture map between the tool call and its associated step
Dictionary<string, string> toolMap = [];
foreach (RunStep step in steps)
{
foreach (RunStepToolCall stepDetails in step.Details.ToolCalls)
{
toolMap[stepDetails.ToolCallId] = step.Id;
}
}

// Execute functions in parallel and post results at once.
FunctionCallContent[] functionCalls = steps.SelectMany(step => ParseFunctionStep(agent, step)).ToArray();
if (functionCalls.Length > 0)
Expand All @@ -462,27 +460,54 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
ToolOutput[] toolOutputs = GenerateToolOutputs(functionResults);
asyncUpdates = client.SubmitToolOutputsToRunStreamingAsync(run.ThreadId, run.Id, toolOutputs, cancellationToken);

messages?.Add(GenerateFunctionResultContent(agent.GetName(), functionResults));
foreach (RunStep step in steps)
{
stepFunctionResults.Add(step.Id, functionResults.Where(result => step.Id == toolMap[result.CallId!]).ToArray());
}
}
}

if (activeMessages.Count > 0)
if (stepsToProcess.Count > 0)
{
logger.LogOpenAIAssistantProcessingRunMessages(nameof(InvokeAsync), run!.Id, threadId);

foreach (string messageId in activeMessages.Keys)
foreach (RunStep step in stepsToProcess)
{
RunStep? step = activeMessages[messageId];
ThreadMessage? message = await RetrieveMessageAsync(client, threadId, messageId, agent.PollingOptions.MessageSynchronizationDelay, cancellationToken).ConfigureAwait(false);

if (message != null)
if (!string.IsNullOrEmpty(step.Details.CreatedMessageId))
{
ThreadMessage? message =
await RetrieveMessageAsync(
client,
threadId,
step.Details.CreatedMessageId,
agent.PollingOptions.MessageSynchronizationDelay,
cancellationToken).ConfigureAwait(false);

if (message != null)
{
ChatMessageContent content = GenerateMessageContent(agent.GetName(), message, step);
messages?.Add(content);
}
}
else
{
ChatMessageContent content = GenerateMessageContent(agent.GetName(), message, step);
messages?.Add(content);
foreach (RunStepToolCall toolCall in step.Details.ToolCalls)
{
switch (toolCall.ToolKind)
{
case RunStepToolCallKind.CodeInterpreter:
messages?.Add(GenerateCodeInterpreterContent(agent.GetName(), toolCall.CodeInterpreterInput, step));
break;
case RunStepToolCallKind.Function:
messages?.Add(GenerateFunctionResultContent(agent.GetName(), stepFunctionResults[step.Id], step));
stepFunctionResults.Remove(step.Id);
break;
}
}
}
}

logger.LogOpenAIAssistantProcessedRunMessages(nameof(InvokeAsync), activeMessages.Count, run!.Id, threadId);
logger.LogOpenAIAssistantProcessedRunMessages(nameof(InvokeAsync), stepsToProcess.Count, run!.Id, threadId);
}
}
while (run?.Status != RunStatus.Completed);
Expand All @@ -506,25 +531,27 @@ private static ChatMessageContent GenerateMessageContent(string? assistantName,
{
AuthorRole role = new(message.Role.ToString());

Dictionary<string, object?>? metaData =
completedStep != null ?
new Dictionary<string, object?>
{
{ nameof(completedStep.CreatedAt), completedStep.CreatedAt },
{ nameof(MessageContentUpdate.MessageId), message.Id },
{ nameof(RunStepDetailsUpdate.StepId), completedStep.Id },
{ nameof(completedStep.RunId), completedStep.RunId },
{ nameof(completedStep.ThreadId), completedStep.ThreadId },
{ nameof(completedStep.AssistantId), completedStep.AssistantId },
{ nameof(completedStep.Usage), completedStep.Usage },
} :
null;
Dictionary<string, object?>? metadata =
new()
{
{ nameof(ThreadMessage.CreatedAt), message.CreatedAt },
{ nameof(ThreadMessage.AssistantId), message.AssistantId },
{ nameof(ThreadMessage.ThreadId), message.ThreadId },
{ nameof(ThreadMessage.RunId), message.RunId },
{ nameof(MessageContentUpdate.MessageId), message.Id },
};

if (completedStep != null)
{
metadata[nameof(RunStepDetailsUpdate.StepId)] = completedStep.Id;
metadata[nameof(RunStep.Usage)] = completedStep.Usage;
}

ChatMessageContent content =
new(role, content: null)
{
AuthorName = assistantName,
Metadata = metaData,
Metadata = metadata,
};

foreach (MessageContent itemContent in message.Content)
Expand Down Expand Up @@ -655,8 +682,11 @@ private static StreamingAnnotationContent GenerateStreamingAnnotationContent(Tex
};
}

private static ChatMessageContent GenerateCodeInterpreterContent(string agentName, string pythonCode)
private static ChatMessageContent GenerateCodeInterpreterContent(string agentName, string pythonCode, RunStep completedStep)
{
Dictionary<string, object?> metadata = GenerateToolCallMetadata(completedStep);
metadata[OpenAIAssistantAgent.CodeInterpreterMetadataKey] = true;

return
new ChatMessageContent(
AuthorRole.Assistant,
Expand All @@ -665,7 +695,7 @@ private static ChatMessageContent GenerateCodeInterpreterContent(string agentNam
])
{
AuthorName = agentName,
Metadata = new Dictionary<string, object?> { { OpenAIAssistantAgent.CodeInterpreterMetadataKey, true } },
Metadata = metadata,
};
}

Expand Down Expand Up @@ -713,11 +743,12 @@ private static ChatMessageContent GenerateFunctionCallContent(string agentName,
return functionCallContent;
}

private static ChatMessageContent GenerateFunctionResultContent(string agentName, FunctionResultContent[] functionResults)
private static ChatMessageContent GenerateFunctionResultContent(string agentName, FunctionResultContent[] functionResults, RunStep completedStep)
{
ChatMessageContent functionResultContent = new(AuthorRole.Tool, content: null)
{
AuthorName = agentName
AuthorName = agentName,
Metadata = GenerateToolCallMetadata(completedStep),
};

foreach (FunctionResultContent functionResult in functionResults)
Expand All @@ -733,6 +764,19 @@ private static ChatMessageContent GenerateFunctionResultContent(string agentName
return functionResultContent;
}

private static Dictionary<string, object?> GenerateToolCallMetadata(RunStep completedStep)
{
return new()
{
{ nameof(RunStep.CreatedAt), completedStep.CreatedAt },
{ nameof(RunStep.AssistantId), completedStep.AssistantId },
{ nameof(RunStep.ThreadId), completedStep.ThreadId },
{ nameof(RunStep.RunId), completedStep.RunId },
{ nameof(RunStepDetailsUpdate.StepId), completedStep.Id },
{ nameof(RunStep.Usage), completedStep.Usage },
};
}

private static Task<FunctionResultContent>[] ExecuteFunctionSteps(OpenAIAssistantAgent agent, FunctionCallContent[] functionCalls, CancellationToken cancellationToken)
{
Task<FunctionResultContent>[] functionTasks = new Task<FunctionResultContent>[functionCalls.Length];
Expand Down
Loading

0 comments on commit fb4c9bb

Please sign in to comment.