Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Fix: Enhance Function Argument Validation to Prevent Null Argument Exceptions in Tool Calls. #9273

Merged
merged 6 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,44 @@ public async Task ItTargetsApiVersionAsExpected(string? apiVersion, string? expe
Assert.Contains($"api-version={expectedVersion}", this._messageHandlerStub.RequestUris[0]!.ToString());
}

[Fact]
public async Task GetStreamingChatMessageContentsWithFunctionCallAndEmptyArgumentsDoNotThrowAsync()
{
// Arrange
int functionCallCount = 0;

var kernel = Kernel.CreateBuilder().Build();
var function = KernelFunctionFactory.CreateFromMethod((string addressCode) =>
{
functionCallCount++;
return "Some weather";
}, "GetWeather");

kernel.Plugins.Add(KernelPluginFactory.CreateFromFunctions("WeatherPlugin", [function]));
using var multiHttpClient = new HttpClient(this._messageHandlerStub, false);
var service = new OpenAIChatCompletionService("model-id", "api-key", httpClient: multiHttpClient, loggerFactory: this._mockLoggerFactory.Object);
var settings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };

this._messageHandlerStub.ResponsesToReturn.Add(
new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_single_function_call_empty_assistance_response.txt"))
});

this._messageHandlerStub.ResponsesToReturn.Add(
new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_test_response.txt"))
});

// Act & Assert
await foreach (var chunk in service.GetStreamingChatMessageContentsAsync([], settings, kernel))
{
}

Assert.Equal(1, functionCallCount);
}

public static TheoryData<string?, string?> Versions => new()
{
{ null, "2024-08-01-preview" },
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_id","type":"function","function":{"name":"WeatherPlugin-GetWeather","arguments":""}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\n"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" "}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" \""}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"address"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Code"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" \""}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"440"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"100"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"\n"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}

data: [DONE]
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
<None Update="TestData\chat_completion_streaming_multiple_function_calls_test_response.txt">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
<None Update="TestData\chat_completion_streaming_single_function_call_empty_assistance_response.txt">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
<None Update="TestData\chat_completion_streaming_single_function_call_test_response.txt">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Text;
using System.Text.Json;
Expand Down Expand Up @@ -82,4 +83,50 @@ public void ConvertToolCallUpdatesWithNotEmptyIndexesReturnsNotEmptyToolCalls()
Assert.Equal("test-function", toolCall.FunctionName);
Assert.Equal("test-argument", toolCall.FunctionArguments.ToString());
}

[Fact]
public void TrackStreamingToolingUpdateWithNullUpdatesDoesNotThrowException()
RogerBarreto marked this conversation as resolved.
Show resolved Hide resolved
{
// Arrange
Dictionary<int, string>? toolCallIdsByIndex = null;
Dictionary<int, string>? functionNamesByIndex = null;
Dictionary<int, StringBuilder>? functionArgumentBuildersByIndex = null;
IReadOnlyList<StreamingChatToolCallUpdate>? updates = [];

StreamingChatToolCallUpdate update = ModelReaderWriter.Read<StreamingChatToolCallUpdate>(BinaryData.FromString("""{"index":0,"id":"call_id","type":"function","function":{"name":"WeatherPlugin-GetWeather","arguments":""}}"""))!;

// Act
var exception = Record.Exception(() =>
OpenAIFunctionToolCall.TrackStreamingToolingUpdate(
[
GetUpdateChunkFromString("""{"index":0,"id":"call_id","type":"function","function":{"name":"WeatherPlugin-GetWeather","arguments":""}}"""),
GetUpdateChunkFromString("""{"index":0,"function":{"arguments":"{\n"}}"""),
GetUpdateChunkFromString("""{"index":0,"function":{"arguments":" "}}"""),
GetUpdateChunkFromString("""{"index":0,"function":{"arguments":" \""}}"""),
GetUpdateChunkFromString("""{"index":0,"function":{"arguments":"address"}}"""),
GetUpdateChunkFromString("""{"index":0,"function":{"arguments":"Code"}}"""),
GetUpdateChunkFromString("""{"index":0,"function":{"arguments":"\":"}}"""),
GetUpdateChunkFromString("""{"index":0,"function":{"arguments":" \""}}"""),
GetUpdateChunkFromString("""{"index":0,"function":{"arguments":"440"}}"""),
GetUpdateChunkFromString("""{"index":0,"function":{"arguments":"100"}}"""),
GetUpdateChunkFromString("""{"index":0,"function":{"arguments":"\"\n"}}"""),
GetUpdateChunkFromString("""{"index":0,"function":{"arguments":"}"}}"""),
],
ref toolCallIdsByIndex,
ref functionNamesByIndex,
ref functionArgumentBuildersByIndex
));

// Assert
Assert.Equal(
"""
{
"addressCode": "440100"
}
""", functionArgumentBuildersByIndex![0].ToString());
Assert.Null(exception);
}

