From 660d12f94b99306db00a50c18b15fb5ccbd00514 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Wed, 18 Dec 2024 15:41:50 -0800 Subject: [PATCH] RAG: discard duplicate facts by default (#952) ## Motivation and Context (Why the change? What's the scenario?) When importing chat messages or running tests with small files or without IDs, the storage might fill up with duplicate chunks of text that affect the number of tokens used when generating an answer. ## High level description (Approach, Design) When using the ASK API, after fetching N records from storage, discard duplicates and use only unique chunks in the RAG prompt. The behavior is configurable via `SearchClientConfig.IncludeDuplicateFacts` and request context `custom_rag_include_duplicate_facts_bool` arg. --- examples/002-dotnet-Serverless/Program.cs | 4 ++ service/Abstractions/Constants.cs | 3 + service/Abstractions/Context/IContext.cs | 10 +++ .../Abstractions/Search/SearchClientConfig.cs | 15 +++++ service/Core/Search/SearchClient.cs | 63 ++++++++++++------- service/Core/Search/SearchClientResult.cs | 5 +- service/Service/appsettings.json | 2 + 7 files changed, 78 insertions(+), 24 deletions(-) diff --git a/examples/002-dotnet-Serverless/Program.cs b/examples/002-dotnet-Serverless/Program.cs index 5febbd487..77b77bc5a 100644 --- a/examples/002-dotnet-Serverless/Program.cs +++ b/examples/002-dotnet-Serverless/Program.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using Microsoft.KernelMemory; +using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.Safety.AzureAIContentSafety; /* Use MemoryServerlessClient to run the default import pipeline @@ -21,6 +22,8 @@ public static class Program public static async Task Main() { + SensitiveDataLogger.Enabled = true; + var memoryConfiguration = new KernelMemoryConfig(); var searchClientConfig = new SearchClientConfig(); @@ -61,6 +64,7 @@ public static async Task Main() l.AddSimpleConsole(c => c.SingleLine = true); })) .AddSingleton(memoryConfiguration) + .WithSearchClientConfig(searchClientConfig) // .WithOpenAIDefaults(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) // Use OpenAI for text generation and embedding // .WithOpenAI(openAIConfig) // Use OpenAI for text generation and embedding // .WithLlamaTextGeneration(llamaConfig) // Generate answers and summaries using LLama diff --git a/service/Abstractions/Constants.cs b/service/Abstractions/Constants.cs index 408c1b8dc..40987040c 100644 --- a/service/Abstractions/Constants.cs +++ b/service/Abstractions/Constants.cs @@ -70,6 +70,9 @@ public static class Rag // Used to override how facts are injected into RAG prompt public const string FactTemplate = "custom_rag_fact_template_str"; + // Used to override if duplicate facts are included in RAG prompts + public const string IncludeDuplicateFacts = "custom_rag_include_duplicate_facts_bool"; + // Used to override the max tokens to generate when using the RAG prompt public const string MaxTokens = "custom_rag_max_tokens_int"; diff --git a/service/Abstractions/Context/IContext.cs b/service/Abstractions/Context/IContext.cs index 45ebc3ddb..a0e6ac560 100644 --- a/service/Abstractions/Context/IContext.cs +++ b/service/Abstractions/Context/IContext.cs @@ -110,6 +110,16 @@ public static string GetCustomRagFactTemplateOrDefault(this IContext? context, s return defaultValue; } + public static bool GetCustomRagIncludeDuplicateFactsOrDefault(this IContext? context, bool defaultValue) + { + if (context.TryGetArg(Constants.CustomContext.Rag.IncludeDuplicateFacts, out var customValue)) + { + return customValue; + } + + return defaultValue; + } + public static string GetCustomRagPromptOrDefault(this IContext? context, string defaultValue) { if (context.TryGetArg(Constants.CustomContext.Rag.Prompt, out var customValue)) diff --git a/service/Abstractions/Search/SearchClientConfig.cs b/service/Abstractions/Search/SearchClientConfig.cs index a92498e9e..8c7b40e93 100644 --- a/service/Abstractions/Search/SearchClientConfig.cs +++ b/service/Abstractions/Search/SearchClientConfig.cs @@ -53,6 +53,21 @@ public class SearchClientConfig /// public string FactTemplate { get; set; } = "==== [File:{{$source}};Relevance:{{$relevance}}]:\n{{$content}}"; + /// + /// The memory DB might include duplicate chunks of text, e.g. when importing the same files + /// with different document IDs or chat messages (high probability), or when the same text + /// appears in different files (not very frequent, considering partitioning process). + /// If two chunks are equal (not case-sensitive), regardless of tags and file names, it's usually + /// better to skip the duplication, including the chunk only once in the RAG prompt, reducing the + /// tokens used. The chunk will still be listed under sources. + /// You might want to set this to True if your prompt includes other chunk details, such as tags + /// and filenames, that could affect the LLM output. + /// Note: when the storage contains duplicate records, other relevant records will be left out, + /// possibly affecting RAG quality, because deduplication occurs after retrieving N records from storage, + /// leaving RAG with [N - count(duplicates)] records to work with. + /// + public bool IncludeDuplicateFacts { get; set; } = false; + /// /// Number between 0.0 and 2.0. It controls the randomness of the completion. /// The higher the temperature, the more random the completion. diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index f6b3ddae7..8216e8a83 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -201,6 +201,7 @@ public async IAsyncEnumerable AskStreamingAsync( string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer); string answerPrompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); int limit = context.GetCustomRagMaxMatchesCountOrDefault(this._config.MaxMatchesCount); + bool includeDuplicateFacts = context.GetCustomRagIncludeDuplicateFactsOrDefault(this._config.IncludeDuplicateFacts); var maxTokens = this._config.MaxAskPromptSize > 0 ? this._config.MaxAskPromptSize @@ -242,7 +243,7 @@ public async IAsyncEnumerable AskStreamingAsync( await foreach ((MemoryRecord memoryRecord, double recordRelevance) in matches.ConfigureAwait(false)) { result.State = SearchState.Continue; - result = this.ProcessMemoryRecord(result, index, memoryRecord, recordRelevance, factTemplate); + result = this.ProcessMemoryRecord(result, index, memoryRecord, recordRelevance, includeDuplicateFacts, factTemplate); if (result.State == SearchState.SkipRecord) { continue; } @@ -276,10 +277,11 @@ public async IAsyncEnumerable AskStreamingAsync( /// Memory record, e.g. text chunk + metadata /// Memory record relevance /// Memory index name + /// Whether to include or skip duplicate chunks of text /// How to render the record when preparing an LLM prompt /// Updated search result state private SearchClientResult ProcessMemoryRecord( - SearchClientResult result, string index, MemoryRecord record, double recordRelevance, string? factTemplate = null) + SearchClientResult result, string index, MemoryRecord record, double recordRelevance, bool includeDupes = true, string? factTemplate = null) { var partitionText = record.GetPartitionText(this._log).Trim(); if (string.IsNullOrEmpty(partitionText)) @@ -309,6 +311,11 @@ private SearchClientResult ProcessMemoryRecord( // Name of the file to show to the LLM, avoiding "content.url" string fileNameForLLM = (fileName == "content.url" ? fileDownloadUrl : fileName); + // Dupes management note: don't skip the record, only skip the chunk in the prompt + // so Citations includes also duplicates, which might have different tags + bool isDupe = !result.FactsUniqueness.Add($"{partitionText}"); + bool skipFactInPrompt = (isDupe && !includeDupes); + if (result.Mode == SearchMode.SearchMode) { // Relevance is `float.MinValue` when search uses only filters @@ -318,29 +325,39 @@ private SearchClientResult ProcessMemoryRecord( { result.FactsAvailableCount++; - string fact = PromptUtils.RenderFactTemplate( - template: factTemplate!, - factContent: partitionText, - source: fileNameForLLM, - relevance: recordRelevance.ToString("P1", CultureInfo.CurrentCulture), - recordId: record.Id, - tags: record.Tags, - metadata: record.Payload); - - // Use the partition/chunk only if there's room for it - int factSizeInTokens = this._textGenerator.CountTokens(fact); - if (factSizeInTokens >= result.TokensAvailable) + if (!skipFactInPrompt) { - // Stop after reaching the max number of tokens - return result.Stop(); + string fact = PromptUtils.RenderFactTemplate( + template: factTemplate!, + factContent: partitionText, + source: fileNameForLLM, + relevance: recordRelevance.ToString("P1", CultureInfo.CurrentCulture), + recordId: record.Id, + tags: record.Tags, + metadata: record.Payload); + + // Use the partition/chunk only if there's room for it + int factSizeInTokens = this._textGenerator.CountTokens(fact); + if (factSizeInTokens >= result.TokensAvailable) + { + // Stop after reaching the max number of tokens + return result.Stop(); + } + + result.Facts.Append(fact); + result.FactsUsedCount++; + result.TokensAvailable -= factSizeInTokens; + + // Relevance is cosine similarity when not using hybrid search + this._log.LogTrace("Adding content #{FactsUsedCount} with relevance {Relevance} (dupe: {IsDupe})", + result.FactsUsedCount, recordRelevance, isDupe); + } + else + { + // The counter must be increased to avoid long/infinite loops + // in case the storage contains several duplications + result.FactsUsedCount++; } - - result.Facts.Append(fact); - result.FactsUsedCount++; - result.TokensAvailable -= factSizeInTokens; - - // Relevance is cosine similarity when not using hybrid search - this._log.LogTrace("Adding content #{0} with relevance {1}", result.FactsUsedCount, recordRelevance); } var citation = result.Mode switch diff --git a/service/Core/Search/SearchClientResult.cs b/service/Core/Search/SearchClientResult.cs index 66ff58f15..a795d5cce 100644 --- a/service/Core/Search/SearchClientResult.cs +++ b/service/Core/Search/SearchClientResult.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; using System.Text; namespace Microsoft.KernelMemory.Search; @@ -37,8 +39,9 @@ internal class SearchClientResult public SearchResult SearchResult { get; private init; } = new(); public StringBuilder Facts { get; } = new(); public int FactsAvailableCount { get; set; } - public int FactsUsedCount { get; set; } + public int FactsUsedCount { get; set; } // Note: the number includes also duplicate chunks not used in the prompt public int TokensAvailable { get; set; } + public HashSet FactsUniqueness { get; set; } = new(StringComparer.OrdinalIgnoreCase); /// /// Create new instance in Ask mode diff --git a/service/Service/appsettings.json b/service/Service/appsettings.json index f73fd4316..5519ff8d2 100644 --- a/service/Service/appsettings.json +++ b/service/Service/appsettings.json @@ -198,6 +198,8 @@ "AnswerTokens": 300, // Text to return when the LLM cannot produce an answer. "EmptyAnswer": "INFO NOT FOUND", + // Whether to include duplicate chunks in the RAG prompt. + "IncludeDuplicateFacts": false, // Number between 0 and 2 that controls the randomness of the completion. // The higher the temperature, the more random the completion. "Temperature": 0,