Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds strict mode support #1

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7348f8b
Replace stj-schema-mapper source code with M.E.AI schema generation
eiriktsarpalis Nov 25, 2024
39a79b9
feat: adds strict mode flag to function calling
baywet Dec 6, 2024
835dff6
feat: adds strict schema function behaviour and maps it to the metadata
baywet Dec 6, 2024
c85f93c
chore: adds unit test for additional properties false in strict mode
baywet Dec 6, 2024
ba4cf8b
chore: adds tests for tool call behaviour and strict mode
baywet Dec 6, 2024
6f5a10f
chore: adds unit test for new function choice behaviour options property
baywet Dec 6, 2024
dfe780a
chore: cleanup reference to default
baywet Dec 6, 2024
84ca142
fix: badly formatted doc comment
baywet Dec 6, 2024
498dcdf
chore: adds test for function metadata to OAI function strict more ma…
baywet Dec 6, 2024
3ee8998
chore: adds validation for strict property mapping on OpenAIFunction
baywet Dec 6, 2024
0769ca5
chore: migrates to foreach
baywet Dec 6, 2024
bcb9785
chore: adds unit test for required properties behaviour with strict mode
baywet Dec 6, 2024
9542514
chore: adds test for metadata copy constructor
baywet Dec 6, 2024
9355ec4
feat: adds strict parameter to OpenAPI based functions
baywet Dec 6, 2024
8792fa5
fix: pass strict when cloning function
baywet Dec 6, 2024
0901a10
smell: having to set strict in the function prompt
baywet Dec 6, 2024
208f1f4
fix: reverts additional strict property
baywet Dec 9, 2024
a2dd727
fix: tests after strict property removal
baywet Dec 9, 2024
7a8adb2
chore: code linting
baywet Dec 9, 2024
817c9f7
fix: makes schema less parameters optional in strict mode
baywet Dec 9, 2024
947dfb3
feat; sanitizes forbidden strict mode keywords
baywet Dec 9, 2024
667d0aa
fix: adds missing null type in strict mode
baywet Dec 9, 2024
0ae128c
docs: add links to null type behaviour
baywet Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public static async IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(Assist

FunctionCallsProcessor functionProcessor = new(logger);
// This matches current behavior. Will be configurable upon integrating with `FunctionChoice` (#6795/#5200)
FunctionChoiceBehaviorOptions functionOptions = new() { AllowConcurrentInvocation = true, AllowParallelCalls = true };
FunctionChoiceBehaviorOptions functionOptions = new() { AllowConcurrentInvocation = true, AllowParallelCalls = true, AllowStrictSchemaAdherence = true };

// Evaluate status and process steps and messages, as encountered.
HashSet<string> processedStepIds = [];
Expand Down Expand Up @@ -412,7 +412,7 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin

FunctionCallsProcessor functionProcessor = new(logger);
// This matches current behavior. Will be configurable upon integrating with `FunctionChoice` (#6795/#5200)
FunctionChoiceBehaviorOptions functionOptions = new() { AllowConcurrentInvocation = true, AllowParallelCalls = true };
FunctionChoiceBehaviorOptions functionOptions = new() { AllowConcurrentInvocation = true, AllowParallelCalls = true, AllowStrictSchemaAdherence = true };

IAsyncEnumerable<StreamingUpdate> asyncUpdates = client.CreateRunStreamingAsync(threadId, agent.Id, options, cancellationToken);
do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,50 @@ public void ItCanConvertToFunctionDefinitionWithNoPluginName()
Assert.Equal(sut.Description, result.FunctionDescription);
}

[Fact]
public void ItCanConvertToFunctionDefinitionWithNullParameters()
[InlineData(true)]
[InlineData(false)]
[Theory]
public void ItCanConvertToFunctionDefinitionWithNullParameters(bool strict)
{
// Arrange
// Arrange
OpenAIFunction sut = new("plugin", "function", "description", null, null);

// Act
var result = sut.ToFunctionDefinition();
var result = sut.ToFunctionDefinition(strict);

// Assert
Assert.Equal("{\"type\":\"object\",\"required\":[],\"properties\":{}}", result.FunctionParameters.ToString());
if (strict)
{
Assert.Equal("{\"type\":\"object\",\"required\":[],\"properties\":{},\"additionalProperties\":false}", result.FunctionParameters.ToString());
}
else
{
Assert.Equal("{\"type\":\"object\",\"required\":[],\"properties\":{}}", result.FunctionParameters.ToString());
}
}