private static StreamingChatToolCallUpdate GetUpdateChunkFromString(string jsonChunk)
=> ModelReaderWriter.Read<StreamingChatToolCallUpdate>(BinaryData.FromString(jsonChunk))!;
}
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,44 @@ public async Task GetStreamingChatMessageContentsWithFunctionCallMaximumAutoInvo
Assert.Equal(DefaultMaximumAutoInvokeAttempts, functionCallCount);
}

[Fact]
public async Task GetStreamingChatMessageContentsWithFunctionCallAndEmptyArgumentsDoNotThrowAsync()
{
// Arrange
int functionCallCount = 0;

var kernel = Kernel.CreateBuilder().Build();
var function = KernelFunctionFactory.CreateFromMethod((string addressCode) =>
{
functionCallCount++;
return "Some weather";
}, "GetWeather");

kernel.Plugins.Add(KernelPluginFactory.CreateFromFunctions("WeatherPlugin", [function]));
using var multiHttpClient = new HttpClient(this._multiMessageHandlerStub, false);
var service = new OpenAIChatCompletionService("model-id", "api-key", httpClient: multiHttpClient, loggerFactory: this._mockLoggerFactory.Object);
var settings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };

this._multiMessageHandlerStub.ResponsesToReturn.Add(
new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_single_function_call_empty_assistance_response.txt"))
});

this._multiMessageHandlerStub.ResponsesToReturn.Add(
new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_test_response.txt"))
});

// Act & Assert
await foreach (var chunk in service.GetStreamingChatMessageContentsAsync([], settings, kernel))
{
}

Assert.Equal(1, functionCallCount);
}

[Fact]
public async Task GetStreamingChatMessageContentsWithRequiredFunctionCallAsync()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_id","type":"function","function":{"name":"WeatherPlugin-GetWeather","arguments":""}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\n"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" "}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" \""}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"address"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Code"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" \""}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"440"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"100"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"\n"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-AH9wO192nxDoDKnTwpgdLCtAYLkjp","object":"chat.completion.chunk","created":1728653152,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_67802d9a6d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}

data: [DONE]
Original file line number Diff line number Diff line change
Expand Up @@ -308,22 +308,14 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
// If we're intending to invoke function calls, we need to consume that function call information.
if (toolCallingConfig.AutoInvoke)
{
try
foreach (var contentPart in chatCompletionUpdate.ContentUpdate)
{
foreach (var contentPart in chatCompletionUpdate.ContentUpdate)
if (contentPart.Kind == ChatMessageContentPartKind.Text)
{
if (contentPart.Kind == ChatMessageContentPartKind.Text)
{
(contentBuilder ??= new()).Append(contentPart.Text);
}
(contentBuilder ??= new()).Append(contentPart.Text);
}
OpenAIFunctionToolCall.TrackStreamingToolingUpdate(chatCompletionUpdate.ToolCallUpdates, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex);
}
catch (NullReferenceException)
{
// Temporary workaround for OpenAI SDK Bug here: https://github.com/openai/openai-dotnet/issues/198
// TODO: Remove this try-catch block once the bug is fixed.
}
OpenAIFunctionToolCall.TrackStreamingToolingUpdate(chatCompletionUpdate.ToolCallUpdates, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex);
}

var openAIStreamingChatMessageContent = new OpenAIStreamingChatMessageContent(chatCompletionUpdate, 0, targetModel, metadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ internal static void TrackStreamingToolingUpdate(
}

// Ensure we're tracking the function's arguments.
if (update.FunctionArgumentsUpdate is not null)
if (update.FunctionArgumentsUpdate is not null && !update.FunctionArgumentsUpdate.ToMemory().IsEmpty)
Copy link
Member

@stephentoub stephentoub Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my edification, why is this additional clause required? Is it trying to be an optimization, or is it fixing an issue? Based on a cursory look at the code, for the condition that clause is guarding, I'd have expected arguments.Append(update.FunctionArgumentsUpdate.ToString()) to simply be the equivalent of arguments.Append(""), since BinaryData.ToString() just does Encoding.UTF8.GetString on the supplied bytes, and passing in an empty set of bytes will lead that to return an empty string.

{
if (!(functionArgumentBuildersByIndex ??= []).TryGetValue(update.Index, out StringBuilder? arguments))
{
Expand Down
Loading