Skip to content

Commit

Permalink
.Net: Make ONNX connector Native-AOT compatible. (#9192)
Browse files Browse the repository at this point in the history
### Motivation and Context
This PR makes the ONNX connector Native-AOT friendly, allowing it to be
used in applications targeting Native-AOT.

### Description
- Adds a new parameter of `JsonSerializerOptions` type to the
constructor of the `OnnxRuntimeGenAIChatCompletionService` class and a
few extension methods.
- Adds a new
`OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings` method
that accepts JSOs to serialize prompt execution settings type of which
is not known in advance.
- Makes function choice behavior classes public, allowing the source
generator to access them and not fail with the error:
"_dotnet\src\Connectors\Connectors.Onnx\Text\OnnxRuntimeGenAIPromptExecutionSettingsJsonSerializerContext
- Copy.cs(10,31,10,91): warning SYSLIB1222: The constructor on type
'Microsoft.SemanticKernel.AutoFunctionChoiceBehavior' has been annotated
with JsonConstructorAttribute but is not accessible by the source
generator.
(https://learn.microsoft.com/dotnet/fundamentals/syslib-diagnostics/syslib1222)_"

CC: @eiriktsarpalis
  • Loading branch information
SergeyMenshykh authored Oct 11, 2024
1 parent c029612 commit 47a2ad3
Show file tree
Hide file tree
Showing 15 changed files with 201 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
<ItemGroup>
<ProjectReference Include="..\..\..\src\SemanticKernel.Abstractions\SemanticKernel.Abstractions.csproj" />
<ProjectReference Include="..\..\..\src\SemanticKernel.Core\SemanticKernel.Core.csproj" />
<ProjectReference Include="..\..\..\src\Connectors\Connectors.Onnx\Connectors.Onnx.csproj" />
<TrimmerRootAssembly Include="Microsoft.SemanticKernel.Abstractions" />
<TrimmerRootAssembly Include="Microsoft.SemanticKernel.Core" />
<TrimmerRootAssembly Include="Microsoft.SemanticKernel.Connectors.Onnx" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;
using Microsoft.SemanticKernel;

namespace SemanticKernel.Connectors.Onnx.UnitTests;

internal sealed class CustomPromptExecutionSettings : PromptExecutionSettings
{
/// <summary>
/// Temperature to sample with.
/// </summary>
[JsonPropertyName("temperature")]
public float? Temperature { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;

namespace SemanticKernel.Connectors.Onnx.UnitTests;

[JsonSerializable(typeof(CustomPromptExecutionSettings))]
internal sealed partial class CustomPromptExecutionSettingsJsonSerializerContext : JsonSerializerContext
{
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,37 @@ public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecia
Assert.False(onnxExecutionSettings.EarlyStopping);
Assert.True(onnxExecutionSettings.DoSample);
}

[Fact]
public void ItShouldCreateOnnxPromptExecutionSettingsFromCustomPromptExecutionSettings()
{
// Arrange
var customExecutionSettings = new CustomPromptExecutionSettings() { ServiceId = "service-id", Temperature = 36.6f };

// Act
var onnxExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(customExecutionSettings);

// Assert
Assert.Equal("service-id", onnxExecutionSettings.ServiceId);
Assert.Equal(36.6f, onnxExecutionSettings.Temperature);
}

[Fact]
public void ItShouldCreateOnnxPromptExecutionSettingsFromCustomPromptExecutionSettingsUsingJSOs()
{
// Arrange
var jsos = new JsonSerializerOptions
{
TypeInfoResolver = CustomPromptExecutionSettingsJsonSerializerContext.Default
};

var customExecutionSettings = new CustomPromptExecutionSettings() { ServiceId = "service-id", Temperature = 36.6f };

// Act
var onnxExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(customExecutionSettings, jsos);

// Assert
Assert.Equal("service-id", onnxExecutionSettings.ServiceId);
Assert.Equal(36.6f, onnxExecutionSettings.Temperature);
}
}
16 changes: 14 additions & 2 deletions dotnet/src/Connectors/Connectors.Onnx/Connectors.Onnx.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,24 @@
<RootNamespace>$(AssemblyName)</RootNamespace>
<TargetFrameworks>net8.0;netstandard2.0</TargetFrameworks>
<VersionSuffix>alpha</VersionSuffix>
<IsAotCompatible Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net7.0'))">true</IsAotCompatible>
</PropertyGroup>

<!-- IMPORT NUGET PACKAGE SHARED PROPERTIES -->
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<Import Project="$(RepoRoot)/dotnet/src/InternalUtilities/src/InternalUtilities.props" />

<ItemGroup>
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Diagnostics/Verify.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Diagnostics/NullableAttributes.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Diagnostics/ExperimentalAttribute.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Diagnostics/CompilerServicesAttributes.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Diagnostics/IsExternalInit.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Diagnostics/RequiresUnreferencedCodeAttribute.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Diagnostics/RequiresDynamicCodeAttribute.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Diagnostics/UnconditionalSuppressMessageAttribute.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Text/JsonOptionsCache.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/System/AppContextSwitchHelper.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
</ItemGroup>

<PropertyGroup>
<Title>Semantic Kernel - ONNX Connectors</Title>
Expand All @@ -19,7 +32,6 @@

<ItemGroup>
<ProjectReference Include="..\..\SemanticKernel.Core\SemanticKernel.Core.csproj" />

<PackageReference Include="FastBertTokenizer" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" />
<PackageReference Include="System.Numerics.Tensors" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System.IO;
using System.Text.Json;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
Expand All @@ -23,18 +24,21 @@ public static class OnnxKernelBuilderExtensions
/// <param name="modelId">Model Id.</param>
/// <param name="modelPath">The generative AI ONNX model path.</param>
/// <param name="serviceId">The optional service ID.</param>
/// <param name="jsonSerializerOptions">The <see cref="JsonSerializerOptions"/> to use for various aspects of serialization, such as function argument deserialization, function result serialization, logging, etc., of the service.</param>
/// <returns>The updated kernel builder.</returns>
public static IKernelBuilder AddOnnxRuntimeGenAIChatCompletion(
this IKernelBuilder builder,
string modelId,
string modelPath,
string? serviceId = null)
string? serviceId = null,
JsonSerializerOptions? jsonSerializerOptions = null)
{
builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
new OnnxRuntimeGenAIChatCompletionService(
modelId,
modelPath: modelPath,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
loggerFactory: serviceProvider.GetService<ILoggerFactory>(),
jsonSerializerOptions));

return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
Expand All @@ -20,6 +22,7 @@ public sealed class OnnxRuntimeGenAIChatCompletionService : IChatCompletionServi
{
private readonly string _modelId;
private readonly string _modelPath;
private readonly JsonSerializerOptions? _jsonSerializerOptions;
private Model? _model;
private Tokenizer? _tokenizer;

Expand All @@ -31,17 +34,19 @@ public sealed class OnnxRuntimeGenAIChatCompletionService : IChatCompletionServi
/// <param name="modelId">The name of the model.</param>
/// <param name="modelPath">The generative AI ONNX model path for the chat completion service.</param>
/// <param name="loggerFactory">Optional logger factory to be used for logging.</param>
/// <param name="jsonSerializerOptions">The <see cref="JsonSerializerOptions"/> to use for various aspects of serialization and deserialization required by the service.</param>
public OnnxRuntimeGenAIChatCompletionService(
string modelId,
string modelPath,
ILoggerFactory? loggerFactory = null)
ILoggerFactory? loggerFactory = null,
JsonSerializerOptions? jsonSerializerOptions = null)
{
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(modelPath);

this._modelId = modelId;
this._modelPath = modelPath;

this._jsonSerializerOptions = jsonSerializerOptions;
this.AttributesInternal.Add(AIServiceExtensions.ModelIdKey, this._modelId);
}

Expand Down Expand Up @@ -82,7 +87,7 @@ public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessa

private async IAsyncEnumerable<string> RunInferenceAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings, [EnumeratorCancellation] CancellationToken cancellationToken)
{
OnnxRuntimeGenAIPromptExecutionSettings onnxPromptExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);
OnnxRuntimeGenAIPromptExecutionSettings onnxPromptExecutionSettings = this.GetOnnxPromptExecutionSettingsSettings(executionSettings);

var prompt = this.GetPrompt(chatHistory, onnxPromptExecutionSettings);
var tokens = this.GetTokenizer().Encode(prompt);
Expand Down Expand Up @@ -190,6 +195,18 @@ private void UpdateGeneratorParamsFromPromptExecutionSettings(GeneratorParams ge
}
}

[UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "JSOs are required only in cases where the supplied settings are not Onnx-specific. For these cases, JSOs can be provided via the class constructor.")]
[UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "JSOs are required only in cases where the supplied settings are not Onnx-specific. For these cases, JSOs can be provided via class constructor.")]
private OnnxRuntimeGenAIPromptExecutionSettings GetOnnxPromptExecutionSettingsSettings(PromptExecutionSettings? executionSettings)
{
if (this._jsonSerializerOptions is not null)
{
return OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings, this._jsonSerializerOptions);
}

return OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);
}

/// <inheritdoc/>
public void Dispose()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.SemanticKernel.Text;

namespace Microsoft.SemanticKernel.Connectors.Onnx;
Expand All @@ -14,8 +16,10 @@ public sealed class OnnxRuntimeGenAIPromptExecutionSettings : PromptExecutionSet
/// <summary>
/// Convert PromptExecutionSettings to OnnxRuntimeGenAIPromptExecutionSettings
/// </summary>
/// <param name="executionSettings"></param>
/// <returns></returns>
/// <param name="executionSettings">The <see cref="PromptExecutionSettings"/> to convert to <see cref="OnnxRuntimeGenAIPromptExecutionSettings"/>.</param>
/// <returns>Returns the <see cref="OnnxRuntimeGenAIPromptExecutionSettings"/> object.</returns>
[RequiresUnreferencedCode("This method uses reflection to serialize and deserialize the execution settings, making it incompatible with AOT scenarios.")]
[RequiresDynamicCode("This method uses reflection to serialize and deserialize the execution settings, making it incompatible with AOT scenarios.")]
public static OnnxRuntimeGenAIPromptExecutionSettings FromExecutionSettings(PromptExecutionSettings? executionSettings)
{
if (executionSettings is null)
Expand All @@ -28,10 +32,34 @@ public static OnnxRuntimeGenAIPromptExecutionSettings FromExecutionSettings(Prom
return settings;
}

var json = JsonSerializer.Serialize(executionSettings);
var json = JsonSerializer.Serialize<object>(executionSettings);

var onnxRuntimeGenAIPromptExecutionSettings = JsonSerializer.Deserialize<OnnxRuntimeGenAIPromptExecutionSettings>(json, JsonOptionsCache.ReadPermissive);
return onnxRuntimeGenAIPromptExecutionSettings!;
return JsonSerializer.Deserialize<OnnxRuntimeGenAIPromptExecutionSettings>(json, JsonOptionsCache.ReadPermissive)!;
}