[InlineData(false)]
[InlineData(true)]
[Theory]
public void SetsParametersToRequiredWhenStrict(bool strict)
{
var parameters = new List<OpenAIFunctionParameter>
{
new ("foo", "bar", false, typeof(string), null),
};
OpenAIFunction sut = new("plugin", "function", "description", parameters, null);

var result = sut.ToFunctionDefinition(strict);

Assert.Equal(strict, result.FunctionSchemaIsStrict);
if (strict)
{
Assert.Equal("""{"type":"object","required":["foo"],"properties":{"foo":{"description":"bar","type":["string","null"]}},"additionalProperties":false}""", result.FunctionParameters.ToString());
}
else
{
Assert.Equal("""{"type":"object","required":[],"properties":{"foo":{"description":"bar","type":"string"}}}""", result.FunctionParameters.ToString());
}
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC

// Process function calls by invoking the functions and adding the results to the chat history.
// Each function call will trigger auto-function-invocation filters, which can terminate the process.
// In such cases, we'll return the last message in the chat history.
// In such cases, we'll return the last message in the chat history.
var lastMessage = await this.FunctionCallsProcessor.ProcessFunctionCallsAsync(
chatMessageContent,
chatHistory,
Expand Down Expand Up @@ -679,8 +679,8 @@ private static List<ChatMessage> CreateRequestMessages(ChatMessageContent messag
{
var toolCalls = new List<ChatToolCall>();

// Handling function calls supplied via either:
// ChatCompletionsToolCall.ToolCalls collection items or
// Handling function calls supplied via either:
// ChatCompletionsToolCall.ToolCalls collection items or
// ChatMessageContent.Metadata collection item with 'ChatResponseMessage.FunctionToolCalls' key.
IEnumerable<ChatToolCall>? tools = (message as OpenAIChatMessageContent)?.ToolCalls;
if (tools is null && message.Metadata?.TryGetValue(OpenAIChatMessageContent.FunctionToolCallsProperty, out object? toolCallsObject) is true)
Expand Down Expand Up @@ -734,7 +734,7 @@ private static List<ChatMessage> CreateRequestMessages(ChatMessageContent messag
}

// This check is necessary to prevent an exception that will be thrown if the toolCalls collection is empty.
// HTTP 400 (invalid_request_error:) [] should be non-empty - 'messages.3.tool_calls'
// HTTP 400 (invalid_request_error:) [] should be non-empty - 'messages.3.tool_calls'
if (toolCalls.Count == 0)
{
return [new AssistantChatMessage(message.Content) { ParticipantName = message.AuthorName }];
Expand Down Expand Up @@ -1011,7 +1011,7 @@ private ToolCallingConfig GetFunctionCallingConfiguration(Kernel? kernel, OpenAI

foreach (var function in functions)
{
tools.Add(function.Metadata.ToOpenAIFunction().ToFunctionDefinition());
tools.Add(function.Metadata.ToOpenAIFunction().ToFunctionDefinition(config?.Options?.AllowStrictSchemaAdherence ?? false));
}
}

Expand Down
135 changes: 119 additions & 16 deletions dotnet/src/Connectors/Connectors.OpenAI/Core/OpenAIFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using Microsoft.Extensions.AI;
using OpenAI.Chat;

namespace Microsoft.SemanticKernel.Connectors.OpenAI;
Expand Down Expand Up @@ -72,9 +76,17 @@ public sealed class OpenAIFunction
/// </remarks>
private static readonly BinaryData s_zeroFunctionParametersSchema = new("""{"type":"object","required":[],"properties":{}}""");
/// <summary>
/// Same as above, but with additionalProperties: false for strict mode.
/// </summary>
private static readonly BinaryData s_zeroFunctionParametersSchema_strict = new("""{"type":"object","required":[],"properties":{},"additionalProperties":false}""");
/// <summary>
/// Cached schema for a descriptionless string.
/// </summary>
private static readonly KernelJsonSchema s_stringNoDescriptionSchema = KernelJsonSchema.Parse("""{"type":"string"}""");
/// <summary>
/// Cached schema for a descriptionless string that's nullable.
/// </summary>
private static readonly KernelJsonSchema s_stringNoDescriptionSchemaAndNull = KernelJsonSchema.Parse("""{"type":["string","null"]}""");

/// <summary>Initializes the OpenAIFunction.</summary>
internal OpenAIFunction(
Expand Down Expand Up @@ -127,52 +139,143 @@ internal OpenAIFunction(
/// <see cref="ChatTool"/> representation.
/// </summary>
/// <returns>A <see cref="ChatTool"/> containing all the function information.</returns>
public ChatTool ToFunctionDefinition()
public ChatTool ToFunctionDefinition(bool allowStrictSchemaAdherence = false)
{
BinaryData resultParameters = s_zeroFunctionParametersSchema;
BinaryData resultParameters = allowStrictSchemaAdherence ? s_zeroFunctionParametersSchema_strict : s_zeroFunctionParametersSchema;

IReadOnlyList<OpenAIFunctionParameter>? parameters = this.Parameters;
if (parameters is { Count: > 0 })
{
var properties = new Dictionary<string, KernelJsonSchema>();
var required = new List<string>();

for (int i = 0; i < parameters.Count; i++)
foreach (var parameter in parameters)
{
var parameter = parameters[i];
properties.Add(parameter.Name, parameter.Schema ?? GetDefaultSchemaForTypelessParameter(parameter.Description));
if (parameter.IsRequired)
properties.Add(parameter.Name, GetSanitizedSchemaForStrictMode(parameter.Schema, !parameter.IsRequired && allowStrictSchemaAdherence) ?? GetDefaultSchemaForTypelessParameter(parameter.Description, allowStrictSchemaAdherence));
if (parameter.IsRequired || allowStrictSchemaAdherence)
{
required.Add(parameter.Name);
}
}

resultParameters = BinaryData.FromObjectAsJson(new
{
type = "object",
required,
properties,
});
resultParameters = allowStrictSchemaAdherence
? BinaryData.FromObjectAsJson(new
{
type = "object",
required,
properties,
additionalProperties = false
})
: BinaryData.FromObjectAsJson(new
{
type = "object",
required,
properties,
});
}

return ChatTool.CreateFunctionTool
(
functionName: this.FullyQualifiedName,
functionDescription: this.Description,
functionParameters: resultParameters
functionParameters: resultParameters,
functionSchemaIsStrict: allowStrictSchemaAdherence
);
}

/// <summary>Gets a <see cref="KernelJsonSchema"/> for a typeless parameter with the specified description, defaulting to typeof(string)</summary>
private static KernelJsonSchema GetDefaultSchemaForTypelessParameter(string? description)
private static KernelJsonSchema GetDefaultSchemaForTypelessParameter(string? description, bool allowStrictSchemaAdherence)
{
// If there's a description, incorporate it.
if (!string.IsNullOrWhiteSpace(description))
{
return KernelJsonSchemaBuilder.Build(typeof(string), description);
return allowStrictSchemaAdherence ?
GetOptionalStringSchemaWithDescription(description!) :
KernelJsonSchemaBuilder.Build(typeof(string), description, AIJsonSchemaCreateOptions.Default);
}

// Otherwise, we can use a cached schema for a string with no description.
return s_stringNoDescriptionSchema;
return allowStrictSchemaAdherence ? s_stringNoDescriptionSchemaAndNull : s_stringNoDescriptionSchema;
}

private static KernelJsonSchema GetOptionalStringSchemaWithDescription(string description)
{
var jObject = new JsonObject
{
{ "description", description },
{ "type", new JsonArray { "string", "null" } },
};
return KernelJsonSchema.Parse(jObject.ToString());
}
private static KernelJsonSchema? GetSanitizedSchemaForStrictMode(KernelJsonSchema? schema, bool insertNullType)
{
if (schema is null)
{
return null;
}
var forbiddenPropertyNames = s_forbiddenKeywords.Where(k => schema.RootElement.TryGetProperty(k, out _)).ToArray();

if (forbiddenPropertyNames.Length > 0 || insertNullType && schema.RootElement.TryGetProperty(TypeKey, out var typeElement) &&
(typeElement.ValueKind == JsonValueKind.Array && !typeElement.EnumerateArray().Any(static t => NullType.Equals(t.GetString(), StringComparison.OrdinalIgnoreCase)) ||
typeElement.ValueKind == JsonValueKind.String && typeElement.GetString() != NullType))
{
var originalSchema = JsonSerializer.Serialize(schema.RootElement);
var parsedJson = JsonNode.Parse(originalSchema);

if (parsedJson is null)
{
return schema;
}

var jsonObject = parsedJson.AsObject();
foreach (var forbiddenPropertyName in forbiddenPropertyNames)
{
jsonObject.Remove(forbiddenPropertyName);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is where I'm doing the forbidden keywords cleanup, might want to be recursive in case we have those on the request body.
more context microsoft#9807 (comment)

}

InsertNullTypeIfRequired(insertNullType, jsonObject);

return KernelJsonSchema.Parse(jsonObject.ToString());
}
return schema;
}
// https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
private static void InsertNullTypeIfRequired(bool insertNullType, JsonObject jsonObject)
{
if (insertNullType && jsonObject.TryGetPropertyValue(TypeKey, out var typeValue))
baywet marked this conversation as resolved.
Show resolved Hide resolved
{
if (typeValue is JsonArray jsonArray && !jsonArray.Contains(NullType))
{
jsonArray.Add(NullType);
}
else if (typeValue is JsonValue jsonValue && jsonValue.GetValueKind() == JsonValueKind.String)
{
jsonObject[TypeKey] = new JsonArray { typeValue.GetValue<string>(), NullType };
}
}
}
private const string NullType = "null";
private const string TypeKey = "type";
// https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
private static readonly string[] s_forbiddenKeywords = [
"contains",
"format",
"maxContains",
"maximum",
"maxItems",
"maxLength",
"maxProperties",
"minContains",
"minimum",
"minItems",
"minLength",
"minProperties",
"multipleOf",
"pattern",
"patternProperties",
"propertyNames",
"unevaluatedItems",
"unevaluatedProperties",
"uniqueItems",
];
}
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ public void ItShouldDeserializeAutoFunctionChoiceBehaviorFromJsonWithOptions()
options:
allow_parallel_calls: true
allow_concurrent_invocation: true
allow_strict_schema_adherence: true
""";

var executionSettings = this._deserializer.Deserialize<PromptExecutionSettings>(yaml);
Expand All @@ -314,6 +315,7 @@ public void ItShouldDeserializeAutoFunctionChoiceBehaviorFromJsonWithOptions()
// Assert
Assert.True(config.Options.AllowParallelCalls);
Assert.True(config.Options.AllowConcurrentInvocation);
Assert.True(config.Options.AllowStrictSchemaAdherence);
}

[Fact]
Expand All @@ -326,6 +328,7 @@ public void ItShouldDeserializeRequiredFunctionChoiceBehaviorFromJsonWithOptions
options:
allow_parallel_calls: true
allow_concurrent_invocation: true
allow_strict_schema_adherence: true
""";

var executionSettings = this._deserializer.Deserialize<PromptExecutionSettings>(yaml);
Expand All @@ -336,6 +339,7 @@ public void ItShouldDeserializeRequiredFunctionChoiceBehaviorFromJsonWithOptions
// Assert
Assert.True(config.Options.AllowParallelCalls);
Assert.True(config.Options.AllowConcurrentInvocation);
Assert.True(config.Options.AllowStrictSchemaAdherence);
}

private readonly string _yaml = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace Microsoft.SemanticKernel;
internal static class KernelJsonSchemaBuilder
{
private static JsonSerializerOptions? s_options;
private static readonly AIJsonSchemaCreateOptions s_schemaOptions = new()
internal static readonly AIJsonSchemaCreateOptions s_schemaOptions = new()
{
IncludeSchemaKeyword = false,
IncludeTypeInEnumSchemas = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,13 @@ public sealed class FunctionChoiceBehaviorOptions
/// </remarks>
[JsonPropertyName("allow_concurrent_invocation")]
public bool AllowConcurrentInvocation { get; set; } = false;

/// <summary>
/// Gets or sets whether the AI model should strictly adhere to the function schema.
/// </summary>
/// <remarks>
/// The default value is set to false. If set to true, the AI model will strictly adhere to the function schema.
/// </remarks>
[JsonPropertyName("allow_strict_schema_adherence")]
public bool AllowStrictSchemaAdherence { get; set; } = false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,24 @@ public void ItShouldPropagateAllowConcurrentInvocationOptionToConfiguration()
Assert.True(configuration.Options.AllowConcurrentInvocation);
}

[Fact]
public void ItShouldPropagateAllowStrictSchemaAdherenceOptionToConfiguration()
{
// Arrange
var options = new FunctionChoiceBehaviorOptions
{
AllowStrictSchemaAdherence = true
};

// Act
var choiceBehavior = new AutoFunctionChoiceBehavior(autoInvoke: false, options: options);

// Assert
var configuration = choiceBehavior.GetConfiguration(new FunctionChoiceBehaviorConfigurationContext(chatHistory: []));

Assert.True(configuration.Options.AllowStrictSchemaAdherence);
}

private static KernelPlugin GetTestPlugin()
{
var function1 = KernelFunctionFactory.CreateFromMethod(() => { }, "Function1");
Expand Down
Loading
Loading