Skip to content

Commit

Permalink
.Net: Support polymorphic serialization of ChatMessageContent class a…
Browse files Browse the repository at this point in the history
…nd its derivatives (#8901)

### Motivation, Context and Description
Today, when serializing chat history, the `ChatMessageContent` type or
its derivatives, like `OpenAIChatMessageContent`, are serialized as the
`ChatMessageContent` type because the type is used as the chat history
element type and the JSON serializer uses this type's public contract
for serialization.

However, when attempting serialization of instances of either the
`ChatMessageContent` or `OpenAIChatMessageContent` type that are
declared as `KernelContent` or `object` type, the serialization fails
with the `System.NotSupportedException: Runtime type
'{OpenAI}ChatMessageContent' is not supported by polymorphic type
'Microsoft.SemanticKernel.KernelContent'` exception. The reason for this
exception is that neither of these types is registered for polymorphic
serialization.

This PR registers `ChatMessageContent` type for polymorphic
serialization to allow serialization of the type instances declared as
of `KernelContent` or `object` types:
```csharp

KernelContent content = new ChatMessageContent(...);

// Now it's possible to serialize the content variable of KernelContent type that holds reference to an instance of the ChatMessageContent type as ChatMessageContent type.
var json = JsonSerializer.Serialize(content); 
```
Additionally, it enables serialization of unknow in advance and external
derivatives of `ChatMessageContent` type like
`OpenAIChatMessageContent`. These types are serialized using contract of
nearest ancestor which is `ChatMessageContent` by default. To change
this behavior and register the unknown type for polymorphic
serialization use the contract model - [Configure polymorphism with the
contract
model](https://learn.microsoft.com/en-us/dotnet/standard/serialization/system-text-json/polymorphism?pivots=dotnet-8-0#configure-polymorphism-with-the-contract-model)
```csharp
KernelContent content = new UnknownChatMessageContent(...);

// The content variable will be serialized using the ChatMessageContent type contract.
var json = JsonSerializer.Serialize(content);

private class UnknownChatMessageContent : ChatMessageContent{}
```
Closes: #7478

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
SergeyMenshykh authored Sep 19, 2024
1 parent cc4a497 commit fbdd6bc
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ namespace Microsoft.SemanticKernel;
/// <summary>
/// Base class for all AI non-streaming results
/// </summary>
[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")]
[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type", UnknownDerivedTypeHandling = JsonUnknownDerivedTypeHandling.FallBackToNearestAncestor)]
[JsonDerivedType(typeof(TextContent), typeDiscriminator: nameof(TextContent))]
[JsonDerivedType(typeof(ImageContent), typeDiscriminator: nameof(ImageContent))]
[JsonDerivedType(typeof(FunctionCallContent), typeDiscriminator: nameof(FunctionCallContent))]
[JsonDerivedType(typeof(FunctionResultContent), typeDiscriminator: nameof(FunctionResultContent))]
[JsonDerivedType(typeof(BinaryContent), typeDiscriminator: nameof(BinaryContent))]
[JsonDerivedType(typeof(AudioContent), typeDiscriminator: nameof(AudioContent))]
[JsonDerivedType(typeof(ChatMessageContent), typeDiscriminator: nameof(ChatMessageContent))]
#pragma warning disable SKEXP0110
[JsonDerivedType(typeof(AnnotationContent), typeDiscriminator: nameof(AnnotationContent))]
[JsonDerivedType(typeof(FileReferenceContent), typeDiscriminator: nameof(FileReferenceContent))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ public void ItCanBeSerializeAndDeserialized()
new FunctionCallContent("function-name", "plugin-name", "function-id", new KernelArguments { ["parameter"] = "argument" }),
new FunctionResultContent(new FunctionCallContent("function-name", "plugin-name", "function-id"), "function-result"),
new FileReferenceContent(fileId: "file-id-1") { ModelId = "model-7", Metadata = new Dictionary<string, object?>() { ["metadata-key-7"] = "metadata-value-7" } },
new AnnotationContent("quote-8") { ModelId = "model-8", FileId = "file-id-2", StartIndex = 2, EndIndex = 24, Metadata = new Dictionary<string, object?>() { ["metadata-key-8"] = "metadata-value-8" } }
new AnnotationContent("quote-8") { ModelId = "model-8", FileId = "file-id-2", StartIndex = 2, EndIndex = 24, Metadata = new Dictionary<string, object?>() { ["metadata-key-8"] = "metadata-value-8" } },
];

// Act
Expand Down Expand Up @@ -320,4 +320,92 @@ public void ItCanBeSerializeAndDeserialized()
Assert.Single(annotationContent.Metadata);
Assert.Equal("metadata-value-8", annotationContent.Metadata["metadata-key-8"]?.ToString());
}

[Fact]
public void ItCanBePolymorphicallySerializedAndDeserializedAsKernelContentType()
{
// Arrange
KernelContent sut = new ChatMessageContent(AuthorRole.User, "test-content", "test-model", metadata: new Dictionary<string, object?>()
{
["test-metadata-key"] = "test-metadata-value"
})
{
MimeType = "test-mime-type"
};

// Act
var json = JsonSerializer.Serialize(sut);

var deserialized = JsonSerializer.Deserialize<KernelContent>(json)!;

// Assert
Assert.IsType<ChatMessageContent>(deserialized);
Assert.Equal("test-content", ((ChatMessageContent)deserialized).Content);
Assert.Equal("test-model", deserialized.ModelId);
Assert.Equal("test-mime-type", deserialized.MimeType);
Assert.NotNull(deserialized.Metadata);
Assert.Single(deserialized.Metadata);
Assert.Equal("test-metadata-value", deserialized.Metadata["test-metadata-key"]?.ToString());
}

[Fact]
public void UnknownDerivativeCanBePolymorphicallySerializedAndDeserializedAsChatMessageContentType()
{
// Arrange
KernelContent sut = new UnknownExternalChatMessageContent(AuthorRole.User, "test-content")
{
MimeType = "test-mime-type",
};

// Act
var json = JsonSerializer.Serialize(sut);

var deserialized = JsonSerializer.Deserialize<KernelContent>(json)!;

// Assert
Assert.IsType<ChatMessageContent>(deserialized);
Assert.Equal("test-content", ((ChatMessageContent)deserialized).Content);
Assert.Equal("test-mime-type", deserialized.MimeType);
}

[Fact]
public void ItCanBeSerializeAndDeserializedWithFunctionResultOfChatMessageType()
{
// Arrange
ChatMessageContentItemCollection items = [
new FunctionResultContent(new FunctionCallContent("function-name-1", "plugin-name-1", "function-id-1"), new ChatMessageContent(AuthorRole.User, "test-content-1")),
new FunctionResultContent(new FunctionCallContent("function-name-2", "plugin-name-2", "function-id-2"), new UnknownExternalChatMessageContent(AuthorRole.Assistant, "test-content-2")),
];

// Act
var chatMessageJson = JsonSerializer.Serialize(new ChatMessageContent(AuthorRole.User, items: items, "message-model"));

var deserializedMessage = JsonSerializer.Deserialize<ChatMessageContent>(chatMessageJson)!;

// Assert
var functionResultContentWithResultOfChatMessageContentType = deserializedMessage.Items[0] as FunctionResultContent;
Assert.NotNull(functionResultContentWithResultOfChatMessageContentType);
Assert.Equal("function-name-1", functionResultContentWithResultOfChatMessageContentType.FunctionName);
Assert.Equal("function-id-1", functionResultContentWithResultOfChatMessageContentType.CallId);
Assert.Equal("plugin-name-1", functionResultContentWithResultOfChatMessageContentType.PluginName);
var chatMessageContent = Assert.IsType<JsonElement>(functionResultContentWithResultOfChatMessageContentType.Result);
Assert.Equal("user", chatMessageContent.GetProperty("Role").GetProperty("Label").GetString());
Assert.Equal("test-content-1", chatMessageContent.GetProperty("Items")[0].GetProperty("Text").GetString());

var functionResultContentWithResultOfUnknownChatMessageContentType = deserializedMessage.Items[1] as FunctionResultContent;
Assert.NotNull(functionResultContentWithResultOfUnknownChatMessageContentType);
Assert.Equal("function-name-2", functionResultContentWithResultOfUnknownChatMessageContentType.FunctionName);
Assert.Equal("function-id-2", functionResultContentWithResultOfUnknownChatMessageContentType.CallId);
Assert.Equal("plugin-name-2", functionResultContentWithResultOfUnknownChatMessageContentType.PluginName);
var unknownChatMessageContent = Assert.IsType<JsonElement>(functionResultContentWithResultOfUnknownChatMessageContentType.Result);
Assert.Equal("assistant", unknownChatMessageContent.GetProperty("Role").GetProperty("Label").GetString());
Assert.Equal("test-content-2", unknownChatMessageContent.GetProperty("Items")[0].GetProperty("Text").GetString());
}

private sealed class UnknownExternalChatMessageContent : ChatMessageContent
{
public UnknownExternalChatMessageContent(AuthorRole role, string? content) : base(role, content)
{
}
}
}

0 comments on commit fbdd6bc

Please sign in to comment.