Skip to content

Commit

Permalink
RAG: discard duplicate facts by default (#952)
Browse files Browse the repository at this point in the history
## Motivation and Context (Why the change? What's the scenario?)

When importing chat messages or running tests with small files or
without IDs, the storage might fill up with duplicate chunks of text
that affect the number of tokens used when generating an answer.

## High level description (Approach, Design)

When using the ASK API, after fetching N records from storage, discard
duplicates and use only unique chunks in the RAG prompt.
The behavior is configurable via
`SearchClientConfig.IncludeDuplicateFacts` and request context
`custom_rag_include_duplicate_facts_bool` arg.
  • Loading branch information
dluc authored Dec 18, 2024
1 parent 0b006f8 commit 660d12f
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 24 deletions.
4 changes: 4 additions & 0 deletions examples/002-dotnet-Serverless/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.KernelMemory;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.Safety.AzureAIContentSafety;

/* Use MemoryServerlessClient to run the default import pipeline
Expand All @@ -21,6 +22,8 @@ public static class Program

public static async Task Main()
{
SensitiveDataLogger.Enabled = true;

var memoryConfiguration = new KernelMemoryConfig();
var searchClientConfig = new SearchClientConfig();

Expand Down Expand Up @@ -61,6 +64,7 @@ public static async Task Main()
l.AddSimpleConsole(c => c.SingleLine = true);
}))
.AddSingleton(memoryConfiguration)
.WithSearchClientConfig(searchClientConfig)
// .WithOpenAIDefaults(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) // Use OpenAI for text generation and embedding
// .WithOpenAI(openAIConfig) // Use OpenAI for text generation and embedding
// .WithLlamaTextGeneration(llamaConfig) // Generate answers and summaries using LLama
Expand Down
3 changes: 3 additions & 0 deletions service/Abstractions/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ public static class Rag
// Used to override how facts are injected into RAG prompt
public const string FactTemplate = "custom_rag_fact_template_str";

// Used to override if duplicate facts are included in RAG prompts
public const string IncludeDuplicateFacts = "custom_rag_include_duplicate_facts_bool";

// Used to override the max tokens to generate when using the RAG prompt
public const string MaxTokens = "custom_rag_max_tokens_int";

Expand Down
10 changes: 10 additions & 0 deletions service/Abstractions/Context/IContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ public static string GetCustomRagFactTemplateOrDefault(this IContext? context, s
return defaultValue;
}

public static bool GetCustomRagIncludeDuplicateFactsOrDefault(this IContext? context, bool defaultValue)
{
if (context.TryGetArg<bool>(Constants.CustomContext.Rag.IncludeDuplicateFacts, out var customValue))
{
return customValue;
}

return defaultValue;
}

public static string GetCustomRagPromptOrDefault(this IContext? context, string defaultValue)
{
if (context.TryGetArg<string>(Constants.CustomContext.Rag.Prompt, out var customValue))
Expand Down
15 changes: 15 additions & 0 deletions service/Abstractions/Search/SearchClientConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ public class SearchClientConfig
/// </summary>
public string FactTemplate { get; set; } = "==== [File:{{$source}};Relevance:{{$relevance}}]:\n{{$content}}";

/// <summary>
/// The memory DB might include duplicate chunks of text, e.g. when importing the same files
/// with different document IDs or chat messages (high probability), or when the same text
/// appears in different files (not very frequent, considering partitioning process).
/// If two chunks are equal (not case-sensitive), regardless of tags and file names, it's usually
/// better to skip the duplication, including the chunk only once in the RAG prompt, reducing the
/// tokens used. The chunk will still be listed under sources.
/// You might want to set this to True if your prompt includes other chunk details, such as tags
/// and filenames, that could affect the LLM output.
/// Note: when the storage contains duplicate records, other relevant records will be left out,
/// possibly affecting RAG quality, because deduplication occurs after retrieving N records from storage,
/// leaving RAG with [N - count(duplicates)] records to work with.
/// </summary>
public bool IncludeDuplicateFacts { get; set; } = false;

/// <summary>
/// Number between 0.0 and 2.0. It controls the randomness of the completion.
/// The higher the temperature, the more random the completion.
Expand Down
63 changes: 40 additions & 23 deletions service/Core/Search/SearchClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ public async IAsyncEnumerable<MemoryAnswer> AskStreamingAsync(
string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer);
string answerPrompt = context.GetCustomRagPromptOrDefault(this._answerPrompt);
int limit = context.GetCustomRagMaxMatchesCountOrDefault(this._config.MaxMatchesCount);
bool includeDuplicateFacts = context.GetCustomRagIncludeDuplicateFactsOrDefault(this._config.IncludeDuplicateFacts);

var maxTokens = this._config.MaxAskPromptSize > 0
? this._config.MaxAskPromptSize
Expand Down Expand Up @@ -242,7 +243,7 @@ public async IAsyncEnumerable<MemoryAnswer> AskStreamingAsync(
await foreach ((MemoryRecord memoryRecord, double recordRelevance) in matches.ConfigureAwait(false))
{
result.State = SearchState.Continue;
result = this.ProcessMemoryRecord(result, index, memoryRecord, recordRelevance, factTemplate);
result = this.ProcessMemoryRecord(result, index, memoryRecord, recordRelevance, includeDuplicateFacts, factTemplate);

if (result.State == SearchState.SkipRecord) { continue; }

Expand Down Expand Up @@ -276,10 +277,11 @@ public async IAsyncEnumerable<MemoryAnswer> AskStreamingAsync(
/// <param name="record">Memory record, e.g. text chunk + metadata</param>
/// <param name="recordRelevance">Memory record relevance</param>
/// <param name="index">Memory index name</param>
/// <param name="includeDupes">Whether to include or skip duplicate chunks of text</param>
/// <param name="factTemplate">How to render the record when preparing an LLM prompt</param>
/// <returns>Updated search result state</returns>
private SearchClientResult ProcessMemoryRecord(
SearchClientResult result, string index, MemoryRecord record, double recordRelevance, string? factTemplate = null)
SearchClientResult result, string index, MemoryRecord record, double recordRelevance, bool includeDupes = true, string? factTemplate = null)
{
var partitionText = record.GetPartitionText(this._log).Trim();
if (string.IsNullOrEmpty(partitionText))
Expand Down Expand Up @@ -309,6 +311,11 @@ private SearchClientResult ProcessMemoryRecord(
// Name of the file to show to the LLM, avoiding "content.url"
string fileNameForLLM = (fileName == "content.url" ? fileDownloadUrl : fileName);

// Dupes management note: don't skip the record, only skip the chunk in the prompt
// so Citations includes also duplicates, which might have different tags
bool isDupe = !result.FactsUniqueness.Add($"{partitionText}");
bool skipFactInPrompt = (isDupe && !includeDupes);

if (result.Mode == SearchMode.SearchMode)
{
// Relevance is `float.MinValue` when search uses only filters
Expand All @@ -318,29 +325,39 @@ private SearchClientResult ProcessMemoryRecord(
{
result.FactsAvailableCount++;

string fact = PromptUtils.RenderFactTemplate(
template: factTemplate!,
factContent: partitionText,
source: fileNameForLLM,
relevance: recordRelevance.ToString("P1", CultureInfo.CurrentCulture),
recordId: record.Id,
tags: record.Tags,
metadata: record.Payload);

// Use the partition/chunk only if there's room for it
int factSizeInTokens = this._textGenerator.CountTokens(fact);
if (factSizeInTokens >= result.TokensAvailable)
if (!skipFactInPrompt)
{
// Stop after reaching the max number of tokens
return result.Stop();
string fact = PromptUtils.RenderFactTemplate(
template: factTemplate!,
factContent: partitionText,
source: fileNameForLLM,
relevance: recordRelevance.ToString("P1", CultureInfo.CurrentCulture),
recordId: record.Id,
tags: record.Tags,
metadata: record.Payload);

// Use the partition/chunk only if there's room for it
int factSizeInTokens = this._textGenerator.CountTokens(fact);
if (factSizeInTokens >= result.TokensAvailable)
{
// Stop after reaching the max number of tokens
return result.Stop();
}

result.Facts.Append(fact);
result.FactsUsedCount++;
result.TokensAvailable -= factSizeInTokens;

// Relevance is cosine similarity when not using hybrid search
this._log.LogTrace("Adding content #{FactsUsedCount} with relevance {Relevance} (dupe: {IsDupe})",
result.FactsUsedCount, recordRelevance, isDupe);
}
else
{
// The counter must be increased to avoid long/infinite loops
// in case the storage contains several duplications
result.FactsUsedCount++;
}

result.Facts.Append(fact);
result.FactsUsedCount++;
result.TokensAvailable -= factSizeInTokens;

// Relevance is cosine similarity when not using hybrid search
this._log.LogTrace("Adding content #{0} with relevance {1}", result.FactsUsedCount, recordRelevance);
}

var citation = result.Mode switch
Expand Down
5 changes: 4 additions & 1 deletion service/Core/Search/SearchClientResult.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Text;

namespace Microsoft.KernelMemory.Search;
Expand Down Expand Up @@ -37,8 +39,9 @@ internal class SearchClientResult
public SearchResult SearchResult { get; private init; } = new();
public StringBuilder Facts { get; } = new();
public int FactsAvailableCount { get; set; }
public int FactsUsedCount { get; set; }
public int FactsUsedCount { get; set; } // Note: the number includes also duplicate chunks not used in the prompt
public int TokensAvailable { get; set; }
public HashSet<string> FactsUniqueness { get; set; } = new(StringComparer.OrdinalIgnoreCase);

/// <summary>
/// Create new instance in Ask mode
Expand Down
2 changes: 2 additions & 0 deletions service/Service/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@
"AnswerTokens": 300,
// Text to return when the LLM cannot produce an answer.
"EmptyAnswer": "INFO NOT FOUND",
// Whether to include duplicate chunks in the RAG prompt.
"IncludeDuplicateFacts": false,
// Number between 0 and 2 that controls the randomness of the completion.
// The higher the temperature, the more random the completion.
"Temperature": 0,
Expand Down

0 comments on commit 660d12f

Please sign in to comment.