From 4721625e6af328ebb29cff12fa45dcdbe453b0a7 Mon Sep 17 00:00:00 2001 From: "Daniel Dror (Dubovski)" Date: Thu, 10 Aug 2023 13:53:29 -0700 Subject: [PATCH] .Net: Adding Kusto as an external memory (#2257) ### Motivation and Context ### Description Adding Kusto as an external memory. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --------- Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Co-authored-by: Shawn Callegari <36091529+shawncal@users.noreply.github.com> --- dotnet/Directory.Packages.props | 1 + dotnet/SK-dotnet.sln | 9 + .../KernelSyntaxExamples/Example53_Kusto.cs | 80 ++++ .../KernelSyntaxExamples.csproj | 1 + .../samples/KernelSyntaxExamples/Program.cs | 1 + dotnet/samples/KernelSyntaxExamples/README.md | 1 + .../KernelSyntaxExamples/TestConfiguration.cs | 6 + .../Connectors.Memory.Kusto.csproj | 31 ++ .../KustoMemoryRecord.cs | 95 ++++ .../KustoMemoryStore.cs | 413 ++++++++++++++++++ .../KustoSerializer.cs | 107 +++++ .../Connectors.Memory.Kusto/README.md | 42 ++ .../Connectors.UnitTests.csproj | 1 + .../Memory/Kusto/KustoMemoryStoreTests.cs | 406 +++++++++++++++++ 14 files changed, 1194 insertions(+) create mode 100644 dotnet/samples/KernelSyntaxExamples/Example53_Kusto.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Kusto/Connectors.Memory.Kusto.csproj create mode 100644 dotnet/src/Connectors/Connectors.Memory.Kusto/KustoMemoryRecord.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Kusto/KustoMemoryStore.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Kusto/KustoSerializer.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Kusto/README.md create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/Memory/Kusto/KustoMemoryStoreTests.cs diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 5ffc494033a7..bc3ecadf1947 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -9,6 +9,7 @@ + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index bde158c3c1ac..d668e5b2d465 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -150,6 +150,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Planning.StepwisePlanner", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ApplicationInsightsExample", "samples\ApplicationInsightsExample\ApplicationInsightsExample.csproj", "{C754950A-E16C-4F96-9CC7-9328E361B5AF}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.Kusto", "src\Connectors\Connectors.Memory.Kusto\Connectors.Memory.Kusto.csproj", "{E07608CC-D710-4655-BB9E-D22CF3CDD193}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -361,6 +363,12 @@ Global {C754950A-E16C-4F96-9CC7-9328E361B5AF}.Publish|Any CPU.ActiveCfg = Release|Any CPU {C754950A-E16C-4F96-9CC7-9328E361B5AF}.Release|Any CPU.ActiveCfg = Release|Any CPU {C754950A-E16C-4F96-9CC7-9328E361B5AF}.Release|Any CPU.Build.0 = Release|Any CPU + {E07608CC-D710-4655-BB9E-D22CF3CDD193}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E07608CC-D710-4655-BB9E-D22CF3CDD193}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E07608CC-D710-4655-BB9E-D22CF3CDD193}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {E07608CC-D710-4655-BB9E-D22CF3CDD193}.Publish|Any CPU.Build.0 = Debug|Any CPU + {E07608CC-D710-4655-BB9E-D22CF3CDD193}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E07608CC-D710-4655-BB9E-D22CF3CDD193}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -413,6 +421,7 @@ Global {677F1381-7830-4115-9C1A-58B282629DC6} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {4762BCAF-E1C5-4714-B88D-E50FA333C50E} = {078F96B4-09E1-4E0E-B214-F71A4F4BF633} {C754950A-E16C-4F96-9CC7-9328E361B5AF} = {FA3720F1-C99A-49B2-9577-A940257098BF} + {E07608CC-D710-4655-BB9E-D22CF3CDD193} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/samples/KernelSyntaxExamples/Example53_Kusto.cs b/dotnet/samples/KernelSyntaxExamples/Example53_Kusto.cs new file mode 100644 index 000000000000..c133c69b2e73 --- /dev/null +++ b/dotnet/samples/KernelSyntaxExamples/Example53_Kusto.cs @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Memory.Kusto; +using Microsoft.SemanticKernel.Memory; +using RepoUtils; + +// ReSharper disable once InconsistentNaming +public static class Example53_Kusto +{ + private const string MemoryCollectionName = "kusto_test"; + + public static async Task RunAsync() + { + var connectionString = new Kusto.Data.KustoConnectionStringBuilder(TestConfiguration.Kusto.ConnectionString).WithAadUserPromptAuthentication(); + using KustoMemoryStore memoryStore = new(connectionString, "MyDatabase"); + + IKernel kernel = Kernel.Builder + .WithLogger(ConsoleLogger.Logger) + .WithOpenAITextCompletionService( + modelId: TestConfiguration.OpenAI.ModelId, + apiKey: TestConfiguration.OpenAI.ApiKey) + .WithOpenAITextEmbeddingGenerationService( + modelId: TestConfiguration.OpenAI.EmbeddingModelId, + apiKey: TestConfiguration.OpenAI.ApiKey) + .WithMemoryStorage(memoryStore) + .Build(); + + Console.WriteLine("== Printing Collections in DB =="); + var collections = memoryStore.GetCollectionsAsync(); + await foreach (var collection in collections) + { + Console.WriteLine(collection); + } + + Console.WriteLine("== Adding Memories =="); + + var key1 = await kernel.Memory.SaveInformationAsync(MemoryCollectionName, id: "cat1", text: "british short hair"); + var key2 = await kernel.Memory.SaveInformationAsync(MemoryCollectionName, id: "cat2", text: "orange tabby"); + var key3 = await kernel.Memory.SaveInformationAsync(MemoryCollectionName, id: "cat3", text: "norwegian forest cat"); + + Console.WriteLine("== Printing Collections in DB =="); + collections = memoryStore.GetCollectionsAsync(); + await foreach (var collection in collections) + { + Console.WriteLine(collection); + } + + Console.WriteLine("== Retrieving Memories Through the Kernel =="); + MemoryQueryResult? lookup = await kernel.Memory.GetAsync(MemoryCollectionName, "cat1"); + Console.WriteLine(lookup != null ? lookup.Metadata.Text : "ERROR: memory not found"); + + Console.WriteLine("== Retrieving Memories Directly From the Store =="); + var memory1 = await memoryStore.GetAsync(MemoryCollectionName, key1); + var memory2 = await memoryStore.GetAsync(MemoryCollectionName, key2); + var memory3 = await memoryStore.GetAsync(MemoryCollectionName, key3); + Console.WriteLine(memory1 != null ? memory1.Metadata.Text : "ERROR: memory not found"); + Console.WriteLine(memory2 != null ? memory2.Metadata.Text : "ERROR: memory not found"); + Console.WriteLine(memory3 != null ? memory3.Metadata.Text : "ERROR: memory not found"); + + Console.WriteLine("== Similarity Searching Memories: My favorite color is orange =="); + var searchResults = kernel.Memory.SearchAsync(MemoryCollectionName, "My favorite color is orange", limit: 3, minRelevanceScore: 0.8); + + await foreach (var item in searchResults) + { + Console.WriteLine(item.Metadata.Text + " : " + item.Relevance); + } + + Console.WriteLine("== Removing Collection {0} ==", MemoryCollectionName); + await memoryStore.DeleteCollectionAsync(MemoryCollectionName); + + Console.WriteLine("== Printing Collections in DB =="); + await foreach (var collection in collections) + { + Console.WriteLine(collection); + } + } +} diff --git a/dotnet/samples/KernelSyntaxExamples/KernelSyntaxExamples.csproj b/dotnet/samples/KernelSyntaxExamples/KernelSyntaxExamples.csproj index 244cadfa2ea4..b440c11e94f5 100644 --- a/dotnet/samples/KernelSyntaxExamples/KernelSyntaxExamples.csproj +++ b/dotnet/samples/KernelSyntaxExamples/KernelSyntaxExamples.csproj @@ -33,6 +33,7 @@ + diff --git a/dotnet/samples/KernelSyntaxExamples/Program.cs b/dotnet/samples/KernelSyntaxExamples/Program.cs index 4e2ea73ef816..b0db3feae29f 100644 --- a/dotnet/samples/KernelSyntaxExamples/Program.cs +++ b/dotnet/samples/KernelSyntaxExamples/Program.cs @@ -72,6 +72,7 @@ public static async Task Main() await Example50_Chroma.RunAsync().SafeWaitAsync(cancelToken); await Example51_StepwisePlanner.RunAsync().SafeWaitAsync(cancelToken); await Example52_ApimAuth.RunAsync().SafeWaitAsync(cancelToken); + await Example53_Kusto.RunAsync().SafeWaitAsync(cancelToken); } private static void LoadUserSecrets() diff --git a/dotnet/samples/KernelSyntaxExamples/README.md b/dotnet/samples/KernelSyntaxExamples/README.md index 81562e86e582..d96fefbb6262 100644 --- a/dotnet/samples/KernelSyntaxExamples/README.md +++ b/dotnet/samples/KernelSyntaxExamples/README.md @@ -69,6 +69,7 @@ dotnet user-secrets set "Apim:SubscriptionKey" "..." dotnet user-secrets set "Postgres:ConnectionString" "..." dotnet user-secrets set "Redis:Configuration" "..." +dotnet user-secrets set "Kusto:ConnectionString" "..." ``` To set your secrets with environment variables, use these names: diff --git a/dotnet/samples/KernelSyntaxExamples/TestConfiguration.cs b/dotnet/samples/KernelSyntaxExamples/TestConfiguration.cs index 1c2ff6f60078..7b41017cafec 100644 --- a/dotnet/samples/KernelSyntaxExamples/TestConfiguration.cs +++ b/dotnet/samples/KernelSyntaxExamples/TestConfiguration.cs @@ -36,6 +36,7 @@ public static void Initialize(IConfigurationRoot configRoot) public static RedisConfig Redis => LoadSection(); public static JiraConfig Jira => LoadSection(); public static ChromaConfig Chroma => LoadSection(); + public static KustoConfig Kusto => LoadSection(); private static T LoadSection([CallerMemberName] string? caller = null) { @@ -154,5 +155,10 @@ public class ChromaConfig { public string Endpoint { get; set; } } + + public class KustoConfig + { + public string ConnectionString { get; set; } + } #pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. } diff --git a/dotnet/src/Connectors/Connectors.Memory.Kusto/Connectors.Memory.Kusto.csproj b/dotnet/src/Connectors/Connectors.Memory.Kusto/Connectors.Memory.Kusto.csproj new file mode 100644 index 000000000000..d0b94bd567ae --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Kusto/Connectors.Memory.Kusto.csproj @@ -0,0 +1,31 @@ + + + + Microsoft.SemanticKernel.Connectors.Memory.Kusto + Microsoft.SemanticKernel.Connectors.Memory.Kusto + netstandard2.0 + + + NU5104 + + + + + + + + + Microsoft.SemanticKernel.Connectors.Memory.Kusto + Semantic Kernel - Azure Data Explorer (Kusto) Semantic Memory + Azure Data Explorer (Kusto) Semantic Memory connector for Semantic Kernel + + + + + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Kusto/KustoMemoryRecord.cs b/dotnet/src/Connectors/Connectors.Memory.Kusto/KustoMemoryRecord.cs new file mode 100644 index 000000000000..74fb8e9897b2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Kusto/KustoMemoryRecord.cs @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Kusto.Cloud.Platform.Utils; +using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.Memory; + +namespace Microsoft.SemanticKernel.Connectors.Memory.Kusto; + +/// +/// Kusto memory record entity. +/// +public sealed class KustoMemoryRecord +{ + /// + /// Entity key. + /// + public string Key { get; set; } + + /// + /// Metadata associated with memory entity. + /// + public MemoryRecordMetadata Metadata { get; set; } + + /// + /// Source content embedding. + /// + public Embedding Embedding { get; set; } + + /// + /// Optional timestamp. + /// + public DateTimeOffset? Timestamp { get; set; } + + /// + /// Initializes a new instance of the class. + /// + /// Instance of . + public KustoMemoryRecord(MemoryRecord record) : this(record.Key, record.Metadata, record.Embedding, record.Timestamp) { } + + /// + /// Initializes a new instance of the class. + /// + /// Entity key. + /// Metadata associated with memory entity. + /// Source content embedding. + /// Optional timestamp. + public KustoMemoryRecord(string key, MemoryRecordMetadata metadata, Embedding embedding, DateTimeOffset? timestamp = null) + { + this.Key = key; + this.Metadata = metadata; + this.Embedding = embedding; + this.Timestamp = timestamp; + } + + /// + /// Initializes a new instance of the class. + /// + /// Entity key. + /// Serialized metadata associated with memory entity. + /// Source content embedding. + /// Optional timestamp. + public KustoMemoryRecord(string key, string metadata, string? embedding, string? timestamp = null) + { + this.Key = key; + this.Metadata = KustoSerializer.DeserializeMetadata(metadata); + this.Embedding = KustoSerializer.DeserializeEmbedding(embedding); + this.Timestamp = KustoSerializer.DeserializeDateTimeOffset(timestamp); + } + + /// + /// Returns instance of mapped . + /// + public MemoryRecord ToMemoryRecord() + { + return new MemoryRecord(this.Metadata, this.Embedding, this.Key, this.Timestamp); + } + + /// + /// Writes properties of instance to stream using . + /// + /// Instance of to write properties to stream. + public void WriteToCsvStream(CsvWriter streamWriter) + { + var jsonifiedMetadata = KustoSerializer.SerializeMetadata(this.Metadata); + var jsonifiedEmbedding = KustoSerializer.SerializeEmbedding(this.Embedding); + var isoFormattedDate = KustoSerializer.SerializeDateTimeOffset(this.Timestamp); + + streamWriter.WriteField(this.Key); + streamWriter.WriteField(jsonifiedMetadata); + streamWriter.WriteField(jsonifiedEmbedding); + streamWriter.WriteField(isoFormattedDate); + streamWriter.CompleteRecord(); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Kusto/KustoMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Kusto/KustoMemoryStore.cs new file mode 100644 index 000000000000..b5d9e7a95a8f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Kusto/KustoMemoryStore.cs @@ -0,0 +1,413 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Kusto.Cloud.Platform.Utils; +using Kusto.Data; +using Kusto.Data.Common; +using Kusto.Data.Net.Client; +using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.Diagnostics; +using Microsoft.SemanticKernel.Memory; + +namespace Microsoft.SemanticKernel.Connectors.Memory.Kusto; + +/// +/// An implementation of backed by a Kusto database. +/// +/// +/// The embedded data is saved to the Kusto database specified in the constructor. +/// Similarity search capability is provided through a cosine similarity function (added on first search operation). Use Kusto's "Table" to implement "Collection". +/// +public class KustoMemoryStore : IMemoryStore, IDisposable +{ + /// + /// Initializes a new instance of the class. + /// + /// Kusto Admin Client. + /// Kusto Query Client. + /// The database used for the tables. + public KustoMemoryStore(ICslAdminProvider cslAdminProvider, ICslQueryProvider cslQueryProvider, string database) + { + this._database = database; + this._queryClient = cslQueryProvider; + this._adminClient = cslAdminProvider; + + this._searchInitialized = false; + this._disposer = new Disposer(nameof(KustoMemoryStore), nameof(KustoMemoryStore)); + } + + /// + /// Initializes a new instance of the class. + /// + /// Kusto Connection String Builder. + /// The database used for the tables. + public KustoMemoryStore(KustoConnectionStringBuilder builder, string database) + : this(KustoClientFactory.CreateCslAdminProvider(builder), KustoClientFactory.CreateCslQueryProvider(builder), database) + { + // Dispose resources provided by this class + this._disposer.Add(this._queryClient); + this._disposer.Add(this._adminClient); + } + + /// + public async Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default) + { + using var resp = await this._adminClient + .ExecuteControlCommandAsync( + this._database, + CslCommandGenerator.GenerateTableCreateCommand(new TableSchema(GetTableName(collectionName, normalized: false), s_collectionColumns)), + GetClientRequestProperties() + ).ConfigureAwait(false); + } + + /// + public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default) + { + using var resp = await this._adminClient + .ExecuteControlCommandAsync( + this._database, + CslCommandGenerator.GenerateTableDropCommand(GetTableName(collectionName, normalized: false)), + GetClientRequestProperties() + ).ConfigureAwait(false); + } + + /// + public async Task DoesCollectionExistAsync(string collectionName, CancellationToken cancellationToken = default) + { + var command = CslCommandGenerator.GenerateTablesShowCommand() + $" | where TableName == '{GetTableName(collectionName, normalized: false)}' | project TableName"; + var result = await this._adminClient + .ExecuteControlCommandAsync( + this._database, + command, + GetClientRequestProperties() + ).ConfigureAwait(false); + + return result.Count() == 1; + } + + /// + public async Task GetAsync(string collectionName, string key, bool withEmbedding = false, CancellationToken cancellationToken = default) + { + var result = this.GetBatchAsync(collectionName, new[] { key }, withEmbedding, cancellationToken); + return await result.FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable GetBatchAsync( + string collectionName, + IEnumerable keys, + bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var inClauseValue = string.Join(",", keys.Select(k => $"'{k}'")); + var query = $"{this.GetBaseQuery(collectionName)} " + + $"| where Key in ({inClauseValue}) " + + "| project " + + $"{KeyColumn.Name}, " + + $"{MetadataColumn.Name}=tostring({MetadataColumn.Name}), " + + $"{TimestampColumn.Name}, " + + $"{EmbeddingColumn.Name}=tostring({EmbeddingColumn.Name})"; + + if (!withEmbeddings) + { + // easiest way to ignore embeddings + query += " | extend Embedding = ''"; + } + + using var reader = await this._queryClient + .ExecuteQueryAsync( + this._database, + query, + GetClientRequestProperties(), + cancellationToken + ).ConfigureAwait(false); + + while (reader.Read()) + { + var key = reader.GetString(0); + var metadata = reader.GetString(1); + var timestamp = !reader.IsDBNull(2) ? reader.GetString(2) : null; + var embedding = withEmbeddings ? reader.GetString(3) : default; + + var kustoRecord = new KustoMemoryRecord(key, metadata, embedding, timestamp); + + yield return kustoRecord.ToMemoryRecord(); + } + } + + /// + public async IAsyncEnumerable GetCollectionsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var result = await this._adminClient + .ExecuteControlCommandAsync( + this._database, + CslCommandGenerator.GenerateTablesShowCommand(), + GetClientRequestProperties() + ).ConfigureAwait(false); + + foreach (var item in result) + { + yield return GetCollectionName(item.TableName); + } + } + + /// + public async Task<(MemoryRecord, double)?> GetNearestMatchAsync( + string collectionName, + Embedding embedding, + double minRelevanceScore = 0, + bool withEmbedding = false, + CancellationToken cancellationToken = default) + { + var result = this.GetNearestMatchesAsync(collectionName, embedding, 1, minRelevanceScore, withEmbedding, cancellationToken); + return await result.FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync( + string collectionName, + Embedding embedding, + int limit, + double minRelevanceScore = 0, + bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + this.InitializeVectorFunctions(); + + var similarityQuery = $"{this.GetBaseQuery(collectionName)} | extend similarity=series_cosine_similarity_fl('{KustoSerializer.SerializeEmbedding(embedding)}', {EmbeddingColumn.Name}, 1, 1)"; + + if (minRelevanceScore != 0) + { + similarityQuery += $" | where similarity > {minRelevanceScore}"; + } + + similarityQuery += $" | top {limit} by similarity desc"; + + // reorder to make it easier to ignore the embedding (key, metadata, timestamp, similarity, embedding) + // Using tostring to make it easier to parse the result. There are probably better ways we should explore. + similarityQuery += "| project " + + $"{KeyColumn.Name}, " + + $"{MetadataColumn.Name}=tostring({MetadataColumn.Name}), " + + $"{TimestampColumn.Name}, " + + "similarity, " + + $"{EmbeddingColumn.Name}=tostring({EmbeddingColumn.Name})"; + + if (!withEmbeddings) + { + similarityQuery += $" | project-away {EmbeddingColumn.Name} "; + } + + using var reader = await this._queryClient + .ExecuteQueryAsync( + this._database, + similarityQuery, + GetClientRequestProperties(), + cancellationToken + ).ConfigureAwait(false); + + while (reader.Read()) + { + var key = reader.GetString(0); + var metadata = reader.GetString(1); + var timestamp = !reader.IsDBNull(2) ? reader.GetString(2) : null; + var similarity = reader.GetDouble(3); + var recordEmbedding = withEmbeddings ? reader.GetString(4) : default; + + var kustoRecord = new KustoMemoryRecord(key, metadata, recordEmbedding, timestamp); + + yield return (kustoRecord.ToMemoryRecord(), similarity); + } + } + + /// + public Task RemoveAsync(string collectionName, string key, CancellationToken cancellationToken = default) + => this.RemoveBatchAsync(collectionName, new[] { key }, cancellationToken); + + /// + public async Task RemoveBatchAsync(string collectionName, IEnumerable keys, CancellationToken cancellationToken = default) + { + if (keys != null) + { + var keysString = string.Join(",", keys.Select(k => $"'{k}'")); + using var resp = await this._adminClient + .ExecuteControlCommandAsync( + this._database, + CslCommandGenerator.GenerateDeleteTableRecordsCommand(GetTableName(collectionName), $"{GetTableName(collectionName)} | where Key in ({keysString})"), + GetClientRequestProperties() + ).ConfigureAwait(false); + } + } + + /// + public async Task UpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken = default) + { + var result = this.UpsertBatchAsync(collectionName, new[] { record }, cancellationToken); + return await result.FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false) ?? string.Empty; + } + + /// + public async IAsyncEnumerable UpsertBatchAsync( + string collectionName, + IEnumerable records, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // In Kusto, upserts don't exist because it operates as an append-only store. + // Nevertheless, given that we have a straightforward primary key (PK), we can simply insert a new record. + // Our query always selects the latest row of that PK. + // An interesting scenario arises when performing deletion after many "upserts". + // This could turn out to be a heavy operation since, in theory, we might need to remove many outdated versions. + // Additionally, deleting these records equates to a "soft delete" operation. + // For simplicity, and under the assumption that upserts are relatively rare in most systems, + // we will overlook the potential accumulation of "garbage" records. + // Kusto is generally efficient with handling large volumes of data. + using var stream = new MemoryStream(); + using var streamWriter = new StreamWriter(stream); + var csvWriter = new FastCsvWriter(streamWriter); + + var keys = new List(); + var recordsAsList = records.ToList(); + + for (var i = 0; i < recordsAsList.Count; i++) + { + var record = recordsAsList[i]; + record.Key = record.Metadata.Id; + keys.Add(record.Key); + new KustoMemoryRecord(record).WriteToCsvStream(csvWriter); + } + + csvWriter.Flush(); + stream.Seek(0, SeekOrigin.Begin); + + var command = CslCommandGenerator.GenerateTableIngestPushCommand(GetTableName(collectionName), false, stream); + await this._adminClient + .ExecuteControlCommandAsync( + this._database, + command, + GetClientRequestProperties() + ).ConfigureAwait(false); + + foreach (var key in keys) + { + yield return key; + } + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + this._disposer.Dispose(); + } + } + + #region private ================================================================================ + + private Disposer _disposer; + private object _lock = new(); + + private string _database; + + private static ClientRequestProperties GetClientRequestProperties() => new() + { + Application = Telemetry.HttpUserAgent, + }; + + private bool _searchInitialized; + + private readonly ICslQueryProvider _queryClient; + private readonly ICslAdminProvider _adminClient; + + private static ColumnSchema KeyColumn = new("Key", typeof(string).FullName); + private static ColumnSchema MetadataColumn = new("Metadata", typeof(object).FullName); + private static ColumnSchema EmbeddingColumn = new("Embedding", typeof(object).FullName); + private static ColumnSchema TimestampColumn = new("Timestamp", typeof(DateTime).FullName); + + private static readonly ColumnSchema[] s_collectionColumns = new ColumnSchema[] + { + KeyColumn, + MetadataColumn, + EmbeddingColumn, + TimestampColumn + }; + + /// + /// Converts collection name to Kusto table name. + /// + /// + /// Kusto escaping rules for table names: https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/schema-entities/entity-names#identifier-quoting + /// + /// Kusto table name. + /// Boolean flag that indicates if table name normalization is needed. + private static string GetTableName(string collectionName, bool normalized = true) + => normalized ? CslSyntaxGenerator.NormalizeTableName(collectionName) : collectionName; + + /// + /// Converts Kusto table name to collection name. + /// + /// + /// Kusto escaping rules for table names: https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/schema-entities/entity-names#identifier-quoting + /// + /// Kusto table name. + private static string GetCollectionName(string tableName) + => tableName.Replace("['", "").Replace("']", ""); + + /// + /// Returns base Kusto query. + /// + /// + /// Kusto is an append-only store. Although deletions are possible, they are highly discourged, + /// and should only be used in rare cases (see: https://learn.microsoft.com/en-us/azure/data-explorer/kusto/concepts/data-soft-delete#use-cases). + /// As such, the recommended approach for dealing with row updates is versioning. + /// An easy way to achieve this is by using the ingestion time of the record (insertion time). + /// + /// Collection name. + private string GetBaseQuery(string collection) + => $"{GetTableName(collection)} | summarize arg_max(ingestion_time(), *) by {KeyColumn.Name} "; + + /// + /// Initializes vector cosine similarity function for given database. + /// + /// + /// Cosine similarity function is created only once for better performance. + /// It's possible to run function creation multiple times, since .create-or-alter command is idempotent. + /// + private void InitializeVectorFunctions() + { + if (!this._searchInitialized) + { + lock (this._lock) + { + if (!this._searchInitialized) + { + var resp = this._adminClient + .ExecuteControlCommand( + this._database, + ".create-or-alter function with (docstring = 'Calculate the Cosine similarity of 2 numerical arrays',folder = 'Vector') series_cosine_similarity_fl(vec1:dynamic,vec2:dynamic,vec1_size:real=real(null),vec2_size:real=real(null)) {" + + " let dp = series_dot_product(vec1, vec2);" + + " let v1l = iff(isnull(vec1_size), sqrt(series_dot_product(vec1, vec1)), vec1_size);" + + " let v2l = iff(isnull(vec2_size), sqrt(series_dot_product(vec2, vec2)), vec2_size);" + + " dp/(v1l*v2l)" + + "}", + GetClientRequestProperties() + ); + + this._searchInitialized = true; + } + } + } + } + + #endregion private ================================================================================ +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Kusto/KustoSerializer.cs b/dotnet/src/Connectors/Connectors.Memory.Kusto/KustoSerializer.cs new file mode 100644 index 000000000000..d4d5e3d73883 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Kusto/KustoSerializer.cs @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Globalization; +using System.Text.Json; +using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.Memory; + +namespace Microsoft.SemanticKernel.Connectors.Memory.Kusto; + +/// +/// Contains serialization/deserialization logic for memory record properties in Kusto. +/// +public static class KustoSerializer +{ + /// + /// Returns serialized string from instance. + /// + /// Instance of for serialization. + public static string SerializeEmbedding(Embedding embedding) + { + return JsonSerializer.Serialize(embedding.Vector); + } + + /// + /// Returns deserialized instance of from serialized embedding. + /// + /// Serialized embedding. + public static Embedding DeserializeEmbedding(string? embedding) + { + if (string.IsNullOrEmpty(embedding)) + { + return default; + } + + float[]? floatArray = JsonSerializer.Deserialize(embedding!); + + if (floatArray == null) + { + return default; + } + + return new Embedding(floatArray); + } + + /// + /// Returns serialized string from instance. + /// + /// Instance of for serialization. + public static string SerializeMetadata(MemoryRecordMetadata metadata) + { + if (metadata == null) + { + return string.Empty; + } + + return JsonSerializer.Serialize(metadata); + } + + /// + /// Returns deserialized instance of from serialized metadata. + /// + /// Serialized metadata. + public static MemoryRecordMetadata DeserializeMetadata(string metadata) + { + return JsonSerializer.Deserialize(metadata)!; + } + + /// + /// Returns serialized string from instance. + /// + /// Instance of for serialization. + public static string SerializeDateTimeOffset(DateTimeOffset? dateTimeOffset) + { + if (dateTimeOffset == null) + { + return string.Empty; + } + + return dateTimeOffset.Value.DateTime.ToString(TimestampFormat, CultureInfo.InvariantCulture); + } + + /// + /// Returns deserialized instance of from serialized timestamp. + /// + /// Serialized timestamp. + public static DateTimeOffset? DeserializeDateTimeOffset(string? dateTimeOffset) + { + if (string.IsNullOrWhiteSpace(dateTimeOffset)) + { + return null; + } + + if (DateTimeOffset.TryParseExact(dateTimeOffset, TimestampFormat, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset result)) + { + return result; + } + + throw new InvalidCastException("Timestamp format cannot be parsed"); + } + + #region private ================================================================================ + + private const string TimestampFormat = "yyyy-MM-ddTHH:mm:ssZ"; + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Kusto/README.md b/dotnet/src/Connectors/Connectors.Memory.Kusto/README.md new file mode 100644 index 000000000000..6b33ab53ec81 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Kusto/README.md @@ -0,0 +1,42 @@ +# Microsoft.SemanticKernel.Connectors.Memory.Kusto + +This connector uses [Azure Data Explorer (Kusto)](https://learn.microsoft.com/en-us/azure/data-explorer/) to implement Semantic Memory. + +## Quick Start + +1. Create a cluster and database in Azure Data Explorer (Kusto) - see https://learn.microsoft.com/en-us/azure/data-explorer/create-cluster-and-database?tabs=free + +2. To use Kusto as a semantic memory store, use the following code: + +```csharp +using Kusto.Data; + +var connectionString = new KustoConnectionStringBuilder("https://kvc123.eastus.kusto.windows.net").WithAadUserPromptAuthentication(); +KustoMemoryStore memoryStore = new(connectionString, "MyDatabase"); + +IKernel kernel = Kernel.Builder + .WithLogger(ConsoleLogger.Log) + .WithOpenAITextCompletionService(modelId: TestConfiguration.OpenAI.ModelId, apiKey: TestConfiguration.OpenAI.ApiKey) + .WithOpenAITextEmbeddingGenerationService(modelId: TestConfiguration.OpenAI.EmbeddingModelId,apiKey: TestConfiguration.OpenAI.ApiKey) + .WithMemoryStorage(memoryStore) + .Build(); +``` + +## Important Notes + +### Cosine Similarity +As of now, cosine similarity is not built-in to Kusto. +A function to calculate cosine similarity is automatically added to the Kusto database during first search operation. +This function (`series_cosine_similarity_fl`) is not removed automatically. +You might want to delete it manually if you stop using the Kusto database as a semantic memory store. +If you want to delete the function, you can do it manually using the Kusto explorer. +The function is called `series_cosine_similarity_fl` and is located in the `Functions` folder of the database. + +### Append-Only Store +Kusto is an append-only store. This means that when a fact is updated, the old fact is not deleted. +This isn't a problem for the semantic memory connector, as it always utilizes the most recent fact. +This is made possible by using the [arg_max](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/arg-max-aggfunction) aggregation function in conjunction with the [ingestion_time](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ingestiontimefunction) function. +However, users manually querying the underlying table should be aware of this behavior. + +### Authentication +Please note that the authentication used in the example above is not recommended for production use. You can find more details here: https://learn.microsoft.com/en-us/azure/data-explorer/kusto/api/connection-strings/kusto diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj index 62baaacb23ed..0c4ae44ca110 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj @@ -35,6 +35,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Kusto/KustoMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Kusto/KustoMemoryStoreTests.cs new file mode 100644 index 000000000000..2a3d6867eba1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Kusto/KustoMemoryStoreTests.cs @@ -0,0 +1,406 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Kusto.Cloud.Platform.Utils; +using Kusto.Data.Common; +using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.Connectors.Memory.Kusto; +using Microsoft.SemanticKernel.Diagnostics; +using Microsoft.SemanticKernel.Memory; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.UnitTests.Memory.Kusto; + +/// +/// Unit tests for class. +/// +public class KustoMemoryStoreTests +{ + private const string CollectionName = "fake_collection"; + private const string DatabaseName = "FakeDb"; + private readonly Mock _cslQueryProviderMock; + private readonly Mock _cslAdminProviderMock; + + public KustoMemoryStoreTests() + { + this._cslQueryProviderMock = new Mock(); + this._cslAdminProviderMock = new Mock(); + + this._cslAdminProviderMock + .Setup(client => client.ExecuteControlCommandAsync( + DatabaseName, + It.IsAny(), + It.IsAny())) + .ReturnsAsync(FakeEmptyResult()); + + this._cslAdminProviderMock + .Setup(client => client.ExecuteControlCommand( + DatabaseName, + It.IsAny(), + It.IsAny())) + .Returns(FakeEmptyResult()); + + this._cslQueryProviderMock + .Setup(client => client.ExecuteQueryAsync( + DatabaseName, + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(FakeEmptyResult()); + } + + [Fact] + public async Task ItCanCreateCollectionAsync() + { + // Arrange + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + + // Act + await store.CreateCollectionAsync(CollectionName); + + // Assert + this._cslAdminProviderMock + .Verify(client => client.ExecuteControlCommandAsync( + DatabaseName, + It.Is(s => s.StartsWith($".create table {CollectionName}")), + It.Is(crp => string.Equals(crp.Application, Telemetry.HttpUserAgent, StringComparison.Ordinal)) + ), Times.Once()); + } + + [Fact] + public async Task ItCanDeleteCollectionAsync() + { + // Arrange + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + + // Act + await store.DeleteCollectionAsync(CollectionName); + + // Assert + // Assert + this._cslAdminProviderMock + .Verify(client => client.ExecuteControlCommandAsync( + DatabaseName, + It.Is(s => s.StartsWith($".drop table {CollectionName}")), + It.Is(crp => string.Equals(crp.Application, Telemetry.HttpUserAgent, StringComparison.Ordinal)) + ), Times.Once()); + } + + [Fact] + public async Task ItReturnsTrueWhenCollectionExistsAsync() + { + // Arrange + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + + this._cslAdminProviderMock + .Setup(client => client.ExecuteControlCommandAsync( + DatabaseName, + It.Is(s => s.StartsWith(CslCommandGenerator.GenerateTablesShowCommand())), + It.IsAny())) + .ReturnsAsync(CollectionToSingleColumnDataReader(new[] { CollectionName })); + + // Act + var doesCollectionExist = await store.DoesCollectionExistAsync(CollectionName); + + // Assert + Assert.True(doesCollectionExist); + } + + [Fact] + public async Task ItReturnsFalseWhenCollectionDoesNotExistAsync() + { + // Arrange + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + + this._cslAdminProviderMock + .Setup(client => client.ExecuteControlCommandAsync( + DatabaseName, + It.Is(s => s.StartsWith(CslCommandGenerator.GenerateTablesShowCommand())), + It.IsAny())) + .ReturnsAsync(FakeEmptyResult()); + + // Act + var doesCollectionExist = await store.DoesCollectionExistAsync(CollectionName); + + // Assert + Assert.False(doesCollectionExist); + } + + [Fact] + public async Task ItCanUpsertAsync() + { + // Arrange + var expectedMemoryRecord = this.GetRandomMemoryRecord(); + var kustoMemoryEntry = new KustoMemoryRecord(expectedMemoryRecord); + + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + + // Act + var actualMemoryRecordKey = await store.UpsertAsync(CollectionName, expectedMemoryRecord); + + // Assert + this._cslAdminProviderMock.Verify(client => client.ExecuteControlCommandAsync( + DatabaseName, + It.Is(s => s.StartsWith($".ingest inline into table {CollectionName}", StringComparison.Ordinal) && s.Contains(actualMemoryRecordKey, StringComparison.Ordinal)), + It.IsAny()), Times.Once()); + Assert.Equal(expectedMemoryRecord.Key, actualMemoryRecordKey); + } + + [Fact] + public async Task ItCanUpsertBatchAsyncAsync() + { + // Arrange + var memoryRecord1 = this.GetRandomMemoryRecord(); + var memoryRecord2 = this.GetRandomMemoryRecord(); + var memoryRecord3 = this.GetRandomMemoryRecord(); + + var batchUpsertMemoryRecords = new[] { memoryRecord1, memoryRecord2, memoryRecord3 }; + var expectedMemoryRecordKeys = batchUpsertMemoryRecords.Select(l => l.Key).ToList(); + + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + + // Act + var actualMemoryRecordKeys = await store.UpsertBatchAsync(CollectionName, batchUpsertMemoryRecords).ToListAsync(); + + // Assert + this._cslAdminProviderMock + .Verify(client => client.ExecuteControlCommandAsync( + DatabaseName, + It.Is(s => + s.StartsWith($".ingest inline into table {CollectionName}", StringComparison.Ordinal) && + batchUpsertMemoryRecords.All(r => s.Contains(r.Key, StringComparison.Ordinal))), + It.IsAny() + ), Times.Once()); + + for (int i = 0; i < expectedMemoryRecordKeys.Count; i++) + { + Assert.Equal(expectedMemoryRecordKeys[i], actualMemoryRecordKeys[i]); + } + } + + [Fact] + public async Task ItCanGetMemoryRecordFromCollectionAsync() + { + // Arrange + var expectedMemoryRecord = this.GetRandomMemoryRecord(); + var kustoMemoryEntry = new KustoMemoryRecord(expectedMemoryRecord); + + this._cslQueryProviderMock + .Setup(client => client.ExecuteQueryAsync( + DatabaseName, + It.Is(s => s.Contains(CollectionName) && s.Contains(expectedMemoryRecord.Key)), + It.IsAny(), + CancellationToken.None)) + .ReturnsAsync(CollectionToDataReader(new string[][] { + new string[] { + expectedMemoryRecord.Key, + KustoSerializer.SerializeMetadata(expectedMemoryRecord.Metadata), + KustoSerializer.SerializeDateTimeOffset(expectedMemoryRecord.Timestamp), + KustoSerializer.SerializeEmbedding(expectedMemoryRecord.Embedding), + }})); + + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + + // Act + var actualMemoryRecord = await store.GetAsync(CollectionName, expectedMemoryRecord.Key, withEmbedding: true); + + // Assert + Assert.NotNull(actualMemoryRecord); + this.AssertMemoryRecordEqual(expectedMemoryRecord, actualMemoryRecord); + } + + [Fact] + public async Task ItReturnsNullWhenMemoryRecordDoesNotExistAsync() + { + // Arrange + const string memoryRecordKey = "fake-record-key"; + + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + + // Act + var actualMemoryRecord = await store.GetAsync(CollectionName, memoryRecordKey, withEmbedding: true); + + // Assert + Assert.Null(actualMemoryRecord); + } + + [Fact] + public async Task ItCanGetMemoryRecordBatchFromCollectionAsync() + { + // Arrange + var memoryRecord1 = this.GetRandomMemoryRecord(); + var memoryRecord2 = this.GetRandomMemoryRecord(); + var memoryRecord3 = this.GetRandomMemoryRecord(); + + var batchUpsertMemoryRecords = new[] { memoryRecord1, memoryRecord2, memoryRecord3 }; + var expectedMemoryRecordKeys = batchUpsertMemoryRecords.Select(l => l.Key).ToList(); + + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + this._cslQueryProviderMock + .Setup(client => client.ExecuteQueryAsync( + DatabaseName, + It.Is(s => + s.Contains(CollectionName, StringComparison.Ordinal) && + batchUpsertMemoryRecords.All(r => s.Contains(r.Key, StringComparison.Ordinal))), + It.IsAny(), + CancellationToken.None)) + .ReturnsAsync(CollectionToDataReader(batchUpsertMemoryRecords.Select(r => new string[] { + r.Key, + KustoSerializer.SerializeMetadata(r.Metadata), + KustoSerializer.SerializeDateTimeOffset(r.Timestamp), + KustoSerializer.SerializeEmbedding(r.Embedding), + }).ToArray())); + + // Act + var actualMemoryRecords = await store.GetBatchAsync(CollectionName, expectedMemoryRecordKeys, withEmbeddings: true).ToListAsync(); + + // Assert + Assert.NotNull(actualMemoryRecords); + for (var i = 0; i < actualMemoryRecords.Count; i++) + { + this.AssertMemoryRecordEqual(batchUpsertMemoryRecords[i], actualMemoryRecords[i]); + } + } + + [Fact] + public async Task ItCanReturnCollectionsAsync() + { + // Arrange + var expectedCollections = new List { "fake-collection-1", "fake-collection-2", "fake-collection-3" }; + + this._cslAdminProviderMock + .Setup(client => client.ExecuteControlCommandAsync( + DatabaseName, + It.Is(s => s.StartsWith(CslCommandGenerator.GenerateTablesShowCommand(), StringComparison.Ordinal)), + It.IsAny()) + ).ReturnsAsync(CollectionToSingleColumnDataReader(expectedCollections)); + + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + + // Act + var actualCollections = await store.GetCollectionsAsync().ToListAsync(); + + // Assert + Assert.Equal(expectedCollections.Count, actualCollections.Count); + + for (var i = 0; i < expectedCollections.Count; i++) + { + Assert.Equal(expectedCollections[i], actualCollections[i]); + } + } + + [Fact] + public async Task ItCanRemoveAsync() + { + // Arrange + const string memoryRecordKey = "fake-record-key"; + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + + // Act + await store.RemoveAsync(CollectionName, memoryRecordKey); + + // Assert + this._cslAdminProviderMock + .Verify(client => client.ExecuteControlCommandAsync( + DatabaseName, + It.Is(s => s.Replace(" ", " ").StartsWith($".delete table {CollectionName}") && s.Contains(memoryRecordKey)), // Replace double spaces with single space to account for the fact that the query is formatted with double spaces and to be future proof + It.IsAny() + ), Times.Once()); + } + + [Fact] + public async Task ItCanRemoveBatchAsync() + { + // Arrange + string[] memoryRecordKeys = new string[] { "fake-record-key1", "fake-record-key2", "fake-record-key3" }; + using var store = new KustoMemoryStore(this._cslAdminProviderMock.Object, this._cslQueryProviderMock.Object, DatabaseName); + + // Act + await store.RemoveBatchAsync(CollectionName, memoryRecordKeys); + + // Assert + this._cslAdminProviderMock + .Verify(client => client.ExecuteControlCommandAsync( + DatabaseName, + It.Is(s => s.Replace(" ", " ").StartsWith($".delete table {CollectionName}") && memoryRecordKeys.All(r => s.Contains(r, StringComparison.OrdinalIgnoreCase))), + It.IsAny() + ), Times.Once()); + } + + #region private ================================================================================ + + private void AssertMemoryRecordEqual(MemoryRecord expectedRecord, MemoryRecord actualRecord) + { + Assert.Equal(expectedRecord.Key, actualRecord.Key); + Assert.Equal(expectedRecord.Timestamp, actualRecord.Timestamp); + Assert.Equal(expectedRecord.Embedding.Vector, actualRecord.Embedding.Vector); + Assert.Equal(expectedRecord.Metadata.Id, actualRecord.Metadata.Id); + Assert.Equal(expectedRecord.Metadata.Text, actualRecord.Metadata.Text); + Assert.Equal(expectedRecord.Metadata.Description, actualRecord.Metadata.Description); + Assert.Equal(expectedRecord.Metadata.AdditionalMetadata, actualRecord.Metadata.AdditionalMetadata); + Assert.Equal(expectedRecord.Metadata.IsReference, actualRecord.Metadata.IsReference); + Assert.Equal(expectedRecord.Metadata.ExternalSourceName, actualRecord.Metadata.ExternalSourceName); + } + + private MemoryRecord GetRandomMemoryRecord(Embedding? embedding = null) + { + var id = Guid.NewGuid().ToString(); + var memoryEmbedding = embedding ?? new Embedding(new[] { 1f, 3f, 5f }); + + return MemoryRecord.LocalRecord( + id: id, + text: "text-" + Guid.NewGuid().ToString(), + description: "description-" + Guid.NewGuid().ToString(), + embedding: memoryEmbedding, + additionalMetadata: "metadata-" + Guid.NewGuid().ToString(), + key: id, + timestamp: new DateTimeOffset(2023, 8, 4, 23, 59, 59, TimeSpan.Zero)); + } + + private static DataTableReader FakeEmptyResult() => Array.Empty().ToDataTable().CreateDataReader(); + + private static DataTableReader CollectionToSingleColumnDataReader(IEnumerable collection) + { + using var table = new DataTable(); + table.Columns.Add("Column1", typeof(string)); + + foreach (var item in collection) + { + table.Rows.Add(item); + } + + return table.CreateDataReader(); + } + + private static DataTableReader CollectionToDataReader(string[][] data) + { + using var table = new DataTable(); + + if (data != null) + { + data = data.ToArrayIfNotAlready(); + if (data[0] != null) + { + for (int i = 0; i < data[0].Length; i++) + { + table.Columns.Add($"Column{i + 1}", typeof(string)); + } + } + + for (int i = 0; i < data.Length; i++) + { + table.Rows.Add(data[i]); + } + } + + return table.CreateDataReader(); + } + + #endregion +}