Skip to content

Commit

Permalink
.Net: Added vector search implementation for Azure CosmosDB for Mongo…
Browse files Browse the repository at this point in the history
…DB (#8887)

### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Related: #6522

- Implemented `VectorizedSearchAsync` method in Azure CosmosDB for
MongoDB connector.
- Added unit and integration tests.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
dmytrostruk authored Sep 18, 2024
1 parent f35d051 commit 8dc7dc3
Show file tree
Hide file tree
Showing 5 changed files with 414 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,132 @@ public async Task GetWithCustomMapperWorksCorrectlyAsync()
Assert.Equal("Name from mapper", result.HotelName);
}

[Theory]
[MemberData(nameof(VectorizedSearchVectorTypeData))]
public async Task VectorizedSearchThrowsExceptionWithInvalidVectorTypeAsync(object vector, bool exceptionExpected)
{
// Arrange
this.MockCollectionForSearch();

var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection<AzureCosmosDBMongoDBHotelModel>(
this._mockMongoDatabase.Object,
"collection");

// Act & Assert
if (exceptionExpected)
{
await Assert.ThrowsAsync<NotSupportedException>(async () => await sut.VectorizedSearchAsync(vector).ToListAsync());
}
else
{
var result = await sut.VectorizedSearchAsync(vector).FirstOrDefaultAsync();

Assert.NotNull(result);
}
}

[Theory]
[InlineData(null, "TestEmbedding1", 1, 1)]
[InlineData("", "TestEmbedding1", 2, 2)]
[InlineData("TestEmbedding1", "TestEmbedding1", 3, 3)]
[InlineData("TestEmbedding2", "test_embedding_2", 4, 4)]
public async Task VectorizedSearchUsesValidQueryAsync(
string? vectorPropertyName,
string expectedVectorPropertyName,
int actualLimit,
int expectedLimit)
{
// Arrange
var vector = new ReadOnlyMemory<float>([1f, 2f, 3f]);

var expectedSearch = new BsonDocument
{
{ "$search",
new BsonDocument
{
{ "cosmosSearch",
new BsonDocument
{
{ "vector", BsonArray.Create(vector.ToArray()) },
{ "path", expectedVectorPropertyName },
{ "k", expectedLimit },
}
},
{ "returnStoredSource", true }
}
}
};

var expectedProjection = new BsonDocument
{
{ "$project",
new BsonDocument
{
{ "similarityScore", new BsonDocument { { "$meta", "searchScore" } } },
{ "document", "$$ROOT" }
}
}
};

this.MockCollectionForSearch();

var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection<VectorSearchModel>(
this._mockMongoDatabase.Object,
"collection");

// Act
var result = await sut.VectorizedSearchAsync(vector, new()
{
VectorFieldName = vectorPropertyName,
Limit = actualLimit,
}).FirstOrDefaultAsync();

// Assert
Assert.NotNull(result);

this._mockMongoCollection.Verify(l => l.AggregateAsync(
It.Is<PipelineDefinition<BsonDocument, BsonDocument>>(pipeline =>
this.ComparePipeline(pipeline, expectedSearch, expectedProjection)),
It.IsAny<AggregateOptions>(),
It.IsAny<CancellationToken>()), Times.Once());
}

[Fact]
public async Task VectorizedSearchThrowsExceptionWithNonExistentVectorPropertyNameAsync()
{
// Arrange
this.MockCollectionForSearch();

var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection<AzureCosmosDBMongoDBHotelModel>(
this._mockMongoDatabase.Object,
"collection");

var options = new VectorSearchOptions { VectorFieldName = "non-existent-property" };

// Act & Assert
await Assert.ThrowsAsync<InvalidOperationException>(async () => await sut.VectorizedSearchAsync(new ReadOnlyMemory<float>([1f, 2f, 3f]), options).FirstOrDefaultAsync());
}

[Fact]
public async Task VectorizedSearchReturnsRecordWithScoreAsync()
{
// Arrange
this.MockCollectionForSearch();

var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection<AzureCosmosDBMongoDBHotelModel>(
this._mockMongoDatabase.Object,
"collection");

// Act
var result = await sut.VectorizedSearchAsync(new ReadOnlyMemory<float>([1f, 2f, 3f])).FirstOrDefaultAsync();

// Assert
Assert.NotNull(result);
Assert.Equal("key", result.Record.HotelId);
Assert.Equal("Test Name", result.Record.HotelName);
Assert.Equal(0.99f, result.Score);
}

public static TheoryData<List<string>, string, bool> CollectionExistsData => new()
{
{ ["collection-2"], "collection-2", true },
Expand All @@ -558,8 +684,54 @@ public async Task GetWithCustomMapperWorksCorrectlyAsync()
{ [], 1 }
};

