diff --git a/index_alias_impl.go b/index_alias_impl.go index ccb52f244..4942eb1ab 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -165,7 +165,6 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest if len(i.indexes) == 1 { return i.indexes[0].SearchInContext(ctx, req) } - return MultiSearch(ctx, req, i.indexes...) } @@ -453,6 +452,10 @@ func MultiSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*Se req.SearchAfter = req.SearchBefore req.SearchBefore = nil } + originalSize := req.Size + if len(indexes) > 1 { + req.Size = adjustRequestSizeForKNN(req, len(indexes)) + } // run search on each index in separate go routine var waitGroup sync.WaitGroup @@ -491,6 +494,7 @@ func MultiSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*Se indexErrors[asr.Name] = asr.Err } } + req.Size = originalSize // merge just concatenated all the hits // now lets clean it up @@ -504,6 +508,10 @@ func MultiSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*Se } } + if len(indexes) > 1 { + mergeKNNResults(req, sr) + } + sortFunc := req.SortFunc() // sort all hits with the requested order if len(req.Sort) > 0 { diff --git a/index_impl.go b/index_impl.go index 5c9538822..bda58b7c4 100644 --- a/index_impl.go +++ b/index_impl.go @@ -642,7 +642,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr req.SearchAfter = nil } - return &SearchResult{ + rv := &SearchResult{ Status: &SearchStatus{ Total: 1, Successful: 1, @@ -653,7 +653,9 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr MaxScore: coll.MaxScore(), Took: searchDuration, Facets: coll.FacetResults(), - }, nil + } + mergeKNNResults(req, rv) + return rv, nil } func LoadAndHighlightFields(hit *search.DocumentMatch, req *SearchRequest, diff --git a/knn_test.go b/knn_test.go new file mode 100644 index 000000000..32e2ff62d --- /dev/null +++ b/knn_test.go @@ -0,0 +1,846 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package bleve + +import ( + "archive/zip" + "encoding/json" + "math" + "math/rand" + "testing" + + "github.com/blevesearch/bleve/v2/mapping" +) + +const testInputCompressedFile = "test/knn/knn_dataset_queries.zip" +const testDatasetFileName = "knn_dataset.json" +const testQueryFileName = "knn_queries.json" + +const testDatasetDims = 384 + +type testDocument struct { + ID string `json:"id"` + Content string `json:"content"` + Vector []float64 `json:"vector"` +} + +func readDatasetAndQueries(fileName string) ([]testDocument, []*SearchRequest, error) { + // Open the zip archive for reading + r, err := zip.OpenReader(fileName) + if err != nil { + return nil, nil, err + } + var dataset []testDocument + var queries []*SearchRequest + + defer r.Close() + for _, f := range r.File { + jsonFile, err := f.Open() + if err != nil { + return nil, nil, err + } + defer jsonFile.Close() + if f.Name == testDatasetFileName { + err = json.NewDecoder(jsonFile).Decode(&dataset) + if err != nil { + return nil, nil, err + } + } else if f.Name == testQueryFileName { + err = json.NewDecoder(jsonFile).Decode(&queries) + if err != nil { + return nil, nil, err + } + } + } + return dataset, queries, nil +} + +func makeDatasetIntoDocuments(dataset []testDocument) []map[string]interface{} { + documents := make([]map[string]interface{}, len(dataset)) + for i := 0; i < len(dataset); i++ { + document := make(map[string]interface{}) + document["id"] = dataset[i].ID + document["content"] = dataset[i].Content + document["vector"] = dataset[i].Vector + documents[i] = document + } + return documents +} + +func cleanUp(t *testing.T, indexPaths []string, indexes ...Index) { + for _, childIndex := range indexes { + err := childIndex.Close() + if err != nil { + t.Fatal(err) + } + } + for _, indexPath := range indexPaths { + cleanupTmpIndexPath(t, indexPath) + } +} + +func createPartitionedIndex(documents []map[string]interface{}, index *indexAliasImpl, numPartitions int, + mapping mapping.IndexMapping, t *testing.T) []string { + + partitionSize := len(documents) / numPartitions + extraDocs := len(documents) % numPartitions + docsPerPartition := make([]int, numPartitions) + for i := 0; i < numPartitions; i++ { + docsPerPartition[i] = partitionSize + if extraDocs > 0 { + docsPerPartition[i]++ + extraDocs-- + } + } + var rv []string + prevCutoff := 0 + for i := 0; i < numPartitions; i++ { + tmpIndexPath := createTmpIndexPath(t) + rv = append(rv, tmpIndexPath) + childIndex, err := New(tmpIndexPath, mapping) + if err != nil { + cleanUp(t, rv) + t.Fatal(err) + } + batch := childIndex.NewBatch() + for j := prevCutoff; j < prevCutoff+docsPerPartition[i]; j++ { + doc := documents[j] + err := batch.Index(doc["id"].(string), doc) + if err != nil { + cleanUp(t, rv) + t.Fatal(err) + } + } + prevCutoff += docsPerPartition[i] + err = childIndex.Batch(batch) + if err != nil { + cleanUp(t, rv) + t.Fatal(err) + } + index.Add(childIndex) + } + return rv +} + +func createMultipleSegmentsIndex(documents []map[string]interface{}, index Index, numSegments int) error { + // create multiple batches to simulate more than one segment + numBatches := numSegments + + batches := make([]*Batch, numBatches) + numDocsPerBatch := len(documents) / numBatches + extraDocs := len(documents) % numBatches + + docsPerBatch := make([]int, numBatches) + for i := 0; i < numBatches; i++ { + docsPerBatch[i] = numDocsPerBatch + if extraDocs > 0 { + docsPerBatch[i]++ + extraDocs-- + } + } + prevCutoff := 0 + for i := 0; i < numBatches; i++ { + batches[i] = index.NewBatch() + for j := prevCutoff; j < prevCutoff+docsPerBatch[i]; j++ { + doc := documents[j] + err := batches[i].Index(doc["id"].(string), doc) + if err != nil { + return err + } + } + prevCutoff += docsPerBatch[i] + } + for _, batch := range batches { + err = index.Batch(batch) + if err != nil { + return err + } + } + return nil +} + +// Fisher-Yates shuffle +func shuffleDocuments(documents []map[string]interface{}) []map[string]interface{} { + for i := range documents { + j := i + rand.Intn(len(documents)-i) + documents[i], documents[j] = documents[j], documents[i] + } + return documents +} + +func truncateScore(score float64) float64 { + return float64(int(score*1e6)) / 1e6 +} + +func TestSimilaritySearchRandomized(t *testing.T) { + runKNNTest(t, true) +} + +func TestSimilaritySearchNotRandomized(t *testing.T) { + runKNNTest(t, false) +} + +func runKNNTest(t *testing.T, randomizeDocuments bool) { + dataset, searchRequests, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + documents := makeDatasetIntoDocuments(dataset) + if randomizeDocuments { + documents = shuffleDocuments(documents) + } + + indexMapping := NewIndexMapping() + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Analyzer = "en" + + vecFieldMapping := mapping.NewVectorFieldMapping() + vecFieldMapping.Dims = testDatasetDims + vecFieldMapping.Similarity = "l2_norm" + + indexMapping.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping) + indexMapping.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMapping) + + index := NewIndexAlias() + + type testResult struct { + score float64 + scoreBreakdown []float64 + } + + type testCase struct { + testType string + queryIndex int + numIndexPartitions int + expectedResults map[string]testResult + } + + testCases := []testCase{ + { + testType: "single_partition:match_none:oneKNNreq:k=3", + queryIndex: 0, + numIndexPartitions: 1, + expectedResults: map[string]testResult{ + "doc29": { + score: 0.5547758085810349, + scoreBreakdown: []float64{0, 1.1095516171620698}, + }, + "doc23": { + score: 0.3817633037007331, + scoreBreakdown: []float64{0, 0.7635266074014662}, + }, + "doc28": { + score: 0.33983667469689355, + scoreBreakdown: []float64{0, 0.6796733493937871}, + }, + }, + }, + { + testType: "multi_partition:match_none:oneKNNreq:k=3", + queryIndex: 0, + numIndexPartitions: 4, + expectedResults: map[string]testResult{ + "doc29": { + score: 0.5547758085810349, + scoreBreakdown: []float64{0, 1.1095516171620698}, + }, + "doc23": { + score: 0.3817633037007331, + scoreBreakdown: []float64{0, 0.7635266074014662}, + }, + "doc28": { + score: 0.33983667469689355, + scoreBreakdown: []float64{0, 0.6796733493937871}, + }, + }, + }, + { + testType: "multi_partition:match_none:oneKNNreq:k=2", + queryIndex: 0, + numIndexPartitions: 10, + expectedResults: map[string]testResult{ + "doc29": { + score: 0.554775, + scoreBreakdown: []float64{0, 1.109551}, + }, + "doc23": { + score: 0.381763, + scoreBreakdown: []float64{0, 0.763526}, + }, + "doc28": { + score: 0.339836, + scoreBreakdown: []float64{0, 0.679673}, + }, + }, + }, + { + testType: "single_partition:match:oneKNNreq:k=2", + queryIndex: 1, + numIndexPartitions: 1, + expectedResults: map[string]testResult{ + "doc29": { + score: 1.8859816084399936, + scoreBreakdown: []float64{0.7764299912779237, 1.1095516171620698}, + }, + "doc23": { + score: 1.8615644255330264, + scoreBreakdown: []float64{1.0980378181315602, 0.7635266074014662}, + }, + "doc27": { + score: 0.4640056648691007, + scoreBreakdown: []float64{0.9280113297382014, 0}, + }, + "doc28": { + score: 0.434037555556026, + scoreBreakdown: []float64{0.868075111112052, 0}, + }, + "doc30": { + score: 0.38821499563896184, + scoreBreakdown: []float64{0.7764299912779237, 0}, + }, + "doc24": { + score: 0.38821499563896184, + scoreBreakdown: []float64{0.7764299912779237, 0}, + }, + }, + }, + { + testType: "multi_partition:match:oneKNNreq:k=2", + queryIndex: 1, + numIndexPartitions: 5, + expectedResults: map[string]testResult{ + "doc23": { + score: 1.5207250366637521, + scoreBreakdown: []float64{0.7571984292622859, 0.7635266074014662}, + }, + "doc29": { + score: 1.4834345192674083, + scoreBreakdown: []float64{0.3738829021053385, 1.1095516171620698}, + }, + "doc24": { + score: 0.2677100734235977, + scoreBreakdown: []float64{0.5354201468471954, 0}, + }, + "doc27": { + score: 0.22343776840593196, + scoreBreakdown: []float64{0.4468755368118639, 0}, + }, + "doc28": { + score: 0.20900689401100958, + scoreBreakdown: []float64{0.41801378802201916, 0}, + }, + "doc30": { + score: 0.18694145105266924, + scoreBreakdown: []float64{0.3738829021053385, 0}, + }, + }, + }, + { + testType: "single_partition:disjunction:twoKNNreq:k=2,2", + queryIndex: 2, + numIndexPartitions: 1, + expectedResults: map[string]testResult{ + "doc7": { + score: math.MaxFloat64, + scoreBreakdown: []float64{0, 0, math.MaxFloat64 / 3.0}, + }, + "doc29": { + score: 0.6774608026082964, + scoreBreakdown: []float64{0.23161973134064517, 0.7845714725717996, 0}, + }, + "doc23": { + score: 0.5783030702431613, + scoreBreakdown: []float64{0.32755976365480655, 0.5398948417099355, 0}, + }, + "doc3": { + score: 0.2550334160459894, + scoreBreakdown: []float64{0.7651002481379682, 0, 0}, + }, + "doc13": { + score: 0.2208654210738964, + scoreBreakdown: []float64{0.6625962632216892, 0, 0}, + }, + "doc5": { + score: 0.21180931116413285, + scoreBreakdown: []float64{0, 0, 0.6354279334923986}, + }, + "doc27": { + score: 0.09227950890170131, + scoreBreakdown: []float64{0.27683852670510395, 0, 0}, + }, + "doc28": { + score: 0.0863195764709126, + scoreBreakdown: []float64{0.2589587294127378, 0, 0}, + }, + "doc30": { + score: 0.07720657711354839, + scoreBreakdown: []float64{0.23161973134064517, 0, 0}, + }, + "doc24": { + score: 0.07720657711354839, + scoreBreakdown: []float64{0.23161973134064517, 0, 0}, + }, + }, + }, + { + testType: "multi_partition:disjunction:twoKNNreq:k=2,2", + queryIndex: 2, + numIndexPartitions: 4, + expectedResults: map[string]testResult{ + "doc7": { + score: math.MaxFloat64, + scoreBreakdown: []float64{0, 0, math.MaxFloat64 / 3.0}, + }, + "doc29": { + score: 0.567426591648309, + scoreBreakdown: []float64{0.06656841490066398, 0.7845714725717996, 0}, + }, + "doc23": { + score: 0.5639255136185979, + scoreBreakdown: []float64{0.3059934287179615, 0.5398948417099355, 0}, + }, + "doc5": { + score: 0.21180931116413285, + scoreBreakdown: []float64{0, 0, 0.6354279334923986}, + }, + "doc3": { + score: 0.14064944169372873, + scoreBreakdown: []float64{0.42194832508118624, 0, 0}, + }, + "doc13": { + score: 0.12180599172106943, + scoreBreakdown: []float64{0.3654179751632083, 0, 0}, + }, + "doc27": { + score: 0.026521491065731144, + scoreBreakdown: []float64{0.07956447319719344, 0, 0}, + }, + "doc28": { + score: 0.024808583220893122, + scoreBreakdown: []float64{0.07442574966267937, 0, 0}, + }, + "doc30": { + score: 0.02218947163355466, + scoreBreakdown: []float64{0.06656841490066398, 0, 0}, + }, + "doc24": { + score: 0.02218947163355466, + scoreBreakdown: []float64{0.06656841490066398, 0, 0}, + }, + }, + }, + { + // control: + // from = 0 + // size = 8 + testType: "pagination", + queryIndex: 3, + numIndexPartitions: 4, + expectedResults: map[string]testResult{ + "doc24": { + score: 1.22027994094805, + scoreBreakdown: []float64{0.027736154383370196, 0.3471022633855392, 0.5085619451465123, 0.33687957803262836}, + }, + "doc17": { + score: 0.7851856993753307, + scoreBreakdown: []float64{0.3367753689069724, 0, 0.3892791754255179, 0.320859721501284}, + }, + "doc21": { + score: 0.5927148028393034, + scoreBreakdown: []float64{0.06974846263723515, 0, 0.3914133076090359, 0.3291246335394669}, + }, + "doc14": { + score: 0.45680756875853035, + scoreBreakdown: []float64{0.5968461853543279, 0, 0, 0.31676895216273276}, + }, + "doc25": { + score: 0.292014972318407, + scoreBreakdown: []float64{0.17861510907524708, 0, 0.405414835561567, 0}, + }, + "doc23": { + score: 0.24706850662359503, + scoreBreakdown: []float64{0.09761951136424651, 0, 0.39651750188294355, 0}, + }, + "doc15": { + score: 0.24489276164017085, + scoreBreakdown: []float64{0.17216818679645968, 0, 0, 0.317617336483882}, + }, + "doc5": { + score: 0.10331722282971788, + scoreBreakdown: []float64{0, 0.4132688913188715, 0, 0}, + }, + }, + }, + { + // experimental: + // from = 0 + // size = 3 + testType: "pagination", + queryIndex: 4, + numIndexPartitions: 4, + expectedResults: map[string]testResult{ + "doc24": { + score: 1.22027994094805, + scoreBreakdown: []float64{0.027736154383370196, 0.3471022633855392, 0.5085619451465123, 0.33687957803262836}, + }, + "doc17": { + score: 0.7851856993753307, + scoreBreakdown: []float64{0.3367753689069724, 0, 0.3892791754255179, 0.320859721501284}, + }, + "doc21": { + score: 0.5927148028393034, + scoreBreakdown: []float64{0.06974846263723515, 0, 0.3914133076090359, 0.3291246335394669}, + }, + }, + }, + { + // from = 3 + // size = 3 + testType: "pagination", + queryIndex: 5, + numIndexPartitions: 4, + expectedResults: map[string]testResult{ + "doc14": { + score: 0.45680756875853035, + scoreBreakdown: []float64{0.5968461853543279, 0, 0, 0.31676895216273276}, + }, + "doc25": { + score: 0.292014972318407, + scoreBreakdown: []float64{0.17861510907524708, 0, 0.405414835561567, 0}, + }, + "doc23": { + score: 0.24706850662359503, + scoreBreakdown: []float64{0.09761951136424651, 0, 0.39651750188294355, 0}, + }, + }, + }, + } + + for testCaseNum, testCase := range testCases { + index.indexes = make([]Index, 0) + indexPaths := createPartitionedIndex(documents, index, testCase.numIndexPartitions, indexMapping, t) + query := searchRequests[testCase.queryIndex] + res, err := index.Search(query) + if err != nil { + t.Fatal(err) + } + if len(res.Hits) != len(testCase.expectedResults) { + t.Fatalf("testcase %d failed: expected %d results, got %d", testCaseNum, len(testCase.expectedResults), len(res.Hits)) + } + if randomizeDocuments && testCase.testType == "pagination" { + // pagination is not deterministic when documents are randomized + continue + } + for i, hit := range res.Hits { + var expectedHit testResult + var ok bool + if expectedHit, ok = testCase.expectedResults[hit.ID]; !ok { + t.Fatalf("testcase %d failed: unexpected result %s", testCaseNum, hit.ID) + } + // Truncate to 6 decimal places + actualScore := truncateScore(hit.Score) + expectScore := truncateScore(expectedHit.score) + if !randomizeDocuments && expectScore != actualScore { + t.Fatalf("testcase %d failed: expected hit %d to have score %f, got %f", testCaseNum, i, expectedHit.score, hit.Score) + } + if len(hit.ScoreBreakdown) != len(expectedHit.scoreBreakdown) { + t.Fatalf("testcase %d failed: expected hit %d to have %d score breakdowns, got %d", testCaseNum, i, len(expectedHit.scoreBreakdown), len(hit.ScoreBreakdown)) + } + if !randomizeDocuments { + actualScore := truncateScore(hit.ScoreBreakdown[0]) + expectScore := truncateScore(expectedHit.scoreBreakdown[0]) + if expectScore != actualScore { + t.Fatalf("testcase %d failed: expected hit %d to have score breakdown %f, got %f", testCaseNum, i, expectedHit.scoreBreakdown[0], hit.ScoreBreakdown[0]) + } + } + for j := 1; j < len(hit.ScoreBreakdown); j++ { + // Truncate to 6 decimal places + actualScore := truncateScore(hit.ScoreBreakdown[j]) + expectScore := truncateScore(expectedHit.scoreBreakdown[j]) + if expectScore != actualScore { + t.Fatalf("testcase %d failed: expected hit %d to have score breakdown %f, got %f", testCaseNum, i, expectedHit.scoreBreakdown[j], hit.ScoreBreakdown[j]) + } + } + } + cleanUp(t, indexPaths, index.indexes...) + } +} + +func TestSimilaritySearchMultipleSegments(t *testing.T) { + dataset, searchRequests, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + documents := makeDatasetIntoDocuments(dataset) + + indexMapping := NewIndexMapping() + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Analyzer = "en" + + vecFieldMapping := mapping.NewVectorFieldMapping() + vecFieldMapping.Dims = testDatasetDims + vecFieldMapping.Similarity = "l2_norm" + + indexMapping.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping) + indexMapping.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMapping) + + type testResult struct { + score float64 + scoreBreakdown []float64 + } + + testCases := []struct { + numSegments int + queryIndex int + expectedResults map[string]testResult + }{ + { + numSegments: 1, + queryIndex: 0, + expectedResults: map[string]testResult{ + "doc29": { + score: 0.5547758085810349, + scoreBreakdown: []float64{0, 1.1095516171620698}, + }, + "doc23": { + score: 0.3817633037007331, + scoreBreakdown: []float64{0, 0.7635266074014662}, + }, + "doc28": { + score: 0.33983667469689355, + scoreBreakdown: []float64{0, 0.6796733493937871}, + }, + }, + }, + { + numSegments: 6, + queryIndex: 0, + expectedResults: map[string]testResult{ + "doc29": { + score: 0.5547758085810349, + scoreBreakdown: []float64{0, 1.1095516171620698}, + }, + "doc23": { + score: 0.3817633037007331, + scoreBreakdown: []float64{0, 0.7635266074014662}, + }, + "doc28": { + score: 0.33983667469689355, + scoreBreakdown: []float64{0, 0.6796733493937871}, + }, + }, + }, + { + numSegments: 1, + queryIndex: 1, + expectedResults: map[string]testResult{ + "doc29": { + score: 1.8859816084399936, + scoreBreakdown: []float64{0.7764299912779237, 1.1095516171620698}, + }, + "doc23": { + score: 1.8615644255330264, + scoreBreakdown: []float64{1.0980378181315602, 0.7635266074014662}, + }, + "doc27": { + score: 0.4640056648691007, + scoreBreakdown: []float64{0.9280113297382014, 0}, + }, + "doc28": { + score: 0.434037555556026, + scoreBreakdown: []float64{0.868075111112052, 0}, + }, + "doc30": { + score: 0.38821499563896184, + scoreBreakdown: []float64{0.7764299912779237, 0}, + }, + "doc24": { + score: 0.38821499563896184, + scoreBreakdown: []float64{0.7764299912779237, 0}, + }, + }, + }, + { + numSegments: 7, + queryIndex: 1, + expectedResults: map[string]testResult{ + "doc29": { + score: 1.8859816084399936, + scoreBreakdown: []float64{0.7764299912779237, 1.1095516171620698}, + }, + "doc23": { + score: 1.8615644255330264, + scoreBreakdown: []float64{1.0980378181315602, 0.7635266074014662}, + }, + "doc27": { + score: 0.4640056648691007, + scoreBreakdown: []float64{0.9280113297382014, 0}, + }, + "doc28": { + score: 0.434037555556026, + scoreBreakdown: []float64{0.868075111112052, 0}, + }, + "doc30": { + score: 0.38821499563896184, + scoreBreakdown: []float64{0.7764299912779237, 0}, + }, + "doc24": { + score: 0.38821499563896184, + scoreBreakdown: []float64{0.7764299912779237, 0}, + }, + }, + }, + { + numSegments: 1, + queryIndex: 2, + expectedResults: map[string]testResult{ + "doc7": { + score: 2357.022603955158, + scoreBreakdown: []float64{0, 0, 7071.067811865475}, + }, + "doc29": { + score: 0.6774608026082964, + scoreBreakdown: []float64{0.23161973134064517, 0.7845714725717996, 0}, + }, + "doc23": { + score: 0.5783030702431613, + scoreBreakdown: []float64{0.32755976365480655, 0.5398948417099355, 0}, + }, + "doc3": { + score: 0.2550334160459894, + scoreBreakdown: []float64{0.7651002481379682, 0, 0}, + }, + "doc13": { + score: 0.2208654210738964, + scoreBreakdown: []float64{0.6625962632216892, 0, 0}, + }, + "doc5": { + score: 0.21180931116413285, + scoreBreakdown: []float64{0, 0, 0.6354279334923986}, + }, + "doc27": { + score: 0.09227950890170131, + scoreBreakdown: []float64{0.27683852670510395, 0, 0}, + }, + "doc28": { + score: 0.0863195764709126, + scoreBreakdown: []float64{0.2589587294127378, 0, 0}, + }, + "doc30": { + score: 0.07720657711354839, + scoreBreakdown: []float64{0.23161973134064517, 0, 0}, + }, + "doc24": { + score: 0.07720657711354839, + scoreBreakdown: []float64{0.23161973134064517, 0, 0}, + }, + }, + }, + { + numSegments: 6, + queryIndex: 2, + expectedResults: map[string]testResult{ + "doc7": { + score: 2357.022603955158, + scoreBreakdown: []float64{0, 0, 7071.067811865475}, + }, + "doc29": { + score: 0.6774608026082964, + scoreBreakdown: []float64{0.23161973134064517, 0.7845714725717996, 0}, + }, + "doc23": { + score: 0.5783030702431613, + scoreBreakdown: []float64{0.32755976365480655, 0.5398948417099355, 0}, + }, + "doc3": { + score: 0.2550334160459894, + scoreBreakdown: []float64{0.7651002481379682, 0, 0}, + }, + "doc13": { + score: 0.2208654210738964, + scoreBreakdown: []float64{0.6625962632216892, 0, 0}, + }, + "doc5": { + score: 0.21180931116413285, + scoreBreakdown: []float64{0, 0, 0.6354279334923986}, + }, + "doc27": { + score: 0.09227950890170131, + scoreBreakdown: []float64{0.27683852670510395, 0, 0}, + }, + "doc28": { + score: 0.0863195764709126, + scoreBreakdown: []float64{0.2589587294127378, 0, 0}, + }, + "doc30": { + score: 0.07720657711354839, + scoreBreakdown: []float64{0.23161973134064517, 0, 0}, + }, + "doc24": { + score: 0.07720657711354839, + scoreBreakdown: []float64{0.23161973134064517, 0, 0}, + }, + }, + }, + } + for testCaseNum, testCase := range testCases { + tmpIndexPath := createTmpIndexPath(t) + index, err := New(tmpIndexPath, indexMapping) + if err != nil { + t.Fatal(err) + } + query := searchRequests[testCase.queryIndex] + err = createMultipleSegmentsIndex(documents, index, testCase.numSegments) + if err != nil { + t.Fatal(err) + } + res, err := index.Search(query) + if err != nil { + t.Fatal(err) + } + for i, hit := range res.Hits { + var expectedHit testResult + var ok bool + if expectedHit, ok = testCase.expectedResults[hit.ID]; !ok { + t.Fatalf("testcase %d failed: unexpected result %s", testCaseNum, hit.ID) + } + // Truncate to 6 decimal places + actualScore := truncateScore(hit.Score) + expectScore := truncateScore(expectedHit.score) + if expectScore != actualScore { + t.Fatalf("testcase %d failed: expected hit %d to have score %f, got %f", testCaseNum, i, expectedHit.score, hit.Score) + } + if len(hit.ScoreBreakdown) != len(expectedHit.scoreBreakdown) { + t.Fatalf("testcase %d failed: expected hit %d to have %d score breakdowns, got %d", testCaseNum, i, len(expectedHit.scoreBreakdown), len(hit.ScoreBreakdown)) + } + for j := 0; j < len(hit.ScoreBreakdown); j++ { + // Truncate to 6 decimal places + actualScore := truncateScore(hit.ScoreBreakdown[j]) + expectScore := truncateScore(expectedHit.scoreBreakdown[j]) + if expectScore != actualScore { + t.Fatalf("testcase %d failed: expected hit %d to have score breakdown %f, got %f", testCaseNum, i, expectedHit.scoreBreakdown[j], hit.ScoreBreakdown[j]) + } + } + } + err = index.Close() + if err != nil { + t.Fatal(err) + } + cleanupTmpIndexPath(t, tmpIndexPath) + } +} diff --git a/search/scorer/scorer_conjunction.go b/search/scorer/scorer_conjunction.go index f3c81a78c..c94aeebb8 100644 --- a/search/scorer/scorer_conjunction.go +++ b/search/scorer/scorer_conjunction.go @@ -42,15 +42,28 @@ func NewConjunctionQueryScorer(options search.SearcherOptions) *ConjunctionQuery } } -func (s *ConjunctionQueryScorer) Score(ctx *search.SearchContext, constituents []*search.DocumentMatch) *search.DocumentMatch { +func (s *ConjunctionQueryScorer) Score(ctx *search.SearchContext, constituents []*search.DocumentMatch, originalPositions []int) *search.DocumentMatch { var sum float64 var childrenExplanations []*search.Explanation if s.options.Explain { childrenExplanations = make([]*search.Explanation, len(constituents)) } - + scoreBreakdown := make([]float64, len(constituents)) for i, docMatch := range constituents { sum += docMatch.Score + if originalPositions != nil { + // for use in conjunction searcher + // the originalPositions are the positions of the searchers + // pre sort, since conjunction searcher sorts the searchers + // in order of their Count(). + scoreBreakdown[originalPositions[i]] = docMatch.Score + } else { + // the indexes of searchers are the original searcher positions + // eg boolean searcher also uses the conjunction scorer, + // with index 0 being the must (conjunction) searcher + // and index 1 being the should (disjunction) searcher + scoreBreakdown[i] = docMatch.Score + } if s.options.Explain { childrenExplanations[i] = docMatch.Expl } @@ -65,6 +78,7 @@ func (s *ConjunctionQueryScorer) Score(ctx *search.SearchContext, constituents [ rv := constituents[0] rv.Score = newScore rv.Expl = newExpl + rv.ScoreBreakdown = scoreBreakdown rv.FieldTermLocations = search.MergeFieldTermLocations( rv.FieldTermLocations, constituents[1:]) diff --git a/search/scorer/scorer_disjunction.go b/search/scorer/scorer_disjunction.go index 054e76fd4..db79025ff 100644 --- a/search/scorer/scorer_disjunction.go +++ b/search/scorer/scorer_disjunction.go @@ -43,15 +43,22 @@ func NewDisjunctionQueryScorer(options search.SearcherOptions) *DisjunctionQuery } } -func (s *DisjunctionQueryScorer) Score(ctx *search.SearchContext, constituents []*search.DocumentMatch, countMatch, countTotal int) *search.DocumentMatch { +func (s *DisjunctionQueryScorer) Score(ctx *search.SearchContext, constituents []*search.DocumentMatch, countMatch, countTotal int, + matchingIdxs []int, originalPositions []int) *search.DocumentMatch { + var sum float64 var childrenExplanations []*search.Explanation if s.options.Explain { childrenExplanations = make([]*search.Explanation, len(constituents)) } - + scoreBreakdown := make([]float64, countTotal) for i, docMatch := range constituents { sum += docMatch.Score + if originalPositions != nil { + scoreBreakdown[originalPositions[matchingIdxs[i]]] = docMatch.Score + } else { + scoreBreakdown[matchingIdxs[i]] = docMatch.Score + } if s.options.Explain { childrenExplanations[i] = docMatch.Expl } @@ -75,6 +82,7 @@ func (s *DisjunctionQueryScorer) Score(ctx *search.SearchContext, constituents [ // reuse constituents[0] as the return value rv := constituents[0] rv.Score = newScore + rv.ScoreBreakdown = scoreBreakdown rv.Expl = newExpl rv.FieldTermLocations = search.MergeFieldTermLocations( rv.FieldTermLocations, constituents[1:]) diff --git a/search/scorer/scorer_knn.go b/search/scorer/scorer_knn.go index 70724fa65..7623fec46 100644 --- a/search/scorer/scorer_knn.go +++ b/search/scorer/scorer_knn.go @@ -18,6 +18,7 @@ package scorer import ( + "math" "reflect" "github.com/blevesearch/bleve/v2/search" @@ -60,6 +61,10 @@ func NewKNNQueryScorer(queryVector []float32, queryField string, queryBoost floa } } +// Score used when the knnMatch.Score = 0 -> +// the query and indexed vector are exactly the same. +const maxKNNScore = math.MaxFloat64 + func (sqs *KNNQueryScorer) Score(ctx *search.SearchContext, knnMatch *index.VectorDoc) *search.DocumentMatch { rv := ctx.DocumentMatchPool.Get() @@ -67,14 +72,18 @@ func (sqs *KNNQueryScorer) Score(ctx *search.SearchContext, if sqs.includeScore || sqs.options.Explain { var scoreExplanation *search.Explanation score := knnMatch.Score - if sqs.similarityMetric == index.EuclideanDistance { - // eucliden distances need to be inverted to work + // in case of euclidean distance being the distance metric, + // an exact vector (perfect match), would return distance = 0 + if score == 0 { + score = maxKNNScore + } else { + // euclidean distances need to be inverted to work with // tf-idf scoring score = 1.0 / score } // if the query weight isn't 1, multiply - if sqs.queryWeight != 1.0 { + if sqs.queryWeight != 1.0 && score != maxKNNScore { score = score * sqs.queryWeight } diff --git a/search/search.go b/search/search.go index b7a3c42ae..9b77251cf 100644 --- a/search/search.go +++ b/search/search.go @@ -173,6 +173,15 @@ type DocumentMatch struct { // not all sub-queries matched // if false, all the sub-queries matched PartialMatch bool `json:"partial_match,omitempty"` + + // used to indicate the sub-scores that combined to form the + // final score for this document match. This is only populated + // when the search request's query is a DisjunctionQuery + // or a ConjunctionQuery. The length of this slice will be + // the same as the number of sub-queries in the query. + // the order of the scores will match the order of the sub-queries + // in the query. + ScoreBreakdown []float64 `json:"score_breakdown,omitempty"` } func (dm *DocumentMatch) AddFieldValue(name string, value interface{}) { diff --git a/search/searcher/ordered_searchers_list.go b/search/searcher/ordered_searchers_list.go index f3e646e9d..4e9409224 100644 --- a/search/searcher/ordered_searchers_list.go +++ b/search/searcher/ordered_searchers_list.go @@ -18,18 +18,22 @@ import ( "github.com/blevesearch/bleve/v2/search" ) -type OrderedSearcherList []search.Searcher +type OrderedSearcherList struct { + searchers []search.Searcher + index []int +} // sort.Interface func (otrl OrderedSearcherList) Len() int { - return len(otrl) + return len(otrl.searchers) } func (otrl OrderedSearcherList) Less(i, j int) bool { - return otrl[i].Count() < otrl[j].Count() + return otrl.searchers[i].Count() < otrl.searchers[j].Count() } func (otrl OrderedSearcherList) Swap(i, j int) { - otrl[i], otrl[j] = otrl[j], otrl[i] + otrl.searchers[i], otrl.searchers[j] = otrl.searchers[j], otrl.searchers[i] + otrl.index[i], otrl.index[j] = otrl.index[j], otrl.index[i] } diff --git a/search/searcher/search_boolean.go b/search/searcher/search_boolean.go index bf207f810..5f2c7893e 100644 --- a/search/searcher/search_boolean.go +++ b/search/searcher/search_boolean.go @@ -274,7 +274,7 @@ func (s *BooleanSearcher) Next(ctx *search.SearchContext) (*search.DocumentMatch cons = s.matches[0:1] cons[0] = s.currShould } - rv = s.scorer.Score(ctx, cons) + rv = s.scorer.Score(ctx, cons, nil) err = s.advanceNextMust(ctx, rv) if err != nil { return nil, err @@ -284,7 +284,7 @@ func (s *BooleanSearcher) Next(ctx *search.SearchContext) (*search.DocumentMatch // match is OK anyway cons := s.matches[0:1] cons[0] = s.currMust - rv = s.scorer.Score(ctx, cons) + rv = s.scorer.Score(ctx, cons, nil) err = s.advanceNextMust(ctx, rv) if err != nil { return nil, err @@ -302,7 +302,7 @@ func (s *BooleanSearcher) Next(ctx *search.SearchContext) (*search.DocumentMatch cons = s.matches[0:1] cons[0] = s.currShould } - rv = s.scorer.Score(ctx, cons) + rv = s.scorer.Score(ctx, cons, nil) err = s.advanceNextMust(ctx, rv) if err != nil { return nil, err @@ -312,7 +312,7 @@ func (s *BooleanSearcher) Next(ctx *search.SearchContext) (*search.DocumentMatch // match is OK anyway cons := s.matches[0:1] cons[0] = s.currMust - rv = s.scorer.Score(ctx, cons) + rv = s.scorer.Score(ctx, cons, nil) err = s.advanceNextMust(ctx, rv) if err != nil { return nil, err diff --git a/search/searcher/search_conjunction.go b/search/searcher/search_conjunction.go index 19ef199ac..89a441d1e 100644 --- a/search/searcher/search_conjunction.go +++ b/search/searcher/search_conjunction.go @@ -16,7 +16,6 @@ package searcher import ( "context" - "math" "reflect" "sort" @@ -34,26 +33,34 @@ func init() { } type ConjunctionSearcher struct { - indexReader index.IndexReader - searchers OrderedSearcherList - queryNorm float64 - currs []*search.DocumentMatch - maxIDIdx int - scorer *scorer.ConjunctionQueryScorer - initialized bool - options search.SearcherOptions - bytesRead uint64 + indexReader index.IndexReader + searchers []search.Searcher + originalPos []int + queryNorm float64 + queryNormForKNN float64 + currs []*search.DocumentMatch + maxIDIdx int + scorer *scorer.ConjunctionQueryScorer + initialized bool + options search.SearcherOptions + bytesRead uint64 } func NewConjunctionSearcher(ctx context.Context, indexReader index.IndexReader, qsearchers []search.Searcher, options search.SearcherOptions) ( search.Searcher, error) { // build the sorted downstream searchers - searchers := make(OrderedSearcherList, len(qsearchers)) + sortedSearchers := &OrderedSearcherList{ + searchers: make([]search.Searcher, len(qsearchers)), + index: make([]int, len(qsearchers)), + } for i, searcher := range qsearchers { - searchers[i] = searcher + sortedSearchers.searchers[i] = searcher + sortedSearchers.index[i] = i } - sort.Sort(searchers) + sort.Sort(sortedSearchers) + searchers := sortedSearchers.searchers + originalPos := sortedSearchers.index // attempt the "unadorned" conjunction optimization only when we // do not need extra information like freq-norm's or term vectors @@ -70,6 +77,7 @@ func NewConjunctionSearcher(ctx context.Context, indexReader index.IndexReader, rv := ConjunctionSearcher{ indexReader: indexReader, options: options, + originalPos: originalPos, searchers: searchers, currs: make([]*search.DocumentMatch, len(searchers)), scorer: scorer.NewConjunctionQueryScorer(options), @@ -102,21 +110,9 @@ func (s *ConjunctionSearcher) Size() int { } } - return sizeInBytes -} + sizeInBytes += len(s.originalPos) * size.SizeOfInt -func (s *ConjunctionSearcher) computeQueryNorm() { - // first calculate sum of squared weights - sumOfSquaredWeights := 0.0 - for _, searcher := range s.searchers { - sumOfSquaredWeights += searcher.Weight() - } - // now compute query norm from this - s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) - // finally tell all the downstream searchers the norm - for _, searcher := range s.searchers { - searcher.SetQueryNorm(s.queryNorm) - } + return sizeInBytes } func (s *ConjunctionSearcher) initSearchers(ctx *search.SearchContext) error { @@ -207,7 +203,7 @@ OUTER: } // if we get here, a doc matched all readers, so score and add it - rv = s.scorer.Score(ctx, s.currs) + rv = s.scorer.Score(ctx, s.currs, s.originalPos) // we know all the searchers are pointing at the same thing // so they all need to be bumped diff --git a/search/searcher/search_disjunction_heap.go b/search/searcher/search_disjunction_heap.go index d36e30131..51e210804 100644 --- a/search/searcher/search_disjunction_heap.go +++ b/search/searcher/search_disjunction_heap.go @@ -18,7 +18,6 @@ import ( "bytes" "container/heap" "context" - "math" "reflect" "github.com/blevesearch/bleve/v2/search" @@ -39,22 +38,25 @@ func init() { } type SearcherCurr struct { - searcher search.Searcher - curr *search.DocumentMatch + searcher search.Searcher + curr *search.DocumentMatch + matchingIdx int } type DisjunctionHeapSearcher struct { indexReader index.IndexReader - numSearchers int - scorer *scorer.DisjunctionQueryScorer - min int - queryNorm float64 - initialized bool - searchers []search.Searcher - heap []*SearcherCurr + numSearchers int + scorer *scorer.DisjunctionQueryScorer + min int + queryNorm float64 + queryNormForKNN float64 + initialized bool + searchers []search.Searcher + heap []*SearcherCurr matching []*search.DocumentMatch + matchingIdxs []int matchingCurrs []*SearcherCurr bytesRead uint64 @@ -77,6 +79,7 @@ func newDisjunctionHeapSearcher(ctx context.Context, indexReader index.IndexRead min: int(min), matching: make([]*search.DocumentMatch, len(searchers)), matchingCurrs: make([]*SearcherCurr, len(searchers)), + matchingIdxs: make([]int, len(searchers)), heap: make([]*SearcherCurr, 0, len(searchers)), } rv.computeQueryNorm() @@ -101,24 +104,11 @@ func (s *DisjunctionHeapSearcher) Size() int { // since searchers and document matches already counted above sizeInBytes += len(s.matchingCurrs) * reflectStaticSizeSearcherCurr sizeInBytes += len(s.heap) * reflectStaticSizeSearcherCurr + sizeInBytes += len(s.matchingIdxs) * size.SizeOfInt return sizeInBytes } -func (s *DisjunctionHeapSearcher) computeQueryNorm() { - // first calculate sum of squared weights - sumOfSquaredWeights := 0.0 - for _, searcher := range s.searchers { - sumOfSquaredWeights += searcher.Weight() - } - // now compute query norm from this - s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) - // finally tell all the downstream searchers the norm - for _, searcher := range s.searchers { - searcher.SetQueryNorm(s.queryNorm) - } -} - func (s *DisjunctionHeapSearcher) initSearchers(ctx *search.SearchContext) error { // alloc a single block of SearcherCurrs block := make([]SearcherCurr, len(s.searchers)) @@ -132,6 +122,7 @@ func (s *DisjunctionHeapSearcher) initSearchers(ctx *search.SearchContext) error if curr != nil { block[i].searcher = searcher block[i].curr = curr + block[i].matchingIdx = i heap.Push(s, &block[i]) } } @@ -147,6 +138,7 @@ func (s *DisjunctionHeapSearcher) initSearchers(ctx *search.SearchContext) error func (s *DisjunctionHeapSearcher) updateMatches() error { matching := s.matching[:0] matchingCurrs := s.matchingCurrs[:0] + matchingIdxs := s.matchingIdxs[:0] if len(s.heap) > 0 { @@ -154,17 +146,20 @@ func (s *DisjunctionHeapSearcher) updateMatches() error { next := heap.Pop(s).(*SearcherCurr) matching = append(matching, next.curr) matchingCurrs = append(matchingCurrs, next) + matchingIdxs = append(matchingIdxs, next.matchingIdx) // now as long as top of heap matches, keep popping for len(s.heap) > 0 && bytes.Compare(next.curr.IndexInternalID, s.heap[0].curr.IndexInternalID) == 0 { next = heap.Pop(s).(*SearcherCurr) matching = append(matching, next.curr) matchingCurrs = append(matchingCurrs, next) + matchingIdxs = append(matchingIdxs, next.matchingIdx) } } s.matching = matching s.matchingCurrs = matchingCurrs + s.matchingIdxs = matchingIdxs return nil } @@ -199,7 +194,7 @@ func (s *DisjunctionHeapSearcher) Next(ctx *search.SearchContext) ( found = true partialMatch := len(s.matching) != len(s.searchers) // score this match - rv = s.scorer.Score(ctx, s.matching, len(s.matching), s.numSearchers) + rv = s.scorer.Score(ctx, s.matching, len(s.matching), s.numSearchers, s.matchingIdxs, nil) rv.PartialMatch = partialMatch } diff --git a/search/searcher/search_disjunction_slice.go b/search/searcher/search_disjunction_slice.go index 0969c8cf3..e14598409 100644 --- a/search/searcher/search_disjunction_slice.go +++ b/search/searcher/search_disjunction_slice.go @@ -16,7 +16,6 @@ package searcher import ( "context" - "math" "reflect" "sort" @@ -34,17 +33,19 @@ func init() { } type DisjunctionSliceSearcher struct { - indexReader index.IndexReader - searchers OrderedSearcherList - numSearchers int - queryNorm float64 - currs []*search.DocumentMatch - scorer *scorer.DisjunctionQueryScorer - min int - matching []*search.DocumentMatch - matchingIdxs []int - initialized bool - bytesRead uint64 + indexReader index.IndexReader + searchers []search.Searcher + originalPos []int + numSearchers int + queryNorm float64 + queryNormForKNN float64 + currs []*search.DocumentMatch + scorer *scorer.DisjunctionQueryScorer + min int + matching []*search.DocumentMatch + matchingIdxs []int + initialized bool + bytesRead uint64 } func newDisjunctionSliceSearcher(ctx context.Context, indexReader index.IndexReader, @@ -55,16 +56,24 @@ func newDisjunctionSliceSearcher(ctx context.Context, indexReader index.IndexRea return nil, tooManyClausesErr("", len(qsearchers)) } // build the downstream searchers - searchers := make(OrderedSearcherList, len(qsearchers)) + sortedSearchers := &OrderedSearcherList{ + searchers: make([]search.Searcher, len(qsearchers)), + index: make([]int, len(qsearchers)), + } for i, searcher := range qsearchers { - searchers[i] = searcher + sortedSearchers.searchers[i] = searcher + sortedSearchers.index[i] = i } // sort the searchers - sort.Sort(sort.Reverse(searchers)) + sort.Sort(sort.Reverse(sortedSearchers)) // build our searcher + searchers := sortedSearchers.searchers + originalPos := sortedSearchers.index + rv := DisjunctionSliceSearcher{ indexReader: indexReader, searchers: searchers, + originalPos: originalPos, numSearchers: len(searchers), currs: make([]*search.DocumentMatch, len(searchers)), scorer: scorer.NewDisjunctionQueryScorer(options), @@ -97,24 +106,11 @@ func (s *DisjunctionSliceSearcher) Size() int { } sizeInBytes += len(s.matchingIdxs) * size.SizeOfInt + sizeInBytes += len(s.originalPos) * size.SizeOfInt return sizeInBytes } -func (s *DisjunctionSliceSearcher) computeQueryNorm() { - // first calculate sum of squared weights - sumOfSquaredWeights := 0.0 - for _, searcher := range s.searchers { - sumOfSquaredWeights += searcher.Weight() - } - // now compute query norm from this - s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) - // finally tell all the downstream searchers the norm - for _, searcher := range s.searchers { - searcher.SetQueryNorm(s.queryNorm) - } -} - func (s *DisjunctionSliceSearcher) initSearchers(ctx *search.SearchContext) error { var err error // get all searchers pointing at their first match @@ -199,7 +195,7 @@ func (s *DisjunctionSliceSearcher) Next(ctx *search.SearchContext) ( found = true partialMatch := len(s.matching) != len(s.searchers) // score this match - rv = s.scorer.Score(ctx, s.matching, len(s.matching), s.numSearchers) + rv = s.scorer.Score(ctx, s.matching, len(s.matching), s.numSearchers, s.matchingIdxs, s.originalPos) rv.PartialMatch = partialMatch } diff --git a/search/searcher/util_knn.go b/search/searcher/util_knn.go new file mode 100644 index 000000000..62f7f9b22 --- /dev/null +++ b/search/searcher/util_knn.go @@ -0,0 +1,78 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package searcher + +import ( + "math" + + "github.com/blevesearch/bleve/v2/search" +) + +// util func used by both disjunction and conjunction searchers +// to compute the query norm. +// This follows a separate code path from the non-knn version +// because we need to separate out the weights from the KNN searchers +// and the rest of the searchers to make the knn +// score completely independent of tf-idf. +// the sumOfSquaredWeights depends on the tf-idf weights +// and using the same value for knn searchers will make the +// knn score dependent on tf-idf. +func computeQueryNorm(searchers []search.Searcher) (float64, float64) { + var queryNorm float64 + var queryNormForKNN float64 + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + + sumOfSquaredWeightsForKNN := 0.0 + + for _, searcher := range searchers { + if knnSearcher, ok := searcher.(*KNNSearcher); ok { + sumOfSquaredWeightsForKNN += knnSearcher.Weight() + } else { + sumOfSquaredWeights += searcher.Weight() + } + } + // now compute query norm from this + if sumOfSquaredWeights != 0.0 { + queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) + } + if sumOfSquaredWeightsForKNN != 0.0 { + queryNormForKNN = 1.0 / math.Sqrt(sumOfSquaredWeightsForKNN) + } + // finally tell all the downstream searchers the norm + for _, searcher := range searchers { + if knnSearcher, ok := searcher.(*KNNSearcher); ok { + knnSearcher.SetQueryNorm(queryNormForKNN) + } else { + searcher.SetQueryNorm(queryNorm) + } + } + return queryNorm, queryNormForKNN +} + +func (s *DisjunctionSliceSearcher) computeQueryNorm() { + s.queryNorm, s.queryNormForKNN = computeQueryNorm(s.searchers) +} + +func (s *DisjunctionHeapSearcher) computeQueryNorm() { + s.queryNorm, s.queryNormForKNN = computeQueryNorm(s.searchers) +} + +func (s *ConjunctionSearcher) computeQueryNorm() { + s.queryNorm, s.queryNormForKNN = computeQueryNorm(s.searchers) +} diff --git a/search/searcher/util_no_knn.go b/search/searcher/util_no_knn.go new file mode 100644 index 000000000..abf7a32d8 --- /dev/null +++ b/search/searcher/util_no_knn.go @@ -0,0 +1,62 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vectors +// +build !vectors + +package searcher + +import "math" + +func (s *DisjunctionSliceSearcher) computeQueryNorm() { + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + for _, searcher := range s.searchers { + sumOfSquaredWeights += searcher.Weight() + } + // now compute query norm from this + s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) + // finally tell all the downstream searchers the norm + for _, searcher := range s.searchers { + searcher.SetQueryNorm(s.queryNorm) + } +} + +func (s *DisjunctionHeapSearcher) computeQueryNorm() { + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + for _, searcher := range s.searchers { + sumOfSquaredWeights += searcher.Weight() + } + // now compute query norm from this + s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) + // finally tell all the downstream searchers the norm + for _, searcher := range s.searchers { + searcher.SetQueryNorm(s.queryNorm) + } +} + +func (s *ConjunctionSearcher) computeQueryNorm() { + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + for _, searcher := range s.searchers { + sumOfSquaredWeights += searcher.Weight() + } + // now compute query norm from this + s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) + // finally tell all the downstream searchers the norm + for _, searcher := range s.searchers { + searcher.SetQueryNorm(s.queryNorm) + } +} diff --git a/search_knn.go b/search_knn.go index 0e20d8d99..53e8829de 100644 --- a/search_knn.go +++ b/search_knn.go @@ -18,6 +18,7 @@ package bleve import ( + "container/heap" "encoding/json" "fmt" "sort" @@ -173,3 +174,113 @@ func validateKNN(req *SearchRequest) error { } return nil } + +func mergeKNNResults(req *SearchRequest, sr *SearchResult) { + if len(req.KNN) > 0 { + mergeKNN(req, sr) + } +} + +func adjustRequestSizeForKNN(req *SearchRequest, numIndexPartitions int) int { + var adjustedSize int + if req != nil { + adjustedSize = req.Size + if len(req.KNN) > 0 { + var minSizeReq int64 + for _, knn := range req.KNN { + minSizeReq += knn.K + } + minSizeReq *= int64(numIndexPartitions) + if int64(adjustedSize) < minSizeReq { + adjustedSize = int(minSizeReq) + } + } + } + return adjustedSize +} + +// heap impl +type scoreHeap struct { + scoreBreakdown []*[]float64 + sortIndex int +} + +func (s *scoreHeap) Len() int { return len(s.scoreBreakdown) } + +func (s *scoreHeap) Less(i, j int) bool { + return (*s.scoreBreakdown[i])[s.sortIndex] > (*s.scoreBreakdown[j])[s.sortIndex] +} + +func (s *scoreHeap) Swap(i, j int) { + s.scoreBreakdown[i], s.scoreBreakdown[j] = s.scoreBreakdown[j], s.scoreBreakdown[i] +} + +func (s *scoreHeap) Push(x interface{}) { + s.scoreBreakdown = append(s.scoreBreakdown, x.(*[]float64)) +} + +func (s *scoreHeap) Pop() interface{} { + old := s.scoreBreakdown + n := len(old) + x := old[n-1] + s.scoreBreakdown = old[0 : n-1] + return x +} + +func mergeKNN(req *SearchRequest, sr *SearchResult) { + // index 0 of score breakdown is always tf-idf score + numKnnQuery := len(req.KNN) + maxHeap := &scoreHeap{ + scoreBreakdown: make([]*[]float64, 0), + } + for i := 0; i < numKnnQuery; i++ { + kVal := req.KNN[i].K + maxHeap.sortIndex = i + 1 + for _, hit := range sr.Hits { + heap.Push(maxHeap, &hit.ScoreBreakdown) + } + for maxHeap.Len() > 0 { + arr := heap.Pop(maxHeap).(*[]float64) + if kVal > 0 { + kVal-- + } else { + (*arr)[maxHeap.sortIndex] = 0 + } + } + } + operator := 0 + if _, ok := req.Query.(*query.ConjunctionQuery); ok { + operator = 1 + } + nonZeroScoreHits := make([]*search.DocumentMatch, 0, len(sr.Hits)) + maxScore := 0.0 + for _, hit := range sr.Hits { + newScore := recomputeTotalScore(operator, hit) + if newScore > 0 { + hit.Score = newScore + if newScore > maxScore { + maxScore = newScore + } + nonZeroScoreHits = append(nonZeroScoreHits, hit) + } + } + sr.Hits = nonZeroScoreHits + sr.MaxScore = maxScore + sr.Total = uint64(len(nonZeroScoreHits)) +} + +func recomputeTotalScore(operator int, hit *search.DocumentMatch) float64 { + totalScore := 0.0 + numNonZero := 0 + for _, score := range hit.ScoreBreakdown { + if score != 0 { + numNonZero += 1 + } + totalScore += score + } + if operator == 0 { + coord := float64(numNonZero) / float64(len(hit.ScoreBreakdown)) + totalScore = totalScore * coord + } + return totalScore +} diff --git a/search_no_knn.go b/search_no_knn.go index 4600f8748..d4751279a 100644 --- a/search_no_knn.go +++ b/search_no_knn.go @@ -151,3 +151,14 @@ func disjunctQueryWithKNN(req *SearchRequest) query.Query { func validateKNN(req *SearchRequest) error { return nil } + +func mergeKNNResults(req *SearchRequest, sr *SearchResult) { + // no-op +} + +func adjustRequestSizeForKNN(req *SearchRequest, numIndexPartitions int) int { + if req != nil { + return req.Size + } + return 0 +} diff --git a/search_test.go b/search_test.go index 37da8da0a..fd56b0959 100644 --- a/search_test.go +++ b/search_test.go @@ -17,6 +17,7 @@ package bleve import ( "encoding/json" "fmt" + "math" "reflect" "strconv" "strings" @@ -26,6 +27,7 @@ import ( "github.com/blevesearch/bleve/v2/analysis" "github.com/blevesearch/bleve/v2/analysis/analyzer/custom" "github.com/blevesearch/bleve/v2/analysis/analyzer/keyword" + "github.com/blevesearch/bleve/v2/analysis/analyzer/simple" "github.com/blevesearch/bleve/v2/analysis/analyzer/standard" html_char_filter "github.com/blevesearch/bleve/v2/analysis/char/html" regexp_char_filter "github.com/blevesearch/bleve/v2/analysis/char/regexp" @@ -3376,3 +3378,175 @@ func TestPercentAndIsoStyleDates(t *testing.T) { } } } + +func roundToDecimalPlace(num float64, decimalPlaces int) float64 { + precision := math.Pow(10, float64(decimalPlaces)) + return math.Round(num*precision) / precision +} + +func TestScoreBreakdown(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + imap := mapping.NewIndexMapping() + textField := mapping.NewTextFieldMapping() + textField.Analyzer = simple.Name + imap.DefaultMapping.AddFieldMappingsAt("text", textField) + + documents := map[string]map[string]interface{}{ + "doc1": { + "text": "lorem ipsum dolor sit amet consectetur adipiscing elit do eiusmod tempor", + }, + "doc2": { + "text": "lorem dolor amet adipiscing sed eiusmod", + }, + "doc3": { + "text": "ipsum sit consectetur elit do tempor", + }, + "doc4": { + "text": "lorem ipsum sit amet adipiscing elit do eiusmod", + }, + } + + idx, err := New(tmpIndexPath, imap) + if err != nil { + t.Fatal(err) + } + defer func() { + err = idx.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch := idx.NewBatch() + for docID, doc := range documents { + err := batch.Index(docID, doc) + if err != nil { + t.Fatal(err) + } + } + err = idx.Batch(batch) + if err != nil { + t.Fatal(err) + } + + type testResult struct { + docID string // doc ID of the hit + score float64 + scoreBreakdown []float64 + } + type testStruct struct { + query string + operator int + expectHits []testResult + } + testQueries := []testStruct{ + { + // trigger conjunction searcher from match query with operator 1 + // expect dolor and tempor to have higher term score - since present in lesser docs and having same term freq + query: "lorem dolor amet adipiscing do tempor", + operator: 1, + expectHits: []testResult{ + { + docID: "doc1", + score: 0.815545, + scoreBreakdown: []float64{0.11147035536863306, 0.18483179634014485, 0.11147035536863306, 0.11147035536863306, 0.11147035536863306, 0.18483179634014485}, + }, + }, + }, + { + // trigger disjunction heap searcher from match query with operator 0 (>10 searchers) + // expect score breakdown to have a 0 at BLANK + query: "lorem BLANK ipsum BLANK BLANK dolor sit amet consectetur BLANK adipiscing BLANK elit sed do eiusmod tempor BLANK BLANK", + operator: 0, + expectHits: []testResult{ + { + docID: "doc1", + score: 0.3034548543819603, + scoreBreakdown: []float64{0.040398807605268316, 0, 0.040398807605268316, 0, 0, 0.0669862776967768, 0.040398807605268316, 0.040398807605268316, 0.0669862776967768, 0, 0.040398807605268316, 0, 0.040398807605268316, 0, 0.040398807605268316, 0.040398807605268316, 0.0669862776967768, 0, 0}, + }, + { + docID: "doc4", + score: 0.15956816751152955, + scoreBreakdown: []float64{0.04737179972998534, 0, 0.04737179972998534, 0, 0, 0, 0.04737179972998534, 0.04737179972998534, 0, 0, 0.04737179972998534, 0, 0.04737179972998534, 0, 0.04737179972998534, 0.04737179972998534, 0, 0, 0}, + }, + { + docID: "doc2", + score: 0.14725661652397853, + scoreBreakdown: []float64{0.05470024557900147, 0, 0, 0, 0, 0.09069985124905133, 0, 0.05470024557900147, 0, 0, 0.05470024557900147, 0, 0, 0.15681178542754148, 0, 0.05470024557900147, 0, 0, 0}, + }, + { + docID: "doc3", + score: 0.12637916362550797, + scoreBreakdown: []float64{0, 0, 0.05470024557900147, 0, 0, 0, 0.05470024557900147, 0, 0.09069985124905133, 0, 0, 0, 0.05470024557900147, 0, 0.05470024557900147, 0, 0.09069985124905133, 0, 0}, + }, + }, + }, + { + // trigger disjunction slice searcher from match query with operator 0 (< 10 searchers) + // expect BLANK to give a 0 in score breakdown + query: "BLANK lorem ipsum BLANK BLANK dolor sit BLANK", + operator: 0, + expectHits: []testResult{ + { + docID: "doc1", + score: 0.1340684440934241, + scoreBreakdown: []float64{0, 0.05756326446708409, 0.05756326446708409, 0, 0, 0.09544709478559595, 0.05756326446708409, 0}, + }, + { + docID: "doc4", + score: 0.07593627256602972, + scoreBreakdown: []float64{0, 0.06749890894758198, 0.06749890894758198, 0, 0, 0, 0.06749890894758198, 0}, + }, + { + docID: "doc2", + score: 0.05179425287147191, + scoreBreakdown: []float64{0, 0.0779410306721006, 0, 0, 0, 0.129235980813787, 0, 0}, + }, + { + docID: "doc3", + score: 0.0389705153360503, + scoreBreakdown: []float64{0, 0, 0.0779410306721006, 0, 0, 0, 0.0779410306721006, 0}, + }, + }, + }, + } + + for _, dtq := range testQueries { + + mq := NewMatchQuery(dtq.query) + mq.SetField("text") + mq.SetOperator(query.MatchQueryOperator(dtq.operator)) + sr := NewSearchRequest(mq) + sr.Explain = true + res, err := idx.Search(sr) + if err != nil { + t.Fatal(err) + } + if len(res.Hits) != len(dtq.expectHits) { + t.Fatalf("expected %d hits, got %d", len(dtq.expectHits), len(res.Hits)) + } + for i, hit := range res.Hits { + if hit.ID != dtq.expectHits[i].docID { + t.Fatalf("expected docID %s, got %s", dtq.expectHits[i].docID, hit.ID) + } + hit.Score = roundToDecimalPlace(hit.Score, 3) + expectScore := roundToDecimalPlace(dtq.expectHits[i].score, 3) + if hit.Score != expectScore { + t.Fatalf("expected score %f, got %f", dtq.expectHits[i].score, hit.Score) + } + if len(hit.ScoreBreakdown) != len(dtq.expectHits[i].scoreBreakdown) { + t.Fatalf("expected %d score breakdown, got %d", len(dtq.expectHits[i].scoreBreakdown), len(hit.ScoreBreakdown)) + } + for j, score := range hit.ScoreBreakdown { + actualScore := roundToDecimalPlace(score, 3) + expectScore := roundToDecimalPlace(dtq.expectHits[i].scoreBreakdown[j], 3) + if actualScore != expectScore { + t.Fatalf("expected score breakdown %f, got %f", dtq.expectHits[i].scoreBreakdown[j], score) + } + } + } + } + +} diff --git a/test/knn/knn_dataset_queries.zip b/test/knn/knn_dataset_queries.zip new file mode 100644 index 000000000..d840ded2f Binary files /dev/null and b/test/knn/knn_dataset_queries.zip differ