From 2fdd74e18d765260da663273708bd38134450ff1 Mon Sep 17 00:00:00 2001 From: Evan Mattson <35585003+moonbox3@users.noreply.github.com> Date: Thu, 31 Oct 2024 10:11:48 -0400 Subject: [PATCH 1/8] Python: Bump Python version to 1.13.0 for a release. (#9480) ### Motivation and Context Bump Python version to 1.13.0 for a release. ### Description Bump Python version to 1.13.0 for a release. ### Contribution Checklist - [X] The code builds clean without any errors or warnings - [X] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [X] All unit tests pass, and I have added new tests where possible - [X] I didn't break anyone :smile: --- python/semantic_kernel/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/semantic_kernel/__init__.py b/python/semantic_kernel/__init__.py index 545ce3e87884..2910ffe9ce06 100644 --- a/python/semantic_kernel/__init__.py +++ b/python/semantic_kernel/__init__.py @@ -2,5 +2,5 @@ from semantic_kernel.kernel import Kernel -__version__ = "1.12.1" +__version__ = "1.13.0" __all__ = ["Kernel", "__version__"] From 44b6762f4524707c0f14be73eb83d1b7c127189e Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Thu, 31 Oct 2024 08:02:24 -0700 Subject: [PATCH 2/8] .Net: Added a streaming flag to filter context models (#9482) ### Motivation and Context Resolves: https://github.com/microsoft/semantic-kernel/issues/7336 This PR adds a boolean flag to identify whether a filter is invoked within streaming or non-streaming mode. This provides an ability to use the same filter for both scenarios and access filter context data in easier way based on used mode. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .../Filtering/FunctionInvocationFiltering.cs | 90 ++++++++++++++++--- .../Filtering/TelemetryWithFilters.cs | 76 +++++++++++++--- .../Client/MistralClientTests.cs | 43 +++++++++ .../Client/MistralClient.cs | 6 +- .../Core/AutoFunctionInvocationFilterTests.cs | 57 +++++++++++- .../Core/ClientCore.ChatCompletion.cs | 4 + .../FunctionCalling/FunctionCallsProcessor.cs | 6 +- .../AutoFunctionInvocationContext.cs | 5 ++ .../Function/FunctionInvocationContext.cs | 7 ++ .../Filters/Prompt/PromptRenderContext.cs | 7 ++ .../Functions/KernelFunction.cs | 4 +- .../src/SemanticKernel.Abstractions/Kernel.cs | 8 +- .../Functions/KernelFunctionFromPrompt.cs | 20 ++++- .../Filters/FunctionInvocationFilterTests.cs | 31 +++++++ .../Filters/PromptRenderFilterTests.cs | 32 +++++++ .../FunctionCallsProcessorTests.cs | 16 ++++ 16 files changed, 375 insertions(+), 37 deletions(-) diff --git a/dotnet/samples/Concepts/Filtering/FunctionInvocationFiltering.cs b/dotnet/samples/Concepts/Filtering/FunctionInvocationFiltering.cs index e1bbd1561463..48b9763da081 100644 --- a/dotnet/samples/Concepts/Filtering/FunctionInvocationFiltering.cs +++ b/dotnet/samples/Concepts/Filtering/FunctionInvocationFiltering.cs @@ -63,26 +63,63 @@ public async Task FunctionFilterResultOverrideOnStreamingAsync() { var builder = Kernel.CreateBuilder(); - // This filter overrides streaming results with "item * 2" logic. + // This filter overrides streaming results with new ending in each chunk. builder.Services.AddSingleton(); var kernel = builder.Build(); - static async IAsyncEnumerable GetData() + static async IAsyncEnumerable GetData() { - yield return 1; - yield return 2; - yield return 3; + yield return "chunk1"; + yield return "chunk2"; + yield return "chunk3"; } var function = KernelFunctionFactory.CreateFromMethod(GetData); - await foreach (var item in kernel.InvokeStreamingAsync(function)) + await foreach (var item in kernel.InvokeStreamingAsync(function)) { Console.WriteLine(item); } - // Output: 2, 4, 6. + // Output: + // chunk1 - updated from filter + // chunk2 - updated from filter + // chunk3 - updated from filter + } + + [Fact] + public async Task FunctionFilterResultOverrideForBothStreamingAndNonStreamingAsync() + { + var builder = Kernel.CreateBuilder(); + + // This filter overrides result for both streaming and non-streaming invocation modes. + builder.Services.AddSingleton(); + + var kernel = builder.Build(); + + static async IAsyncEnumerable GetData() + { + yield return "chunk1"; + yield return "chunk2"; + yield return "chunk3"; + } + + var nonStreamingFunction = KernelFunctionFactory.CreateFromMethod(() => "Result"); + var streamingFunction = KernelFunctionFactory.CreateFromMethod(GetData); + + var nonStreamingResult = await kernel.InvokeAsync(nonStreamingFunction); + var streamingResult = await kernel.InvokeStreamingAsync(streamingFunction).ToListAsync(); + + Console.WriteLine($"Non-streaming result: {nonStreamingResult}"); + Console.WriteLine($"Streaming result \n: {string.Join("\n", streamingResult)}"); + + // Output: + // Non-streaming result: Result - updated from filter + // Streaming result: + // chunk1 - updated from filter + // chunk2 - updated from filter + // chunk3 - updated from filter } [Fact] @@ -172,16 +209,16 @@ public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, F // In streaming scenario, async enumerable is available in context result object. // To override data: get async enumerable from function result, override data and set new async enumerable in context result: - var enumerable = context.Result.GetValue>(); + var enumerable = context.Result.GetValue>(); context.Result = new FunctionResult(context.Result, OverrideStreamingDataAsync(enumerable!)); } - private async IAsyncEnumerable OverrideStreamingDataAsync(IAsyncEnumerable data) + private async IAsyncEnumerable OverrideStreamingDataAsync(IAsyncEnumerable data) { await foreach (var item in data) { // Example: override streaming data - yield return item * 2; + yield return $"{item} - updated from filter"; } } } @@ -255,6 +292,39 @@ private async IAsyncEnumerable StreamingWithExceptionHandlingAsync(IAsyn } } + /// Filter that can be used for both streaming and non-streaming invocation modes at the same time. + private sealed class DualModeFilter : IFunctionInvocationFilter + { + public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func next) + { + await next(context); + + if (context.IsStreaming) + { + var enumerable = context.Result.GetValue>(); + context.Result = new FunctionResult(context.Result, OverrideStreamingDataAsync(enumerable!)); + } + else + { + var data = context.Result.GetValue(); + context.Result = new FunctionResult(context.Result, OverrideNonStreamingData(data!)); + } + } + + private async IAsyncEnumerable OverrideStreamingDataAsync(IAsyncEnumerable data) + { + await foreach (var item in data) + { + yield return $"{item} - updated from filter"; + } + } + + private string OverrideNonStreamingData(string data) + { + return $"{data} - updated from filter"; + } + } + #endregion #region Filters diff --git a/dotnet/samples/Concepts/Filtering/TelemetryWithFilters.cs b/dotnet/samples/Concepts/Filtering/TelemetryWithFilters.cs index 6823f6c14820..0b5938d1761c 100644 --- a/dotnet/samples/Concepts/Filtering/TelemetryWithFilters.cs +++ b/dotnet/samples/Concepts/Filtering/TelemetryWithFilters.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Diagnostics; +using System.Text; using System.Text.Json; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; @@ -17,8 +18,10 @@ namespace Filtering; /// public class TelemetryWithFilters(ITestOutputHelper output) : BaseTest(output) { - [Fact] - public async Task LoggingAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task LoggingAsync(bool isStreaming) { // Initialize kernel with chat completion service. var builder = Kernel @@ -69,9 +72,25 @@ public async Task LoggingAsync() { // Invoke prompt with arguments. const string Prompt = "Given the current time of day and weather, what is the likely color of the sky in {{$city}}?"; - var result = await kernel.InvokePromptAsync(Prompt, new(executionSettings) { ["city"] = "Boston" }); - Console.WriteLine(result); + var arguments = new KernelArguments(executionSettings) { ["city"] = "Boston" }; + + if (isStreaming) + { + await foreach (var item in kernel.InvokePromptStreamingAsync(Prompt, arguments)) + { + if (item.Content is not null) + { + Console.Write(item.Content); + } + } + } + else + { + var result = await kernel.InvokePromptAsync(Prompt, arguments); + + Console.WriteLine(result); + } } // Output: @@ -127,17 +146,8 @@ public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, F await next(context); logger.LogInformation("Function {FunctionName} succeeded.", context.Function.Name); - logger.LogTrace("Function result: {Result}", context.Result.ToString()); - if (logger.IsEnabled(LogLevel.Information)) - { - var usage = context.Result.Metadata?["Usage"]; - - if (usage is not null) - { - logger.LogInformation("Usage: {Usage}", JsonSerializer.Serialize(usage)); - } - } + await this.LogFunctionResultAsync(context); } catch (Exception exception) { @@ -156,6 +166,44 @@ public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, F } } } + + private async Task LogFunctionResultAsync(FunctionInvocationContext context) + { + string? result = null; + object? usage = null; + + if (context.IsStreaming) + { + var stringBuilder = new StringBuilder(); + + await foreach (var item in context.Result.GetValue>()!) + { + if (item.Content is not null) + { + stringBuilder.Append(item.Content); + } + + usage = item.Metadata?["Usage"]; + } + + result = stringBuilder.ToString(); + } + else + { + result = context.Result.GetValue(); + usage = context.Result.Metadata?["Usage"]; + } + + if (result is not null) + { + logger.LogTrace("Function result: {Result}", result); + } + + if (logger.IsEnabled(LogLevel.Information) && usage is not null) + { + logger.LogInformation("Usage: {Usage}", JsonSerializer.Serialize(usage)); + } + } } /// diff --git a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs index fbd082eb077c..37e00ec56154 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs @@ -357,6 +357,49 @@ public async Task ValidateGetChatMessageContentsWithFunctionInvocationFilterAsyn Assert.Contains("GetWeather", invokedFunctions); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming) + { + // Arrange + bool? actualStreamingFlag = null; + + var client = isStreaming ? + this.CreateMistralClientStreaming("mistral-tiny", "https://api.mistral.ai/v1/chat/completions", "chat_completions_streaming_function_call_response.txt") : + this.CreateMistralClient("mistral-large-latest", "https://api.mistral.ai/v1/chat/completions", "chat_completions_function_call_response.json", "chat_completions_function_called_response.json"); + + var kernel = new Kernel(); + kernel.Plugins.AddFromType(); + + var filter = new FakeAutoFunctionFilter(async (context, next) => + { + actualStreamingFlag = context.IsStreaming; + await next(context); + }); + + kernel.AutoFunctionInvocationFilters.Add(filter); + + // Act + var executionSettings = new MistralAIPromptExecutionSettings { ToolCallBehavior = MistralAIToolCallBehavior.AutoInvokeKernelFunctions }; + var chatHistory = new ChatHistory + { + new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?") + }; + + if (isStreaming) + { + await client.GetStreamingChatMessageContentsAsync(chatHistory, default, executionSettings, kernel).ToListAsync(); + } + else + { + await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel); + } + + // Assert + Assert.Equal(isStreaming, actualStreamingFlag); + } + [Fact] public async Task ValidateGetChatMessageContentsWithAutoFunctionInvocationFilterTerminateAsync() { diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs index 532bc94e6150..9157073b244c 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs @@ -178,7 +178,8 @@ internal async Task> GetChatMessageContentsAsy RequestSequenceIndex = requestIndex - 1, FunctionSequenceIndex = toolCallIndex, FunctionCount = chatChoice.ToolCalls.Count, - CancellationToken = cancellationToken + CancellationToken = cancellationToken, + IsStreaming = false }; s_inflightAutoInvokes.Value++; try @@ -408,7 +409,8 @@ internal async IAsyncEnumerable GetStreamingChatMes RequestSequenceIndex = requestIndex - 1, FunctionSequenceIndex = toolCallIndex, FunctionCount = toolCalls.Count, - CancellationToken = cancellationToken + CancellationToken = cancellationToken, + IsStreaming = true }; s_inflightAutoInvokes.Value++; try diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs index 5df2fb54cdb5..e15b1f74b042 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs @@ -126,7 +126,6 @@ public async Task FiltersAreExecutedCorrectlyOnStreamingAsync() public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync() { // Arrange - var function = KernelFunctionFactory.CreateFromMethod(() => "Result"); var executionOrder = new List(); var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); @@ -183,7 +182,6 @@ public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync() public async Task MultipleFiltersAreExecutedInOrderAsync(bool isStreaming) { // Arrange - var function = KernelFunctionFactory.CreateFromMethod(() => "Result"); var executionOrder = new List(); var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); @@ -573,6 +571,61 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync() Assert.Equal(AuthorRole.Tool, lastMessageContent.Role); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming) + { + // Arrange + bool? actualStreamingFlag = null; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var filter = new AutoFunctionInvocationFilter(async (context, next) => + { + actualStreamingFlag = context.IsStreaming; + await next(context); + }); + + var builder = Kernel.CreateBuilder(); + + builder.Plugins.Add(plugin); + + builder.Services.AddSingleton((serviceProvider) => + { + return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); + }); + + builder.Services.AddSingleton(filter); + + var kernel = builder.Build(); + + var arguments = new KernelArguments(new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + }); + + // Act + if (isStreaming) + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + await kernel.InvokePromptStreamingAsync("Test prompt", arguments).ToListAsync(); + } + else + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + await kernel.InvokePromptAsync("Test prompt", arguments); + } + + // Assert + Assert.Equal(isStreaming, actualStreamingFlag); + } + public void Dispose() { this._httpClient.Dispose(); diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs index ac52d8361307..ff160b0dfcaf 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs @@ -209,7 +209,9 @@ internal async Task> GetChatMessageContentsAsy (FunctionCallContent content) => IsRequestableTool(chatOptions.Tools, content), functionCallingConfig.Options ?? new FunctionChoiceBehaviorOptions(), kernel, + isStreaming: false, cancellationToken).ConfigureAwait(false); + if (lastMessage != null) { return [lastMessage]; @@ -388,7 +390,9 @@ internal async IAsyncEnumerable GetStreamingC (FunctionCallContent content) => IsRequestableTool(chatOptions.Tools, content), functionCallingConfig.Options ?? new FunctionChoiceBehaviorOptions(), kernel, + isStreaming: true, cancellationToken).ConfigureAwait(false); + if (lastMessage != null) { yield return new OpenAIStreamingChatMessageContent(lastMessage.Role, lastMessage.Content); diff --git a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs index ffce85072e6b..a1c92b842669 100644 --- a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs +++ b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs @@ -131,6 +131,7 @@ public FunctionCallsProcessor(ILogger? logger = null) /// Callback to check if a function was advertised to AI model or not. /// Function choice behavior options. /// The . + /// Boolean flag which indicates whether an operation is invoked within streaming or non-streaming mode. /// The to monitor for cancellation requests. /// Last chat history message if function invocation filter requested processing termination, otherwise null. public async Task ProcessFunctionCallsAsync( @@ -140,6 +141,7 @@ public FunctionCallsProcessor(ILogger? logger = null) Func checkIfFunctionAdvertised, FunctionChoiceBehaviorOptions options, Kernel? kernel, + bool isStreaming, CancellationToken cancellationToken) { var functionCalls = FunctionCallContent.GetFunctionCalls(chatMessageContent).ToList(); @@ -201,7 +203,9 @@ public FunctionCallsProcessor(ILogger? logger = null) Arguments = functionCall.Arguments, RequestSequenceIndex = requestIndex, FunctionSequenceIndex = functionCallIndex, - FunctionCount = functionCalls.Count + FunctionCount = functionCalls.Count, + CancellationToken = cancellationToken, + IsStreaming = isStreaming }; var functionTask = Task.Run<(string? Result, string? ErrorMessage, FunctionCallContent FunctionCall, bool Terminate)>(async () => diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index 6710d18070a8..68be900e1389 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -46,6 +46,11 @@ public AutoFunctionInvocationContext( /// public CancellationToken CancellationToken { get; init; } + /// + /// Boolean flag which indicates whether a filter is invoked within streaming or non-streaming mode. + /// + public bool IsStreaming { get; init; } + /// /// Gets the arguments associated with the operation. /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/Function/FunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/Function/FunctionInvocationContext.cs index 2c7e92166ed0..a358c1a3d22f 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/Function/FunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/Function/FunctionInvocationContext.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Diagnostics.CodeAnalysis; using System.Threading; namespace Microsoft.SemanticKernel; @@ -34,6 +35,12 @@ internal FunctionInvocationContext(Kernel kernel, KernelFunction function, Kerne /// public CancellationToken CancellationToken { get; init; } + /// + /// Boolean flag which indicates whether a filter is invoked within streaming or non-streaming mode. + /// + [Experimental("SKEXP0001")] + public bool IsStreaming { get; init; } + /// /// Gets the containing services, plugins, and other state for use throughout the operation. /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs index ee64d0a01f09..2b04e9afc540 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Diagnostics.CodeAnalysis; using System.Threading; namespace Microsoft.SemanticKernel; @@ -33,6 +34,12 @@ internal PromptRenderContext(Kernel kernel, KernelFunction function, KernelArgum /// public CancellationToken CancellationToken { get; init; } + /// + /// Boolean flag which indicates whether a filter is invoked within streaming or non-streaming mode. + /// + [Experimental("SKEXP0001")] + public bool IsStreaming { get; init; } + /// /// Gets the containing services, plugins, and other state for use throughout the operation. /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs index 759e83235699..9c851bd2dfa0 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs @@ -250,7 +250,7 @@ public async Task InvokeAsync( throw new OperationCanceledException($"A {nameof(Kernel)}.{nameof(Kernel.FunctionInvoking)} event handler requested cancellation before function invocation."); } - var invocationContext = await kernel.OnFunctionInvocationAsync(this, arguments, functionResult, async (context) => + var invocationContext = await kernel.OnFunctionInvocationAsync(this, arguments, functionResult, isStreaming: false, async (context) => { // Invoking the function and updating context with result. context.Result = functionResult = await this.InvokeCoreAsync(kernel, context.Arguments, cancellationToken).ConfigureAwait(false); @@ -382,7 +382,7 @@ public async IAsyncEnumerable InvokeStreamingAsync( FunctionResult functionResult = new(this, culture: kernel.Culture); - var invocationContext = await kernel.OnFunctionInvocationAsync(this, arguments, functionResult, (context) => + var invocationContext = await kernel.OnFunctionInvocationAsync(this, arguments, functionResult, isStreaming: true, (context) => { // Invoke the function and get its streaming enumerable. var enumerable = this.InvokeStreamingCoreAsync(kernel, context.Arguments, cancellationToken); diff --git a/dotnet/src/SemanticKernel.Abstractions/Kernel.cs b/dotnet/src/SemanticKernel.Abstractions/Kernel.cs index 987766feda4f..5a44a4dffd6a 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Kernel.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Kernel.cs @@ -311,12 +311,14 @@ internal async Task OnFunctionInvocationAsync( KernelFunction function, KernelArguments arguments, FunctionResult functionResult, + bool isStreaming, Func functionCallback, CancellationToken cancellationToken) { FunctionInvocationContext context = new(this, function, arguments, functionResult) { - CancellationToken = cancellationToken + CancellationToken = cancellationToken, + IsStreaming = isStreaming }; await InvokeFilterOrFunctionAsync(this._functionInvocationFilters, functionCallback, context).ConfigureAwait(false); @@ -351,12 +353,14 @@ await functionFilters[index].OnFunctionInvocationAsync(context, internal async Task OnPromptRenderAsync( KernelFunction function, KernelArguments arguments, + bool isStreaming, Func renderCallback, CancellationToken cancellationToken) { PromptRenderContext context = new(this, function, arguments) { - CancellationToken = cancellationToken + CancellationToken = cancellationToken, + IsStreaming = isStreaming }; await InvokeFilterOrPromptRenderAsync(this._promptRenderFilters, renderCallback, context).ConfigureAwait(false); diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs index 2f17bb8fcadc..8652cfa1cbfe 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs @@ -235,7 +235,11 @@ protected override async ValueTask InvokeCoreAsync( { this.AddDefaultValues(arguments); - var promptRenderingResult = await this.RenderPromptAsync(kernel, arguments, cancellationToken).ConfigureAwait(false); + var promptRenderingResult = await this.RenderPromptAsync( + kernel, + arguments, + isStreaming: false, + cancellationToken).ConfigureAwait(false); #pragma warning disable CS0612 // Events are deprecated if (promptRenderingResult.RenderedEventArgs?.Cancel is true) @@ -268,7 +272,11 @@ protected override async IAsyncEnumerable InvokeStreamingCoreAsync RenderPromptAsync(Kernel kernel, KernelArguments arguments, CancellationToken cancellationToken) + private async Task RenderPromptAsync( + Kernel kernel, + KernelArguments arguments, + bool isStreaming, + CancellationToken cancellationToken) { var serviceSelector = kernel.ServiceSelector; @@ -506,7 +518,7 @@ private async Task RenderPromptAsync(Kernel kernel, Kerne kernel.OnPromptRendering(this, arguments); #pragma warning restore CS0618 // Events are deprecated - var renderingContext = await kernel.OnPromptRenderAsync(this, arguments, async (context) => + var renderingContext = await kernel.OnPromptRenderAsync(this, arguments, isStreaming, async (context) => { renderedPrompt = await this._promptTemplate.RenderAsync(kernel, context.Arguments, cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/SemanticKernel.UnitTests/Filters/FunctionInvocationFilterTests.cs b/dotnet/src/SemanticKernel.UnitTests/Filters/FunctionInvocationFilterTests.cs index 99d81af29c4e..6ad0d67aa04f 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Filters/FunctionInvocationFilterTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Filters/FunctionInvocationFilterTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Globalization; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -1052,4 +1053,34 @@ public async Task FilterContextHasCancellationTokenAsync() Assert.NotNull(exception.FunctionResult); Assert.Equal("Result", exception.FunctionResult.ToString()); } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming) + { + // Arrange + bool? actualStreamingFlag = null; + + var function = KernelFunctionFactory.CreateFromMethod(() => "Result"); + + var kernel = this.GetKernelWithFilters(onFunctionInvocation: async (context, next) => + { + actualStreamingFlag = context.IsStreaming; + await next(context); + }); + + // Act + if (isStreaming) + { + await kernel.InvokeStreamingAsync(function).ToListAsync(); + } + else + { + await kernel.InvokeAsync(function); + } + + // Assert + Assert.Equal(isStreaming, actualStreamingFlag); + } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs b/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs index 4cb0c46082b7..3a0f1e627bd6 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -289,4 +290,35 @@ public async Task FilterContextHasCancellationTokenAsync() await Assert.ThrowsAsync(() => kernel.InvokeAsync(function, cancellationToken: cancellationTokenSource.Token)); } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming) + { + // Arrange + bool? actualStreamingFlag = null; + + var mockTextGeneration = this.GetMockTextGeneration(); + + var kernel = this.GetKernelWithFilters(textGenerationService: mockTextGeneration.Object, + onPromptRender: async (context, next) => + { + actualStreamingFlag = context.IsStreaming; + await next(context); + }); + + // Act + if (isStreaming) + { + await kernel.InvokePromptStreamingAsync("Prompt").ToListAsync(); + } + else + { + await kernel.InvokePromptAsync("Prompt"); + } + + // Assert + Assert.Equal(isStreaming, actualStreamingFlag); + } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs index 41f92d08071c..a4111bc9b5c0 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs @@ -99,6 +99,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel, + isStreaming: false, cancellationToken: CancellationToken.None); } @@ -127,6 +128,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: CreateKernel(), + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -154,6 +156,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: CreateKernel(), + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -186,6 +189,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel, + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -213,6 +217,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => false, // Return false to simulate that the function is not advertised options: this._functionChoiceBehaviorOptions, kernel: CreateKernel(), + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -240,6 +245,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: CreateKernel(), + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -280,6 +286,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel, + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -345,6 +352,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel!, + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -433,6 +441,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel!, + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -480,6 +489,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel!, + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -531,6 +541,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel!, + isStreaming: false, cancellationToken: CancellationToken.None); var firstFunctionResult = chatHistory[^2].Content; @@ -582,6 +593,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel!, + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -627,6 +639,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel!, + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -670,6 +683,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel!, + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -723,6 +737,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel, + isStreaming: false, cancellationToken: CancellationToken.None); // Assert @@ -757,6 +772,7 @@ await this._sut.ProcessFunctionCallsAsync( checkIfFunctionAdvertised: (_) => true, options: this._functionChoiceBehaviorOptions, kernel: kernel, + isStreaming: false, cancellationToken: CancellationToken.None); // Assert From d0145347948a0f33eab1c50ea0aeb1dcf3bff89a Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Thu, 31 Oct 2024 19:01:06 +0000 Subject: [PATCH 3/8] .Net: Parallel function calls option (#9487) ### Motivation and Context This PR adds the `FunctionChoiceBehaviorOptions.AllowParallelCalls` option and updates {Azure} OpenAI AI connectors to support it. This option instructs the AI model to generate multiple function calls in a single response when set to true. _"This is especially useful if executing the given functions takes a long time. For example, the model may call functions to get the weather in three different locations at the same time, which will result in a message with three function calls in the tool_calls array."_ **Source** - [Configuring parallel function calling](https://platform.openai.com/docs/guides/function-calling/configuring-parallel-function-calling) Closes: https://github.com/microsoft/semantic-kernel/issues/6636 --- .../FunctionCalling/FunctionCalling.cs | 67 +++++++++++++++++++ .../AzureOpenAIChatCompletionServiceTests.cs | 55 +++++++++++++++ .../Core/AzureClientCore.ChatCompletion.cs | 5 ++ .../OpenAIChatCompletionServiceTests.cs | 52 ++++++++++++++ .../Core/ClientCore.ChatCompletion.cs | 7 +- ...omptExecutionSettingsTypeConverterTests.cs | 44 ++++++++++++ ...pletion_AutoFunctionChoiceBehaviorTests.cs | 47 +++++++++++++ ...pletion_AutoFunctionChoiceBehaviorTests.cs | 47 +++++++++++++ .../FunctionChoiceBehaviorOptions.cs | 12 +++- ...ctionChoiceBehaviorDeserializationTests.cs | 48 +++++++++++++ 10 files changed, 382 insertions(+), 2 deletions(-) diff --git a/dotnet/samples/Concepts/FunctionCalling/FunctionCalling.cs b/dotnet/samples/Concepts/FunctionCalling/FunctionCalling.cs index 70dbe2bdd0ef..9ce10a4ae5ea 100644 --- a/dotnet/samples/Concepts/FunctionCalling/FunctionCalling.cs +++ b/dotnet/samples/Concepts/FunctionCalling/FunctionCalling.cs @@ -46,7 +46,19 @@ namespace FunctionCalling; /// * The option enables concurrent invocation of functions by SK. /// By default, this option is set to false, meaning that functions are invoked sequentially. Concurrent invocation is only possible if the AI model can /// call or select multiple functions for invocation in a single request; otherwise, there is no distinction between sequential and concurrent invocation. +/// * The option instructs the AI model to call multiple functions in one request if the model supports parallel function calls. +/// By default, this option is set to null, meaning that the AI model default value will be used. /// +/// The following table summarizes the effects of different combinations of these options: +/// +/// | AllowParallelCalls | AllowConcurrentInvocation | AI function call requests | Concurrent Invocation | +/// |---------------------|---------------------------|--------------------------------|-----------------------| +/// | false | false | one request per call | false | +/// | false | true | one request per call | false* | +/// | true | false | one request per multiple calls | false | +/// | true | true | one request per multiple calls | true | +/// +/// `*` There's only one function to call /// public class FunctionCalling(ITestOutputHelper output) : BaseTest(output) { @@ -458,6 +470,61 @@ public async Task RunNonStreamingChatCompletionApiWithConcurrentFunctionInvocati // Expected output: Good morning! The current UTC time is 07:47 on October 22, 2024. Here are the latest news headlines: 1. Squirrel Steals Show - Discover the unexpected star of a recent event. 2. Dog Wins Lottery - Unbelievably, a lucky canine has hit the jackpot. } + [Fact] + /// + /// This example demonstrates usage of the non-streaming chat completion API with that + /// advertises all kernel functions to the AI model and instructs the model to call multiple functions in parallel. + /// + public async Task RunNonStreamingChatCompletionApiWithParallelFunctionCallOptionAsync() + { + Kernel kernel = CreateKernel(); + + // The `AllowParallelCalls` option instructs the AI model to call multiple functions in parallel if the model supports parallel function calls. + FunctionChoiceBehaviorOptions options = new() { AllowParallelCalls = true }; + + OpenAIPromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: options) }; + + IChatCompletionService chatCompletionService = kernel.GetRequiredService(); + + ChatMessageContent result = await chatCompletionService.GetChatMessageContentAsync( + "Good morning! What’s the current time and latest news headlines?", + settings, + kernel); + + // Assert + Console.WriteLine(result); + + // Expected output: Good morning! The current UTC time is 07:47 on October 22, 2024. Here are the latest news headlines: 1. Squirrel Steals Show - Discover the unexpected star of a recent event. 2. Dog Wins Lottery - Unbelievably, a lucky canine has hit the jackpot. + } + + [Fact] + /// + /// This example demonstrates usage of the non-streaming chat completion API with that + /// advertises all kernel functions to the AI model, instructs the model to call multiple functions in parallel, and invokes them concurrently. + /// + public async Task RunNonStreamingChatCompletionApiWithParallelFunctionCallAndConcurrentFunctionInvocationOptionsAsync() + { + Kernel kernel = CreateKernel(); + + // The `AllowParallelCalls` option instructs the AI model to call multiple functions in parallel if the model supports parallel function calls. + // The `AllowConcurrentInvocation` option enables concurrent invocation of the functions. + FunctionChoiceBehaviorOptions options = new() { AllowParallelCalls = true, AllowConcurrentInvocation = true }; + + OpenAIPromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: options) }; + + IChatCompletionService chatCompletionService = kernel.GetRequiredService(); + + ChatMessageContent result = await chatCompletionService.GetChatMessageContentAsync( + "Good morning! What’s the current time and latest news headlines?", + settings, + kernel); + + // Assert + Console.WriteLine(result); + + // Expected output: Good morning! The current UTC time is 07:47 on October 22, 2024. Here are the latest news headlines: 1. Squirrel Steals Show - Discover the unexpected star of a recent event. 2. Dog Wins Lottery - Unbelievably, a lucky canine has hit the jackpot. + } + private static Kernel CreateKernel() { // Create kernel diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs index 995d8c7e4913..074018f14fe6 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs @@ -1342,6 +1342,61 @@ public async Task ItCreatesCorrectFunctionToolCallsWhenUsingRequiredFunctionChoi Assert.Equal("required", optionsJson.GetProperty("tool_choice").ToString()); } + [Theory] + [InlineData("auto", true)] + [InlineData("auto", false)] + [InlineData("auto", null)] + [InlineData("required", true)] + [InlineData("required", false)] + [InlineData("required", null)] + public async Task ItPassesAllowParallelCallsOptionToLLMAsync(string choice, bool? optionValue) + { + // Arrange + var kernel = new Kernel(); + kernel.Plugins.AddFromFunctions("TimePlugin", [ + KernelFunctionFactory.CreateFromMethod(() => { }, "Date"), + KernelFunctionFactory.CreateFromMethod(() => { }, "Now") + ]); + + var sut = new AzureOpenAIChatCompletionService("deployment", "https://endpoint", "api-key", "model-id", this._httpClient); + + using var responseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(AzureOpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) + }; + this._messageHandlerStub.ResponsesToReturn.Add(responseMessage); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Fake prompt"); + + var functionChoiceBehaviorOptions = new FunctionChoiceBehaviorOptions() { AllowParallelCalls = optionValue }; + + var executionSettings = new OpenAIPromptExecutionSettings() + { + FunctionChoiceBehavior = choice switch + { + "auto" => FunctionChoiceBehavior.Auto(options: functionChoiceBehaviorOptions), + "required" => FunctionChoiceBehavior.Required(options: functionChoiceBehaviorOptions), + _ => throw new ArgumentException("Invalid choice", nameof(choice)) + } + }; + + // Act + await sut.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel); + + // Assert + var optionsJson = JsonSerializer.Deserialize(Encoding.UTF8.GetString(this._messageHandlerStub.RequestContents[0]!)); + + if (optionValue is null) + { + Assert.False(optionsJson.TryGetProperty("parallel_tool_calls", out _)); + } + else + { + Assert.Equal(optionValue, optionsJson.GetProperty("parallel_tool_calls").GetBoolean()); + } + } + [Fact] public async Task ItDoesNotChangeDefaultsForToolsAndChoiceIfNeitherOfFunctionCallingConfigurationsSetAsync() { diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI/Core/AzureClientCore.ChatCompletion.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI/Core/AzureClientCore.ChatCompletion.cs index 6627b7482fae..63d46c7c77e2 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI/Core/AzureClientCore.ChatCompletion.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI/Core/AzureClientCore.ChatCompletion.cs @@ -90,6 +90,11 @@ protected override ChatCompletionOptions CreateChatCompletionOptions( } } + if (toolCallingConfig.Options?.AllowParallelCalls is not null) + { + options.AllowParallelToolCalls = toolCallingConfig.Options.AllowParallelCalls; + } + return options; } } diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs index 943e8e577b7d..80b2ad0331c2 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs @@ -1371,6 +1371,58 @@ public async Task ItCreatesCorrectFunctionToolCallsWhenUsingRequiredFunctionChoi Assert.Equal("required", optionsJson.GetProperty("tool_choice").ToString()); } + [Theory] + [InlineData("auto", true)] + [InlineData("auto", false)] + [InlineData("auto", null)] + [InlineData("required", true)] + [InlineData("required", false)] + [InlineData("required", null)] + public async Task ItPassesAllowParallelCallsOptionToLLMAsync(string choice, bool? optionValue) + { + // Arrange + var kernel = new Kernel(); + kernel.Plugins.AddFromFunctions("TimePlugin", [ + KernelFunctionFactory.CreateFromMethod(() => { }, "Date"), + KernelFunctionFactory.CreateFromMethod(() => { }, "Now") + ]); + + var chatCompletion = new OpenAIChatCompletionService(modelId: "gpt-3.5-turbo", apiKey: "NOKEY", httpClient: this._httpClient); + + using var response = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponseQueue.Enqueue(response); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Fake prompt"); + + var functionChoiceBehaviorOptions = new FunctionChoiceBehaviorOptions() { AllowParallelCalls = optionValue }; + + var executionSettings = new OpenAIPromptExecutionSettings() + { + FunctionChoiceBehavior = choice switch + { + "auto" => FunctionChoiceBehavior.Auto(options: functionChoiceBehaviorOptions), + "required" => FunctionChoiceBehavior.Required(options: functionChoiceBehaviorOptions), + _ => throw new ArgumentException("Invalid choice", nameof(choice)) + } + }; + + // Act + await chatCompletion.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel); + + // Assert + var optionsJson = JsonSerializer.Deserialize(Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!)); + + if (optionValue is null) + { + Assert.False(optionsJson.TryGetProperty("parallel_tool_calls", out _)); + } + else + { + Assert.Equal(optionValue, optionsJson.GetProperty("parallel_tool_calls").GetBoolean()); + } + } + [Fact] public async Task ItDoesNotChangeDefaultsForToolsAndChoiceIfNeitherOfFunctionCallingConfigurationsSetAsync() { diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs index ff160b0dfcaf..7017ca1eb929 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs @@ -468,7 +468,7 @@ protected virtual ChatCompletionOptions CreateChatCompletionOptions( #pragma warning restore OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. EndUserId = executionSettings.User, TopLogProbabilityCount = executionSettings.TopLogprobs, - IncludeLogProbabilities = executionSettings.Logprobs, + IncludeLogProbabilities = executionSettings.Logprobs }; var responseFormat = GetResponseFormat(executionSettings); @@ -503,6 +503,11 @@ protected virtual ChatCompletionOptions CreateChatCompletionOptions( } } + if (toolCallingConfig.Options?.AllowParallelCalls is not null) + { + options.AllowParallelToolCalls = toolCallingConfig.Options.AllowParallelCalls; + } + return options; } diff --git a/dotnet/src/Functions/Functions.UnitTests/Yaml/PromptExecutionSettingsTypeConverterTests.cs b/dotnet/src/Functions/Functions.UnitTests/Yaml/PromptExecutionSettingsTypeConverterTests.cs index d8c927393ca4..45334b1f39f4 100644 --- a/dotnet/src/Functions/Functions.UnitTests/Yaml/PromptExecutionSettingsTypeConverterTests.cs +++ b/dotnet/src/Functions/Functions.UnitTests/Yaml/PromptExecutionSettingsTypeConverterTests.cs @@ -294,6 +294,50 @@ public void ItShouldDeserializedNoneFunctionChoiceBehaviorFromYamlWithSpecifiedF Assert.Contains(config.Functions, f => f.PluginName == "MyPlugin" && f.Name == "Function3"); } + [Fact] + public void ItShouldDeserializeAutoFunctionChoiceBehaviorFromJsonWithOptions() + { + // Arrange + var yaml = """ + function_choice_behavior: + type: auto + options: + allow_parallel_calls: true + allow_concurrent_invocation: true + """; + + var executionSettings = this._deserializer.Deserialize(yaml); + + // Act + var config = executionSettings!.FunctionChoiceBehavior!.GetConfiguration(new(chatHistory: []) { Kernel = this._kernel }); + + // Assert + Assert.True(config.Options.AllowParallelCalls); + Assert.True(config.Options.AllowConcurrentInvocation); + } + + [Fact] + public void ItShouldDeserializeRequiredFunctionChoiceBehaviorFromJsonWithOptions() + { + // Arrange + var yaml = """ + function_choice_behavior: + type: required + options: + allow_parallel_calls: true + allow_concurrent_invocation: true + """; + + var executionSettings = this._deserializer.Deserialize(yaml); + + // Act + var config = executionSettings!.FunctionChoiceBehavior!.GetConfiguration(new(chatHistory: []) { Kernel = this._kernel }); + + // Assert + Assert.True(config.Options.AllowParallelCalls); + Assert.True(config.Options.AllowConcurrentInvocation); + } + private readonly string _yaml = """ template_format: semantic-kernel template: Say hello world to {{$name}} in {{$language}} diff --git a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs index 32321fb81da9..e3ecebadf687 100644 --- a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs @@ -350,6 +350,53 @@ public async Task SpecifiedInCodeInstructsConnectorToInvokeKernelFunctionsAutoma Assert.True(requestIndexLog.All((item) => item == 0)); // Assert that all functions called by the AI model were executed within the same initial request. } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task SpecifiedInCodeInstructsAIModelToCallFunctionInParallelOrSequentiallyAsync(bool callInParallel) + { + // Arrange + var requestIndexLog = new ConcurrentBag(); + + this._kernel.ImportPluginFromType(); + this._kernel.ImportPluginFromFunctions("WeatherUtils", [KernelFunctionFactory.CreateFromMethod(() => "Rainy day magic!", "GetCurrentWeather")]); + + var invokedFunctions = new ConcurrentBag(); + + this._autoFunctionInvocationFilter.RegisterFunctionInvocationHandler(async (context, next) => + { + requestIndexLog.Add(context.RequestSequenceIndex); + invokedFunctions.Add(context.Function.Name); + + await next(context); + }); + + var settings = new AzureOpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: new() { AllowParallelCalls = callInParallel }) }; + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Give me today's date and weather."); + + // Act + var result = await this._chatCompletionService.GetChatMessageContentAsync(chatHistory, settings, this._kernel); + + // Assert + Assert.NotNull(result); + + Assert.Contains("GetCurrentDate", invokedFunctions); + Assert.Contains("GetCurrentWeather", invokedFunctions); + + if (callInParallel) + { + // Assert that all functions are called within the same initial request. + Assert.True(requestIndexLog.All((item) => item == 0)); + } + else + { + // Assert that all functions are called in separate requests. + Assert.Equal([0, 1], requestIndexLog); + } + } + private Kernel InitializeKernel() { var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs index ab030369ab42..f98918d08eaf 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs @@ -347,6 +347,53 @@ public async Task SpecifiedInCodeInstructsConnectorToInvokeKernelFunctionsAutoma Assert.True(requestIndexLog.All((item) => item == 0)); // Assert that all functions called by the AI model were executed within the same initial request. } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task SpecifiedInCodeInstructsAIModelToCallFunctionInParallelOrSequentiallyAsync(bool callInParallel) + { + // Arrange + var requestIndexLog = new ConcurrentBag(); + + this._kernel.ImportPluginFromType(); + this._kernel.ImportPluginFromFunctions("WeatherUtils", [KernelFunctionFactory.CreateFromMethod(() => "Rainy day magic!", "GetCurrentWeather")]); + + var invokedFunctions = new ConcurrentBag(); + + this._autoFunctionInvocationFilter.RegisterFunctionInvocationHandler(async (context, next) => + { + requestIndexLog.Add(context.RequestSequenceIndex); + invokedFunctions.Add(context.Function.Name); + + await next(context); + }); + + var settings = new OpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: new() { AllowParallelCalls = callInParallel }) }; + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Give me today's date and weather."); + + // Act + var result = await this._chatCompletionService.GetChatMessageContentAsync(chatHistory, settings, this._kernel); + + // Assert + Assert.NotNull(result); + + Assert.Contains("GetCurrentDate", invokedFunctions); + Assert.Contains("GetCurrentWeather", invokedFunctions); + + if (callInParallel) + { + // Assert that all functions are called within the same initial request. + Assert.True(requestIndexLog.All((item) => item == 0)); + } + else + { + // Assert that all functions are called in separate requests. + Assert.Equal([0, 1], requestIndexLog); + } + } + private Kernel InitializeKernel() { var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/FunctionChoiceBehaviors/FunctionChoiceBehaviorOptions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/FunctionChoiceBehaviors/FunctionChoiceBehaviorOptions.cs index 870cc75616ec..ecb3988b9611 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/FunctionChoiceBehaviors/FunctionChoiceBehaviorOptions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/FunctionChoiceBehaviors/FunctionChoiceBehaviorOptions.cs @@ -11,7 +11,17 @@ namespace Microsoft.SemanticKernel; [Experimental("SKEXP0001")] public sealed class FunctionChoiceBehaviorOptions { - /// Gets or sets whether multiple function invocations requested in parallel by the service may be invoked to run concurrently. + /// + /// Gets or sets whether AI model should prefer parallel function calls over sequential ones. + /// If set to true, instructs the model to call multiple functions in one request if the model supports parallel function calls. + /// Otherwise, it will send a request for each function call. If set to null, the AI model default value will be used. + /// + [JsonPropertyName("allow_parallel_calls")] + public bool? AllowParallelCalls { get; set; } = null; + + /// + /// Gets or sets whether multiple function invocations requested in parallel by the service may be invoked to run concurrently. + /// /// /// The default value is set to false. However, if the function invocations are safe to execute concurrently, /// such as when the function does not modify shared state, this setting can be set to true. diff --git a/dotnet/src/SemanticKernel.UnitTests/AI/FunctionChoiceBehaviors/FunctionChoiceBehaviorDeserializationTests.cs b/dotnet/src/SemanticKernel.UnitTests/AI/FunctionChoiceBehaviors/FunctionChoiceBehaviorDeserializationTests.cs index 197640eca0f0..1d8b239f3ee2 100644 --- a/dotnet/src/SemanticKernel.UnitTests/AI/FunctionChoiceBehaviors/FunctionChoiceBehaviorDeserializationTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/AI/FunctionChoiceBehaviors/FunctionChoiceBehaviorDeserializationTests.cs @@ -270,6 +270,54 @@ public void ItShouldDeserializedNoneFunctionChoiceBehaviorFromJsonWithNotEmptyFu Assert.Contains(config.Functions, f => f.PluginName == "MyPlugin" && f.Name == "Function3"); } + [Fact] + public void ItShouldDeserializeAutoFunctionChoiceBehaviorFromJsonWithOptions() + { + // Arrange + var json = """ + { + "type": "auto", + "options": { + "allow_parallel_calls": true, + "allow_concurrent_invocation": true + } + } + """; + + var sut = JsonSerializer.Deserialize(json); + + // Act + var config = sut!.GetConfiguration(new(chatHistory: []) { Kernel = this._kernel }); + + // Assert + Assert.True(config.Options.AllowParallelCalls); + Assert.True(config.Options.AllowConcurrentInvocation); + } + + [Fact] + public void ItShouldDeserializeRequiredFunctionChoiceBehaviorFromJsonWithOptions() + { + // Arrange + var json = """ + { + "type": "required", + "options": { + "allow_parallel_calls": true, + "allow_concurrent_invocation": true + } + } + """; + + var sut = JsonSerializer.Deserialize(json); + + // Act + var config = sut!.GetConfiguration(new(chatHistory: []) { Kernel = this._kernel }); + + // Assert + Assert.True(config.Options.AllowParallelCalls); + Assert.True(config.Options.AllowConcurrentInvocation); + } + private static KernelPlugin GetTestPlugin() { var function1 = KernelFunctionFactory.CreateFromMethod(() => { }, "Function1"); From f0de0b614607db84b162435a82c141f5b710769c Mon Sep 17 00:00:00 2001 From: Ben Thomas Date: Thu, 31 Oct 2024 12:27:58 -0700 Subject: [PATCH 4/8] .Net Processes: Fixing an issue with nested processes in Dapr runtime. (#9491) ### Description This PR addresses an issue when running a nested process in the Dapr runtime. After this changes, the Dapr runtime fully supports nested processes with the ability to pass events in both directions. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --------- Co-authored-by: Chris <66376200+crickman@users.noreply.github.com> --- .../Process.Core/ProcessStepBuilder.cs | 3 +- .../Process.LocalRuntime/LocalProcess.cs | 1 - .../Actors/ProcessActor.cs | 58 ++++++++++++++++--- .../Process.Runtime.Dapr/Actors/StepActor.cs | 6 +- .../Process.Runtime.Dapr/DaprStepInfo.cs | 1 + 5 files changed, 56 insertions(+), 13 deletions(-) diff --git a/dotnet/src/Experimental/Process.Core/ProcessStepBuilder.cs b/dotnet/src/Experimental/Process.Core/ProcessStepBuilder.cs index 27db41d73c0b..5b917fd3fa8d 100644 --- a/dotnet/src/Experimental/Process.Core/ProcessStepBuilder.cs +++ b/dotnet/src/Experimental/Process.Core/ProcessStepBuilder.cs @@ -263,8 +263,9 @@ internal override KernelProcessStepInfo BuildStep(KernelProcessStepStateMetadata throw new KernelException($"The initial state provided for step {this.Name} is not of the correct type. The expected type is {userStateType.Name}."); } + var initialState = this._initialState ?? Activator.CreateInstance(userStateType); stateObject = (KernelProcessStepState?)Activator.CreateInstance(stateType, this.Name, this.Id); - stateType.GetProperty(nameof(KernelProcessStepState.State))?.SetValue(stateObject, this._initialState); + stateType.GetProperty(nameof(KernelProcessStepState.State))?.SetValue(stateObject, initialState); } else { diff --git a/dotnet/src/Experimental/Process.LocalRuntime/LocalProcess.cs b/dotnet/src/Experimental/Process.LocalRuntime/LocalProcess.cs index 3b9edf23651a..1b4ad7c1de07 100644 --- a/dotnet/src/Experimental/Process.LocalRuntime/LocalProcess.cs +++ b/dotnet/src/Experimental/Process.LocalRuntime/LocalProcess.cs @@ -188,7 +188,6 @@ private ValueTask InitializeProcessAsync() kernel: this._kernel, parentProcessId: this.Id); - //await process.StartAsync(kernel: this._kernel, keepAlive: true).ConfigureAwait(false); localStep = process; } else diff --git a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/ProcessActor.cs b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/ProcessActor.cs index e13a33997d4a..51f9098d7b99 100644 --- a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/ProcessActor.cs +++ b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/ProcessActor.cs @@ -271,9 +271,6 @@ private async Task InitializeProcessActorAsync(DaprProcessInfo processInfo, stri private async Task Internal_ExecuteAsync(Kernel? kernel = null, int maxSupersteps = 100, bool keepAlive = true, CancellationToken cancellationToken = default) { - Kernel localKernel = kernel ?? this._kernel; - Queue messageChannel = new(); - try { // Run the Pregel algorithm until there are no more messages being sent. @@ -308,8 +305,7 @@ private async Task Internal_ExecuteAsync(Kernel? kernel = null, int maxSuperstep await Task.WhenAll(stepProcessingTasks).ConfigureAwait(false); // Handle public events that need to be bubbled out of the process. - var eventQueue = this.ProxyFactory.CreateActorProxy(new ActorId(this.Id.GetId()), nameof(EventBufferActor)); - var allEvents = await eventQueue.DequeueAllAsync().ConfigureAwait(false); + await this.SendOutgoingPublicEventsAsync().ConfigureAwait(false); } } catch (Exception ex) @@ -354,6 +350,36 @@ private async Task EnqueueExternalMessagesAsync() } } + /// + /// Public events that are produced inside of this process need to be sent to the parent process. This method reads + /// all of the public events from the event buffer and sends them to the targeted step in the parent process. + /// + private async Task SendOutgoingPublicEventsAsync() + { + // Loop through all steps that are processes and call a function requesting their outgoing events, then queue them up. + if (!string.IsNullOrWhiteSpace(this.ParentProcessId)) + { + // Handle public events that need to be bubbled out of the process. + var eventQueue = this.ProxyFactory.CreateActorProxy(new ActorId(this.Id.GetId()), nameof(EventBufferActor)); + var allEvents = await eventQueue.DequeueAllAsync().ConfigureAwait(false); + + foreach (var e in allEvents) + { + var scopedEvent = this.ScopedEvent(e); + if (this._outputEdges!.TryGetValue(scopedEvent.Id, out List? edges) && edges is not null) + { + foreach (var edge in edges) + { + ProcessMessage message = ProcessMessageFactory.CreateFromEdge(edge, e.Data); + var scopedMessageBufferId = this.ScopedActorId(new ActorId(edge.OutputTarget.StepId), scopeToParent: true); + var messageQueue = this.ProxyFactory.CreateActorProxy(scopedMessageBufferId, nameof(MessageBufferActor)); + await messageQueue.EnqueueAsync(message).ConfigureAwait(false); + } + } + } + } + } + /// /// Determines is the end message has been sent to the process. /// @@ -383,10 +409,28 @@ private async Task ToDaprProcessInfoAsync() /// Scopes the Id of a step within the process to the process. /// /// The actor Id to scope. + /// Indicates if the Id should be scoped to the parent process. /// A new which is scoped to the process. - private ActorId ScopedActorId(ActorId actorId) + private ActorId ScopedActorId(ActorId actorId, bool scopeToParent = false) + { + if (scopeToParent && string.IsNullOrWhiteSpace(this.ParentProcessId)) + { + throw new InvalidOperationException("The parent process Id must be set before scoping to the parent process."); + } + + string id = scopeToParent ? this.ParentProcessId! : this.Id.GetId(); + return new ActorId($"{id}.{actorId.GetId()}"); + } + + /// + /// Generates a scoped event for the step. + /// + /// The event. + /// A with the correctly scoped namespace. + private ProcessEvent ScopedEvent(ProcessEvent daprEvent) { - return new ActorId($"{this.Id}.{actorId.GetId()}"); + Verify.NotNull(daprEvent); + return daprEvent with { Namespace = $"{this.Name}_{this._process!.State.Id}" }; } #endregion diff --git a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/StepActor.cs b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/StepActor.cs index efe8d5007612..9b627ad4d43f 100644 --- a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/StepActor.cs +++ b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/StepActor.cs @@ -378,8 +378,6 @@ private Task InvokeFunction(KernelFunction function, Kernel kern /// The event to emit. internal async ValueTask EmitEventAsync(ProcessEvent daprEvent) { - var scopedEvent = this.ScopedEvent(daprEvent); - // Emit the event out of the process (this one) if it's visibility is public. if (daprEvent.Visibility == KernelProcessEventVisibility.Public) { @@ -387,7 +385,7 @@ internal async ValueTask EmitEventAsync(ProcessEvent daprEvent) { // Emit the event to the parent process var parentProcess = this.ProxyFactory.CreateActorProxy(new ActorId(this.ParentProcessId), nameof(EventBufferActor)); - await parentProcess.EnqueueAsync(scopedEvent).ConfigureAwait(false); + await parentProcess.EnqueueAsync(daprEvent).ConfigureAwait(false); } } @@ -406,7 +404,7 @@ internal async ValueTask EmitEventAsync(ProcessEvent daprEvent) /// /// The event. /// A with the correctly scoped namespace. - internal ProcessEvent ScopedEvent(ProcessEvent daprEvent) + private ProcessEvent ScopedEvent(ProcessEvent daprEvent) { Verify.NotNull(daprEvent, nameof(daprEvent)); return daprEvent with { Namespace = $"{this.Name}_{this.Id}" }; diff --git a/dotnet/src/Experimental/Process.Runtime.Dapr/DaprStepInfo.cs b/dotnet/src/Experimental/Process.Runtime.Dapr/DaprStepInfo.cs index 05d53042705d..a5d63077a08b 100644 --- a/dotnet/src/Experimental/Process.Runtime.Dapr/DaprStepInfo.cs +++ b/dotnet/src/Experimental/Process.Runtime.Dapr/DaprStepInfo.cs @@ -12,6 +12,7 @@ namespace Microsoft.SemanticKernel; /// [KnownType(typeof(KernelProcessEdge))] [KnownType(typeof(KernelProcessStepState))] +[KnownType(typeof(DaprProcessInfo))] public record DaprStepInfo { /// From 303c2022277e923d8db24b1a56858ace27a5c69a Mon Sep 17 00:00:00 2001 From: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com> Date: Thu, 31 Oct 2024 20:01:01 +0000 Subject: [PATCH 5/8] .Net: Version 1.26.0 (#9492) ### Motivation and Context Version bump for 1.26.0 release ### Description ### Contribution Checklist - [ ] The code builds clean without any errors or warnings - [ ] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [ ] All unit tests pass, and I have added new tests where possible - [ ] I didn't break anyone :smile: --- dotnet/nuget/nuget-package.props | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/nuget/nuget-package.props b/dotnet/nuget/nuget-package.props index 98065d22b848..1b389acf97dc 100644 --- a/dotnet/nuget/nuget-package.props +++ b/dotnet/nuget/nuget-package.props @@ -1,7 +1,7 @@ - 1.25.0 + 1.26.0 $(VersionPrefix)-$(VersionSuffix) $(VersionPrefix) From 5ac8460a4250ce3bb4b6df8c0d7d24ad92beb6a6 Mon Sep 17 00:00:00 2001 From: Chris <66376200+crickman@users.noreply.github.com> Date: Thu, 31 Oct 2024 14:31:12 -0700 Subject: [PATCH 6/8] .Net Processes - Add Process-Level Error Handler (#9477) ### Motivation and Context Enabled support for function-specific error-handler step in this PR: https://github.com/microsoft/semantic-kernel/pull/9187 Fixes: https://github.com/microsoft/semantic-kernel/issues/9291 This change provides the ability to define a _process scoped_ error handler (as opposed to function specific). When a function-scoped error-handler is defined, it will take precedence. ### Description ```c# ProcessBuilder process = new(nameof(ProcessFunctionErrorHandledAsync)); ProcessStepBuilder errorStep = process.AddStepFromType(); process.OnError().SendEventTo(new ProcessFunctionTargetBuilder(errorStep)); class ErrorStep : KernelProcessStep { [KernelFunction] public void GlobalErrorHandler(Exception exception) { } } ``` **Notes:** - Switch error handler from passing `Exception` object to a `KernelProcessError` to satisfy serialization expectations - Normalized namespaces for `Internal` shared code - Introduced shared `ProcessConstants` file - Opportunistically converted some `List` creation to `Array` - Opportunistically included parameter name in some `Verify` assertions. - Opportunistically removed a extraneous _not-null_ directives (`!`) - Verified DAPR error handling in demo app (`True` means the expected error handler was invoked): image ### Contribution Checklist - [X] The code builds clean without any errors or warnings - [X] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [X] All unit tests pass, and I have added new tests where possible - [X] I didn't break anyone :smile: --- .../Controllers/ProcessController.cs | 2 +- .../Step04/Steps/RenderMessageStep.cs | 8 +- .../KernelProcessError.cs | 39 ++++++ .../KernelProcessStepInfo.cs | 2 +- .../Process.Core/Internal/EndStep.cs | 17 +-- .../Process.Core/ProcessBuilder.cs | 14 ++ .../Process.Core/ProcessStepBuilder.cs | 2 +- .../Process.Core/ProcessStepEdgeBuilder.cs | 3 +- .../Process.LocalRuntime/LocalProcess.cs | 27 +++- .../Process.LocalRuntime/LocalStep.cs | 44 ++++--- .../Actors/EventBufferActor.cs | 5 +- .../Actors/MessageBufferActor.cs | 5 +- .../Actors/ProcessActor.cs | 59 +++++++-- .../Process.Runtime.Dapr/Actors/StepActor.cs | 66 ++++++---- .../DaprKernelProcessFactory.cs | 4 +- .../Process.Runtime.Dapr/DaprStepInfo.cs | 3 +- .../Interfaces/IEventBuffer.cs | 2 +- .../Interfaces/IMessageBuffer.cs | 2 +- .../Core/ProcessBuilderTests.cs | 35 +++++ .../Core/ProcessStepBuilderTests.cs | 1 - .../Runtime.Local/LocalProcessTests.cs | 120 ++++++++++++++++++ .../Process.Utilities.UnitTests/CloneTests.cs | 2 +- .../ProcessTypeExtensionsTests.cs | 1 + .../Abstractions/ExceptionExtensions.cs | 2 +- .../KernelProcessStepExtension.cs | 2 +- .../process/Abstractions/ProcessConstants.cs | 15 +++ .../process/Abstractions/ProcessExtensions.cs | 2 +- .../process/Abstractions/StepExtensions.cs | 2 +- .../process/Runtime/ProcessEvent.cs | 13 +- .../process/Runtime/ProcessMessage.cs | 1 + 30 files changed, 395 insertions(+), 105 deletions(-) create mode 100644 dotnet/src/Experimental/Process.Abstractions/KernelProcessError.cs create mode 100644 dotnet/src/InternalUtilities/process/Abstractions/ProcessConstants.cs diff --git a/dotnet/samples/Demos/ProcessWithDapr/Controllers/ProcessController.cs b/dotnet/samples/Demos/ProcessWithDapr/Controllers/ProcessController.cs index b50cd8fba34a..21bf44ae717c 100644 --- a/dotnet/samples/Demos/ProcessWithDapr/Controllers/ProcessController.cs +++ b/dotnet/samples/Demos/ProcessWithDapr/Controllers/ProcessController.cs @@ -32,7 +32,7 @@ public ProcessController(Kernel kernel) public async Task PostAsync(string processId) { var process = this.GetProcess(); - var processContext = await process.StartAsync(this._kernel, new KernelProcessEvent() { Id = CommonEvents.StartProcess }, processId: processId); + var processContext = await process.StartAsync(new KernelProcessEvent() { Id = CommonEvents.StartProcess }, processId: processId); var finalState = await processContext.GetStateAsync(); return this.Ok(processId); diff --git a/dotnet/samples/GettingStartedWithProcesses/Step04/Steps/RenderMessageStep.cs b/dotnet/samples/GettingStartedWithProcesses/Step04/Steps/RenderMessageStep.cs index 938a2c4a6ea8..684bdc29bda9 100644 --- a/dotnet/samples/GettingStartedWithProcesses/Step04/Steps/RenderMessageStep.cs +++ b/dotnet/samples/GettingStartedWithProcesses/Step04/Steps/RenderMessageStep.cs @@ -41,11 +41,11 @@ public void RenderDone() /// Render exception /// [KernelFunction] - public void RenderError(Exception exception, ILogger logger) + public void RenderError(KernelProcessError error, ILogger logger) { - string message = string.IsNullOrWhiteSpace(exception.Message) ? "Unexpected failure" : exception.Message; - Render($"ERROR: {message} [{exception.GetType().Name}]{Environment.NewLine}{exception.StackTrace}"); - logger.LogError(exception, "Unexpected failure."); + string message = string.IsNullOrWhiteSpace(error.Message) ? "Unexpected failure" : error.Message; + Render($"ERROR: {message} [{error.GetType().Name}]{Environment.NewLine}{error.StackTrace}"); + logger.LogError("Unexpected failure: {ErrorMessage} [{ErrorType}]", error.Message, error.Type); } /// diff --git a/dotnet/src/Experimental/Process.Abstractions/KernelProcessError.cs b/dotnet/src/Experimental/Process.Abstractions/KernelProcessError.cs new file mode 100644 index 000000000000..3af07e70d384 --- /dev/null +++ b/dotnet/src/Experimental/Process.Abstractions/KernelProcessError.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. +using System; +using System.Runtime.Serialization; + +namespace Microsoft.SemanticKernel; + +/// +/// Represents an failure that occurred during the execution of a process. +/// +/// +/// Initializes a new instance of the class. +/// +/// The exception type name +/// The exception message ( +/// The exception stack-trace ( +[DataContract] +public sealed record KernelProcessError( + [property:DataMember] + string Type, + [property:DataMember] + string Message, + [property:DataMember] + string? StackTrace) +{ + /// + /// The inner failure, when exists, as . + /// + [DataMember] + public KernelProcessError? InnerError { get; init; } + + /// + /// Factory method to create a from a source object. + /// + public static KernelProcessError FromException(Exception ex) => + new(ex.GetType().Name, ex.Message, ex.StackTrace) + { + InnerError = ex.InnerException is not null ? FromException(ex.InnerException) : null + }; +} diff --git a/dotnet/src/Experimental/Process.Abstractions/KernelProcessStepInfo.cs b/dotnet/src/Experimental/Process.Abstractions/KernelProcessStepInfo.cs index 88e1d4cfdd3c..26b76de4604c 100644 --- a/dotnet/src/Experimental/Process.Abstractions/KernelProcessStepInfo.cs +++ b/dotnet/src/Experimental/Process.Abstractions/KernelProcessStepInfo.cs @@ -3,7 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; -using Microsoft.SemanticKernel.Process; +using Microsoft.SemanticKernel.Process.Internal; using Microsoft.SemanticKernel.Process.Models; namespace Microsoft.SemanticKernel; diff --git a/dotnet/src/Experimental/Process.Core/Internal/EndStep.cs b/dotnet/src/Experimental/Process.Core/Internal/EndStep.cs index 432aecf33128..cf1b4b770c75 100644 --- a/dotnet/src/Experimental/Process.Core/Internal/EndStep.cs +++ b/dotnet/src/Experimental/Process.Core/Internal/EndStep.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using Microsoft.SemanticKernel.Process.Internal; using Microsoft.SemanticKernel.Process.Models; namespace Microsoft.SemanticKernel; @@ -10,18 +11,6 @@ namespace Microsoft.SemanticKernel; /// internal sealed class EndStep : ProcessStepBuilder { - private const string EndStepValue = "Microsoft.SemanticKernel.Process.EndStep"; - - /// - /// The name of the end step. - /// - public const string EndStepName = EndStepValue; - - /// - /// The event ID for stopping a process. - /// - public const string EndStepId = EndStepValue; - /// /// The static instance of the class. /// @@ -31,7 +20,7 @@ internal sealed class EndStep : ProcessStepBuilder /// Represents the end of a process. /// internal EndStep() - : base(EndStepName) + : base(ProcessConstants.EndStepName) { } @@ -49,6 +38,6 @@ internal override KernelProcessStepInfo BuildStep() internal override KernelProcessStepInfo BuildStep(KernelProcessStepStateMetadata? stateMetadata) { // The end step has no state. - return new KernelProcessStepInfo(typeof(KernelProcessStepState), new KernelProcessStepState(EndStepName), []); + return new KernelProcessStepInfo(typeof(KernelProcessStepState), new KernelProcessStepState(ProcessConstants.EndStepName), []); } } diff --git a/dotnet/src/Experimental/Process.Core/ProcessBuilder.cs b/dotnet/src/Experimental/Process.Core/ProcessBuilder.cs index ff5130f47db4..e905764b8096 100644 --- a/dotnet/src/Experimental/Process.Core/ProcessBuilder.cs +++ b/dotnet/src/Experimental/Process.Core/ProcessBuilder.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.SemanticKernel.Process.Internal; using Microsoft.SemanticKernel.Process.Models; namespace Microsoft.SemanticKernel; @@ -177,6 +178,19 @@ public ProcessEdgeBuilder OnInputEvent(string eventId) return new ProcessEdgeBuilder(this, eventId); } + /// + /// Provides an instance of for defining an edge to a + /// step that responds to an unhandled process error. + /// + /// An instance of + /// + /// To target a specific error source, use the on the step. + /// + public ProcessEdgeBuilder OnError() + { + return new ProcessEdgeBuilder(this, ProcessConstants.GlobalErrorEventId); + } + /// /// Retrieves the target for a given external event. The step associated with the target is the process itself (this). /// diff --git a/dotnet/src/Experimental/Process.Core/ProcessStepBuilder.cs b/dotnet/src/Experimental/Process.Core/ProcessStepBuilder.cs index 5b917fd3fa8d..04d27023ce4f 100644 --- a/dotnet/src/Experimental/Process.Core/ProcessStepBuilder.cs +++ b/dotnet/src/Experimental/Process.Core/ProcessStepBuilder.cs @@ -4,7 +4,7 @@ using System.Collections.Generic; using System.Linq; using System.Text.Json; -using Microsoft.SemanticKernel.Process; +using Microsoft.SemanticKernel.Process.Internal; using Microsoft.SemanticKernel.Process.Models; namespace Microsoft.SemanticKernel; diff --git a/dotnet/src/Experimental/Process.Core/ProcessStepEdgeBuilder.cs b/dotnet/src/Experimental/Process.Core/ProcessStepEdgeBuilder.cs index fdcbe5d402ca..2e4afbfa51e9 100644 --- a/dotnet/src/Experimental/Process.Core/ProcessStepEdgeBuilder.cs +++ b/dotnet/src/Experimental/Process.Core/ProcessStepEdgeBuilder.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.SemanticKernel.Process.Internal; namespace Microsoft.SemanticKernel; @@ -76,6 +77,6 @@ public void StopProcess() var outputTarget = new ProcessFunctionTargetBuilder(EndStep.Instance); this.Target = outputTarget; - this.Source.LinkTo(EndStep.EndStepName, this); + this.Source.LinkTo(ProcessConstants.EndStepName, this); } } diff --git a/dotnet/src/Experimental/Process.LocalRuntime/LocalProcess.cs b/dotnet/src/Experimental/Process.LocalRuntime/LocalProcess.cs index 1b4ad7c1de07..4286b482579e 100644 --- a/dotnet/src/Experimental/Process.LocalRuntime/LocalProcess.cs +++ b/dotnet/src/Experimental/Process.LocalRuntime/LocalProcess.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.SemanticKernel.Process.Internal; using Microsoft.SemanticKernel.Process.Runtime; using Microsoft.VisualStudio.Threading; @@ -15,7 +16,6 @@ namespace Microsoft.SemanticKernel; internal sealed class LocalProcess : LocalStep, IDisposable { - private const string EndProcessId = "Microsoft.SemanticKernel.Process.EndStep"; private readonly JoinableTaskFactory _joinableTaskFactory; private readonly JoinableTaskContext _joinableTaskContext; private readonly Channel _externalEventChannel; @@ -240,11 +240,11 @@ private async Task Internal_ExecuteAsync(Kernel? kernel = null, int maxSuperstep } // Complete the writing side, indicating no more messages in this superstep. - var messagesToProcess = messageChannel.ToList(); + var messagesToProcess = messageChannel.ToArray(); messageChannel.Clear(); // If there are no messages to process, wait for an external event. - if (messagesToProcess.Count == 0) + if (messagesToProcess.Length == 0) { if (!keepAlive || !await this._externalEventChannel.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) { @@ -257,7 +257,7 @@ private async Task Internal_ExecuteAsync(Kernel? kernel = null, int maxSuperstep foreach (var message in messagesToProcess) { // Check for end condition - if (message.DestinationId.Equals(EndProcessId, StringComparison.OrdinalIgnoreCase)) + if (message.DestinationId.Equals(ProcessConstants.EndStepName, StringComparison.OrdinalIgnoreCase)) { this._processCancelSource?.Cancel(); break; @@ -320,7 +320,7 @@ private void EnqueueExternalMessages(Queue messageChannel) private void EnqueueStepMessages(LocalStep step, Queue messageChannel) { var allStepEvents = step.GetAllEvents(); - foreach (var stepEvent in allStepEvents) + foreach (ProcessEvent stepEvent in allStepEvents) { // Emit the event out of the process (this one) if it's visibility is public. if (stepEvent.Visibility == KernelProcessEventVisibility.Public) @@ -329,10 +329,25 @@ private void EnqueueStepMessages(LocalStep step, Queue messageCh } // Get the edges for the event and queue up the messages to be sent to the next steps. - foreach (var edge in step.GetEdgeForEvent(stepEvent.Id!)) + bool foundEdge = false; + foreach (KernelProcessEdge edge in step.GetEdgeForEvent(stepEvent.Id)) { ProcessMessage message = ProcessMessageFactory.CreateFromEdge(edge, stepEvent.Data); messageChannel.Enqueue(message); + foundEdge = true; + } + + // Error event was raised with no edge to handle it, send it to an edge defined as the global error target. + if (!foundEdge && stepEvent.IsError) + { + if (this._outputEdges.TryGetValue(ProcessConstants.GlobalErrorEventId, out List? edges)) + { + foreach (KernelProcessEdge edge in edges) + { + ProcessMessage message = ProcessMessageFactory.CreateFromEdge(edge, stepEvent.Data); + messageChannel.Enqueue(message); + } + } } } } diff --git a/dotnet/src/Experimental/Process.LocalRuntime/LocalStep.cs b/dotnet/src/Experimental/Process.LocalRuntime/LocalStep.cs index 0e0f06668095..41f6ba552822 100644 --- a/dotnet/src/Experimental/Process.LocalRuntime/LocalStep.cs +++ b/dotnet/src/Experimental/Process.LocalRuntime/LocalStep.cs @@ -8,6 +8,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.SemanticKernel.Process.Internal; using Microsoft.SemanticKernel.Process.Runtime; namespace Microsoft.SemanticKernel; @@ -111,9 +112,17 @@ internal IEnumerable GetEdgeForEvent(string eventId) /// /// The event to emit. /// A - public ValueTask EmitEventAsync(KernelProcessEvent processEvent) + public ValueTask EmitEventAsync(KernelProcessEvent processEvent) => this.EmitEventAsync(processEvent, isError: false); + + /// + /// Emits an event from the step. + /// + /// The event to emit. + /// Flag indicating if the event being emitted is in response to a step failure + /// A + internal ValueTask EmitEventAsync(KernelProcessEvent processEvent, bool isError) { - this.EmitEvent(ProcessEvent.FromKernelProcessEvent(processEvent, this._eventNamespace)); + this.EmitEvent(new ProcessEvent(this._eventNamespace, processEvent, isError)); return default; } @@ -148,7 +157,7 @@ internal virtual async Task HandleMessageAsync(ProcessMessage message) if (!this._inputs.TryGetValue(message.FunctionName, out Dictionary? functionParameters)) { - this._inputs[message.FunctionName] = new(); + this._inputs[message.FunctionName] = []; functionParameters = this._inputs[message.FunctionName]; } @@ -179,28 +188,31 @@ internal virtual async Task HandleMessageAsync(ProcessMessage message) throw new ArgumentException($"Function {targetFunction} not found in plugin {this.Name}"); } - FunctionResult? invokeResult = null; - string? eventName = null; - object? eventValue = null; - // Invoke the function, catching all exceptions that it may throw, and then post the appropriate event. #pragma warning disable CA1031 // Do not catch general exception types try { - invokeResult = await this.InvokeFunction(function, this._kernel, arguments).ConfigureAwait(false); - eventName = $"{targetFunction}.OnResult"; - eventValue = invokeResult?.GetValue(); + FunctionResult invokeResult = await this.InvokeFunction(function, this._kernel, arguments).ConfigureAwait(false); + await this.EmitEventAsync( + new KernelProcessEvent + { + Id = $"{targetFunction}.OnResult", + Data = invokeResult.GetValue(), + }).ConfigureAwait(false); } catch (Exception ex) { this._logger.LogError("Error in Step {StepName}: {ErrorMessage}", this.Name, ex.Message); - eventName = $"{targetFunction}.OnError"; - eventValue = ex; + await this.EmitEventAsync( + new KernelProcessEvent + { + Id = $"{targetFunction}.OnError", + Data = KernelProcessError.FromException(ex), + }, + isError: true).ConfigureAwait(false); } finally { - await this.EmitEventAsync(new KernelProcessEvent { Id = eventName, Data = eventValue }).ConfigureAwait(false); - // Reset the inputs for the function that was just executed this._inputs[targetFunction] = new(this._initialInputs[targetFunction] ?? []); } @@ -216,7 +228,7 @@ protected virtual async ValueTask InitializeStepAsync() { // Instantiate an instance of the inner step object KernelProcessStep stepInstance = (KernelProcessStep)ActivatorUtilities.CreateInstance(this._kernel.Services, this._stepInfo.InnerStepType); - var kernelPlugin = KernelPluginFactory.CreateFromObject(stepInstance, pluginName: this._stepInfo.State.Name!); + var kernelPlugin = KernelPluginFactory.CreateFromObject(stepInstance, pluginName: this._stepInfo.State.Name); // Load the kernel functions foreach (KernelFunction f in kernelPlugin) @@ -312,6 +324,6 @@ protected ProcessEvent ScopedEvent(ProcessEvent localEvent) protected ProcessEvent ScopedEvent(KernelProcessEvent processEvent) { Verify.NotNull(processEvent, nameof(processEvent)); - return ProcessEvent.FromKernelProcessEvent(processEvent, $"{this.Name}_{this.Id}"); + return new ProcessEvent($"{this.Name}_{this.Id}", processEvent); } } diff --git a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/EventBufferActor.cs b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/EventBufferActor.cs index 1d61a497bd35..f9c44aee6488 100644 --- a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/EventBufferActor.cs +++ b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/EventBufferActor.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; -using System.Linq; using System.Threading.Tasks; using Dapr.Actors.Runtime; using Microsoft.SemanticKernel.Process.Runtime; @@ -27,10 +26,10 @@ public EventBufferActor(ActorHost host) : base(host) /// Dequeues an event. /// /// A where T is - public async Task> DequeueAllAsync() + public async Task> DequeueAllAsync() { // Dequeue and clear the queue. - var items = this._queue!.ToList(); + var items = this._queue!.ToArray(); this._queue!.Clear(); // Save the state. diff --git a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/MessageBufferActor.cs b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/MessageBufferActor.cs index 65acb3099441..0d3a9e9931ce 100644 --- a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/MessageBufferActor.cs +++ b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/MessageBufferActor.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; -using System.Linq; using System.Threading.Tasks; using Dapr.Actors.Runtime; using Microsoft.SemanticKernel.Process.Runtime; @@ -27,10 +26,10 @@ public MessageBufferActor(ActorHost host) : base(host) /// Dequeues an event. /// /// A where T is - public async Task> DequeueAllAsync() + public async Task> DequeueAllAsync() { // Dequeue and clear the queue. - var items = this._queue!.ToList(); + var items = this._queue!.ToArray(); this._queue!.Clear(); // Save the state. diff --git a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/ProcessActor.cs b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/ProcessActor.cs index 51f9098d7b99..f6fe3f2b63ff 100644 --- a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/ProcessActor.cs +++ b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/ProcessActor.cs @@ -10,6 +10,7 @@ using Dapr.Actors.Runtime; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.SemanticKernel.Process.Internal; using Microsoft.SemanticKernel.Process.Runtime; using Microsoft.VisualStudio.Threading; @@ -17,14 +18,13 @@ namespace Microsoft.SemanticKernel; internal sealed class ProcessActor : StepActor, IProcess, IDisposable { - private const string EndStepId = "Microsoft.SemanticKernel.Process.EndStep"; private readonly JoinableTaskFactory _joinableTaskFactory; private readonly JoinableTaskContext _joinableTaskContext; private readonly Channel _externalEventChannel; internal readonly List _steps = []; - internal List? _stepsInfos; + internal IList? _stepsInfos; internal DaprProcessInfo? _process; private JoinableTask? _processTask; private CancellationTokenSource? _processCancelSource; @@ -82,7 +82,7 @@ public Task StartAsync(bool keepAlive) this._processCancelSource = new CancellationTokenSource(); this._processTask = this._joinableTaskFactory.RunAsync(() - => this.Internal_ExecuteAsync(this._kernel, keepAlive: keepAlive, cancellationToken: this._processCancelSource.Token)); + => this.Internal_ExecuteAsync(keepAlive: keepAlive, cancellationToken: this._processCancelSource.Token)); return Task.CompletedTask; } @@ -206,6 +206,8 @@ internal override async Task HandleMessageAsync(ProcessMessage message) } } + internal static ActorId GetScopedGlobalErrorEventBufferId(string processId) => new($"{ProcessConstants.GlobalErrorEventId}_{processId}"); + #region Private Methods /// @@ -225,7 +227,7 @@ private async Task InitializeProcessActorAsync(DaprProcessInfo processInfo, stri this.ParentProcessId = parentProcessId; this._process = processInfo; - this._stepsInfos = new List(this._process.Steps); + this._stepsInfos = [.. this._process.Steps]; this._logger = this._kernel.LoggerFactory?.CreateLogger(this._process.State.Name) ?? new NullLogger(); // Initialize the input and output edges for the process @@ -269,7 +271,7 @@ private async Task InitializeProcessActorAsync(DaprProcessInfo processInfo, stri this._isInitialized = true; } - private async Task Internal_ExecuteAsync(Kernel? kernel = null, int maxSupersteps = 100, bool keepAlive = true, CancellationToken cancellationToken = default) + private async Task Internal_ExecuteAsync(int maxSupersteps = 100, bool keepAlive = true, CancellationToken cancellationToken = default) { try { @@ -283,11 +285,14 @@ private async Task Internal_ExecuteAsync(Kernel? kernel = null, int maxSuperstep break; } + // Translate any global error events into an message that targets the appropriate step, when one exists. + await this.HandleGlobalErrorMessageAsync().ConfigureAwait(false); + // Check for external events await this.EnqueueExternalMessagesAsync().ConfigureAwait(false); // Reach out to all of the steps in the process and instruct them to retrieve their pending messages from their associated queues. - var stepPreparationTasks = this._steps.Select(step => step.PrepareIncomingMessagesAsync()).ToList(); + var stepPreparationTasks = this._steps.Select(step => step.PrepareIncomingMessagesAsync()).ToArray(); var messageCounts = await Task.WhenAll(stepPreparationTasks).ConfigureAwait(false); // If there are no messages to process, wait for an external event or finish. @@ -301,7 +306,7 @@ private async Task Internal_ExecuteAsync(Kernel? kernel = null, int maxSuperstep } // Process the incoming messages for each step. - var stepProcessingTasks = this._steps.Select(step => step.ProcessIncomingMessagesAsync()).ToList(); + var stepProcessingTasks = this._steps.Select(step => step.ProcessIncomingMessagesAsync()).ToArray(); await Task.WhenAll(stepProcessingTasks).ConfigureAwait(false); // Handle public events that need to be bubbled out of the process. @@ -310,7 +315,7 @@ private async Task Internal_ExecuteAsync(Kernel? kernel = null, int maxSuperstep } catch (Exception ex) { - this._logger?.LogError("An error occurred while running the process: {ErrorMessage}.", ex.Message); + this._logger?.LogError(ex, "An error occurred while running the process: {ErrorMessage}.", ex.Message); throw; } finally @@ -350,6 +355,40 @@ private async Task EnqueueExternalMessagesAsync() } } + /// + /// Check for the presence of an global-error event and any edges defined for processing it. + /// When both exist, the error event is processed and sent to the appropriate targets. + /// + private async Task HandleGlobalErrorMessageAsync() + { + var errorEventQueue = this.ProxyFactory.CreateActorProxy(ProcessActor.GetScopedGlobalErrorEventBufferId(this.Id.GetId()), nameof(EventBufferActor)); + + var errorEvents = await errorEventQueue.DequeueAllAsync().ConfigureAwait(false); + if (errorEvents.Count == 0) + { + // No error events in queue. + return; + } + + var errorEdges = this.GetEdgeForEvent(ProcessConstants.GlobalErrorEventId).ToArray(); + if (errorEdges.Length == 0) + { + // No further action is required when there are no targetes defined for processing the error. + return; + } + + foreach (var errorEdge in errorEdges) + { + foreach (ProcessEvent errorEvent in errorEvents) + { + var errorMessage = ProcessMessageFactory.CreateFromEdge(errorEdge, errorEvent.Data); + var scopedErrorMessageBufferId = this.ScopedActorId(new ActorId(errorEdge.OutputTarget.StepId)); + var errorStepQueue = this.ProxyFactory.CreateActorProxy(scopedErrorMessageBufferId, nameof(MessageBufferActor)); + await errorStepQueue.EnqueueAsync(errorMessage).ConfigureAwait(false); + } + } + } + /// /// Public events that are produced inside of this process need to be sent to the parent process. This method reads /// all of the public events from the event buffer and sends them to the targeted step in the parent process. @@ -386,7 +425,7 @@ private async Task SendOutgoingPublicEventsAsync() /// True if the end message has been sent, otherwise false. private async Task IsEndMessageSentAsync() { - var scopedMessageBufferId = this.ScopedActorId(new ActorId(EndStepId)); + var scopedMessageBufferId = this.ScopedActorId(new ActorId(ProcessConstants.EndStepName)); var endMessageQueue = this.ProxyFactory.CreateActorProxy(scopedMessageBufferId, nameof(MessageBufferActor)); var messages = await endMessageQueue.DequeueAllAsync().ConfigureAwait(false); return messages.Count > 0; @@ -402,7 +441,7 @@ private async Task ToDaprProcessInfoAsync() var processState = new KernelProcessState(this.Name, this.Id.GetId()); var stepTasks = this._steps.Select(step => step.ToDaprStepInfoAsync()).ToList(); var steps = await Task.WhenAll(stepTasks).ConfigureAwait(false); - return new DaprProcessInfo { InnerStepDotnetType = this._process!.InnerStepDotnetType, Edges = this._process!.Edges, State = processState, Steps = steps.ToList() }; + return new DaprProcessInfo { InnerStepDotnetType = this._process!.InnerStepDotnetType, Edges = this._process!.Edges, State = processState, Steps = [.. steps] }; } /// diff --git a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/StepActor.cs b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/StepActor.cs index 9b627ad4d43f..e6c04bf00674 100644 --- a/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/StepActor.cs +++ b/dotnet/src/Experimental/Process.Runtime.Dapr/Actors/StepActor.cs @@ -11,6 +11,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.SemanticKernel.Process.Internal; using Microsoft.SemanticKernel.Process.Runtime; namespace Microsoft.SemanticKernel; @@ -112,7 +113,8 @@ public async Task PrepareIncomingMessagesAsync() { var messageQueue = this.ProxyFactory.CreateActorProxy(new ActorId(this.Id.GetId()), nameof(MessageBufferActor)); var incoming = await messageQueue.DequeueAllAsync().ConfigureAwait(false); - foreach (var message in incoming) + + foreach (ProcessMessage message in incoming) { this._incomingMessages.Enqueue(message); } @@ -190,10 +192,16 @@ protected override async Task OnActivateAsync() /// /// The event to emit. /// A - public ValueTask EmitEventAsync(KernelProcessEvent processEvent) - { - return this.EmitEventAsync(ProcessEvent.FromKernelProcessEvent(processEvent, this._eventNamespace!)); - } + public ValueTask EmitEventAsync(KernelProcessEvent processEvent) => this.EmitEventAsync(processEvent, isError: false); + + /// + /// Emits an event from the step. + /// + /// The event to emit. + /// Flag indicating if the event being emitted is in response to a step failure + /// A + internal ValueTask EmitEventAsync(KernelProcessEvent processEvent, bool isError) => + this.EmitEventAsync(new ProcessEvent(this._eventNamespace, processEvent, isError)); /// /// Handles a that has been sent to the step. @@ -226,7 +234,7 @@ internal virtual async Task HandleMessageAsync(ProcessMessage message) if (!this._inputs.TryGetValue(message.FunctionName, out Dictionary? functionParameters)) { - this._inputs[message.FunctionName] = new(); + this._inputs[message.FunctionName] = []; functionParameters = this._inputs[message.FunctionName]; } @@ -257,36 +265,40 @@ internal virtual async Task HandleMessageAsync(ProcessMessage message) throw new InvalidOperationException($"Function {targetFunction} not found in plugin {this.Name}").Log(this._logger); } - FunctionResult? invokeResult = null; - string? eventName = null; - object? eventValue = null; - // Invoke the function, catching all exceptions that it may throw, and then post the appropriate event. #pragma warning disable CA1031 // Do not catch general exception types try { this?._logger?.LogInformation("Invoking function {FunctionName} with arguments {Arguments}", targetFunction, arguments); - invokeResult = await this.InvokeFunction(function, this._kernel, arguments).ConfigureAwait(false); + FunctionResult invokeResult = await this.InvokeFunction(function, this._kernel, arguments).ConfigureAwait(false); this?.Logger?.LogInformation("Function {FunctionName} returned {Result}", targetFunction, invokeResult); - eventName = $"{targetFunction}.OnResult"; - eventValue = invokeResult?.GetValue(); // Persist the state after the function has been executed var stateJson = JsonSerializer.Serialize(this._stepState, this._stepStateType!); await this.StateManager.SetStateAsync(ActorStateKeys.StepStateJson, stateJson).ConfigureAwait(false); await this.StateManager.SaveStateAsync().ConfigureAwait(false); + + await this.EmitEventAsync( + new KernelProcessEvent + { + Id = $"{targetFunction}.OnResult", + Data = invokeResult.GetValue(), + }).ConfigureAwait(false); } catch (Exception ex) { this._logger?.LogInformation("Error in Step {StepName}: {ErrorMessage}", this.Name, ex.Message); - eventName = $"{targetFunction}.OnError"; - eventValue = ex.Message; + await this.EmitEventAsync( + new KernelProcessEvent + { + Id = $"{targetFunction}.OnError", + Data = KernelProcessError.FromException(ex), + }, + isError: true).ConfigureAwait(false); } finally { - await this.EmitEventAsync(new KernelProcessEvent { Id = eventName, Data = eventValue }).ConfigureAwait(false); - // Reset the inputs for the function that was just executed this._inputs[targetFunction] = new(this._initialInputs[targetFunction] ?? []); } @@ -307,7 +319,7 @@ protected virtual async ValueTask ActivateStepAsync() // Instantiate an instance of the inner step object KernelProcessStep stepInstance = (KernelProcessStep)ActivatorUtilities.CreateInstance(this._kernel.Services, this._innerStepType!); - var kernelPlugin = KernelPluginFactory.CreateFromObject(stepInstance, pluginName: this._stepInfo.State.Name!); + var kernelPlugin = KernelPluginFactory.CreateFromObject(stepInstance, pluginName: this._stepInfo.State.Name); // Load the kernel functions foreach (KernelFunction f in kernelPlugin) @@ -347,12 +359,9 @@ protected virtual async ValueTask ActivateStepAsync() throw new KernelException("The state object for the KernelProcessStep could not be created.").Log(this._logger); } - MethodInfo? methodInfo = this._innerStepType!.GetMethod(nameof(KernelProcessStep.ActivateAsync), [stateType]); - - if (methodInfo is null) - { + MethodInfo? methodInfo = + this._innerStepType!.GetMethod(nameof(KernelProcessStep.ActivateAsync), [stateType]) ?? throw new KernelException("The ActivateAsync method for the KernelProcessStep could not be found.").Log(this._logger); - } this._stepState = stateObject; this._stepStateType = stateType; @@ -390,12 +399,21 @@ internal async ValueTask EmitEventAsync(ProcessEvent daprEvent) } // Get the edges for the event and queue up the messages to be sent to the next steps. - foreach (var edge in this.GetEdgeForEvent(daprEvent.Id!)) + bool foundEdge = false; + foreach (var edge in this.GetEdgeForEvent(daprEvent.Id)) { ProcessMessage message = ProcessMessageFactory.CreateFromEdge(edge, daprEvent.Data); var scopedStepId = this.ScopedActorId(new ActorId(edge.OutputTarget.StepId)); var targetStep = this.ProxyFactory.CreateActorProxy(scopedStepId, nameof(MessageBufferActor)); await targetStep.EnqueueAsync(message).ConfigureAwait(false); + foundEdge = true; + } + + // Error event was raised with no edge to handle it, send it to the global error event buffer. + if (!foundEdge && daprEvent.IsError && this.ParentProcessId != null) + { + var parentProcess1 = this.ProxyFactory.CreateActorProxy(ProcessActor.GetScopedGlobalErrorEventBufferId(this.ParentProcessId), nameof(EventBufferActor)); + await parentProcess1.EnqueueAsync(daprEvent).ConfigureAwait(false); } } diff --git a/dotnet/src/Experimental/Process.Runtime.Dapr/DaprKernelProcessFactory.cs b/dotnet/src/Experimental/Process.Runtime.Dapr/DaprKernelProcessFactory.cs index 0d8f74dfe661..8e84d878d034 100644 --- a/dotnet/src/Experimental/Process.Runtime.Dapr/DaprKernelProcessFactory.cs +++ b/dotnet/src/Experimental/Process.Runtime.Dapr/DaprKernelProcessFactory.cs @@ -12,15 +12,13 @@ public static class DaprKernelProcessFactory /// Starts the specified process. /// /// Required: The to start running. - /// Required: An instance of /// Required: The initial event to start the process. /// Optional: Used to specify the unique Id of the process. If the process already has an Id, it will not be overwritten and this parameter has no effect. /// An instance of that can be used to interrogate or stop the running process. - public static async Task StartAsync(this KernelProcess process, Kernel kernel, KernelProcessEvent initialEvent, string? processId = null) + public static async Task StartAsync(this KernelProcess process, KernelProcessEvent initialEvent, string? processId = null) { Verify.NotNull(process); Verify.NotNullOrWhiteSpace(process.State?.Name); - Verify.NotNull(kernel); Verify.NotNull(initialEvent); // Assign the process Id if one is provided and the processes does not already have an Id. diff --git a/dotnet/src/Experimental/Process.Runtime.Dapr/DaprStepInfo.cs b/dotnet/src/Experimental/Process.Runtime.Dapr/DaprStepInfo.cs index a5d63077a08b..252a705ec12d 100644 --- a/dotnet/src/Experimental/Process.Runtime.Dapr/DaprStepInfo.cs +++ b/dotnet/src/Experimental/Process.Runtime.Dapr/DaprStepInfo.cs @@ -52,7 +52,8 @@ public KernelProcessStepInfo ToKernelProcessStepInfo() /// An instance of public static DaprStepInfo FromKernelStepInfo(KernelProcessStepInfo kernelStepInfo) { - Verify.NotNull(kernelStepInfo); + Verify.NotNull(kernelStepInfo, nameof(kernelStepInfo)); + return new DaprStepInfo { InnerStepDotnetType = kernelStepInfo.InnerStepType.AssemblyQualifiedName!, diff --git a/dotnet/src/Experimental/Process.Runtime.Dapr/Interfaces/IEventBuffer.cs b/dotnet/src/Experimental/Process.Runtime.Dapr/Interfaces/IEventBuffer.cs index b7f726e5c3bb..c2e354610c4d 100644 --- a/dotnet/src/Experimental/Process.Runtime.Dapr/Interfaces/IEventBuffer.cs +++ b/dotnet/src/Experimental/Process.Runtime.Dapr/Interfaces/IEventBuffer.cs @@ -23,5 +23,5 @@ public interface IEventBuffer : IActor /// Dequeues all external events. /// /// A where T is - Task> DequeueAllAsync(); + Task> DequeueAllAsync(); } diff --git a/dotnet/src/Experimental/Process.Runtime.Dapr/Interfaces/IMessageBuffer.cs b/dotnet/src/Experimental/Process.Runtime.Dapr/Interfaces/IMessageBuffer.cs index b92ccc0c8ee7..eac72cf7492d 100644 --- a/dotnet/src/Experimental/Process.Runtime.Dapr/Interfaces/IMessageBuffer.cs +++ b/dotnet/src/Experimental/Process.Runtime.Dapr/Interfaces/IMessageBuffer.cs @@ -23,5 +23,5 @@ public interface IMessageBuffer : IActor /// Dequeues all external events. /// /// A where T is - Task> DequeueAllAsync(); + Task> DequeueAllAsync(); } diff --git a/dotnet/src/Experimental/Process.UnitTests/Core/ProcessBuilderTests.cs b/dotnet/src/Experimental/Process.UnitTests/Core/ProcessBuilderTests.cs index 60b232e04a32..cee388496f92 100644 --- a/dotnet/src/Experimental/Process.UnitTests/Core/ProcessBuilderTests.cs +++ b/dotnet/src/Experimental/Process.UnitTests/Core/ProcessBuilderTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using Xunit; namespace Microsoft.SemanticKernel.Process.Core.UnitTests; @@ -99,6 +100,26 @@ public void BuildCreatesKernelProcess() Assert.Single(kernelProcess.Steps); } + /// + /// Verify that the method returns a . + /// + [Fact] + public void OnFunctionErrorCreatesEdgeBuilder() + { + // Arrange + var processBuilder = new ProcessBuilder(ProcessName); + var errorStep = processBuilder.AddStepFromType(); + var edgeBuilder = processBuilder.OnError().SendEventTo(new ProcessFunctionTargetBuilder(errorStep)); + processBuilder.AddStepFromType(); + + // Act + var kernelProcess = processBuilder.Build(); + + // Assert + Assert.NotNull(edgeBuilder); + Assert.EndsWith("Global.OnError", edgeBuilder.EventId); + } + /// /// A class that represents a step for testing. /// @@ -118,6 +139,20 @@ public void TestFunction() } } + /// + /// A class that represents a step for testing. + /// + private sealed class ErrorStep : KernelProcessStep + { + /// + /// A method for unhandling failures at the process level. + /// + [KernelFunction] + public void GlobalErrorHandler(Exception exception) + { + } + } + /// /// A class that represents a state for testing. /// diff --git a/dotnet/src/Experimental/Process.UnitTests/Core/ProcessStepBuilderTests.cs b/dotnet/src/Experimental/Process.UnitTests/Core/ProcessStepBuilderTests.cs index ba50da10c6e8..db1ef24a7b31 100644 --- a/dotnet/src/Experimental/Process.UnitTests/Core/ProcessStepBuilderTests.cs +++ b/dotnet/src/Experimental/Process.UnitTests/Core/ProcessStepBuilderTests.cs @@ -80,7 +80,6 @@ public void OnFunctionErrorShouldReturnProcessStepEdgeBuilder() // Assert Assert.NotNull(edgeBuilder); - Assert.IsType(edgeBuilder); Assert.EndsWith("TestFunction.OnError", edgeBuilder.EventId); } diff --git a/dotnet/src/Experimental/Process.UnitTests/Runtime.Local/LocalProcessTests.cs b/dotnet/src/Experimental/Process.UnitTests/Runtime.Local/LocalProcessTests.cs index 8c7a6d015728..0a2599c2956b 100644 --- a/dotnet/src/Experimental/Process.UnitTests/Runtime.Local/LocalProcessTests.cs +++ b/dotnet/src/Experimental/Process.UnitTests/Runtime.Local/LocalProcessTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Threading.Tasks; using Xunit; @@ -81,6 +82,86 @@ public void ProcessWithAssignedIdIsNotOverwrittenId() Assert.Equal("AlreadySet", localProcess.Id); } + /// + /// Verify that the function level error handler is called when a function fails. + /// + [Fact] + public async Task ProcessFunctionErrorHandledAsync() + { + // Arrange + ProcessBuilder process = new(nameof(ProcessFunctionErrorHandledAsync)); + + ProcessStepBuilder testStep = process.AddStepFromType(); + process.OnInputEvent("Start").SendEventTo(new ProcessFunctionTargetBuilder(testStep)); + + ProcessStepBuilder errorStep = process.AddStepFromType(); + testStep.OnFunctionError(nameof(FailedStep.TestFailure)).SendEventTo(new ProcessFunctionTargetBuilder(errorStep, nameof(ErrorStep.FunctionErrorHandler))); + + KernelProcess processInstance = process.Build(); + Kernel kernel = new(); + + // Act + using LocalKernelProcessContext runningProcess = await processInstance.StartAsync(kernel, new KernelProcessEvent() { Id = "Start" }); + + // Assert + Assert.True(kernel.Data.ContainsKey("error-function")); + Assert.IsType(kernel.Data["error-function"]); + } + + /// + /// Verify that the process level error handler is called when a function fails. + /// + [Fact] + public async Task ProcessGlobalErrorHandledAsync() + { + // Arrange + ProcessBuilder process = new(nameof(ProcessFunctionErrorHandledAsync)); + + ProcessStepBuilder testStep = process.AddStepFromType(); + process.OnInputEvent("Start").SendEventTo(new ProcessFunctionTargetBuilder(testStep)); + + ProcessStepBuilder errorStep = process.AddStepFromType(); + process.OnError().SendEventTo(new ProcessFunctionTargetBuilder(errorStep, nameof(ErrorStep.GlobalErrorHandler))); + + KernelProcess processInstance = process.Build(); + Kernel kernel = new(); + + // Act + using LocalKernelProcessContext runningProcess = await processInstance.StartAsync(kernel, new KernelProcessEvent() { Id = "Start" }); + + // Assert + Assert.True(kernel.Data.ContainsKey("error-global")); + Assert.IsType(kernel.Data["error-global"]); + } + + /// + /// Verify that the function level error handler has precedence over the process level error handler. + /// + [Fact] + public async Task FunctionErrorHandlerTakesPrecedenceAsync() + { + // Arrange + ProcessBuilder process = new(nameof(ProcessFunctionErrorHandledAsync)); + + ProcessStepBuilder testStep = process.AddStepFromType(); + process.OnInputEvent("Start").SendEventTo(new ProcessFunctionTargetBuilder(testStep)); + + ProcessStepBuilder errorStep = process.AddStepFromType(); + testStep.OnFunctionError(nameof(FailedStep.TestFailure)).SendEventTo(new ProcessFunctionTargetBuilder(errorStep, nameof(ErrorStep.FunctionErrorHandler))); + process.OnError().SendEventTo(new ProcessFunctionTargetBuilder(errorStep, nameof(ErrorStep.GlobalErrorHandler))); + + KernelProcess processInstance = process.Build(); + Kernel kernel = new(); + + // Act + using LocalKernelProcessContext runningProcess = await processInstance.StartAsync(kernel, new KernelProcessEvent() { Id = "Start" }); + + // Assert + Assert.False(kernel.Data.ContainsKey("error-global")); + Assert.True(kernel.Data.ContainsKey("error-function")); + Assert.IsType(kernel.Data["error-function"]); + } + /// /// A class that represents a step for testing. /// @@ -100,6 +181,45 @@ public void TestFunction() } } + /// + /// A class that represents a step for testing. + /// + private sealed class FailedStep : KernelProcessStep + { + /// + /// A method that represents a function for testing. + /// + [KernelFunction] + public void TestFailure() + { + throw new InvalidOperationException("I failed!"); + } + } + + /// + /// A class that represents a step for testing. + /// + private sealed class ErrorStep : KernelProcessStep + { + /// + /// A method for unhandling failures at the process level. + /// + [KernelFunction] + public void GlobalErrorHandler(KernelProcessError exception, Kernel kernel) + { + kernel.Data.Add("error-global", exception); + } + + /// + /// A method for unhandling failures at the function level. + /// + [KernelFunction] + public void FunctionErrorHandler(KernelProcessError exception, Kernel kernel) + { + kernel.Data.Add("error-function", exception); + } + } + /// /// A class that represents a state for testing. /// diff --git a/dotnet/src/Experimental/Process.Utilities.UnitTests/CloneTests.cs b/dotnet/src/Experimental/Process.Utilities.UnitTests/CloneTests.cs index c54a5c7c1dfb..48043fe541c6 100644 --- a/dotnet/src/Experimental/Process.Utilities.UnitTests/CloneTests.cs +++ b/dotnet/src/Experimental/Process.Utilities.UnitTests/CloneTests.cs @@ -4,7 +4,7 @@ using System.Linq; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.Process.Runtime; +using Microsoft.SemanticKernel.Process.Internal; using Xunit; namespace SemanticKernel.Process.Utilities.UnitTests; diff --git a/dotnet/src/Experimental/Process.Utilities.UnitTests/ProcessTypeExtensionsTests.cs b/dotnet/src/Experimental/Process.Utilities.UnitTests/ProcessTypeExtensionsTests.cs index 23e27b6d121b..96e22e308a55 100644 --- a/dotnet/src/Experimental/Process.Utilities.UnitTests/ProcessTypeExtensionsTests.cs +++ b/dotnet/src/Experimental/Process.Utilities.UnitTests/ProcessTypeExtensionsTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.SemanticKernel.Process.Internal; using Xunit; namespace Microsoft.SemanticKernel.Process.Core.UnitTests; diff --git a/dotnet/src/InternalUtilities/process/Abstractions/ExceptionExtensions.cs b/dotnet/src/InternalUtilities/process/Abstractions/ExceptionExtensions.cs index 26abfe74934a..95a9ae784aae 100644 --- a/dotnet/src/InternalUtilities/process/Abstractions/ExceptionExtensions.cs +++ b/dotnet/src/InternalUtilities/process/Abstractions/ExceptionExtensions.cs @@ -2,7 +2,7 @@ using System; using Microsoft.Extensions.Logging; -namespace Microsoft.SemanticKernel.Process.Runtime; +namespace Microsoft.SemanticKernel.Process.Internal; internal static class ExceptionExtensions { diff --git a/dotnet/src/InternalUtilities/process/Abstractions/KernelProcessStepExtension.cs b/dotnet/src/InternalUtilities/process/Abstractions/KernelProcessStepExtension.cs index 63cf4003127a..ef37c0429f08 100644 --- a/dotnet/src/InternalUtilities/process/Abstractions/KernelProcessStepExtension.cs +++ b/dotnet/src/InternalUtilities/process/Abstractions/KernelProcessStepExtension.cs @@ -2,7 +2,7 @@ using System; -namespace Microsoft.SemanticKernel.Process; +namespace Microsoft.SemanticKernel.Process.Internal; internal static class KernelProcessStepExtensions { diff --git a/dotnet/src/InternalUtilities/process/Abstractions/ProcessConstants.cs b/dotnet/src/InternalUtilities/process/Abstractions/ProcessConstants.cs new file mode 100644 index 000000000000..92e005479675 --- /dev/null +++ b/dotnet/src/InternalUtilities/process/Abstractions/ProcessConstants.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. +namespace Microsoft.SemanticKernel.Process.Internal; + +internal static class ProcessConstants +{ + /// + /// Event raised internally for errors not handled at the step level. + /// + public const string GlobalErrorEventId = "Microsoft.SemanticKernel.Process.Global.OnError"; + + /// + /// Qualified name of the end step. + /// + public const string EndStepName = "Microsoft.SemanticKernel.Process.EndStep"; +} diff --git a/dotnet/src/InternalUtilities/process/Abstractions/ProcessExtensions.cs b/dotnet/src/InternalUtilities/process/Abstractions/ProcessExtensions.cs index 85faf0813946..bd5f08cf519c 100644 --- a/dotnet/src/InternalUtilities/process/Abstractions/ProcessExtensions.cs +++ b/dotnet/src/InternalUtilities/process/Abstractions/ProcessExtensions.cs @@ -3,7 +3,7 @@ using System.Linq; using Microsoft.Extensions.Logging; -namespace Microsoft.SemanticKernel.Process.Runtime; +namespace Microsoft.SemanticKernel.Process.Internal; internal static class ProcessExtensions { diff --git a/dotnet/src/InternalUtilities/process/Abstractions/StepExtensions.cs b/dotnet/src/InternalUtilities/process/Abstractions/StepExtensions.cs index b4112a0f0541..1344ea5c4979 100644 --- a/dotnet/src/InternalUtilities/process/Abstractions/StepExtensions.cs +++ b/dotnet/src/InternalUtilities/process/Abstractions/StepExtensions.cs @@ -5,7 +5,7 @@ using System.Linq; using Microsoft.Extensions.Logging; -namespace Microsoft.SemanticKernel.Process.Runtime; +namespace Microsoft.SemanticKernel.Process.Internal; internal static class StepExtensions { diff --git a/dotnet/src/InternalUtilities/process/Runtime/ProcessEvent.cs b/dotnet/src/InternalUtilities/process/Runtime/ProcessEvent.cs index 6b6babf0e86d..da270c773911 100644 --- a/dotnet/src/InternalUtilities/process/Runtime/ProcessEvent.cs +++ b/dotnet/src/InternalUtilities/process/Runtime/ProcessEvent.cs @@ -8,10 +8,13 @@ namespace Microsoft.SemanticKernel.Process.Runtime; /// /// The namespace of the event. /// The instance of that this came from. +/// This event represents a runtime error / exception raised internally by the framework. [DataContract] +[KnownType(typeof(KernelProcessError))] public record ProcessEvent( [property: DataMember] string? Namespace, - [property: DataMember] KernelProcessEvent InnerEvent) + [property: DataMember] KernelProcessEvent InnerEvent, + [property: DataMember] bool IsError = false) { /// /// The Id of the event. @@ -27,12 +30,4 @@ public record ProcessEvent( /// The visibility of the event. /// internal KernelProcessEventVisibility Visibility => this.InnerEvent.Visibility; - - /// - /// Creates a new from a . - /// - /// The - /// The namespace of the event. - /// An instance of - internal static ProcessEvent FromKernelProcessEvent(KernelProcessEvent kernelProcessEvent, string Namespace) => new(Namespace, kernelProcessEvent); } diff --git a/dotnet/src/InternalUtilities/process/Runtime/ProcessMessage.cs b/dotnet/src/InternalUtilities/process/Runtime/ProcessMessage.cs index a1931fad363e..c63d89f10262 100644 --- a/dotnet/src/InternalUtilities/process/Runtime/ProcessMessage.cs +++ b/dotnet/src/InternalUtilities/process/Runtime/ProcessMessage.cs @@ -15,6 +15,7 @@ namespace Microsoft.SemanticKernel.Process.Runtime; /// The name of the function associated with the message. /// The dictionary of values associated with the message. [DataContract] +[KnownType(typeof(KernelProcessError))] public record ProcessMessage( [property:DataMember] string SourceId, From 6f223c27e7debb070806731d2f96bd0599dcd40a Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Fri, 1 Nov 2024 16:25:27 +0000 Subject: [PATCH 7/8] .Net: Improve logging for function calls processor and kernel function (#9495) ### Motivation, Context and Description This PR improves the existing logging functionality in both `FunctionCallsProcessor` and `KernelFunction` by logging additional contextual information, such as function details, function choice behavior configuration, and function call details, to simplify troubleshooting. --- dotnet/SK-dotnet.sln | 1 + .../Connectors.OpenAI/Core/ClientCore.cs | 4 +- .../FunctionCalling/FunctionCallsProcessor.cs | 72 ++++----- .../FunctionCallsProcessorLoggerExtensions.cs | 147 ++++++++++++++++++ .../Functions/KernelFunction.cs | 30 ++-- .../Functions/KernelFunctionLogMessages.cs | 64 ++++---- .../KernelFunctionLogMessagesTests.cs | 12 +- 7 files changed, 234 insertions(+), 96 deletions(-) create mode 100644 dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessorLoggerExtensions.cs diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 374f877460f6..c2062d3ed7fd 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -349,6 +349,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "FunctionCalling", "Function ProjectSection(SolutionItems) = preProject src\InternalUtilities\connectors\AI\FunctionCalling\FunctionCallingUtilities.props = src\InternalUtilities\connectors\AI\FunctionCalling\FunctionCallingUtilities.props src\InternalUtilities\connectors\AI\FunctionCalling\FunctionCallsProcessor.cs = src\InternalUtilities\connectors\AI\FunctionCalling\FunctionCallsProcessor.cs + src\InternalUtilities\connectors\AI\FunctionCalling\FunctionCallsProcessorLoggerExtensions.cs = src\InternalUtilities\connectors\AI\FunctionCalling\FunctionCallsProcessorLoggerExtensions.cs EndProjectSection EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Weaviate.UnitTests", "src\Connectors\Connectors.Weaviate.UnitTests\Connectors.Weaviate.UnitTests.csproj", "{E8FC97B0-B417-4A90-993C-B8AA9223B058}" diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.cs index a295c7876e69..146701cdfaf0 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.cs @@ -88,6 +88,8 @@ internal ClientCore( HttpClient? httpClient = null, ILogger? logger = null) { + this.Logger = logger ?? NullLogger.Instance; + this.FunctionCallsProcessor = new FunctionCallsProcessor(this.Logger); // Empty constructor will be used when inherited by a specialized Client. @@ -107,8 +109,6 @@ internal ClientCore( this.AddAttribute(AIServiceExtensions.ModelIdKey, modelId); } - this.Logger = logger ?? NullLogger.Instance; - // Accepts the endpoint if provided, otherwise uses the default OpenAI endpoint. this.Endpoint = endpoint ?? httpClient?.BaseAddress; if (this.Endpoint is null) diff --git a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs index a1c92b842669..1b591b78db77 100644 --- a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs +++ b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs @@ -91,6 +91,8 @@ public FunctionCallsProcessor(ILogger? logger = null) var configuration = behavior.GetConfiguration(new(chatHistory) { Kernel = kernel, RequestSequenceIndex = requestIndex }); + this._logger.LogFunctionChoiceBehaviorConfiguration(configuration); + // Disable auto invocation if no kernel is provided. configuration.AutoInvoke = kernel is not null && configuration.AutoInvoke; @@ -99,24 +101,13 @@ public FunctionCallsProcessor(ILogger? logger = null) if (requestIndex >= maximumAutoInvokeAttempts) { configuration.AutoInvoke = false; - if (this._logger!.IsEnabled(LogLevel.Debug)) - { - this._logger.LogDebug("Maximum auto-invoke ({MaximumAutoInvoke}) reached.", maximumAutoInvokeAttempts); - } + this._logger.LogMaximumNumberOfAutoInvocationsPerUserRequestReached(maximumAutoInvokeAttempts); } // Disable auto invocation if we've exceeded the allowed limit of in-flight auto-invokes. See XML comment for the "MaxInflightAutoInvokes" const for more details. else if (s_inflightAutoInvokes.Value >= MaxInflightAutoInvokes) { configuration.AutoInvoke = false; - if (this._logger!.IsEnabled(LogLevel.Debug)) - { - this._logger.LogDebug("Maximum auto-invoke ({MaxInflightAutoInvoke}) reached.", MaxInflightAutoInvokes); - } - } - - if (configuration.Functions?.Count == 0) - { - this._logger.LogDebug("No functions provided to AI model. Function calling is disabled."); + this._logger.LogMaximumNumberOfInFlightAutoInvocationsReached(MaxInflightAutoInvokes); } return configuration; @@ -146,28 +137,14 @@ public FunctionCallsProcessor(ILogger? logger = null) { var functionCalls = FunctionCallContent.GetFunctionCalls(chatMessageContent).ToList(); - if (this._logger.IsEnabled(LogLevel.Debug)) - { - this._logger.LogDebug("Function calls: {Calls}", functionCalls.Count); - } - if (this._logger.IsEnabled(LogLevel.Trace)) - { - var messages = new List(functionCalls.Count); - foreach (var call in functionCalls) - { - var argumentsString = call.Arguments is not null ? $"({string.Join(",", call.Arguments.Select(a => $"{a.Key}={a.Value}"))})" : "()"; - var pluginName = string.IsNullOrEmpty(call.PluginName) ? string.Empty : $"{call.PluginName}-"; - messages.Add($"{pluginName}{call.FunctionName}{argumentsString}"); - } - this._logger.LogTrace("Function calls: {Calls}", string.Join(", ", messages)); - } + this._logger.LogFunctionCalls(functionCalls); // Add the result message to the caller's chat history; // this is required for AI model to understand the function results. chatHistory.Add(chatMessageContent); var functionTasks = options.AllowConcurrentInvocation && functionCalls.Count > 1 ? - new List>(functionCalls.Count) : + new List>(functionCalls.Count) : null; // We must send back a result for every function call, regardless of whether we successfully executed it or not. @@ -205,15 +182,16 @@ public FunctionCallsProcessor(ILogger? logger = null) FunctionSequenceIndex = functionCallIndex, FunctionCount = functionCalls.Count, CancellationToken = cancellationToken, - IsStreaming = isStreaming + IsStreaming = isStreaming, + ToolCallId = functionCall.Id }; - var functionTask = Task.Run<(string? Result, string? ErrorMessage, FunctionCallContent FunctionCall, bool Terminate)>(async () => + var functionTask = Task.Run<(string? Result, string? ErrorMessage, FunctionCallContent FunctionCall, AutoFunctionInvocationContext Context)>(async () => { s_inflightAutoInvokes.Value++; try { - invocationContext = await OnAutoFunctionInvocationAsync(kernel, invocationContext, async (context) => + invocationContext = await this.OnAutoFunctionInvocationAsync(kernel, invocationContext, async (context) => { // Check if filter requested termination. if (context.Terminate) @@ -231,12 +209,12 @@ public FunctionCallsProcessor(ILogger? logger = null) catch (Exception e) #pragma warning restore CA1031 // Do not catch general exception types { - return (null, $"Error: Exception while invoking function. {e.Message}", functionCall, false); + return (null, $"Error: Exception while invoking function. {e.Message}", functionCall, invocationContext); } // Apply any changes from the auto function invocation filters context to final result. var stringResult = ProcessFunctionResult(invocationContext.Result.GetValue() ?? string.Empty); - return (stringResult, null, functionCall, invocationContext.Terminate); + return (stringResult, null, functionCall, invocationContext); }, cancellationToken); // If concurrent invocation is enabled, add the task to the list for later waiting. Otherwise, join with it now. @@ -250,9 +228,9 @@ public FunctionCallsProcessor(ILogger? logger = null) this.AddFunctionCallResultToChatHistory(chatHistory, functionResult.FunctionCall, functionResult.Result, functionResult.ErrorMessage); // If filter requested termination, return last chat history message. - if (functionResult.Terminate) + if (functionResult.Context.Terminate) { - this._logger.LogDebug("Filter requested termination of automatic function invocation."); + this._logger.LogAutoFunctionInvocationProcessTermination(functionResult.Context); return chatHistory.Last(); } } @@ -270,8 +248,9 @@ public FunctionCallsProcessor(ILogger? logger = null) { this.AddFunctionCallResultToChatHistory(chatHistory, functionTask.Result.FunctionCall, functionTask.Result.Result, functionTask.Result.ErrorMessage); - if (functionTask.Result.Terminate) + if (functionTask.Result.Context.Terminate) { + this._logger.LogAutoFunctionInvocationProcessTermination(functionTask.Result.Context); terminationRequested = true; } } @@ -279,7 +258,6 @@ public FunctionCallsProcessor(ILogger? logger = null) // If filter requested termination, return last chat history message. if (terminationRequested) { - this._logger.LogDebug("Filter requested termination of automatic function invocation."); return chatHistory.Last(); } } @@ -297,9 +275,9 @@ public FunctionCallsProcessor(ILogger? logger = null) private void AddFunctionCallResultToChatHistory(ChatHistory chatHistory, FunctionCallContent functionCall, string? result, string? errorMessage = null) { // Log any error - if (errorMessage is not null && this._logger.IsEnabled(LogLevel.Debug)) + if (errorMessage is not null) { - this._logger.LogDebug("Failed to handle function request ({Id}). {Error}", functionCall.Id, errorMessage); + this._logger.LogFunctionCallRequestFailure(functionCall, errorMessage); } result ??= errorMessage ?? string.Empty; @@ -317,12 +295,12 @@ private void AddFunctionCallResultToChatHistory(ChatHistory chatHistory, Functio /// The auto function invocation context. /// The function to call after the filters. /// The auto function invocation context. - private static async Task OnAutoFunctionInvocationAsync( + private async Task OnAutoFunctionInvocationAsync( Kernel kernel, AutoFunctionInvocationContext context, Func functionCallCallback) { - await InvokeFilterOrFunctionAsync(kernel.AutoFunctionInvocationFilters, functionCallCallback, context).ConfigureAwait(false); + await this.InvokeFilterOrFunctionAsync(kernel.AutoFunctionInvocationFilters, functionCallCallback, context).ConfigureAwait(false); return context; } @@ -334,7 +312,7 @@ private static async Task OnAutoFunctionInvocatio /// Second parameter of filter is callback. It can be either filter on + 1 position or function if there are no remaining filters to execute. /// Function will be always executed as last step after all filters. /// - private static async Task InvokeFilterOrFunctionAsync( + private async Task InvokeFilterOrFunctionAsync( IList? autoFunctionInvocationFilters, Func functionCallCallback, AutoFunctionInvocationContext context, @@ -342,8 +320,12 @@ private static async Task InvokeFilterOrFunctionAsync( { if (autoFunctionInvocationFilters is { Count: > 0 } && index < autoFunctionInvocationFilters.Count) { - await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(context, - (context) => InvokeFilterOrFunctionAsync(autoFunctionInvocationFilters, functionCallCallback, context, index + 1)).ConfigureAwait(false); + this._logger.LogAutoFunctionInvocationFilterContext(context); + + await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( + context, + (context) => this.InvokeFilterOrFunctionAsync(autoFunctionInvocationFilters, functionCallCallback, context, index + 1) + ).ConfigureAwait(false); } else { diff --git a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessorLoggerExtensions.cs b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessorLoggerExtensions.cs new file mode 100644 index 000000000000..ad6c2e033af0 --- /dev/null +++ b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessorLoggerExtensions.cs @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.Extensions.Logging; + +namespace Microsoft.SemanticKernel.Connectors.FunctionCalling; + +[ExcludeFromCodeCoverage] +internal static partial class FunctionCallsProcessorLoggingExtensions +{ + /// + /// Action to log the . + /// + private static readonly Action s_logFunctionChoiceBehaviorConfiguration = + LoggerMessage.Define( + logLevel: LogLevel.Debug, + eventId: 0, + "Function choice behavior configuration: Choice:{Choice}, AutoInvoke:{AutoInvoke}, AllowConcurrentInvocation:{AllowConcurrentInvocation}, AllowParallelCalls:{AllowParallelCalls} Functions:{Functions}"); + + /// + /// Action to log function calls. + /// + private static readonly Action s_logFunctionCalls = + LoggerMessage.Define( + logLevel: LogLevel.Debug, + eventId: 0, + "Function calls: {Calls}"); + + /// + /// Action to log auto function invocation filter context. + /// + private static readonly Action s_logAutoFunctionInvocationFilterContext = + LoggerMessage.Define( + logLevel: LogLevel.Debug, + eventId: 0, + "Auto function invocation filter context: Name:{Name}, Id:{Id}, IsStreaming:{IsStreaming} FunctionSequenceIndex:{FunctionSequenceIndex}, RequestSequenceIndex:{RequestSequenceIndex}, FunctionCount:{FunctionCount}"); + + /// + /// Action to log auto function invocation filter termination. + /// + private static readonly Action s_logAutoFunctionInvocationFilterTermination = + LoggerMessage.Define( + logLevel: LogLevel.Debug, + eventId: 0, + "Auto function invocation filter requested termination: Name:{Name}, Id:{Id}"); + + /// + /// Logs . + /// + public static void LogFunctionChoiceBehaviorConfiguration(this ILogger logger, FunctionChoiceBehaviorConfiguration configuration) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + var functionsLog = (configuration.Functions != null && configuration.Functions.Any()) + ? string.Join(", ", configuration.Functions.Select(f => FunctionName.ToFullyQualifiedName(f.Name, f.PluginName))) + : "None (Function calling is disabled)"; + + s_logFunctionChoiceBehaviorConfiguration( + logger, + configuration.Choice.Label, + configuration.AutoInvoke, + configuration.Options.AllowConcurrentInvocation, + configuration.Options.AllowParallelCalls, + functionsLog, + null); + } + } + + /// + /// Logs function calls. + /// + public static void LogFunctionCalls(this ILogger logger, List functionCalls) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + s_logFunctionCalls( + logger, + string.Join(", ", functionCalls.Select(call => $"{FunctionName.ToFullyQualifiedName(call.FunctionName, call.PluginName)} [Id: {call.Id}]")), + null + ); + } + } + + /// + /// Logs the . + /// + public static void LogAutoFunctionInvocationFilterContext(this ILogger logger, AutoFunctionInvocationContext context) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + var fqn = FunctionName.ToFullyQualifiedName(context.Function.Name, context.Function.PluginName); + + s_logAutoFunctionInvocationFilterContext( + logger, + fqn, + context.ToolCallId, + context.IsStreaming, + context.FunctionSequenceIndex, + context.RequestSequenceIndex, + context.FunctionCount, + null); + } + } + + /// + /// Logs the auto function invocation process termination. + /// + public static void LogAutoFunctionInvocationProcessTermination(this ILogger logger, AutoFunctionInvocationContext context) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + var fqn = FunctionName.ToFullyQualifiedName(context.Function.Name, context.Function.PluginName); + + s_logAutoFunctionInvocationFilterTermination(logger, fqn, context.ToolCallId, null); + } + } + + /// + /// Logs function call request failure. + /// + public static void LogFunctionCallRequestFailure(this ILogger logger, FunctionCallContent functionCall, string error) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + var fqn = FunctionName.ToFullyQualifiedName(functionCall.FunctionName, functionCall.PluginName); + + logger.LogDebug("Function call request failed: Name:{Name}, Id:{Id}", fqn, functionCall.Id); + } + + // Log error at trace level only because it may contain sensitive information. + if (logger.IsEnabled(LogLevel.Trace)) + { + var fqn = FunctionName.ToFullyQualifiedName(functionCall.FunctionName, functionCall.PluginName); + + logger.LogTrace("Function call request failed: Name:{Name}, Id:{Id}, Error:{Error}", fqn, functionCall.Id, error); + } + } + + [LoggerMessage(EventId = 0, Level = LogLevel.Debug, Message = "The maximum limit of {MaxNumberOfAutoInvocations} auto invocations per user request has been reached. Auto invocation is now disabled.")] + public static partial void LogMaximumNumberOfAutoInvocationsPerUserRequestReached(this ILogger logger, int maxNumberOfAutoInvocations); + + [LoggerMessage(EventId = 0, Level = LogLevel.Debug, Message = "The maximum limit of {MaxNumberOfInflightAutoInvocations} in-flight auto invocations has been reached. Auto invocation is now disabled.")] + public static partial void LogMaximumNumberOfInFlightAutoInvocationsReached(this ILogger logger, int maxNumberOfInflightAutoInvocations); +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs index 9c851bd2dfa0..9885ff22ba9d 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs @@ -228,9 +228,9 @@ public async Task InvokeAsync( // Ensure arguments are initialized. arguments ??= []; - logger.LogFunctionInvoking(this.Name); + logger.LogFunctionInvoking(this.PluginName, this.Name); - this.LogFunctionArguments(logger, arguments); + this.LogFunctionArguments(logger, this.PluginName, this.Name, arguments); TagList tags = new() { { MeasurementFunctionTagName, this.Name } }; long startingTimestamp = Stopwatch.GetTimestamp(); @@ -275,9 +275,9 @@ public async Task InvokeAsync( throw new OperationCanceledException($"A {nameof(Kernel)}.{nameof(Kernel.FunctionInvoked)} event handler requested cancellation after function invocation."); } - logger.LogFunctionInvokedSuccess(this.Name); + logger.LogFunctionInvokedSuccess(this.PluginName, this.Name); - this.LogFunctionResult(logger, functionResult); + this.LogFunctionResult(logger, this.PluginName, this.Name, functionResult); return functionResult; } @@ -291,7 +291,7 @@ public async Task InvokeAsync( // Record the invocation duration metric and log the completion. TimeSpan duration = new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * (10_000_000.0 / Stopwatch.Frequency))); s_invocationDuration.Record(duration.TotalSeconds, in tags); - logger.LogFunctionComplete(duration.TotalSeconds); + logger.LogFunctionComplete(this.PluginName, this.Name, duration.TotalSeconds); } } @@ -355,9 +355,9 @@ public async IAsyncEnumerable InvokeStreamingAsync( ILogger logger = kernel.LoggerFactory.CreateLogger(this.Name) ?? NullLogger.Instance; arguments ??= []; - logger.LogFunctionStreamingInvoking(this.Name); + logger.LogFunctionStreamingInvoking(this.PluginName, this.Name); - this.LogFunctionArguments(logger, arguments); + this.LogFunctionArguments(logger, this.PluginName, this.Name, arguments); TagList tags = new() { { MeasurementFunctionTagName, this.Name } }; long startingTimestamp = Stopwatch.GetTimestamp(); @@ -436,7 +436,7 @@ public async IAsyncEnumerable InvokeStreamingAsync( // Record the streaming duration metric and log the completion. TimeSpan duration = new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * (10_000_000.0 / Stopwatch.Frequency))); s_streamingDuration.Record(duration.TotalSeconds, in tags); - logger.LogFunctionStreamingComplete(duration.TotalSeconds); + logger.LogFunctionStreamingComplete(this.PluginName, this.Name, duration.TotalSeconds); } } @@ -493,7 +493,7 @@ private static void HandleException( // Log the exception and add its type to the tags that'll be included with recording the invocation duration. tags.Add(MeasurementErrorTagName, ex.GetType().FullName); activity?.SetError(ex); - logger.LogFunctionError(ex, ex.Message); + logger.LogFunctionError(kernelFunction.PluginName, kernelFunction.Name, ex, ex.Message); // If the exception is an OperationCanceledException, wrap it in a KernelFunctionCanceledException // in order to convey additional details about what function was canceled. This is particularly @@ -513,29 +513,29 @@ private static void HandleException( [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "The warning is shown and should be addressed at the function creation site; there is no need to show it again at the function invocation sites.")] [UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "The warning is shown and should be addressed at the function creation site; there is no need to show it again at the function invocation sites.")] - private void LogFunctionArguments(ILogger logger, KernelArguments arguments) + private void LogFunctionArguments(ILogger logger, string? pluginName, string functionName, KernelArguments arguments) { if (this.JsonSerializerOptions is not null) { - logger.LogFunctionArguments(arguments, this.JsonSerializerOptions); + logger.LogFunctionArguments(pluginName, functionName, arguments, this.JsonSerializerOptions); } else { - logger.LogFunctionArguments(arguments); + logger.LogFunctionArguments(pluginName, functionName, arguments); } } [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "The warning is shown and should be addressed at the function creation site; there is no need to show it again at the function invocation sites.")] [UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "The warning is shown and should be addressed at the function creation site; there is no need to show it again at the function invocation sites.")] - private void LogFunctionResult(ILogger logger, FunctionResult functionResult) + private void LogFunctionResult(ILogger logger, string? pluginName, string functionName, FunctionResult functionResult) { if (this.JsonSerializerOptions is not null) { - logger.LogFunctionResultValue(functionResult, this.JsonSerializerOptions); + logger.LogFunctionResultValue(pluginName, functionName, functionResult, this.JsonSerializerOptions); } else { - logger.LogFunctionResultValue(functionResult); + logger.LogFunctionResultValue(pluginName, functionName, functionResult); } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionLogMessages.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionLogMessages.cs index 594294d8aacc..42c4b7f6e6a9 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionLogMessages.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionLogMessages.cs @@ -23,9 +23,10 @@ internal static partial class KernelFunctionLogMessages [LoggerMessage( EventId = 0, Level = LogLevel.Information, - Message = "Function {FunctionName} invoking.")] + Message = "Function {PluginName}-{FunctionName} invoking.")] public static partial void LogFunctionInvoking( this ILogger logger, + string? pluginName, string functionName); /// @@ -33,17 +34,17 @@ public static partial void LogFunctionInvoking( /// The action provides the benefit of caching the template parsing result for better performance. /// And the public method is a helper to serialize the arguments. /// - private static readonly Action s_logFunctionArguments = - LoggerMessage.Define( + private static readonly Action s_logFunctionArguments = + LoggerMessage.Define( logLevel: LogLevel.Trace, // Sensitive data, logging as trace, disabled by default eventId: 0, - "Function arguments: {Arguments}"); + "Function {PluginName}-{FunctionName} arguments: {Arguments}"); [RequiresUnreferencedCode("Uses reflection to serialize function arguments, making it incompatible with AOT scenarios.")] [RequiresDynamicCode("Uses reflection to serialize the function arguments, making it incompatible with AOT scenarios.")] - public static void LogFunctionArguments(this ILogger logger, KernelArguments arguments) + public static void LogFunctionArguments(this ILogger logger, string? pluginName, string functionName, KernelArguments arguments) { - LogFunctionArgumentsInternal(logger, arguments); + LogFunctionArgumentsInternal(logger, pluginName, functionName, arguments); } /// @@ -51,9 +52,9 @@ public static void LogFunctionArguments(this ILogger logger, KernelArguments arg /// [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "This method is AOT safe.")] [UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "This method is AOT safe.")] - public static void LogFunctionArguments(this ILogger logger, KernelArguments arguments, JsonSerializerOptions jsonSerializerOptions) + public static void LogFunctionArguments(this ILogger logger, string? pluginName, string functionName, KernelArguments arguments, JsonSerializerOptions jsonSerializerOptions) { - LogFunctionArgumentsInternal(logger, arguments, jsonSerializerOptions); + LogFunctionArgumentsInternal(logger, pluginName, functionName, arguments, jsonSerializerOptions); } /// @@ -61,7 +62,7 @@ public static void LogFunctionArguments(this ILogger logger, KernelArguments arg /// [RequiresUnreferencedCode("Uses reflection, if no JOSs are supplied, to serialize function arguments, making it incompatible with AOT scenarios.")] [RequiresDynamicCode("Uses reflection, if no JOSs are supplied, to serialize function arguments, making it incompatible with AOT scenarios.")] - private static void LogFunctionArgumentsInternal(this ILogger logger, KernelArguments arguments, JsonSerializerOptions? jsonSerializerOptions = null) + private static void LogFunctionArgumentsInternal(this ILogger logger, string? pluginName, string functionName, KernelArguments arguments, JsonSerializerOptions? jsonSerializerOptions = null) { if (logger.IsEnabled(LogLevel.Trace)) { @@ -79,11 +80,11 @@ private static void LogFunctionArgumentsInternal(this ILogger logger, KernelArgu jsonString = JsonSerializer.Serialize(arguments); } - s_logFunctionArguments(logger, jsonString, null); + s_logFunctionArguments(logger, pluginName, functionName, jsonString, null); } catch (NotSupportedException ex) { - s_logFunctionArguments(logger, "Failed to serialize arguments to Json", ex); + s_logFunctionArguments(logger, pluginName, functionName, "Failed to serialize arguments to Json", ex); } } } @@ -94,24 +95,24 @@ private static void LogFunctionArgumentsInternal(this ILogger logger, KernelArgu [LoggerMessage( EventId = 0, Level = LogLevel.Information, - Message = "Function {FunctionName} succeeded.")] - public static partial void LogFunctionInvokedSuccess(this ILogger logger, string functionName); + Message = "Function {PluginName}-{FunctionName} succeeded.")] + public static partial void LogFunctionInvokedSuccess(this ILogger logger, string? pluginName, string functionName); /// /// Logs result of a . /// The action provides the benefit of caching the template parsing result for better performance. /// And the public method is a helper to serialize the result. /// - private static readonly Action s_logFunctionResultValue = - LoggerMessage.Define( + private static readonly Action s_logFunctionResultValue = + LoggerMessage.Define( logLevel: LogLevel.Trace, // Sensitive data, logging as trace, disabled by default eventId: 0, - "Function result: {ResultValue}"); + "Function {PluginName}-{FunctionName} result: {ResultValue}"); [RequiresUnreferencedCode("Uses reflection to serialize function result, making it incompatible with AOT scenarios.")] [RequiresDynamicCode("Uses reflection to serialize the function result, making it incompatible with AOT scenarios.")] - public static void LogFunctionResultValue(this ILogger logger, FunctionResult? resultValue) + public static void LogFunctionResultValue(this ILogger logger, string? pluginName, string functionName, FunctionResult? resultValue) { - LogFunctionResultValueInternal(logger, resultValue); + LogFunctionResultValueInternal(logger, pluginName, functionName, resultValue); } /// @@ -121,22 +122,22 @@ public static void LogFunctionResultValue(this ILogger logger, FunctionResult? r /// [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "This method is AOT safe.")] [UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "This method is AOT safe.")] - public static void LogFunctionResultValue(this ILogger logger, FunctionResult? resultValue, JsonSerializerOptions jsonSerializerOptions) + public static void LogFunctionResultValue(this ILogger logger, string? pluginName, string functionName, FunctionResult? resultValue, JsonSerializerOptions jsonSerializerOptions) { - LogFunctionResultValueInternal(logger, resultValue, jsonSerializerOptions); + LogFunctionResultValueInternal(logger, pluginName, functionName, resultValue, jsonSerializerOptions); } [SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "By design. See comment below.")] [RequiresUnreferencedCode("Uses reflection, if no JOSs are supplied, to serialize function arguments, making it incompatible with AOT scenarios.")] [RequiresDynamicCode("Uses reflection, if no JOSs are supplied, to serialize function arguments, making it incompatible with AOT scenarios.")] - private static void LogFunctionResultValueInternal(this ILogger logger, FunctionResult? resultValue, JsonSerializerOptions? jsonSerializerOptions = null) + private static void LogFunctionResultValueInternal(this ILogger logger, string? pluginName, string functionName, FunctionResult? resultValue, JsonSerializerOptions? jsonSerializerOptions = null) { if (logger.IsEnabled(LogLevel.Trace)) { // Attempt to convert the result value to string using the GetValue heuristic try { - s_logFunctionResultValue(logger, resultValue?.GetValue() ?? string.Empty, null); + s_logFunctionResultValue(logger, pluginName, functionName, resultValue?.GetValue() ?? string.Empty, null); return; } catch { } @@ -156,11 +157,11 @@ private static void LogFunctionResultValueInternal(this ILogger logger, Function jsonString = JsonSerializer.Serialize(resultValue?.Value); } - s_logFunctionResultValue(logger, jsonString, null); + s_logFunctionResultValue(logger, pluginName, functionName, jsonString, null); } catch (NotSupportedException ex) { - s_logFunctionResultValue(logger, "Failed to log function result value", ex); + s_logFunctionResultValue(logger, pluginName, functionName, "Failed to log function result value", ex); } } } @@ -171,9 +172,11 @@ private static void LogFunctionResultValueInternal(this ILogger logger, Function [LoggerMessage( EventId = 0, Level = LogLevel.Error, - Message = "Function failed. Error: {Message}")] + Message = "Function {PluginName}-{FunctionName} failed. Error: {Message}")] public static partial void LogFunctionError( this ILogger logger, + string? pluginName, + string functionName, Exception exception, string message); @@ -183,9 +186,11 @@ public static partial void LogFunctionError( [LoggerMessage( EventId = 0, Level = LogLevel.Information, - Message = "Function completed. Duration: {Duration}s")] + Message = "Function {PluginName}-{FunctionName} completed. Duration: {Duration}s")] public static partial void LogFunctionComplete( this ILogger logger, + string? pluginName, + string functionName, double duration); /// @@ -194,9 +199,10 @@ public static partial void LogFunctionComplete( [LoggerMessage( EventId = 0, Level = LogLevel.Information, - Message = "Function {FunctionName} streaming.")] + Message = "Function {PluginName}-{FunctionName} streaming.")] public static partial void LogFunctionStreamingInvoking( this ILogger logger, + string? pluginName, string functionName); /// @@ -205,8 +211,10 @@ public static partial void LogFunctionStreamingInvoking( [LoggerMessage( EventId = 0, Level = LogLevel.Information, - Message = "Function streaming completed. Duration: {Duration}s.")] + Message = "Function {PluginName}-{FunctionName} streaming completed. Duration: {Duration}s.")] public static partial void LogFunctionStreamingComplete( this ILogger logger, + string? pluginName, + string functionName, double duration); } diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionLogMessagesTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionLogMessagesTests.cs index ab00eb27b9be..1ca1d1b124f1 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionLogMessagesTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionLogMessagesTests.cs @@ -22,11 +22,11 @@ public void ItShouldLogFunctionResultOfAnyType(Type resultType) // Arrange (object FunctionResult, string LogMessage) testData = resultType switch { - Type t when t == typeof(string) => ("test-string", "Function result: test-string"), - Type t when t == typeof(int) => (6, "Function result: 6"), - Type t when t == typeof(bool) => (true, "Function result: true"), - Type t when t == typeof(ChatMessageContent) => (new ChatMessageContent(AuthorRole.Assistant, "test-content"), "Function result: test-content"), - Type t when t == typeof(User) => (new User { Name = "test-user-name" }, "Function result: {\"name\":\"test-user-name\"}"), + Type t when t == typeof(string) => ("test-string", "Function p1-f1 result: test-string"), + Type t when t == typeof(int) => (6, "Function p1-f1 result: 6"), + Type t when t == typeof(bool) => (true, "Function p1-f1 result: true"), + Type t when t == typeof(ChatMessageContent) => (new ChatMessageContent(AuthorRole.Assistant, "test-content"), "Function p1-f1 result: test-content"), + Type t when t == typeof(User) => (new User { Name = "test-user-name" }, "Function p1-f1 result: {\"name\":\"test-user-name\"}"), _ => throw new ArgumentException("Invalid type") }; @@ -36,7 +36,7 @@ public void ItShouldLogFunctionResultOfAnyType(Type resultType) var functionResult = new FunctionResult(KernelFunctionFactory.CreateFromMethod(() => { }), testData.FunctionResult); // Act - logger.Object.LogFunctionResultValue(functionResult); + logger.Object.LogFunctionResultValue("p1", "f1", functionResult); // Assert logger.Verify(l => l.Log( From 936366ee0fcd74a14877d3bb3dc5ad42a1280ce1 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:10:47 -0700 Subject: [PATCH 8/8] .Net: Fixed typos (#9503) ### Motivation and Context Fixed typos that are blocking other PRs from merge due to Spell Checker. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .github/_typos.toml | 1 + .../UnitTests/Core/History/MockHistoryGenerator.cs | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/_typos.toml b/.github/_typos.toml index a926a1403856..c2394c3dd9e1 100644 --- a/.github/_typos.toml +++ b/.github/_typos.toml @@ -23,6 +23,7 @@ extend-exclude = [ "PopulationByCountry.csv", "PopulationByAdmin1.csv", "WomensSuffrage.txt", + "SK-dotnet.sln.DotSettings" ] [default.extend-words] diff --git a/dotnet/src/Agents/UnitTests/Core/History/MockHistoryGenerator.cs b/dotnet/src/Agents/UnitTests/Core/History/MockHistoryGenerator.cs index 375b6fc9aa40..3475776a1935 100644 --- a/dotnet/src/Agents/UnitTests/Core/History/MockHistoryGenerator.cs +++ b/dotnet/src/Agents/UnitTests/Core/History/MockHistoryGenerator.cs @@ -30,7 +30,7 @@ public static IEnumerable CreateHistoryWithUserInput(int mes { yield return index % 2 == 1 ? - new ChatMessageContent(AuthorRole.Assistant, $"asistant response: {index}") : + new ChatMessageContent(AuthorRole.Assistant, $"assistant response: {index}") : new ChatMessageContent(AuthorRole.User, $"user input: {index}"); } } @@ -49,18 +49,18 @@ public static IEnumerable CreateHistoryWithUserInput(int mes public static IEnumerable CreateHistoryWithFunctionContent() { yield return new ChatMessageContent(AuthorRole.User, "user input: 0"); - yield return new ChatMessageContent(AuthorRole.Assistant, "asistant response: 1"); + yield return new ChatMessageContent(AuthorRole.Assistant, "assistant response: 1"); yield return new ChatMessageContent(AuthorRole.User, "user input: 2"); - yield return new ChatMessageContent(AuthorRole.Assistant, "asistant response: 3"); + yield return new ChatMessageContent(AuthorRole.Assistant, "assistant response: 3"); yield return new ChatMessageContent(AuthorRole.User, "user input: 4"); yield return new ChatMessageContent(AuthorRole.Assistant, [new FunctionCallContent("function call: 5")]); yield return new ChatMessageContent(AuthorRole.Tool, [new FunctionResultContent("function result: 6")]); - yield return new ChatMessageContent(AuthorRole.Assistant, "asistant response: 7"); + yield return new ChatMessageContent(AuthorRole.Assistant, "assistant response: 7"); yield return new ChatMessageContent(AuthorRole.User, "user input: 8"); yield return new ChatMessageContent(AuthorRole.Assistant, [new FunctionCallContent("function call: 9")]); yield return new ChatMessageContent(AuthorRole.Tool, [new FunctionResultContent("function result: 10")]); - yield return new ChatMessageContent(AuthorRole.Assistant, "asistant response: 11"); + yield return new ChatMessageContent(AuthorRole.Assistant, "assistant response: 11"); yield return new ChatMessageContent(AuthorRole.User, "user input: 12"); - yield return new ChatMessageContent(AuthorRole.Assistant, "asistant response: 13"); + yield return new ChatMessageContent(AuthorRole.Assistant, "assistant response: 13"); } }