public static TheoryData<object, bool> VectorizedSearchVectorTypeData => new()
{
{ new ReadOnlyMemory<float>([1f, 2f, 3f]), false },
{ new ReadOnlyMemory<double>([1f, 2f, 3f]), false },
{ new ReadOnlyMemory<float>?(new([1f, 2f, 3f])), false },
{ new ReadOnlyMemory<double>?(new([1f, 2f, 3f])), false },
{ new List<float>([1f, 2f, 3f]), true },
};

#region private

private bool ComparePipeline(
PipelineDefinition<BsonDocument, BsonDocument> actualPipeline,
BsonDocument expectedSearch,
BsonDocument expectedProjection)
{
var serializerRegistry = BsonSerializer.SerializerRegistry;
var documentSerializer = serializerRegistry.GetSerializer<BsonDocument>();

var documents = actualPipeline.Render(documentSerializer, serializerRegistry).Documents;

return
documents[0].ToJson() == expectedSearch.ToJson() &&
documents[1].ToJson() == expectedProjection.ToJson();
}

private void MockCollectionForSearch()
{
var document = new BsonDocument { ["_id"] = "key", ["HotelName"] = "Test Name" };
var searchResult = new BsonDocument { ["document"] = document, ["similarityScore"] = 0.99f };

var mockCursor = new Mock<IAsyncCursor<BsonDocument>>();
mockCursor
.Setup(l => l.MoveNextAsync(It.IsAny<CancellationToken>()))
.ReturnsAsync(true);

mockCursor
.Setup(l => l.Current)
.Returns([searchResult]);

this._mockMongoCollection
.Setup(l => l.AggregateAsync(
It.IsAny<PipelineDefinition<BsonDocument, BsonDocument>>(),
It.IsAny<AggregateOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(mockCursor.Object);
}

private async Task TestUpsertWithModelAsync<TDataModel>(
TDataModel dataModel,
string expectedPropertyName,
Expand Down Expand Up @@ -645,6 +817,23 @@ private sealed class BsonVectorStoreWithNameTestModel
[VectorStoreRecordData(StoragePropertyName = "storage_hotel_name")]
public string? HotelName { get; set; }
}

private sealed class VectorSearchModel
{
[BsonId]
[VectorStoreRecordKey]
public string? Id { get; set; }

[VectorStoreRecordData]
public string? HotelName { get; set; }

[VectorStoreRecordVector(Dimensions: 4, IndexKind: IndexKind.IvfFlat, DistanceFunction: DistanceFunction.CosineDistance, StoragePropertyName = "test_embedding_1")]
public ReadOnlyMemory<float> TestEmbedding1 { get; set; }

[BsonElement("test_embedding_2")]
[VectorStoreRecordVector(Dimensions: 4, IndexKind: IndexKind.IvfFlat, DistanceFunction: DistanceFunction.CosineDistance)]
public ReadOnlyMemory<float> TestEmbedding2 { get; set; }
}
#pragma warning restore CA1812

#endregion
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.SemanticKernel.Data;
using MongoDB.Bson;

namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB;

/// <summary>
/// Contains mapping helpers to use when searching for documents using Azure CosmosDB MongoDB.
/// </summary>
internal sealed class AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping
{
/// <summary>Returns search part of the search query for <see cref="IndexKind.Hnsw"/> index kind.</summary>
public static BsonDocument GetSearchQueryForHnswIndex<TVector>(
TVector vector,
string vectorPropertyName,
int limit,
int efSearch)
{
return new BsonDocument
{
{ "$search",
new BsonDocument
{
{ "cosmosSearch",
new BsonDocument
{
{ "vector", BsonArray.Create(vector) },
{ "path", vectorPropertyName },
{ "k", limit },
{ "efSearch", efSearch }
}
}
}
}
};
}

/// <summary>Returns search part of the search query for <see cref="IndexKind.IvfFlat"/> index kind.</summary>
public static BsonDocument GetSearchQueryForIvfIndex<TVector>(
TVector vector,
string vectorPropertyName,
int limit)
{
return new BsonDocument
{
{ "$search",
new BsonDocument
{
{ "cosmosSearch",
new BsonDocument
{
{ "vector", BsonArray.Create(vector) },
{ "path", vectorPropertyName },
{ "k", limit },
}
},
{ "returnStoredSource", true }
}
}
};
}

/// <summary>Returns projection part of the search query to return similarity score together with document.</summary>
public static BsonDocument GetProjectionQuery(string scorePropertyName, string documentPropertyName)
{
return new BsonDocument
{
{ "$project",
new BsonDocument
{
{ scorePropertyName, new BsonDocument { { "$meta", "searchScore" } } },
{ documentPropertyName, "$$ROOT" }
}
}
};
}
}
Loading

0 comments on commit 8dc7dc3

Please sign in to comment.