From 5e19613dbda0f6811bf1533901f2af406a41cf53 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Dec 2024 18:44:15 +0000 Subject: [PATCH] Fix bug where redis score was mapped from wrong score field. --- ...RedisHashSetVectorStoreRecordCollection.cs | 20 ++++--- .../RedisJsonVectorStoreRecordCollection.cs | 6 +- ...RedisVectorStoreCollectionCreateMapping.cs | 1 + ...RedisVectorStoreCollectionSearchMapping.cs | 47 ++++++++++++++++ ...HashSetVectorStoreRecordCollectionTests.cs | 15 +++-- ...disJsonVectorStoreRecordCollectionTests.cs | 16 ++++-- ...VectorStoreCollectionSearchMappingTests.cs | 56 +++++++++++++++++++ ...HashSetVectorStoreRecordCollectionTests.cs | 2 + ...disJsonVectorStoreRecordCollectionTests.cs | 4 +- 9 files changed, 146 insertions(+), 21 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs index 25236402cbdf..4767967a4340 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs @@ -78,8 +78,8 @@ public sealed class RedisHashSetVectorStoreRecordCollection : IVectorSt /// An array of the names of all the data properties that are part of the Redis payload as RedisValue objects, i.e. all properties except the key and vector properties. private readonly RedisValue[] _dataStoragePropertyNameRedisValues; - /// An array of the names of all the data properties that are part of the Redis payload, i.e. all properties except the key and vector properties. - private readonly string[] _dataStoragePropertyNames; + /// An array of the names of all the data properties that are part of the Redis payload, i.e. all properties except the key and vector properties, plus the generated score property. + private readonly string[] _dataStoragePropertyNamesWithScore; /// The mapper to use when mapping between the consumer data model and the Redis record. private readonly IVectorStoreRecordMapper _mapper; @@ -119,14 +119,14 @@ public RedisHashSetVectorStoreRecordCollection(IDatabase database, string collec this._propertyReader.VerifyVectorProperties(s_supportedVectorTypes); // Lookup storage property names. - this._dataStoragePropertyNames = this._propertyReader - .DataPropertyStoragePropertyNames - .ToArray(); + var dataStoragePropertyNames = this._propertyReader.DataPropertyStoragePropertyNames; - this._dataStoragePropertyNameRedisValues = this._dataStoragePropertyNames + this._dataStoragePropertyNameRedisValues = dataStoragePropertyNames .Select(RedisValue.Unbox) .ToArray(); + this._dataStoragePropertyNamesWithScore = [.. dataStoragePropertyNames, "vector_score"]; + // Assign Mapper. if (this._options.HashEntriesCustomMapper is not null) { @@ -342,7 +342,7 @@ public async Task> VectorizedSearchAsync(T var internalOptions = options ?? s_defaultVectorSearchOptions; // Build query & search. - var selectFields = internalOptions.IncludeVectors ? null : this._dataStoragePropertyNames; + var selectFields = internalOptions.IncludeVectors ? null : this._dataStoragePropertyNamesWithScore; byte[] vectorBytes = RedisVectorStoreCollectionSearchMapping.ValidateVectorAndConvertToBytes(vector, "HashSet"); var query = RedisVectorStoreCollectionSearchMapping.BuildQuery(vectorBytes, internalOptions, this._propertyReader.StoragePropertyNamesMap, this._propertyReader.FirstVectorPropertyStoragePropertyName!, selectFields); var results = await this.RunOperationAsync( @@ -369,7 +369,11 @@ public async Task> VectorizedSearchAsync(T return this._mapper.MapFromStorageToDataModel((this.RemoveKeyPrefixIfNeeded(result.Id), retrievedHashEntries), new() { IncludeVectors = internalOptions.IncludeVectors }); }); - return new VectorSearchResult(dataModel, result.Score); + // Process the score of the result item. + var distanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(internalOptions, this._propertyReader.VectorProperties, this._propertyReader.VectorProperty!); + var score = RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(result["vector_score"].HasValue ? (float)result["vector_score"] : null, distanceFunction); + + return new VectorSearchResult(dataModel, score); }); return new VectorSearchResults(mappedResults.ToAsyncEnumerable()); diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs index b3467b12abb6..08fb1155ee60 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs @@ -410,7 +410,11 @@ public async Task> VectorizedSearchAsync(T new() { IncludeVectors = internalOptions.IncludeVectors }); }); - return new VectorSearchResult(mappedRecord, result.Score); + // Process the score of the result item. + var distanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(internalOptions, this._propertyReader.VectorProperties, this._propertyReader.VectorProperty!); + var score = RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(result["vector_score"].HasValue ? (float)result["vector_score"] : null, distanceFunction); + + return new VectorSearchResult(mappedRecord, score); }); return new VectorSearchResults(mappedResults.ToAsyncEnumerable()); diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionCreateMapping.cs index 86da3a800f6f..cec458e99c2d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionCreateMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionCreateMapping.cs @@ -177,6 +177,7 @@ public static string GetSDKDistanceAlgorithm(VectorStoreRecordVectorProperty vec return vectorProperty.DistanceFunction switch { DistanceFunction.CosineSimilarity => "COSINE", + DistanceFunction.CosineDistance => "COSINE", DistanceFunction.DotProductSimilarity => "IP", DistanceFunction.EuclideanDistance => "L2", _ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Redis VectorStore.") diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs index bd2d504ba5d0..9171cf4e389d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs @@ -117,6 +117,53 @@ public static string BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IR return $"({string.Join(" ", filterClauses)})"; } + /// + /// Resolve the distance function to use for a search by checking the distance function of the vector property specified in options + /// or by falling back to the distance function of the first vector property, or by falling back to the default distance function. + /// + /// The search options potentially containing a vector field to search. + /// The list of all vector properties. + /// The first vector property in the record. + /// The distance function for the vector we want to search. + /// Thrown when a user asked for a vector property that doesn't exist on the record. + public static string ResolveDistanceFunction(VectorSearchOptions options, IReadOnlyList vectorProperties, VectorStoreRecordVectorProperty firstVectorProperty) + { + if (options.VectorPropertyName == null || vectorProperties.Count == 1) + { + return firstVectorProperty.DistanceFunction ?? DistanceFunction.CosineSimilarity; + } + + var vectorProperty = vectorProperties.FirstOrDefault(p => p.DataModelPropertyName == options.VectorPropertyName) + ?? throw new InvalidOperationException($"The collection does not have a vector field named '{options.VectorPropertyName}'."); + + return vectorProperty.DistanceFunction ?? DistanceFunction.CosineSimilarity; + } + + /// + /// Convert the score from redis into the appropriate output score based on the distance function. + /// Redis doesn't support Cosine Similarity, so we need to convert from distance to similarity if it was chosen. + /// + /// The redis score to convert. + /// The distance function used in the search. + /// The converted score. + /// Thrown if the provided distance function is not supported by redis. + public static float? GetOutputScoreFromRedisScore(float? redisScore, string distanceFunction) + { + if (redisScore is null) + { + return null; + } + + return distanceFunction switch + { + DistanceFunction.CosineSimilarity => 1 - redisScore, + DistanceFunction.CosineDistance => redisScore, + DistanceFunction.DotProductSimilarity => redisScore, + DistanceFunction.EuclideanDistance => redisScore, + _ => throw new InvalidOperationException($"The distance function '{distanceFunction}' is not supported."), + }; + } + /// /// Resolve the vector field name to use for a search by using the storage name for the field name from options /// if available, and falling back to the first vector field name if not. diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs index a34eae25bab3..5457582661ee 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs @@ -427,7 +427,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc { RedisResult.Create(new RedisValue("1")), RedisResult.Create(new RedisValue(TestRecordKey1)), - RedisResult.Create(new RedisValue("0.5")), + RedisResult.Create(new RedisValue("0.8")), RedisResult.Create( [ new RedisValue("OriginalNameData"), @@ -436,6 +436,8 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc new RedisValue("data 1"), new RedisValue("vector_storage_name"), RedisValue.Unbox(MemoryMarshal.AsBytes(new ReadOnlySpan(new float[] { 1, 2, 3, 4 })).ToArray()), + new RedisValue("vector_score"), + new RedisValue("0.25"), ]), }); var sut = this.CreateRecordCollection(useDefinition); @@ -468,9 +470,10 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc var returnArgs = includeVectors ? Array.Empty() : new object[] { "RETURN", - 2, + 3, "OriginalNameData", - "data_storage_name" + "data_storage_name", + "vector_score" }; var expectedArgsPart2 = new object[] { @@ -493,7 +496,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc var results = await actual.Results.ToListAsync(); Assert.Single(results); Assert.Equal(TestRecordKey1, results.First().Record.Key); - Assert.Equal(0.5d, results.First().Score); + Assert.Equal(0.25d, results.First().Score); Assert.Equal("original data 1", results.First().Record.OriginalNameData); Assert.Equal("data 1", results.First().Record.Data); if (includeVectors) @@ -613,7 +616,7 @@ private static SinglePropsModel CreateModel(string key, bool withVectors) new VectorStoreRecordKeyProperty("Key", typeof(string)), new VectorStoreRecordDataProperty("OriginalNameData", typeof(string)), new VectorStoreRecordDataProperty("Data", typeof(string)) { StoragePropertyName = "data_storage_name" }, - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { StoragePropertyName = "vector_storage_name" } + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { StoragePropertyName = "vector_storage_name", DistanceFunction = DistanceFunction.CosineDistance } ] }; @@ -630,7 +633,7 @@ public sealed class SinglePropsModel public string Data { get; set; } = string.Empty; [JsonPropertyName("ignored_vector_json_name")] - [VectorStoreRecordVector(4, StoragePropertyName = "vector_storage_name")] + [VectorStoreRecordVector(4, DistanceFunction.CosineDistance, StoragePropertyName = "vector_storage_name")] public ReadOnlyMemory? Vector { get; set; } public string? NotAnnotated { get; set; } diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs index 477e08bfca73..20d1b0da5831 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs @@ -450,8 +450,14 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition) { RedisResult.Create(new RedisValue("1")), RedisResult.Create(new RedisValue(TestRecordKey1)), - RedisResult.Create(new RedisValue("0.5")), - RedisResult.Create([new RedisValue("$"), new RedisValue(jsonResult)]), + RedisResult.Create(new RedisValue("0.8")), + RedisResult.Create( + [ + new RedisValue("$"), + new RedisValue(jsonResult), + new RedisValue("vector_score"), + new RedisValue("0.25") + ]), }); var sut = this.CreateRecordCollection(useDefinition); @@ -496,7 +502,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition) var results = await actual.Results.ToListAsync(); Assert.Single(results); Assert.Equal(TestRecordKey1, results.First().Record.Key); - Assert.Equal(0.5d, results.First().Score); + Assert.Equal(0.25d, results.First().Score); Assert.Equal("data 1", results.First().Record.Data1); Assert.Equal("data 2", results.First().Record.Data2); Assert.Equal(new float[] { 1, 2, 3, 4 }, results.First().Record.Vector1!.Value.ToArray()); @@ -617,7 +623,7 @@ private static MultiPropsModel CreateModel(string key, bool withVectors) new VectorStoreRecordKeyProperty("Key", typeof(string)), new VectorStoreRecordDataProperty("Data1", typeof(string)) { IsFilterable = true, StoragePropertyName = "ignored_data1_storage_name" }, new VectorStoreRecordDataProperty("Data2", typeof(string)) { IsFilterable = true }, - new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory)) { Dimensions = 4, StoragePropertyName = "ignored_vector1_storage_name" }, + new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory)) { Dimensions = 4, DistanceFunction = DistanceFunction.CosineDistance, StoragePropertyName = "ignored_vector1_storage_name" }, new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory)) { Dimensions = 4 } ] }; @@ -635,7 +641,7 @@ public sealed class MultiPropsModel public string Data2 { get; set; } = string.Empty; [JsonPropertyName("vector1_json_name")] - [VectorStoreRecordVector(4, StoragePropertyName = "ignored_vector1_storage_name")] + [VectorStoreRecordVector(4, DistanceFunction.CosineDistance, StoragePropertyName = "ignored_vector1_storage_name")] public ReadOnlyMemory? Vector1 { get; set; } [VectorStoreRecordVector(4)] diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs index f6f11fdd73bc..5fb91154caf1 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs @@ -204,4 +204,60 @@ public void BuildFilterThrowsForUnknownFieldName() var filter = RedisVectorStoreCollectionSearchMapping.BuildFilter(basicVectorSearchFilter, storagePropertyNames); }); } + + [Fact] + public void ResolveDistanceFunctionReturnsCosineSimilarityIfNoDistanceFunctionSpecified() + { + var property = new VectorStoreRecordVectorProperty("Prop", typeof(ReadOnlyMemory)); + + // Act. + var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property); + + // Assert. + Assert.Equal(DistanceFunction.CosineSimilarity, resolvedDistanceFunction); + } + + [Fact] + public void ResolveDistanceFunctionReturnsDistanceFunctionFromFirstPropertyIfNoFieldChosen() + { + var property = new VectorStoreRecordVectorProperty("Prop", typeof(ReadOnlyMemory)) { DistanceFunction = DistanceFunction.DotProductSimilarity }; + + // Act. + var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property); + + // Assert. + Assert.Equal(DistanceFunction.DotProductSimilarity, resolvedDistanceFunction); + } + + [Fact] + public void ResolveDistanceFunctionReturnsDistanceFunctionFromChosenPropertyIfFieldChosen() + { + var property1 = new VectorStoreRecordVectorProperty("Prop1", typeof(ReadOnlyMemory)) { DistanceFunction = DistanceFunction.CosineDistance }; + var property2 = new VectorStoreRecordVectorProperty("Prop2", typeof(ReadOnlyMemory)) { DistanceFunction = DistanceFunction.DotProductSimilarity }; + + // Act. + var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions() { VectorPropertyName = "Prop2" }, [property1, property2], property1); + + // Assert. + Assert.Equal(DistanceFunction.DotProductSimilarity, resolvedDistanceFunction); + } + + [Fact] + public void GetOutputScoreFromRedisScoreConvertsCosineDistanceToSimilarity() + { + // Act & Assert. + Assert.Equal(-1, RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(2, DistanceFunction.CosineSimilarity)); + Assert.Equal(0, RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(1, DistanceFunction.CosineSimilarity)); + Assert.Equal(1, RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(0, DistanceFunction.CosineSimilarity)); + } + + [Theory] + [InlineData(DistanceFunction.CosineDistance, 2)] + [InlineData(DistanceFunction.DotProductSimilarity, 2)] + [InlineData(DistanceFunction.EuclideanDistance, 2)] + public void GetOutputScoreFromRedisScoreLeavesNonConsineSimiliartyUntouched(string distanceFunction, float score) + { + // Act & Assert. + Assert.Equal(score, RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(score, distanceFunction)); + } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs index d5d807781807..4fff25413c5c 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs @@ -84,6 +84,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); + Assert.Equal(1, searchResults.First().Score); var searchResultRecord = searchResults.First().Record; Assert.Equal(record.HotelId, searchResultRecord?.HotelId); Assert.Equal(record.HotelName, searchResultRecord?.HotelName); @@ -325,6 +326,7 @@ public async Task ItCanSearchWithFloat32VectorAndFilterAsync(string filterType, // Assert var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); + Assert.Equal(1, searchResults.First().Score); var searchResult = searchResults.First().Record; Assert.Equal("HBaseSet-1", searchResult?.HotelId); Assert.Equal("My Hotel 1", searchResult?.HotelName); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs index 2ed69bc63055..0667f8328983 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs @@ -88,6 +88,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); + Assert.Equal(1, searchResults.First().Score); var searchResultRecord = searchResults.First().Record; Assert.Equal(record.HotelId, searchResultRecord?.HotelId); Assert.Equal(record.HotelName, searchResultRecord?.HotelName); @@ -101,7 +102,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe Assert.Equal(record.Address.City, searchResultRecord?.Address.City); Assert.Equal(record.Description, searchResultRecord?.Description); Assert.Equal(record.DescriptionEmbedding?.ToArray(), searchResultRecord?.DescriptionEmbedding?.ToArray()); - + // Output output.WriteLine(collectionExistResult.ToString()); output.WriteLine(upsertResult); @@ -351,6 +352,7 @@ public async Task ItCanSearchWithFloat32VectorAndFilterAsync(string filterType) // Assert var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); + Assert.Equal(1, searchResults.First().Score); var searchResult = searchResults.First().Record; Assert.Equal("My Hotel 1", searchResults.First().Record.HotelName); Assert.Equal("BaseSet-1", searchResult?.HotelId);