diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 712a133c3de5..b69a75f9a175 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -31,6 +31,8 @@ + + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 3a0ec36d6ff9..39ff0cf90a81 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -103,6 +103,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "InternalUtilities", "Intern EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenApiSkillsExample", "..\samples\dotnet\openapi-skills\OpenApiSkillsExample.csproj", "{4D91A3E0-C404-495B-AD4A-411C4E83CF54}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.DuckDB", "src\Connectors\Connectors.Memory.DuckDB\Connectors.Memory.DuckDB.csproj", "{50FAE231-6F24-4779-9D02-12ABBC9A49E2}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -294,6 +296,12 @@ Global {4D91A3E0-C404-495B-AD4A-411C4E83CF54}.Publish|Any CPU.Build.0 = Release|Any CPU {4D91A3E0-C404-495B-AD4A-411C4E83CF54}.Release|Any CPU.ActiveCfg = Release|Any CPU {4D91A3E0-C404-495B-AD4A-411C4E83CF54}.Release|Any CPU.Build.0 = Release|Any CPU + {50FAE231-6F24-4779-9D02-12ABBC9A49E2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {50FAE231-6F24-4779-9D02-12ABBC9A49E2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {50FAE231-6F24-4779-9D02-12ABBC9A49E2}.Publish|Any CPU.ActiveCfg = Publish|Any CPU + {50FAE231-6F24-4779-9D02-12ABBC9A49E2}.Publish|Any CPU.Build.0 = Publish|Any CPU + {50FAE231-6F24-4779-9D02-12ABBC9A49E2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {50FAE231-6F24-4779-9D02-12ABBC9A49E2}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -336,6 +344,7 @@ Global {136823BE-8665-4D57-87E0-EF41535539E2} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {4D3DAE63-41C6-4E1C-A35A-E77BDFC40675} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} {4D91A3E0-C404-495B-AD4A-411C4E83CF54} = {FA3720F1-C99A-49B2-9577-A940257098BF} + {50FAE231-6F24-4779-9D02-12ABBC9A49E2} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/src/Connectors/Connectors.Memory.DuckDB/Connectors.Memory.DuckDB.csproj b/dotnet/src/Connectors/Connectors.Memory.DuckDB/Connectors.Memory.DuckDB.csproj new file mode 100644 index 000000000000..61514808ca20 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.DuckDB/Connectors.Memory.DuckDB.csproj @@ -0,0 +1,28 @@ + + + + + Microsoft.SemanticKernel.Connectors.Memory.DuckDB + $(AssemblyName) + netstandard2.0 + + + + + + + + Semantic Kernel - DuckDB Connector + DuckDB connector for Semantic Kernel skills and semantic memory + + + + + + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.DuckDB/Database.cs b/dotnet/src/Connectors/Connectors.Memory.DuckDB/Database.cs new file mode 100644 index 000000000000..d58c933b89ec --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.DuckDB/Database.cs @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using DuckDB.NET.Data; + +namespace Microsoft.SemanticKernel.Connectors.Memory.DuckDB; + +internal struct DatabaseEntry +{ + public string Key { get; set; } + + public string MetadataString { get; set; } + + public string EmbeddingString { get; set; } + + public string? Timestamp { get; set; } +} + +internal sealed class Database +{ + private const string TableName = "SKMemoryTable"; + + public Database() { } + + public Task CreateTableAsync(DuckDBConnection conn, CancellationToken cancellationToken = default) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = $@" + CREATE TABLE IF NOT EXISTS {TableName}( + collection TEXT, + key TEXT, + metadata TEXT, + embedding TEXT, + timestamp TEXT, + PRIMARY KEY(collection, key))"; + return cmd.ExecuteNonQueryAsync(cancellationToken); + } + + public async Task CreateCollectionAsync(DuckDBConnection conn, string collectionName, CancellationToken cancellationToken = default) + { + if (await this.DoesCollectionExistsAsync(conn, collectionName, cancellationToken).ConfigureAwait(false)) + { + // Collection already exists + return; + } + + using var cmd = conn.CreateCommand(); + cmd.CommandText = $@" + INSERT INTO {TableName} VALUES (?1,?2,?3,?4,?5 ); "; + cmd.Parameters.Add(new DuckDBParameter(collectionName)); + cmd.Parameters.Add(new DuckDBParameter(string.Empty)); + cmd.Parameters.Add(new DuckDBParameter(string.Empty)); + cmd.Parameters.Add(new DuckDBParameter(string.Empty)); + cmd.Parameters.Add(new DuckDBParameter(string.Empty)); + + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + public async Task UpdateOrInsertAsync(DuckDBConnection conn, + string collection, string key, string? metadata, string? embedding, string? timestamp, CancellationToken cancellationToken = default) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = $@" + INSERT INTO {TableName} VALUES(?1, ?2, ?3, ?4, ?5) + ON CONFLICT (collection, key) DO UPDATE SET metadata=?3, embedding=?4, timestamp=?5; "; + cmd.Parameters.Add(new DuckDBParameter(collection)); + cmd.Parameters.Add(new DuckDBParameter(key)); + cmd.Parameters.Add(new DuckDBParameter(metadata ?? string.Empty)); + cmd.Parameters.Add(new DuckDBParameter(embedding ?? string.Empty)); + cmd.Parameters.Add(new DuckDBParameter(timestamp ?? string.Empty)); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + public async Task DoesCollectionExistsAsync(DuckDBConnection conn, + string collectionName, + CancellationToken cancellationToken = default) + { + var collections = await this.GetCollectionsAsync(conn, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); + return collections.Contains(collectionName); + } + + public async IAsyncEnumerable GetCollectionsAsync(DuckDBConnection conn, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = $@" + SELECT DISTINCT collection + FROM {TableName};"; + + using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return dataReader.GetString("collection"); + } + } + + public async IAsyncEnumerable ReadAllAsync(DuckDBConnection conn, + string collectionName, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = $@" + SELECT * FROM {TableName} + WHERE collection=?1;"; + cmd.Parameters.Add(new DuckDBParameter(collectionName)); + + using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + string key = dataReader.GetString("key"); + if (string.IsNullOrWhiteSpace(key)) + { + continue; + } + string metadata = dataReader.GetString("metadata"); + string embedding = dataReader.GetString("embedding"); + string timestamp = dataReader.GetString("timestamp"); + yield return new DatabaseEntry() { Key = key, MetadataString = metadata, EmbeddingString = embedding, Timestamp = timestamp }; + } + } + + public async Task ReadAsync(DuckDBConnection conn, + string collectionName, + string key, + CancellationToken cancellationToken = default) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = $@" + SELECT * FROM {TableName} + WHERE collection=?1 + AND key=?2; "; + cmd.Parameters.Add(new DuckDBParameter(collectionName)); + cmd.Parameters.Add(new DuckDBParameter(key)); + + using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + string metadata = dataReader.GetString(dataReader.GetOrdinal("metadata")); + string embedding = dataReader.GetString(dataReader.GetOrdinal("embedding")); + string timestamp = dataReader.GetString(dataReader.GetOrdinal("timestamp")); + return new DatabaseEntry() + { + Key = key, + MetadataString = metadata, + EmbeddingString = embedding, + Timestamp = timestamp + }; + } + + return null; + } + + public Task DeleteCollectionAsync(DuckDBConnection conn, string collectionName, CancellationToken cancellationToken = default) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = $@" + DELETE FROM {TableName} + WHERE collection=?;"; + cmd.Parameters.Add(new DuckDBParameter(collectionName)); + return cmd.ExecuteNonQueryAsync(cancellationToken); + } + + public Task DeleteAsync(DuckDBConnection conn, string collectionName, string key, CancellationToken cancellationToken = default) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = $@" + DELETE FROM {TableName} + WHERE collection=?1 + AND key=?2; "; + cmd.Parameters.Add(new DuckDBParameter(collectionName)); + cmd.Parameters.Add(new DuckDBParameter(key)); + return cmd.ExecuteNonQueryAsync(cancellationToken); + } + + public Task DeleteEmptyAsync(DuckDBConnection conn, string collectionName, CancellationToken cancellationToken = default) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = $@" + DELETE FROM {TableName} + WHERE collection=?1 + AND key IS NULL"; + cmd.Parameters.Add(new DuckDBParameter(collectionName)); + return cmd.ExecuteNonQueryAsync(cancellationToken); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.DuckDB/DuckDBExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.DuckDB/DuckDBExtensions.cs new file mode 100644 index 000000000000..8724a64d0da8 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.DuckDB/DuckDBExtensions.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Data.Common; + +namespace Microsoft.SemanticKernel.Connectors.Memory.DuckDB; + +internal static class DuckDBExtensions +{ + public static string GetString(this DbDataReader reader, string fieldName) + { + int ordinal = reader.GetOrdinal(fieldName); + return reader.GetString(ordinal); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.DuckDB/DuckDBMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.DuckDB/DuckDBMemoryStore.cs new file mode 100644 index 000000000000..3e79b72d1be9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.DuckDB/DuckDBMemoryStore.cs @@ -0,0 +1,338 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using DuckDB.NET.Data; +using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.AI.Embeddings.VectorOperations; +using Microsoft.SemanticKernel.Memory; +using Microsoft.SemanticKernel.Memory.Collections; + +namespace Microsoft.SemanticKernel.Connectors.Memory.DuckDB; + +/// +/// An implementation of backed by a DuckDB database. +/// +/// The data is saved to a database file, specified in the constructor. +/// The data persists between subsequent instances. Only one instance may access the file at a time. +/// The caller is responsible for deleting the file. +public class DuckDBMemoryStore : IMemoryStore, IDisposable +{ + /// + /// Connect a DuckDB database + /// + /// Path to the database file. If file does not exist, it will be created. + /// The to monitor for cancellation requests. The default is . + public static async Task ConnectAsync(string filename, + CancellationToken cancellationToken = default) + { + var memoryStore = new DuckDBMemoryStore($"Data Source={filename}"); + return await InitialiseMemoryStoreAsync(memoryStore, cancellationToken).ConfigureAwait(false); + } + + /// + /// Connect a in memory DuckDB database + /// + /// The to monitor for cancellation requests. The default is . + public static async Task ConnectAsync( + CancellationToken cancellationToken = default) + { + var memoryStore = new DuckDBMemoryStore(":memory:"); + return await InitialiseMemoryStoreAsync(memoryStore, cancellationToken).ConfigureAwait(false); + } + + /// + /// Connect a in memory DuckDB database + /// + /// An already established connection. + /// The to monitor for cancellation requests. The default is . + public static async Task ConnectAsync(DuckDBConnection connection, + CancellationToken cancellationToken = default) + { + var memoryStore = new DuckDBMemoryStore(connection); + return await InitialiseMemoryStoreAsync(memoryStore, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default) + { + await this._dbConnector.CreateCollectionAsync(this._dbConnection, collectionName, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task DoesCollectionExistAsync(string collectionName, CancellationToken cancellationToken = default) + { + return await this._dbConnector.DoesCollectionExistsAsync(this._dbConnection, collectionName, cancellationToken).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable GetCollectionsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var collection in this._dbConnector.GetCollectionsAsync(this._dbConnection, cancellationToken)) + { + yield return collection; + } + } + + /// + public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default) + { + await this._dbConnector.DeleteCollectionAsync(this._dbConnection, collectionName, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task UpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken = default) + { + return await this.InternalUpsertAsync(this._dbConnection, collectionName, record, cancellationToken).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(string collectionName, IEnumerable records, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var record in records) + { + yield return await this.InternalUpsertAsync(this._dbConnection, collectionName, record, cancellationToken).ConfigureAwait(false); + } + } + + /// + public async Task GetAsync(string collectionName, string key, bool withEmbedding = false, CancellationToken cancellationToken = default) + { + return await this.InternalGetAsync(this._dbConnection, collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable GetBatchAsync(string collectionName, IEnumerable keys, bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var key in keys) + { + var result = await this.InternalGetAsync(this._dbConnection, collectionName, key, withEmbeddings, cancellationToken).ConfigureAwait(false); + if (result != null) + { + yield return result; + } + else + { + yield break; + } + } + } + + /// + public async Task RemoveAsync(string collectionName, string key, CancellationToken cancellationToken = default) + { + await this._dbConnector.DeleteAsync(this._dbConnection, collectionName, key, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task RemoveBatchAsync(string collectionName, IEnumerable keys, CancellationToken cancellationToken = default) + { + await Task.WhenAll(keys.Select(k => this._dbConnector.DeleteAsync(this._dbConnection, collectionName, k, cancellationToken))).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync( + string collectionName, + Embedding embedding, + int limit, + double minRelevanceScore = 0, + bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (limit <= 0) + { + yield break; + } + + var collectionMemories = new List(); + TopNCollection embeddings = new(limit); + + await foreach (var record in this.GetAllAsync(collectionName, cancellationToken)) + { + if (record != null) + { + double similarity = embedding + .AsReadOnlySpan() + .CosineSimilarity(record.Embedding.AsReadOnlySpan()); + if (similarity >= minRelevanceScore) + { + var entry = withEmbeddings ? record : MemoryRecord.FromMetadata(record.Metadata, Embedding.Empty, record.Key, record.Timestamp); + embeddings.Add(new(entry, similarity)); + } + } + } + + embeddings.SortByScore(); + + foreach (var item in embeddings) + { + yield return (item.Value, item.Score.Value); + } + } + + /// + public async Task<(MemoryRecord, double)?> GetNearestMatchAsync(string collectionName, Embedding embedding, double minRelevanceScore = 0, bool withEmbedding = false, + CancellationToken cancellationToken = default) + { + return await this.GetNearestMatchesAsync( + collectionName: collectionName, + embedding: embedding, + limit: 1, + minRelevanceScore: minRelevanceScore, + withEmbeddings: withEmbedding, + cancellationToken: cancellationToken).FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false); + } + + /// + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + #region protected ================================================================================ + + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._dbConnection.Close(); + this._dbConnection.Dispose(); + } + + this._disposedValue = true; + } + } + + #endregion + + #region private ================================================================================ + + private readonly Database _dbConnector; + private readonly DuckDBConnection _dbConnection; + private bool _disposedValue; + + private static async Task InitialiseMemoryStoreAsync(DuckDBMemoryStore memoryStore, CancellationToken cancellationToken = default) + { + await memoryStore._dbConnection.OpenAsync(cancellationToken).ConfigureAwait(false); + await memoryStore._dbConnector.CreateTableAsync(memoryStore._dbConnection, cancellationToken).ConfigureAwait(false); + return memoryStore; + } + + /// + /// Constructor + /// + /// DuckDB db filename. + private DuckDBMemoryStore(string filename) + { + this._dbConnector = new Database(); + this._dbConnection = new DuckDBConnection($"Data Source={filename};"); + this._disposedValue = false; + } + + /// + /// Constructor + /// + private DuckDBMemoryStore() + { + this._dbConnector = new Database(); + this._dbConnection = new DuckDBConnection("Data Source=:memory:;"); + this._disposedValue = false; + } + + /// + /// Constructor + /// + /// + private DuckDBMemoryStore(DuckDBConnection connection) + { + this._dbConnector = new Database(); + this._dbConnection = connection; + this._disposedValue = false; + } + + private static string? ToTimestampString(DateTimeOffset? timestamp) + { + return timestamp?.ToString("u", CultureInfo.InvariantCulture); + } + + private static DateTimeOffset? ParseTimestamp(string? str) + { + if (!string.IsNullOrEmpty(str) + && DateTimeOffset.TryParse(str, CultureInfo.InvariantCulture, DateTimeStyles.AssumeUniversal, out DateTimeOffset timestamp)) + { + return timestamp; + } + + return null; + } + + private async IAsyncEnumerable GetAllAsync(string collectionName, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // delete empty entry in the database if it exists (see CreateCollection) + await this._dbConnector.DeleteEmptyAsync(this._dbConnection, collectionName, cancellationToken).ConfigureAwait(false); + + await foreach (DatabaseEntry dbEntry in this._dbConnector.ReadAllAsync(this._dbConnection, collectionName, cancellationToken)) + { + var dbEntryEmbeddingString = dbEntry.EmbeddingString; + Embedding? vector = JsonSerializer.Deserialize>(dbEntryEmbeddingString); + + var record = MemoryRecord.FromJsonMetadata(dbEntry.MetadataString, vector, dbEntry.Key, ParseTimestamp(dbEntry.Timestamp)); + + yield return record; + } + } + + private async Task InternalUpsertAsync(DuckDBConnection connection, string collectionName, MemoryRecord record, CancellationToken cancellationToken) + { + record.Key = record.Metadata.Id; + + await this._dbConnector.UpdateOrInsertAsync(conn: connection, + collection: collectionName, + key: record.Key, + metadata: record.GetSerializedMetadata(), + embedding: JsonSerializer.Serialize(record.Embedding), + timestamp: ToTimestampString(record.Timestamp), + cancellationToken: cancellationToken).ConfigureAwait(false); + + return record.Key; + } + + private async Task InternalGetAsync( + DuckDBConnection connection, + string collectionName, + string key, bool withEmbedding, + CancellationToken cancellationToken) + { + DatabaseEntry? entry = await this._dbConnector.ReadAsync(connection, collectionName, key, cancellationToken).ConfigureAwait(false); + + if (!entry.HasValue) { return null; } + + if (withEmbedding) + { + return MemoryRecord.FromJsonMetadata( + json: entry.Value.MetadataString, + JsonSerializer.Deserialize>(entry.Value.EmbeddingString), + entry.Value.Key, + ParseTimestamp(entry.Value.Timestamp)); + } + + return MemoryRecord.FromJsonMetadata( + json: entry.Value.MetadataString, + Embedding.Empty, + entry.Value.Key, + ParseTimestamp(entry.Value.Timestamp)); + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj index d90d2daf54d7..2536b7d40a44 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj @@ -1,4 +1,4 @@ - + SemanticKernel.Connectors.UnitTests @@ -27,6 +27,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/DuckDB/DuckDBMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/DuckDB/DuckDBMemoryStoreTests.cs new file mode 100644 index 000000000000..8e809a5dd4e9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/DuckDB/DuckDBMemoryStoreTests.cs @@ -0,0 +1,657 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.Connectors.Memory.DuckDB; +using Microsoft.SemanticKernel.Memory; +using Xunit; + +namespace SemanticKernel.Connectors.UnitTests.Memory.DuckDB; + +/// +/// Unit tests of . +/// +[Collection("Sequential")] +public class DuckDBMemoryStoreTests +{ + private int _collectionNum = 0; + + private IEnumerable CreateBatchRecords(int numRecords) + { + Assert.True(numRecords % 2 == 0, "Number of records must be even"); + Assert.True(numRecords > 0, "Number of records must be greater than 0"); + + IEnumerable records = new List(numRecords); + for (int i = 0; i < numRecords / 2; i++) + { + var testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 1, 1 })); + records = records.Append(testRecord); + } + + for (int i = numRecords / 2; i < numRecords; i++) + { + var testRecord = MemoryRecord.ReferenceRecord( + externalId: "test" + i, + sourceName: "sourceName" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 2, 3 })); + records = records.Append(testRecord); + } + + return records; + } + + [Fact] + public async Task InitializeDbConnectionSucceedsAsync() + { + using var db = await DuckDBMemoryStore.ConnectAsync(); + // Assert + Assert.NotNull(db); + } + + [Fact] + public async Task ItCanCreateAndGetCollectionAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var collections = db.GetCollectionsAsync(); + + // Assert + Assert.NotEmpty(collections.ToEnumerable()); + Assert.True(await collections.ContainsAsync(collection)); + } + + [Fact] + public async Task ItCanCheckIfCollectionExistsAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + string collection = "my_collection"; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + + // Assert + Assert.True(await db.DoesCollectionExistAsync("my_collection")); + Assert.False(await db.DoesCollectionExistAsync("my_collection2")); + } + + [Fact] + public async Task CreatingDuplicateCollectionDoesNothingAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var collections = db.GetCollectionsAsync(); + await db.CreateCollectionAsync(collection); + + // Assert + var collections2 = db.GetCollectionsAsync(); + Assert.Equal(await collections.CountAsync(), await collections.CountAsync()); + } + + [Fact] + public async Task CollectionsCanBeDeletedAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + await db.CreateCollectionAsync(collection); + var collections = await db.GetCollectionsAsync().ToListAsync(); + Assert.True(collections.Count > 0); + + // Act + foreach (var c in collections) + { + await db.DeleteCollectionAsync(c); + } + + // Assert + var collections2 = db.GetCollectionsAsync(); + Assert.True(await collections2.CountAsync() == 0); + } + + [Fact] + public async Task ItCanInsertIntoNonExistentCollectionAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test", + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 }), + key: null, + timestamp: null); + + // Arrange + var key = await db.UpsertAsync("random collection", testRecord); + var actual = await db.GetAsync("random collection", key, true); + + // Assert + Assert.NotNull(actual); + Assert.Equal(testRecord.Metadata.Id, key); + Assert.Equal(testRecord.Metadata.Id, actual.Key); + Assert.Equal(testRecord.Embedding.Vector, actual.Embedding.Vector); + Assert.Equal(testRecord.Metadata.Text, actual.Metadata.Text); + Assert.Equal(testRecord.Metadata.Description, actual.Metadata.Description); + Assert.Equal(testRecord.Metadata.ExternalSourceName, actual.Metadata.ExternalSourceName); + Assert.Equal(testRecord.Metadata.Id, actual.Metadata.Id); + } + + [Fact] + public async Task GetAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test", + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 }), + key: null, + timestamp: null); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var key = await db.UpsertAsync(collection, testRecord); + var actualDefault = await db.GetAsync(collection, key); + var actualWithEmbedding = await db.GetAsync(collection, key, true); + + // Assert + Assert.NotNull(actualDefault); + Assert.NotNull(actualWithEmbedding); + Assert.Empty(actualDefault.Embedding.Vector); + Assert.NotEmpty(actualWithEmbedding.Embedding.Vector); + } + + [Fact] + public async Task ItCanUpsertAndRetrieveARecordWithNoTimestampAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test", + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 }), + key: null, + timestamp: null); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var key = await db.UpsertAsync(collection, testRecord); + var actual = await db.GetAsync(collection, key, true); + + // Assert + Assert.NotNull(actual); + Assert.Equal(testRecord.Metadata.Id, key); + Assert.Equal(testRecord.Metadata.Id, actual.Key); + Assert.Equal(testRecord.Embedding.Vector, actual.Embedding.Vector); + Assert.Equal(testRecord.Metadata.Text, actual.Metadata.Text); + Assert.Equal(testRecord.Metadata.Description, actual.Metadata.Description); + Assert.Equal(testRecord.Metadata.ExternalSourceName, actual.Metadata.ExternalSourceName); + Assert.Equal(testRecord.Metadata.Id, actual.Metadata.Id); + } + + [Fact] + public async Task ItCanUpsertAndRetrieveARecordWithTimestampAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test", + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 }), + key: null, + timestamp: DateTimeOffset.UtcNow); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var key = await db.UpsertAsync(collection, testRecord); + var actual = await db.GetAsync(collection, key, true); + + // Assert + Assert.NotNull(actual); + Assert.Equal(testRecord.Metadata.Id, key); + Assert.Equal(testRecord.Metadata.Id, actual.Key); + Assert.Equal(testRecord.Embedding.Vector, actual.Embedding.Vector); + Assert.Equal(testRecord.Metadata.Text, actual.Metadata.Text); + Assert.Equal(testRecord.Metadata.Description, actual.Metadata.Description); + Assert.Equal(testRecord.Metadata.ExternalSourceName, actual.Metadata.ExternalSourceName); + Assert.Equal(testRecord.Metadata.Id, actual.Metadata.Id); + } + + [Fact] + public async Task UpsertReplacesExistingRecordWithSameIdAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + string commonId = "test"; + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: commonId, + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 })); + MemoryRecord testRecord2 = MemoryRecord.LocalRecord( + id: commonId, + text: "text2", + description: "description2", + embedding: new Embedding(new float[] { 1, 2, 4 })); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var key = await db.UpsertAsync(collection, testRecord); + var key2 = await db.UpsertAsync(collection, testRecord2); + var actual = await db.GetAsync(collection, key, true); + + // Assert + Assert.NotNull(actual); + Assert.Equal(testRecord.Metadata.Id, key); + Assert.Equal(testRecord2.Metadata.Id, actual.Key); + Assert.NotEqual(testRecord.Embedding.Vector, actual.Embedding.Vector); + Assert.Equal(testRecord2.Embedding.Vector, actual.Embedding.Vector); + Assert.NotEqual(testRecord.Metadata.Text, actual.Metadata.Text); + Assert.Equal(testRecord2.Metadata.Description, actual.Metadata.Description); + } + + [Fact] + public async Task ExistingRecordCanBeRemovedAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test", + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 })); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var key = await db.UpsertAsync(collection, testRecord); + await db.RemoveAsync(collection, key); + var actual = await db.GetAsync(collection, key); + + // Assert + Assert.Null(actual); + } + + [Fact] + public async Task RemovingNonExistingRecordDoesNothingAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + await db.RemoveAsync(collection, "key"); + var actual = await db.GetAsync(collection, "key"); + + // Assert + Assert.Null(actual); + } + + [Fact] + public async Task ItCanListAllDatabaseCollectionsAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + string[] testCollections = { "random_collection1", "random_collection2", "random_collection3" }; + this._collectionNum += 3; + await db.CreateCollectionAsync(testCollections[0]); + await db.CreateCollectionAsync(testCollections[1]); + await db.CreateCollectionAsync(testCollections[2]); + + // Act + var collections = await db.GetCollectionsAsync().ToListAsync(); + + // Assert + foreach (var collection in testCollections) + { + Assert.True(await db.DoesCollectionExistAsync(collection)); + } + + Assert.NotNull(collections); + Assert.NotEmpty(collections); + Assert.Equal(testCollections.Length, collections.Count); + Assert.True(collections.Contains(testCollections[0]), + $"Collections does not contain the newly-created collection {testCollections[0]}"); + Assert.True(collections.Contains(testCollections[1]), + $"Collections does not contain the newly-created collection {testCollections[1]}"); + Assert.True(collections.Contains(testCollections[2]), + $"Collections does not contain the newly-created collection {testCollections[2]}"); + } + + [Fact] + public async Task GetNearestMatchesReturnsAllResultsWithNoMinScoreAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); + int topN = 4; + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + await db.CreateCollectionAsync(collection); + int i = 0; + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 1, 1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -1, -1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 2, 3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -2, -3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, -1, -2 })); + _ = await db.UpsertAsync(collection, testRecord); + + // Act + double threshold = -1; + var topNResults = db.GetNearestMatchesAsync(collection, compareEmbedding, limit: topN, minRelevanceScore: threshold).ToEnumerable().ToArray(); + + // Assert + Assert.Equal(topN, topNResults.Length); + for (int j = 0; j < topN - 1; j++) + { + int compare = topNResults[j].Item2.CompareTo(topNResults[j + 1].Item2); + Assert.True(compare >= 0); + } + } + + [Fact] + public async Task GetNearestMatchAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + await db.CreateCollectionAsync(collection); + int i = 0; + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 1, 1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -1, -1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 2, 3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -2, -3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, -1, -2 })); + _ = await db.UpsertAsync(collection, testRecord); + + // Act + double threshold = 0.75; + var topNResultDefault = await db.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold); + var topNResultWithEmbedding = await db.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold, withEmbedding: true); + + // Assert + Assert.NotNull(topNResultDefault); + Assert.NotNull(topNResultWithEmbedding); + Assert.Empty(topNResultDefault.Value.Item1.Embedding.Vector); + Assert.NotEmpty(topNResultWithEmbedding.Value.Item1.Embedding.Vector); + } + + [Fact] + public async Task GetNearestMatchAsyncReturnsExpectedAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + await db.CreateCollectionAsync(collection); + int i = 0; + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 1, 1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -1, -1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 2, 3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -2, -3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, -1, -2 })); + _ = await db.UpsertAsync(collection, testRecord); + + // Act + double threshold = 0.75; + var topNResult = await db.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold); + + // Assert + Assert.NotNull(topNResult); + Assert.Equal("test0", topNResult.Value.Item1.Metadata.Id); + Assert.True(topNResult.Value.Item2 >= threshold); + } + + [Fact] + public async Task GetNearestMatchesDifferentiatesIdenticalVectorsByKeyAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); + int topN = 4; + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + await db.CreateCollectionAsync(collection); + + for (int i = 0; i < 10; i++) + { + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 1, 1 })); + _ = await db.UpsertAsync(collection, testRecord); + } + + // Act + var topNResults = db.GetNearestMatchesAsync(collection, compareEmbedding, limit: topN, minRelevanceScore: 0.75).ToEnumerable().ToArray(); + IEnumerable topNKeys = topNResults.Select(x => x.Item1.Key).ToImmutableSortedSet(); + + // Assert + Assert.Equal(topN, topNResults.Length); + Assert.Equal(topN, topNKeys.Count()); + + for (int i = 0; i < topNResults.Length; i++) + { + int compare = topNResults[i].Item2.CompareTo(0.75); + Assert.True(compare >= 0); + } + } + + [Fact] + public async Task ItCanBatchUpsertRecordsAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + int numRecords = 10; + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + IEnumerable records = this.CreateBatchRecords(numRecords); + + // Act + await db.CreateCollectionAsync(collection); + var keys = db.UpsertBatchAsync(collection, records); + var resultRecords = db.GetBatchAsync(collection, keys.ToEnumerable()); + + // Assert + Assert.NotNull(keys); + Assert.Equal(numRecords, keys.ToEnumerable().Count()); + Assert.Equal(numRecords, resultRecords.ToEnumerable().Count()); + } + + [Fact] + public async Task ItCanBatchGetRecordsAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + int numRecords = 10; + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + IEnumerable records = this.CreateBatchRecords(numRecords); + var keys = db.UpsertBatchAsync(collection, records); + + // Act + await db.CreateCollectionAsync(collection); + var results = db.GetBatchAsync(collection, keys.ToEnumerable()); + + // Assert + Assert.NotNull(keys); + Assert.NotNull(results); + Assert.Equal(numRecords, results.ToEnumerable().Count()); + } + + [Fact] + public async Task ItCanBatchRemoveRecordsAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + int numRecords = 10; + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + IEnumerable records = this.CreateBatchRecords(numRecords); + await db.CreateCollectionAsync(collection); + + List keys = new(); + + // Act + await foreach (var key in db.UpsertBatchAsync(collection, records)) + { + keys.Add(key); + } + + await db.RemoveBatchAsync(collection, keys); + + // Assert + await foreach (var result in db.GetBatchAsync(collection, keys)) + { + Assert.Null(result); + } + } + + [Fact] + public async Task DeletingNonExistentCollectionDoesNothingAsync() + { + // Arrange + using var db = await DuckDBMemoryStore.ConnectAsync(); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.DeleteCollectionAsync(collection); + } +}