/// <summary>
/// Convert PromptExecutionSettings to OnnxRuntimeGenAIPromptExecutionSettings
/// </summary>
/// <param name="executionSettings">The <see cref="PromptExecutionSettings"/> to convert to <see cref="OnnxRuntimeGenAIPromptExecutionSettings"/>.</param>
/// <param name="jsonSerializerOptions">The <see cref="JsonSerializerOptions"/> to use for serialization of <see cref="PromptExecutionSettings"/> and deserialize them to <see cref="OnnxRuntimeGenAIPromptExecutionSettings"/>.</param>
/// <returns>Returns the <see cref="OnnxRuntimeGenAIPromptExecutionSettings"/> object.</returns>
public static OnnxRuntimeGenAIPromptExecutionSettings FromExecutionSettings(PromptExecutionSettings? executionSettings, JsonSerializerOptions jsonSerializerOptions)
{
if (executionSettings is null)
{
return new OnnxRuntimeGenAIPromptExecutionSettings();
}

if (executionSettings is OnnxRuntimeGenAIPromptExecutionSettings settings)
{
return settings;
}

JsonTypeInfo typeInfo = jsonSerializerOptions.GetTypeInfo(executionSettings!.GetType());

var json = JsonSerializer.Serialize(executionSettings, typeInfo);

return JsonSerializer.Deserialize<OnnxRuntimeGenAIPromptExecutionSettings>(json, OnnxRuntimeGenAIPromptExecutionSettingsJsonSerializerContext.ReadPermissive.OnnxRuntimeGenAIPromptExecutionSettings)!;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System.IO;
using System.Text.Json;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
Expand All @@ -23,18 +24,21 @@ public static class OnnxServiceCollectionExtensions
/// <param name="modelId">The name of the model.</param>
/// <param name="modelPath">The generative AI ONNX model path.</param>
/// <param name="serviceId">Optional service ID.</param>
/// <param name="jsonSerializerOptions">The <see cref="JsonSerializerOptions"/> to use for various aspects of serialization and deserialization required by the service.</param>
/// <returns>The updated service collection.</returns>
public static IServiceCollection AddOnnxRuntimeGenAIChatCompletion(
this IServiceCollection services,
string modelId,
string modelPath,
string? serviceId = null)
string? serviceId = null,
JsonSerializerOptions? jsonSerializerOptions = null)
{
services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
new OnnxRuntimeGenAIChatCompletionService(
modelId,
modelPath,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
loggerFactory: serviceProvider.GetService<ILoggerFactory>(),
jsonSerializerOptions));

