Skip to content

Commit

Permalink
Merge branch 'main' into add-filtering-samples
Browse files Browse the repository at this point in the history
  • Loading branch information
SergeyMenshykh authored Nov 27, 2024
2 parents 0bdca1b + e780d7b commit 700c9c2
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 92 deletions.
149 changes: 101 additions & 48 deletions dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics;
using Azure.Identity;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB;
using Microsoft.SemanticKernel.Connectors.Redis;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Embeddings;

namespace Caching;

Expand All @@ -18,20 +18,17 @@ namespace Caching;
/// </summary>
public class SemanticCachingWithFilters(ITestOutputHelper output) : BaseTest(output)
{
/// <summary>
/// Similarity/relevance score, from 0 to 1, where 1 means exact match.
/// It's possible to change this value during testing to see how caching logic will behave.
/// </summary>
private const double SimilarityScore = 0.9;

/// <summary>
/// Executing similar requests two times using in-memory caching store to compare execution time and results.
/// Second execution is faster, because the result is returned from cache.
/// </summary>
[Fact]
public async Task InMemoryCacheAsync()
{
var kernel = GetKernelWithCache(_ => new VolatileMemoryStore());
var kernel = GetKernelWithCache(services =>
{
services.AddInMemoryVectorStore();
});

var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?");
var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?");
Expand All @@ -53,12 +50,15 @@ public async Task InMemoryCacheAsync()
/// <summary>
/// Executing similar requests two times using Redis caching store to compare execution time and results.
/// Second execution is faster, because the result is returned from cache.
/// How to run Redis on Docker locally: https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/docker/
/// How to run Redis on Docker locally: https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/docker/.
/// </summary>
[Fact]
public async Task RedisCacheAsync()
{
var kernel = GetKernelWithCache(_ => new RedisMemoryStore("localhost:6379", vectorSize: 1536));
var kernel = GetKernelWithCache(services =>
{
services.AddRedisVectorStore("localhost:6379");
});

var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?");
var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?");
Expand All @@ -84,10 +84,12 @@ public async Task RedisCacheAsync()
[Fact]
public async Task AzureCosmosDBMongoDBCacheAsync()
{
var kernel = GetKernelWithCache(_ => new AzureCosmosDBMongoDBMemoryStore(
TestConfiguration.AzureCosmosDbMongoDb.ConnectionString,
TestConfiguration.AzureCosmosDbMongoDb.DatabaseName,
new(dimensions: 1536)));
var kernel = GetKernelWithCache(services =>
{
services.AddAzureCosmosDBMongoDBVectorStore(
TestConfiguration.AzureCosmosDbMongoDb.ConnectionString,
TestConfiguration.AzureCosmosDbMongoDb.DatabaseName);
});

var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?");
var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?");
Expand All @@ -110,27 +112,41 @@ public async Task AzureCosmosDBMongoDBCacheAsync()
/// <summary>
/// Returns <see cref="Kernel"/> instance with required registered services.
/// </summary>
private Kernel GetKernelWithCache(Func<IServiceProvider, IMemoryStore> cacheFactory)
private Kernel GetKernelWithCache(Action<IServiceCollection> configureVectorStore)
{
var builder = Kernel.CreateBuilder();

// Add Azure OpenAI chat completion service
builder.AddAzureOpenAIChatCompletion(
TestConfiguration.AzureOpenAI.ChatDeploymentName,
TestConfiguration.AzureOpenAI.Endpoint,
TestConfiguration.AzureOpenAI.ApiKey);

// Add Azure OpenAI text embedding generation service
builder.AddAzureOpenAITextEmbeddingGeneration(
TestConfiguration.AzureOpenAIEmbeddings.DeploymentName,
TestConfiguration.AzureOpenAIEmbeddings.Endpoint,
TestConfiguration.AzureOpenAIEmbeddings.ApiKey);

// Add memory store for caching purposes (e.g. in-memory, Redis, Azure Cosmos DB)
builder.Services.AddSingleton<IMemoryStore>(cacheFactory);
if (!string.IsNullOrWhiteSpace(TestConfiguration.AzureOpenAI.ApiKey))
{
// Add Azure OpenAI chat completion service
builder.AddAzureOpenAIChatCompletion(
TestConfiguration.AzureOpenAI.ChatDeploymentName,
TestConfiguration.AzureOpenAI.Endpoint,
TestConfiguration.AzureOpenAI.ApiKey);

// Add Azure OpenAI text embedding generation service
builder.AddAzureOpenAITextEmbeddingGeneration(
TestConfiguration.AzureOpenAIEmbeddings.DeploymentName,
TestConfiguration.AzureOpenAIEmbeddings.Endpoint,
TestConfiguration.AzureOpenAI.ApiKey);
}
else
{
// Add Azure OpenAI chat completion service
builder.AddAzureOpenAIChatCompletion(
TestConfiguration.AzureOpenAI.ChatDeploymentName,
TestConfiguration.AzureOpenAI.Endpoint,
new AzureCliCredential());

// Add Azure OpenAI text embedding generation service
builder.AddAzureOpenAITextEmbeddingGeneration(
TestConfiguration.AzureOpenAIEmbeddings.DeploymentName,
TestConfiguration.AzureOpenAIEmbeddings.Endpoint,
new AzureCliCredential());
}

// Add text memory service that will be used to generate embeddings and query/store data.
builder.Services.AddSingleton<ISemanticTextMemory, SemanticTextMemory>();
// Add vector store for caching purposes (e.g. in-memory, Redis, Azure Cosmos DB)
configureVectorStore(builder.Services);

// Add prompt render filter to query cache and check if rendered prompt was already answered.
builder.Services.AddSingleton<IPromptRenderFilter, PromptCacheFilter>();
Expand Down Expand Up @@ -164,7 +180,10 @@ public class CacheBaseFilter
/// <summary>
/// Filter which is executed during prompt rendering operation.
/// </summary>
public sealed class PromptCacheFilter(ISemanticTextMemory semanticTextMemory) : CacheBaseFilter, IPromptRenderFilter
public sealed class PromptCacheFilter(
ITextEmbeddingGenerationService textEmbeddingGenerationService,
IVectorStore vectorStore)
: CacheBaseFilter, IPromptRenderFilter
{
public async Task OnPromptRenderAsync(PromptRenderContext context, Func<PromptRenderContext, Task> next)
{
Expand All @@ -174,20 +193,22 @@ public async Task OnPromptRenderAsync(PromptRenderContext context, Func<PromptRe
// Get rendered prompt
var prompt = context.RenderedPrompt!;

// Search for similar prompts in cache with provided similarity/relevance score
var searchResult = await semanticTextMemory.SearchAsync(
CollectionName,
prompt,
limit: 1,
minRelevanceScore: SimilarityScore).FirstOrDefaultAsync();
var promptEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(prompt);

var collection = vectorStore.GetCollection<string, CacheRecord>(CollectionName);
await collection.CreateCollectionIfNotExistsAsync();

// Search for similar prompts in cache.
var searchResults = await collection.VectorizedSearchAsync(promptEmbedding, new() { Top = 1 }, context.CancellationToken);
var searchResult = (await searchResults.Results.FirstOrDefaultAsync())?.Record;

// If result exists, return it.
if (searchResult is not null)
{
// Override function result. This will prevent calling LLM and will return result immediately.
context.Result = new FunctionResult(context.Function, searchResult.Metadata.AdditionalMetadata)
context.Result = new FunctionResult(context.Function, searchResult.Result)
{
Metadata = new Dictionary<string, object?> { [RecordIdKey] = searchResult.Metadata.Id }
Metadata = new Dictionary<string, object?> { [RecordIdKey] = searchResult.Id }
};
}
}
Expand All @@ -196,7 +217,10 @@ public async Task OnPromptRenderAsync(PromptRenderContext context, Func<PromptRe
/// <summary>
/// Filter which is executed during function invocation.
/// </summary>
public sealed class FunctionCacheFilter(ISemanticTextMemory semanticTextMemory) : CacheBaseFilter, IFunctionInvocationFilter
public sealed class FunctionCacheFilter(
ITextEmbeddingGenerationService textEmbeddingGenerationService,
IVectorStore vectorStore)
: CacheBaseFilter, IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
Expand All @@ -212,12 +236,22 @@ public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, F
// Get cache record id if result was cached previously or generate new id.
var recordId = context.Result.Metadata?.GetValueOrDefault(RecordIdKey, Guid.NewGuid().ToString()) as string;

// Generate prompt embedding.
var promptEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(context.Result.RenderedPrompt);

// Cache rendered prompt and LLM result.
await semanticTextMemory.SaveInformationAsync(
CollectionName,
context.Result.RenderedPrompt,
recordId!,
additionalMetadata: result.ToString());
var collection = vectorStore.GetCollection<string, CacheRecord>(CollectionName);
await collection.CreateCollectionIfNotExistsAsync();

var cacheRecord = new CacheRecord
{
Id = recordId!,
Prompt = context.Result.RenderedPrompt,
Result = result.ToString(),
PromptEmbedding = promptEmbedding
};

await collection.UpsertAsync(cacheRecord, cancellationToken: context.CancellationToken);
}
}
}
Expand Down Expand Up @@ -245,4 +279,23 @@ private async Task<FunctionResult> ExecuteAsync(Kernel kernel, string title, str
}

#endregion

#region Vector Store Record

private sealed class CacheRecord
{
[VectorStoreRecordKey]
public string Id { get; set; }

[VectorStoreRecordData]
public string Prompt { get; set; }

[VectorStoreRecordData]
public string Result { get; set; }

[VectorStoreRecordVector(Dimensions: 1536)]
public ReadOnlyMemory<float> PromptEmbedding { get; set; }
}

#endregion
}
49 changes: 35 additions & 14 deletions dotnet/samples/Concepts/Optimization/FrugalGPTWithFilters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

using System.Runtime.CompilerServices;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.InMemory;
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.PromptTemplates.Handlebars;
using Microsoft.SemanticKernel.Services;

Expand Down Expand Up @@ -97,11 +98,11 @@ public async Task ReducePromptSizeAsync()

// Add few-shot prompt optimization filter.
// The filter uses in-memory store for vector similarity search and text embedding generation service to generate embeddings.
var memoryStore = new VolatileMemoryStore();
var vectorStore = new InMemoryVectorStore();
var textEmbeddingGenerationService = kernel.GetRequiredService<ITextEmbeddingGenerationService>();

// Register optimization filter.
kernel.PromptRenderFilters.Add(new FewShotPromptOptimizationFilter(memoryStore, textEmbeddingGenerationService));
kernel.PromptRenderFilters.Add(new FewShotPromptOptimizationFilter(vectorStore, textEmbeddingGenerationService));

// Get result again and compare the usage.
result = await kernel.InvokeAsync(function, arguments);
Expand Down Expand Up @@ -167,7 +168,7 @@ public async Task LLMCascadeAsync()
/// which are similar to original request.
/// </summary>
private sealed class FewShotPromptOptimizationFilter(
IMemoryStore memoryStore,
IVectorStore vectorStore,
ITextEmbeddingGenerationService textEmbeddingGenerationService) : IPromptRenderFilter
{
/// <summary>
Expand All @@ -176,7 +177,7 @@ private sealed class FewShotPromptOptimizationFilter(
private const int TopN = 5;

/// <summary>
/// Collection name to use in memory store.
/// Collection name to use in vector store.
/// </summary>
private const string CollectionName = "examples";

Expand All @@ -188,30 +189,38 @@ public async Task OnPromptRenderAsync(PromptRenderContext context, Func<PromptRe

if (examples is { Count: > 0 } && !string.IsNullOrEmpty(request))
{
var memoryRecords = new List<MemoryRecord>();
var exampleRecords = new List<ExampleRecord>();

// Generate embedding for each example.
var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(examples);

// Create memory record instances with example text and embedding.
// Create vector store record instances with example text and embedding.
for (var i = 0; i < examples.Count; i++)
{
memoryRecords.Add(MemoryRecord.LocalRecord(Guid.NewGuid().ToString(), examples[i], "description", embeddings[i]));
exampleRecords.Add(new ExampleRecord
{
Id = Guid.NewGuid().ToString(),
Example = examples[i],
ExampleEmbedding = embeddings[i]
});
}

// Create collection and upsert all memory records for search.
// Create collection and upsert all vector store records for search.
// It's possible to do it only once and re-use the same examples for future requests.
await memoryStore.CreateCollectionAsync(CollectionName);
await memoryStore.UpsertBatchAsync(CollectionName, memoryRecords).ToListAsync();
var collection = vectorStore.GetCollection<string, ExampleRecord>(CollectionName);
await collection.CreateCollectionIfNotExistsAsync(context.CancellationToken);

await collection.UpsertBatchAsync(exampleRecords, cancellationToken: context.CancellationToken).ToListAsync(context.CancellationToken);

// Generate embedding for original request.
var requestEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(request);
var requestEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(request, cancellationToken: context.CancellationToken);

// Find top N examples which are similar to original request.
var topNExamples = await memoryStore.GetNearestMatchesAsync(CollectionName, requestEmbedding, TopN).ToListAsync();
var searchResults = await collection.VectorizedSearchAsync(requestEmbedding, new() { Top = TopN }, cancellationToken: context.CancellationToken);
var topNExamples = (await searchResults.Results.ToListAsync(context.CancellationToken)).Select(l => l.Record).ToList();

// Override arguments to use only top N examples, which will be sent to LLM.
context.Arguments["Examples"] = topNExamples.Select(l => l.Item1.Metadata.Text);
context.Arguments["Examples"] = topNExamples.Select(l => l.Example);
}

// Continue prompt rendering operation.
Expand Down Expand Up @@ -305,4 +314,16 @@ public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessa
yield return new StreamingChatMessageContent(AuthorRole.Assistant, mockResult);
}
}

private sealed class ExampleRecord
{
[VectorStoreRecordKey]
public string Id { get; set; }

[VectorStoreRecordData]
public string Example { get; set; }

[VectorStoreRecordVector]
public ReadOnlyMemory<float> ExampleEmbedding { get; set; }
}
}
Loading

0 comments on commit 700c9c2

Please sign in to comment.