From fbbd444992df9b045749abc6774e0088d80923cc Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 13 Dec 2024 11:15:04 -0800 Subject: [PATCH 01/11] Python: Add PR number to test coverage workflow (#9964) ### Motivation and Context The test coverage workflow needs the PR number to post comments to PRs. However, the `workflow_run` context doesn't contain the PR number thus we need to save it. ### Description Save the PR number in a file and upload it as artifact for the second workflow to use. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .github/workflows/python-test-coverage-report.yml | 9 +++++++++ .github/workflows/python-test-coverage.yml | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/.github/workflows/python-test-coverage-report.yml b/.github/workflows/python-test-coverage-report.yml index 7f0d323bb710..67c848609f6b 100644 --- a/.github/workflows/python-test-coverage-report.yml +++ b/.github/workflows/python-test-coverage-report.yml @@ -25,10 +25,19 @@ jobs: merge-multiple: true - name: Display structure of downloaded files run: ls + - name: Read and set PR number + # Need to read the PR number from the file saved in the previous workflow + # because the workflow_run event does not have access to the PR number + # The PR number is needed to post the comment on the PR + run: | + PR_NUMBER=$(cat pr_number) + echo "PR number: $PR_NUMBER" + echo "::set-env name=PR_NUMBER::$PR_NUMBER" - name: Pytest coverage comment id: coverageComment uses: MishaKav/pytest-coverage-comment@main with: + issue-number: ${{ env.PR_NUMBER }} pytest-coverage-path: python/python-coverage.txt title: "Python Test Coverage Report" badge-title: "Python Test Coverage" diff --git a/.github/workflows/python-test-coverage.yml b/.github/workflows/python-test-coverage.yml index 7ffc9925fb34..5d67b29b6b12 100644 --- a/.github/workflows/python-test-coverage.yml +++ b/.github/workflows/python-test-coverage.yml @@ -21,6 +21,11 @@ jobs: UV_PYTHON: "3.10" steps: - uses: actions/checkout@v4 + # Save the PR number to a file since the workflow_run event + # in the coverage report workflow does not have access to it + - name: Save PR number + run: | + echo ${{ github.event.number }} > ./pr_number - name: Set up uv uses: astral-sh/setup-uv@v4 with: @@ -37,6 +42,7 @@ jobs: path: | python/python-coverage.txt python/pytest.xml + python/pr_number overwrite: true retention-days: 1 if-no-files-found: error From 7531abfc2a197eeddf86478e257f158e2a6e379a Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 13 Dec 2024 14:22:58 -0800 Subject: [PATCH 02/11] Python: Deprecated retry_mechanism (#9965) ### Motivation and Context In Python, we have the `retry_mechanism` property in the kernel but it's never used. Address: https://github.com/microsoft/semantic-kernel/issues/9015 ### Description 1. Deprecate the retry_mechanism field. 2. Add a sample showing how to perform retries with the kernel using filter. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .../concepts/filtering/retry_with_filters.py | 103 ++++++++++++++++++ python/semantic_kernel/kernel.py | 10 +- .../kernel_reliability_extension.py | 6 +- python/tests/unit/kernel/test_kernel.py | 5 +- 4 files changed, 111 insertions(+), 13 deletions(-) create mode 100644 python/samples/concepts/filtering/retry_with_filters.py diff --git a/python/samples/concepts/filtering/retry_with_filters.py b/python/samples/concepts/filtering/retry_with_filters.py new file mode 100644 index 000000000000..92131ad1d292 --- /dev/null +++ b/python/samples/concepts/filtering/retry_with_filters.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import logging +from collections.abc import Callable, Coroutine +from typing import Any + +from samples.concepts.setup.chat_completion_services import Services, get_chat_completion_service_and_request_settings +from semantic_kernel import Kernel +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior +from semantic_kernel.contents import ChatHistory +from semantic_kernel.filters import FunctionInvocationContext +from semantic_kernel.filters.filter_types import FilterTypes +from semantic_kernel.functions import kernel_function + +# This sample shows how to use a filter for retrying a function invocation. +# This sample requires the following components: +# - a ChatCompletionService: This component is responsible for generating responses to user messages. +# - a ChatHistory: This component is responsible for keeping track of the chat history. +# - a Kernel: This component is responsible for managing plugins and filters. +# - a mock plugin: This plugin contains a function that simulates a call to an external service. +# - a filter: This filter retries the function invocation if it fails. + +logger = logging.getLogger(__name__) + +# The maximum number of retries for the filter +MAX_RETRIES = 3 + + +class WeatherPlugin: + MAX_FAILURES = 2 + + def __init__(self): + self._invocation_count = 0 + + @kernel_function(name="GetWeather", description="Get the weather of the day at the current location.") + def get_wather(self) -> str: + """Get the weather of the day at the current location. + + Simulates a call to an external service to get the weather. + This function is designed to fail a certain number of times before succeeding. + """ + if self._invocation_count < self.MAX_FAILURES: + self._invocation_count += 1 + print(f"Number of attempts: {self._invocation_count}") + raise Exception("Failed to get the weather") + + return "Sunny" + + +async def retry_filter( + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Coroutine[Any, Any, None]], +) -> None: + """A filter that retries the function invocation if it fails. + + The filter uses a binary exponential backoff strategy to retry the function invocation. + """ + for i in range(MAX_RETRIES): + try: + await next(context) + return + except Exception as e: + logger.warning(f"Failed to execute the function: {e}") + backoff = 2**i + logger.info(f"Sleeping for {backoff} seconds before retrying") + + +async def main() -> None: + kernel = Kernel() + # Register the plugin to the kernel + kernel.add_plugin(WeatherPlugin(), plugin_name="WeatherPlugin") + # Add the filter to the kernel as a function invocation filter + # A function invocation filter is called during when the kernel executes a function + kernel.add_filter(FilterTypes.FUNCTION_INVOCATION, retry_filter) + + chat_history = ChatHistory() + chat_history.add_user_message("What is the weather today?") + + chat_completion_service, request_settings = get_chat_completion_service_and_request_settings(Services.OPENAI) + # Need to set the function choice behavior to auto such that the + # service will automatically invoke the function in the response. + request_settings.function_choice_behavior = FunctionChoiceBehavior.Auto() + + response = await chat_completion_service.get_chat_message_content( + chat_history=chat_history, + settings=request_settings, + # Need to pass the kernel to the chat completion service so that it has access to the plugins and filters + kernel=kernel, + ) + + print(response) + + # Sample output: + # Number of attempts: 1 + # Failed to execute the function: Failed to get the weather + # Number of attempts: 2 + # Failed to execute the function: Failed to get the weather + # The weather today is Sunny + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/semantic_kernel/kernel.py b/python/semantic_kernel/kernel.py index 3fd5d33dcc1d..5ddef255f355 100644 --- a/python/semantic_kernel/kernel.py +++ b/python/semantic_kernel/kernel.py @@ -65,8 +65,6 @@ class Kernel(KernelFilterExtension, KernelFunctionExtension, KernelServicesExten plugins: A dict with the plugins registered with the Kernel, from KernelFunctionExtension. services: A dict with the services registered with the Kernel, from KernelServicesExtension. ai_service_selector: The AI service selector to be used by the kernel, from KernelServicesExtension. - retry_mechanism: The retry mechanism to be used by the kernel, from KernelReliabilityExtension. - """ def __init__( @@ -84,12 +82,8 @@ def __init__( plugins: The plugins to be used by the kernel, will be rewritten to a dict with plugin name as key services: The services to be used by the kernel, will be rewritten to a dict with service_id as key ai_service_selector: The AI service selector to be used by the kernel, - default is based on order of execution settings. - **kwargs: Additional fields to be passed to the Kernel model, - these are limited to retry_mechanism and function_invoking_handlers - and function_invoked_handlers, the best way to add function_invoking_handlers - and function_invoked_handlers is to use the add_function_invoking_handler - and add_function_invoked_handler methods. + default is based on order of execution settings. + **kwargs: Additional fields to be passed to the Kernel model, these are limited to filters. """ args = { "services": services, diff --git a/python/semantic_kernel/reliability/kernel_reliability_extension.py b/python/semantic_kernel/reliability/kernel_reliability_extension.py index 82a020cfdeff..9c89766c47db 100644 --- a/python/semantic_kernel/reliability/kernel_reliability_extension.py +++ b/python/semantic_kernel/reliability/kernel_reliability_extension.py @@ -4,6 +4,7 @@ from abc import ABC from pydantic import Field +from typing_extensions import deprecated from semantic_kernel.kernel_pydantic import KernelBaseModel from semantic_kernel.reliability.pass_through_without_retry import PassThroughWithoutRetry @@ -15,4 +16,7 @@ class KernelReliabilityExtension(KernelBaseModel, ABC): """Kernel reliability extension.""" - retry_mechanism: RetryMechanismBase = Field(default_factory=PassThroughWithoutRetry) + retry_mechanism: RetryMechanismBase = Field( + default_factory=PassThroughWithoutRetry, + deprecated=deprecated("retry_mechanism is deprecated; This property doesn't have any effect on the kernel."), + ) diff --git a/python/tests/unit/kernel/test_kernel.py b/python/tests/unit/kernel/test_kernel.py index 4180994792dd..38b1608d150f 100644 --- a/python/tests/unit/kernel/test_kernel.py +++ b/python/tests/unit/kernel/test_kernel.py @@ -18,10 +18,7 @@ from semantic_kernel.contents import ChatMessageContent from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.function_call_content import FunctionCallContent -from semantic_kernel.exceptions import ( - KernelFunctionAlreadyExistsError, - KernelServiceNotFoundError, -) +from semantic_kernel.exceptions import KernelFunctionAlreadyExistsError, KernelServiceNotFoundError from semantic_kernel.exceptions.content_exceptions import FunctionCallInvalidArgumentsException from semantic_kernel.exceptions.kernel_exceptions import ( KernelFunctionNotFoundError, From adb49f3cb9d607db84e7e7956dbda3d0c5988b33 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 13 Dec 2024 15:50:23 -0800 Subject: [PATCH 03/11] Python: Fix set-env not allow in workflow (#9973) ### Motivation and Context `set-env` command is disabled in our workflow. ![image](https://github.com/user-attachments/assets/a6813291-e58f-4691-ae05-efd35c04ec15) ### Description Use `$GITHUB_ENV` instead. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com> --- .github/workflows/python-test-coverage-report.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-test-coverage-report.yml b/.github/workflows/python-test-coverage-report.yml index 67c848609f6b..f63b81c18d77 100644 --- a/.github/workflows/python-test-coverage-report.yml +++ b/.github/workflows/python-test-coverage-report.yml @@ -32,7 +32,7 @@ jobs: run: | PR_NUMBER=$(cat pr_number) echo "PR number: $PR_NUMBER" - echo "::set-env name=PR_NUMBER::$PR_NUMBER" + echo "PR_NUMBER=$PR_NUMBER" >> $GITHUB_ENV - name: Pytest coverage comment id: coverageComment uses: MishaKav/pytest-coverage-comment@main From 12a4d4095798c4720447816f4e834bd038faee58 Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Mon, 16 Dec 2024 10:04:04 +0100 Subject: [PATCH 04/11] Python: add write token to report (#9980) ### Motivation and Context The coverage report write job did not have the right permissions yet. ### Description ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .github/workflows/python-test-coverage-report.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/python-test-coverage-report.yml b/.github/workflows/python-test-coverage-report.yml index f63b81c18d77..01a2c7dc48e8 100644 --- a/.github/workflows/python-test-coverage-report.yml +++ b/.github/workflows/python-test-coverage-report.yml @@ -6,6 +6,10 @@ on: types: - completed +permissions: + contents: read + pull-requests: write + jobs: python-test-coverage-report: runs-on: ubuntu-latest @@ -37,6 +41,7 @@ jobs: id: coverageComment uses: MishaKav/pytest-coverage-comment@main with: + github-token: ${{ secrets.GH_ACTIONS_PR_WRITE }} issue-number: ${{ env.PR_NUMBER }} pytest-coverage-path: python/python-coverage.txt title: "Python Test Coverage Report" From c7a371e3861bda3812a11950435fc5aa75c5572d Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Mon, 16 Dec 2024 04:33:04 -0500 Subject: [PATCH 05/11] .Net: Add PostgresVectorStore Memory connector. (#9324) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a PostgresVectorStore and related classes to Microsoft.SemanticKernel.Connectors.Postgres. ### Motivation and Context As part of the move to having memory connectors implement the new Microsoft.Extensions.VectorData.IVectorStore architecture (see https://github.com/microsoft/semantic-kernel/blob/main/docs/decisions/0050-updated-vector-store-design.md), each memory connector needs to be updated with the new architecture. This PR tackles updating the existing Microsoft.SemanticKernel.Connectors.Postgres package to include this implementation. This will supercede the PostgresMemoryStore implementation. Some high level comments about design: - PostgresVectorStore and PostgresVectorStoreRecordCollection get injected with an IPostgresVectorStoreDbClient. This abstracts the database communication and allows for unit tests to mock database interactions. - The PostgresVectorStoreDbClient gets passed in a NpgsqlDataSource from the user, which is used to manage connections to the database. The responsibility of connection pool lifecycle management is on the user. - The IPostgresVectorStoreDbClient is designed to accept and produce the storage model, which in this case is a Dictionary . This is the intermediate type that is mapped to by the IVectorStoreRecordMapper. - The PostgresVectorStoreDbClient also takes a IPostgresVectorStoreCollectionSqlBuilder, which generates SQL command information for interacting with the database. This abstracts the SQL queries related to each task, and allows for future expansion. This is particularly targeted at creating a AzureDBForPostgre vector store that will enable alternate vector implementations like [DiskANN](https://techcommunity.microsoft.com/t5/azure-database-for-postgresql/introducing-diskann-vector-index-in-azure-database-for/ba-p/4261192), while leveraging the same database client as the Postgres connector. -  The integration tests for the vector store utilize Docker.Net to bring up a pgvector/pgvector docker container, which test are run against. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --------- Co-authored-by: Rob Emanuele Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> --- dotnet/SK-dotnet.sln | 9 + dotnet/samples/Concepts/Concepts.csproj | 3 + .../VectorStoreFixtures/VectorStoreInfra.cs | 45 ++ .../VectorStorePostgresContainerFixture.cs | 67 +++ ...rStore_VectorSearch_MultiStore_Postgres.cs | 85 +++ dotnet/samples/Concepts/README.md | 82 +++ .../Connectors.Memory.Postgres.csproj | 5 + .../IPostgresDbClient.cs | 2 +- ...PostgresVectorStoreCollectionSqlBuilder.cs | 136 +++++ .../IPostgresVectorStoreDbClient.cs | 132 ++++ ...tgresVectorStoreRecordCollectionFactory.cs | 24 + .../PostgresConstants.cs | 92 +++ .../PostgresDbClient.cs | 2 +- .../PostgresGenericDataModelMapper.cs | 104 ++++ .../PostgresServiceCollectionExtensions.cs | 172 ++++++ .../PostgresSqlCommandInfo.cs | 55 ++ .../PostgresVectorStore.cs | 75 +++ ...PostgresVectorStoreCollectionSqlBuilder.cs | 453 ++++++++++++++ .../PostgresVectorStoreDbClient.cs | 253 ++++++++ .../PostgresVectorStoreOptions.cs | 19 + .../PostgresVectorStoreRecordCollection.cs | 378 ++++++++++++ ...tgresVectorStoreRecordCollectionOptions.cs | 35 ++ .../PostgresVectorStoreRecordMapper.cs | 100 ++++ ...ostgresVectorStoreRecordPropertyMapping.cs | 269 +++++++++ .../PostgresVectorStoreUtils.cs | 59 ++ .../Connectors.Memory.Postgres/README.md | 75 +-- .../Connectors.Postgres.UnitTests.csproj | 32 + .../PostgresGenericDataModelMapperTests.cs | 190 ++++++ .../PostgresHotel.cs | 51 ++ ...ostgresServiceCollectionExtensionsTests.cs | 70 +++ ...resVectorStoreCollectionSqlBuilderTests.cs | 422 +++++++++++++ ...ostgresVectorStoreRecordCollectionTests.cs | 207 +++++++ .../PostgresVectorStoreRecordMapperTests.cs | 213 +++++++ ...esVectorStoreRecordPropertyMappingTests.cs | 147 +++++ .../PostgresVectorStoreTests.cs | 143 +++++ .../Memory/Postgres/PostgresHotel.cs | 60 ++ .../Postgres/PostgresMemoryStoreTests.cs | 6 +- .../PostgresVectorStoreCollectionFixture.cs | 10 + .../Postgres/PostgresVectorStoreFixture.cs | 239 ++++++++ ...ostgresVectorStoreRecordCollectionTests.cs | 562 ++++++++++++++++++ .../Postgres/PostgresVectorStoreTests.cs | 28 + .../src/Linq/AsyncEnumerable.cs | 35 ++ 42 files changed, 5074 insertions(+), 72 deletions(-) create mode 100644 dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs create mode 100644 dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 0844db359552..0a711f84f5f3 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -411,6 +411,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AotCompatibility", "samples EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SemanticKernel.AotTests", "src\SemanticKernel.AotTests\SemanticKernel.AotTests.csproj", "{39EAB599-742F-417D-AF80-95F90376BB18}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.Postgres.UnitTests", "src\Connectors\Connectors.Postgres.UnitTests\Connectors.Postgres.UnitTests.csproj", "{232E1153-6366-4175-A982-D66B30AAD610}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Process.Utilities.UnitTests", "src\Experimental\Process.Utilities.UnitTests\Process.Utilities.UnitTests.csproj", "{DAC54048-A39A-4739-8307-EA5A291F2EA0}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "GettingStartedWithVectorStores", "samples\GettingStartedWithVectorStores\GettingStartedWithVectorStores.csproj", "{8C3DE41C-E2C8-42B9-8638-574F8946EB0E}" @@ -1074,6 +1076,12 @@ Global {6F591D05-5F7F-4211-9042-42D8BCE60415}.Publish|Any CPU.Build.0 = Debug|Any CPU {6F591D05-5F7F-4211-9042-42D8BCE60415}.Release|Any CPU.ActiveCfg = Release|Any CPU {6F591D05-5F7F-4211-9042-42D8BCE60415}.Release|Any CPU.Build.0 = Release|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Debug|Any CPU.Build.0 = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Publish|Any CPU.Build.0 = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Release|Any CPU.ActiveCfg = Release|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Release|Any CPU.Build.0 = Release|Any CPU {E82B640C-1704-430D-8D71-FD8ED3695468}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {E82B640C-1704-430D-8D71-FD8ED3695468}.Debug|Any CPU.Build.0 = Debug|Any CPU {E82B640C-1704-430D-8D71-FD8ED3695468}.Publish|Any CPU.ActiveCfg = Debug|Any CPU @@ -1311,6 +1319,7 @@ Global {E82B640C-1704-430D-8D71-FD8ED3695468} = {5A7028A7-4DDF-4E4F-84A9-37CE8F8D7E89} {6ECFDF04-2237-4A85-B114-DAA34923E9E6} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {39EAB599-742F-417D-AF80-95F90376BB18} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} + {232E1153-6366-4175-A982-D66B30AAD610} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {DAC54048-A39A-4739-8307-EA5A291F2EA0} = {0D8C6358-5DAA-4EA6-A924-C268A9A21BC9} {8C3DE41C-E2C8-42B9-8638-574F8946EB0E} = {FA3720F1-C99A-49B2-9577-A940257098BF} {DB58FDD0-308E-472F-BFF5-508BC64C727E} = {0D8C6358-5DAA-4EA6-A924-C268A9A21BC9} diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index d65aef92e0c3..746d5fbb73cf 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -102,6 +102,9 @@ + + Always + PreserveNewest diff --git a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs index ea498f20c5ab..2681231c80d7 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs @@ -10,6 +10,51 @@ namespace Memory.VectorStoreFixtures; /// internal static class VectorStoreInfra { + /// + /// Setup the postgres pgvector container by pulling the image and running it. + /// + /// The docker client to create the container with. + /// The id of the container. + public static async Task SetupPostgresContainerAsync(DockerClient client) + { + await client.Images.CreateImageAsync( + new ImagesCreateParameters + { + FromImage = "pgvector/pgvector", + Tag = "pg16", + }, + null, + new Progress()); + + var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() + { + Image = "pgvector/pgvector:pg16", + HostConfig = new HostConfig() + { + PortBindings = new Dictionary> + { + {"5432", new List {new() {HostPort = "5432" } }}, + }, + PublishAllPorts = true + }, + ExposedPorts = new Dictionary + { + { "5432", default }, + }, + Env = new List + { + "POSTGRES_USER=postgres", + "POSTGRES_PASSWORD=example", + }, + }); + + await client.Containers.StartContainerAsync( + container.ID, + new ContainerStartParameters()); + + return container.ID; + } + /// /// Setup the qdrant container by pulling the image and running it. /// diff --git a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs new file mode 100644 index 000000000000..200c4e48f5ac --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet; +using Npgsql; + +namespace Memory.VectorStoreFixtures; + +/// +/// Fixture to use for creating a Postgres container before tests and delete it after tests. +/// +public class VectorStorePostgresContainerFixture : IAsyncLifetime +{ + private DockerClient? _dockerClient; + private string? _postgresContainerId; + + public async Task InitializeAsync() + { + } + + public async Task ManualInitializeAsync() + { + if (this._postgresContainerId == null) + { + // Connect to docker and start the docker container. + using var dockerClientConfiguration = new DockerClientConfiguration(); + this._dockerClient = dockerClientConfiguration.CreateClient(); + this._postgresContainerId = await VectorStoreInfra.SetupPostgresContainerAsync(this._dockerClient); + + // Delay until the Postgres server is ready. + var connectionString = TestConfiguration.Postgres.ConnectionString; + var succeeded = false; + var attemptCount = 0; + while (!succeeded && attemptCount++ < 10) + { + try + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + using var dataSource = dataSourceBuilder.Build(); + NpgsqlConnection connection = await dataSource.OpenConnectionAsync().ConfigureAwait(false); + + await using (connection) + { + // Create extension vector if it doesn't exist + await using (NpgsqlCommand command = new("CREATE EXTENSION IF NOT EXISTS vector", connection)) + { + await command.ExecuteNonQueryAsync(); + } + } + } + catch (Exception) + { + await Task.Delay(1000); + } + } + } + } + + public async Task DisposeAsync() + { + if (this._dockerClient != null && this._postgresContainerId != null) + { + // Delete docker container. + await VectorStoreInfra.DeleteContainerAsync(this._dockerClient, this._postgresContainerId); + } + } +} diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs new file mode 100644 index 000000000000..e45c3390a2c0 --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Azure.Identity; +using Memory.VectorStoreFixtures; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; + +namespace Memory; + +/// +/// An example showing how to use common code, that can work with any vector database, with a Postgres database. +/// The common code is in the class. +/// The common code ingests data into the vector store and then searches over that data. +/// This example is part of a set of examples each showing a different vector database. +/// +/// For other databases, see the following classes: +/// +/// +/// +/// +/// To run this sample, you need a local instance of Docker running, since the associated fixture will try and start a Postgres container in the local docker instance. +/// +public class VectorStore_VectorSearch_MultiStore_Postgres(ITestOutputHelper output, VectorStorePostgresContainerFixture PostgresFixture) : BaseTest(output), IClassFixture +{ + [Fact] + public async Task ExampleWithDIAsync() + { + // Use the kernel for DI purposes. + var kernelBuilder = Kernel + .CreateBuilder(); + + // Register an embedding generation service with the DI container. + kernelBuilder.AddAzureOpenAITextEmbeddingGeneration( + deploymentName: TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + endpoint: TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + credential: new AzureCliCredential()); + + // Initialize the Postgres docker container via the fixtures and register the Postgres VectorStore. + await PostgresFixture.ManualInitializeAsync(); + kernelBuilder.Services.AddPostgresVectorStore(TestConfiguration.Postgres.ConnectionString); + + // Register the test output helper common processor with the DI container. + kernelBuilder.Services.AddSingleton(this.Output); + kernelBuilder.Services.AddTransient(); + + // Build the kernel. + var kernel = kernelBuilder.Build(); + + // Build a common processor object using the DI container. + var processor = kernel.GetRequiredService(); + + // Run the process and pass a key generator function to it, to generate unique record keys. + // The key generator function is required, since different vector stores may require different key types. + // E.g. Postgres supports Guid and ulong keys, but others may support strings only. + await processor.IngestDataAndSearchAsync("skglossaryWithDI", () => Guid.NewGuid()); + } + + [Fact] + public async Task ExampleWithoutDIAsync() + { + // Create an embedding generation service. + var textEmbeddingGenerationService = new AzureOpenAITextEmbeddingGenerationService( + TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + new AzureCliCredential()); + + // Initialize the Postgres docker container via the fixtures and construct the Postgres VectorStore. + await PostgresFixture.ManualInitializeAsync(); + var dataSourceBuilder = new NpgsqlDataSourceBuilder(TestConfiguration.Postgres.ConnectionString); + dataSourceBuilder.UseVector(); + await using var dataSource = dataSourceBuilder.Build(); + var vectorStore = new PostgresVectorStore(dataSource); + + // Create the common processor that works for any vector store. + var processor = new VectorStore_VectorSearch_MultiStore_Common(vectorStore, textEmbeddingGenerationService, this.Output); + + // Run the process and pass a key generator function to it, to generate unique record keys. + // The key generator function is required, since different vector stores may require different key types. + // E.g. Postgres supports Guid and ulong keys, but others may support strings only. + await processor.IngestDataAndSearchAsync("skglossaryWithoutDI", () => Guid.NewGuid()); + } +} diff --git a/dotnet/samples/Concepts/README.md b/dotnet/samples/Concepts/README.md index 6b0f28b329ca..deb3a6a43a20 100644 --- a/dotnet/samples/Concepts/README.md +++ b/dotnet/samples/Concepts/README.md @@ -215,3 +215,85 @@ dotnet test -l "console;verbosity=detailed" --filter "FullyQualifiedName=ChatCom - [OpenAI_TextToImage](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/TextToImage/OpenAI_TextToImage.cs) - [OpenAI_TextToImageLegacy](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/TextToImage/OpenAI_TextToImageLegacy.cs) - [AzureOpenAI_TextToImage](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/TextToImage/AzureOpenAI_TextToImage.cs) + +## Configuration + +### Option 1: Use Secret Manager + +Concept samples will require secrets and credentials, to access OpenAI, Azure OpenAI, +Bing and other resources. + +We suggest using .NET [Secret Manager](https://learn.microsoft.com/en-us/aspnet/core/security/app-secrets) +to avoid the risk of leaking secrets into the repository, branches and pull requests. +You can also use environment variables if you prefer. + +To set your secrets with Secret Manager: + +``` +cd dotnet/src/samples/Concepts +dotnet user-secrets init +dotnet user-secrets set "OpenAI:ServiceId" "gpt-3.5-turbo-instruct" +dotnet user-secrets set "OpenAI:ModelId" "gpt-3.5-turbo-instruct" +dotnet user-secrets set "OpenAI:ChatModelId" "gpt-4" +dotnet user-secrets set "OpenAI:ApiKey" "..." +... +``` + +### Option 2: Use Configuration File +1. Create a `appsettings.Development.json` file next to the `Concepts.csproj` file. This file will be ignored by git, + the content will not end up in pull requests, so it's safe for personal settings. Keep the file safe. +2. Edit `appsettings.Development.json` and set the appropriate configuration for the samples you are running. + +For example: + +```json +{ + "OpenAI": { + "ServiceId": "gpt-3.5-turbo-instruct", + "ModelId": "gpt-3.5-turbo-instruct", + "ChatModelId": "gpt-4", + "ApiKey": "sk-...." + }, + "AzureOpenAI": { + "ServiceId": "azure-gpt-35-turbo-instruct", + "DeploymentName": "gpt-35-turbo-instruct", + "ChatDeploymentName": "gpt-4", + "Endpoint": "https://contoso.openai.azure.com/", + "ApiKey": "...." + }, + // etc. +} +``` + +### Option 3: Use Environment Variables +You may also set the settings in your environment variables. The environment variables will override the settings in the `appsettings.Development.json` file. + +When setting environment variables, use a double underscore (i.e. "\_\_") to delineate between parent and child properties. For example: + +- bash: + + ```bash + export OpenAI__ApiKey="sk-...." + export AzureOpenAI__ApiKey="...." + export AzureOpenAI__DeploymentName="gpt-35-turbo-instruct" + export AzureOpenAI__ChatDeploymentName="gpt-4" + export AzureOpenAIEmbeddings__DeploymentName="azure-text-embedding-ada-002" + export AzureOpenAI__Endpoint="https://contoso.openai.azure.com/" + export HuggingFace__ApiKey="...." + export Bing__ApiKey="...." + export Postgres__ConnectionString="...." + ``` + +- PowerShell: + + ```ps + $env:OpenAI__ApiKey = "sk-...." + $env:AzureOpenAI__ApiKey = "...." + $env:AzureOpenAI__DeploymentName = "gpt-35-turbo-instruct" + $env:AzureOpenAI__ChatDeploymentName = "gpt-4" + $env:AzureOpenAIEmbeddings__DeploymentName = "azure-text-embedding-ada-002" + $env:AzureOpenAI__Endpoint = "https://contoso.openai.azure.com/" + $env:HuggingFace__ApiKey = "...." + $env:Bing__ApiKey = "...." + $env:Postgres__ConnectionString = "...." + ``` diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj index a5ec850f1b6e..b1904c6cc1cd 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj @@ -29,4 +29,9 @@ + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs index 70747990e2fd..2af6d4f5fb62 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// -/// Interface for client managing postgres database operations. +/// Interface for client managing postgres database operations for . /// public interface IPostgresDbClient { diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs new file mode 100644 index 000000000000..d130d2f13b44 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Interface for constructing SQL commands for Postgres vector store collections. +/// +internal interface IPostgresVectorStoreCollectionSqlBuilder +{ + /// + /// Builds a SQL command to check if a table exists in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The built SQL command. + /// + /// The command must return a single row with a single column named "table_name" if the table exists. + /// + PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName); + + /// + /// Builds a SQL command to fetch all tables in the Postgres vector store. + /// + /// The schema of the tables. + PostgresSqlCommandInfo BuildGetTablesCommand(string schema); + + /// + /// Builds a SQL command to create a table in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The properties of the table. + /// Specifies whether to include IF NOT EXISTS in the command. + /// The built SQL command info. + PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, IReadOnlyList properties, bool ifNotExists = true); + + /// + /// Builds a SQL command to create a vector index in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The name of the vector column. + /// The kind of index to create. + /// The distance function to use for the index. + /// The built SQL command info. + PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction); + + /// + /// Builds a SQL command to drop a table in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The built SQL command info. + PostgresSqlCommandInfo BuildDropTableCommand(string schema, string tableName); + + /// + /// Builds a SQL command to upsert a record in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The key column of the table. + /// The row to upsert. + /// The built SQL command info. + PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, string keyColumn, Dictionary row); + + /// + /// Builds a SQL command to upsert a batch of records in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The key column of the table. + /// The rows to upsert. + /// The built SQL command info. + PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tableName, string keyColumn, List> rows); + + /// + /// Builds a SQL command to get a record from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The properties of the table. + /// The key of the record to get. + /// Specifies whether to include vectors in the record. + /// The built SQL command info. + PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, IReadOnlyList properties, TKey key, bool includeVectors = false) where TKey : notnull; + + /// + /// Builds a SQL command to get a batch of records from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The properties of the table. + /// The keys of the records to get. + /// Specifies whether to include vectors in the records. + /// The built SQL command info. + PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, IReadOnlyList properties, List keys, bool includeVectors = false) where TKey : notnull; + + /// + /// Builds a SQL command to delete a record from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The key column of the table. + /// The key of the record to delete. + /// The built SQL command info. + PostgresSqlCommandInfo BuildDeleteCommand(string schema, string tableName, string keyColumn, TKey key); + + /// + /// Builds a SQL command to delete a batch of records from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The key column of the table. + /// The keys of the records to delete. + /// The built SQL command info. + PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, string tableName, string keyColumn, List keys); + + /// + /// Builds a SQL command to get the nearest match from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The properties of the table. + /// The property which the vectors to compare are stored in. + /// The vector to match. + /// The filter conditions for the query. + /// The number of records to skip. + /// Specifies whether to include vectors in the result. + /// The maximum number of records to return. + /// The built SQL command info. + PostgresSqlCommandInfo BuildGetNearestMatchCommand(string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, VectorSearchFilter? filter, int? skip, bool includeVectors, int limit); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs new file mode 100644 index 000000000000..59aa9829c568 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Npgsql; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Internal interface for client managing postgres database operations. +/// +internal interface IPostgresVectorStoreDbClient +{ + /// + /// The used to connect to the database. + /// + public NpgsqlDataSource DataSource { get; } + + /// + /// Check if a table exists. + /// + /// The name assigned to a table of entries. + /// The to monitor for cancellation requests. The default is . + /// + Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default); + + /// + /// Get all tables. + /// + /// The to monitor for cancellation requests. The default is . + /// A group of tables. + IAsyncEnumerable GetTablesAsync(CancellationToken cancellationToken = default); + /// + /// Create a table. Also creates an index on vector columns if the table has vector properties defined. + /// + /// The name assigned to a table of entries. + /// The properties of the record definition that define the table. + /// Specifies whether to include IF NOT EXISTS in the command. + /// The to monitor for cancellation requests. The default is . + /// + Task CreateTableAsync(string tableName, IReadOnlyList properties, bool ifNotExists = true, CancellationToken cancellationToken = default); + + /// + /// Drop a table. + /// + /// The name assigned to a table of entries. + /// The to monitor for cancellation requests. The default is . + Task DeleteTableAsync(string tableName, CancellationToken cancellationToken = default); + + /// + /// Upsert entry into a table. + /// + /// The name assigned to a table of entries. + /// The row to upsert into the table. + /// The key column of the table. + /// The to monitor for cancellation requests. The default is . + /// + Task UpsertAsync(string tableName, Dictionary row, string keyColumn, CancellationToken cancellationToken = default); + + /// + /// Upsert multiple entries into a table. + /// + /// The name assigned to a table of entries. + /// The rows to upsert into the table. + /// The key column of the table. + /// The to monitor for cancellation requests. The default is . + /// + Task UpsertBatchAsync(string tableName, IEnumerable> rows, string keyColumn, CancellationToken cancellationToken = default); + + /// + /// Get a entry by its key. + /// + /// The name assigned to a table of entries. + /// The key of the entry to get. + /// The properties to include in the entry. + /// If true, the vectors will be included in the entry. + /// The to monitor for cancellation requests. The default is . + /// The row if the key is found, otherwise null. + Task?> GetAsync(string tableName, TKey key, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) + where TKey : notnull; + + /// + /// Get multiple entries by their keys. + /// + /// The name assigned to a table of entries. + /// The keys of the entries to get. + /// The properties of the table. + /// If true, the vectors will be included in the entries. + /// The to monitor for cancellation requests. The default is . + /// The rows that match the given keys. + IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) + where TKey : notnull; + + /// + /// Delete a entry by its key. + /// + /// The name assigned to a table of entries. + /// The name of the key column. + /// The key of the entry to delete. + /// The to monitor for cancellation requests. The default is . + /// + Task DeleteAsync(string tableName, string keyColumn, TKey key, CancellationToken cancellationToken = default); + + /// + /// Delete multiple entries by their keys. + /// + /// The name assigned to a table of entries. + /// The name of the key column. + /// The keys of the entries to delete. + /// The to monitor for cancellation requests. The default is . + /// + Task DeleteBatchAsync(string tableName, string keyColumn, IEnumerable keys, CancellationToken cancellationToken = default); + + /// + /// Gets the nearest matches to the . + /// + /// The name assigned to a table of entries. + /// The properties to retrieve. + /// The property which the vectors to compare are stored in. + /// The to compare the table's vector with. + /// The maximum number of similarity results to return. + /// Optional conditions to filter the results. + /// The number of entries to skip. + /// If true, the vectors will be returned in the entries. + /// The to monitor for cancellation requests. The default is . + /// An asynchronous stream of objects that the nearest matches to the . + IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync(string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, + VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs new file mode 100644 index 000000000000..5bf0d9cad789 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Interface for constructing Postgres instances when using to retrieve these. +/// +public interface IPostgresVectorStoreRecordCollectionFactory +{ + /// + /// Constructs a new instance of the . + /// + /// The data type of the record key. + /// The data model to use for adding, updating and retrieving data from storage. + /// The Postgres data source. + /// The name of the collection to connect to. + /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. + /// The new instance of . + IVectorStoreRecordCollection CreateVectorStoreRecordCollection(NpgsqlDataSource dataSource, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + where TKey : notnull; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs new file mode 100644 index 000000000000..f8784890e83a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal static class PostgresConstants +{ + /// The name of this database for telemetry purposes. + public const string DatabaseName = "Postgres"; + + /// A of types that a key on the provided model may have. + public static readonly HashSet SupportedKeyTypes = + [ + typeof(short), + typeof(int), + typeof(long), + typeof(string), + typeof(Guid), + ]; + + /// A of types that data properties on the provided model may have. + public static readonly HashSet SupportedDataTypes = + [ + typeof(bool), + typeof(bool?), + typeof(short), + typeof(short?), + typeof(int), + typeof(int?), + typeof(long), + typeof(long?), + typeof(float), + typeof(float?), + typeof(double), + typeof(double?), + typeof(decimal), + typeof(decimal?), + typeof(string), + typeof(DateTime), + typeof(DateTime?), + typeof(DateTimeOffset), + typeof(DateTimeOffset?), + typeof(Guid), + typeof(Guid?), + typeof(byte[]), + ]; + + /// A of types that enumerable data properties on the provided model may use as their element types. + public static readonly HashSet SupportedEnumerableDataElementTypes = + [ + typeof(bool), + typeof(short), + typeof(int), + typeof(long), + typeof(float), + typeof(double), + typeof(decimal), + typeof(string), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(Guid), + ]; + + /// A of types that vector properties on the provided model may have. + public static readonly HashSet SupportedVectorTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?) + ]; + + /// The default schema name. + public const string DefaultSchema = "public"; + + /// The name of the column that returns distance value in the database. + /// It is used in the similarity search query. Must not conflict with model property. + public const string DistanceColumnName = "sk_pg_distance"; + + /// The default index kind. + /// Defaults to "Flat", which means no indexing. + public const string DefaultIndexKind = IndexKind.Flat; + + /// The default distance function. + public const string DefaultDistanceFunction = DistanceFunction.CosineDistance; + + public static readonly Dictionary IndexMaxDimensions = new() + { + { IndexKind.Hnsw, 2000 }, + }; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs index 1dc1ffef3c1d..d927710d4fd9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs @@ -13,7 +13,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// -/// An implementation of a client for Postgres. This class is used to managing postgres database operations. +/// An implementation of a client for Postgres. This class is used to managing postgres database operations for . /// [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] public class PostgresDbClient : IPostgresDbClient diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs new file mode 100644 index 000000000000..efdec538c772 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal sealed class PostgresGenericDataModelMapper : IVectorStoreRecordMapper, Dictionary> + where TKey : notnull +{ + /// with helpers for reading vector store model properties and their attributes. + private readonly VectorStoreRecordPropertyReader _propertyReader; + + /// + /// Initializes a new instance of the class. + /// /// + /// with helpers for reading vector store model properties and their attributes. + public PostgresGenericDataModelMapper(VectorStoreRecordPropertyReader propertyReader) + { + Verify.NotNull(propertyReader); + + this._propertyReader = propertyReader; + + // Validate property types. + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); + this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); + } + + public Dictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) + { + var properties = new Dictionary + { + // Add key property + { this._propertyReader.KeyPropertyStoragePropertyName, dataModel.Key } + }; + + // Add data properties + if (dataModel.Data is not null) + { + foreach (var property in this._propertyReader.DataProperties) + { + if (dataModel.Data.TryGetValue(property.DataModelPropertyName, out var dataValue)) + { + properties.Add(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), dataValue); + } + } + } + + // Add vector properties + if (dataModel.Vectors is not null) + { + foreach (var property in this._propertyReader.VectorProperties) + { + if (dataModel.Vectors.TryGetValue(property.DataModelPropertyName, out var vectorValue)) + { + var result = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vectorValue); + properties.Add(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), result); + } + } + } + + return properties; + } + + public VectorStoreGenericDataModel MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) + { + TKey key; + var dataProperties = new Dictionary(); + var vectorProperties = new Dictionary(); + + // Process key property. + if (storageModel.TryGetValue(this._propertyReader.KeyPropertyStoragePropertyName, out var keyObject) && keyObject is not null) + { + key = (TKey)keyObject; + } + else + { + throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); + } + + // Process data properties. + foreach (var property in this._propertyReader.DataProperties) + { + if (storageModel.TryGetValue(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), out var dataValue)) + { + dataProperties.Add(property.DataModelPropertyName, dataValue); + } + } + + // Process vector properties + if (options.IncludeVectors) + { + foreach (var property in this._propertyReader.VectorProperties) + { + if (storageModel.TryGetValue(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), out var vectorValue)) + { + vectorProperties.Add(property.DataModelPropertyName, PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(vectorValue)); + } + } + } + + return new VectorStoreGenericDataModel(key) { Data = dataProperties, Vectors = vectorProperties }; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs new file mode 100644 index 000000000000..983b8e7db443 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods to register Postgres instances on an . +/// +public static class PostgresServiceCollectionExtensions +{ + /// + /// Register a Postgres with the specified service ID and where the NpgsqlDataSource is retrieved from the dependency injection container. + /// + /// The to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPostgresVectorStore(this IServiceCollection services, PostgresVectorStoreOptions? options = default, string? serviceId = default) + { + // Since we are not constructing the data source, add the IVectorStore as transient, since we + // cannot make assumptions about how data source is being managed. + services.AddKeyedTransient( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService(); + + return new PostgresVectorStore( + dataSource, + selectedOptions); + }); + + return services; + } + + /// + /// Register a Postgres with the specified service ID and where an NpgsqlDataSource is constructed using the provided parameters. + /// + /// The to register the on. + /// Postgres database connection string. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPostgresVectorStore(this IServiceCollection services, string connectionString, PostgresVectorStoreOptions? options = default, string? serviceId = default) + { + string? npgsqlServiceId = serviceId == null ? default : $"{serviceId}_NpgsqlDataSource"; + // Register NpgsqlDataSource to ensure proper disposal. + services.AddKeyedSingleton( + npgsqlServiceId, + (sp, obj) => + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + return dataSourceBuilder.Build(); + }); + + services.AddKeyedSingleton( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredKeyedService(npgsqlServiceId); + var selectedOptions = options ?? sp.GetService(); + + return new PostgresVectorStore( + dataSource, + selectedOptions); + }); + + return services; + } + + /// + /// Register a Postgres and with the specified service ID + /// and where the NpgsqlDataSource is retrieved from the dependency injection container. + /// + /// The type of the key. + /// The type of the record. + /// The to register the on. + /// The name of the collection. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// Service collection. + public static IServiceCollection AddPostgresVectorStoreRecordCollection( + this IServiceCollection services, + string collectionName, + PostgresVectorStoreRecordCollectionOptions? options = default, + string? serviceId = default) + where TKey : notnull + { + services.AddKeyedTransient>( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService>(); + + return (new PostgresVectorStoreRecordCollection(dataSource, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + }); + + AddVectorizedSearch(services, serviceId); + + return services; + } + + /// + /// Register a Postgres and with the specified service ID + /// and where the NpgsqlDataSource is constructed using the provided parameters. + /// + /// The type of the key. + /// The type of the record. + /// The to register the on. + /// The name of the collection. + /// Postgres database connection string. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// Service collection. + public static IServiceCollection AddPostgresVectorStoreRecordCollection( + this IServiceCollection services, + string collectionName, + string connectionString, + PostgresVectorStoreRecordCollectionOptions? options = default, + string? serviceId = default) + where TKey : notnull + { + string? npgsqlServiceId = serviceId == null ? default : $"{serviceId}_NpgsqlDataSource"; + // Register NpgsqlDataSource to ensure proper disposal. + services.AddKeyedSingleton( + npgsqlServiceId, + (sp, obj) => + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + return dataSourceBuilder.Build(); + }); + + services.AddKeyedSingleton>( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredKeyedService(npgsqlServiceId); + + return (new PostgresVectorStoreRecordCollection(dataSource, collectionName, options) as IVectorStoreRecordCollection)!; + }); + + AddVectorizedSearch(services, serviceId); + + return services; + } + + /// + /// Also register the with the given as a . + /// + /// The type of the key. + /// The type of the data model that the collection should contain. + /// The service collection to register on. + /// The service id that the registrations should use. + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + where TKey : notnull + { + services.AddKeyedTransient>( + serviceId, + (sp, obj) => + { + return sp.GetRequiredKeyedService>(serviceId); + }); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs new file mode 100644 index 000000000000..fb520188b84b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Represents a SQL command for Postgres. +/// +internal class PostgresSqlCommandInfo +{ + /// + /// Gets or sets the SQL command text. + /// + public string CommandText { get; set; } + /// + /// Gets or sets the parameters for the SQL command. + /// + public List? Parameters { get; set; } = null; + + /// + /// Initializes a new instance of the class. + /// + /// The SQL command text. + /// The parameters for the SQL command. + public PostgresSqlCommandInfo(string commandText, List? parameters = null) + { + this.CommandText = commandText; + this.Parameters = parameters; + } + + /// + /// Converts this instance to an . + /// + [SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "User input is passed using command parameters.")] + public NpgsqlCommand ToNpgsqlCommand(NpgsqlConnection connection, NpgsqlTransaction? transaction = null) + { + NpgsqlCommand cmd = connection.CreateCommand(); + if (transaction != null) + { + cmd.Transaction = transaction; + } + cmd.CommandText = this.CommandText; + if (this.Parameters != null) + { + foreach (var parameter in this.Parameters) + { + cmd.Parameters.Add(parameter); + } + } + return cmd; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs new file mode 100644 index 000000000000..99bbc8e320b5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using Microsoft.Extensions.VectorData; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Represents a vector store implementation using PostgreSQL. +/// +public class PostgresVectorStore : IVectorStore +{ + private readonly IPostgresVectorStoreDbClient _postgresClient; + private readonly NpgsqlDataSource? _dataSource; + private readonly PostgresVectorStoreOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// Postgres data source. + /// Optional configuration options for this class + public PostgresVectorStore(NpgsqlDataSource dataSource, PostgresVectorStoreOptions? options = default) + { + this._dataSource = dataSource; + this._options = options ?? new PostgresVectorStoreOptions(); + this._postgresClient = new PostgresVectorStoreDbClient(this._dataSource, this._options.Schema); + } + + /// + /// Initializes a new instance of the class. + /// + /// An instance of . + /// Optional configuration options for this class + internal PostgresVectorStore(IPostgresVectorStoreDbClient postgresDbClient, PostgresVectorStoreOptions? options = default) + { + this._postgresClient = postgresDbClient; + this._options = options ?? new PostgresVectorStoreOptions(); + } + + /// + public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) + { + const string OperationName = "ListCollectionNames"; + return PostgresVectorStoreUtils.WrapAsyncEnumerableAsync( + this._postgresClient.GetTablesAsync(cancellationToken), + OperationName + ); + } + + /// + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + { + if (!PostgresConstants.SupportedKeyTypes.Contains(typeof(TKey))) + { + throw new NotSupportedException($"Unsupported key type: {typeof(TKey)}"); + } + + if (this._options.VectorStoreCollectionFactory is not null) + { + return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(this._postgresClient.DataSource, name, vectorStoreRecordDefinition); + } + + var recordCollection = new PostgresVectorStoreRecordCollection( + this._postgresClient, + name, + new PostgresVectorStoreRecordCollectionOptions() { Schema = this._options.Schema, VectorStoreRecordDefinition = vectorStoreRecordDefinition } + ); + + return recordCollection as IVectorStoreRecordCollection ?? throw new InvalidOperationException("Failed to cast record collection."); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs new file mode 100644 index 000000000000..d68412d31b7d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -0,0 +1,453 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.Extensions.VectorData; +using Npgsql; +using NpgsqlTypes; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Provides methods to build SQL commands for managing vector store collections in PostgreSQL. +/// +internal class PostgresVectorStoreCollectionSqlBuilder : IPostgresVectorStoreCollectionSqlBuilder +{ + /// + public PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName) + { + return new PostgresSqlCommandInfo( + commandText: @" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 + AND table_type = 'BASE TABLE' + AND table_name = $2", + parameters: [ + new NpgsqlParameter() { Value = schema }, + new NpgsqlParameter() { Value = tableName } + ] + ); + } + + /// + public PostgresSqlCommandInfo BuildGetTablesCommand(string schema) + { + return new PostgresSqlCommandInfo( + commandText: @" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 + AND table_type = 'BASE TABLE'", + parameters: [new NpgsqlParameter() { Value = schema }] + ); + } + + /// + public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, IReadOnlyList properties, bool ifNotExists = true) + { + if (string.IsNullOrWhiteSpace(tableName)) + { + throw new ArgumentException("Table name cannot be null or whitespace", nameof(tableName)); + } + + VectorStoreRecordKeyProperty? keyProperty = default; + List dataProperties = new(); + List vectorProperties = new(); + + foreach (var property in properties) + { + if (property is VectorStoreRecordKeyProperty keyProp) + { + if (keyProperty != null) + { + // Should be impossible, as property reader should have already validated that + // multiple key properties are not allowed. + throw new ArgumentException("Record definition cannot have more than one key property."); + } + keyProperty = keyProp; + } + else if (property is VectorStoreRecordDataProperty dataProp) + { + dataProperties.Add(dataProp); + } + else if (property is VectorStoreRecordVectorProperty vectorProp) + { + vectorProperties.Add(vectorProp); + } + else + { + throw new NotSupportedException($"Property type {property.GetType().Name} is not supported by this store."); + } + } + + if (keyProperty == null) + { + throw new ArgumentException("Record definition must have a key property."); + } + + var keyName = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; + + StringBuilder createTableCommand = new(); + createTableCommand.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : "")}{schema}.\"{tableName}\" ("); + + // Add the key column + var keyPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(keyProperty.PropertyType); + createTableCommand.AppendLine($" \"{keyName}\" {keyPgTypeInfo.PgType} {(keyPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + + // Add the data columns + foreach (var dataProperty in dataProperties) + { + string columnName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; + var dataPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(dataProperty.PropertyType); + createTableCommand.AppendLine($" \"{columnName}\" {dataPgTypeInfo.PgType} {(dataPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + } + + // Add the vector columns + foreach (var vectorProperty in vectorProperties) + { + string columnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + var vectorPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPgVectorTypeName(vectorProperty); + createTableCommand.AppendLine($" \"{columnName}\" {vectorPgTypeInfo.PgType} {(vectorPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + } + + createTableCommand.AppendLine($" PRIMARY KEY (\"{keyName}\")"); + + createTableCommand.AppendLine(");"); + + return new PostgresSqlCommandInfo(commandText: createTableCommand.ToString()); + } + + /// + public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction) + { + // Only support creating HNSW index creation through the connector. + var indexTypeName = indexKind switch + { + IndexKind.Hnsw => "hnsw", + _ => throw new NotSupportedException($"Index kind '{indexKind}' is not supported for table creation. If you need to create an index of this type, please do so manually. Only HNSW indexes are supported through the vector store.") + }; + + distanceFunction ??= PostgresConstants.DefaultDistanceFunction; // Default to Cosine distance + + var indexOps = distanceFunction switch + { + DistanceFunction.CosineDistance => "vector_cosine_ops", + DistanceFunction.CosineSimilarity => "vector_cosine_ops", + DistanceFunction.DotProductSimilarity => "vector_ip_ops", + DistanceFunction.EuclideanDistance => "vector_l2_ops", + DistanceFunction.ManhattanDistance => "vector_l1_ops", + _ => throw new NotSupportedException($"Distance function {distanceFunction} is not supported.") + }; + + var indexName = $"{tableName}_{vectorColumnName}_index"; + + return new PostgresSqlCommandInfo( + commandText: $@" + CREATE INDEX {indexName} ON {schema}.""{tableName}"" USING {indexTypeName} (""{vectorColumnName}"" {indexOps});" + ); + } + + /// + public PostgresSqlCommandInfo BuildDropTableCommand(string schema, string tableName) + { + return new PostgresSqlCommandInfo( + commandText: $@"DROP TABLE IF EXISTS {schema}.""{tableName}""" + ); + } + + /// + public PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, string keyColumn, Dictionary row) + { + var columns = row.Keys.ToList(); + var columnNames = string.Join(", ", columns.Select(k => $"\"{k}\"")); + var valuesParams = string.Join(", ", columns.Select((k, i) => $"${i + 1}")); + var columnsWithIndex = columns.Select((k, i) => (col: k, idx: i)); + var updateColumnsWithParams = string.Join(", ", columnsWithIndex.Where(c => c.col != keyColumn).Select(c => $"\"{c.col}\"=${c.idx + 1}")); + var commandText = $@" + INSERT INTO {schema}.""{tableName}"" ({columnNames}) + VALUES({valuesParams}) + ON CONFLICT (""{keyColumn}"") + DO UPDATE SET {updateColumnsWithParams};"; + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = columns.Select(c => + PostgresVectorStoreRecordPropertyMapping.GetNpgsqlParameter(row[c]) + ).ToList() + }; + } + + /// + public PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tableName, string keyColumn, List> rows) + { + if (rows == null || rows.Count == 0) + { + throw new ArgumentException("Rows cannot be null or empty", nameof(rows)); + } + + var firstRow = rows[0]; + var columns = firstRow.Keys.ToList(); + + // Generate column names and parameter placeholders + var columnNames = string.Join(", ", columns.Select(c => $"\"{c}\"")); + var valuePlaceholders = string.Join(", ", columns.Select((c, i) => $"${i + 1}")); + var valuesRows = string.Join(", ", + rows.Select((row, rowIndex) => + $"({string.Join(", ", + columns.Select((c, colIndex) => $"${rowIndex * columns.Count + colIndex + 1}"))})")); + + // Generate the update set clause + var updateSetClause = string.Join(", ", columns.Where(c => c != keyColumn).Select(c => $"\"{c}\" = EXCLUDED.\"{c}\"")); + + // Generate the SQL command + var commandText = $@" + INSERT INTO {schema}.""{tableName}"" ({columnNames}) + VALUES {valuesRows} + ON CONFLICT (""{keyColumn}"") + DO UPDATE SET {updateSetClause}; "; + + // Generate the parameters + var parameters = new List(); + for (int rowIndex = 0; rowIndex < rows.Count; rowIndex++) + { + var row = rows[rowIndex]; + foreach (var column in columns) + { + parameters.Add(new NpgsqlParameter() + { + Value = row[column] ?? DBNull.Value + }); + } + } + + return new PostgresSqlCommandInfo(commandText, parameters); + } + + /// + public PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, IReadOnlyList properties, TKey key, bool includeVectors = false) + where TKey : notnull + { + List queryColumns = new(); + string? keyColumn = null; + + foreach (var property in properties) + { + if (property is VectorStoreRecordKeyProperty keyProperty) + { + if (keyColumn != null) + { + throw new ArgumentException("Record definition cannot have more than one key property."); + } + keyColumn = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; + queryColumns.Add($"\"{keyColumn}\""); + } + else if (property is VectorStoreRecordDataProperty dataProperty) + { + string columnName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; + queryColumns.Add($"\"{columnName}\""); + } + else if (property is VectorStoreRecordVectorProperty vectorProperty && includeVectors) + { + string columnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + queryColumns.Add($"\"{columnName}\""); + } + } + + Verify.NotNull(keyColumn, "Record definition must have a key property."); + + var queryColumnList = string.Join(", ", queryColumns); + + return new PostgresSqlCommandInfo( + commandText: $@" + SELECT {queryColumnList} + FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ${1};", + parameters: [new NpgsqlParameter() { Value = key }] + ); + } + + /// + public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, IReadOnlyList properties, List keys, bool includeVectors = false) + where TKey : notnull + { + NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); + + if (keys == null || keys.Count == 0) + { + throw new ArgumentException("Keys cannot be null or empty", nameof(keys)); + } + + var keyProperty = properties.OfType().FirstOrDefault() ?? throw new ArgumentException("Properties must contain a key property", nameof(properties)); + var keyColumn = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; + + // Generate the column names + var columns = properties + .Where(p => includeVectors || p is not VectorStoreRecordVectorProperty) + .Select(p => p.StoragePropertyName ?? p.DataModelPropertyName) + .ToList(); + + var columnNames = string.Join(", ", columns.Select(c => $"\"{c}\"")); + var keyParams = string.Join(", ", keys.Select((k, i) => $"${i + 1}")); + + // Generate the SQL command + var commandText = $@" + SELECT {columnNames} + FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ANY($1);"; + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = [new NpgsqlParameter() { Value = keys.ToArray(), NpgsqlDbType = NpgsqlDbType.Array | keyType.Value }] + }; + } + + /// + public PostgresSqlCommandInfo BuildDeleteCommand(string schema, string tableName, string keyColumn, TKey key) + { + return new PostgresSqlCommandInfo( + commandText: $@" + DELETE FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ${1};", + parameters: [new NpgsqlParameter() { Value = key }] + ); + } + + /// + public PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, string tableName, string keyColumn, List keys) + { + NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); + if (keys == null || keys.Count == 0) + { + throw new ArgumentException("Keys cannot be null or empty", nameof(keys)); + } + + for (int i = 0; i < keys.Count; i++) + { + if (keys[i] == null) + { + throw new ArgumentException("Keys cannot contain null values", nameof(keys)); + } + } + + var commandText = $@" + DELETE FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ANY($1);"; + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = [new NpgsqlParameter() { Value = keys, NpgsqlDbType = NpgsqlDbType.Array | keyType.Value }] + }; + } + + /// + public PostgresSqlCommandInfo BuildGetNearestMatchCommand( + string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, + VectorSearchFilter? filter, int? skip, bool includeVectors, int limit) + { + var columns = string.Join(" ,", + properties + .Select(property => property.StoragePropertyName ?? property.DataModelPropertyName) + .Select(column => $"\"{column}\"") + ); + + var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; + var distanceOp = distanceFunction switch + { + DistanceFunction.CosineDistance => "<=>", + DistanceFunction.CosineSimilarity => "<=>", + DistanceFunction.EuclideanDistance => "<->", + DistanceFunction.ManhattanDistance => "<+>", + DistanceFunction.DotProductSimilarity => "<#>", + null or "" => "<->", // Default to Euclidean distance + _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") + }; + + var vectorColumn = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + // Start where clause params at 2, vector takes param 1. + var where = GenerateWhereClause(schema, tableName, properties, filter, startParamIndex: 2); + + var commandText = $@" + SELECT {columns}, ""{vectorColumn}"" {distanceOp} $1 AS ""{PostgresConstants.DistanceColumnName}"" + FROM {schema}.""{tableName}"" {where.Clause} + ORDER BY {PostgresConstants.DistanceColumnName} + LIMIT {limit}"; + + if (skip.HasValue) { commandText += $" OFFSET {skip.Value}"; } + + // For cosine similarity, we need to take 1 - cosine distance. + // However, we can't use an expression in the ORDER BY clause or else the index won't be used. + // Instead we'll wrap the query in a subquery and modify the distance in the outer query. + if (vectorProperty.DistanceFunction == DistanceFunction.CosineSimilarity) + { + commandText = $@" + SELECT {columns}, 1 - ""{PostgresConstants.DistanceColumnName}"" AS ""{PostgresConstants.DistanceColumnName}"" + FROM ({commandText}) AS subquery"; + } + + // For inner product, we need to take -1 * inner product. + // However, we can't use an expression in the ORDER BY clause or else the index won't be used. + // Instead we'll wrap the query in a subquery and modify the distance in the outer query. + if (vectorProperty.DistanceFunction == DistanceFunction.DotProductSimilarity) + { + commandText = $@" + SELECT {columns}, -1 * ""{PostgresConstants.DistanceColumnName}"" AS ""{PostgresConstants.DistanceColumnName}"" + FROM ({commandText}) AS subquery"; + } + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = [new NpgsqlParameter() { Value = vectorValue }, .. where.Parameters.Select(p => new NpgsqlParameter() { Value = p })] + }; + } + + internal static (string Clause, List Parameters) GenerateWhereClause(string schema, string tableName, IReadOnlyList properties, VectorSearchFilter? filter, int startParamIndex) + { + if (filter == null) { return (string.Empty, new List()); } + + var whereClause = new StringBuilder("WHERE "); + var filterClauses = new List(); + var parameters = new List(); + + var paramIndex = startParamIndex; + + foreach (var filterClause in filter.FilterClauses) + { + if (filterClause is EqualToFilterClause equalTo) + { + var property = properties.FirstOrDefault(p => p.DataModelPropertyName == equalTo.FieldName); + if (property == null) { throw new ArgumentException($"Property {equalTo.FieldName} not found in record definition."); } + + var columnName = property.StoragePropertyName ?? property.DataModelPropertyName; + filterClauses.Add($"\"{columnName}\" = ${paramIndex}"); + parameters.Add(equalTo.Value); + paramIndex++; + } + else if (filterClause is AnyTagEqualToFilterClause anyTagEqualTo) + { + var property = properties.FirstOrDefault(p => p.DataModelPropertyName == anyTagEqualTo.FieldName); + if (property == null) { throw new ArgumentException($"Property {anyTagEqualTo.FieldName} not found in record definition."); } + + if (property.PropertyType != typeof(List)) + { + throw new ArgumentException($"Property {anyTagEqualTo.FieldName} must be of type List to use AnyTagEqualTo filter."); + } + + var columnName = property.StoragePropertyName ?? property.DataModelPropertyName; + filterClauses.Add($"\"{columnName}\" @> ARRAY[${paramIndex}::TEXT]"); + parameters.Add(anyTagEqualTo.Value); + paramIndex++; + } + else + { + throw new NotSupportedException($"Filter clause type {filterClause.GetType().Name} is not supported."); + } + } + + whereClause.Append(string.Join(" AND ", filterClauses)); + return (whereClause.ToString(), parameters); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs new file mode 100644 index 000000000000..5ef18cc88fdf --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Npgsql; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// An implementation of a client for Postgres. This class is used to managing postgres database operations. +/// +/// +/// Initializes a new instance of the class. +/// +/// Postgres data source. +/// Schema of collection tables. +[System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] +internal class PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema = PostgresConstants.DefaultSchema) : IPostgresVectorStoreDbClient +{ + private readonly string _schema = schema; + + private IPostgresVectorStoreCollectionSqlBuilder _sqlBuilder = new PostgresVectorStoreCollectionSqlBuilder(); + + public NpgsqlDataSource DataSource { get; } = dataSource; + + /// + public async Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildDoesTableExistCommand(this._schema, tableName); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + return dataReader.GetString(dataReader.GetOrdinal("table_name")) == tableName; + } + + return false; + } + } + + /// + public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetTablesCommand(this._schema); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return dataReader.GetString(dataReader.GetOrdinal("table_name")); + } + } + } + + /// + public async Task CreateTableAsync(string tableName, IReadOnlyList properties, bool ifNotExists = true, CancellationToken cancellationToken = default) + { + // Prepare the SQL commands. + var commandInfo = this._sqlBuilder.BuildCreateTableCommand(this._schema, tableName, properties, ifNotExists); + var createIndexCommands = + PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo(properties) + .Select(index => + this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, index.column, index.kind, index.function) + ); + + // Execute the commands in a transaction. + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { +#if !NETSTANDARD2_0 + var transaction = await connection.BeginTransactionAsync(cancellationToken).ConfigureAwait(false); + await using (transaction) +#else + var transaction = connection.BeginTransaction(); + using (transaction) +#endif + { + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection, transaction); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + + foreach (var createIndexCommand in createIndexCommands) + { + using NpgsqlCommand indexCmd = createIndexCommand.ToNpgsqlCommand(connection, transaction); + await indexCmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + +#if !NETSTANDARD2_0 + await transaction.CommitAsync(cancellationToken).ConfigureAwait(false); +#else + transaction.Commit(); +#endif + } + } + } + + /// + public async Task DeleteTableAsync(string tableName, CancellationToken cancellationToken = default) + { + var commandInfo = this._sqlBuilder.BuildDropTableCommand(this._schema, tableName); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task UpsertAsync(string tableName, Dictionary row, string keyColumn, CancellationToken cancellationToken = default) + { + var commandInfo = this._sqlBuilder.BuildUpsertCommand(this._schema, tableName, keyColumn, row); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task UpsertBatchAsync(string tableName, IEnumerable> rows, string keyColumn, CancellationToken cancellationToken = default) + { + var commandInfo = this._sqlBuilder.BuildUpsertBatchCommand(this._schema, tableName, keyColumn, rows.ToList()); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task?> GetAsync(string tableName, TKey key, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetCommand(this._schema, tableName, properties, key, includeVectors); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + return this.GetRecord(dataReader, properties, includeVectors); + } + + return null; + } + } + + /// + public async IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, IReadOnlyList properties, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TKey : notnull + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetBatchCommand(this._schema, tableName, properties, keys.ToList(), includeVectors); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return this.GetRecord(dataReader, properties, includeVectors); + } + } + } + + /// + public async Task DeleteAsync(string tableName, string keyColumn, TKey key, CancellationToken cancellationToken = default) + { + var commandInfo = this._sqlBuilder.BuildDeleteCommand(this._schema, tableName, keyColumn, key); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync( + string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, + VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetNearestMatchCommand(this._schema, tableName, properties, vectorProperty, vectorValue, filter, skip, includeVectors, limit); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + var distance = dataReader.GetDouble(dataReader.GetOrdinal(PostgresConstants.DistanceColumnName)); + yield return (Row: this.GetRecord(dataReader, properties, includeVectors), Distance: distance); + } + } + } + + /// + public async Task DeleteBatchAsync(string tableName, string keyColumn, IEnumerable keys, CancellationToken cancellationToken = default) + { + var commandInfo = this._sqlBuilder.BuildDeleteBatchCommand(this._schema, tableName, keyColumn, keys.ToList()); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); + } + + #region internal =============================================================================== + + /// + /// Sets the SQL builder for the client. + /// + /// + /// + /// This method is used for other Semnatic Kernel connectors that may need to override the default SQL + /// used by this client. + /// + internal void SetSqlBuilder(IPostgresVectorStoreCollectionSqlBuilder sqlBuilder) + { + this._sqlBuilder = sqlBuilder; + } + + #endregion + + #region private ================================================================================ + + private Dictionary GetRecord( + NpgsqlDataReader reader, + IEnumerable properties, + bool includeVectors = false + ) + { + var storageModel = new Dictionary(); + + foreach (var property in properties) + { + var isEmbedding = property is VectorStoreRecordVectorProperty; + var propertyName = property.StoragePropertyName ?? property.DataModelPropertyName; + var propertyType = property.PropertyType; + var propertyValue = !isEmbedding || includeVectors ? PostgresVectorStoreRecordPropertyMapping.GetPropertyValue(reader, propertyName, propertyType) : null; + + storageModel.Add(propertyName, propertyValue); + } + + return storageModel; + } + + private async Task ExecuteNonQueryAsync(PostgresSqlCommandInfo commandInfo, CancellationToken cancellationToken) + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs new file mode 100644 index 000000000000..013f1810e146 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Options when creating a . +/// +public sealed class PostgresVectorStoreOptions +{ + /// + /// Gets or sets the database schema. + /// + public string Schema { get; init; } = "public"; + + /// + /// An optional factory to use for constructing instances, if a custom record collection is required. + /// + public IPostgresVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..95c8a4bcf282 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -0,0 +1,378 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Represents a collection of vector store records in a Postgres database. +/// +/// The type of the key. +/// The type of the record. +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public sealed class PostgresVectorStoreRecordCollection : IVectorStoreRecordCollection +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix + where TKey : notnull +{ + /// + public string CollectionName { get; } + + /// Postgres client that is used to interact with the database. + private readonly IPostgresVectorStoreDbClient _client; + + // Optional configuration options for this class. + private readonly PostgresVectorStoreRecordCollectionOptions _options; + + /// A helper to access property information for the current data model and record definition. + private readonly VectorStoreRecordPropertyReader _propertyReader; + + /// A mapper to use for converting between the data model and the Azure AI Search record. + private readonly IVectorStoreRecordMapper> _mapper; + + /// The default options for vector search. + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + + /// + /// Initializes a new instance of the class. + /// + /// The data source to use for connecting to the database. + /// The name of the collection. + /// Optional configuration options for this class. + public PostgresVectorStoreRecordCollection(NpgsqlDataSource dataSource, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default) + : this(new PostgresVectorStoreDbClient(dataSource), collectionName, options) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The client to use for interacting with the database. + /// The name of the collection. + /// Optional configuration options for this class. + /// + /// This constructor is internal. It allows internal code to create an instance of this class with a custom client. + /// + internal PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default) + { + // Verify. + Verify.NotNull(client); + Verify.NotNullOrWhiteSpace(collectionName); + VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.DictionaryCustomMapper is not null, PostgresConstants.SupportedKeyTypes); + VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + + // Assign. + this._client = client; + this.CollectionName = collectionName; + this._options = options ?? new PostgresVectorStoreRecordCollectionOptions(); + this._propertyReader = new VectorStoreRecordPropertyReader( + typeof(TRecord), + this._options.VectorStoreRecordDefinition, + new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + }); + + // Validate property types. + this._propertyReader.VerifyKeyProperties(PostgresConstants.SupportedKeyTypes); + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); + this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); + + // Resolve mapper. + // First, if someone has provided a custom mapper, use that. + // If they didn't provide a custom mapper, and the record type is the generic data model, use the built in mapper for that. + // Otherwise, use our own default mapper implementation for all other data models. + if (this._options.DictionaryCustomMapper is not null) + { + this._mapper = this._options.DictionaryCustomMapper; + } + else if (typeof(TRecord).IsGenericType && typeof(TRecord).GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>)) + { + this._mapper = (new PostgresGenericDataModelMapper(this._propertyReader) as IVectorStoreRecordMapper>)!; + } + else + { + this._mapper = new PostgresVectorStoreRecordMapper(this._propertyReader); + } + } + + /// + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + const string OperationName = "DoesTableExists"; + return this.RunOperationAsync(OperationName, () => + this._client.DoesTableExistsAsync(this.CollectionName, cancellationToken) + ); + } + + /// + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + const string OperationName = "CreateCollection"; + return this.RunOperationAsync(OperationName, () => + this.InternalCreateCollectionAsync(false, cancellationToken) + ); + } + + /// + public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + const string OperationName = "CreateCollectionIfNotExists"; + return this.RunOperationAsync(OperationName, () => + this.InternalCreateCollectionAsync(true, cancellationToken) + ); + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + const string OperationName = "DeleteCollection"; + return this.RunOperationAsync(OperationName, () => + this._client.DeleteTableAsync(this.CollectionName, cancellationToken) + ); + } + + /// + public Task UpsertAsync(TRecord record, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "Upsert"; + + var storageModel = VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromDataToStorageModel(record)); + + Verify.NotNull(storageModel); + + var keyObj = storageModel[this._propertyReader.KeyPropertyStoragePropertyName]; + Verify.NotNull(keyObj); + TKey key = (TKey)keyObj!; + + return this.RunOperationAsync(OperationName, async () => + { + await this._client.UpsertAsync(this.CollectionName, storageModel, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); + return key; + } + ); + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + const string OperationName = "UpsertBatch"; + + var storageModels = records.Select(record => VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromDataToStorageModel(record))).ToList(); + + var keys = storageModels.Select(model => model[this._propertyReader.KeyPropertyStoragePropertyName]!).ToList(); + + await this.RunOperationAsync(OperationName, () => + this._client.UpsertBatchAsync(this.CollectionName, storageModels, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken) + ).ConfigureAwait(false); + + foreach (var key in keys) { yield return (TKey)key!; } + } + + /// + public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "Get"; + + Verify.NotNull(key); + + bool includeVectors = options?.IncludeVectors is true; + + return this.RunOperationAsync(OperationName, async () => + { + var row = await this._client.GetAsync(this.CollectionName, key, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken).ConfigureAwait(false); + + if (row is null) { return default; } + return VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })); + }); + } + + /// + public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "GetBatch"; + + Verify.NotNull(keys); + + bool includeVectors = options?.IncludeVectors is true; + + return PostgresVectorStoreUtils.WrapAsyncEnumerableAsync( + this._client.GetBatchAsync(this.CollectionName, keys, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken) + .SelectAsync(row => + VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })), + cancellationToken + ), + OperationName, + this.CollectionName + ); + } + + /// + public Task DeleteAsync(TKey key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "Delete"; + return this.RunOperationAsync(OperationName, () => + this._client.DeleteAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, key, cancellationToken) + ); + } + + /// + public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "DeleteBatch"; + return this.RunOperationAsync(OperationName, () => + this._client.DeleteBatchAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, keys, cancellationToken) + ); + } + + /// + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "VectorizedSearch"; + + Verify.NotNull(vector); + + var vectorType = vector.GetType(); + + if (!PostgresConstants.SupportedVectorTypes.Contains(vectorType)) + { + throw new NotSupportedException( + $"The provided vector type {vectorType.FullName} is not supported by the SQLite connector. " + + $"Supported types are: {string.Join(", ", PostgresConstants.SupportedVectorTypes.Select(l => l.FullName))}"); + } + + var searchOptions = options ?? s_defaultVectorSearchOptions; + var vectorProperty = this.GetVectorPropertyForSearch(searchOptions.VectorPropertyName); + + if (vectorProperty is null) + { + throw new InvalidOperationException("The collection does not have any vector properties, so vector search is not possible."); + } + + var pgVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + Verify.NotNull(pgVector); + + // Simulating skip/offset logic locally, since OFFSET can work only with LIMIT in combination + // and LIMIT is not supported in vector search extension, instead of LIMIT - "k" parameter is used. + var limit = searchOptions.Top + searchOptions.Skip; + + return this.RunOperationAsync(OperationName, () => + { + var results = this._client.GetNearestMatchesAsync( + this.CollectionName, + this._propertyReader.RecordDefinition.Properties, + vectorProperty, + pgVector, + searchOptions.Top, + searchOptions.Filter, + searchOptions.Skip, + searchOptions.IncludeVectors, + cancellationToken) + .SelectAsync(result => + { + var record = VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel( + result.Row, new StorageToDataModelMapperOptions() { IncludeVectors = searchOptions.IncludeVectors }) + ); + + return new VectorSearchResult(record, result.Distance); + }, cancellationToken); + + return Task.FromResult(new VectorSearchResults(results)); + }); + } + + private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken = default) + { + return this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition.Properties, ifNotExists, cancellationToken); + } + + /// + /// Get vector property to use for a search by using the storage name for the field name from options + /// if available, and falling back to the first vector property in if not. + /// + /// The vector field name. + /// Thrown if the provided field name is not a valid field name. + private VectorStoreRecordVectorProperty? GetVectorPropertyForSearch(string? vectorFieldName) + { + // If vector property name is provided in options, try to find it in schema or throw an exception. + if (!string.IsNullOrWhiteSpace(vectorFieldName)) + { + // Check vector properties by data model property name. + var vectorProperty = this._propertyReader.VectorProperties + .FirstOrDefault(l => l.DataModelPropertyName.Equals(vectorFieldName, StringComparison.Ordinal)); + + if (vectorProperty is not null) + { + return vectorProperty; + } + + throw new InvalidOperationException($"The {typeof(TRecord).FullName} type does not have a vector property named '{vectorFieldName}'."); + } + + // If vector property is not provided in options, return first vector property from schema. + return this._propertyReader.VectorProperty; + } + + private async Task RunOperationAsync(string operationName, Func operation) + { + try + { + await operation.Invoke().ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = PostgresConstants.DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + } + + private async Task RunOperationAsync(string operationName, Func> operation) + { + try + { + return await operation.Invoke().ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = PostgresConstants.DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..753713d21b3f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Options when creating a . +/// +public sealed class PostgresVectorStoreRecordCollectionOptions +{ + /// + /// Gets or sets the database schema. + /// + public string Schema { get; init; } = "public"; + + /// + /// Gets or sets an optional custom mapper to use when converting between the data model and the Postgres record. + /// + /// + /// If not set, the default mapper will be used. + /// + public IVectorStoreRecordMapper>? DictionaryCustomMapper { get; init; } = null; + + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs new file mode 100644 index 000000000000..e656678413cc --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// A mapper class that handles the conversion between data models and storage models for Postgres vector store. +/// +/// The type of the data model record. +internal sealed class PostgresVectorStoreRecordMapper : IVectorStoreRecordMapper> +{ + /// with helpers for reading vector store model properties and their attributes. + private readonly VectorStoreRecordPropertyReader _propertyReader; + + /// + /// Initializes a new instance of the class. + /// + /// A that defines the schema of the data in the database. + public PostgresVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyReader) + { + Verify.NotNull(propertyReader); + + this._propertyReader = propertyReader; + + this._propertyReader.VerifyHasParameterlessConstructor(); + + // Validate property types. + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); + this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); + } + + public Dictionary MapFromDataToStorageModel(TRecord dataModel) + { + var properties = new Dictionary + { + // Add key property + { this._propertyReader.KeyPropertyStoragePropertyName, this._propertyReader.KeyPropertyInfo.GetValue(dataModel) } + }; + + // Add data properties + foreach (var property in this._propertyReader.DataPropertiesInfo) + { + properties.Add( + this._propertyReader.GetStoragePropertyName(property.Name), + property.GetValue(dataModel) + ); + } + + // Add vector properties + foreach (var property in this._propertyReader.VectorPropertiesInfo) + { + var propertyValue = property.GetValue(dataModel); + var result = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(propertyValue); + + properties.Add(this._propertyReader.GetStoragePropertyName(property.Name), result); + } + + return properties; + } + + public TRecord MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) + { + var record = (TRecord)this._propertyReader.ParameterLessConstructorInfo.Invoke(null); + + // Set key. + var keyPropertyValue = Convert.ChangeType( + storageModel[this._propertyReader.KeyPropertyStoragePropertyName], + this._propertyReader.KeyProperty.PropertyType); + + this._propertyReader.KeyPropertyInfo.SetValue(record, keyPropertyValue); + + // Process data properties. + var dataPropertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( + this._propertyReader.DataPropertiesInfo, + this._propertyReader.StoragePropertyNamesMap, + storageModel); + + VectorStoreRecordMapping.SetPropertiesOnRecord(record, dataPropertiesInfoWithValues); + + if (options.IncludeVectors) + { + // Process vector properties. + var vectorPropertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( + this._propertyReader.VectorPropertiesInfo, + this._propertyReader.StoragePropertyNamesMap, + storageModel, + (object? vector, Type type) => + { + return PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(vector); + }); + + VectorStoreRecordMapping.SetPropertiesOnRecord(record, vectorPropertiesInfoWithValues); + } + + return record; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs new file mode 100644 index 000000000000..0b36f2003bf5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -0,0 +1,269 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using Microsoft.Extensions.VectorData; +using Npgsql; +using NpgsqlTypes; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal static class PostgresVectorStoreRecordPropertyMapping +{ + internal static float[] GetOrCreateArray(ReadOnlyMemory memory) => + MemoryMarshal.TryGetArray(memory, out ArraySegment array) && + array.Count == array.Array!.Length ? + array.Array : + memory.ToArray(); + + public static Vector? MapVectorForStorageModel(TVector vector) + { + if (vector == null) + { + return null; + } + + if (vector is ReadOnlyMemory floatMemory) + { + var vecArray = MemoryMarshal.TryGetArray(floatMemory, out ArraySegment array) && + array.Count == array.Array!.Length ? + array.Array : + floatMemory.ToArray(); + return new Vector(vecArray); + } + + throw new NotSupportedException($"Mapping for type {typeof(TVector).FullName} to a vector is not supported."); + } + + public static ReadOnlyMemory? MapVectorForDataModel(object? vector) + { + var pgVector = vector is Vector pgv ? pgv : null; + if (pgVector == null) { return null; } + var vecArray = pgVector.ToArray(); + return vecArray != null && vecArray.Length != 0 ? (ReadOnlyMemory)vecArray : null; + } + + public static TPropertyType? GetPropertyValue(NpgsqlDataReader reader, string propertyName) + { + int propertyIndex = reader.GetOrdinal(propertyName); + + if (reader.IsDBNull(propertyIndex)) + { + return default; + } + + return reader.GetFieldValue(propertyIndex); + } + + public static object? GetPropertyValue(NpgsqlDataReader reader, string propertyName, Type propertyType) + { + int propertyIndex = reader.GetOrdinal(propertyName); + + if (reader.IsDBNull(propertyIndex)) + { + return null; + } + + // Check if the type implements IEnumerable + if (propertyType.IsGenericType && propertyType.GetInterfaces().Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>))) + { + var enumerable = (IEnumerable)reader.GetValue(propertyIndex); + return VectorStoreRecordMapping.CreateEnumerable(enumerable.Cast(), propertyType); + } + + return propertyType switch + { + Type t when t == typeof(bool) || t == typeof(bool?) => reader.GetBoolean(propertyIndex), + Type t when t == typeof(short) || t == typeof(short?) => reader.GetInt16(propertyIndex), + Type t when t == typeof(int) || t == typeof(int?) => reader.GetInt32(propertyIndex), + Type t when t == typeof(long) || t == typeof(long?) => reader.GetInt64(propertyIndex), + Type t when t == typeof(float) || t == typeof(float?) => reader.GetFloat(propertyIndex), + Type t when t == typeof(double) || t == typeof(double?) => reader.GetDouble(propertyIndex), + Type t when t == typeof(decimal) || t == typeof(decimal?) => reader.GetDecimal(propertyIndex), + Type t when t == typeof(string) => reader.GetString(propertyIndex), + Type t when t == typeof(byte[]) => reader.GetFieldValue(propertyIndex), + Type t when t == typeof(DateTime) || t == typeof(DateTime?) => reader.GetDateTime(propertyIndex), + Type t when t == typeof(DateTimeOffset) || t == typeof(DateTimeOffset?) => reader.GetFieldValue(propertyIndex), + Type t when t == typeof(Guid) => reader.GetFieldValue(propertyIndex), + _ => reader.GetValue(propertyIndex) + }; + } + + public static NpgsqlDbType? GetNpgsqlDbType(Type propertyType) => + propertyType switch + { + Type t when t == typeof(bool) || t == typeof(bool?) => NpgsqlDbType.Boolean, + Type t when t == typeof(short) || t == typeof(short?) => NpgsqlDbType.Smallint, + Type t when t == typeof(int) || t == typeof(int?) => NpgsqlDbType.Integer, + Type t when t == typeof(long) || t == typeof(long?) => NpgsqlDbType.Bigint, + Type t when t == typeof(float) || t == typeof(float?) => NpgsqlDbType.Real, + Type t when t == typeof(double) || t == typeof(double?) => NpgsqlDbType.Double, + Type t when t == typeof(decimal) || t == typeof(decimal?) => NpgsqlDbType.Numeric, + Type t when t == typeof(string) => NpgsqlDbType.Text, + Type t when t == typeof(byte[]) => NpgsqlDbType.Bytea, + Type t when t == typeof(DateTime) || t == typeof(DateTime?) => NpgsqlDbType.Timestamp, + Type t when t == typeof(DateTimeOffset) || t == typeof(DateTimeOffset?) => NpgsqlDbType.TimestampTz, + Type t when t == typeof(Guid) => NpgsqlDbType.Uuid, + _ => null + }; + + /// + /// Maps a .NET type to a PostgreSQL type name. + /// + /// The .NET type. + /// Tuple of the the PostgreSQL type name and whether it can be NULL + public static (string PgType, bool IsNullable) GetPostgresTypeName(Type propertyType) + { + var (pgType, isNullable) = propertyType switch + { + Type t when t == typeof(bool) => ("BOOLEAN", false), + Type t when t == typeof(short) => ("SMALLINT", false), + Type t when t == typeof(int) => ("INTEGER", false), + Type t when t == typeof(long) => ("BIGINT", false), + Type t when t == typeof(float) => ("REAL", false), + Type t when t == typeof(double) => ("DOUBLE PRECISION", false), + Type t when t == typeof(decimal) => ("NUMERIC", false), + Type t when t == typeof(string) => ("TEXT", true), + Type t when t == typeof(byte[]) => ("BYTEA", true), + Type t when t == typeof(DateTime) => ("TIMESTAMP", false), + Type t when t == typeof(DateTimeOffset) => ("TIMESTAMPTZ", false), + Type t when t == typeof(Guid) => ("UUID", false), + _ => (null, false) + }; + + if (pgType != null) + { + return (pgType, isNullable); + } + + // Handle enumerables + if (VectorStoreRecordPropertyVerification.IsSupportedEnumerableType(propertyType)) + { + Type elementType = propertyType.GetGenericArguments()[0]; + var underlyingPgType = GetPostgresTypeName(elementType); + return (underlyingPgType.PgType + "[]", true); + } + + // Handle nullable types (e.g. Nullable) + if (Nullable.GetUnderlyingType(propertyType) != null) + { + Type underlyingType = Nullable.GetUnderlyingType(propertyType) ?? throw new ArgumentException("Nullable type must have an underlying type."); + var underlyingPgType = GetPostgresTypeName(underlyingType); + return (underlyingPgType.PgType, true); + } + + throw new NotSupportedException($"Type {propertyType.Name} is not supported by this store."); + } + + /// + /// Gets the PostgreSQL vector type name based on the dimensions of the vector property. + /// + /// The vector property. + /// The PostgreSQL vector type name. + public static (string PgType, bool IsNullable) GetPgVectorTypeName(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty.Dimensions <= 0) + { + throw new ArgumentException("Vector property must have a positive number of dimensions."); + } + + return ($"VECTOR({vectorProperty.Dimensions})", Nullable.GetUnderlyingType(vectorProperty.PropertyType) != null); + } + + public static NpgsqlParameter GetNpgsqlParameter(object? value) + { + if (value == null) + { + return new NpgsqlParameter() { Value = DBNull.Value }; + } + + // If it's an IEnumerable, use reflection to determine if it needs to be converted to a list + if (value is IEnumerable enumerable && !(value is string)) + { + Type propertyType = value.GetType(); + if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(List<>)) + { + // If it's already a List, return it directly + return new NpgsqlParameter() { Value = value }; + } + + return new NpgsqlParameter() { Value = ConvertToListIfNecessary(enumerable) }; + } + + // Return the value directly if it's not IEnumerable + return new NpgsqlParameter() { Value = value }; + } + + /// + /// Returns information about vector indexes to create, validating that the dimensions of the vector are supported. + /// + /// The properties of the vector store record. + /// A list of tuples containing the column name, index kind, and distance function for each vector property. + /// + /// The default index kind is "Flat", which prevents the creation of an index. + /// + public static List<(string column, string kind, string function)> GetVectorIndexInfo(IReadOnlyList properties) + { + var vectorIndexesToCreate = new List<(string column, string kind, string function)>(); + foreach (var property in properties) + { + if (property is VectorStoreRecordVectorProperty vectorProperty) + { + var vectorColumnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + var indexKind = vectorProperty.IndexKind ?? PostgresConstants.DefaultIndexKind; + var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; + + // Index kind of "Flat" to prevent the creation of an index. This is the default behavior. + // Otherwise, the index will be created with the specified index kind and distance function, if supported. + if (indexKind != IndexKind.Flat) + { + // Ensure the dimensionality of the vector is supported for indexing. + if (PostgresConstants.IndexMaxDimensions.TryGetValue(indexKind, out int maxDimensions) && vectorProperty.Dimensions > maxDimensions) + { + throw new NotSupportedException( + $"The provided vector property {vectorProperty.DataModelPropertyName} has {vectorProperty.Dimensions} dimensions, " + + $"which is not supported by the {indexKind} index. The maximum number of dimensions supported by the {indexKind} index " + + $"is {maxDimensions}. Please reduce the number of dimensions or use a different index." + ); + } + + vectorIndexesToCreate.Add((vectorColumnName, indexKind, distanceFunction)); + } + } + } + return vectorIndexesToCreate; + } + + // Helper method to convert an IEnumerable to a List if necessary + private static object ConvertToListIfNecessary(IEnumerable enumerable) + { + // Get an enumerator to manually iterate over the collection + var enumerator = enumerable.GetEnumerator(); + + // Check if the collection is empty by attempting to move to the first element + if (!enumerator.MoveNext()) + { + return enumerable; // Return the original enumerable if it's empty + } + + // Determine the type of the first element + var firstItem = enumerator.Current; + var itemType = firstItem?.GetType() ?? typeof(object); + + // Create a strongly-typed List based on the type of the first element + var typedList = Activator.CreateInstance(typeof(List<>).MakeGenericType(itemType)) as IList; + typedList!.Add(firstItem); // Add the first element to the typed list + + // Continue iterating through the rest of the enumerable and add items to the list + while (enumerator.MoveNext()) + { + typedList.Add(enumerator.Current); + } + + return typedList; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs new file mode 100644 index 000000000000..27fa7181bdc5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal static class PostgresVectorStoreUtils +{ + /// + /// Wraps an in an that will throw a + /// if an exception is thrown while iterating over the original enumerator. + /// + /// The type of the items in the async enumerable. + /// The async enumerable to wrap. + /// The name of the operation being performed. + /// The name of the collection being operated on. + /// An async enumerable that will throw a if an exception is thrown while iterating over the original enumerator. + public static async IAsyncEnumerable WrapAsyncEnumerableAsync(IAsyncEnumerable asyncEnumerable, string operationName, string? collectionName = null) + { + var enumerator = asyncEnumerable.ConfigureAwait(false).GetAsyncEnumerator(); + + var nextResult = await GetNextAsync(enumerator, operationName, collectionName).ConfigureAwait(false); + while (nextResult.more) + { + yield return nextResult.item; + nextResult = await GetNextAsync(enumerator, operationName, collectionName).ConfigureAwait(false); + } + } + + /// + /// Helper method to get the next index name from the enumerator with a try catch around the move next call to convert + /// exceptions to . + /// + /// The enumerator to get the next result from. + /// The name of the operation being performed. + /// The name of the collection being operated on. + /// A value indicating whether there are more results and the current string if true. + public static async Task<(T item, bool more)> GetNextAsync(ConfiguredCancelableAsyncEnumerable.Enumerator enumerator, string operationName, string? collectionName = null) + { + try + { + var more = await enumerator.MoveNextAsync(); + return (enumerator.Current, more); + } + catch (Exception ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = PostgresConstants.DatabaseName, + CollectionName = collectionName, + OperationName = operationName + }; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md index 35c80a45087a..e9ed71109fbb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md @@ -18,7 +18,7 @@ This extension is also available for **Azure Database for PostgreSQL - Flexible 1. To install pgvector using Docker: ```bash -docker run -d --name postgres-pgvector -p 5432:5432 -e POSTGRES_PASSWORD=mysecretpassword ankane/pgvector +docker run -d --name postgres-pgvector -p 5432:5432 -e POSTGRES_PASSWORD=mysecretpassword pgvector/pgvector ``` 2. Create a database and enable pgvector extension on this database @@ -33,8 +33,13 @@ sk_demo=# CREATE EXTENSION vector; > Note, "Azure Cosmos DB for PostgreSQL" uses `SELECT CREATE_EXTENSION('vector');` to enable the extension. -3. To use Postgres as a semantic memory store: - > See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. +### Using PostgresVectorStore + +See [this sample](../../../samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs) for an example of using the vector store. + +### Using PostgresMemoryStore + +> See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. ```csharp NpgsqlDataSourceBuilder dataSourceBuilder = new NpgsqlDataSourceBuilder("Host=localhost;Port=5432;Database=sk_demo;User Id=postgres;Password=mysecretpassword"); @@ -87,67 +92,3 @@ BEGIN END IF; END $$; ``` - -## Migration from older versions - -Since Postgres Memory connector has been re-implemented, the new implementation uses a separate table to store each Collection. - -We provide the following migration script to help you migrate to the new structure. However, please note that due to the use of collections as table names, you need to make sure that all Collections conform to the [Postgres naming convention](https://www.postgresql.org/docs/15/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS) before migrating. - -- Table names may only consist of ASCII letters, digits, and underscores. -- Table names must start with a letter or an underscore. -- Table names may not exceed 63 characters in length. -- Table names are case-insensitive, but it is recommended to use lowercase letters. - -```sql --- Create new tables, each with the name of the collection field value -DO $$ -DECLARE - r record; - c_count integer; -BEGIN - FOR r IN SELECT DISTINCT collection FROM sk_memory_table LOOP - - -- Drop Table (This will delete the table that already exists. Please consider carefully if you think you need to cancel this comment!) - -- EXECUTE format('DROP TABLE IF EXISTS %I;', r.collection); - - -- Create Table (Modify vector size on demand) - EXECUTE format('CREATE TABLE public.%I ( - key TEXT NOT NULL, - metadata JSONB, - embedding vector(1536), - timestamp TIMESTAMP WITH TIME ZONE, - PRIMARY KEY (key) - );', r.collection); - - -- Get count of records in collection - SELECT count(*) INTO c_count FROM sk_memory_table WHERE collection = r.collection AND key <> ''; - - -- Create Index (https://github.com/pgvector/pgvector#indexing) - IF c_count > 10000000 THEN - EXECUTE format('CREATE INDEX %I - ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', - r.collection || '_ix', r.collection, ROUND(sqrt(c_count))); - ELSIF c_count > 10000 THEN - EXECUTE format('CREATE INDEX %I - ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', - r.collection || '_ix', r.collection, c_count / 1000); - END IF; - END LOOP; -END $$; - --- Copy data from the old table to the new table -DO $$ -DECLARE - r record; -BEGIN - FOR r IN SELECT DISTINCT collection FROM sk_memory_table LOOP - EXECUTE format('INSERT INTO public.%I (key, metadata, embedding, timestamp) - SELECT key, metadata::JSONB, embedding, to_timestamp(timestamp / 1000.0) AT TIME ZONE ''UTC'' - FROM sk_memory_table WHERE collection = %L AND key <> '''';', r.collection, r.collection); - END LOOP; -END $$; - --- Drop old table (After ensuring successful execution, you can remove the following comments to remove sk_memory_table.) --- DROP TABLE IF EXISTS sk_memory_table; -``` diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj new file mode 100644 index 000000000000..5698a909022e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj @@ -0,0 +1,32 @@ + + + + SemanticKernel.Connectors.Postgres.UnitTests + SemanticKernel.Connectors.Postgres.UnitTests + net8.0 + true + enable + disable + false + $(NoWarn);SKEXP0001,SKEXP0020,VSTHRD111,CA2007,CS1591 + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs new file mode 100644 index 000000000000..d9e97fc6b855 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresGenericDataModelMapperTests +{ + [Fact] + public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetGenericDataModel("key"); + + var mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal("key", result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + + var vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Fact] + public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetGenericDataModel(1); + + var mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal(1, result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + + var vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = "key", + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["FloatVector"] = storageVector + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + var mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal("key", result.Key); + Assert.Equal("Value1", result.Data["StringProperty"]); + Assert.Equal(5, result.Data["IntProperty"]); + + if (includeVectors) + { + Assert.NotNull(result.Vectors["FloatVector"]); + Assert.Equal(vector.ToArray(), ((ReadOnlyMemory)result.Vectors["FloatVector"]!).ToArray()); + } + else + { + Assert.False(result.Vectors.ContainsKey("FloatVector")); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = 1, + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["FloatVector"] = storageVector + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + var mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal(1, result.Key); + Assert.Equal("Value1", result.Data["StringProperty"]); + Assert.Equal(5, result.Data["IntProperty"]); + + if (includeVectors) + { + Assert.NotNull(result.Vectors["FloatVector"]); + Assert.Equal(vector.ToArray(), ((ReadOnlyMemory)result.Vectors["FloatVector"]!).ToArray()); + } + else + { + Assert.False(result.Vectors.ContainsKey("FloatVector")); + } + } + + #region private + + private static VectorStoreRecordDefinition GetRecordDefinition() + { + return new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(TKey)), + new VectorStoreRecordDataProperty("StringProperty", typeof(string)), + new VectorStoreRecordDataProperty("IntProperty", typeof(int)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + } + }; + } + + private static VectorStoreGenericDataModel GetGenericDataModel(TKey key) + { + return new VectorStoreGenericDataModel(key) + { + Data = new() + { + ["StringProperty"] = "Value1", + ["IntProperty"] = 5 + }, + Vectors = new() + { + ["FloatVector"] = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]) + } + }; + } + + private static VectorStoreRecordPropertyReader GetPropertyReader(VectorStoreRecordDefinition definition) + { + return new VectorStoreRecordPropertyReader(typeof(TRecord), definition, new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true + }); + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs new file mode 100644 index 000000000000..e8e84badf292 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + +/// +/// A test model for the postgres vector store. +/// +public record PostgresHotel() +{ + /// The key of the record. + [VectorStoreRecordKey] + public T HotelId { get; init; } + + /// A string metadata field. + [VectorStoreRecordData()] + public string? HotelName { get; set; } + + /// An int metadata field. + [VectorStoreRecordData()] + public int HotelCode { get; set; } + + /// A float metadata field. + [VectorStoreRecordData()] + public float? HotelRating { get; set; } + + /// A bool metadata field. + [VectorStoreRecordData(StoragePropertyName = "parking_is_included")] + public bool ParkingIncluded { get; set; } + + [VectorStoreRecordData] + public List Tags { get; set; } = []; + + /// A data field. + [VectorStoreRecordData] + public string Description { get; set; } + + public DateTime CreatedAt { get; set; } = DateTime.UtcNow; + + public DateTimeOffset UpdatedAt { get; set; } = DateTimeOffset.UtcNow; + + /// A vector field. + [VectorStoreRecordVector(4, IndexKind.Hnsw, DistanceFunction.ManhattanDistance)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } +} +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..f667d86eee30 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresServiceCollectionExtensionsTests +{ + private readonly IServiceCollection _serviceCollection = new ServiceCollection(); + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange + using var dataSource = NpgsqlDataSource.Create("Host=fake;"); + this._serviceCollection.AddSingleton(dataSource); + + // Act + this._serviceCollection.AddPostgresVectorStore(); + + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + var vectorStore = serviceProvider.GetRequiredService(); + + // Assert + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } + + [Fact] + public void AddVectorStoreRecordCollectionRegistersClass() + { + // Arrange + using var dataSource = NpgsqlDataSource.Create("Host=fake;"); + this._serviceCollection.AddSingleton(dataSource); + + // Act + this._serviceCollection.AddPostgresVectorStoreRecordCollection("testcollection"); + + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + + // Assert + var collection = serviceProvider.GetRequiredService>(); + Assert.NotNull(collection); + Assert.IsType>(collection); + + var vectorizedSearch = serviceProvider.GetRequiredService>(); + Assert.NotNull(vectorizedSearch); + Assert.IsType>(vectorizedSearch); + } + + #region private + +#pragma warning disable CA1812 // Avoid uninstantiated internal classes + private sealed class TestRecord +#pragma warning restore CA1812 // Avoid uninstantiated internal classes + { + [VectorStoreRecordKey] + public string Id { get; set; } = string.Empty; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs new file mode 100644 index 000000000000..675843a78c18 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -0,0 +1,422 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +public class PostgresVectorStoreCollectionSqlBuilderTests +{ + private readonly ITestOutputHelper _output; + private static readonly float[] s_vector = new float[] { 1.0f, 2.0f, 3.0f }; + + public PostgresVectorStoreCollectionSqlBuilderTests(ITestOutputHelper output) + { + this._output = output; + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void TestBuildCreateTableCommand(bool ifNotExists) + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var cmdInfo = builder.BuildCreateTableCommand("public", "testcollection", recordDefinition.Properties, ifNotExists: ifNotExists); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("public.\"testcollection\" (", cmdInfo.CommandText); + Assert.Contains("\"name\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"code\" INTEGER NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"rating\" REAL", cmdInfo.CommandText); + Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"free_parking\" BOOLEAN NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"tags\" TEXT[]", cmdInfo.CommandText); + Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"embedding1\" VECTOR(10) NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"embedding2\" VECTOR(10)", cmdInfo.CommandText); + Assert.Contains("PRIMARY KEY (\"id\")", cmdInfo.CommandText); + + if (ifNotExists) + { + Assert.Contains("IF NOT EXISTS", cmdInfo.CommandText); + } + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Theory] + [InlineData(IndexKind.Hnsw, DistanceFunction.EuclideanDistance)] + [InlineData(IndexKind.IvfFlat, DistanceFunction.DotProductSimilarity)] + [InlineData(IndexKind.Hnsw, DistanceFunction.CosineDistance)] + public void TestBuildCreateIndexCommand(string indexKind, string distanceFunction) + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var vectorColumn = "embedding1"; + + if (indexKind != IndexKind.Hnsw) + { + Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction)); + return; + } + + var cmdInfo = builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("CREATE INDEX ", cmdInfo.CommandText); + Assert.Contains("ON public.\"testcollection\" USING hnsw (\"embedding1\" ", cmdInfo.CommandText); + if (distanceFunction == null) + { + // Check for distance function defaults to cosine distance + Assert.Contains("vector_cosine_ops)", cmdInfo.CommandText); + } + else if (distanceFunction == DistanceFunction.CosineDistance) + { + Assert.Contains("vector_cosine_ops)", cmdInfo.CommandText); + } + else if (distanceFunction == DistanceFunction.EuclideanDistance) + { + Assert.Contains("vector_l2_ops)", cmdInfo.CommandText); + } + else + { + throw new NotImplementedException($"Test case for Distance function {distanceFunction} is not implemented."); + } + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildDropTableCommand() + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var cmdInfo = builder.BuildDropTableCommand("public", "testcollection"); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("DROP TABLE IF EXISTS public.\"testcollection\"", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildUpsertCommand() + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var row = new Dictionary() + { + ["id"] = 123, + ["name"] = "Hotel", + ["code"] = 456, + ["rating"] = 4.5f, + ["description"] = "Hotel description", + ["parking_is_included"] = true, + ["tags"] = new List { "tag1", "tag2" }, + ["embedding1"] = new Vector(s_vector), + }; + + var keyColumn = "id"; + + var cmdInfo = builder.BuildUpsertCommand("public", "testcollection", keyColumn, row); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("INSERT INTO public.\"testcollection\" (", cmdInfo.CommandText); + Assert.Contains("ON CONFLICT (\"id\")", cmdInfo.CommandText); + Assert.Contains("DO UPDATE SET", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + + foreach (var (key, index) in row.Keys.Select((key, index) => (key, index))) + { + Assert.Equal(row[key], cmdInfo.Parameters[index].Value); + // If the key is not the key column, it should be included in the update clause. + if (key != keyColumn) + { + Assert.Contains($"\"{key}\"=${index + 1}", cmdInfo.CommandText); + } + } + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildUpsertBatchCommand() + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var rows = new List>() + { + new() + { + ["id"] = 123, + ["name"] = "Hotel", + ["code"] = 456, + ["rating"] = 4.5f, + ["description"] = "Hotel description", + ["parking_is_included"] = true, + ["tags"] = new List { "tag1", "tag2" }, + ["embedding1"] = new Vector(s_vector), + }, + new() + { + ["id"] = 124, + ["name"] = "Motel", + ["code"] = 457, + ["rating"] = 4.6f, + ["description"] = "Motel description", + ["parking_is_included"] = false, + ["tags"] = new List { "tag3", "tag4" }, + ["embedding1"] = new Vector(s_vector), + }, + }; + + var keyColumn = "id"; + var columnCount = rows.First().Count; + + var cmdInfo = builder.BuildUpsertBatchCommand("public", "testcollection", keyColumn, rows); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("INSERT INTO public.\"testcollection\" (", cmdInfo.CommandText); + Assert.Contains("ON CONFLICT (\"id\")", cmdInfo.CommandText); + Assert.Contains("DO UPDATE SET", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + + foreach (var (row, rowIndex) in rows.Select((row, rowIndex) => (row, rowIndex))) + { + foreach (var (column, columnIndex) in row.Keys.Select((key, index) => (key, index))) + { + Assert.Equal(row[column], cmdInfo.Parameters[columnIndex + (rowIndex * columnCount)].Value); + // If the key is not the key column, it should be included in the update clause. + if (column != keyColumn) + { + Assert.Contains($"\"{column}\" = EXCLUDED.\"{column}\"", cmdInfo.CommandText); + } + } + } + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildGetCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var key = 123; + + // Act + var cmdInfo = builder.BuildGetCommand("public", "testcollection", recordDefinition.Properties, key, includeVectors: true); + + // Assert + Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("\"free_parking\"", cmdInfo.CommandText); + Assert.Contains("\"embedding1\"", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = $1", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildGetBatchCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var keys = new List { 123, 124 }; + + // Act + var cmdInfo = builder.BuildGetBatchCommand("public", "testcollection", recordDefinition.Properties, keys, includeVectors: true); + + // Assert + Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("\"code\"", cmdInfo.CommandText); + Assert.Contains("\"free_parking\"", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = ANY($1)", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + Assert.Single(cmdInfo.Parameters); + Assert.Equal(keys, cmdInfo.Parameters[0].Value); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildDeleteCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var key = 123; + + // Act + var cmdInfo = builder.BuildDeleteCommand("public", "testcollection", "id", key); + + // Assert + Assert.Contains("DELETE", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = $1", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildDeleteBatchCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var keys = new List { 123, 124 }; + + // Act + var cmdInfo = builder.BuildDeleteBatchCommand("public", "testcollection", "id", keys); + + // Assert + Assert.Contains("DELETE", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = ANY($1)", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + Assert.Single(cmdInfo.Parameters); + Assert.Equal(keys, cmdInfo.Parameters[0].Value); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildGetNearestMatchCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var vectorProperty = new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }; + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("tags", typeof(List)), + vectorProperty, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var vector = new Vector(s_vector); + + // Act + var cmdInfo = builder.BuildGetNearestMatchCommand("public", "testcollection", + properties: recordDefinition.Properties, + vectorProperty: vectorProperty, + vectorValue: vector, + filter: null, + skip: null, + includeVectors: true, + limit: 10); + + // Assert + Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("ORDER BY", cmdInfo.CommandText); + Assert.Contains("LIMIT 10", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..0533ab28c3f3 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Moq; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +public class PostgresVectorStoreRecordCollectionTests +{ + private const string TestCollectionName = "testcollection"; + + private readonly Mock _postgresClientMock; + private readonly CancellationToken _testCancellationToken = new(false); + + public PostgresVectorStoreRecordCollectionTests() + { + this._postgresClientMock = new Mock(MockBehavior.Strict); + } + + [Fact] + public async Task CreatesCollectionForGenericModelAsync() + { + // Arrange + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = [ + new VectorStoreRecordKeyProperty("HotelId", typeof(int)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 100, DistanceFunction = DistanceFunction.ManhattanDistance } + ] + }; + var options = new PostgresVectorStoreRecordCollectionOptions>() + { + VectorStoreRecordDefinition = recordDefinition + }; + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options); + this._postgresClientMock.Setup(x => x.DoesTableExistsAsync(TestCollectionName, this._testCancellationToken)).ReturnsAsync(false); + + // Act + var exists = await sut.CollectionExistsAsync(); + + // Assert. + Assert.False(exists); + } + + [Fact] + public void ThrowsForUnsupportedType() + { + // Arrange + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = [ + new VectorStoreRecordKeyProperty("HotelId", typeof(ulong)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + ] + }; + var options = new PostgresVectorStoreRecordCollectionOptions>() + { + VectorStoreRecordDefinition = recordDefinition + }; + + // Act & Assert + Assert.Throws(() => new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options)); + } + + [Fact] + public async Task UpsertRecordAsyncProducesExpectedClientCallAsync() + { + // Arrange + Dictionary? capturedArguments = null; + + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName); + var record = new PostgresHotel + { + HotelId = 1, + HotelName = "Hotel 1", + HotelCode = 1, + HotelRating = 4.5f, + ParkingIncluded = true, + Tags = ["tag1", "tag2"], + Description = "A hotel", + DescriptionEmbedding = new ReadOnlyMemory([1.0f, 2.0f, 3.0f, 4.0f]) + }; + + this._postgresClientMock.Setup(x => x.UpsertAsync( + TestCollectionName, + It.IsAny>(), + "HotelId", + this._testCancellationToken)) + .Callback, string, CancellationToken>((collectionName, args, key, ct) => capturedArguments = args) + .Returns(Task.CompletedTask); + + // Act + await sut.UpsertAsync(record, cancellationToken: this._testCancellationToken); + + // Assert + Assert.NotNull(capturedArguments); + Assert.Equal(1, (int)(capturedArguments["HotelId"] ?? 0)); + Assert.Equal("Hotel 1", (string)(capturedArguments["HotelName"] ?? "")); + Assert.Equal(1, (int)(capturedArguments["HotelCode"] ?? 0)); + Assert.Equal(4.5f, (float)(capturedArguments["HotelRating"] ?? 0.0f)); + Assert.True((bool)(capturedArguments["parking_is_included"] ?? false)); + Assert.True(capturedArguments["Tags"] is List); + var tags = capturedArguments["Tags"] as List; + Assert.Equal(2, tags!.Count); + Assert.Equal("tag1", tags[0]); + Assert.Equal("tag2", tags[1]); + Assert.Equal("A hotel", (string)(capturedArguments["Description"] ?? "")); + Assert.NotNull(capturedArguments["DescriptionEmbedding"]); + Assert.IsType(capturedArguments["DescriptionEmbedding"]); + var embedding = ((Vector)capturedArguments["DescriptionEmbedding"]!).ToArray(); + Assert.Equal(1.0f, embedding[0]); + Assert.Equal(2.0f, embedding[1]); + Assert.Equal(3.0f, embedding[2]); + Assert.Equal(4.0f, embedding[3]); + } + + [Fact] + public async Task CollectionExistsReturnsValidResultAsync() + { + // Arrange + const string TableName = "CollectionExists"; + + this._postgresClientMock.Setup(x => x.DoesTableExistsAsync(TableName, this._testCancellationToken)).ReturnsAsync(true); + + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TableName); + + // Act + var result = await sut.CollectionExistsAsync(); + + Assert.True(result); + } + + [Fact] + public async Task DeleteCollectionCallsClientDeleteAsync() + { + // Arrange + const string TableName = "DeleteCollection"; + + this._postgresClientMock.Setup(x => x.DeleteTableAsync(TableName, this._testCancellationToken)).Returns(Task.CompletedTask); + + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TableName); + + // Act + await sut.DeleteCollectionAsync(); + + // Assert + this._postgresClientMock.Verify(x => x.DeleteTableAsync(TableName, this._testCancellationToken), Times.Once); + } + + #region private + + private static void AssertRecord(TestRecord expectedRecord, TestRecord? actualRecord, bool includeVectors) + { + Assert.NotNull(actualRecord); + + Assert.Equal(expectedRecord.Key, actualRecord.Key); + Assert.Equal(expectedRecord.Data, actualRecord.Data); + + if (includeVectors) + { + Assert.NotNull(actualRecord.Vector); + Assert.Equal(expectedRecord.Vector!.Value.ToArray(), actualRecord.Vector.Value.Span.ToArray()); + } + else + { + Assert.Null(actualRecord.Vector); + } + } + +#pragma warning disable CA1812 + private sealed class TestRecord + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } + + [VectorStoreRecordData] + public string? Data { get; set; } + + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + public ReadOnlyMemory? Vector { get; set; } + } + + private sealed class TestRecordWithoutVectorProperty + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } + + [VectorStoreRecordData] + public string? Data { get; set; } + } +#pragma warning restore CA1812 + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs new file mode 100644 index 000000000000..11dfd2ecd564 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresVectorStoreRecordMapperTests +{ + [Fact] + public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetDataModel("key"); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal("key", result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + Assert.Equal(new List { "Value2", "Value3" }, result["StringArray"]); + + Vector? vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Fact] + public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetDataModel(1); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal((ulong)1, result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + Assert.Equal(new List { "Value2", "Value3" }, result["StringArray"]); + + var vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = "key", + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["StringArray"] = new List { "Value2", "Value3" }, + ["FloatVector"] = storageVector, + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal("key", result.Key); + Assert.Equal("Value1", result.StringProperty); + Assert.Equal(5, result.IntProperty); + Assert.Equal(new List { "Value2", "Value3" }, result.StringArray); + + if (includeVectors) + { + Assert.NotNull(result.FloatVector); + Assert.Equal(vector.Span.ToArray(), result.FloatVector.Value.Span.ToArray()); + } + else + { + Assert.Null(result.FloatVector); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = (ulong)1, + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["StringArray"] = new List { "Value2", "Value3" }, + ["FloatVector"] = storageVector + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal((ulong)1, result.Key); + Assert.Equal("Value1", result.StringProperty); + Assert.Equal(5, result.IntProperty); + Assert.Equal(new List { "Value2", "Value3" }, result.StringArray); + + if (includeVectors) + { + Assert.NotNull(result.FloatVector); + Assert.Equal(vector.Span.ToArray(), result.FloatVector.Value.Span.ToArray()); + } + else + { + Assert.Null(result.FloatVector); + } + } + + #region private + + private static VectorStoreRecordDefinition GetRecordDefinition() + { + return new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(TKey)), + new VectorStoreRecordDataProperty("StringProperty", typeof(string)), + new VectorStoreRecordDataProperty("IntProperty", typeof(int)), + new VectorStoreRecordDataProperty("StringArray", typeof(IEnumerable)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + } + }; + } + + private static TestRecord GetDataModel(TKey key) + { + return new TestRecord + { + Key = key, + StringProperty = "Value1", + IntProperty = 5, + StringArray = new List { "Value2", "Value3" }, + FloatVector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]) + }; + } + + private static VectorStoreRecordPropertyReader GetPropertyReader(VectorStoreRecordDefinition definition) + { + return new VectorStoreRecordPropertyReader(typeof(TRecord), definition, new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true + }); + } + +#pragma warning disable CA1812 + private sealed class TestRecord + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } + + [VectorStoreRecordData] + public string? StringProperty { get; set; } + + [VectorStoreRecordData] + public int? IntProperty { get; set; } + + [VectorStoreRecordData] + public IEnumerable? StringArray { get; set; } + + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + public ReadOnlyMemory? FloatVector { get; set; } + } +#pragma warning restore CA1812 + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs new file mode 100644 index 000000000000..0631cc2c0df4 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresVectorStoreRecordPropertyMappingTests +{ + [Fact] + public void MapVectorForStorageModelWithInvalidVectorTypeThrowsException() + { + // Arrange + var vector = new float[] { 1f, 2f, 3f }; + + // Act & Assert + Assert.Throws(() => PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector)); + } + + [Fact] + public void MapVectorForStorageModelReturnsVector() + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + + // Act + var storageModelVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + // Assert + Assert.IsType(storageModelVector); + Assert.True(storageModelVector.ToArray().Length > 0); + } + + [Fact] + public void MapVectorForDataModelReturnsReadOnlyMemory() + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var pgVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + // Act + var dataModelVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(pgVector); + + // Assert + Assert.NotNull(dataModelVector); + Assert.Equal(vector.ToArray(), dataModelVector!.Value.ToArray()); + } + + [Fact] + public void GetPropertyValueReturnsCorrectValuesForLists() + { + // Arrange + var typesAndExpectedValues = new List<(Type, object)> + { + (typeof(List), "INTEGER[]"), + (typeof(List), "REAL[]"), + (typeof(List), "DOUBLE PRECISION[]"), + (typeof(List), "TEXT[]"), + (typeof(List), "BOOLEAN[]"), + (typeof(List), "TIMESTAMP[]"), + (typeof(List), "UUID[]"), + }; + + // Act & Assert + foreach (var (type, expectedValue) in typesAndExpectedValues) + { + var (pgType, _) = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(type); + Assert.Equal(expectedValue, pgType); + } + } + + [Fact] + public void GetPropertyValueReturnsCorrectNullableValue() + { + // Arrange + var typesAndExpectedValues = new List<(Type, object)> + { + (typeof(short), false), + (typeof(short?), true), + (typeof(int?), true), + (typeof(long), false), + (typeof(string), true), + (typeof(bool?), true), + (typeof(DateTime?), true), + (typeof(Guid), false), + }; + + // Act & Assert + foreach (var (type, expectedValue) in typesAndExpectedValues) + { + var (_, isNullable) = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(type); + Assert.Equal(expectedValue, isNullable); + } + } + + [Fact] + public void GetVectorIndexInfoReturnsCorrectValues() + { + // Arrange + List vectorProperties = [ + new VectorStoreRecordVectorProperty("vector1", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 1000 }, + new VectorStoreRecordVectorProperty("vector2", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Flat, Dimensions = 3000 }, + new VectorStoreRecordVectorProperty("vector3", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 900, DistanceFunction = DistanceFunction.ManhattanDistance }, + ]; + + // Act + var indexInfo = PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo(vectorProperties); + + // Assert + Assert.Equal(2, indexInfo.Count); + foreach (var (columnName, indexKind, distanceFunction) in indexInfo) + { + if (columnName == "vector1") + { + Assert.Equal(IndexKind.Hnsw, indexKind); + Assert.Equal(DistanceFunction.CosineDistance, distanceFunction); + } + else if (columnName == "vector3") + { + Assert.Equal(IndexKind.Hnsw, indexKind); + Assert.Equal(DistanceFunction.ManhattanDistance, distanceFunction); + } + else + { + Assert.Fail("Unexpected column name"); + } + } + } + + [Theory] + [InlineData(IndexKind.Hnsw, 3000)] + public void GetVectorIndexInfoReturnsThrowsForInvalidDimensions(string indexKind, int dimensions) + { + // Arrange + var vectorProperty = new VectorStoreRecordVectorProperty("vector", typeof(ReadOnlyMemory?)) { IndexKind = indexKind, Dimensions = dimensions }; + + // Act & Assert + Assert.Throws(() => PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo([vectorProperty])); + } +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs new file mode 100644 index 000000000000..b11d6a81963f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Moq; +using Npgsql; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Contains tests for the class. +/// +public class PostgresVectorStoreTests +{ + private const string TestCollectionName = "testcollection"; + + private readonly Mock _postgresClientMock; + private readonly CancellationToken _testCancellationToken = new(false); + + public PostgresVectorStoreTests() + { + this._postgresClientMock = new Mock(MockBehavior.Strict); + } + + [Fact] + public void GetCollectionReturnsCollection() + { + // Arrange. + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.NotNull(actual); + Assert.IsType>>(actual); + } + + [Fact] + public void GetCollectionThrowsForInvalidKeyType() + { + // Arrange. + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act and Assert. + Assert.Throws(() => sut.GetCollection>(TestCollectionName)); + } + + [Fact] + public void GetCollectionCallsFactoryIfProvided() + { + // Arrange. + var factoryMock = new Mock(MockBehavior.Strict); + var collectionMock = new Mock>>(MockBehavior.Strict); + var clientMock = new Mock(MockBehavior.Strict); + clientMock.Setup(x => x.DataSource).Returns(null); + factoryMock + .Setup(x => x.CreateVectorStoreRecordCollection>(It.IsAny(), TestCollectionName, null)) + .Returns(collectionMock.Object); + var sut = new PostgresVectorStore(clientMock.Object, new() { VectorStoreCollectionFactory = factoryMock.Object }); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.Equal(collectionMock.Object, actual); + } + + [Fact] + public async Task ListCollectionNamesCallsSDKAsync() + { + // Arrange + var expectedCollections = new List { "fake-collection-1", "fake-collection-2", "fake-collection-3" }; + + this._postgresClientMock + .Setup(client => client.GetTablesAsync(CancellationToken.None)) + .Returns(expectedCollections.ToAsyncEnumerable()); + + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act. + var actual = sut.ListCollectionNamesAsync(this._testCancellationToken); + + // Assert + Assert.NotNull(actual); + var actualList = await actual.ToListAsync(); + Assert.Equal(expectedCollections, actualList); + } + + [Fact] + public async Task ListCollectionNamesThrowsCorrectExceptionAsync() + { + // Arrange + var expectedCollections = new List { "fake-collection-1", "fake-collection-2", "fake-collection-3" }; + + this._postgresClientMock + .Setup(client => client.GetTablesAsync(CancellationToken.None)) + .Returns(this.ThrowingAsyncEnumerableAsync); + + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act. + var actual = sut.ListCollectionNamesAsync(this._testCancellationToken); + + // Assert + Assert.NotNull(actual); + await Assert.ThrowsAsync(async () => await actual.ToListAsync()); + } + + private async IAsyncEnumerable ThrowingAsyncEnumerableAsync() + { + int itemIndex = 0; + await foreach (var item in new List { "item1", "item2", "item3" }.ToAsyncEnumerable()) + { + if (itemIndex == 1) + { + throw new InvalidOperationException("Test exception"); + } + yield return item; + itemIndex++; + } + } + + public sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public required TKey Key { get; set; } + + [VectorStoreRecordData] + public string Data { get; set; } = string.Empty; + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? Vector { get; set; } + + public string? NotAnnotated { get; set; } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs new file mode 100644 index 000000000000..48a8f5f36a41 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + +/// +/// A test model for the postgres vector store. +/// +public record PostgresHotel() +{ + /// The key of the record. + [VectorStoreRecordKey] + public T HotelId { get; init; } + + /// A string metadata field. + [VectorStoreRecordData()] + public string? HotelName { get; set; } + + /// An int metadata field. + [VectorStoreRecordData()] + public int HotelCode { get; set; } + + /// A float metadata field. + [VectorStoreRecordData()] + public float? HotelRating { get; set; } + + /// A bool metadata field. + [VectorStoreRecordData(StoragePropertyName = "parking_is_included")] + public bool ParkingIncluded { get; set; } + + [VectorStoreRecordData] + public List Tags { get; set; } = []; + + [VectorStoreRecordData] + public List? ListInts { get; set; } = null; + + /// A data field. + [VectorStoreRecordData] + public string Description { get; set; } + + /// A vector field. + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.EuclideanDistance, IndexKind: IndexKind.Hnsw)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } + + public DateTime CreatedAt { get; set; } = DateTime.UtcNow; + + public DateTimeOffset UpdatedAt { get; set; } = DateTimeOffset.UtcNow; + + public PostgresHotel(T key) : this() + { + this.HotelId = key; + } +} + +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs index 19126a090874..71474ff0ebc6 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs @@ -11,7 +11,7 @@ using Npgsql; using Xunit; -namespace SemanticKernel.IntegrationTests.Connectors.Postgres; +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; /// /// Integration tests of . @@ -41,6 +41,8 @@ public async Task InitializeAsync() this._connectionString = connectionString; this._databaseName = $"sk_it_{Guid.NewGuid():N}"; + await this.CreateDatabaseAsync(); + NpgsqlConnectionStringBuilder connectionStringBuilder = new(this._connectionString) { Database = this._databaseName @@ -50,8 +52,6 @@ public async Task InitializeAsync() dataSourceBuilder.UseVector(); this._dataSource = dataSourceBuilder.Build(); - - await this.CreateDatabaseAsync(); } public async Task DisposeAsync() diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs new file mode 100644 index 000000000000..5d202af5b9f5 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +[CollectionDefinition("PostgresVectorStoreCollection")] +public class PostgresVectorStoreCollectionFixture : ICollectionFixture +{ +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs new file mode 100644 index 000000000000..5888a513ace0 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs @@ -0,0 +1,239 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Docker.DotNet; +using Docker.DotNet.Models; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +public class PostgresVectorStoreFixture : IAsyncLifetime +{ + /// The docker client we are using to create a postgres container with. + private readonly DockerClient _client; + + /// The id of the postgres container that we are testing with. + private string? _containerId = null; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + + /// + /// Initializes a new instance of the class. + /// + public PostgresVectorStoreFixture() + { + using var dockerClientConfiguration = new DockerClientConfiguration(); + this._client = dockerClientConfiguration.CreateClient(); + } + + /// + /// Holds the Npgsql data source to use for tests. + /// + private NpgsqlDataSource? _dataSource; + + private string _connectionString = null!; + private string _databaseName = null!; + + /// + /// Gets a vector store to use for tests. + /// + public IVectorStore VectorStore => new PostgresVectorStore(this._dataSource!); + + /// + /// Get a database connection + /// + public NpgsqlConnection GetConnection() + { + return this._dataSource!.OpenConnection(); + } + + public IVectorStoreRecordCollection GetCollection( + string collectionName, + VectorStoreRecordDefinition? recordDefinition = default) + where TKey : notnull + where TRecord : class + { + var vectorStore = this.VectorStore; + return vectorStore.GetCollection(collectionName, recordDefinition); + } + + /// + /// Create / Recreate postgres docker container and run it. + /// + /// An async task. + public async Task InitializeAsync() + { + this._containerId = await SetupPostgresContainerAsync(this._client); + this._connectionString = "Host=localhost;Port=5432;Username=postgres;Password=example;Database=postgres;"; + this._databaseName = $"sk_it_{Guid.NewGuid():N}"; + + // Connect to postgres. + NpgsqlConnectionStringBuilder connectionStringBuilder = new(this._connectionString) + { + Database = this._databaseName + }; + + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionStringBuilder.ToString()); + dataSourceBuilder.UseVector(); + + this._dataSource = dataSourceBuilder.Build(); + + // Wait for the postgres container to be ready and create the test database using the initial data source. + var initialDataSource = NpgsqlDataSource.Create(this._connectionString); + using (initialDataSource) + { + var retryCount = 0; + var exceptionCount = 0; + while (retryCount++ < 5) + { + try + { + NpgsqlConnection connection = await initialDataSource.OpenConnectionAsync().ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT count(*) FROM information_schema.tables WHERE table_schema = 'public';"; + await cmd.ExecuteScalarAsync().ConfigureAwait(false); + } + } + catch (NpgsqlException) + { + exceptionCount++; + await Task.Delay(1000); + } + } + + if (exceptionCount >= 5) + { + // Throw an exception for test setup + throw new InvalidOperationException("Postgres container did not start in time."); + } + + await this.CreateDatabaseAsync(initialDataSource); + } + + // Create the table. + await this.CreateTableAsync(); + } + + private async Task CreateTableAsync() + { + NpgsqlConnection connection = await this._dataSource!.OpenConnectionAsync().ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = @" + CREATE TABLE hotel_info ( + HotelId INTEGER NOT NULL, + HotelName TEXT, + HotelCode INTEGER NOT NULL, + HotelRating REAL, + parking_is_included BOOLEAN, + Tags TEXT[] NOT NULL, + Description TEXT NOT NULL, + DescriptionEmbedding VECTOR(4) NOT NULL, + PRIMARY KEY (HotelId));"; + await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); + } + } + + /// + /// Delete the docker container after the test run. + /// + /// An async task. + public async Task DisposeAsync() + { + if (this._dataSource != null) + { + this._dataSource.Dispose(); + } + + await this.DropDatabaseAsync(); + + if (this._containerId != null) + { + await this._client.Containers.StopContainerAsync(this._containerId, new ContainerStopParameters()); + await this._client.Containers.RemoveContainerAsync(this._containerId, new ContainerRemoveParameters()); + } + } + + /// + /// Setup the postgres container by pulling the image and running it. + /// + /// The docker client to create the container with. + /// The id of the container. + private static async Task SetupPostgresContainerAsync(DockerClient client) + { + await client.Images.CreateImageAsync( + new ImagesCreateParameters + { + FromImage = "pgvector/pgvector", + Tag = "pg16", + }, + null, + new Progress()); + + var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() + { + Image = "pgvector/pgvector:pg16", + HostConfig = new HostConfig() + { + PortBindings = new Dictionary> + { + {"5432", new List {new() {HostPort = "5432" } }}, + }, + PublishAllPorts = true + }, + ExposedPorts = new Dictionary + { + { "5432", default }, + }, + Env = new List + { + "POSTGRES_USER=postgres", + "POSTGRES_PASSWORD=example", + }, + }); + + await client.Containers.StartContainerAsync( + container.ID, + new ContainerStartParameters()); + + return container.ID; + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "The database name is generated randomly, it does not support parameterized passing.")] + private async Task CreateDatabaseAsync(NpgsqlDataSource initialDataSource) + { + await using (NpgsqlConnection conn = await initialDataSource.OpenConnectionAsync()) + { + await using NpgsqlCommand command = new($"CREATE DATABASE \"{this._databaseName}\"", conn); + await command.ExecuteNonQueryAsync(); + } + + await using (NpgsqlConnection conn = await this._dataSource!.OpenConnectionAsync()) + { + await using (NpgsqlCommand command = new("CREATE EXTENSION vector", conn)) + { + await command.ExecuteNonQueryAsync(); + } + await conn.ReloadTypesAsync(); + } + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "The database name is generated randomly, it does not support parameterized passing.")] + private async Task DropDatabaseAsync() + { + using NpgsqlDataSource dataSource = NpgsqlDataSource.Create(this._connectionString); + await using NpgsqlConnection conn = await dataSource.OpenConnectionAsync(); + await using NpgsqlCommand command = new($"DROP DATABASE IF EXISTS \"{this._databaseName}\"", conn); + await command.ExecuteNonQueryAsync(); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..7e3ae3ad9392 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -0,0 +1,562 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Npgsql; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +[Collection("PostgresVectorStoreCollection")] +public sealed class PostgresVectorStoreRecordCollectionTests(PostgresVectorStoreFixture fixture) +{ + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CollectionExistsReturnsCollectionStateAsync(bool createCollection) + { + // Arrange + var sut = fixture.GetCollection>("CollectionExists"); + + if (createCollection) + { + await sut.CreateCollectionAsync(); + } + + try + { + // Act + var collectionExists = await sut.CollectionExistsAsync(); + + // Assert + Assert.Equal(createCollection, collectionExists); + } + finally + { + // Cleanup + if (createCollection) + { + await sut.DeleteCollectionAsync(); + } + } + } + + [Fact] + public async Task CollectionCanUpsertAndGetAsync() + { + // Arrange + var sut = fixture.GetCollection>("CollectionCanUpsertAndGet"); + if (await sut.CollectionExistsAsync()) + { + await sut.DeleteCollectionAsync(); + } + + await sut.CreateCollectionAsync(); + + var writtenHotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + var writtenHotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, ListInts = [1, 2] }; + + try + { + // Act + + await sut.UpsertAsync(writtenHotel1); + + await sut.UpsertAsync(writtenHotel2); + + var fetchedHotel1 = await sut.GetAsync(1); + var fetchedHotel2 = await sut.GetAsync(2); + + // Assert + Assert.NotNull(fetchedHotel1); + Assert.Equal(1, fetchedHotel1!.HotelId); + Assert.Equal("Hotel 1", fetchedHotel1!.HotelName); + Assert.Equal(1, fetchedHotel1!.HotelCode); + Assert.True(fetchedHotel1!.ParkingIncluded); + Assert.Equal(4.5f, fetchedHotel1!.HotelRating); + Assert.NotNull(fetchedHotel1!.Tags); + Assert.Equal(2, fetchedHotel1!.Tags!.Count); + Assert.Equal("tag1", fetchedHotel1!.Tags![0]); + Assert.Equal("tag2", fetchedHotel1!.Tags![1]); + Assert.Null(fetchedHotel1!.ListInts); + Assert.Equal(TruncateMilliseconds(fetchedHotel1.CreatedAt), TruncateMilliseconds(writtenHotel1.CreatedAt)); + Assert.Equal(TruncateMilliseconds(fetchedHotel1.UpdatedAt), TruncateMilliseconds(writtenHotel1.UpdatedAt)); + + Assert.NotNull(fetchedHotel2); + Assert.Equal(2, fetchedHotel2!.HotelId); + Assert.Equal("Hotel 2", fetchedHotel2!.HotelName); + Assert.Equal(2, fetchedHotel2!.HotelCode); + Assert.False(fetchedHotel2!.ParkingIncluded); + Assert.Equal(2.5f, fetchedHotel2!.HotelRating); + Assert.NotNull(fetchedHotel2!.Tags); + Assert.Empty(fetchedHotel2!.Tags); + Assert.NotNull(fetchedHotel2!.ListInts); + Assert.Equal(2, fetchedHotel2!.ListInts!.Count); + Assert.Equal(1, fetchedHotel2!.ListInts![0]); + Assert.Equal(2, fetchedHotel2!.ListInts![1]); + Assert.Equal(TruncateMilliseconds(fetchedHotel2.CreatedAt), TruncateMilliseconds(writtenHotel2.CreatedAt)); + Assert.Equal(TruncateMilliseconds(fetchedHotel2.UpdatedAt), TruncateMilliseconds(writtenHotel2.UpdatedAt)); + } + finally + { + // Cleanup + await sut.DeleteCollectionAsync(); + } + } + + public static IEnumerable ItCanGetAndDeleteRecordParameters => + new List + { + new object[] { typeof(short), (short)3 }, + new object[] { typeof(int), 5 }, + new object[] { typeof(long), 7L }, + new object[] { typeof(string), "key1" }, + new object[] { typeof(Guid), Guid.NewGuid() } + }; + + [Theory] + [MemberData(nameof(ItCanGetAndDeleteRecordParameters))] + public async Task ItCanGetAndDeleteRecordAsync(Type idType, TKey? key) + { + // Arrange + var collectionName = "DeleteRecord"; + var sut = this.GetCollection(idType, collectionName); + + await sut.CreateCollectionAsync(); + + try + { + var record = this.CreateRecord(idType, key!); + var recordKey = record.HotelId; + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync(recordKey); + + Assert.Equal(key, upsertResult); + Assert.NotNull(getResult); + + // Act + await sut.DeleteAsync(recordKey); + + getResult = await sut.GetAsync(recordKey); + + // Assert + Assert.Null(getResult); + } + finally + { + // Cleanup + await sut.DeleteCollectionAsync(); + } + } + + [Fact] + public async Task ItCanGetUpsertDeleteBatchAsync() + { + // Arrange + const int HotelId1 = 1; + const int HotelId2 = 2; + const int HotelId3 = 3; + + var sut = fixture.GetCollection>("GetUpsertDeleteBatch"); + + await sut.CreateCollectionAsync(); + + var record1 = new PostgresHotel { HotelId = HotelId1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + var record2 = new PostgresHotel { HotelId = HotelId2, HotelName = "Hotel 2", HotelCode = 1, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag3"] }; + var record3 = new PostgresHotel { HotelId = HotelId3, HotelName = "Hotel 3", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag4"] }; + + var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); + var getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + + Assert.Equal([HotelId1, HotelId2, HotelId3], upsertResults); + + Assert.NotNull(getResults.First(l => l.HotelId == HotelId1)); + Assert.NotNull(getResults.First(l => l.HotelId == HotelId2)); + Assert.NotNull(getResults.First(l => l.HotelId == HotelId3)); + + // Act + await sut.DeleteBatchAsync([HotelId1, HotelId2, HotelId3]); + + getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + + // Assert + Assert.Empty(getResults); + } + + [Fact] + public async Task ItCanUpsertExistingRecordAsync() + { + // Arrange + const int HotelId = 5; + var sut = fixture.GetCollection>("UpsertRecord"); + + await sut.CreateCollectionAsync(); + + var record = new PostgresHotel { HotelId = HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync(HotelId, new() { IncludeVectors = true }); + + Assert.Equal(HotelId, upsertResult); + Assert.NotNull(getResult); + Assert.Null(getResult!.DescriptionEmbedding); + + // Act + record.HotelName = "Updated name"; + record.HotelRating = 10; + record.DescriptionEmbedding = new[] { 1f, 2f, 3f, 4f }; + + upsertResult = await sut.UpsertAsync(record); + getResult = await sut.GetAsync(HotelId, new() { IncludeVectors = true }); + + // Assert + Assert.NotNull(getResult); + Assert.Equal("Updated name", getResult.HotelName); + Assert.Equal(10, getResult.HotelRating); + + Assert.NotNull(getResult.DescriptionEmbedding); + Assert.Equal(record.DescriptionEmbedding!.Value.ToArray(), getResult.DescriptionEmbedding.Value.ToArray()); + } + + [Fact] + public async Task ItCanReadManuallyInsertedRecordAsync() + { + const string CollectionName = "ItCanReadManuallyInsertedRecordAsync"; + // Arrange + var sut = fixture.GetCollection>(CollectionName); + await sut.CreateCollectionAsync().ConfigureAwait(true); + Assert.True(await sut.CollectionExistsAsync().ConfigureAwait(true)); + await using (var connection = fixture.GetConnection()) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = @$" + INSERT INTO public.""{CollectionName}"" ( + ""HotelId"", ""HotelName"", ""HotelCode"", ""HotelRating"", ""parking_is_included"", ""Tags"", ""Description"", ""DescriptionEmbedding"" + ) VALUES ( + 215, 'Divine Lorraine', 215, 5, false, ARRAY['historic', 'philly'], 'An iconic building on broad street', '[10,20,30,40]' + );"; + await cmd.ExecuteNonQueryAsync().ConfigureAwait(true); + } + + // Act + var getResult = await sut.GetAsync(215, new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(getResult); + Assert.Equal(215, getResult!.HotelId); + Assert.Equal("Divine Lorraine", getResult.HotelName); + Assert.Equal(215, getResult.HotelCode); + Assert.Equal(5, getResult.HotelRating); + Assert.False(getResult.ParkingIncluded); + Assert.Equal(new List { "historic", "philly" }, getResult.Tags); + Assert.Equal("An iconic building on broad street", getResult.Description); + Assert.Equal([10f, 20f, 30f, 40f], getResult.DescriptionEmbedding!.Value.ToArray()); + } + + [Fact] + public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + { + const int HotelId = 5; + + var sut = fixture.GetCollection>("GenericMapperWithNumericKey", GetVectorStoreRecordDefinition()); + + await sut.CreateCollectionAsync(); + + var record = new PostgresHotel { HotelId = (int)HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + + // Act + var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) + { + Data = + { + { "HotelName", "Generic Mapper Hotel" }, + { "Description", "This is a generic mapper hotel" }, + { "HotelCode", 1 }, + { "ParkingIncluded", true }, + { "HotelRating", 3.6f } + }, + Vectors = + { + { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } + } + }); + + var localGetResult = await sut.GetAsync(HotelId, new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.Equal(HotelId, upsertResult); + + Assert.NotNull(localGetResult); + Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); + Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); + Assert.True((bool?)localGetResult.Data["ParkingIncluded"]); + Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); + Assert.Equal([30f, 31f, 32f, 33f], ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + + // Act - update with null embeddings + // Act + var upsertResult2 = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) + { + Data = + { + { "HotelName", "Generic Mapper Hotel" }, + { "Description", "This is a generic mapper hotel" }, + { "HotelCode", 1 }, + { "ParkingIncluded", true }, + { "HotelRating", 3.6f } + }, + Vectors = + { + { "DescriptionEmbedding", null } + } + }); + + var localGetResult2 = await sut.GetAsync(HotelId, new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(localGetResult2); + Assert.Null(localGetResult2.Vectors["DescriptionEmbedding"]); + } + + [Theory] + [InlineData(true, DistanceFunction.CosineDistance)] + [InlineData(false, DistanceFunction.CosineDistance)] + [InlineData(false, DistanceFunction.CosineSimilarity)] + [InlineData(false, DistanceFunction.EuclideanDistance)] + [InlineData(false, DistanceFunction.ManhattanDistance)] + [InlineData(false, DistanceFunction.DotProductSimilarity)] + public async Task VectorizedSearchReturnsValidResultsByDefaultAsync(bool includeVectors, string distanceFunction) + { + // Arrange + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 1f, 0f, 0f, 0f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 0f, 1f, 0f, 0f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 3.5f, Tags = ["tag1", "tag4"], DescriptionEmbedding = new[] { 0f, 0f, 1f, 0f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 1.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 0f, 0f, 0f, 1f } }; + + var sut = fixture.GetCollection>($"VectorizedSearch_{includeVectors}_{distanceFunction}", GetVectorStoreRecordDefinition(distanceFunction)); + + await sut.CreateCollectionAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([0.9f, 0.1f, 0.5f, 0.8f]), new() + { + IncludeVectors = includeVectors + }); + + var results = await searchResults.Results.ToListAsync(); + + // Assert + var ids = results.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal(1, ids[0]); + Assert.Equal(4, ids[1]); + Assert.Equal(3, ids[2]); + + // Default limit is 3 + Assert.DoesNotContain(2, ids); + + Assert.True(0 < results.First(l => l.Record.HotelId == 1).Score); + + Assert.Equal(includeVectors, results.All(result => result.Record.DescriptionEmbedding is not null)); + } + + [Fact] + public async Task VectorizedSearchWithEqualToFilterReturnsValidResultsAsync() + { + // Arrange + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; + + var sut = fixture.GetCollection>("VectorizedSearchWithEqualToFilter"); + + await sut.CreateCollectionAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 29f, 28f, 27f]), new() + { + IncludeVectors = false, + Top = 5, + Filter = new([ + new EqualToFilterClause("HotelRating", 2.5f) + ]) + }); + + var results = await searchResults.Results.ToListAsync(); + + // Assert + var ids = results.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal([1, 3, 2], ids); + } + + [Fact] + public async Task VectorizedSearchWithAnyTagFilterReturnsValidResultsAsync() + { + // Arrange + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag2", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; + + var sut = fixture.GetCollection>("VectorizedSearchWithAnyTagEqualToFilter"); + + await sut.CreateCollectionAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 29f, 28f, 27f]), new() + { + IncludeVectors = false, + Top = 5, + Filter = new([ + new AnyTagEqualToFilterClause("Tags", "tag2") + ]) + }); + + var results = await searchResults.Results.ToListAsync(); + + // Assert + var ids = results.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal([1, 3], ids); + } + + [Fact] + public async Task ItCanUpsertAndGetEnumerableTypesAsync() + { + // Arrange + var sut = fixture.GetCollection("UpsertAndGetEnumerableTypes"); + + await sut.CreateCollectionAsync(); + + var record = new RecordWithEnumerables + { + Id = 1, + ListInts = new() { 1, 2, 3 }, + CollectionInts = new HashSet() { 4, 5, 6 }, + EnumerableInts = [7, 8, 9], + ReadOnlyCollectionInts = new List { 10, 11, 12 }, + ReadOnlyListInts = new List { 13, 14, 15 } + }; + + // Act + await sut.UpsertAsync(record); + + var getResult = await sut.GetAsync(1); + + // Assert + Assert.NotNull(getResult); + Assert.Equal(1, getResult!.Id); + Assert.NotNull(getResult.ListInts); + Assert.Equal(3, getResult.ListInts!.Count); + Assert.Equal(1, getResult.ListInts![0]); + Assert.Equal(2, getResult.ListInts![1]); + Assert.Equal(3, getResult.ListInts![2]); + Assert.NotNull(getResult.CollectionInts); + Assert.Equal(3, getResult.CollectionInts!.Count); + Assert.Contains(4, getResult.CollectionInts); + Assert.Contains(5, getResult.CollectionInts); + Assert.Contains(6, getResult.CollectionInts); + Assert.NotNull(getResult.EnumerableInts); + Assert.Equal(3, getResult.EnumerableInts!.Count()); + Assert.Equal(7, getResult.EnumerableInts.ElementAt(0)); + Assert.Equal(8, getResult.EnumerableInts.ElementAt(1)); + Assert.Equal(9, getResult.EnumerableInts.ElementAt(2)); + Assert.NotNull(getResult.ReadOnlyCollectionInts); + Assert.Equal(3, getResult.ReadOnlyCollectionInts!.Count); + var readOnlyCollectionIntsList = getResult.ReadOnlyCollectionInts.ToList(); + Assert.Equal(10, readOnlyCollectionIntsList[0]); + Assert.Equal(11, readOnlyCollectionIntsList[1]); + Assert.Equal(12, readOnlyCollectionIntsList[2]); + Assert.NotNull(getResult.ReadOnlyListInts); + Assert.Equal(3, getResult.ReadOnlyListInts!.Count); + Assert.Equal(13, getResult.ReadOnlyListInts[0]); + Assert.Equal(14, getResult.ReadOnlyListInts[1]); + Assert.Equal(15, getResult.ReadOnlyListInts[2]); + } + + #region private ================================================================================== + + private static VectorStoreRecordDefinition GetVectorStoreRecordDefinition(string distanceFunction = DistanceFunction.CosineDistance) => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("HotelId", typeof(TKey)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)), + new VectorStoreRecordDataProperty("HotelCode", typeof(int)), + new VectorStoreRecordDataProperty("HotelRating", typeof(float?)), + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordDataProperty("ListInts", typeof(List)), + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.Hnsw, DistanceFunction = distanceFunction } + ] + }; + + private dynamic GetCollection(Type idType, string collectionName) + { + var method = typeof(PostgresVectorStoreFixture).GetMethod("GetCollection"); + var genericMethod = method!.MakeGenericMethod(idType, typeof(PostgresHotel<>).MakeGenericType(idType)); + return genericMethod.Invoke(fixture, [collectionName, null])!; + } + + private PostgresHotel CreateRecord(Type idType, TKey key) + { + var recordType = typeof(PostgresHotel<>).MakeGenericType(idType); + var record = (PostgresHotel)Activator.CreateInstance(recordType, key)!; + record.HotelName = "Hotel 1"; + record.HotelCode = 1; + record.ParkingIncluded = true; + record.HotelRating = 4.5f; + record.Tags = new List { "tag1", "tag2" }; + return record; + } + private static DateTime TruncateMilliseconds(DateTime dateTime) + { + return new DateTime(dateTime.Ticks - (dateTime.Ticks % TimeSpan.TicksPerSecond), dateTime.Kind); + } + + private static DateTimeOffset TruncateMilliseconds(DateTimeOffset dateTimeOffset) + { + return new DateTimeOffset(dateTimeOffset.Ticks - (dateTimeOffset.Ticks % TimeSpan.TicksPerSecond), dateTimeOffset.Offset); + } + +#pragma warning disable CA1812, CA1859 + private sealed class RecordWithEnumerables + { + [VectorStoreRecordKey] + public int Id { get; set; } + + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + public ReadOnlyMemory? Embedding { get; set; } + + [VectorStoreRecordData] + public List? ListInts { get; set; } + + [VectorStoreRecordData] + public ICollection? CollectionInts { get; set; } + + [VectorStoreRecordData] + public IEnumerable? EnumerableInts { get; set; } + + [VectorStoreRecordData] + public IReadOnlyCollection? ReadOnlyCollectionInts { get; set; } + + [VectorStoreRecordData] + public IReadOnlyList? ReadOnlyListInts { get; set; } + } +#pragma warning restore CA1812, CA1859 + + #endregion + +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs new file mode 100644 index 000000000000..3eb2c02d54c6 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq; +using System.Threading.Tasks; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +[Collection("PostgresVectorStoreCollection")] +public class PostgresVectorStoreTests(PostgresVectorStoreFixture fixture) +{ + [Fact] + public async Task ItCanGetAListOfExistingCollectionNamesAsync() + { + // Arrange + var sut = fixture.VectorStore; + + // Setup + var collection = sut.GetCollection>("VS_TEST_HOTELS"); + await collection.CreateCollectionIfNotExistsAsync(); + + // Act + var collectionNames = await sut.ListCollectionNamesAsync().ToListAsync(); + + // Assert + Assert.Contains("VS_TEST_HOTELS", collectionNames); + } +} diff --git a/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs b/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs index 844ae7e2f573..a85a509d1980 100644 --- a/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs +++ b/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -135,6 +136,40 @@ static async ValueTask Core(IAsyncEnumerable source, Func + /// Projects each element of an into a new form by incorporating + /// an asynchronous transformation function. + /// + /// The type of the elements of the source sequence. + /// The type of the elements of the resulting sequence. + /// An to invoke a transform function on. + /// + /// A transform function to apply to each element. This function takes an element of + /// type TSource and returns an element of type TResult. + /// + /// + /// A CancellationToken to observe while iterating through the sequence. + /// + /// + /// An whose elements are the result of invoking the transform + /// function on each element of the original sequence. + /// + /// Thrown when the source or selector is null. + public static async IAsyncEnumerable SelectAsync( + this IAsyncEnumerable source, + Func selector, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return selector(item); + } + } + +#pragma warning restore IDE1006 // Naming rule violation: Missing suffix: 'Async' + private sealed class EmptyAsyncEnumerable : IAsyncEnumerable, IAsyncEnumerator { public static readonly EmptyAsyncEnumerable Instance = new(); From 7c25ac4c3be7e19950916a41f918dbadf66a819a Mon Sep 17 00:00:00 2001 From: blurred83 Date: Mon, 16 Dec 2024 03:53:21 -0600 Subject: [PATCH 06/11] .Net: Fix typo in GettingStarted.Step3_Yaml_Prompt - CreatPrompt -> CreatePrompt (#9823) ### Motivation and Context Fixing a typo in a unit test name (CreatPromptFromYamlAsync) that ReSharper noticed. ### Description Changed CreatPromptFromYamlAsync to CreatePromptFromYamlAsync (and rebuilt/ran the test just to be sure). Co-authored-by: Max Szczurek Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com> --- dotnet/samples/GettingStarted/Step3_Yaml_Prompt.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/samples/GettingStarted/Step3_Yaml_Prompt.cs b/dotnet/samples/GettingStarted/Step3_Yaml_Prompt.cs index 29d50f7b6da7..a848779d4e96 100644 --- a/dotnet/samples/GettingStarted/Step3_Yaml_Prompt.cs +++ b/dotnet/samples/GettingStarted/Step3_Yaml_Prompt.cs @@ -15,7 +15,7 @@ public sealed class Step3_Yaml_Prompt(ITestOutputHelper output) : BaseTest(outpu /// Show how to create a prompt from a YAML resource. /// [Fact] - public async Task CreatPromptFromYamlAsync() + public async Task CreatePromptFromYamlAsync() { // Create a kernel with OpenAI chat completion Kernel kernel = Kernel.CreateBuilder() From 6d02eeff815915f12cb830180243532743c2a211 Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:00:06 +0000 Subject: [PATCH 07/11] .Net: Allow customization of building REST API operation URL, payload, and headers (#9985) ### Motivation and Context CopilotAgentPlugin functionality may need more control over the way url, headers and payload are created. ### Description This PR adds internal factories for creating URLs, headers, and payloads. The factories are kept internal because the necessity of having them and their structure may change in the future. --- .../Functions.OpenApi/HttpContentFactory.cs | 2 +- .../Model/RestApiOperationHeadersFactory.cs | 14 +++ .../Model/RestApiOperationPayloadFactory.cs | 23 ++++ .../Model/RestApiOperationUrlFactory.cs | 15 +++ .../RestApiOperationRunner.cs | 38 ++++-- .../OpenApi/RestApiOperationRunnerTests.cs | 115 ++++++++++++++++++ 6 files changed, 199 insertions(+), 8 deletions(-) create mode 100644 dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationHeadersFactory.cs create mode 100644 dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationPayloadFactory.cs create mode 100644 dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationUrlFactory.cs diff --git a/dotnet/src/Functions/Functions.OpenApi/HttpContentFactory.cs b/dotnet/src/Functions/Functions.OpenApi/HttpContentFactory.cs index c3ebf9251e0a..45cea8a3ec3a 100644 --- a/dotnet/src/Functions/Functions.OpenApi/HttpContentFactory.cs +++ b/dotnet/src/Functions/Functions.OpenApi/HttpContentFactory.cs @@ -11,4 +11,4 @@ namespace Microsoft.SemanticKernel.Plugins.OpenApi; /// The operation payload metadata. /// The operation arguments. /// The object and HttpContent representing the operation payload. -internal delegate (object? Payload, HttpContent Content) HttpContentFactory(RestApiPayload? payload, IDictionary arguments); +internal delegate (object Payload, HttpContent Content) HttpContentFactory(RestApiPayload? payload, IDictionary arguments); diff --git a/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationHeadersFactory.cs b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationHeadersFactory.cs new file mode 100644 index 000000000000..738a47a670f8 --- /dev/null +++ b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationHeadersFactory.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; + +namespace Microsoft.SemanticKernel.Plugins.OpenApi; + +/// +/// Represents a delegate for creating headers for a REST API operation. +/// +/// The REST API operation. +/// The arguments for the operation. +/// The operation run options. +/// The operation headers. +internal delegate IDictionary? RestApiOperationHeadersFactory(RestApiOperation operation, IDictionary arguments, RestApiOperationRunOptions? options); diff --git a/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationPayloadFactory.cs b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationPayloadFactory.cs new file mode 100644 index 000000000000..1000a616fe73 --- /dev/null +++ b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationPayloadFactory.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Net.Http; + +namespace Microsoft.SemanticKernel.Plugins.OpenApi; + +/// +/// Represents a delegate for creating a payload for a REST API operation. +/// +/// The REST API operation. +/// The arguments for the operation. +/// +/// Determines whether the operation payload is constructed dynamically based on operation payload metadata. +/// If false, the operation payload must be provided via the 'payload' property. +/// +/// +/// Determines whether payload parameters are resolved from the arguments by +/// full name (parameter name prefixed with the parent property name). +/// +/// The operation run options. +/// The operation payload. +internal delegate (object Payload, HttpContent Content)? RestApiOperationPayloadFactory(RestApiOperation operation, IDictionary arguments, bool enableDynamicPayload, bool enablePayloadNamespacing, RestApiOperationRunOptions? options); diff --git a/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationUrlFactory.cs b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationUrlFactory.cs new file mode 100644 index 000000000000..64736c6decbe --- /dev/null +++ b/dotnet/src/Functions/Functions.OpenApi/Model/RestApiOperationUrlFactory.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; + +namespace Microsoft.SemanticKernel.Plugins.OpenApi; + +/// +/// Represents a delegate for creating a URL for a REST API operation. +/// +/// The REST API operation. +/// The arguments for the operation. +/// The operation run options. +/// The operation URL. +internal delegate Uri? RestApiOperationUrlFactory(RestApiOperation operation, IDictionary arguments, RestApiOperationRunOptions? options); diff --git a/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs b/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs index 29b58fa6b480..9c1c2bcb1177 100644 --- a/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs +++ b/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs @@ -88,6 +88,21 @@ internal sealed class RestApiOperationRunner /// private readonly HttpResponseContentReader? _httpResponseContentReader; + /// + /// The external URL factory to use if provided, instead of the default one. + /// + private readonly RestApiOperationUrlFactory? _urlFactory; + + /// + /// The external header factory to use if provided, instead of the default one. + /// + private readonly RestApiOperationHeadersFactory? _headersFactory; + + /// + /// The external payload factory to use if provided, instead of the default one. + /// + private readonly RestApiOperationPayloadFactory? _payloadFactory; + /// /// Creates an instance of the class. /// @@ -100,19 +115,28 @@ internal sealed class RestApiOperationRunner /// Determines whether payload parameters are resolved from the arguments by /// full name (parameter name prefixed with the parent property name). /// Custom HTTP response content reader. + /// The external URL factory to use if provided if provided instead of the default one. + /// The external headers factory to use if provided instead of the default one. + /// The external payload factory to use if provided instead of the default one. public RestApiOperationRunner( HttpClient httpClient, AuthenticateRequestAsyncCallback? authCallback = null, string? userAgent = null, bool enableDynamicPayload = false, bool enablePayloadNamespacing = false, - HttpResponseContentReader? httpResponseContentReader = null) + HttpResponseContentReader? httpResponseContentReader = null, + RestApiOperationUrlFactory? urlFactory = null, + RestApiOperationHeadersFactory? headersFactory = null, + RestApiOperationPayloadFactory? payloadFactory = null) { this._httpClient = httpClient; this._userAgent = userAgent ?? HttpHeaderConstant.Values.UserAgent; this._enableDynamicPayload = enableDynamicPayload; this._enablePayloadNamespacing = enablePayloadNamespacing; this._httpResponseContentReader = httpResponseContentReader; + this._urlFactory = urlFactory; + this._headersFactory = headersFactory; + this._payloadFactory = payloadFactory; // If no auth callback provided, use empty function if (authCallback is null) @@ -145,13 +169,13 @@ public Task RunAsync( RestApiOperationRunOptions? options = null, CancellationToken cancellationToken = default) { - var url = this.BuildsOperationUrl(operation, arguments, options?.ServerUrlOverride, options?.ApiHostUrl); + var url = this._urlFactory?.Invoke(operation, arguments, options) ?? this.BuildsOperationUrl(operation, arguments, options?.ServerUrlOverride, options?.ApiHostUrl); - var headers = operation.BuildHeaders(arguments); + var headers = this._headersFactory?.Invoke(operation, arguments, options) ?? operation.BuildHeaders(arguments); - var operationPayload = this.BuildOperationPayload(operation, arguments); + var (Payload, Content) = this._payloadFactory?.Invoke(operation, arguments, this._enableDynamicPayload, this._enablePayloadNamespacing, options) ?? this.BuildOperationPayload(operation, arguments); - return this.SendAsync(url, operation.Method, headers, operationPayload.Payload, operationPayload.Content, operation.Responses.ToDictionary(item => item.Key, item => item.Value.Schema), options, cancellationToken); + return this.SendAsync(url, operation.Method, headers, Payload, Content, operation.Responses.ToDictionary(item => item.Key, item => item.Value.Schema), options, cancellationToken); } #region private @@ -340,7 +364,7 @@ private async Task ReadContentAndCreateOperationRespon /// The payload meta-data. /// The payload arguments. /// The JSON payload the corresponding HttpContent. - private (object? Payload, HttpContent Content) BuildJsonPayload(RestApiPayload? payloadMetadata, IDictionary arguments) + private (object Payload, HttpContent Content) BuildJsonPayload(RestApiPayload? payloadMetadata, IDictionary arguments) { // Build operation payload dynamically if (this._enableDynamicPayload) @@ -440,7 +464,7 @@ private JsonObject BuildJsonObject(IList properties, IDi /// The payload meta-data. /// The payload arguments. /// The text payload and corresponding HttpContent. - private (object? Payload, HttpContent Content) BuildPlainTextPayload(RestApiPayload? payloadMetadata, IDictionary arguments) + private (object Payload, HttpContent Content) BuildPlainTextPayload(RestApiPayload? payloadMetadata, IDictionary arguments) { if (!arguments.TryGetValue(RestApiOperation.PayloadArgumentName, out object? argument) || argument is not string payload) { diff --git a/dotnet/src/Functions/Functions.UnitTests/OpenApi/RestApiOperationRunnerTests.cs b/dotnet/src/Functions/Functions.UnitTests/OpenApi/RestApiOperationRunnerTests.cs index e30d115aaece..089644ad7848 100644 --- a/dotnet/src/Functions/Functions.UnitTests/OpenApi/RestApiOperationRunnerTests.cs +++ b/dotnet/src/Functions/Functions.UnitTests/OpenApi/RestApiOperationRunnerTests.cs @@ -1517,6 +1517,121 @@ public async Task ItShouldUseRestApiOperationPayloadPropertyNameToLookupArgument Assert.Equal("true", enabledProperty.ToString()); } + [Fact] + public async Task ItShouldUseUrlHeaderAndPayloadFactoriesIfProvidedAsync() + { + // Arrange + this._httpMessageHandlerStub.ResponseToReturn.Content = new StringContent("fake-content", Encoding.UTF8, MediaTypeNames.Application.Json); + + List payloadProperties = + [ + new("name", "string", true, []) + ]; + + var payload = new RestApiPayload(MediaTypeNames.Application.Json, payloadProperties); + + var expectedOperation = new RestApiOperation( + id: "fake-id", + servers: [new RestApiServer("https://fake-random-test-host")], + path: "fake-path", + method: HttpMethod.Post, + description: "fake-description", + parameters: [], + responses: new Dictionary(), + securityRequirements: [], + payload: payload + ); + + var expectedArguments = new KernelArguments(); + + var expectedOptions = new RestApiOperationRunOptions() + { + Kernel = new(), + KernelFunction = KernelFunctionFactory.CreateFromMethod(() => false), + KernelArguments = expectedArguments, + }; + + bool createUrlFactoryCalled = false; + bool createHeadersFactoryCalled = false; + bool createPayloadFactoryCalled = false; + + Uri CreateUrl(RestApiOperation operation, IDictionary arguments, RestApiOperationRunOptions? options) + { + createUrlFactoryCalled = true; + Assert.Same(expectedOperation, operation); + Assert.Same(expectedArguments, arguments); + Assert.Same(expectedOptions, options); + + return new Uri("https://fake-random-test-host-from-factory/"); + } + + IDictionary? CreateHeaders(RestApiOperation operation, IDictionary arguments, RestApiOperationRunOptions? options) + { + createHeadersFactoryCalled = true; + Assert.Same(expectedOperation, operation); + Assert.Same(expectedArguments, arguments); + Assert.Same(expectedOptions, options); + + return new Dictionary() { ["header-from-factory"] = "value-of-header-from-factory" }; + } + + (object Payload, HttpContent Content)? CreatePayload(RestApiOperation operation, IDictionary arguments, bool enableDynamicPayload, bool enablePayloadNamespacing, RestApiOperationRunOptions? options) + { + createPayloadFactoryCalled = true; + Assert.Same(expectedOperation, operation); + Assert.Same(expectedArguments, arguments); + Assert.True(enableDynamicPayload); + Assert.True(enablePayloadNamespacing); + Assert.Same(expectedOptions, options); + + var json = """{"name":"fake-name-value"}"""; + + return ((JsonObject)JsonObject.Parse(json)!, new StringContent(json, Encoding.UTF8, MediaTypeNames.Application.Json)); + } + + var sut = new RestApiOperationRunner( + this._httpClient, + enableDynamicPayload: true, + enablePayloadNamespacing: true, + urlFactory: CreateUrl, + headersFactory: CreateHeaders, + payloadFactory: CreatePayload); + + // Act + var result = await sut.RunAsync(expectedOperation, expectedArguments, expectedOptions); + + // Assert + Assert.True(createUrlFactoryCalled); + Assert.True(createHeadersFactoryCalled); + Assert.True(createPayloadFactoryCalled); + + // Assert url factory + Assert.NotNull(this._httpMessageHandlerStub.RequestUri); + Assert.Equal("https://fake-random-test-host-from-factory/", this._httpMessageHandlerStub.RequestUri.AbsoluteUri); + + // Assert headers factory + Assert.NotNull(this._httpMessageHandlerStub.RequestHeaders); + Assert.Equal(3, this._httpMessageHandlerStub.RequestHeaders.Count()); + + Assert.Contains(this._httpMessageHandlerStub.RequestHeaders, h => h.Key == "header-from-factory" && h.Value.Contains("value-of-header-from-factory")); + Assert.Contains(this._httpMessageHandlerStub.RequestHeaders, h => h.Key == "User-Agent" && h.Value.Contains("Semantic-Kernel")); + Assert.Contains(this._httpMessageHandlerStub.RequestHeaders, h => h.Key == "Semantic-Kernel-Version"); + + // Assert payload factory + var messageContent = this._httpMessageHandlerStub.RequestContent; + Assert.NotNull(messageContent); + + var deserializedPayload = await JsonNode.ParseAsync(new MemoryStream(messageContent)); + Assert.NotNull(deserializedPayload); + + var nameProperty = deserializedPayload["name"]?.ToString(); + Assert.Equal("fake-name-value", nameProperty); + + Assert.NotNull(result.RequestPayload); + Assert.IsType(result.RequestPayload); + Assert.Equal("""{"name":"fake-name-value"}""", ((JsonObject)result.RequestPayload).ToJsonString()); + } + public class SchemaTestData : IEnumerable { public IEnumerator GetEnumerator() From 5874188b2b967c72a4a309c28aa594b506ab32a2 Mon Sep 17 00:00:00 2001 From: Vincent Biret Date: Mon, 16 Dec 2024 11:12:08 -0500 Subject: [PATCH 08/11] .Net: fix: includes path item path parameters to OpenAPI document parsing (#9969) fixes #9962 --------- Signed-off-by: Vincent Biret Co-authored-by: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> --- .../OpenApi/OpenApiDocumentParser.cs | 25 ++- .../OpenApi/OpenApiDocumentParserV20Tests.cs | 160 ++++++++++++++++ .../OpenApi/OpenApiDocumentParserV30Tests.cs | 171 ++++++++++++++++++ .../OpenApi/OpenApiDocumentParserV31Tests.cs | 171 ++++++++++++++++++ 4 files changed, 525 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Functions/Functions.OpenApi/OpenApi/OpenApiDocumentParser.cs b/dotnet/src/Functions/Functions.OpenApi/OpenApi/OpenApiDocumentParser.cs index 67ba2d34e79a..4803d28e1e1b 100644 --- a/dotnet/src/Functions/Functions.OpenApi/OpenApi/OpenApiDocumentParser.cs +++ b/dotnet/src/Functions/Functions.OpenApi/OpenApi/OpenApiDocumentParser.cs @@ -211,7 +211,7 @@ internal static List CreateRestApiOperations(OpenApiDocument d path: path, method: new HttpMethod(method), description: string.IsNullOrEmpty(operationItem.Description) ? operationItem.Summary : operationItem.Description, - parameters: CreateRestApiOperationParameters(operationItem.OperationId, operationItem.Parameters), + parameters: CreateRestApiOperationParameters(operationItem.OperationId, operationItem.Parameters.Union(pathItem.Parameters, s_parameterNameAndLocationComparer)), payload: CreateRestApiOperationPayload(operationItem.OperationId, operationItem.RequestBody), responses: CreateRestApiOperationExpectedResponses(operationItem.Responses).ToDictionary(static item => item.Item1, static item => item.Item2), securityRequirements: CreateRestApiOperationSecurityRequirements(operationItem.Security) @@ -237,6 +237,27 @@ internal static List CreateRestApiOperations(OpenApiDocument d } } + private static readonly ParameterNameAndLocationComparer s_parameterNameAndLocationComparer = new(); + + /// + /// Compares two objects by their name and location. + /// + private sealed class ParameterNameAndLocationComparer : IEqualityComparer + { + public bool Equals(OpenApiParameter? x, OpenApiParameter? y) + { + if (x is null || y is null) + { + return x == y; + } + return this.GetHashCode(x) == this.GetHashCode(y); + } + public int GetHashCode([DisallowNull] OpenApiParameter obj) + { + return HashCode.Combine(obj.Name, obj.In); + } + } + /// /// Build a list of objects from the given list of objects. /// @@ -381,7 +402,7 @@ internal static List CreateRestApiOperationSecurityR /// The operation id. /// The OpenAPI parameters. /// The parameters. - private static List CreateRestApiOperationParameters(string operationId, IList parameters) + private static List CreateRestApiOperationParameters(string operationId, IEnumerable parameters) { var result = new List(); diff --git a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV20Tests.cs b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV20Tests.cs index 625420e2f956..9313297ace66 100644 --- a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV20Tests.cs +++ b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV20Tests.cs @@ -5,6 +5,7 @@ using System.IO; using System.Linq; using System.Net.Http; +using System.Text; using System.Text.Json; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -434,6 +435,165 @@ public async Task ItCanFilterOutSpecifiedOperationsAsync() Assert.Contains(restApiSpec.Operations, o => o.Id == "SetSecret"); Assert.Contains(restApiSpec.Operations, o => o.Id == "GetSecret"); } + [Fact] + public async Task ItCanParsePathItemPathParametersAsync() + { + var document = + """ + { + "swagger": "2.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "type": "string" + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } + + [Fact] + public async Task ItCanParsePathItemPathParametersAndOverridesAsync() + { + var document = + """ + { + "swagger": "2.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "type": "string" + }, + { + "name": "itemId", + "in": "path", + "description": "item ID override", + "required": true, + "type": "string" + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + Assert.Equal("item ID override", pathParameter.Description); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } private static RestApiParameter GetParameterMetadata(IList operations, string operationId, RestApiParameterLocation location, string name) diff --git a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV30Tests.cs b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV30Tests.cs index 8728771ac54a..02b3d363ebfb 100644 --- a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV30Tests.cs +++ b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV30Tests.cs @@ -5,6 +5,7 @@ using System.IO; using System.Linq; using System.Net.Http; +using System.Text; using System.Text.Json; using System.Text.Json.Nodes; using System.Threading.Tasks; @@ -500,6 +501,176 @@ public async Task ItCanParseDocumentWithMultipleServersAsync() Assert.Equal("https://ppe.my-key-vault.vault.azure.net", restApi.Operations[0].Servers[1].Url); } + [Fact] + public async Task ItCanParsePathItemPathParametersAsync() + { + var document = + """ + { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } + + [Fact] + public async Task ItCanParsePathItemPathParametersAndOverridesAsync() + { + var document = + """ + { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "itemId", + "in": "path", + "description": "item ID override", + "required": true, + "schema": { + "type": "string" + } + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + Assert.Equal("item ID override", pathParameter.Description); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } + private static MemoryStream ModifyOpenApiDocument(Stream openApiDocument, Action transformer) { var json = JsonSerializer.Deserialize(openApiDocument); diff --git a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV31Tests.cs b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV31Tests.cs index 6455b95dd34b..5fc59c70a8f9 100644 --- a/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV31Tests.cs +++ b/dotnet/src/Functions/Functions.UnitTests/OpenApi/OpenApiDocumentParserV31Tests.cs @@ -6,6 +6,7 @@ using System.IO; using System.Linq; using System.Net.Http; +using System.Text; using System.Text.Json; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -477,6 +478,176 @@ public async Task ItCanParseDocumentWithMultipleServersAsync() Assert.Equal("https://ppe.my-key-vault.vault.azure.net", restApi.Operations[0].Servers[1].Url); } + [Fact] + public async Task ItCanParsePathItemPathParametersAsync() + {//TODO update the document version when upgrading Microsoft.OpenAPI to v2 + var document = + """ + { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } + + [Fact] + public async Task ItCanParsePathItemPathParametersAndOverridesAsync() + {//TODO update the document version when upgrading Microsoft.OpenAPI to v2 + var document = + """ + { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "paths": { + "/items/{itemId}/{format}": { + "parameters": [ + { + "name": "itemId", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "get": { + "parameters": [ + { + "name": "format", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "itemId", + "in": "path", + "description": "item ID override", + "required": true, + "schema": { + "type": "string" + } + } + ], + "summary": "Get an item by ID", + "responses": { + "200": { + "description": "Successful response" + } + } + } + } + } + } + """; + + await using var steam = new MemoryStream(Encoding.UTF8.GetBytes(document)); + var restApi = await this._sut.ParseAsync(steam); + + Assert.NotNull(restApi); + Assert.NotNull(restApi.Operations); + Assert.NotEmpty(restApi.Operations); + + var firstOperation = restApi.Operations[0]; + + Assert.NotNull(firstOperation); + Assert.Equal("Get an item by ID", firstOperation.Description); + Assert.Equal("/items/{itemId}/{format}", firstOperation.Path); + + var parameters = firstOperation.GetParameters(); + Assert.NotNull(parameters); + Assert.Equal(2, parameters.Count); + + var pathParameter = parameters.Single(static p => "itemId".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(pathParameter); + Assert.True(pathParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, pathParameter.Location); + Assert.Null(pathParameter.DefaultValue); + Assert.NotNull(pathParameter.Schema); + Assert.Equal("string", pathParameter.Schema.RootElement.GetProperty("type").GetString()); + Assert.Equal("item ID override", pathParameter.Description); + + var formatParameter = parameters.Single(static p => "format".Equals(p.Name, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(formatParameter); + Assert.True(formatParameter.IsRequired); + Assert.Equal(RestApiParameterLocation.Path, formatParameter.Location); + Assert.Null(formatParameter.DefaultValue); + Assert.NotNull(formatParameter.Schema); + Assert.Equal("string", formatParameter.Schema.RootElement.GetProperty("type").GetString()); + } + private static MemoryStream ModifyOpenApiDocument(Stream openApiDocument, Action> transformer) { var serializer = new SharpYaml.Serialization.Serializer(); From 62a50f32cf140b24876517219726a28465ef640e Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Mon, 16 Dec 2024 20:14:36 +0100 Subject: [PATCH 09/11] Python: Qdrant - fix in filter and 100% test coverage (#9982) ### Motivation and Context There was a small error in the filter creation logic, and improved test coverage for Qdrant. ### Description ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .../memory/qdrant/qdrant_collection.py | 4 +- .../connectors/memory/qdrant/test_qdrant.py | 84 ++++++++++++++++--- 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py b/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py index 5fb8c177be89..cb30fa0cdc76 100644 --- a/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py +++ b/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py @@ -188,7 +188,7 @@ async def _inner_search( else: query_vector = vector if query_vector is None: - raise VectorSearchExecutionException("Search requires either a vector.") + raise VectorSearchExecutionException("Search requires a vector.") results = await self.qdrant_client.search( collection_name=self.collection_name, query_vector=query_vector, @@ -214,7 +214,7 @@ def _get_score_from_result(self, result: ScoredPoint) -> float: def _create_filter(self, options: VectorSearchOptions) -> Filter: return Filter( must=[ - FieldCondition(key=filter.field_name, match=MatchAny(any=filter.value)) + FieldCondition(key=filter.field_name, match=MatchAny(any=[filter.value])) for filter in options.filter.filters ] ) diff --git a/python/tests/unit/connectors/memory/qdrant/test_qdrant.py b/python/tests/unit/connectors/memory/qdrant/test_qdrant.py index ce00e7d88c95..c92571daf238 100644 --- a/python/tests/unit/connectors/memory/qdrant/test_qdrant.py +++ b/python/tests/unit/connectors/memory/qdrant/test_qdrant.py @@ -4,17 +4,19 @@ from pytest import fixture, mark, raises from qdrant_client.async_qdrant_client import AsyncQdrantClient -from qdrant_client.models import Datatype, Distance, VectorParams +from qdrant_client.models import Datatype, Distance, FieldCondition, Filter, MatchAny, VectorParams from semantic_kernel.connectors.memory.qdrant.qdrant_collection import QdrantCollection from semantic_kernel.connectors.memory.qdrant.qdrant_store import QdrantStore from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordVectorField +from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions from semantic_kernel.exceptions.memory_connector_exceptions import ( MemoryConnectorException, MemoryConnectorInitializationError, VectorStoreModelValidationError, ) +from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException BASE_PATH = "qdrant_client.async_qdrant_client.AsyncQdrantClient" @@ -119,9 +121,10 @@ def mock_search(): yield mock_search -def test_vector_store_defaults(vector_store): - assert vector_store.qdrant_client is not None - assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333" +async def test_vector_store_defaults(vector_store): + async with vector_store: + assert vector_store.qdrant_client is not None + assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333" def test_vector_store_with_client(): @@ -162,18 +165,18 @@ def test_get_collection(vector_store, data_model_definition, qdrant_unit_test_en assert vector_store.vector_record_collections["test"] == collection -def test_collection_init(data_model_definition, qdrant_unit_test_env): - collection = QdrantCollection( +async def test_collection_init(data_model_definition, qdrant_unit_test_env): + async with QdrantCollection( data_model_type=dict, collection_name="test", data_model_definition=data_model_definition, env_file_path="test.env", - ) - assert collection.collection_name == "test" - assert collection.qdrant_client is not None - assert collection.data_model_type is dict - assert collection.data_model_definition == data_model_definition - assert collection.named_vectors + ) as collection: + assert collection.collection_name == "test" + assert collection.qdrant_client is not None + assert collection.data_model_type is dict + assert collection.data_model_definition == data_model_definition + assert collection.named_vectors def test_collection_init_fail(data_model_definition): @@ -275,8 +278,63 @@ async def test_create_index_fail(collection_to_use, request): await collection.create_collection() -async def test_search(collection): +async def test_search(collection, mock_search): results = await collection._inner_search(vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(include_vectors=False)) async for result in results.results: assert result.record["id"] == "id1" break + + assert mock_search.call_count == 1 + mock_search.assert_called_with( + collection_name="test", + query_vector=[1.0, 2.0, 3.0], + query_filter=Filter(must=[]), + with_vectors=False, + limit=3, + offset=0, + ) + + +async def test_search_named_vectors(collection, mock_search): + collection.named_vectors = True + results = await collection._inner_search( + vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(vector_field_name="vector", include_vectors=False) + ) + async for result in results.results: + assert result.record["id"] == "id1" + break + + assert mock_search.call_count == 1 + mock_search.assert_called_with( + collection_name="test", + query_vector=("vector", [1.0, 2.0, 3.0]), + query_filter=Filter(must=[]), + with_vectors=False, + limit=3, + offset=0, + ) + + +async def test_search_filter(collection, mock_search): + results = await collection._inner_search( + vector=[1.0, 2.0, 3.0], + options=VectorSearchOptions(include_vectors=False, filter=VectorSearchFilter.equal_to("id", "id1")), + ) + async for result in results.results: + assert result.record["id"] == "id1" + break + + assert mock_search.call_count == 1 + mock_search.assert_called_with( + collection_name="test", + query_vector=[1.0, 2.0, 3.0], + query_filter=Filter(must=[FieldCondition(key="id", match=MatchAny(any=["id1"]))]), + with_vectors=False, + limit=3, + offset=0, + ) + + +async def test_search_fail(collection): + with raises(VectorSearchExecutionException, match="Search requires a vector."): + await collection._inner_search(options=VectorSearchOptions(include_vectors=False)) From b42720890d2de562d9bb4a7c4767e90508a54d6e Mon Sep 17 00:00:00 2001 From: Evan Mattson <35585003+moonbox3@users.noreply.github.com> Date: Tue, 17 Dec 2024 07:14:22 +0900 Subject: [PATCH 10/11] Python: FunctionResultContent hash fix to handle lists/sets (#9978) ### Motivation and Context Currently, the hash function in `FunctionResultContent` does not properly handle types of list or set. Adding the proper handling to turn these into a tuple, and then the hash function works properly. ### Description This PR: - Makes sure the FunctionResultContent `__hash__` method can properly handle mutable types like lists or sets - Adds tests to ensure the new behavior works as expected - Ensures the create message content method can handle lists or sets when generating the content and the model can only accept a string. - Fixes #9977 ### Contribution Checklist - [X] The code builds clean without any errors or warnings - [X] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [X] All unit tests pass, and I have added new tests where possible - [X] I didn't break anyone :smile: --------- Co-authored-by: Eduard van Valkenburg --- .../open_ai/assistant_content_generation.py | 23 +++++++++++-- .../contents/function_result_content.py | 17 +++++++++- .../functions/kernel_function_from_prompt.py | 2 +- .../agents/test_open_ai_assistant_base.py | 2 ++ .../contents/test_function_result_content.py | 33 +++++++++++++++++-- 5 files changed, 71 insertions(+), 6 deletions(-) diff --git a/python/semantic_kernel/agents/open_ai/assistant_content_generation.py b/python/semantic_kernel/agents/open_ai/assistant_content_generation.py index 872978adbdd4..1c4a79f3e5eb 100644 --- a/python/semantic_kernel/agents/open_ai/assistant_content_generation.py +++ b/python/semantic_kernel/agents/open_ai/assistant_content_generation.py @@ -87,17 +87,36 @@ def get_message_contents(message: "ChatMessageContent") -> list[dict[str, Any]]: for content in message.items: match content: case TextContent(): - contents.append({"type": "text", "text": content.text}) + # Make sure text is a string + final_text = content.text + if not isinstance(final_text, str): + if isinstance(final_text, (list, tuple)): + final_text = " ".join(map(str, final_text)) + else: + final_text = str(final_text) + + contents.append({"type": "text", "text": final_text}) + case ImageContent(): if content.uri: contents.append(content.to_dict()) + case FileReferenceContent(): contents.append({ "type": "image_file", "image_file": {"file_id": content.file_id}, }) + case FunctionResultContent(): - contents.append({"type": "text", "text": content.result}) + final_result = content.result + match final_result: + case str(): + contents.append({"type": "text", "text": final_result}) + case list() | tuple(): + contents.append({"type": "text", "text": " ".join(map(str, final_result))}) + case _: + contents.append({"type": "text", "text": str(final_result)}) + return contents diff --git a/python/semantic_kernel/contents/function_result_content.py b/python/semantic_kernel/contents/function_result_content.py index 821cc46615d1..536fb4ff19ce 100644 --- a/python/semantic_kernel/contents/function_result_content.py +++ b/python/semantic_kernel/contents/function_result_content.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent + from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.functions.function_result import FunctionResult TAG_CONTENT_MAP = { @@ -157,6 +158,12 @@ def to_chat_message_content(self) -> "ChatMessageContent": return ChatMessageContent(role=AuthorRole.TOOL, items=[self]) + def to_streaming_chat_message_content(self) -> "StreamingChatMessageContent": + """Convert the instance to a StreamingChatMessageContent.""" + from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent + + return StreamingChatMessageContent(role=AuthorRole.TOOL, choice_index=0, items=[self]) + def to_dict(self) -> dict[str, str]: """Convert the instance to a dictionary.""" return { @@ -187,4 +194,12 @@ def serialize_result(self, value: Any) -> str: def __hash__(self) -> int: """Return the hash of the function result content.""" - return hash((self.tag, self.id, self.result, self.name, self.function_name, self.plugin_name, self.encoding)) + return hash(( + self.tag, + self.id, + tuple(self.result) if isinstance(self.result, list) else self.result, + self.name, + self.function_name, + self.plugin_name, + self.encoding, + )) diff --git a/python/semantic_kernel/functions/kernel_function_from_prompt.py b/python/semantic_kernel/functions/kernel_function_from_prompt.py index ecba3e9aa96c..1e301da4fa17 100644 --- a/python/semantic_kernel/functions/kernel_function_from_prompt.py +++ b/python/semantic_kernel/functions/kernel_function_from_prompt.py @@ -6,7 +6,7 @@ from html import unescape from typing import TYPE_CHECKING, Any -import yaml +import yaml # type: ignore from pydantic import Field, ValidationError, model_validator from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase diff --git a/python/tests/unit/agents/test_open_ai_assistant_base.py b/python/tests/unit/agents/test_open_ai_assistant_base.py index 7e0658b252d6..411102270286 100644 --- a/python/tests/unit/agents/test_open_ai_assistant_base.py +++ b/python/tests/unit/agents/test_open_ai_assistant_base.py @@ -1487,6 +1487,8 @@ def test_get_message_contents(azure_openai_assistant_agent: AzureAssistantAgent, ImageContent(role=AuthorRole.ASSISTANT, content="test message", uri="http://image.url"), TextContent(role=AuthorRole.ASSISTANT, text="test message"), FileReferenceContent(role=AuthorRole.ASSISTANT, file_id="test_file_id"), + TextContent(role=AuthorRole.USER, text="test message"), + FunctionResultContent(role=AuthorRole.ASSISTANT, result=["test result"], id="test_id"), ] result = get_message_contents(message) diff --git a/python/tests/unit/contents/test_function_result_content.py b/python/tests/unit/contents/test_function_result_content.py index 4b013d8a83dd..5bb549924d81 100644 --- a/python/tests/unit/contents/test_function_result_content.py +++ b/python/tests/unit/contents/test_function_result_content.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Any from unittest.mock import Mock import pytest @@ -15,6 +14,26 @@ from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata +class CustomResultClass: + """Custom class for testing.""" + + def __init__(self, result): + self.result = result + + def __str__(self) -> str: + return self.result + + +class CustomObjectWithList: + """Custom class for testing.""" + + def __init__(self, items): + self.items = items + + def __str__(self): + return f"CustomObjectWithList({self.items})" + + def test_init(): frc = FunctionResultContent(id="test", name="test-function", result="test-result", metadata={"test": "test"}) assert frc.name == "test-function" @@ -50,6 +69,11 @@ def test_init_from_names(): ChatMessageContent(role="user", content="Hello world!"), ChatMessageContent(role="user", items=[ImageContent(uri="https://example.com")]), ChatMessageContent(role="user", items=[FunctionResultContent(id="test", name="test", result="Hello world!")]), + [1, 2, 3], + [{"key": "value"}, {"another": "item"}], + {"a", "b"}, + CustomResultClass("test"), + CustomObjectWithList(["one", "two", "three"]), ], ids=[ "str", @@ -60,9 +84,14 @@ def test_init_from_names(): "ChatMessageContent", "ChatMessageContent-ImageContent", "ChatMessageContent-FunctionResultContent", + "list", + "list_of_dicts", + "set", + "CustomResultClass", + "CustomObjectWithList", ], ) -def test_from_fcc_and_result(result: Any): +def test_from_fcc_and_result(result: any): fcc = FunctionCallContent( id="test", name="test-function", arguments='{"input": "world"}', metadata={"test": "test"} ) From 4a21254fe8420a36b7ed4bf09f4cae3bb4c0f84d Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 16 Dec 2024 14:51:22 -0800 Subject: [PATCH 11/11] Python: Anthropic function calling fixes (#9938) ### Motivation and Context The current implementation of the Anthropic connector relies on the `inner_content`s in chat messages to prepare the chat history for the Anthropic client. This will only work when the chat history is created by the Anthropic connector. This won't work if the chat history has been processed by other connectors, or if it is hardcoded as in testing. ### Description 1. Prepare the chat history for the Anthropic client by parsing the actual Semantic Kernel item types. 2. Fix tests for the Anthropic connector. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .../services/anthropic_chat_completion.py | 72 +++--------- .../connectors/ai/anthropic/services/utils.py | 110 ++++++++++++++++++ ...t_chat_completion_with_function_calling.py | 12 +- .../completions/test_chat_completions.py | 20 ++++ 4 files changed, 154 insertions(+), 60 deletions(-) create mode 100644 python/semantic_kernel/connectors/ai/anthropic/services/utils.py diff --git a/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py b/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py index 53a52bd13ff9..ed2616ba71aa 100644 --- a/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py @@ -26,6 +26,7 @@ from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import ( AnthropicChatPromptExecutionSettings, ) +from semantic_kernel.connectors.ai.anthropic.services.utils import MESSAGE_CONVERTERS from semantic_kernel.connectors.ai.anthropic.settings.anthropic_settings import AnthropicSettings from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration @@ -34,7 +35,6 @@ from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ITEM_TYPES, ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent -from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.streaming_chat_message_content import ITEM_TYPES as STREAMING_ITEM_TYPES from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.contents.streaming_text_content import StreamingTextContent @@ -192,69 +192,25 @@ def _prepare_chat_history_for_request( A tuple containing the prepared chat history and the first SYSTEM message content. """ system_message_content = None - remaining_messages: list[dict[str, Any]] = [] - system_message_found = False + system_message_count = 0 + formatted_messages: list[dict[str, Any]] = [] for message in chat_history.messages: # Skip system messages after the first one is found if message.role == AuthorRole.SYSTEM: - if not system_message_found: + if system_message_count == 0: system_message_content = message.content - system_message_found = True - elif message.role == AuthorRole.TOOL: - # if tool result message isn't the most recent message, add it to the remaining messages - if not remaining_messages or remaining_messages[-1][role_key] != AuthorRole.USER: - remaining_messages.append({ - role_key: AuthorRole.USER, - content_key: [], - }) - - # add the tool result to the most recent message - tool_results_message = remaining_messages[-1] - for item in message.items: - if isinstance(item, FunctionResultContent): - tool_results_message["content"].append({ - "type": "tool_result", - "tool_use_id": item.id, - content_key: str(item.result), - }) - elif message.finish_reason == SemanticKernelFinishReason.TOOL_CALLS: - if not stream: - if not message.inner_content: - raise ServiceInvalidResponseError( - "Expected a message with an Anthropic Message as inner content." - ) - - remaining_messages.append({ - role_key: AuthorRole.ASSISTANT, - content_key: [content_block.to_dict() for content_block in message.inner_content.content], - }) - else: - content: list[TextBlock | ToolUseBlock] = [] - # for remaining items, add them to the content - for item in message.items: - if isinstance(item, TextContent): - content.append(TextBlock(text=item.text, type="text")) - elif isinstance(item, FunctionCallContent): - item_arguments = ( - item.arguments if not isinstance(item.arguments, str) else json.loads(item.arguments) - ) - - content.append( - ToolUseBlock(id=item.id, input=item_arguments, name=item.name, type="tool_use") - ) - - remaining_messages.append({ - role_key: AuthorRole.ASSISTANT, - content_key: content, - }) + system_message_count += 1 else: - # The API requires only role and content keys for the remaining messages - remaining_messages.append({ - role_key: getattr(message, role_key), - content_key: getattr(message, content_key), - }) + formatted_messages.append(MESSAGE_CONVERTERS[message.role](message)) - return remaining_messages, system_message_content + if system_message_count > 1: + logger.warning( + "Anthropic service only supports one system message, but %s system messages were found." + " Only the first system message will be included in the request.", + system_message_count, + ) + + return formatted_messages, system_message_content # endregion diff --git a/python/semantic_kernel/connectors/ai/anthropic/services/utils.py b/python/semantic_kernel/connectors/ai/anthropic/services/utils.py new file mode 100644 index 000000000000..774d93615927 --- /dev/null +++ b/python/semantic_kernel/connectors/ai/anthropic/services/utils.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft. All rights reserved. + +import json +import logging +from collections.abc import Callable, Mapping +from typing import Any + +from semantic_kernel.contents.chat_message_content import ChatMessageContent +from semantic_kernel.contents.function_call_content import FunctionCallContent +from semantic_kernel.contents.function_result_content import FunctionResultContent +from semantic_kernel.contents.text_content import TextContent +from semantic_kernel.contents.utils.author_role import AuthorRole + +logger: logging.Logger = logging.getLogger(__name__) + + +def _format_user_message(message: ChatMessageContent) -> dict[str, Any]: + """Format a user message to the expected object for the Anthropic client. + + Args: + message: The user message. + + Returns: + The formatted user message. + """ + return { + "role": "user", + "content": message.content, + } + + +def _format_assistant_message(message: ChatMessageContent) -> dict[str, Any]: + """Format an assistant message to the expected object for the Anthropic client. + + Args: + message: The assistant message. + + Returns: + The formatted assistant message. + """ + tool_calls: list[dict[str, Any]] = [] + + for item in message.items: + if isinstance(item, TextContent): + # Assuming the assistant message will have only one text content item + # and we assign the content directly to the message content, which is a string. + continue + if isinstance(item, FunctionCallContent): + tool_calls.append({ + "type": "tool_use", + "id": item.id or "", + "name": item.name or "", + "input": item.arguments if isinstance(item.arguments, Mapping) else json.loads(item.arguments or ""), + }) + else: + logger.warning( + f"Unsupported item type in Assistant message while formatting chat history for Anthropic: {type(item)}" + ) + + if tool_calls: + return { + "role": "assistant", + "content": [ + { + "type": "text", + "text": message.content, + }, + *tool_calls, + ], + } + + return { + "role": "assistant", + "content": message.content, + } + + +def _format_tool_message(message: ChatMessageContent) -> dict[str, Any]: + """Format a tool message to the expected object for the Anthropic client. + + Args: + message: The tool message. + + Returns: + The formatted tool message. + """ + function_result_contents: list[dict[str, Any]] = [] + for item in message.items: + if not isinstance(item, FunctionResultContent): + logger.warning( + f"Unsupported item type in Tool message while formatting chat history for Anthropic: {type(item)}" + ) + continue + function_result_contents.append({ + "type": "tool_result", + "tool_use_id": item.id, + "content": str(item.result), + }) + + return { + "role": "user", + "content": function_result_contents, + } + + +MESSAGE_CONVERTERS: dict[AuthorRole, Callable[[ChatMessageContent], dict[str, Any]]] = { + AuthorRole.USER: _format_user_message, + AuthorRole.ASSISTANT: _format_assistant_message, + AuthorRole.TOOL: _format_tool_message, +} diff --git a/python/tests/integration/completions/test_chat_completion_with_function_calling.py b/python/tests/integration/completions/test_chat_completion_with_function_calling.py index 3542a6a39459..b6d83c6d0735 100644 --- a/python/tests/integration/completions/test_chat_completion_with_function_calling.py +++ b/python/tests/integration/completions/test_chat_completion_with_function_calling.py @@ -450,7 +450,12 @@ class FunctionChoiceTestTypes(str, Enum): ), pytest.param( "anthropic", - {}, + { + # Anthropic expects tools in the request when it sees tool use in the chat history. + "function_choice_behavior": FunctionChoiceBehavior.Auto( + auto_invoke=True, filters={"excluded_plugins": ["task_plugin"]} + ), + }, [ [ ChatMessageContent( @@ -460,9 +465,12 @@ class FunctionChoiceTestTypes(str, Enum): ChatMessageContent( role=AuthorRole.ASSISTANT, items=[ + # Anthropic will often include a chain of thought in the tool call by default. + # If this is not in the message, it will complain about the missing chain of thought. + TextContent(text="I will find the revenue for you."), FunctionCallContent( id="123456789", name="finance-search", arguments='{"company": "contoso", "year": 2024}' - ) + ), ], ), ChatMessageContent( diff --git a/python/tests/integration/completions/test_chat_completions.py b/python/tests/integration/completions/test_chat_completions.py index 71195148aa9b..810be08fd5e2 100644 --- a/python/tests/integration/completions/test_chat_completions.py +++ b/python/tests/integration/completions/test_chat_completions.py @@ -43,6 +43,7 @@ class Reasoning(KernelBaseModel): pytestmark = pytest.mark.parametrize( "service_id, execution_settings_kwargs, inputs, kwargs", [ + # region OpenAI pytest.param( "openai", {}, @@ -63,6 +64,8 @@ class Reasoning(KernelBaseModel): {}, id="openai_json_schema_response_format", ), + # endregion + # region Azure pytest.param( "azure", {}, @@ -83,6 +86,8 @@ class Reasoning(KernelBaseModel): {}, id="azure_custom_client", ), + # endregion + # region Azure AI Inference pytest.param( "azure_ai_inference", {}, @@ -93,6 +98,8 @@ class Reasoning(KernelBaseModel): {}, id="azure_ai_inference_text_input", ), + # endregion + # region Anthropic pytest.param( "anthropic", {}, @@ -104,6 +111,8 @@ class Reasoning(KernelBaseModel): marks=pytest.mark.skipif(not anthropic_setup, reason="Anthropic Environment Variables not set"), id="anthropic_text_input", ), + # endregion + # region Mistral AI pytest.param( "mistral_ai", {}, @@ -115,6 +124,8 @@ class Reasoning(KernelBaseModel): marks=pytest.mark.skipif(not mistral_ai_setup, reason="Mistral AI Environment Variables not set"), id="mistral_ai_text_input", ), + # endregion + # region Ollama pytest.param( "ollama", {}, @@ -129,6 +140,8 @@ class Reasoning(KernelBaseModel): ), id="ollama_text_input", ), + # endregion + # region Onnx Gen AI pytest.param( "onnx_gen_ai", {}, @@ -140,6 +153,8 @@ class Reasoning(KernelBaseModel): marks=pytest.mark.skipif(not onnx_setup, reason="Need a Onnx Model setup"), id="onnx_gen_ai", ), + # endregion + # region Google AI pytest.param( "google_ai", {}, @@ -151,6 +166,8 @@ class Reasoning(KernelBaseModel): marks=pytest.mark.skip(reason="Skipping due to 429s from Google AI."), id="google_ai_text_input", ), + # endregion + # region Vertex AI pytest.param( "vertex_ai", {}, @@ -162,6 +179,8 @@ class Reasoning(KernelBaseModel): marks=pytest.mark.skipif(not vertex_ai_setup, reason="Vertex AI Environment Variables not set"), id="vertex_ai_text_input", ), + # endregion + # region Bedrock pytest.param( "bedrock_amazon_titan", {}, @@ -228,6 +247,7 @@ class Reasoning(KernelBaseModel): marks=pytest.mark.skip(reason="Skipping due to occasional throttling from Bedrock."), id="bedrock_mistralai_text_input", ), + # endregion ], )