return services;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.SemanticKernel.Connectors.Onnx;

namespace Microsoft.SemanticKernel.Text;

[JsonSerializable(typeof(OnnxRuntimeGenAIPromptExecutionSettings))]
internal sealed partial class OnnxRuntimeGenAIPromptExecutionSettingsJsonSerializerContext : JsonSerializerContext
{
public static readonly OnnxRuntimeGenAIPromptExecutionSettingsJsonSerializerContext ReadPermissive = new(new()
{
AllowTrailingCommas = true,
PropertyNameCaseInsensitive = true,
ReadCommentHandling = JsonCommentHandling.Skip,
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using YamlDotNet.Serialization;
using YamlDotNet.Serialization.BufferedDeserialization;
using YamlDotNet.Serialization.NamingConventions;
using YamlDotNet.Serialization.ObjectFactories;

namespace Microsoft.SemanticKernel;

Expand All @@ -29,6 +30,7 @@ public bool Accepts(Type type)
s_deserializer ??= new DeserializerBuilder()
.WithNamingConvention(UnderscoredNamingConvention.Instance)
.IgnoreUnmatchedProperties() // Required to ignore the 'type' property used as type discrimination. Otherwise, the "Property 'type' not found on type '{type.FullName}'" exception is thrown.
.WithObjectFactory(new FunctionChoiceBehaviorsObjectFactory())
.WithTypeDiscriminatingNodeDeserializer(CreateAndRegisterTypeDiscriminatingNodeDeserializer)
.Build();

Expand Down Expand Up @@ -98,4 +100,22 @@ private static void CreateAndRegisterTypeDiscriminatingNodeDeserializer(ITypeDis
/// The YamlDotNet deserializer instance.
/// </summary>
private static IDeserializer? s_deserializer;

private sealed class FunctionChoiceBehaviorsObjectFactory : ObjectFactoryBase
{
private static DefaultObjectFactory? s_defaultFactory = null;

public override object Create(Type type)
{
if (type == typeof(AutoFunctionChoiceBehavior) ||
type == typeof(NoneFunctionChoiceBehavior) ||
type == typeof(RequiredFunctionChoiceBehavior))
{
return Activator.CreateInstance(type, nonPublic: true)!;
}

// Use the default object factory for other types
return (s_defaultFactory ??= new DefaultObjectFactory()).Create(type);
}
}
}
6 changes: 3 additions & 3 deletions dotnet/src/IntegrationTests/Processes/ProcessCycleTests.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.SemanticKernel;
using System.Threading.Tasks;
using System;
using System.Threading.Tasks;
using Microsoft.Extensions.Configuration;
using Microsoft.SemanticKernel;
using SemanticKernel.IntegrationTests.Agents;
using SemanticKernel.IntegrationTests.TestSettings;
using Xunit;
using SemanticKernel.IntegrationTests.Agents;

namespace SemanticKernel.IntegrationTests.Processes;

Expand Down
Loading

0 comments on commit 47a2ad3

Please sign in to comment.