From c7a371e3861bda3812a11950435fc5aa75c5572d Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Mon, 16 Dec 2024 04:33:04 -0500 Subject: [PATCH 1/5] .Net: Add PostgresVectorStore Memory connector. (#9324) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a PostgresVectorStore and related classes to Microsoft.SemanticKernel.Connectors.Postgres. ### Motivation and Context As part of the move to having memory connectors implement the new Microsoft.Extensions.VectorData.IVectorStore architecture (see https://github.com/microsoft/semantic-kernel/blob/main/docs/decisions/0050-updated-vector-store-design.md), each memory connector needs to be updated with the new architecture. This PR tackles updating the existing Microsoft.SemanticKernel.Connectors.Postgres package to include this implementation. This will supercede the PostgresMemoryStore implementation. Some high level comments about design: - PostgresVectorStore and PostgresVectorStoreRecordCollection get injected with an IPostgresVectorStoreDbClient. This abstracts the database communication and allows for unit tests to mock database interactions. - The PostgresVectorStoreDbClient gets passed in a NpgsqlDataSource from the user, which is used to manage connections to the database. The responsibility of connection pool lifecycle management is on the user. - The IPostgresVectorStoreDbClient is designed to accept and produce the storage model, which in this case is a Dictionary . This is the intermediate type that is mapped to by the IVectorStoreRecordMapper. - The PostgresVectorStoreDbClient also takes a IPostgresVectorStoreCollectionSqlBuilder, which generates SQL command information for interacting with the database. This abstracts the SQL queries related to each task, and allows for future expansion. This is particularly targeted at creating a AzureDBForPostgre vector store that will enable alternate vector implementations like [DiskANN](https://techcommunity.microsoft.com/t5/azure-database-for-postgresql/introducing-diskann-vector-index-in-azure-database-for/ba-p/4261192), while leveraging the same database client as the Postgres connector. -  The integration tests for the vector store utilize Docker.Net to bring up a pgvector/pgvector docker container, which test are run against. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --------- Co-authored-by: Rob Emanuele Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> --- dotnet/SK-dotnet.sln | 9 + dotnet/samples/Concepts/Concepts.csproj | 3 + .../VectorStoreFixtures/VectorStoreInfra.cs | 45 ++ .../VectorStorePostgresContainerFixture.cs | 67 +++ ...rStore_VectorSearch_MultiStore_Postgres.cs | 85 +++ dotnet/samples/Concepts/README.md | 82 +++ .../Connectors.Memory.Postgres.csproj | 5 + .../IPostgresDbClient.cs | 2 +- ...PostgresVectorStoreCollectionSqlBuilder.cs | 136 +++++ .../IPostgresVectorStoreDbClient.cs | 132 ++++ ...tgresVectorStoreRecordCollectionFactory.cs | 24 + .../PostgresConstants.cs | 92 +++ .../PostgresDbClient.cs | 2 +- .../PostgresGenericDataModelMapper.cs | 104 ++++ .../PostgresServiceCollectionExtensions.cs | 172 ++++++ .../PostgresSqlCommandInfo.cs | 55 ++ .../PostgresVectorStore.cs | 75 +++ ...PostgresVectorStoreCollectionSqlBuilder.cs | 453 ++++++++++++++ .../PostgresVectorStoreDbClient.cs | 253 ++++++++ .../PostgresVectorStoreOptions.cs | 19 + .../PostgresVectorStoreRecordCollection.cs | 378 ++++++++++++ ...tgresVectorStoreRecordCollectionOptions.cs | 35 ++ .../PostgresVectorStoreRecordMapper.cs | 100 ++++ ...ostgresVectorStoreRecordPropertyMapping.cs | 269 +++++++++ .../PostgresVectorStoreUtils.cs | 59 ++ .../Connectors.Memory.Postgres/README.md | 75 +-- .../Connectors.Postgres.UnitTests.csproj | 32 + .../PostgresGenericDataModelMapperTests.cs | 190 ++++++ .../PostgresHotel.cs | 51 ++ ...ostgresServiceCollectionExtensionsTests.cs | 70 +++ ...resVectorStoreCollectionSqlBuilderTests.cs | 422 +++++++++++++ ...ostgresVectorStoreRecordCollectionTests.cs | 207 +++++++ .../PostgresVectorStoreRecordMapperTests.cs | 213 +++++++ ...esVectorStoreRecordPropertyMappingTests.cs | 147 +++++ .../PostgresVectorStoreTests.cs | 143 +++++ .../Memory/Postgres/PostgresHotel.cs | 60 ++ .../Postgres/PostgresMemoryStoreTests.cs | 6 +- .../PostgresVectorStoreCollectionFixture.cs | 10 + .../Postgres/PostgresVectorStoreFixture.cs | 239 ++++++++ ...ostgresVectorStoreRecordCollectionTests.cs | 562 ++++++++++++++++++ .../Postgres/PostgresVectorStoreTests.cs | 28 + .../src/Linq/AsyncEnumerable.cs | 35 ++ 42 files changed, 5074 insertions(+), 72 deletions(-) create mode 100644 dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs create mode 100644 dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 0844db359552..0a711f84f5f3 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -411,6 +411,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AotCompatibility", "samples EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SemanticKernel.AotTests", "src\SemanticKernel.AotTests\SemanticKernel.AotTests.csproj", "{39EAB599-742F-417D-AF80-95F90376BB18}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.Postgres.UnitTests", "src\Connectors\Connectors.Postgres.UnitTests\Connectors.Postgres.UnitTests.csproj", "{232E1153-6366-4175-A982-D66B30AAD610}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Process.Utilities.UnitTests", "src\Experimental\Process.Utilities.UnitTests\Process.Utilities.UnitTests.csproj", "{DAC54048-A39A-4739-8307-EA5A291F2EA0}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "GettingStartedWithVectorStores", "samples\GettingStartedWithVectorStores\GettingStartedWithVectorStores.csproj", "{8C3DE41C-E2C8-42B9-8638-574F8946EB0E}" @@ -1074,6 +1076,12 @@ Global {6F591D05-5F7F-4211-9042-42D8BCE60415}.Publish|Any CPU.Build.0 = Debug|Any CPU {6F591D05-5F7F-4211-9042-42D8BCE60415}.Release|Any CPU.ActiveCfg = Release|Any CPU {6F591D05-5F7F-4211-9042-42D8BCE60415}.Release|Any CPU.Build.0 = Release|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Debug|Any CPU.Build.0 = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Publish|Any CPU.Build.0 = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Release|Any CPU.ActiveCfg = Release|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Release|Any CPU.Build.0 = Release|Any CPU {E82B640C-1704-430D-8D71-FD8ED3695468}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {E82B640C-1704-430D-8D71-FD8ED3695468}.Debug|Any CPU.Build.0 = Debug|Any CPU {E82B640C-1704-430D-8D71-FD8ED3695468}.Publish|Any CPU.ActiveCfg = Debug|Any CPU @@ -1311,6 +1319,7 @@ Global {E82B640C-1704-430D-8D71-FD8ED3695468} = {5A7028A7-4DDF-4E4F-84A9-37CE8F8D7E89} {6ECFDF04-2237-4A85-B114-DAA34923E9E6} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {39EAB599-742F-417D-AF80-95F90376BB18} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} + {232E1153-6366-4175-A982-D66B30AAD610} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {DAC54048-A39A-4739-8307-EA5A291F2EA0} = {0D8C6358-5DAA-4EA6-A924-C268A9A21BC9} {8C3DE41C-E2C8-42B9-8638-574F8946EB0E} = {FA3720F1-C99A-49B2-9577-A940257098BF} {DB58FDD0-308E-472F-BFF5-508BC64C727E} = {0D8C6358-5DAA-4EA6-A924-C268A9A21BC9} diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index d65aef92e0c3..746d5fbb73cf 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -102,6 +102,9 @@ + + Always + PreserveNewest diff --git a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs index ea498f20c5ab..2681231c80d7 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs @@ -10,6 +10,51 @@ namespace Memory.VectorStoreFixtures; /// internal static class VectorStoreInfra { + /// + /// Setup the postgres pgvector container by pulling the image and running it. + /// + /// The docker client to create the container with. + /// The id of the container. + public static async Task SetupPostgresContainerAsync(DockerClient client) + { + await client.Images.CreateImageAsync( + new ImagesCreateParameters + { + FromImage = "pgvector/pgvector", + Tag = "pg16", + }, + null, + new Progress()); + + var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() + { + Image = "pgvector/pgvector:pg16", + HostConfig = new HostConfig() + { + PortBindings = new Dictionary> + { + {"5432", new List {new() {HostPort = "5432" } }}, + }, + PublishAllPorts = true + }, + ExposedPorts = new Dictionary + { + { "5432", default }, + }, + Env = new List + { + "POSTGRES_USER=postgres", + "POSTGRES_PASSWORD=example", + }, + }); + + await client.Containers.StartContainerAsync( + container.ID, + new ContainerStartParameters()); + + return container.ID; + } + /// /// Setup the qdrant container by pulling the image and running it. /// diff --git a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs new file mode 100644 index 000000000000..200c4e48f5ac --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet; +using Npgsql; + +namespace Memory.VectorStoreFixtures; + +/// +/// Fixture to use for creating a Postgres container before tests and delete it after tests. +/// +public class VectorStorePostgresContainerFixture : IAsyncLifetime +{ + private DockerClient? _dockerClient; + private string? _postgresContainerId; + + public async Task InitializeAsync() + { + } + + public async Task ManualInitializeAsync() + { + if (this._postgresContainerId == null) + { + // Connect to docker and start the docker container. + using var dockerClientConfiguration = new DockerClientConfiguration(); + this._dockerClient = dockerClientConfiguration.CreateClient(); + this._postgresContainerId = await VectorStoreInfra.SetupPostgresContainerAsync(this._dockerClient); + + // Delay until the Postgres server is ready. + var connectionString = TestConfiguration.Postgres.ConnectionString; + var succeeded = false; + var attemptCount = 0; + while (!succeeded && attemptCount++ < 10) + { + try + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + using var dataSource = dataSourceBuilder.Build(); + NpgsqlConnection connection = await dataSource.OpenConnectionAsync().ConfigureAwait(false); + + await using (connection) + { + // Create extension vector if it doesn't exist + await using (NpgsqlCommand command = new("CREATE EXTENSION IF NOT EXISTS vector", connection)) + { + await command.ExecuteNonQueryAsync(); + } + } + } + catch (Exception) + { + await Task.Delay(1000); + } + } + } + } + + public async Task DisposeAsync() + { + if (this._dockerClient != null && this._postgresContainerId != null) + { + // Delete docker container. + await VectorStoreInfra.DeleteContainerAsync(this._dockerClient, this._postgresContainerId); + } + } +} diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs new file mode 100644 index 000000000000..e45c3390a2c0 --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Azure.Identity; +using Memory.VectorStoreFixtures; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; + +namespace Memory; + +/// +/// An example showing how to use common code, that can work with any vector database, with a Postgres database. +/// The common code is in the class. +/// The common code ingests data into the vector store and then searches over that data. +/// This example is part of a set of examples each showing a different vector database. +/// +/// For other databases, see the following classes: +/// +/// +/// +/// +/// To run this sample, you need a local instance of Docker running, since the associated fixture will try and start a Postgres container in the local docker instance. +/// +public class VectorStore_VectorSearch_MultiStore_Postgres(ITestOutputHelper output, VectorStorePostgresContainerFixture PostgresFixture) : BaseTest(output), IClassFixture +{ + [Fact] + public async Task ExampleWithDIAsync() + { + // Use the kernel for DI purposes. + var kernelBuilder = Kernel + .CreateBuilder(); + + // Register an embedding generation service with the DI container. + kernelBuilder.AddAzureOpenAITextEmbeddingGeneration( + deploymentName: TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + endpoint: TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + credential: new AzureCliCredential()); + + // Initialize the Postgres docker container via the fixtures and register the Postgres VectorStore. + await PostgresFixture.ManualInitializeAsync(); + kernelBuilder.Services.AddPostgresVectorStore(TestConfiguration.Postgres.ConnectionString); + + // Register the test output helper common processor with the DI container. + kernelBuilder.Services.AddSingleton(this.Output); + kernelBuilder.Services.AddTransient(); + + // Build the kernel. + var kernel = kernelBuilder.Build(); + + // Build a common processor object using the DI container. + var processor = kernel.GetRequiredService(); + + // Run the process and pass a key generator function to it, to generate unique record keys. + // The key generator function is required, since different vector stores may require different key types. + // E.g. Postgres supports Guid and ulong keys, but others may support strings only. + await processor.IngestDataAndSearchAsync("skglossaryWithDI", () => Guid.NewGuid()); + } + + [Fact] + public async Task ExampleWithoutDIAsync() + { + // Create an embedding generation service. + var textEmbeddingGenerationService = new AzureOpenAITextEmbeddingGenerationService( + TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + new AzureCliCredential()); + + // Initialize the Postgres docker container via the fixtures and construct the Postgres VectorStore. + await PostgresFixture.ManualInitializeAsync(); + var dataSourceBuilder = new NpgsqlDataSourceBuilder(TestConfiguration.Postgres.ConnectionString); + dataSourceBuilder.UseVector(); + await using var dataSource = dataSourceBuilder.Build(); + var vectorStore = new PostgresVectorStore(dataSource); + + // Create the common processor that works for any vector store. + var processor = new VectorStore_VectorSearch_MultiStore_Common(vectorStore, textEmbeddingGenerationService, this.Output); + + // Run the process and pass a key generator function to it, to generate unique record keys. + // The key generator function is required, since different vector stores may require different key types. + // E.g. Postgres supports Guid and ulong keys, but others may support strings only. + await processor.IngestDataAndSearchAsync("skglossaryWithoutDI", () => Guid.NewGuid()); + } +} diff --git a/dotnet/samples/Concepts/README.md b/dotnet/samples/Concepts/README.md index 6b0f28b329ca..deb3a6a43a20 100644 --- a/dotnet/samples/Concepts/README.md +++ b/dotnet/samples/Concepts/README.md @@ -215,3 +215,85 @@ dotnet test -l "console;verbosity=detailed" --filter "FullyQualifiedName=ChatCom - [OpenAI_TextToImage](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/TextToImage/OpenAI_TextToImage.cs) - [OpenAI_TextToImageLegacy](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/TextToImage/OpenAI_TextToImageLegacy.cs) - [AzureOpenAI_TextToImage](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/TextToImage/AzureOpenAI_TextToImage.cs) + +## Configuration + +### Option 1: Use Secret Manager + +Concept samples will require secrets and credentials, to access OpenAI, Azure OpenAI, +Bing and other resources. + +We suggest using .NET [Secret Manager](https://learn.microsoft.com/en-us/aspnet/core/security/app-secrets) +to avoid the risk of leaking secrets into the repository, branches and pull requests. +You can also use environment variables if you prefer. + +To set your secrets with Secret Manager: + +``` +cd dotnet/src/samples/Concepts +dotnet user-secrets init +dotnet user-secrets set "OpenAI:ServiceId" "gpt-3.5-turbo-instruct" +dotnet user-secrets set "OpenAI:ModelId" "gpt-3.5-turbo-instruct" +dotnet user-secrets set "OpenAI:ChatModelId" "gpt-4" +dotnet user-secrets set "OpenAI:ApiKey" "..." +... +``` + +### Option 2: Use Configuration File +1. Create a `appsettings.Development.json` file next to the `Concepts.csproj` file. This file will be ignored by git, + the content will not end up in pull requests, so it's safe for personal settings. Keep the file safe. +2. Edit `appsettings.Development.json` and set the appropriate configuration for the samples you are running. + +For example: + +```json +{ + "OpenAI": { + "ServiceId": "gpt-3.5-turbo-instruct", + "ModelId": "gpt-3.5-turbo-instruct", + "ChatModelId": "gpt-4", + "ApiKey": "sk-...." + }, + "AzureOpenAI": { + "ServiceId": "azure-gpt-35-turbo-instruct", + "DeploymentName": "gpt-35-turbo-instruct", + "ChatDeploymentName": "gpt-4", + "Endpoint": "https://contoso.openai.azure.com/", + "ApiKey": "...." + }, + // etc. +} +``` + +### Option 3: Use Environment Variables +You may also set the settings in your environment variables. The environment variables will override the settings in the `appsettings.Development.json` file. + +When setting environment variables, use a double underscore (i.e. "\_\_") to delineate between parent and child properties. For example: + +- bash: + + ```bash + export OpenAI__ApiKey="sk-...." + export AzureOpenAI__ApiKey="...." + export AzureOpenAI__DeploymentName="gpt-35-turbo-instruct" + export AzureOpenAI__ChatDeploymentName="gpt-4" + export AzureOpenAIEmbeddings__DeploymentName="azure-text-embedding-ada-002" + export AzureOpenAI__Endpoint="https://contoso.openai.azure.com/" + export HuggingFace__ApiKey="...." + export Bing__ApiKey="...." + export Postgres__ConnectionString="...." + ``` + +- PowerShell: + + ```ps + $env:OpenAI__ApiKey = "sk-...." + $env:AzureOpenAI__ApiKey = "...." + $env:AzureOpenAI__DeploymentName = "gpt-35-turbo-instruct" + $env:AzureOpenAI__ChatDeploymentName = "gpt-4" + $env:AzureOpenAIEmbeddings__DeploymentName = "azure-text-embedding-ada-002" + $env:AzureOpenAI__Endpoint = "https://contoso.openai.azure.com/" + $env:HuggingFace__ApiKey = "...." + $env:Bing__ApiKey = "...." + $env:Postgres__ConnectionString = "...." + ``` diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj index a5ec850f1b6e..b1904c6cc1cd 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj @@ -29,4 +29,9 @@ + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs index 70747990e2fd..2af6d4f5fb62 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// -/// Interface for client managing postgres database operations. +/// Interface for client managing postgres database operations for . /// public interface IPostgresDbClient { diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs new file mode 100644 index 000000000000..d130d2f13b44 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Interface for constructing SQL commands for Postgres vector store collections. +/// +internal interface IPostgresVectorStoreCollectionSqlBuilder +{ + /// + /// Builds a SQL command to check if a table exists in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The built SQL command. + /// + /// The command must return a single row with a single column named "table_name" if the table exists. + /// + PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName); + + /// + /// Builds a SQL command to fetch all tables in the Postgres vector store. + /// + /// The schema of the tables. + PostgresSqlCommandInfo BuildGetTablesCommand(string schema); + + /// + /// Builds a SQL command to create a table in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The properties of the table. + /// Specifies whether to include IF NOT EXISTS in the command. + /// The built SQL command info. + PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, IReadOnlyList properties, bool ifNotExists = true); + + /// + /// Builds a SQL command to create a vector index in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The name of the vector column. + /// The kind of index to create. + /// The distance function to use for the index. + /// The built SQL command info. + PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction); + + /// + /// Builds a SQL command to drop a table in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The built SQL command info. + PostgresSqlCommandInfo BuildDropTableCommand(string schema, string tableName); + + /// + /// Builds a SQL command to upsert a record in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The key column of the table. + /// The row to upsert. + /// The built SQL command info. + PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, string keyColumn, Dictionary row); + + /// + /// Builds a SQL command to upsert a batch of records in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The key column of the table. + /// The rows to upsert. + /// The built SQL command info. + PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tableName, string keyColumn, List> rows); + + /// + /// Builds a SQL command to get a record from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The properties of the table. + /// The key of the record to get. + /// Specifies whether to include vectors in the record. + /// The built SQL command info. + PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, IReadOnlyList properties, TKey key, bool includeVectors = false) where TKey : notnull; + + /// + /// Builds a SQL command to get a batch of records from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The properties of the table. + /// The keys of the records to get. + /// Specifies whether to include vectors in the records. + /// The built SQL command info. + PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, IReadOnlyList properties, List keys, bool includeVectors = false) where TKey : notnull; + + /// + /// Builds a SQL command to delete a record from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The key column of the table. + /// The key of the record to delete. + /// The built SQL command info. + PostgresSqlCommandInfo BuildDeleteCommand(string schema, string tableName, string keyColumn, TKey key); + + /// + /// Builds a SQL command to delete a batch of records from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The key column of the table. + /// The keys of the records to delete. + /// The built SQL command info. + PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, string tableName, string keyColumn, List keys); + + /// + /// Builds a SQL command to get the nearest match from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The properties of the table. + /// The property which the vectors to compare are stored in. + /// The vector to match. + /// The filter conditions for the query. + /// The number of records to skip. + /// Specifies whether to include vectors in the result. + /// The maximum number of records to return. + /// The built SQL command info. + PostgresSqlCommandInfo BuildGetNearestMatchCommand(string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, VectorSearchFilter? filter, int? skip, bool includeVectors, int limit); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs new file mode 100644 index 000000000000..59aa9829c568 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Npgsql; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Internal interface for client managing postgres database operations. +/// +internal interface IPostgresVectorStoreDbClient +{ + /// + /// The used to connect to the database. + /// + public NpgsqlDataSource DataSource { get; } + + /// + /// Check if a table exists. + /// + /// The name assigned to a table of entries. + /// The to monitor for cancellation requests. The default is . + /// + Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default); + + /// + /// Get all tables. + /// + /// The to monitor for cancellation requests. The default is . + /// A group of tables. + IAsyncEnumerable GetTablesAsync(CancellationToken cancellationToken = default); + /// + /// Create a table. Also creates an index on vector columns if the table has vector properties defined. + /// + /// The name assigned to a table of entries. + /// The properties of the record definition that define the table. + /// Specifies whether to include IF NOT EXISTS in the command. + /// The to monitor for cancellation requests. The default is . + /// + Task CreateTableAsync(string tableName, IReadOnlyList properties, bool ifNotExists = true, CancellationToken cancellationToken = default); + + /// + /// Drop a table. + /// + /// The name assigned to a table of entries. + /// The to monitor for cancellation requests. The default is . + Task DeleteTableAsync(string tableName, CancellationToken cancellationToken = default); + + /// + /// Upsert entry into a table. + /// + /// The name assigned to a table of entries. + /// The row to upsert into the table. + /// The key column of the table. + /// The to monitor for cancellation requests. The default is . + /// + Task UpsertAsync(string tableName, Dictionary row, string keyColumn, CancellationToken cancellationToken = default); + + /// + /// Upsert multiple entries into a table. + /// + /// The name assigned to a table of entries. + /// The rows to upsert into the table. + /// The key column of the table. + /// The to monitor for cancellation requests. The default is . + /// + Task UpsertBatchAsync(string tableName, IEnumerable> rows, string keyColumn, CancellationToken cancellationToken = default); + + /// + /// Get a entry by its key. + /// + /// The name assigned to a table of entries. + /// The key of the entry to get. + /// The properties to include in the entry. + /// If true, the vectors will be included in the entry. + /// The to monitor for cancellation requests. The default is . + /// The row if the key is found, otherwise null. + Task?> GetAsync(string tableName, TKey key, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) + where TKey : notnull; + + /// + /// Get multiple entries by their keys. + /// + /// The name assigned to a table of entries. + /// The keys of the entries to get. + /// The properties of the table. + /// If true, the vectors will be included in the entries. + /// The to monitor for cancellation requests. The default is . + /// The rows that match the given keys. + IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) + where TKey : notnull; + + /// + /// Delete a entry by its key. + /// + /// The name assigned to a table of entries. + /// The name of the key column. + /// The key of the entry to delete. + /// The to monitor for cancellation requests. The default is . + /// + Task DeleteAsync(string tableName, string keyColumn, TKey key, CancellationToken cancellationToken = default); + + /// + /// Delete multiple entries by their keys. + /// + /// The name assigned to a table of entries. + /// The name of the key column. + /// The keys of the entries to delete. + /// The to monitor for cancellation requests. The default is . + /// + Task DeleteBatchAsync(string tableName, string keyColumn, IEnumerable keys, CancellationToken cancellationToken = default); + + /// + /// Gets the nearest matches to the . + /// + /// The name assigned to a table of entries. + /// The properties to retrieve. + /// The property which the vectors to compare are stored in. + /// The to compare the table's vector with. + /// The maximum number of similarity results to return. + /// Optional conditions to filter the results. + /// The number of entries to skip. + /// If true, the vectors will be returned in the entries. + /// The to monitor for cancellation requests. The default is . + /// An asynchronous stream of objects that the nearest matches to the . + IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync(string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, + VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs new file mode 100644 index 000000000000..5bf0d9cad789 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Interface for constructing Postgres instances when using to retrieve these. +/// +public interface IPostgresVectorStoreRecordCollectionFactory +{ + /// + /// Constructs a new instance of the . + /// + /// The data type of the record key. + /// The data model to use for adding, updating and retrieving data from storage. + /// The Postgres data source. + /// The name of the collection to connect to. + /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. + /// The new instance of . + IVectorStoreRecordCollection CreateVectorStoreRecordCollection(NpgsqlDataSource dataSource, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + where TKey : notnull; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs new file mode 100644 index 000000000000..f8784890e83a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal static class PostgresConstants +{ + /// The name of this database for telemetry purposes. + public const string DatabaseName = "Postgres"; + + /// A of types that a key on the provided model may have. + public static readonly HashSet SupportedKeyTypes = + [ + typeof(short), + typeof(int), + typeof(long), + typeof(string), + typeof(Guid), + ]; + + /// A of types that data properties on the provided model may have. + public static readonly HashSet SupportedDataTypes = + [ + typeof(bool), + typeof(bool?), + typeof(short), + typeof(short?), + typeof(int), + typeof(int?), + typeof(long), + typeof(long?), + typeof(float), + typeof(float?), + typeof(double), + typeof(double?), + typeof(decimal), + typeof(decimal?), + typeof(string), + typeof(DateTime), + typeof(DateTime?), + typeof(DateTimeOffset), + typeof(DateTimeOffset?), + typeof(Guid), + typeof(Guid?), + typeof(byte[]), + ]; + + /// A of types that enumerable data properties on the provided model may use as their element types. + public static readonly HashSet SupportedEnumerableDataElementTypes = + [ + typeof(bool), + typeof(short), + typeof(int), + typeof(long), + typeof(float), + typeof(double), + typeof(decimal), + typeof(string), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(Guid), + ]; + + /// A of types that vector properties on the provided model may have. + public static readonly HashSet SupportedVectorTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?) + ]; + + /// The default schema name. + public const string DefaultSchema = "public"; + + /// The name of the column that returns distance value in the database. + /// It is used in the similarity search query. Must not conflict with model property. + public const string DistanceColumnName = "sk_pg_distance"; + + /// The default index kind. + /// Defaults to "Flat", which means no indexing. + public const string DefaultIndexKind = IndexKind.Flat; + + /// The default distance function. + public const string DefaultDistanceFunction = DistanceFunction.CosineDistance; + + public static readonly Dictionary IndexMaxDimensions = new() + { + { IndexKind.Hnsw, 2000 }, + }; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs index 1dc1ffef3c1d..d927710d4fd9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs @@ -13,7 +13,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// -/// An implementation of a client for Postgres. This class is used to managing postgres database operations. +/// An implementation of a client for Postgres. This class is used to managing postgres database operations for . /// [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] public class PostgresDbClient : IPostgresDbClient diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs new file mode 100644 index 000000000000..efdec538c772 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal sealed class PostgresGenericDataModelMapper : IVectorStoreRecordMapper, Dictionary> + where TKey : notnull +{ + /// with helpers for reading vector store model properties and their attributes. + private readonly VectorStoreRecordPropertyReader _propertyReader; + + /// + /// Initializes a new instance of the class. + /// /// + /// with helpers for reading vector store model properties and their attributes. + public PostgresGenericDataModelMapper(VectorStoreRecordPropertyReader propertyReader) + { + Verify.NotNull(propertyReader); + + this._propertyReader = propertyReader; + + // Validate property types. + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); + this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); + } + + public Dictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) + { + var properties = new Dictionary + { + // Add key property + { this._propertyReader.KeyPropertyStoragePropertyName, dataModel.Key } + }; + + // Add data properties + if (dataModel.Data is not null) + { + foreach (var property in this._propertyReader.DataProperties) + { + if (dataModel.Data.TryGetValue(property.DataModelPropertyName, out var dataValue)) + { + properties.Add(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), dataValue); + } + } + } + + // Add vector properties + if (dataModel.Vectors is not null) + { + foreach (var property in this._propertyReader.VectorProperties) + { + if (dataModel.Vectors.TryGetValue(property.DataModelPropertyName, out var vectorValue)) + { + var result = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vectorValue); + properties.Add(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), result); + } + } + } + + return properties; + } + + public VectorStoreGenericDataModel MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) + { + TKey key; + var dataProperties = new Dictionary(); + var vectorProperties = new Dictionary(); + + // Process key property. + if (storageModel.TryGetValue(this._propertyReader.KeyPropertyStoragePropertyName, out var keyObject) && keyObject is not null) + { + key = (TKey)keyObject; + } + else + { + throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); + } + + // Process data properties. + foreach (var property in this._propertyReader.DataProperties) + { + if (storageModel.TryGetValue(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), out var dataValue)) + { + dataProperties.Add(property.DataModelPropertyName, dataValue); + } + } + + // Process vector properties + if (options.IncludeVectors) + { + foreach (var property in this._propertyReader.VectorProperties) + { + if (storageModel.TryGetValue(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), out var vectorValue)) + { + vectorProperties.Add(property.DataModelPropertyName, PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(vectorValue)); + } + } + } + + return new VectorStoreGenericDataModel(key) { Data = dataProperties, Vectors = vectorProperties }; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs new file mode 100644 index 000000000000..983b8e7db443 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods to register Postgres instances on an . +/// +public static class PostgresServiceCollectionExtensions +{ + /// + /// Register a Postgres with the specified service ID and where the NpgsqlDataSource is retrieved from the dependency injection container. + /// + /// The to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPostgresVectorStore(this IServiceCollection services, PostgresVectorStoreOptions? options = default, string? serviceId = default) + { + // Since we are not constructing the data source, add the IVectorStore as transient, since we + // cannot make assumptions about how data source is being managed. + services.AddKeyedTransient( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService(); + + return new PostgresVectorStore( + dataSource, + selectedOptions); + }); + + return services; + } + + /// + /// Register a Postgres with the specified service ID and where an NpgsqlDataSource is constructed using the provided parameters. + /// + /// The to register the on. + /// Postgres database connection string. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPostgresVectorStore(this IServiceCollection services, string connectionString, PostgresVectorStoreOptions? options = default, string? serviceId = default) + { + string? npgsqlServiceId = serviceId == null ? default : $"{serviceId}_NpgsqlDataSource"; + // Register NpgsqlDataSource to ensure proper disposal. + services.AddKeyedSingleton( + npgsqlServiceId, + (sp, obj) => + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + return dataSourceBuilder.Build(); + }); + + services.AddKeyedSingleton( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredKeyedService(npgsqlServiceId); + var selectedOptions = options ?? sp.GetService(); + + return new PostgresVectorStore( + dataSource, + selectedOptions); + }); + + return services; + } + + /// + /// Register a Postgres and with the specified service ID + /// and where the NpgsqlDataSource is retrieved from the dependency injection container. + /// + /// The type of the key. + /// The type of the record. + /// The to register the on. + /// The name of the collection. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// Service collection. + public static IServiceCollection AddPostgresVectorStoreRecordCollection( + this IServiceCollection services, + string collectionName, + PostgresVectorStoreRecordCollectionOptions? options = default, + string? serviceId = default) + where TKey : notnull + { + services.AddKeyedTransient>( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService>(); + + return (new PostgresVectorStoreRecordCollection(dataSource, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + }); + + AddVectorizedSearch(services, serviceId); + + return services; + } + + /// + /// Register a Postgres and with the specified service ID + /// and where the NpgsqlDataSource is constructed using the provided parameters. + /// + /// The type of the key. + /// The type of the record. + /// The to register the on. + /// The name of the collection. + /// Postgres database connection string. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// Service collection. + public static IServiceCollection AddPostgresVectorStoreRecordCollection( + this IServiceCollection services, + string collectionName, + string connectionString, + PostgresVectorStoreRecordCollectionOptions? options = default, + string? serviceId = default) + where TKey : notnull + { + string? npgsqlServiceId = serviceId == null ? default : $"{serviceId}_NpgsqlDataSource"; + // Register NpgsqlDataSource to ensure proper disposal. + services.AddKeyedSingleton( + npgsqlServiceId, + (sp, obj) => + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + return dataSourceBuilder.Build(); + }); + + services.AddKeyedSingleton>( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredKeyedService(npgsqlServiceId); + + return (new PostgresVectorStoreRecordCollection(dataSource, collectionName, options) as IVectorStoreRecordCollection)!; + }); + + AddVectorizedSearch(services, serviceId); + + return services; + } + + /// + /// Also register the with the given as a . + /// + /// The type of the key. + /// The type of the data model that the collection should contain. + /// The service collection to register on. + /// The service id that the registrations should use. + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + where TKey : notnull + { + services.AddKeyedTransient>( + serviceId, + (sp, obj) => + { + return sp.GetRequiredKeyedService>(serviceId); + }); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs new file mode 100644 index 000000000000..fb520188b84b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Represents a SQL command for Postgres. +/// +internal class PostgresSqlCommandInfo +{ + /// + /// Gets or sets the SQL command text. + /// + public string CommandText { get; set; } + /// + /// Gets or sets the parameters for the SQL command. + /// + public List? Parameters { get; set; } = null; + + /// + /// Initializes a new instance of the class. + /// + /// The SQL command text. + /// The parameters for the SQL command. + public PostgresSqlCommandInfo(string commandText, List? parameters = null) + { + this.CommandText = commandText; + this.Parameters = parameters; + } + + /// + /// Converts this instance to an . + /// + [SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "User input is passed using command parameters.")] + public NpgsqlCommand ToNpgsqlCommand(NpgsqlConnection connection, NpgsqlTransaction? transaction = null) + { + NpgsqlCommand cmd = connection.CreateCommand(); + if (transaction != null) + { + cmd.Transaction = transaction; + } + cmd.CommandText = this.CommandText; + if (this.Parameters != null) + { + foreach (var parameter in this.Parameters) + { + cmd.Parameters.Add(parameter); + } + } + return cmd; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs new file mode 100644 index 000000000000..99bbc8e320b5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using Microsoft.Extensions.VectorData; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Represents a vector store implementation using PostgreSQL. +/// +public class PostgresVectorStore : IVectorStore +{ + private readonly IPostgresVectorStoreDbClient _postgresClient; + private readonly NpgsqlDataSource? _dataSource; + private readonly PostgresVectorStoreOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// Postgres data source. + /// Optional configuration options for this class + public PostgresVectorStore(NpgsqlDataSource dataSource, PostgresVectorStoreOptions? options = default) + { + this._dataSource = dataSource; + this._options = options ?? new PostgresVectorStoreOptions(); + this._postgresClient = new PostgresVectorStoreDbClient(this._dataSource, this._options.Schema); + } + + /// + /// Initializes a new instance of the class. + /// + /// An instance of . + /// Optional configuration options for this class + internal PostgresVectorStore(IPostgresVectorStoreDbClient postgresDbClient, PostgresVectorStoreOptions? options = default) + { + this._postgresClient = postgresDbClient; + this._options = options ?? new PostgresVectorStoreOptions(); + } + + /// + public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) + { + const string OperationName = "ListCollectionNames"; + return PostgresVectorStoreUtils.WrapAsyncEnumerableAsync( + this._postgresClient.GetTablesAsync(cancellationToken), + OperationName + ); + } + + /// + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + { + if (!PostgresConstants.SupportedKeyTypes.Contains(typeof(TKey))) + { + throw new NotSupportedException($"Unsupported key type: {typeof(TKey)}"); + } + + if (this._options.VectorStoreCollectionFactory is not null) + { + return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(this._postgresClient.DataSource, name, vectorStoreRecordDefinition); + } + + var recordCollection = new PostgresVectorStoreRecordCollection( + this._postgresClient, + name, + new PostgresVectorStoreRecordCollectionOptions() { Schema = this._options.Schema, VectorStoreRecordDefinition = vectorStoreRecordDefinition } + ); + + return recordCollection as IVectorStoreRecordCollection ?? throw new InvalidOperationException("Failed to cast record collection."); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs new file mode 100644 index 000000000000..d68412d31b7d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -0,0 +1,453 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.Extensions.VectorData; +using Npgsql; +using NpgsqlTypes; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Provides methods to build SQL commands for managing vector store collections in PostgreSQL. +/// +internal class PostgresVectorStoreCollectionSqlBuilder : IPostgresVectorStoreCollectionSqlBuilder +{ + /// + public PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName) + { + return new PostgresSqlCommandInfo( + commandText: @" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 + AND table_type = 'BASE TABLE' + AND table_name = $2", + parameters: [ + new NpgsqlParameter() { Value = schema }, + new NpgsqlParameter() { Value = tableName } + ] + ); + } + + /// + public PostgresSqlCommandInfo BuildGetTablesCommand(string schema) + { + return new PostgresSqlCommandInfo( + commandText: @" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 + AND table_type = 'BASE TABLE'", + parameters: [new NpgsqlParameter() { Value = schema }] + ); + } + + /// + public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, IReadOnlyList properties, bool ifNotExists = true) + { + if (string.IsNullOrWhiteSpace(tableName)) + { + throw new ArgumentException("Table name cannot be null or whitespace", nameof(tableName)); + } + + VectorStoreRecordKeyProperty? keyProperty = default; + List dataProperties = new(); + List vectorProperties = new(); + + foreach (var property in properties) + { + if (property is VectorStoreRecordKeyProperty keyProp) + { + if (keyProperty != null) + { + // Should be impossible, as property reader should have already validated that + // multiple key properties are not allowed. + throw new ArgumentException("Record definition cannot have more than one key property."); + } + keyProperty = keyProp; + } + else if (property is VectorStoreRecordDataProperty dataProp) + { + dataProperties.Add(dataProp); + } + else if (property is VectorStoreRecordVectorProperty vectorProp) + { + vectorProperties.Add(vectorProp); + } + else + { + throw new NotSupportedException($"Property type {property.GetType().Name} is not supported by this store."); + } + } + + if (keyProperty == null) + { + throw new ArgumentException("Record definition must have a key property."); + } + + var keyName = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; + + StringBuilder createTableCommand = new(); + createTableCommand.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : "")}{schema}.\"{tableName}\" ("); + + // Add the key column + var keyPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(keyProperty.PropertyType); + createTableCommand.AppendLine($" \"{keyName}\" {keyPgTypeInfo.PgType} {(keyPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + + // Add the data columns + foreach (var dataProperty in dataProperties) + { + string columnName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; + var dataPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(dataProperty.PropertyType); + createTableCommand.AppendLine($" \"{columnName}\" {dataPgTypeInfo.PgType} {(dataPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + } + + // Add the vector columns + foreach (var vectorProperty in vectorProperties) + { + string columnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + var vectorPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPgVectorTypeName(vectorProperty); + createTableCommand.AppendLine($" \"{columnName}\" {vectorPgTypeInfo.PgType} {(vectorPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + } + + createTableCommand.AppendLine($" PRIMARY KEY (\"{keyName}\")"); + + createTableCommand.AppendLine(");"); + + return new PostgresSqlCommandInfo(commandText: createTableCommand.ToString()); + } + + /// + public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction) + { + // Only support creating HNSW index creation through the connector. + var indexTypeName = indexKind switch + { + IndexKind.Hnsw => "hnsw", + _ => throw new NotSupportedException($"Index kind '{indexKind}' is not supported for table creation. If you need to create an index of this type, please do so manually. Only HNSW indexes are supported through the vector store.") + }; + + distanceFunction ??= PostgresConstants.DefaultDistanceFunction; // Default to Cosine distance + + var indexOps = distanceFunction switch + { + DistanceFunction.CosineDistance => "vector_cosine_ops", + DistanceFunction.CosineSimilarity => "vector_cosine_ops", + DistanceFunction.DotProductSimilarity => "vector_ip_ops", + DistanceFunction.EuclideanDistance => "vector_l2_ops", + DistanceFunction.ManhattanDistance => "vector_l1_ops", + _ => throw new NotSupportedException($"Distance function {distanceFunction} is not supported.") + }; + + var indexName = $"{tableName}_{vectorColumnName}_index"; + + return new PostgresSqlCommandInfo( + commandText: $@" + CREATE INDEX {indexName} ON {schema}.""{tableName}"" USING {indexTypeName} (""{vectorColumnName}"" {indexOps});" + ); + } + + /// + public PostgresSqlCommandInfo BuildDropTableCommand(string schema, string tableName) + { + return new PostgresSqlCommandInfo( + commandText: $@"DROP TABLE IF EXISTS {schema}.""{tableName}""" + ); + } + + /// + public PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, string keyColumn, Dictionary row) + { + var columns = row.Keys.ToList(); + var columnNames = string.Join(", ", columns.Select(k => $"\"{k}\"")); + var valuesParams = string.Join(", ", columns.Select((k, i) => $"${i + 1}")); + var columnsWithIndex = columns.Select((k, i) => (col: k, idx: i)); + var updateColumnsWithParams = string.Join(", ", columnsWithIndex.Where(c => c.col != keyColumn).Select(c => $"\"{c.col}\"=${c.idx + 1}")); + var commandText = $@" + INSERT INTO {schema}.""{tableName}"" ({columnNames}) + VALUES({valuesParams}) + ON CONFLICT (""{keyColumn}"") + DO UPDATE SET {updateColumnsWithParams};"; + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = columns.Select(c => + PostgresVectorStoreRecordPropertyMapping.GetNpgsqlParameter(row[c]) + ).ToList() + }; + } + + /// + public PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tableName, string keyColumn, List> rows) + { + if (rows == null || rows.Count == 0) + { + throw new ArgumentException("Rows cannot be null or empty", nameof(rows)); + } + + var firstRow = rows[0]; + var columns = firstRow.Keys.ToList(); + + // Generate column names and parameter placeholders + var columnNames = string.Join(", ", columns.Select(c => $"\"{c}\"")); + var valuePlaceholders = string.Join(", ", columns.Select((c, i) => $"${i + 1}")); + var valuesRows = string.Join(", ", + rows.Select((row, rowIndex) => + $"({string.Join(", ", + columns.Select((c, colIndex) => $"${rowIndex * columns.Count + colIndex + 1}"))})")); + + // Generate the update set clause + var updateSetClause = string.Join(", ", columns.Where(c => c != keyColumn).Select(c => $"\"{c}\" = EXCLUDED.\"{c}\"")); + + // Generate the SQL command + var commandText = $@" + INSERT INTO {schema}.""{tableName}"" ({columnNames}) + VALUES {valuesRows} + ON CONFLICT (""{keyColumn}"") + DO UPDATE SET {updateSetClause}; "; + + // Generate the parameters + var parameters = new List(); + for (int rowIndex = 0; rowIndex < rows.Count; rowIndex++) + { + var row = rows[rowIndex]; + foreach (var column in columns) + { + parameters.Add(new NpgsqlParameter() + { + Value = row[column] ?? DBNull.Value + }); + } + } + + return new PostgresSqlCommandInfo(commandText, parameters); + } + + /// + public PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, IReadOnlyList properties, TKey key, bool includeVectors = false) + where TKey : notnull + { + List queryColumns = new(); + string? keyColumn = null; + + foreach (var property in properties) + { + if (property is VectorStoreRecordKeyProperty keyProperty) + { + if (keyColumn != null) + { + throw new ArgumentException("Record definition cannot have more than one key property."); + } + keyColumn = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; + queryColumns.Add($"\"{keyColumn}\""); + } + else if (property is VectorStoreRecordDataProperty dataProperty) + { + string columnName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; + queryColumns.Add($"\"{columnName}\""); + } + else if (property is VectorStoreRecordVectorProperty vectorProperty && includeVectors) + { + string columnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + queryColumns.Add($"\"{columnName}\""); + } + } + + Verify.NotNull(keyColumn, "Record definition must have a key property."); + + var queryColumnList = string.Join(", ", queryColumns); + + return new PostgresSqlCommandInfo( + commandText: $@" + SELECT {queryColumnList} + FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ${1};", + parameters: [new NpgsqlParameter() { Value = key }] + ); + } + + /// + public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, IReadOnlyList properties, List keys, bool includeVectors = false) + where TKey : notnull + { + NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); + + if (keys == null || keys.Count == 0) + { + throw new ArgumentException("Keys cannot be null or empty", nameof(keys)); + } + + var keyProperty = properties.OfType().FirstOrDefault() ?? throw new ArgumentException("Properties must contain a key property", nameof(properties)); + var keyColumn = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; + + // Generate the column names + var columns = properties + .Where(p => includeVectors || p is not VectorStoreRecordVectorProperty) + .Select(p => p.StoragePropertyName ?? p.DataModelPropertyName) + .ToList(); + + var columnNames = string.Join(", ", columns.Select(c => $"\"{c}\"")); + var keyParams = string.Join(", ", keys.Select((k, i) => $"${i + 1}")); + + // Generate the SQL command + var commandText = $@" + SELECT {columnNames} + FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ANY($1);"; + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = [new NpgsqlParameter() { Value = keys.ToArray(), NpgsqlDbType = NpgsqlDbType.Array | keyType.Value }] + }; + } + + /// + public PostgresSqlCommandInfo BuildDeleteCommand(string schema, string tableName, string keyColumn, TKey key) + { + return new PostgresSqlCommandInfo( + commandText: $@" + DELETE FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ${1};", + parameters: [new NpgsqlParameter() { Value = key }] + ); + } + + /// + public PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, string tableName, string keyColumn, List keys) + { + NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); + if (keys == null || keys.Count == 0) + { + throw new ArgumentException("Keys cannot be null or empty", nameof(keys)); + } + + for (int i = 0; i < keys.Count; i++) + { + if (keys[i] == null) + { + throw new ArgumentException("Keys cannot contain null values", nameof(keys)); + } + } + + var commandText = $@" + DELETE FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ANY($1);"; + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = [new NpgsqlParameter() { Value = keys, NpgsqlDbType = NpgsqlDbType.Array | keyType.Value }] + }; + } + + /// + public PostgresSqlCommandInfo BuildGetNearestMatchCommand( + string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, + VectorSearchFilter? filter, int? skip, bool includeVectors, int limit) + { + var columns = string.Join(" ,", + properties + .Select(property => property.StoragePropertyName ?? property.DataModelPropertyName) + .Select(column => $"\"{column}\"") + ); + + var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; + var distanceOp = distanceFunction switch + { + DistanceFunction.CosineDistance => "<=>", + DistanceFunction.CosineSimilarity => "<=>", + DistanceFunction.EuclideanDistance => "<->", + DistanceFunction.ManhattanDistance => "<+>", + DistanceFunction.DotProductSimilarity => "<#>", + null or "" => "<->", // Default to Euclidean distance + _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") + }; + + var vectorColumn = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + // Start where clause params at 2, vector takes param 1. + var where = GenerateWhereClause(schema, tableName, properties, filter, startParamIndex: 2); + + var commandText = $@" + SELECT {columns}, ""{vectorColumn}"" {distanceOp} $1 AS ""{PostgresConstants.DistanceColumnName}"" + FROM {schema}.""{tableName}"" {where.Clause} + ORDER BY {PostgresConstants.DistanceColumnName} + LIMIT {limit}"; + + if (skip.HasValue) { commandText += $" OFFSET {skip.Value}"; } + + // For cosine similarity, we need to take 1 - cosine distance. + // However, we can't use an expression in the ORDER BY clause or else the index won't be used. + // Instead we'll wrap the query in a subquery and modify the distance in the outer query. + if (vectorProperty.DistanceFunction == DistanceFunction.CosineSimilarity) + { + commandText = $@" + SELECT {columns}, 1 - ""{PostgresConstants.DistanceColumnName}"" AS ""{PostgresConstants.DistanceColumnName}"" + FROM ({commandText}) AS subquery"; + } + + // For inner product, we need to take -1 * inner product. + // However, we can't use an expression in the ORDER BY clause or else the index won't be used. + // Instead we'll wrap the query in a subquery and modify the distance in the outer query. + if (vectorProperty.DistanceFunction == DistanceFunction.DotProductSimilarity) + { + commandText = $@" + SELECT {columns}, -1 * ""{PostgresConstants.DistanceColumnName}"" AS ""{PostgresConstants.DistanceColumnName}"" + FROM ({commandText}) AS subquery"; + } + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = [new NpgsqlParameter() { Value = vectorValue }, .. where.Parameters.Select(p => new NpgsqlParameter() { Value = p })] + }; + } + + internal static (string Clause, List Parameters) GenerateWhereClause(string schema, string tableName, IReadOnlyList properties, VectorSearchFilter? filter, int startParamIndex) + { + if (filter == null) { return (string.Empty, new List()); } + + var whereClause = new StringBuilder("WHERE "); + var filterClauses = new List(); + var parameters = new List(); + + var paramIndex = startParamIndex; + + foreach (var filterClause in filter.FilterClauses) + { + if (filterClause is EqualToFilterClause equalTo) + { + var property = properties.FirstOrDefault(p => p.DataModelPropertyName == equalTo.FieldName); + if (property == null) { throw new ArgumentException($"Property {equalTo.FieldName} not found in record definition."); } + + var columnName = property.StoragePropertyName ?? property.DataModelPropertyName; + filterClauses.Add($"\"{columnName}\" = ${paramIndex}"); + parameters.Add(equalTo.Value); + paramIndex++; + } + else if (filterClause is AnyTagEqualToFilterClause anyTagEqualTo) + { + var property = properties.FirstOrDefault(p => p.DataModelPropertyName == anyTagEqualTo.FieldName); + if (property == null) { throw new ArgumentException($"Property {anyTagEqualTo.FieldName} not found in record definition."); } + + if (property.PropertyType != typeof(List)) + { + throw new ArgumentException($"Property {anyTagEqualTo.FieldName} must be of type List to use AnyTagEqualTo filter."); + } + + var columnName = property.StoragePropertyName ?? property.DataModelPropertyName; + filterClauses.Add($"\"{columnName}\" @> ARRAY[${paramIndex}::TEXT]"); + parameters.Add(anyTagEqualTo.Value); + paramIndex++; + } + else + { + throw new NotSupportedException($"Filter clause type {filterClause.GetType().Name} is not supported."); + } + } + + whereClause.Append(string.Join(" AND ", filterClauses)); + return (whereClause.ToString(), parameters); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs new file mode 100644 index 000000000000..5ef18cc88fdf --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Npgsql; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// An implementation of a client for Postgres. This class is used to managing postgres database operations. +/// +/// +/// Initializes a new instance of the class. +/// +/// Postgres data source. +/// Schema of collection tables. +[System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] +internal class PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema = PostgresConstants.DefaultSchema) : IPostgresVectorStoreDbClient +{ + private readonly string _schema = schema; + + private IPostgresVectorStoreCollectionSqlBuilder _sqlBuilder = new PostgresVectorStoreCollectionSqlBuilder(); + + public NpgsqlDataSource DataSource { get; } = dataSource; + + /// + public async Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildDoesTableExistCommand(this._schema, tableName); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + return dataReader.GetString(dataReader.GetOrdinal("table_name")) == tableName; + } + + return false; + } + } + + /// + public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetTablesCommand(this._schema); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return dataReader.GetString(dataReader.GetOrdinal("table_name")); + } + } + } + + /// + public async Task CreateTableAsync(string tableName, IReadOnlyList properties, bool ifNotExists = true, CancellationToken cancellationToken = default) + { + // Prepare the SQL commands. + var commandInfo = this._sqlBuilder.BuildCreateTableCommand(this._schema, tableName, properties, ifNotExists); + var createIndexCommands = + PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo(properties) + .Select(index => + this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, index.column, index.kind, index.function) + ); + + // Execute the commands in a transaction. + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { +#if !NETSTANDARD2_0 + var transaction = await connection.BeginTransactionAsync(cancellationToken).ConfigureAwait(false); + await using (transaction) +#else + var transaction = connection.BeginTransaction(); + using (transaction) +#endif + { + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection, transaction); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + + foreach (var createIndexCommand in createIndexCommands) + { + using NpgsqlCommand indexCmd = createIndexCommand.ToNpgsqlCommand(connection, transaction); + await indexCmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + +#if !NETSTANDARD2_0 + await transaction.CommitAsync(cancellationToken).ConfigureAwait(false); +#else + transaction.Commit(); +#endif + } + } + } + + /// + public async Task DeleteTableAsync(string tableName, CancellationToken cancellationToken = default) + { + var commandInfo = this._sqlBuilder.BuildDropTableCommand(this._schema, tableName); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task UpsertAsync(string tableName, Dictionary row, string keyColumn, CancellationToken cancellationToken = default) + { + var commandInfo = this._sqlBuilder.BuildUpsertCommand(this._schema, tableName, keyColumn, row); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task UpsertBatchAsync(string tableName, IEnumerable> rows, string keyColumn, CancellationToken cancellationToken = default) + { + var commandInfo = this._sqlBuilder.BuildUpsertBatchCommand(this._schema, tableName, keyColumn, rows.ToList()); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task?> GetAsync(string tableName, TKey key, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetCommand(this._schema, tableName, properties, key, includeVectors); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + return this.GetRecord(dataReader, properties, includeVectors); + } + + return null; + } + } + + /// + public async IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, IReadOnlyList properties, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TKey : notnull + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetBatchCommand(this._schema, tableName, properties, keys.ToList(), includeVectors); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return this.GetRecord(dataReader, properties, includeVectors); + } + } + } + + /// + public async Task DeleteAsync(string tableName, string keyColumn, TKey key, CancellationToken cancellationToken = default) + { + var commandInfo = this._sqlBuilder.BuildDeleteCommand(this._schema, tableName, keyColumn, key); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync( + string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, + VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetNearestMatchCommand(this._schema, tableName, properties, vectorProperty, vectorValue, filter, skip, includeVectors, limit); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + var distance = dataReader.GetDouble(dataReader.GetOrdinal(PostgresConstants.DistanceColumnName)); + yield return (Row: this.GetRecord(dataReader, properties, includeVectors), Distance: distance); + } + } + } + + /// + public async Task DeleteBatchAsync(string tableName, string keyColumn, IEnumerable keys, CancellationToken cancellationToken = default) + { + var commandInfo = this._sqlBuilder.BuildDeleteBatchCommand(this._schema, tableName, keyColumn, keys.ToList()); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); + } + + #region internal =============================================================================== + + /// + /// Sets the SQL builder for the client. + /// + /// + /// + /// This method is used for other Semnatic Kernel connectors that may need to override the default SQL + /// used by this client. + /// + internal void SetSqlBuilder(IPostgresVectorStoreCollectionSqlBuilder sqlBuilder) + { + this._sqlBuilder = sqlBuilder; + } + + #endregion + + #region private ================================================================================ + + private Dictionary GetRecord( + NpgsqlDataReader reader, + IEnumerable properties, + bool includeVectors = false + ) + { + var storageModel = new Dictionary(); + + foreach (var property in properties) + { + var isEmbedding = property is VectorStoreRecordVectorProperty; + var propertyName = property.StoragePropertyName ?? property.DataModelPropertyName; + var propertyType = property.PropertyType; + var propertyValue = !isEmbedding || includeVectors ? PostgresVectorStoreRecordPropertyMapping.GetPropertyValue(reader, propertyName, propertyType) : null; + + storageModel.Add(propertyName, propertyValue); + } + + return storageModel; + } + + private async Task ExecuteNonQueryAsync(PostgresSqlCommandInfo commandInfo, CancellationToken cancellationToken) + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs new file mode 100644 index 000000000000..013f1810e146 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Options when creating a . +/// +public sealed class PostgresVectorStoreOptions +{ + /// + /// Gets or sets the database schema. + /// + public string Schema { get; init; } = "public"; + + /// + /// An optional factory to use for constructing instances, if a custom record collection is required. + /// + public IPostgresVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..95c8a4bcf282 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -0,0 +1,378 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Represents a collection of vector store records in a Postgres database. +/// +/// The type of the key. +/// The type of the record. +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public sealed class PostgresVectorStoreRecordCollection : IVectorStoreRecordCollection +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix + where TKey : notnull +{ + /// + public string CollectionName { get; } + + /// Postgres client that is used to interact with the database. + private readonly IPostgresVectorStoreDbClient _client; + + // Optional configuration options for this class. + private readonly PostgresVectorStoreRecordCollectionOptions _options; + + /// A helper to access property information for the current data model and record definition. + private readonly VectorStoreRecordPropertyReader _propertyReader; + + /// A mapper to use for converting between the data model and the Azure AI Search record. + private readonly IVectorStoreRecordMapper> _mapper; + + /// The default options for vector search. + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + + /// + /// Initializes a new instance of the class. + /// + /// The data source to use for connecting to the database. + /// The name of the collection. + /// Optional configuration options for this class. + public PostgresVectorStoreRecordCollection(NpgsqlDataSource dataSource, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default) + : this(new PostgresVectorStoreDbClient(dataSource), collectionName, options) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The client to use for interacting with the database. + /// The name of the collection. + /// Optional configuration options for this class. + /// + /// This constructor is internal. It allows internal code to create an instance of this class with a custom client. + /// + internal PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default) + { + // Verify. + Verify.NotNull(client); + Verify.NotNullOrWhiteSpace(collectionName); + VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.DictionaryCustomMapper is not null, PostgresConstants.SupportedKeyTypes); + VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + + // Assign. + this._client = client; + this.CollectionName = collectionName; + this._options = options ?? new PostgresVectorStoreRecordCollectionOptions(); + this._propertyReader = new VectorStoreRecordPropertyReader( + typeof(TRecord), + this._options.VectorStoreRecordDefinition, + new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + }); + + // Validate property types. + this._propertyReader.VerifyKeyProperties(PostgresConstants.SupportedKeyTypes); + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); + this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); + + // Resolve mapper. + // First, if someone has provided a custom mapper, use that. + // If they didn't provide a custom mapper, and the record type is the generic data model, use the built in mapper for that. + // Otherwise, use our own default mapper implementation for all other data models. + if (this._options.DictionaryCustomMapper is not null) + { + this._mapper = this._options.DictionaryCustomMapper; + } + else if (typeof(TRecord).IsGenericType && typeof(TRecord).GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>)) + { + this._mapper = (new PostgresGenericDataModelMapper(this._propertyReader) as IVectorStoreRecordMapper>)!; + } + else + { + this._mapper = new PostgresVectorStoreRecordMapper(this._propertyReader); + } + } + + /// + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + const string OperationName = "DoesTableExists"; + return this.RunOperationAsync(OperationName, () => + this._client.DoesTableExistsAsync(this.CollectionName, cancellationToken) + ); + } + + /// + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + const string OperationName = "CreateCollection"; + return this.RunOperationAsync(OperationName, () => + this.InternalCreateCollectionAsync(false, cancellationToken) + ); + } + + /// + public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + const string OperationName = "CreateCollectionIfNotExists"; + return this.RunOperationAsync(OperationName, () => + this.InternalCreateCollectionAsync(true, cancellationToken) + ); + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + const string OperationName = "DeleteCollection"; + return this.RunOperationAsync(OperationName, () => + this._client.DeleteTableAsync(this.CollectionName, cancellationToken) + ); + } + + /// + public Task UpsertAsync(TRecord record, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "Upsert"; + + var storageModel = VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromDataToStorageModel(record)); + + Verify.NotNull(storageModel); + + var keyObj = storageModel[this._propertyReader.KeyPropertyStoragePropertyName]; + Verify.NotNull(keyObj); + TKey key = (TKey)keyObj!; + + return this.RunOperationAsync(OperationName, async () => + { + await this._client.UpsertAsync(this.CollectionName, storageModel, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); + return key; + } + ); + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + const string OperationName = "UpsertBatch"; + + var storageModels = records.Select(record => VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromDataToStorageModel(record))).ToList(); + + var keys = storageModels.Select(model => model[this._propertyReader.KeyPropertyStoragePropertyName]!).ToList(); + + await this.RunOperationAsync(OperationName, () => + this._client.UpsertBatchAsync(this.CollectionName, storageModels, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken) + ).ConfigureAwait(false); + + foreach (var key in keys) { yield return (TKey)key!; } + } + + /// + public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "Get"; + + Verify.NotNull(key); + + bool includeVectors = options?.IncludeVectors is true; + + return this.RunOperationAsync(OperationName, async () => + { + var row = await this._client.GetAsync(this.CollectionName, key, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken).ConfigureAwait(false); + + if (row is null) { return default; } + return VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })); + }); + } + + /// + public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "GetBatch"; + + Verify.NotNull(keys); + + bool includeVectors = options?.IncludeVectors is true; + + return PostgresVectorStoreUtils.WrapAsyncEnumerableAsync( + this._client.GetBatchAsync(this.CollectionName, keys, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken) + .SelectAsync(row => + VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })), + cancellationToken + ), + OperationName, + this.CollectionName + ); + } + + /// + public Task DeleteAsync(TKey key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "Delete"; + return this.RunOperationAsync(OperationName, () => + this._client.DeleteAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, key, cancellationToken) + ); + } + + /// + public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "DeleteBatch"; + return this.RunOperationAsync(OperationName, () => + this._client.DeleteBatchAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, keys, cancellationToken) + ); + } + + /// + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "VectorizedSearch"; + + Verify.NotNull(vector); + + var vectorType = vector.GetType(); + + if (!PostgresConstants.SupportedVectorTypes.Contains(vectorType)) + { + throw new NotSupportedException( + $"The provided vector type {vectorType.FullName} is not supported by the SQLite connector. " + + $"Supported types are: {string.Join(", ", PostgresConstants.SupportedVectorTypes.Select(l => l.FullName))}"); + } + + var searchOptions = options ?? s_defaultVectorSearchOptions; + var vectorProperty = this.GetVectorPropertyForSearch(searchOptions.VectorPropertyName); + + if (vectorProperty is null) + { + throw new InvalidOperationException("The collection does not have any vector properties, so vector search is not possible."); + } + + var pgVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + Verify.NotNull(pgVector); + + // Simulating skip/offset logic locally, since OFFSET can work only with LIMIT in combination + // and LIMIT is not supported in vector search extension, instead of LIMIT - "k" parameter is used. + var limit = searchOptions.Top + searchOptions.Skip; + + return this.RunOperationAsync(OperationName, () => + { + var results = this._client.GetNearestMatchesAsync( + this.CollectionName, + this._propertyReader.RecordDefinition.Properties, + vectorProperty, + pgVector, + searchOptions.Top, + searchOptions.Filter, + searchOptions.Skip, + searchOptions.IncludeVectors, + cancellationToken) + .SelectAsync(result => + { + var record = VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel( + result.Row, new StorageToDataModelMapperOptions() { IncludeVectors = searchOptions.IncludeVectors }) + ); + + return new VectorSearchResult(record, result.Distance); + }, cancellationToken); + + return Task.FromResult(new VectorSearchResults(results)); + }); + } + + private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken = default) + { + return this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition.Properties, ifNotExists, cancellationToken); + } + + /// + /// Get vector property to use for a search by using the storage name for the field name from options + /// if available, and falling back to the first vector property in if not. + /// + /// The vector field name. + /// Thrown if the provided field name is not a valid field name. + private VectorStoreRecordVectorProperty? GetVectorPropertyForSearch(string? vectorFieldName) + { + // If vector property name is provided in options, try to find it in schema or throw an exception. + if (!string.IsNullOrWhiteSpace(vectorFieldName)) + { + // Check vector properties by data model property name. + var vectorProperty = this._propertyReader.VectorProperties + .FirstOrDefault(l => l.DataModelPropertyName.Equals(vectorFieldName, StringComparison.Ordinal)); + + if (vectorProperty is not null) + { + return vectorProperty; + } + + throw new InvalidOperationException($"The {typeof(TRecord).FullName} type does not have a vector property named '{vectorFieldName}'."); + } + + // If vector property is not provided in options, return first vector property from schema. + return this._propertyReader.VectorProperty; + } + + private async Task RunOperationAsync(string operationName, Func operation) + { + try + { + await operation.Invoke().ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = PostgresConstants.DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + } + + private async Task RunOperationAsync(string operationName, Func> operation) + { + try + { + return await operation.Invoke().ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = PostgresConstants.DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..753713d21b3f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Options when creating a . +/// +public sealed class PostgresVectorStoreRecordCollectionOptions +{ + /// + /// Gets or sets the database schema. + /// + public string Schema { get; init; } = "public"; + + /// + /// Gets or sets an optional custom mapper to use when converting between the data model and the Postgres record. + /// + /// + /// If not set, the default mapper will be used. + /// + public IVectorStoreRecordMapper>? DictionaryCustomMapper { get; init; } = null; + + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs new file mode 100644 index 000000000000..e656678413cc --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// A mapper class that handles the conversion between data models and storage models for Postgres vector store. +/// +/// The type of the data model record. +internal sealed class PostgresVectorStoreRecordMapper : IVectorStoreRecordMapper> +{ + /// with helpers for reading vector store model properties and their attributes. + private readonly VectorStoreRecordPropertyReader _propertyReader; + + /// + /// Initializes a new instance of the class. + /// + /// A that defines the schema of the data in the database. + public PostgresVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyReader) + { + Verify.NotNull(propertyReader); + + this._propertyReader = propertyReader; + + this._propertyReader.VerifyHasParameterlessConstructor(); + + // Validate property types. + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); + this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); + } + + public Dictionary MapFromDataToStorageModel(TRecord dataModel) + { + var properties = new Dictionary + { + // Add key property + { this._propertyReader.KeyPropertyStoragePropertyName, this._propertyReader.KeyPropertyInfo.GetValue(dataModel) } + }; + + // Add data properties + foreach (var property in this._propertyReader.DataPropertiesInfo) + { + properties.Add( + this._propertyReader.GetStoragePropertyName(property.Name), + property.GetValue(dataModel) + ); + } + + // Add vector properties + foreach (var property in this._propertyReader.VectorPropertiesInfo) + { + var propertyValue = property.GetValue(dataModel); + var result = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(propertyValue); + + properties.Add(this._propertyReader.GetStoragePropertyName(property.Name), result); + } + + return properties; + } + + public TRecord MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) + { + var record = (TRecord)this._propertyReader.ParameterLessConstructorInfo.Invoke(null); + + // Set key. + var keyPropertyValue = Convert.ChangeType( + storageModel[this._propertyReader.KeyPropertyStoragePropertyName], + this._propertyReader.KeyProperty.PropertyType); + + this._propertyReader.KeyPropertyInfo.SetValue(record, keyPropertyValue); + + // Process data properties. + var dataPropertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( + this._propertyReader.DataPropertiesInfo, + this._propertyReader.StoragePropertyNamesMap, + storageModel); + + VectorStoreRecordMapping.SetPropertiesOnRecord(record, dataPropertiesInfoWithValues); + + if (options.IncludeVectors) + { + // Process vector properties. + var vectorPropertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( + this._propertyReader.VectorPropertiesInfo, + this._propertyReader.StoragePropertyNamesMap, + storageModel, + (object? vector, Type type) => + { + return PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(vector); + }); + + VectorStoreRecordMapping.SetPropertiesOnRecord(record, vectorPropertiesInfoWithValues); + } + + return record; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs new file mode 100644 index 000000000000..0b36f2003bf5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -0,0 +1,269 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using Microsoft.Extensions.VectorData; +using Npgsql; +using NpgsqlTypes; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal static class PostgresVectorStoreRecordPropertyMapping +{ + internal static float[] GetOrCreateArray(ReadOnlyMemory memory) => + MemoryMarshal.TryGetArray(memory, out ArraySegment array) && + array.Count == array.Array!.Length ? + array.Array : + memory.ToArray(); + + public static Vector? MapVectorForStorageModel(TVector vector) + { + if (vector == null) + { + return null; + } + + if (vector is ReadOnlyMemory floatMemory) + { + var vecArray = MemoryMarshal.TryGetArray(floatMemory, out ArraySegment array) && + array.Count == array.Array!.Length ? + array.Array : + floatMemory.ToArray(); + return new Vector(vecArray); + } + + throw new NotSupportedException($"Mapping for type {typeof(TVector).FullName} to a vector is not supported."); + } + + public static ReadOnlyMemory? MapVectorForDataModel(object? vector) + { + var pgVector = vector is Vector pgv ? pgv : null; + if (pgVector == null) { return null; } + var vecArray = pgVector.ToArray(); + return vecArray != null && vecArray.Length != 0 ? (ReadOnlyMemory)vecArray : null; + } + + public static TPropertyType? GetPropertyValue(NpgsqlDataReader reader, string propertyName) + { + int propertyIndex = reader.GetOrdinal(propertyName); + + if (reader.IsDBNull(propertyIndex)) + { + return default; + } + + return reader.GetFieldValue(propertyIndex); + } + + public static object? GetPropertyValue(NpgsqlDataReader reader, string propertyName, Type propertyType) + { + int propertyIndex = reader.GetOrdinal(propertyName); + + if (reader.IsDBNull(propertyIndex)) + { + return null; + } + + // Check if the type implements IEnumerable + if (propertyType.IsGenericType && propertyType.GetInterfaces().Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>))) + { + var enumerable = (IEnumerable)reader.GetValue(propertyIndex); + return VectorStoreRecordMapping.CreateEnumerable(enumerable.Cast(), propertyType); + } + + return propertyType switch + { + Type t when t == typeof(bool) || t == typeof(bool?) => reader.GetBoolean(propertyIndex), + Type t when t == typeof(short) || t == typeof(short?) => reader.GetInt16(propertyIndex), + Type t when t == typeof(int) || t == typeof(int?) => reader.GetInt32(propertyIndex), + Type t when t == typeof(long) || t == typeof(long?) => reader.GetInt64(propertyIndex), + Type t when t == typeof(float) || t == typeof(float?) => reader.GetFloat(propertyIndex), + Type t when t == typeof(double) || t == typeof(double?) => reader.GetDouble(propertyIndex), + Type t when t == typeof(decimal) || t == typeof(decimal?) => reader.GetDecimal(propertyIndex), + Type t when t == typeof(string) => reader.GetString(propertyIndex), + Type t when t == typeof(byte[]) => reader.GetFieldValue(propertyIndex), + Type t when t == typeof(DateTime) || t == typeof(DateTime?) => reader.GetDateTime(propertyIndex), + Type t when t == typeof(DateTimeOffset) || t == typeof(DateTimeOffset?) => reader.GetFieldValue(propertyIndex), + Type t when t == typeof(Guid) => reader.GetFieldValue(propertyIndex), + _ => reader.GetValue(propertyIndex) + }; + } + + public static NpgsqlDbType? GetNpgsqlDbType(Type propertyType) => + propertyType switch + { + Type t when t == typeof(bool) || t == typeof(bool?) => NpgsqlDbType.Boolean, + Type t when t == typeof(short) || t == typeof(short?) => NpgsqlDbType.Smallint, + Type t when t == typeof(int) || t == typeof(int?) => NpgsqlDbType.Integer, + Type t when t == typeof(long) || t == typeof(long?) => NpgsqlDbType.Bigint, + Type t when t == typeof(float) || t == typeof(float?) => NpgsqlDbType.Real, + Type t when t == typeof(double) || t == typeof(double?) => NpgsqlDbType.Double, + Type t when t == typeof(decimal) || t == typeof(decimal?) => NpgsqlDbType.Numeric, + Type t when t == typeof(string) => NpgsqlDbType.Text, + Type t when t == typeof(byte[]) => NpgsqlDbType.Bytea, + Type t when t == typeof(DateTime) || t == typeof(DateTime?) => NpgsqlDbType.Timestamp, + Type t when t == typeof(DateTimeOffset) || t == typeof(DateTimeOffset?) => NpgsqlDbType.TimestampTz, + Type t when t == typeof(Guid) => NpgsqlDbType.Uuid, + _ => null + }; + + /// + /// Maps a .NET type to a PostgreSQL type name. + /// + /// The .NET type. + /// Tuple of the the PostgreSQL type name and whether it can be NULL + public static (string PgType, bool IsNullable) GetPostgresTypeName(Type propertyType) + { + var (pgType, isNullable) = propertyType switch + { + Type t when t == typeof(bool) => ("BOOLEAN", false), + Type t when t == typeof(short) => ("SMALLINT", false), + Type t when t == typeof(int) => ("INTEGER", false), + Type t when t == typeof(long) => ("BIGINT", false), + Type t when t == typeof(float) => ("REAL", false), + Type t when t == typeof(double) => ("DOUBLE PRECISION", false), + Type t when t == typeof(decimal) => ("NUMERIC", false), + Type t when t == typeof(string) => ("TEXT", true), + Type t when t == typeof(byte[]) => ("BYTEA", true), + Type t when t == typeof(DateTime) => ("TIMESTAMP", false), + Type t when t == typeof(DateTimeOffset) => ("TIMESTAMPTZ", false), + Type t when t == typeof(Guid) => ("UUID", false), + _ => (null, false) + }; + + if (pgType != null) + { + return (pgType, isNullable); + } + + // Handle enumerables + if (VectorStoreRecordPropertyVerification.IsSupportedEnumerableType(propertyType)) + { + Type elementType = propertyType.GetGenericArguments()[0]; + var underlyingPgType = GetPostgresTypeName(elementType); + return (underlyingPgType.PgType + "[]", true); + } + + // Handle nullable types (e.g. Nullable) + if (Nullable.GetUnderlyingType(propertyType) != null) + { + Type underlyingType = Nullable.GetUnderlyingType(propertyType) ?? throw new ArgumentException("Nullable type must have an underlying type."); + var underlyingPgType = GetPostgresTypeName(underlyingType); + return (underlyingPgType.PgType, true); + } + + throw new NotSupportedException($"Type {propertyType.Name} is not supported by this store."); + } + + /// + /// Gets the PostgreSQL vector type name based on the dimensions of the vector property. + /// + /// The vector property. + /// The PostgreSQL vector type name. + public static (string PgType, bool IsNullable) GetPgVectorTypeName(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty.Dimensions <= 0) + { + throw new ArgumentException("Vector property must have a positive number of dimensions."); + } + + return ($"VECTOR({vectorProperty.Dimensions})", Nullable.GetUnderlyingType(vectorProperty.PropertyType) != null); + } + + public static NpgsqlParameter GetNpgsqlParameter(object? value) + { + if (value == null) + { + return new NpgsqlParameter() { Value = DBNull.Value }; + } + + // If it's an IEnumerable, use reflection to determine if it needs to be converted to a list + if (value is IEnumerable enumerable && !(value is string)) + { + Type propertyType = value.GetType(); + if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(List<>)) + { + // If it's already a List, return it directly + return new NpgsqlParameter() { Value = value }; + } + + return new NpgsqlParameter() { Value = ConvertToListIfNecessary(enumerable) }; + } + + // Return the value directly if it's not IEnumerable + return new NpgsqlParameter() { Value = value }; + } + + /// + /// Returns information about vector indexes to create, validating that the dimensions of the vector are supported. + /// + /// The properties of the vector store record. + /// A list of tuples containing the column name, index kind, and distance function for each vector property. + /// + /// The default index kind is "Flat", which prevents the creation of an index. + /// + public static List<(string column, string kind, string function)> GetVectorIndexInfo(IReadOnlyList properties) + { + var vectorIndexesToCreate = new List<(string column, string kind, string function)>(); + foreach (var property in properties) + { + if (property is VectorStoreRecordVectorProperty vectorProperty) + { + var vectorColumnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + var indexKind = vectorProperty.IndexKind ?? PostgresConstants.DefaultIndexKind; + var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; + + // Index kind of "Flat" to prevent the creation of an index. This is the default behavior. + // Otherwise, the index will be created with the specified index kind and distance function, if supported. + if (indexKind != IndexKind.Flat) + { + // Ensure the dimensionality of the vector is supported for indexing. + if (PostgresConstants.IndexMaxDimensions.TryGetValue(indexKind, out int maxDimensions) && vectorProperty.Dimensions > maxDimensions) + { + throw new NotSupportedException( + $"The provided vector property {vectorProperty.DataModelPropertyName} has {vectorProperty.Dimensions} dimensions, " + + $"which is not supported by the {indexKind} index. The maximum number of dimensions supported by the {indexKind} index " + + $"is {maxDimensions}. Please reduce the number of dimensions or use a different index." + ); + } + + vectorIndexesToCreate.Add((vectorColumnName, indexKind, distanceFunction)); + } + } + } + return vectorIndexesToCreate; + } + + // Helper method to convert an IEnumerable to a List if necessary + private static object ConvertToListIfNecessary(IEnumerable enumerable) + { + // Get an enumerator to manually iterate over the collection + var enumerator = enumerable.GetEnumerator(); + + // Check if the collection is empty by attempting to move to the first element + if (!enumerator.MoveNext()) + { + return enumerable; // Return the original enumerable if it's empty + } + + // Determine the type of the first element + var firstItem = enumerator.Current; + var itemType = firstItem?.GetType() ?? typeof(object); + + // Create a strongly-typed List based on the type of the first element + var typedList = Activator.CreateInstance(typeof(List<>).MakeGenericType(itemType)) as IList; + typedList!.Add(firstItem); // Add the first element to the typed list + + // Continue iterating through the rest of the enumerable and add items to the list + while (enumerator.MoveNext()) + { + typedList.Add(enumerator.Current); + } + + return typedList; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs new file mode 100644 index 000000000000..27fa7181bdc5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal static class PostgresVectorStoreUtils +{ + /// + /// Wraps an in an that will throw a + /// if an exception is thrown while iterating over the original enumerator. + /// + /// The type of the items in the async enumerable. + /// The async enumerable to wrap. + /// The name of the operation being performed. + /// The name of the collection being operated on. + /// An async enumerable that will throw a if an exception is thrown while iterating over the original enumerator. + public static async IAsyncEnumerable WrapAsyncEnumerableAsync(IAsyncEnumerable asyncEnumerable, string operationName, string? collectionName = null) + { + var enumerator = asyncEnumerable.ConfigureAwait(false).GetAsyncEnumerator(); + + var nextResult = await GetNextAsync(enumerator, operationName, collectionName).ConfigureAwait(false); + while (nextResult.more) + { + yield return nextResult.item; + nextResult = await GetNextAsync(enumerator, operationName, collectionName).ConfigureAwait(false); + } + } + + /// + /// Helper method to get the next index name from the enumerator with a try catch around the move next call to convert + /// exceptions to . + /// + /// The enumerator to get the next result from. + /// The name of the operation being performed. + /// The name of the collection being operated on. + /// A value indicating whether there are more results and the current string if true. + public static async Task<(T item, bool more)> GetNextAsync(ConfiguredCancelableAsyncEnumerable.Enumerator enumerator, string operationName, string? collectionName = null) + { + try + { + var more = await enumerator.MoveNextAsync(); + return (enumerator.Current, more); + } + catch (Exception ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = PostgresConstants.DatabaseName, + CollectionName = collectionName, + OperationName = operationName + }; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md index 35c80a45087a..e9ed71109fbb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md @@ -18,7 +18,7 @@ This extension is also available for **Azure Database for PostgreSQL - Flexible 1. To install pgvector using Docker: ```bash -docker run -d --name postgres-pgvector -p 5432:5432 -e POSTGRES_PASSWORD=mysecretpassword ankane/pgvector +docker run -d --name postgres-pgvector -p 5432:5432 -e POSTGRES_PASSWORD=mysecretpassword pgvector/pgvector ``` 2. Create a database and enable pgvector extension on this database @@ -33,8 +33,13 @@ sk_demo=# CREATE EXTENSION vector; > Note, "Azure Cosmos DB for PostgreSQL" uses `SELECT CREATE_EXTENSION('vector');` to enable the extension. -3. To use Postgres as a semantic memory store: - > See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. +### Using PostgresVectorStore + +See [this sample](../../../samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs) for an example of using the vector store. + +### Using PostgresMemoryStore + +> See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. ```csharp NpgsqlDataSourceBuilder dataSourceBuilder = new NpgsqlDataSourceBuilder("Host=localhost;Port=5432;Database=sk_demo;User Id=postgres;Password=mysecretpassword"); @@ -87,67 +92,3 @@ BEGIN END IF; END $$; ``` - -## Migration from older versions - -Since Postgres Memory connector has been re-implemented, the new implementation uses a separate table to store each Collection. - -We provide the following migration script to help you migrate to the new structure. However, please note that due to the use of collections as table names, you need to make sure that all Collections conform to the [Postgres naming convention](https://www.postgresql.org/docs/15/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS) before migrating. - -- Table names may only consist of ASCII letters, digits, and underscores. -- Table names must start with a letter or an underscore. -- Table names may not exceed 63 characters in length. -- Table names are case-insensitive, but it is recommended to use lowercase letters. - -```sql --- Create new tables, each with the name of the collection field value -DO $$ -DECLARE - r record; - c_count integer; -BEGIN - FOR r IN SELECT DISTINCT collection FROM sk_memory_table LOOP - - -- Drop Table (This will delete the table that already exists. Please consider carefully if you think you need to cancel this comment!) - -- EXECUTE format('DROP TABLE IF EXISTS %I;', r.collection); - - -- Create Table (Modify vector size on demand) - EXECUTE format('CREATE TABLE public.%I ( - key TEXT NOT NULL, - metadata JSONB, - embedding vector(1536), - timestamp TIMESTAMP WITH TIME ZONE, - PRIMARY KEY (key) - );', r.collection); - - -- Get count of records in collection - SELECT count(*) INTO c_count FROM sk_memory_table WHERE collection = r.collection AND key <> ''; - - -- Create Index (https://github.com/pgvector/pgvector#indexing) - IF c_count > 10000000 THEN - EXECUTE format('CREATE INDEX %I - ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', - r.collection || '_ix', r.collection, ROUND(sqrt(c_count))); - ELSIF c_count > 10000 THEN - EXECUTE format('CREATE INDEX %I - ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', - r.collection || '_ix', r.collection, c_count / 1000); - END IF; - END LOOP; -END $$; - --- Copy data from the old table to the new table -DO $$ -DECLARE - r record; -BEGIN - FOR r IN SELECT DISTINCT collection FROM sk_memory_table LOOP - EXECUTE format('INSERT INTO public.%I (key, metadata, embedding, timestamp) - SELECT key, metadata::JSONB, embedding, to_timestamp(timestamp / 1000.0) AT TIME ZONE ''UTC'' - FROM sk_memory_table WHERE collection = %L AND key <> '''';', r.collection, r.collection); - END LOOP; -END $$; - --- Drop old table (After ensuring successful execution, you can remove the following comments to remove sk_memory_table.) --- DROP TABLE IF EXISTS sk_memory_table; -``` diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj new file mode 100644 index 000000000000..5698a909022e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj @@ -0,0 +1,32 @@ + + + + SemanticKernel.Connectors.Postgres.UnitTests + SemanticKernel.Connectors.Postgres.UnitTests + net8.0 + true + enable + disable + false + $(NoWarn);SKEXP0001,SKEXP0020,VSTHRD111,CA2007,CS1591 + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs new file mode 100644 index 000000000000..d9e97fc6b855 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresGenericDataModelMapperTests +{ + [Fact] + public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetGenericDataModel("key"); + + var mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal("key", result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + + var vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Fact] + public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetGenericDataModel(1); + + var mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal(1, result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + + var vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = "key", + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["FloatVector"] = storageVector + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + var mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal("key", result.Key); + Assert.Equal("Value1", result.Data["StringProperty"]); + Assert.Equal(5, result.Data["IntProperty"]); + + if (includeVectors) + { + Assert.NotNull(result.Vectors["FloatVector"]); + Assert.Equal(vector.ToArray(), ((ReadOnlyMemory)result.Vectors["FloatVector"]!).ToArray()); + } + else + { + Assert.False(result.Vectors.ContainsKey("FloatVector")); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = 1, + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["FloatVector"] = storageVector + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + var mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal(1, result.Key); + Assert.Equal("Value1", result.Data["StringProperty"]); + Assert.Equal(5, result.Data["IntProperty"]); + + if (includeVectors) + { + Assert.NotNull(result.Vectors["FloatVector"]); + Assert.Equal(vector.ToArray(), ((ReadOnlyMemory)result.Vectors["FloatVector"]!).ToArray()); + } + else + { + Assert.False(result.Vectors.ContainsKey("FloatVector")); + } + } + + #region private + + private static VectorStoreRecordDefinition GetRecordDefinition() + { + return new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(TKey)), + new VectorStoreRecordDataProperty("StringProperty", typeof(string)), + new VectorStoreRecordDataProperty("IntProperty", typeof(int)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + } + }; + } + + private static VectorStoreGenericDataModel GetGenericDataModel(TKey key) + { + return new VectorStoreGenericDataModel(key) + { + Data = new() + { + ["StringProperty"] = "Value1", + ["IntProperty"] = 5 + }, + Vectors = new() + { + ["FloatVector"] = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]) + } + }; + } + + private static VectorStoreRecordPropertyReader GetPropertyReader(VectorStoreRecordDefinition definition) + { + return new VectorStoreRecordPropertyReader(typeof(TRecord), definition, new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true + }); + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs new file mode 100644 index 000000000000..e8e84badf292 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + +/// +/// A test model for the postgres vector store. +/// +public record PostgresHotel() +{ + /// The key of the record. + [VectorStoreRecordKey] + public T HotelId { get; init; } + + /// A string metadata field. + [VectorStoreRecordData()] + public string? HotelName { get; set; } + + /// An int metadata field. + [VectorStoreRecordData()] + public int HotelCode { get; set; } + + /// A float metadata field. + [VectorStoreRecordData()] + public float? HotelRating { get; set; } + + /// A bool metadata field. + [VectorStoreRecordData(StoragePropertyName = "parking_is_included")] + public bool ParkingIncluded { get; set; } + + [VectorStoreRecordData] + public List Tags { get; set; } = []; + + /// A data field. + [VectorStoreRecordData] + public string Description { get; set; } + + public DateTime CreatedAt { get; set; } = DateTime.UtcNow; + + public DateTimeOffset UpdatedAt { get; set; } = DateTimeOffset.UtcNow; + + /// A vector field. + [VectorStoreRecordVector(4, IndexKind.Hnsw, DistanceFunction.ManhattanDistance)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } +} +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..f667d86eee30 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresServiceCollectionExtensionsTests +{ + private readonly IServiceCollection _serviceCollection = new ServiceCollection(); + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange + using var dataSource = NpgsqlDataSource.Create("Host=fake;"); + this._serviceCollection.AddSingleton(dataSource); + + // Act + this._serviceCollection.AddPostgresVectorStore(); + + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + var vectorStore = serviceProvider.GetRequiredService(); + + // Assert + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } + + [Fact] + public void AddVectorStoreRecordCollectionRegistersClass() + { + // Arrange + using var dataSource = NpgsqlDataSource.Create("Host=fake;"); + this._serviceCollection.AddSingleton(dataSource); + + // Act + this._serviceCollection.AddPostgresVectorStoreRecordCollection("testcollection"); + + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + + // Assert + var collection = serviceProvider.GetRequiredService>(); + Assert.NotNull(collection); + Assert.IsType>(collection); + + var vectorizedSearch = serviceProvider.GetRequiredService>(); + Assert.NotNull(vectorizedSearch); + Assert.IsType>(vectorizedSearch); + } + + #region private + +#pragma warning disable CA1812 // Avoid uninstantiated internal classes + private sealed class TestRecord +#pragma warning restore CA1812 // Avoid uninstantiated internal classes + { + [VectorStoreRecordKey] + public string Id { get; set; } = string.Empty; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs new file mode 100644 index 000000000000..675843a78c18 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -0,0 +1,422 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +public class PostgresVectorStoreCollectionSqlBuilderTests +{ + private readonly ITestOutputHelper _output; + private static readonly float[] s_vector = new float[] { 1.0f, 2.0f, 3.0f }; + + public PostgresVectorStoreCollectionSqlBuilderTests(ITestOutputHelper output) + { + this._output = output; + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void TestBuildCreateTableCommand(bool ifNotExists) + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var cmdInfo = builder.BuildCreateTableCommand("public", "testcollection", recordDefinition.Properties, ifNotExists: ifNotExists); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("public.\"testcollection\" (", cmdInfo.CommandText); + Assert.Contains("\"name\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"code\" INTEGER NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"rating\" REAL", cmdInfo.CommandText); + Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"free_parking\" BOOLEAN NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"tags\" TEXT[]", cmdInfo.CommandText); + Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"embedding1\" VECTOR(10) NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"embedding2\" VECTOR(10)", cmdInfo.CommandText); + Assert.Contains("PRIMARY KEY (\"id\")", cmdInfo.CommandText); + + if (ifNotExists) + { + Assert.Contains("IF NOT EXISTS", cmdInfo.CommandText); + } + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Theory] + [InlineData(IndexKind.Hnsw, DistanceFunction.EuclideanDistance)] + [InlineData(IndexKind.IvfFlat, DistanceFunction.DotProductSimilarity)] + [InlineData(IndexKind.Hnsw, DistanceFunction.CosineDistance)] + public void TestBuildCreateIndexCommand(string indexKind, string distanceFunction) + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var vectorColumn = "embedding1"; + + if (indexKind != IndexKind.Hnsw) + { + Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction)); + return; + } + + var cmdInfo = builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("CREATE INDEX ", cmdInfo.CommandText); + Assert.Contains("ON public.\"testcollection\" USING hnsw (\"embedding1\" ", cmdInfo.CommandText); + if (distanceFunction == null) + { + // Check for distance function defaults to cosine distance + Assert.Contains("vector_cosine_ops)", cmdInfo.CommandText); + } + else if (distanceFunction == DistanceFunction.CosineDistance) + { + Assert.Contains("vector_cosine_ops)", cmdInfo.CommandText); + } + else if (distanceFunction == DistanceFunction.EuclideanDistance) + { + Assert.Contains("vector_l2_ops)", cmdInfo.CommandText); + } + else + { + throw new NotImplementedException($"Test case for Distance function {distanceFunction} is not implemented."); + } + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildDropTableCommand() + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var cmdInfo = builder.BuildDropTableCommand("public", "testcollection"); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("DROP TABLE IF EXISTS public.\"testcollection\"", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildUpsertCommand() + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var row = new Dictionary() + { + ["id"] = 123, + ["name"] = "Hotel", + ["code"] = 456, + ["rating"] = 4.5f, + ["description"] = "Hotel description", + ["parking_is_included"] = true, + ["tags"] = new List { "tag1", "tag2" }, + ["embedding1"] = new Vector(s_vector), + }; + + var keyColumn = "id"; + + var cmdInfo = builder.BuildUpsertCommand("public", "testcollection", keyColumn, row); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("INSERT INTO public.\"testcollection\" (", cmdInfo.CommandText); + Assert.Contains("ON CONFLICT (\"id\")", cmdInfo.CommandText); + Assert.Contains("DO UPDATE SET", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + + foreach (var (key, index) in row.Keys.Select((key, index) => (key, index))) + { + Assert.Equal(row[key], cmdInfo.Parameters[index].Value); + // If the key is not the key column, it should be included in the update clause. + if (key != keyColumn) + { + Assert.Contains($"\"{key}\"=${index + 1}", cmdInfo.CommandText); + } + } + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildUpsertBatchCommand() + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var rows = new List>() + { + new() + { + ["id"] = 123, + ["name"] = "Hotel", + ["code"] = 456, + ["rating"] = 4.5f, + ["description"] = "Hotel description", + ["parking_is_included"] = true, + ["tags"] = new List { "tag1", "tag2" }, + ["embedding1"] = new Vector(s_vector), + }, + new() + { + ["id"] = 124, + ["name"] = "Motel", + ["code"] = 457, + ["rating"] = 4.6f, + ["description"] = "Motel description", + ["parking_is_included"] = false, + ["tags"] = new List { "tag3", "tag4" }, + ["embedding1"] = new Vector(s_vector), + }, + }; + + var keyColumn = "id"; + var columnCount = rows.First().Count; + + var cmdInfo = builder.BuildUpsertBatchCommand("public", "testcollection", keyColumn, rows); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("INSERT INTO public.\"testcollection\" (", cmdInfo.CommandText); + Assert.Contains("ON CONFLICT (\"id\")", cmdInfo.CommandText); + Assert.Contains("DO UPDATE SET", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + + foreach (var (row, rowIndex) in rows.Select((row, rowIndex) => (row, rowIndex))) + { + foreach (var (column, columnIndex) in row.Keys.Select((key, index) => (key, index))) + { + Assert.Equal(row[column], cmdInfo.Parameters[columnIndex + (rowIndex * columnCount)].Value); + // If the key is not the key column, it should be included in the update clause. + if (column != keyColumn) + { + Assert.Contains($"\"{column}\" = EXCLUDED.\"{column}\"", cmdInfo.CommandText); + } + } + } + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildGetCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var key = 123; + + // Act + var cmdInfo = builder.BuildGetCommand("public", "testcollection", recordDefinition.Properties, key, includeVectors: true); + + // Assert + Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("\"free_parking\"", cmdInfo.CommandText); + Assert.Contains("\"embedding1\"", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = $1", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildGetBatchCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var keys = new List { 123, 124 }; + + // Act + var cmdInfo = builder.BuildGetBatchCommand("public", "testcollection", recordDefinition.Properties, keys, includeVectors: true); + + // Assert + Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("\"code\"", cmdInfo.CommandText); + Assert.Contains("\"free_parking\"", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = ANY($1)", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + Assert.Single(cmdInfo.Parameters); + Assert.Equal(keys, cmdInfo.Parameters[0].Value); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildDeleteCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var key = 123; + + // Act + var cmdInfo = builder.BuildDeleteCommand("public", "testcollection", "id", key); + + // Assert + Assert.Contains("DELETE", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = $1", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildDeleteBatchCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var keys = new List { 123, 124 }; + + // Act + var cmdInfo = builder.BuildDeleteBatchCommand("public", "testcollection", "id", keys); + + // Assert + Assert.Contains("DELETE", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = ANY($1)", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + Assert.Single(cmdInfo.Parameters); + Assert.Equal(keys, cmdInfo.Parameters[0].Value); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildGetNearestMatchCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var vectorProperty = new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }; + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("tags", typeof(List)), + vectorProperty, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var vector = new Vector(s_vector); + + // Act + var cmdInfo = builder.BuildGetNearestMatchCommand("public", "testcollection", + properties: recordDefinition.Properties, + vectorProperty: vectorProperty, + vectorValue: vector, + filter: null, + skip: null, + includeVectors: true, + limit: 10); + + // Assert + Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("ORDER BY", cmdInfo.CommandText); + Assert.Contains("LIMIT 10", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..0533ab28c3f3 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Moq; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +public class PostgresVectorStoreRecordCollectionTests +{ + private const string TestCollectionName = "testcollection"; + + private readonly Mock _postgresClientMock; + private readonly CancellationToken _testCancellationToken = new(false); + + public PostgresVectorStoreRecordCollectionTests() + { + this._postgresClientMock = new Mock(MockBehavior.Strict); + } + + [Fact] + public async Task CreatesCollectionForGenericModelAsync() + { + // Arrange + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = [ + new VectorStoreRecordKeyProperty("HotelId", typeof(int)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 100, DistanceFunction = DistanceFunction.ManhattanDistance } + ] + }; + var options = new PostgresVectorStoreRecordCollectionOptions>() + { + VectorStoreRecordDefinition = recordDefinition + }; + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options); + this._postgresClientMock.Setup(x => x.DoesTableExistsAsync(TestCollectionName, this._testCancellationToken)).ReturnsAsync(false); + + // Act + var exists = await sut.CollectionExistsAsync(); + + // Assert. + Assert.False(exists); + } + + [Fact] + public void ThrowsForUnsupportedType() + { + // Arrange + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = [ + new VectorStoreRecordKeyProperty("HotelId", typeof(ulong)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + ] + }; + var options = new PostgresVectorStoreRecordCollectionOptions>() + { + VectorStoreRecordDefinition = recordDefinition + }; + + // Act & Assert + Assert.Throws(() => new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options)); + } + + [Fact] + public async Task UpsertRecordAsyncProducesExpectedClientCallAsync() + { + // Arrange + Dictionary? capturedArguments = null; + + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName); + var record = new PostgresHotel + { + HotelId = 1, + HotelName = "Hotel 1", + HotelCode = 1, + HotelRating = 4.5f, + ParkingIncluded = true, + Tags = ["tag1", "tag2"], + Description = "A hotel", + DescriptionEmbedding = new ReadOnlyMemory([1.0f, 2.0f, 3.0f, 4.0f]) + }; + + this._postgresClientMock.Setup(x => x.UpsertAsync( + TestCollectionName, + It.IsAny>(), + "HotelId", + this._testCancellationToken)) + .Callback, string, CancellationToken>((collectionName, args, key, ct) => capturedArguments = args) + .Returns(Task.CompletedTask); + + // Act + await sut.UpsertAsync(record, cancellationToken: this._testCancellationToken); + + // Assert + Assert.NotNull(capturedArguments); + Assert.Equal(1, (int)(capturedArguments["HotelId"] ?? 0)); + Assert.Equal("Hotel 1", (string)(capturedArguments["HotelName"] ?? "")); + Assert.Equal(1, (int)(capturedArguments["HotelCode"] ?? 0)); + Assert.Equal(4.5f, (float)(capturedArguments["HotelRating"] ?? 0.0f)); + Assert.True((bool)(capturedArguments["parking_is_included"] ?? false)); + Assert.True(capturedArguments["Tags"] is List); + var tags = capturedArguments["Tags"] as List; + Assert.Equal(2, tags!.Count); + Assert.Equal("tag1", tags[0]); + Assert.Equal("tag2", tags[1]); + Assert.Equal("A hotel", (string)(capturedArguments["Description"] ?? "")); + Assert.NotNull(capturedArguments["DescriptionEmbedding"]); + Assert.IsType(capturedArguments["DescriptionEmbedding"]); + var embedding = ((Vector)capturedArguments["DescriptionEmbedding"]!).ToArray(); + Assert.Equal(1.0f, embedding[0]); + Assert.Equal(2.0f, embedding[1]); + Assert.Equal(3.0f, embedding[2]); + Assert.Equal(4.0f, embedding[3]); + } + + [Fact] + public async Task CollectionExistsReturnsValidResultAsync() + { + // Arrange + const string TableName = "CollectionExists"; + + this._postgresClientMock.Setup(x => x.DoesTableExistsAsync(TableName, this._testCancellationToken)).ReturnsAsync(true); + + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TableName); + + // Act + var result = await sut.CollectionExistsAsync(); + + Assert.True(result); + } + + [Fact] + public async Task DeleteCollectionCallsClientDeleteAsync() + { + // Arrange + const string TableName = "DeleteCollection"; + + this._postgresClientMock.Setup(x => x.DeleteTableAsync(TableName, this._testCancellationToken)).Returns(Task.CompletedTask); + + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TableName); + + // Act + await sut.DeleteCollectionAsync(); + + // Assert + this._postgresClientMock.Verify(x => x.DeleteTableAsync(TableName, this._testCancellationToken), Times.Once); + } + + #region private + + private static void AssertRecord(TestRecord expectedRecord, TestRecord? actualRecord, bool includeVectors) + { + Assert.NotNull(actualRecord); + + Assert.Equal(expectedRecord.Key, actualRecord.Key); + Assert.Equal(expectedRecord.Data, actualRecord.Data); + + if (includeVectors) + { + Assert.NotNull(actualRecord.Vector); + Assert.Equal(expectedRecord.Vector!.Value.ToArray(), actualRecord.Vector.Value.Span.ToArray()); + } + else + { + Assert.Null(actualRecord.Vector); + } + } + +#pragma warning disable CA1812 + private sealed class TestRecord + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } + + [VectorStoreRecordData] + public string? Data { get; set; } + + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + public ReadOnlyMemory? Vector { get; set; } + } + + private sealed class TestRecordWithoutVectorProperty + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } + + [VectorStoreRecordData] + public string? Data { get; set; } + } +#pragma warning restore CA1812 + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs new file mode 100644 index 000000000000..11dfd2ecd564 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresVectorStoreRecordMapperTests +{ + [Fact] + public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetDataModel("key"); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal("key", result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + Assert.Equal(new List { "Value2", "Value3" }, result["StringArray"]); + + Vector? vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Fact] + public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetDataModel(1); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal((ulong)1, result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + Assert.Equal(new List { "Value2", "Value3" }, result["StringArray"]); + + var vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = "key", + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["StringArray"] = new List { "Value2", "Value3" }, + ["FloatVector"] = storageVector, + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal("key", result.Key); + Assert.Equal("Value1", result.StringProperty); + Assert.Equal(5, result.IntProperty); + Assert.Equal(new List { "Value2", "Value3" }, result.StringArray); + + if (includeVectors) + { + Assert.NotNull(result.FloatVector); + Assert.Equal(vector.Span.ToArray(), result.FloatVector.Value.Span.ToArray()); + } + else + { + Assert.Null(result.FloatVector); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = (ulong)1, + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["StringArray"] = new List { "Value2", "Value3" }, + ["FloatVector"] = storageVector + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal((ulong)1, result.Key); + Assert.Equal("Value1", result.StringProperty); + Assert.Equal(5, result.IntProperty); + Assert.Equal(new List { "Value2", "Value3" }, result.StringArray); + + if (includeVectors) + { + Assert.NotNull(result.FloatVector); + Assert.Equal(vector.Span.ToArray(), result.FloatVector.Value.Span.ToArray()); + } + else + { + Assert.Null(result.FloatVector); + } + } + + #region private + + private static VectorStoreRecordDefinition GetRecordDefinition() + { + return new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(TKey)), + new VectorStoreRecordDataProperty("StringProperty", typeof(string)), + new VectorStoreRecordDataProperty("IntProperty", typeof(int)), + new VectorStoreRecordDataProperty("StringArray", typeof(IEnumerable)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + } + }; + } + + private static TestRecord GetDataModel(TKey key) + { + return new TestRecord + { + Key = key, + StringProperty = "Value1", + IntProperty = 5, + StringArray = new List { "Value2", "Value3" }, + FloatVector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]) + }; + } + + private static VectorStoreRecordPropertyReader GetPropertyReader(VectorStoreRecordDefinition definition) + { + return new VectorStoreRecordPropertyReader(typeof(TRecord), definition, new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true + }); + } + +#pragma warning disable CA1812 + private sealed class TestRecord + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } + + [VectorStoreRecordData] + public string? StringProperty { get; set; } + + [VectorStoreRecordData] + public int? IntProperty { get; set; } + + [VectorStoreRecordData] + public IEnumerable? StringArray { get; set; } + + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + public ReadOnlyMemory? FloatVector { get; set; } + } +#pragma warning restore CA1812 + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs new file mode 100644 index 000000000000..0631cc2c0df4 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresVectorStoreRecordPropertyMappingTests +{ + [Fact] + public void MapVectorForStorageModelWithInvalidVectorTypeThrowsException() + { + // Arrange + var vector = new float[] { 1f, 2f, 3f }; + + // Act & Assert + Assert.Throws(() => PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector)); + } + + [Fact] + public void MapVectorForStorageModelReturnsVector() + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + + // Act + var storageModelVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + // Assert + Assert.IsType(storageModelVector); + Assert.True(storageModelVector.ToArray().Length > 0); + } + + [Fact] + public void MapVectorForDataModelReturnsReadOnlyMemory() + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var pgVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + // Act + var dataModelVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(pgVector); + + // Assert + Assert.NotNull(dataModelVector); + Assert.Equal(vector.ToArray(), dataModelVector!.Value.ToArray()); + } + + [Fact] + public void GetPropertyValueReturnsCorrectValuesForLists() + { + // Arrange + var typesAndExpectedValues = new List<(Type, object)> + { + (typeof(List), "INTEGER[]"), + (typeof(List), "REAL[]"), + (typeof(List), "DOUBLE PRECISION[]"), + (typeof(List), "TEXT[]"), + (typeof(List), "BOOLEAN[]"), + (typeof(List), "TIMESTAMP[]"), + (typeof(List), "UUID[]"), + }; + + // Act & Assert + foreach (var (type, expectedValue) in typesAndExpectedValues) + { + var (pgType, _) = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(type); + Assert.Equal(expectedValue, pgType); + } + } + + [Fact] + public void GetPropertyValueReturnsCorrectNullableValue() + { + // Arrange + var typesAndExpectedValues = new List<(Type, object)> + { + (typeof(short), false), + (typeof(short?), true), + (typeof(int?), true), + (typeof(long), false), + (typeof(string), true), + (typeof(bool?), true), + (typeof(DateTime?), true), + (typeof(Guid), false), + }; + + // Act & Assert + foreach (var (type, expectedValue) in typesAndExpectedValues) + { + var (_, isNullable) = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(type); + Assert.Equal(expectedValue, isNullable); + } + } + + [Fact] + public void GetVectorIndexInfoReturnsCorrectValues() + { + // Arrange + List vectorProperties = [ + new VectorStoreRecordVectorProperty("vector1", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 1000 }, + new VectorStoreRecordVectorProperty("vector2", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Flat, Dimensions = 3000 }, + new VectorStoreRecordVectorProperty("vector3", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 900, DistanceFunction = DistanceFunction.ManhattanDistance }, + ]; + + // Act + var indexInfo = PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo(vectorProperties); + + // Assert + Assert.Equal(2, indexInfo.Count); + foreach (var (columnName, indexKind, distanceFunction) in indexInfo) + { + if (columnName == "vector1") + { + Assert.Equal(IndexKind.Hnsw, indexKind); + Assert.Equal(DistanceFunction.CosineDistance, distanceFunction); + } + else if (columnName == "vector3") + { + Assert.Equal(IndexKind.Hnsw, indexKind); + Assert.Equal(DistanceFunction.ManhattanDistance, distanceFunction); + } + else + { + Assert.Fail("Unexpected column name"); + } + } + } + + [Theory] + [InlineData(IndexKind.Hnsw, 3000)] + public void GetVectorIndexInfoReturnsThrowsForInvalidDimensions(string indexKind, int dimensions) + { + // Arrange + var vectorProperty = new VectorStoreRecordVectorProperty("vector", typeof(ReadOnlyMemory?)) { IndexKind = indexKind, Dimensions = dimensions }; + + // Act & Assert + Assert.Throws(() => PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo([vectorProperty])); + } +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs new file mode 100644 index 000000000000..b11d6a81963f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Moq; +using Npgsql; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Contains tests for the class. +/// +public class PostgresVectorStoreTests +{ + private const string TestCollectionName = "testcollection"; + + private readonly Mock _postgresClientMock; + private readonly CancellationToken _testCancellationToken = new(false); + + public PostgresVectorStoreTests() + { + this._postgresClientMock = new Mock(MockBehavior.Strict); + } + + [Fact] + public void GetCollectionReturnsCollection() + { + // Arrange. + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.NotNull(actual); + Assert.IsType>>(actual); + } + + [Fact] + public void GetCollectionThrowsForInvalidKeyType() + { + // Arrange. + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act and Assert. + Assert.Throws(() => sut.GetCollection>(TestCollectionName)); + } + + [Fact] + public void GetCollectionCallsFactoryIfProvided() + { + // Arrange. + var factoryMock = new Mock(MockBehavior.Strict); + var collectionMock = new Mock>>(MockBehavior.Strict); + var clientMock = new Mock(MockBehavior.Strict); + clientMock.Setup(x => x.DataSource).Returns(null); + factoryMock + .Setup(x => x.CreateVectorStoreRecordCollection>(It.IsAny(), TestCollectionName, null)) + .Returns(collectionMock.Object); + var sut = new PostgresVectorStore(clientMock.Object, new() { VectorStoreCollectionFactory = factoryMock.Object }); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.Equal(collectionMock.Object, actual); + } + + [Fact] + public async Task ListCollectionNamesCallsSDKAsync() + { + // Arrange + var expectedCollections = new List { "fake-collection-1", "fake-collection-2", "fake-collection-3" }; + + this._postgresClientMock + .Setup(client => client.GetTablesAsync(CancellationToken.None)) + .Returns(expectedCollections.ToAsyncEnumerable()); + + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act. + var actual = sut.ListCollectionNamesAsync(this._testCancellationToken); + + // Assert + Assert.NotNull(actual); + var actualList = await actual.ToListAsync(); + Assert.Equal(expectedCollections, actualList); + } + + [Fact] + public async Task ListCollectionNamesThrowsCorrectExceptionAsync() + { + // Arrange + var expectedCollections = new List { "fake-collection-1", "fake-collection-2", "fake-collection-3" }; + + this._postgresClientMock + .Setup(client => client.GetTablesAsync(CancellationToken.None)) + .Returns(this.ThrowingAsyncEnumerableAsync); + + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act. + var actual = sut.ListCollectionNamesAsync(this._testCancellationToken); + + // Assert + Assert.NotNull(actual); + await Assert.ThrowsAsync(async () => await actual.ToListAsync()); + } + + private async IAsyncEnumerable ThrowingAsyncEnumerableAsync() + { + int itemIndex = 0; + await foreach (var item in new List { "item1", "item2", "item3" }.ToAsyncEnumerable()) + { + if (itemIndex == 1) + { + throw new InvalidOperationException("Test exception"); + } + yield return item; + itemIndex++; + } + } + + public sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public required TKey Key { get; set; } + + [VectorStoreRecordData] + public string Data { get; set; } = string.Empty; + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? Vector { get; set; } + + public string? NotAnnotated { get; set; } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs new file mode 100644 index 000000000000..48a8f5f36a41 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + +/// +/// A test model for the postgres vector store. +/// +public record PostgresHotel() +{ + /// The key of the record. + [VectorStoreRecordKey] + public T HotelId { get; init; } + + /// A string metadata field. + [VectorStoreRecordData()] + public string? HotelName { get; set; } + + /// An int metadata field. + [VectorStoreRecordData()] + public int HotelCode { get; set; } + + /// A float metadata field. + [VectorStoreRecordData()] + public float? HotelRating { get; set; } + + /// A bool metadata field. + [VectorStoreRecordData(StoragePropertyName = "parking_is_included")] + public bool ParkingIncluded { get; set; } + + [VectorStoreRecordData] + public List Tags { get; set; } = []; + + [VectorStoreRecordData] + public List? ListInts { get; set; } = null; + + /// A data field. + [VectorStoreRecordData] + public string Description { get; set; } + + /// A vector field. + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.EuclideanDistance, IndexKind: IndexKind.Hnsw)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } + + public DateTime CreatedAt { get; set; } = DateTime.UtcNow; + + public DateTimeOffset UpdatedAt { get; set; } = DateTimeOffset.UtcNow; + + public PostgresHotel(T key) : this() + { + this.HotelId = key; + } +} + +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs index 19126a090874..71474ff0ebc6 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs @@ -11,7 +11,7 @@ using Npgsql; using Xunit; -namespace SemanticKernel.IntegrationTests.Connectors.Postgres; +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; /// /// Integration tests of . @@ -41,6 +41,8 @@ public async Task InitializeAsync() this._connectionString = connectionString; this._databaseName = $"sk_it_{Guid.NewGuid():N}"; + await this.CreateDatabaseAsync(); + NpgsqlConnectionStringBuilder connectionStringBuilder = new(this._connectionString) { Database = this._databaseName @@ -50,8 +52,6 @@ public async Task InitializeAsync() dataSourceBuilder.UseVector(); this._dataSource = dataSourceBuilder.Build(); - - await this.CreateDatabaseAsync(); } public async Task DisposeAsync() diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs new file mode 100644 index 000000000000..5d202af5b9f5 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +[CollectionDefinition("PostgresVectorStoreCollection")] +public class PostgresVectorStoreCollectionFixture : ICollectionFixture +{ +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs new file mode 100644 index 000000000000..5888a513ace0 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs @@ -0,0 +1,239 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Docker.DotNet; +using Docker.DotNet.Models; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +public class PostgresVectorStoreFixture : IAsyncLifetime +{ + /// The docker client we are using to create a postgres container with. + private readonly DockerClient _client; + + /// The id of the postgres container that we are testing with. + private string? _containerId = null; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + + /// + /// Initializes a new instance of the class. + /// + public PostgresVectorStoreFixture() + { + using var dockerClientConfiguration = new DockerClientConfiguration(); + this._client = dockerClientConfiguration.CreateClient(); + } + + /// + /// Holds the Npgsql data source to use for tests. + /// + private NpgsqlDataSource? _dataSource; + + private string _connectionString = null!; + private string _databaseName = null!; + + /// + /// Gets a vector store to use for tests. + /// + public IVectorStore VectorStore => new PostgresVectorStore(this._dataSource!); + + /// + /// Get a database connection + /// + public NpgsqlConnection GetConnection() + { + return this._dataSource!.OpenConnection(); + } + + public IVectorStoreRecordCollection GetCollection( + string collectionName, + VectorStoreRecordDefinition? recordDefinition = default) + where TKey : notnull + where TRecord : class + { + var vectorStore = this.VectorStore; + return vectorStore.GetCollection(collectionName, recordDefinition); + } + + /// + /// Create / Recreate postgres docker container and run it. + /// + /// An async task. + public async Task InitializeAsync() + { + this._containerId = await SetupPostgresContainerAsync(this._client); + this._connectionString = "Host=localhost;Port=5432;Username=postgres;Password=example;Database=postgres;"; + this._databaseName = $"sk_it_{Guid.NewGuid():N}"; + + // Connect to postgres. + NpgsqlConnectionStringBuilder connectionStringBuilder = new(this._connectionString) + { + Database = this._databaseName + }; + + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionStringBuilder.ToString()); + dataSourceBuilder.UseVector(); + + this._dataSource = dataSourceBuilder.Build(); + + // Wait for the postgres container to be ready and create the test database using the initial data source. + var initialDataSource = NpgsqlDataSource.Create(this._connectionString); + using (initialDataSource) + { + var retryCount = 0; + var exceptionCount = 0; + while (retryCount++ < 5) + { + try + { + NpgsqlConnection connection = await initialDataSource.OpenConnectionAsync().ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT count(*) FROM information_schema.tables WHERE table_schema = 'public';"; + await cmd.ExecuteScalarAsync().ConfigureAwait(false); + } + } + catch (NpgsqlException) + { + exceptionCount++; + await Task.Delay(1000); + } + } + + if (exceptionCount >= 5) + { + // Throw an exception for test setup + throw new InvalidOperationException("Postgres container did not start in time."); + } + + await this.CreateDatabaseAsync(initialDataSource); + } + + // Create the table. + await this.CreateTableAsync(); + } + + private async Task CreateTableAsync() + { + NpgsqlConnection connection = await this._dataSource!.OpenConnectionAsync().ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = @" + CREATE TABLE hotel_info ( + HotelId INTEGER NOT NULL, + HotelName TEXT, + HotelCode INTEGER NOT NULL, + HotelRating REAL, + parking_is_included BOOLEAN, + Tags TEXT[] NOT NULL, + Description TEXT NOT NULL, + DescriptionEmbedding VECTOR(4) NOT NULL, + PRIMARY KEY (HotelId));"; + await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); + } + } + + /// + /// Delete the docker container after the test run. + /// + /// An async task. + public async Task DisposeAsync() + { + if (this._dataSource != null) + { + this._dataSource.Dispose(); + } + + await this.DropDatabaseAsync(); + + if (this._containerId != null) + { + await this._client.Containers.StopContainerAsync(this._containerId, new ContainerStopParameters()); + await this._client.Containers.RemoveContainerAsync(this._containerId, new ContainerRemoveParameters()); + } + } + + /// + /// Setup the postgres container by pulling the image and running it. + /// + /// The docker client to create the container with. + /// The id of the container. + private static async Task SetupPostgresContainerAsync(DockerClient client) + { + await client.Images.CreateImageAsync( + new ImagesCreateParameters + { + FromImage = "pgvector/pgvector", + Tag = "pg16", + }, + null, + new Progress()); + + var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() + { + Image = "pgvector/pgvector:pg16", + HostConfig = new HostConfig() + { + PortBindings = new Dictionary> + { + {"5432", new List {new() {HostPort = "5432" } }}, + }, + PublishAllPorts = true + }, + ExposedPorts = new Dictionary + { + { "5432", default }, + }, + Env = new List + { + "POSTGRES_USER=postgres", + "POSTGRES_PASSWORD=example", + }, + }); + + await client.Containers.StartContainerAsync( + container.ID, + new ContainerStartParameters()); + + return container.ID; + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "The database name is generated randomly, it does not support parameterized passing.")] + private async Task CreateDatabaseAsync(NpgsqlDataSource initialDataSource) + { + await using (NpgsqlConnection conn = await initialDataSource.OpenConnectionAsync()) + { + await using NpgsqlCommand command = new($"CREATE DATABASE \"{this._databaseName}\"", conn); + await command.ExecuteNonQueryAsync(); + } + + await using (NpgsqlConnection conn = await this._dataSource!.OpenConnectionAsync()) + { + await using (NpgsqlCommand command = new("CREATE EXTENSION vector", conn)) + { + await command.ExecuteNonQueryAsync(); + } + await conn.ReloadTypesAsync(); + } + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "The database name is generated randomly, it does not support parameterized passing.")] + private async Task DropDatabaseAsync() + { + using NpgsqlDataSource dataSource = NpgsqlDataSource.Create(this._connectionString); + await using NpgsqlConnection conn = await dataSource.OpenConnectionAsync(); + await using NpgsqlCommand command = new($"DROP DATABASE IF EXISTS \"{this._databaseName}\"", conn); + await command.ExecuteNonQueryAsync(); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..7e3ae3ad9392 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -0,0 +1,562 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Npgsql; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +[Collection("PostgresVectorStoreCollection")] +public sealed class PostgresVectorStoreRecordCollectionTests(PostgresVectorStoreFixture fixture) +{ + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CollectionExistsReturnsCollectionStateAsync(bool createCollection) + { + // Arrange + var sut = fixture.GetCollection>("CollectionExists"); + + if (createCollection) + { + await sut.CreateCollectionAsync(); + } + + try + { + // Act + var collectionExists = await sut.CollectionExistsAsync(); + + // Assert + Assert.Equal(createCollection, collectionExists); + } + finally + { + // Cleanup + if (createCollection) + { + await sut.DeleteCollectionAsync(); + } + } + } + + [Fact] + public async Task CollectionCanUpsertAndGetAsync() + { + // Arrange + var sut = fixture.GetCollection>("CollectionCanUpsertAndGet"); + if (await sut.CollectionExistsAsync()) + { + await sut.DeleteCollectionAsync(); + } + + await sut.CreateCollectionAsync(); + + var writtenHotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + var writtenHotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, ListInts = [1, 2] }; + + try + { + // Act + + await sut.UpsertAsync(writtenHotel1); + + await sut.UpsertAsync(writtenHotel2); + + var fetchedHotel1 = await sut.GetAsync(1); + var fetchedHotel2 = await sut.GetAsync(2); + + // Assert + Assert.NotNull(fetchedHotel1); + Assert.Equal(1, fetchedHotel1!.HotelId); + Assert.Equal("Hotel 1", fetchedHotel1!.HotelName); + Assert.Equal(1, fetchedHotel1!.HotelCode); + Assert.True(fetchedHotel1!.ParkingIncluded); + Assert.Equal(4.5f, fetchedHotel1!.HotelRating); + Assert.NotNull(fetchedHotel1!.Tags); + Assert.Equal(2, fetchedHotel1!.Tags!.Count); + Assert.Equal("tag1", fetchedHotel1!.Tags![0]); + Assert.Equal("tag2", fetchedHotel1!.Tags![1]); + Assert.Null(fetchedHotel1!.ListInts); + Assert.Equal(TruncateMilliseconds(fetchedHotel1.CreatedAt), TruncateMilliseconds(writtenHotel1.CreatedAt)); + Assert.Equal(TruncateMilliseconds(fetchedHotel1.UpdatedAt), TruncateMilliseconds(writtenHotel1.UpdatedAt)); + + Assert.NotNull(fetchedHotel2); + Assert.Equal(2, fetchedHotel2!.HotelId); + Assert.Equal("Hotel 2", fetchedHotel2!.HotelName); + Assert.Equal(2, fetchedHotel2!.HotelCode); + Assert.False(fetchedHotel2!.ParkingIncluded); + Assert.Equal(2.5f, fetchedHotel2!.HotelRating); + Assert.NotNull(fetchedHotel2!.Tags); + Assert.Empty(fetchedHotel2!.Tags); + Assert.NotNull(fetchedHotel2!.ListInts); + Assert.Equal(2, fetchedHotel2!.ListInts!.Count); + Assert.Equal(1, fetchedHotel2!.ListInts![0]); + Assert.Equal(2, fetchedHotel2!.ListInts![1]); + Assert.Equal(TruncateMilliseconds(fetchedHotel2.CreatedAt), TruncateMilliseconds(writtenHotel2.CreatedAt)); + Assert.Equal(TruncateMilliseconds(fetchedHotel2.UpdatedAt), TruncateMilliseconds(writtenHotel2.UpdatedAt)); + } + finally + { + // Cleanup + await sut.DeleteCollectionAsync(); + } + } + + public static IEnumerable ItCanGetAndDeleteRecordParameters => + new List + { + new object[] { typeof(short), (short)3 }, + new object[] { typeof(int), 5 }, + new object[] { typeof(long), 7L }, + new object[] { typeof(string), "key1" }, + new object[] { typeof(Guid), Guid.NewGuid() } + }; + + [Theory] + [MemberData(nameof(ItCanGetAndDeleteRecordParameters))] + public async Task ItCanGetAndDeleteRecordAsync(Type idType, TKey? key) + { + // Arrange + var collectionName = "DeleteRecord"; + var sut = this.GetCollection(idType, collectionName); + + await sut.CreateCollectionAsync(); + + try + { + var record = this.CreateRecord(idType, key!); + var recordKey = record.HotelId; + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync(recordKey); + + Assert.Equal(key, upsertResult); + Assert.NotNull(getResult); + + // Act + await sut.DeleteAsync(recordKey); + + getResult = await sut.GetAsync(recordKey); + + // Assert + Assert.Null(getResult); + } + finally + { + // Cleanup + await sut.DeleteCollectionAsync(); + } + } + + [Fact] + public async Task ItCanGetUpsertDeleteBatchAsync() + { + // Arrange + const int HotelId1 = 1; + const int HotelId2 = 2; + const int HotelId3 = 3; + + var sut = fixture.GetCollection>("GetUpsertDeleteBatch"); + + await sut.CreateCollectionAsync(); + + var record1 = new PostgresHotel { HotelId = HotelId1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + var record2 = new PostgresHotel { HotelId = HotelId2, HotelName = "Hotel 2", HotelCode = 1, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag3"] }; + var record3 = new PostgresHotel { HotelId = HotelId3, HotelName = "Hotel 3", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag4"] }; + + var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); + var getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + + Assert.Equal([HotelId1, HotelId2, HotelId3], upsertResults); + + Assert.NotNull(getResults.First(l => l.HotelId == HotelId1)); + Assert.NotNull(getResults.First(l => l.HotelId == HotelId2)); + Assert.NotNull(getResults.First(l => l.HotelId == HotelId3)); + + // Act + await sut.DeleteBatchAsync([HotelId1, HotelId2, HotelId3]); + + getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + + // Assert + Assert.Empty(getResults); + } + + [Fact] + public async Task ItCanUpsertExistingRecordAsync() + { + // Arrange + const int HotelId = 5; + var sut = fixture.GetCollection>("UpsertRecord"); + + await sut.CreateCollectionAsync(); + + var record = new PostgresHotel { HotelId = HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync(HotelId, new() { IncludeVectors = true }); + + Assert.Equal(HotelId, upsertResult); + Assert.NotNull(getResult); + Assert.Null(getResult!.DescriptionEmbedding); + + // Act + record.HotelName = "Updated name"; + record.HotelRating = 10; + record.DescriptionEmbedding = new[] { 1f, 2f, 3f, 4f }; + + upsertResult = await sut.UpsertAsync(record); + getResult = await sut.GetAsync(HotelId, new() { IncludeVectors = true }); + + // Assert + Assert.NotNull(getResult); + Assert.Equal("Updated name", getResult.HotelName); + Assert.Equal(10, getResult.HotelRating); + + Assert.NotNull(getResult.DescriptionEmbedding); + Assert.Equal(record.DescriptionEmbedding!.Value.ToArray(), getResult.DescriptionEmbedding.Value.ToArray()); + } + + [Fact] + public async Task ItCanReadManuallyInsertedRecordAsync() + { + const string CollectionName = "ItCanReadManuallyInsertedRecordAsync"; + // Arrange + var sut = fixture.GetCollection>(CollectionName); + await sut.CreateCollectionAsync().ConfigureAwait(true); + Assert.True(await sut.CollectionExistsAsync().ConfigureAwait(true)); + await using (var connection = fixture.GetConnection()) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = @$" + INSERT INTO public.""{CollectionName}"" ( + ""HotelId"", ""HotelName"", ""HotelCode"", ""HotelRating"", ""parking_is_included"", ""Tags"", ""Description"", ""DescriptionEmbedding"" + ) VALUES ( + 215, 'Divine Lorraine', 215, 5, false, ARRAY['historic', 'philly'], 'An iconic building on broad street', '[10,20,30,40]' + );"; + await cmd.ExecuteNonQueryAsync().ConfigureAwait(true); + } + + // Act + var getResult = await sut.GetAsync(215, new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(getResult); + Assert.Equal(215, getResult!.HotelId); + Assert.Equal("Divine Lorraine", getResult.HotelName); + Assert.Equal(215, getResult.HotelCode); + Assert.Equal(5, getResult.HotelRating); + Assert.False(getResult.ParkingIncluded); + Assert.Equal(new List { "historic", "philly" }, getResult.Tags); + Assert.Equal("An iconic building on broad street", getResult.Description); + Assert.Equal([10f, 20f, 30f, 40f], getResult.DescriptionEmbedding!.Value.ToArray()); + } + + [Fact] + public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + { + const int HotelId = 5; + + var sut = fixture.GetCollection>("GenericMapperWithNumericKey", GetVectorStoreRecordDefinition()); + + await sut.CreateCollectionAsync(); + + var record = new PostgresHotel { HotelId = (int)HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + + // Act + var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) + { + Data = + { + { "HotelName", "Generic Mapper Hotel" }, + { "Description", "This is a generic mapper hotel" }, + { "HotelCode", 1 }, + { "ParkingIncluded", true }, + { "HotelRating", 3.6f } + }, + Vectors = + { + { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } + } + }); + + var localGetResult = await sut.GetAsync(HotelId, new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.Equal(HotelId, upsertResult); + + Assert.NotNull(localGetResult); + Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); + Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); + Assert.True((bool?)localGetResult.Data["ParkingIncluded"]); + Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); + Assert.Equal([30f, 31f, 32f, 33f], ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + + // Act - update with null embeddings + // Act + var upsertResult2 = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) + { + Data = + { + { "HotelName", "Generic Mapper Hotel" }, + { "Description", "This is a generic mapper hotel" }, + { "HotelCode", 1 }, + { "ParkingIncluded", true }, + { "HotelRating", 3.6f } + }, + Vectors = + { + { "DescriptionEmbedding", null } + } + }); + + var localGetResult2 = await sut.GetAsync(HotelId, new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(localGetResult2); + Assert.Null(localGetResult2.Vectors["DescriptionEmbedding"]); + } + + [Theory] + [InlineData(true, DistanceFunction.CosineDistance)] + [InlineData(false, DistanceFunction.CosineDistance)] + [InlineData(false, DistanceFunction.CosineSimilarity)] + [InlineData(false, DistanceFunction.EuclideanDistance)] + [InlineData(false, DistanceFunction.ManhattanDistance)] + [InlineData(false, DistanceFunction.DotProductSimilarity)] + public async Task VectorizedSearchReturnsValidResultsByDefaultAsync(bool includeVectors, string distanceFunction) + { + // Arrange + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 1f, 0f, 0f, 0f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 0f, 1f, 0f, 0f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 3.5f, Tags = ["tag1", "tag4"], DescriptionEmbedding = new[] { 0f, 0f, 1f, 0f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 1.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 0f, 0f, 0f, 1f } }; + + var sut = fixture.GetCollection>($"VectorizedSearch_{includeVectors}_{distanceFunction}", GetVectorStoreRecordDefinition(distanceFunction)); + + await sut.CreateCollectionAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([0.9f, 0.1f, 0.5f, 0.8f]), new() + { + IncludeVectors = includeVectors + }); + + var results = await searchResults.Results.ToListAsync(); + + // Assert + var ids = results.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal(1, ids[0]); + Assert.Equal(4, ids[1]); + Assert.Equal(3, ids[2]); + + // Default limit is 3 + Assert.DoesNotContain(2, ids); + + Assert.True(0 < results.First(l => l.Record.HotelId == 1).Score); + + Assert.Equal(includeVectors, results.All(result => result.Record.DescriptionEmbedding is not null)); + } + + [Fact] + public async Task VectorizedSearchWithEqualToFilterReturnsValidResultsAsync() + { + // Arrange + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; + + var sut = fixture.GetCollection>("VectorizedSearchWithEqualToFilter"); + + await sut.CreateCollectionAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 29f, 28f, 27f]), new() + { + IncludeVectors = false, + Top = 5, + Filter = new([ + new EqualToFilterClause("HotelRating", 2.5f) + ]) + }); + + var results = await searchResults.Results.ToListAsync(); + + // Assert + var ids = results.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal([1, 3, 2], ids); + } + + [Fact] + public async Task VectorizedSearchWithAnyTagFilterReturnsValidResultsAsync() + { + // Arrange + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag2", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; + + var sut = fixture.GetCollection>("VectorizedSearchWithAnyTagEqualToFilter"); + + await sut.CreateCollectionAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 29f, 28f, 27f]), new() + { + IncludeVectors = false, + Top = 5, + Filter = new([ + new AnyTagEqualToFilterClause("Tags", "tag2") + ]) + }); + + var results = await searchResults.Results.ToListAsync(); + + // Assert + var ids = results.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal([1, 3], ids); + } + + [Fact] + public async Task ItCanUpsertAndGetEnumerableTypesAsync() + { + // Arrange + var sut = fixture.GetCollection("UpsertAndGetEnumerableTypes"); + + await sut.CreateCollectionAsync(); + + var record = new RecordWithEnumerables + { + Id = 1, + ListInts = new() { 1, 2, 3 }, + CollectionInts = new HashSet() { 4, 5, 6 }, + EnumerableInts = [7, 8, 9], + ReadOnlyCollectionInts = new List { 10, 11, 12 }, + ReadOnlyListInts = new List { 13, 14, 15 } + }; + + // Act + await sut.UpsertAsync(record); + + var getResult = await sut.GetAsync(1); + + // Assert + Assert.NotNull(getResult); + Assert.Equal(1, getResult!.Id); + Assert.NotNull(getResult.ListInts); + Assert.Equal(3, getResult.ListInts!.Count); + Assert.Equal(1, getResult.ListInts![0]); + Assert.Equal(2, getResult.ListInts![1]); + Assert.Equal(3, getResult.ListInts![2]); + Assert.NotNull(getResult.CollectionInts); + Assert.Equal(3, getResult.CollectionInts!.Count); + Assert.Contains(4, getResult.CollectionInts); + Assert.Contains(5, getResult.CollectionInts); + Assert.Contains(6, getResult.CollectionInts); + Assert.NotNull(getResult.EnumerableInts); + Assert.Equal(3, getResult.EnumerableInts!.Count()); + Assert.Equal(7, getResult.EnumerableInts.ElementAt(0)); + Assert.Equal(8, getResult.EnumerableInts.ElementAt(1)); + Assert.Equal(9, getResult.EnumerableInts.ElementAt(2)); + Assert.NotNull(getResult.ReadOnlyCollectionInts); + Assert.Equal(3, getResult.ReadOnlyCollectionInts!.Count); + var readOnlyCollectionIntsList = getResult.ReadOnlyCollectionInts.ToList(); + Assert.Equal(10, readOnlyCollectionIntsList[0]); + Assert.Equal(11, readOnlyCollectionIntsList[1]); + Assert.Equal(12, readOnlyCollectionIntsList[2]); + Assert.NotNull(getResult.ReadOnlyListInts); + Assert.Equal(3, getResult.ReadOnlyListInts!.Count); + Assert.Equal(13, getResult.ReadOnlyListInts[0]); + Assert.Equal(14, getResult.ReadOnlyListInts[1]); + Assert.Equal(15, getResult.ReadOnlyListInts[2]); + } + + #region private ================================================================================== + + private static VectorStoreRecordDefinition GetVectorStoreRecordDefinition(string distanceFunction = DistanceFunction.CosineDistance) => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("HotelId", typeof(TKey)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)), + new VectorStoreRecordDataProperty("HotelCode", typeof(int)), + new VectorStoreRecordDataProperty("HotelRating", typeof(float?)), + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordDataProperty("ListInts", typeof(List)), + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.Hnsw, DistanceFunction = distanceFunction } + ] + }; + + private dynamic GetCollection(Type idType, string collectionName) + { + var method = typeof(PostgresVectorStoreFixture).GetMethod("GetCollection"); + var genericMethod = method!.MakeGenericMethod(idType, typeof(PostgresHotel<>).MakeGenericType(idType)); + return genericMethod.Invoke(fixture, [collectionName, null])!; + } + + private PostgresHotel CreateRecord(Type idType, TKey key) + { + var recordType = typeof(PostgresHotel<>).MakeGenericType(idType); + var record = (PostgresHotel)Activator.CreateInstance(recordType, key)!; + record.HotelName = "Hotel 1"; + record.HotelCode = 1; + record.ParkingIncluded = true; + record.HotelRating = 4.5f; + record.Tags = new List { "tag1", "tag2" }; + return record; + } + private static DateTime TruncateMilliseconds(DateTime dateTime) + { + return new DateTime(dateTime.Ticks - (dateTime.Ticks % TimeSpan.TicksPerSecond), dateTime.Kind); + } + + private static DateTimeOffset TruncateMilliseconds(DateTimeOffset dateTimeOffset) + { + return new DateTimeOffset(dateTimeOffset.Ticks - (dateTimeOffset.Ticks % TimeSpan.TicksPerSecond), dateTimeOffset.Offset); + } + +#pragma warning disable CA1812, CA1859 + private sealed class RecordWithEnumerables + { + [VectorStoreRecordKey] + public int Id { get; set; } + + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + public ReadOnlyMemory? Embedding { get; set; } + + [VectorStoreRecordData] + public List? ListInts { get; set; } + + [VectorStoreRecordData] + public ICollection? CollectionInts { get; set; } + + [VectorStoreRecordData] + public IEnumerable? EnumerableInts { get; set; } + + [VectorStoreRecordData] + public IReadOnlyCollection? ReadOnlyCollectionInts { get; set; } + + [VectorStoreRecordData] + public IReadOnlyList? ReadOnlyListInts { get; set; } + } +#pragma warning restore CA1812, CA1859 + + #endregion + +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs new file mode 100644 index 000000000000..3eb2c02d54c6 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq; +using System.Threading.Tasks; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +[Collection("PostgresVectorStoreCollection")] +public class PostgresVectorStoreTests(PostgresVectorStoreFixture fixture) +{ + [Fact] + public async Task ItCanGetAListOfExistingCollectionNamesAsync() + { + // Arrange + var sut = fixture.VectorStore; + + // Setup + var collection = sut.GetCollection>("VS_TEST_HOTELS"); + await collection.CreateCollectionIfNotExistsAsync(); + + // Act + var collectionNames = await sut.ListCollectionNamesAsync().ToListAsync(); + + // Assert + Assert.Contains("VS_TEST_HOTELS", collectionNames); + } +} diff --git a/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs b/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs index 844ae7e2f573..a85a509d1980 100644 --- a/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs +++ b/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -135,6 +136,40 @@ static async ValueTask Core(IAsyncEnumerable source, Func + /// Projects each element of an into a new form by incorporating + /// an asynchronous transformation function. + /// + /// The type of the elements of the source sequence. + /// The type of the elements of the resulting sequence. + /// An to invoke a transform function on. + /// + /// A transform function to apply to each element. This function takes an element of + /// type TSource and returns an element of type TResult. + /// + /// + /// A CancellationToken to observe while iterating through the sequence. + /// + /// + /// An whose elements are the result of invoking the transform + /// function on each element of the original sequence. + /// + /// Thrown when the source or selector is null. + public static async IAsyncEnumerable SelectAsync( + this IAsyncEnumerable source, + Func selector, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return selector(item); + } + } + +#pragma warning restore IDE1006 // Naming rule violation: Missing suffix: 'Async' + private sealed class EmptyAsyncEnumerable : IAsyncEnumerable, IAsyncEnumerator { public static readonly EmptyAsyncEnumerable Instance = new(); From 7c25ac4c3be7e19950916a41f918dbadf66a819a Mon Sep 17 00:00:00 2001 From: blurred83 Date: Mon, 16 Dec 2024 03:53:21 -0600 Subject: [PATCH 2/5] .Net: Fix typo in GettingStarted.Step3_Yaml_Prompt - CreatPrompt -> CreatePrompt (#9823) ### Motivation and Context Fixing a typo in a unit test name (CreatPromptFromYamlAsync) that ReSharper noticed. ### Description Changed CreatPromptFromYamlAsync to CreatePromptFromYamlAsync (and rebuilt/ran the test just to be sure). Co-authored-by: Max Szczurek Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com> --- dotnet/samples/GettingStarted/Step3_Yaml_Prompt.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/samples/GettingStarted/Step3_Yaml_Prompt.cs b/dotnet/samples/GettingStarted/Step3_Yaml_Prompt.cs index 29d50f7b6da7..a848779d4e96 100644 --- a/dotnet/samples/GettingStarted/Step3_Yaml_Prompt.cs +++ b/dotnet/samples/GettingStarted/Step3_Yaml_Prompt.cs @@ -15,7 +15,7 @@ public sealed class Step3_Yaml_Prompt(ITestOutputHelper output) : BaseTest(outpu /// Show how to create a prompt from a YAML resource. /// [Fact] - public async Task CreatPromptFromYamlAsync() + public async Task CreatePromptFromYamlAsync() { // Create a kernel with OpenAI chat completion Kernel kernel = Kernel.CreateBuilder() From 6d02eeff815915f12cb830180243532743c2a211 Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:00:06 +0000 Subject: [PATCH 3/5] .Net: Allow customization of building REST API operation URL, payload, and headers (#9985) ### Motivation and Context CopilotAgentPlugin functionality may need more control over the way url, headers and payload are created. ### Description This PR adds internal factories for creating URLs, headers, and payloads. The factories are kept internal because the necessity of having them and their structure may change in the future. --- .../Functions.OpenApi/HttpContentFactory.cs | 2 +- .../Model/RestApiOperationHeadersFactory.cs | 14 +++ .../Model/RestApiOperationPayloadFactory.cs | 23 ++++ .../Model/RestApiOperationUrlFactory.cs | 15 +++ .../RestApiOperationRunner.cs | 38 ++++-- .../OpenApi/RestApiOperationRunnerTests.cs | 115 ++++++++++++++++++ 6 files changed, 199 insertions(+), 8 deletions(-) create mode 100644 dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationHeadersFactory.cs create mode 100644 dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationPayloadFactory.cs create mode 100644 dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationUrlFactory.cs diff --git a/dotnet/src/Functions/Functions.OpenApi/HttpContentFactory.cs b/dotnet/src/Functions/Functions.OpenApi/HttpContentFactory.cs index c3ebf9251e0a..45cea8a3ec3a 100644 --- a/dotnet/src/Functions/Functions.OpenApi/HttpContentFactory.cs +++ b/dotnet/src/Functions/Functions.OpenApi/HttpContentFactory.cs @@ -11,4 +11,4 @@ namespace Microsoft.SemanticKernel.Plugins.OpenApi; /// The operation payload metadata. /// The operation arguments. /// The object and HttpContent representing the operation payload. -internal delegate (object? Payload, HttpContent Content) HttpContentFactory(RestApiPayload? payload, IDictionary arguments); +internal delegate (object Payload, HttpContent Content) HttpContentFactory(RestApiPayload? payload, IDictionary arguments); diff --git a/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationHeadersFactory.cs b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationHeadersFactory.cs new file mode 100644 index 000000000000..738a47a670f8 --- /dev/null +++ b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationHeadersFactory.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; + +namespace Microsoft.SemanticKernel.Plugins.OpenApi; + +/// +/// Represents a delegate for creating headers for a REST API operation. +/// +/// The REST API operation. +/// The arguments for the operation. +/// The operation run options. +/// The operation headers. +internal delegate IDictionary? RestApiOperationHeadersFactory(RestApiOperation operation, IDictionary arguments, RestApiOperationRunOptions? options); diff --git a/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationPayloadFactory.cs b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationPayloadFactory.cs new file mode 100644 index 000000000000..1000a616fe73 --- /dev/null +++ b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationPayloadFactory.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Net.Http; + +namespace Microsoft.SemanticKernel.Plugins.OpenApi; + +/// +/// Represents a delegate for creating a payload for a REST API operation. +/// +/// The REST API operation. +/// The arguments for the operation. +/// +/// Determines whether the operation payload is constructed dynamically based on operation payload metadata. +/// If false, the operation payload must be provided via the 'payload' property. +/// +/// +/// Determines whether payload parameters are resolved from the arguments by +/// full name (parameter name prefixed with the parent property name). +/// +/// The operation run options. +/// The operation payload. +internal delegate (object Payload, HttpContent Content)? RestApiOperationPayloadFactory(RestApiOperation operation, IDictionary arguments, bool enableDynamicPayload, bool enablePayloadNamespacing, RestApiOperationRunOptions? options); diff --git a/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationUrlFactory.cs b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationUrlFactory.cs new file mode 100644 index 000000000000..64736c6decbe --- /dev/null +++ b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationUrlFactory.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; + +namespace Microsoft.SemanticKernel.Plugins.OpenApi; + +/// +/// Represents a delegate for creating a URL for a REST API operation. +/// +/// The REST API operation. +/// The arguments for the operation. +/// The operation run options. +/// The operation URL. +internal delegate Uri? RestApiOperationUrlFactory(RestApiOperation operation, IDictionary arguments, RestApiOperationRunOptions? options); diff --git a/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs b/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs index 29b58fa6b480..9c1c2bcb1177 100644 --- a/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs +++ b/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs @@ -88,6 +88,21 @@ internal sealed class RestApiOperationRunner /// private readonly HttpResponseContentReader? _httpResponseContentReader; + /// + /// The external URL factory to use if provided, instead of the default one. + /// + private readonly RestApiOperationUrlFactory? _urlFactory; + + /// + /// The external header factory to use if provided, instead of the default one. + /// + private readonly RestApiOperationHeadersFactory? _headersFactory; + + /// + /// The external payload factory to use if provided, instead of the default one. + /// + private readonly RestApiOperationPayloadFactory? _payloadFactory; + /// /// Creates an instance of the class. /// @@ -100,19 +115,28 @@ internal sealed class RestApiOperationRunner /// Determines whether payload parameters are resolved from the arguments by /// full name (parameter name prefixed with the parent property name). /// Custom HTTP response content reader. + /// The external URL factory to use if provided if provided instead of the default one. + /// The external headers factory to use if provided instead of the default one. + /// The external payload factory to use if provided instead of the default one. public RestApiOperationRunner( HttpClient httpClient, AuthenticateRequestAsyncCallback? authCallback = null, string? userAgent = null, bool enableDynamicPayload = false, bool enablePayloadNamespacing = false, - HttpResponseContentReader? httpResponseContentReader = null) + HttpResponseContentReader? httpResponseContentReader = null, + RestApiOperationUrlFactory? urlFactory = null, + RestApiOperationHeadersFactory? headersFactory = null, + RestApiOperationPayloadFactory? payloadFactory = null) { this._httpClient = httpClient; this._userAgent = userAgent ?? HttpHeaderConstant.Values.UserAgent; this._enableDynamicPayload = enableDynamicPayload; this._enablePayloadNamespacing = enablePayloadNamespacing; this._httpResponseContentReader = httpResponseContentReader; + this._urlFactory = urlFactory; + this._headersFactory = headersFactory; + this._payloadFactory = payloadFactory; // If no auth callback provided, use empty function if (authCallback is null) @@ -145,13 +169,13 @@ public Task RunAsync( RestApiOperationRunOptions? options = null, CancellationToken cancellationToken = default) { - var url = this.BuildsOperationUrl(operation, arguments, options?.ServerUrlOverride, options?.ApiHostUrl); + var url = this._urlFactory?.Invoke(operation, arguments, options) ?? this.BuildsOperationUrl(operation, arguments, options?.ServerUrlOverride, options?.ApiHostUrl); - var headers = operation.BuildHeaders(arguments); + var headers = this._headersFactory?.Invoke(operation, arguments, options) ?? operation.BuildHeaders(arguments); - var operationPayload = this.BuildOperationPayload(operation, arguments); + var (Payload, Content) = this._payloadFactory?.Invoke(operation, arguments, this._enableDynamicPayload, this._enablePayloadNamespacing, options) ?? this.BuildOperationPayload(operation, arguments); - return this.SendAsync(url, operation.Method, headers, operationPayload.Payload, operationPayload.Content, operation.Responses.ToDictionary(item => item.Key, item => item.Value.Schema), options, cancellationToken); + return this.SendAsync(url, operation.Method, headers, Payload, Content, operation.Responses.ToDictionary(item => item.Key, item => item.Value.Schema), options, cancellationToken); } #region private @@ -340,7 +364,7 @@ private async Task ReadContentAndCreateOperationRespon /// The payload meta-data. /// The payload arguments. /// The JSON payload the corresponding HttpContent. - private (object? Payload, HttpContent Content) BuildJsonPayload(RestApiPayload? payloadMetadata, IDictionary arguments) + private (object Payload, HttpContent Content) BuildJsonPayload(RestApiPayload? payloadMetadata, IDictionary arguments) { // Build operation payload dynamically if (this._enableDynamicPayload) @@ -440,7 +464,7 @@ private JsonObject BuildJsonObject(IList properties, IDi /// The payload meta-data. /// The payload arguments. /// The text payload and corresponding HttpContent. - private (object? Payload, HttpContent Content) BuildPlainTextPayload(RestApiPayload? payloadMetadata, IDictionary arguments) + private (object Payload, HttpContent Content) BuildPlainTextPayload(RestApiPayload? payloadMetadata, IDictionary arguments) { if (!arguments.TryGetValue(RestApiOperation.PayloadArgumentName, out object? argument) || argument is not string payload) { diff --git a/dotnet/src/Functions/Functions.UnitTests/OpenApi/RestApiOperationRunnerTests.cs b/dotnet/src/Functions/Functions.UnitTests/OpenApi/RestApiOperationRunnerTests.cs index e30d115aaece..089644ad7848 100644 --- a/dotnet/src/Functions/Functions.UnitTests/OpenApi/RestApiOperationRunnerTests.cs +++ b/dotnet/src/Functions/Functions.UnitTests/OpenApi/RestApiOperationRunnerTests.cs @@ -1517,6 +1517,121 @@ public async Task ItShouldUseRestApiOperationPayloadPropertyNameToLookupArgument Assert.Equal("true", enabledProperty.ToString()); } + [Fact] + public async Task ItShouldUseUrlHeaderAndPayloadFactoriesIfProvidedAsync() + { + // Arrange + this._httpMessageHandlerStub.ResponseToReturn.Content = new StringContent("fake-content", Encoding.UTF8, MediaTypeNames.Application.Json); + + List payloadProperties = + [ + new("name", "string", true, []) + ]; + + var payload = new RestApiPayload(MediaTypeNames.Application.Json, payloadProperties); + + var expectedOperation = new RestApiOperation( + id: "fake-id", + servers: [new RestApiServer("https://fake-random-test-host")], + path: "fake-path", + method: HttpMethod.Post, + description: "fake-description", + parameters: [], + responses: new Dictionary(), + securityRequirements: [], + payload: payload + ); + + var expectedArguments = new KernelArguments(); + + var expectedOptions = new RestApiOperationRunOptions() + { + Kernel = new(), + KernelFunction = KernelFunctionFactory.CreateFromMethod(() => false), + KernelArguments = expectedArguments, + }; + + bool createUrlFactoryCalled = false; + bool createHeadersFactoryCalled = false; + bool createPayloadFactoryCalled = false; + + Uri CreateUrl(RestApiOperation operation, IDictionary arguments, RestApiOperationRunOptions? options) + { + createUrlFactoryCalled = true; + Assert.Same(expectedOperation, operation); + Assert.Same(expectedArguments, arguments); + Assert.Same(expectedOptions, options); + + return new Uri("https://fake-random-test-host-from-factory/"); + } + + IDictionary? CreateHeaders(RestApiOperation operation, IDictionary arguments, RestApiOperationRunOptions? options) + { + createHeadersFactoryCalled = true; + Assert.Same(expectedOperation, operation); + Assert.Same(expectedArguments, arguments); + Assert.Same(expectedOptions, options); + + return new Dictionary() { ["header-from-factory"] = "value-of-header-from-factory" }; + } + + (object Payload, HttpContent Content)? CreatePayload(RestApiOperation operation, IDictionary arguments, bool enableDynamicPayload, bool enablePayloadNamespacing, RestApiOperationRunOptions? options) + { + createPayloadFactoryCalled = true; + Assert.Same(expectedOperation, operation); + Assert.Same(expectedArguments, arguments); + Assert.True(enableDynamicPayload); + Assert.True(enablePayloadNamespacing); + Assert.Same(expectedOptions, options); + + var json = """{"name":"fake-name-value"}"""; + + return ((JsonObject)JsonObject.Parse(json)!, new StringContent(json, Encoding.UTF8, MediaTypeNames.Application.Json)); + } + + var sut = new RestApiOperationRunner( + this._httpClient, + enableDynamicPayload: true, + enablePayloadNamespacing: true, + urlFactory: CreateUrl, + headersFactory: CreateHeaders, + payloadFactory: CreatePayload); + + // Act + var result = await sut.RunAsync(expectedOperation, expectedArguments, expectedOptions); + + // Assert + Assert.True(createUrlFactoryCalled); + Assert.True(createHeadersFactoryCalled); + Assert.True(createPayloadFactoryCalled); + + // Assert url factory + Assert.NotNull(this._httpMessageHandlerStub.RequestUri); + Assert.Equal("https://fake-random-test-host-from-factory/", this._httpMessageHandlerStub.RequestUri.AbsoluteUri); + + // Assert headers factory + Assert.NotNull(this._httpMessageHandlerStub.RequestHeaders); + Assert.Equal(3, this._httpMessageHandlerStub.RequestHeaders.Count()); + + Assert.Contains(this._httpMessageHandlerStub.RequestHeaders, h => h.Key == "header-from-factory" && h.Value.Contains("value-of-header-from-factory")); + Assert.Contains(this._httpMessageHandlerStub.RequestHeaders, h => h.Key == "User-Agent" && h.Value.Contains("Semantic-Kernel")); + Assert.Contains(this._httpMessageHandlerStub.RequestHeaders, h => h.Key == "Semantic-Kernel-Version"); + + // Assert payload factory + var messageContent = this._httpMessageHandlerStub.RequestContent; + Assert.NotNull(messageContent); + + var deserializedPayload = await JsonNode.ParseAsync(new MemoryStream(messageContent)); + Assert.NotNull(deserializedPayload); + + var nameProperty = deserializedPayload["name"]?.ToString(); + Assert.Equal("fake-name-value", nameProperty); + + Assert.NotNull(result.RequestPayload); + Assert.IsType(result.RequestPayload); + Assert.Equal("""{"name":"fake-name-value"}""", ((JsonObject)result.RequestPayload).ToJsonString()); + } + public class SchemaTestData : IEnumerable { public IEnumerator GetEnumerator() From 5874188b2b967c72a4a309c28aa594b506ab32a2 Mon Sep 17 00:00:00 2001 From: Vincent Biret Date: Mon, 16 Dec 2024 11:12:08 -0500 Subject: [PATCH 4/5] .Net: fix: includes path item path parameters to OpenAPI document parsing (#9969) fixes #9962 --------- Signed-off-by: Vincent Biret Co-authored-by: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> --- .../OpenApi/OpenApiDocumentParser.cs | 25 ++- .../OpenApi/OpenApiDocumentParserV20Tests.cs | 160 ++++++++++++++++ .../OpenApi/OpenApiDocumentParserV30Tests.cs | 171 ++++++++++++++++++ .../OpenApi/OpenApiDocumentParserV31Tests.cs | 171 ++++++++++++++++++ 4 files changed, 525 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Functions/Functions.OpenApi/OpenApi/OpenApiDocumentParser.cs b/dotnet/src/Functions/Functions.OpenApi/OpenApi/OpenApiDocumentParser.cs index 67ba2d34e79a..4803d28e1e1b 100644 --- a/dotnet/src/Functions/Functions.OpenApi/OpenApi/OpenApiDocumentParser.cs +++ b/dotnet/src/Functions/Functions.OpenApi/OpenApi/OpenApiDocumentParser.cs @@ -211,7 +211,7 @@ internal static List CreateRestApiOperations(OpenApiDocument d path: path, method: new HttpMethod(method), description: string.IsNullOrEmpty(operationItem.Description) ? operationItem.Summary : operationItem.Description, - parameters: CreateRestApiOperationParameters(operationItem.OperationId, operationItem.Parameters), + parameters: CreateRestApiOperationParameters(operationItem.OperationId, operationItem.Parameters.Union(pathItem.Parameters, s_parameterNameAndLocationComparer)), payload: CreateRestApiOperationPayload(operationItem.OperationId, operationItem.RequestBody), responses: CreateRestApiOperationExpectedResponses(operationItem.Responses).ToDictionary(static item => item.Item1, static item => item.Item2), securityRequirements: CreateRestApiOperationSecurityRequirements(operationItem.Security) @@ -237,6 +237,27 @@ internal static List CreateRestApiOperations(OpenApiDocument d } } + private static readonly ParameterNameAndLocationComparer s_parameterNameAndLocationComparer = new(); + + /// + /// Compares two objects by their name and location. + /// + private sealed class ParameterNameAndLocationComparer : IEqualityComparer + { + public bool Equals(OpenApiParameter? x, OpenApiParameter? y) + { + if (x is null || y is null) + { + return x == y; + } + return this.GetHashCode(x) == this.GetHashCode(y); + } + public int GetHashCode([DisallowNull] OpenApiParameter obj) + { + return HashCode.Combine(obj.Name, obj.In); + } + } + /// /// Build a list of objects from the given list of objects. /// @@ -381,7 +402,7 @@ internal static List CreateRestApiOperationSecurityR /// The operation id. /// The OpenAPI parameters. /// The parameters. - private static List CreateRestApiOperationParameters(string operationId, IList parameters) + private static List CreateRestApiOperationParameters(string operationId, IEnumerable parameters) { var result = new List(); diff --git a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV20Tests.cs b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV20Tests.cs index 625420e2f956..9313297ace66 100644 --- a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV20Tests.cs +++ b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV20Tests.cs @@ -5,6 +5,7 @@ using System.IO; using System.Linq; using System.Net.Http; +using System.Text; using System.Text.Json; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -434,6 +435,165 @@ public async Task ItCanFilterOutSpecifiedOperationsAsync() Assert.Contains(restApiSpec.Operations, o => o.Id == "SetSecret"); Assert.Contains(restApiSpec.Operations, o => o.Id == "GetSecret"); } + [Fact] + public async Task ItCanParsePathItemPathParametersAsync() + { + var document = + """ + { + "swagger": "2.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "type": "string" + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } + + [Fact] + public async Task ItCanParsePathItemPathParametersAndOverridesAsync() + { + var document = + """ + { + "swagger": "2.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "type": "string" + }, + { + "name": "itemId", + "in": "path", + "description": "item ID override", + "required": true, + "type": "string" + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + Assert.Equal("item ID override", pathParameter.Description); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } private static RestApiParameter GetParameterMetadata(IList operations, string operationId, RestApiParameterLocation location, string name) diff --git a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV30Tests.cs b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV30Tests.cs index 8728771ac54a..02b3d363ebfb 100644 --- a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV30Tests.cs +++ b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV30Tests.cs @@ -5,6 +5,7 @@ using System.IO; using System.Linq; using System.Net.Http; +using System.Text; using System.Text.Json; using System.Text.Json.Nodes; using System.Threading.Tasks; @@ -500,6 +501,176 @@ public async Task ItCanParseDocumentWithMultipleServersAsync() Assert.Equal("https://ppe.my-key-vault.vault.azure.net", restApi.Operations[0].Servers[1].Url); } + [Fact] + public async Task ItCanParsePathItemPathParametersAsync() + { + var document = + """ + { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } + + [Fact] + public async Task ItCanParsePathItemPathParametersAndOverridesAsync() + { + var document = + """ + { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "itemId", + "in": "path", + "description": "item ID override", + "required": true, + "schema": { + "type": "string" + } + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + Assert.Equal("item ID override", pathParameter.Description); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } + private static MemoryStream ModifyOpenApiDocument(Stream openApiDocument, Action transformer) { var json = JsonSerializer.Deserialize(openApiDocument); diff --git a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV31Tests.cs b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV31Tests.cs index 6455b95dd34b..5fc59c70a8f9 100644 --- a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV31Tests.cs +++ b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV31Tests.cs @@ -6,6 +6,7 @@ using System.IO; using System.Linq; using System.Net.Http; +using System.Text; using System.Text.Json; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -477,6 +478,176 @@ public async Task ItCanParseDocumentWithMultipleServersAsync() Assert.Equal("https://ppe.my-key-vault.vault.azure.net", restApi.Operations[0].Servers[1].Url); } + [Fact] + public async Task ItCanParsePathItemPathParametersAsync() + {//TODO update the document version when upgrading Microsoft.OpenAPI to v2 + var document = + """ + { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } + + [Fact] + public async Task ItCanParsePathItemPathParametersAndOverridesAsync() + {//TODO update the document version when upgrading Microsoft.OpenAPI to v2 + var document = + """ + { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "itemId", + "in": "path", + "description": "item ID override", + "required": true, + "schema": { + "type": "string" + } + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + Assert.Equal("item ID override", pathParameter.Description); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } + private static MemoryStream ModifyOpenApiDocument(Stream openApiDocument, Action> transformer) { var serializer = new SharpYaml.Serialization.Serializer(); From 62a50f32cf140b24876517219726a28465ef640e Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Mon, 16 Dec 2024 20:14:36 +0100 Subject: [PATCH 5/5] Python: Qdrant - fix in filter and 100% test coverage (#9982) ### Motivation and Context There was a small error in the filter creation logic, and improved test coverage for Qdrant. ### Description ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .../memory/qdrant/qdrant_collection.py | 4 +- .../connectors/memory/qdrant/test_qdrant.py | 84 ++++++++++++++++--- 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py b/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py index 5fb8c177be89..cb30fa0cdc76 100644 --- a/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py +++ b/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py @@ -188,7 +188,7 @@ async def _inner_search( else: query_vector = vector if query_vector is None: - raise VectorSearchExecutionException("Search requires either a vector.") + raise VectorSearchExecutionException("Search requires a vector.") results = await self.qdrant_client.search( collection_name=self.collection_name, query_vector=query_vector, @@ -214,7 +214,7 @@ def _get_score_from_result(self, result: ScoredPoint) -> float: def _create_filter(self, options: VectorSearchOptions) -> Filter: return Filter( must=[ - FieldCondition(key=filter.field_name, match=MatchAny(any=filter.value)) + FieldCondition(key=filter.field_name, match=MatchAny(any=[filter.value])) for filter in options.filter.filters ] ) diff --git a/python/tests/unit/connectors/memory/qdrant/test_qdrant.py b/python/tests/unit/connectors/memory/qdrant/test_qdrant.py index ce00e7d88c95..c92571daf238 100644 --- a/python/tests/unit/connectors/memory/qdrant/test_qdrant.py +++ b/python/tests/unit/connectors/memory/qdrant/test_qdrant.py @@ -4,17 +4,19 @@ from pytest import fixture, mark, raises from qdrant_client.async_qdrant_client import AsyncQdrantClient -from qdrant_client.models import Datatype, Distance, VectorParams +from qdrant_client.models import Datatype, Distance, FieldCondition, Filter, MatchAny, VectorParams from semantic_kernel.connectors.memory.qdrant.qdrant_collection import QdrantCollection from semantic_kernel.connectors.memory.qdrant.qdrant_store import QdrantStore from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordVectorField +from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions from semantic_kernel.exceptions.memory_connector_exceptions import ( MemoryConnectorException, MemoryConnectorInitializationError, VectorStoreModelValidationError, ) +from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException BASE_PATH = "qdrant_client.async_qdrant_client.AsyncQdrantClient" @@ -119,9 +121,10 @@ def mock_search(): yield mock_search -def test_vector_store_defaults(vector_store): - assert vector_store.qdrant_client is not None - assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333" +async def test_vector_store_defaults(vector_store): + async with vector_store: + assert vector_store.qdrant_client is not None + assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333" def test_vector_store_with_client(): @@ -162,18 +165,18 @@ def test_get_collection(vector_store, data_model_definition, qdrant_unit_test_en assert vector_store.vector_record_collections["test"] == collection -def test_collection_init(data_model_definition, qdrant_unit_test_env): - collection = QdrantCollection( +async def test_collection_init(data_model_definition, qdrant_unit_test_env): + async with QdrantCollection( data_model_type=dict, collection_name="test", data_model_definition=data_model_definition, env_file_path="test.env", - ) - assert collection.collection_name == "test" - assert collection.qdrant_client is not None - assert collection.data_model_type is dict - assert collection.data_model_definition == data_model_definition - assert collection.named_vectors + ) as collection: + assert collection.collection_name == "test" + assert collection.qdrant_client is not None + assert collection.data_model_type is dict + assert collection.data_model_definition == data_model_definition + assert collection.named_vectors def test_collection_init_fail(data_model_definition): @@ -275,8 +278,63 @@ async def test_create_index_fail(collection_to_use, request): await collection.create_collection() -async def test_search(collection): +async def test_search(collection, mock_search): results = await collection._inner_search(vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(include_vectors=False)) async for result in results.results: assert result.record["id"] == "id1" break + + assert mock_search.call_count == 1 + mock_search.assert_called_with( + collection_name="test", + query_vector=[1.0, 2.0, 3.0], + query_filter=Filter(must=[]), + with_vectors=False, + limit=3, + offset=0, + ) + + +async def test_search_named_vectors(collection, mock_search): + collection.named_vectors = True + results = await collection._inner_search( + vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(vector_field_name="vector", include_vectors=False) + ) + async for result in results.results: + assert result.record["id"] == "id1" + break + + assert mock_search.call_count == 1 + mock_search.assert_called_with( + collection_name="test", + query_vector=("vector", [1.0, 2.0, 3.0]), + query_filter=Filter(must=[]), + with_vectors=False, + limit=3, + offset=0, + ) + + +async def test_search_filter(collection, mock_search): + results = await collection._inner_search( + vector=[1.0, 2.0, 3.0], + options=VectorSearchOptions(include_vectors=False, filter=VectorSearchFilter.equal_to("id", "id1")), + ) + async for result in results.results: + assert result.record["id"] == "id1" + break + + assert mock_search.call_count == 1 + mock_search.assert_called_with( + collection_name="test", + query_vector=[1.0, 2.0, 3.0], + query_filter=Filter(must=[FieldCondition(key="id", match=MatchAny(any=["id1"]))]), + with_vectors=False, + limit=3, + offset=0, + ) + + +async def test_search_fail(collection): + with raises(VectorSearchExecutionException, match="Search requires a vector."): + await collection._inner_search(options=VectorSearchOptions(include_vectors=False))