diff --git a/knn_test.go b/knn_test.go index 8efd6144c..b19b7005d 100644 --- a/knn_test.go +++ b/knn_test.go @@ -26,13 +26,11 @@ import ( "github.com/blevesearch/bleve/v2/mapping" ) -const testDatasetFileName = "test/tests/knn/dataset-30-docs.json" -const testQueryFileName = "test/tests/knn/small-query.json" +const testDatasetFileName = "test/knn/knn_dataset.json" +const testQueryFileName = "test/knn/knn_queries.json" const testDatasetDims = 384 -const randomizeDocuments = false - type testDocument struct { ID string `json:"id"` Content string `json:"content"` @@ -93,7 +91,17 @@ func createPartitionedIndex(documents []map[string]interface{}, index *indexAlia mapping mapping.IndexMapping, t *testing.T) []string { partitionSize := len(documents) / numPartitions + extraDocs := len(documents) % numPartitions + docsPerPartition := make([]int, numPartitions) + for i := 0; i < numPartitions; i++ { + docsPerPartition[i] = partitionSize + if extraDocs > 0 { + docsPerPartition[i]++ + extraDocs-- + } + } var rv []string + prevCutoff := 0 for i := 0; i < numPartitions; i++ { tmpIndexPath := createTmpIndexPath(t) rv = append(rv, tmpIndexPath) @@ -103,7 +111,7 @@ func createPartitionedIndex(documents []map[string]interface{}, index *indexAlia t.Fatal(err) } batch := childIndex.NewBatch() - for j := i * partitionSize; (j < (i+1)*partitionSize) && j < len(documents); j++ { + for j := prevCutoff; j < prevCutoff+docsPerPartition[i]; j++ { doc := documents[j] err := batch.Index(doc["id"].(string), doc) if err != nil { @@ -111,6 +119,7 @@ func createPartitionedIndex(documents []map[string]interface{}, index *indexAlia t.Fatal(err) } } + prevCutoff += docsPerPartition[i] err = childIndex.Batch(batch) if err != nil { cleanUp(t, rv) @@ -134,8 +143,15 @@ func truncateScore(score float64) float64 { return float64(int(score*1e6)) / 1e6 } -func TestSimilaritySearchQuery(t *testing.T) { +func TestSimilaritySearchRandomized(t *testing.T) { + runKNNTest(t, true) +} + +func TestSimilaritySearchNotRandomized(t *testing.T) { + runKNNTest(t, false) +} +func runKNNTest(t *testing.T, randomizeDocuments bool) { dataset, err := createVectorDataset(testDatasetFileName) if err != nil { t.Fatal(err) @@ -197,6 +213,10 @@ func TestSimilaritySearchQuery(t *testing.T) { queryIndex: 0, numIndexPartitions: 4, expectedResults: map[string]testResult{ + "doc29": { + score: 0.5547758085810349, + scoreBreakdown: []float64{0, 1.1095516171620698}, + }, "doc23": { score: 0.3817633037007331, scoreBreakdown: []float64{0, 0.7635266074014662}, @@ -205,10 +225,6 @@ func TestSimilaritySearchQuery(t *testing.T) { score: 0.33983667469689355, scoreBreakdown: []float64{0, 0.6796733493937871}, }, - "doc13": { - score: 0.3206958457835452, - scoreBreakdown: []float64{0, 0.6413916915670904}, - }, }, }, { @@ -233,29 +249,29 @@ func TestSimilaritySearchQuery(t *testing.T) { queryIndex: 1, numIndexPartitions: 1, expectedResults: map[string]testResult{ - "doc23": { - score: 1.304929, - scoreBreakdown: []float64{1.016928, 0.288001}, - }, "doc29": { - score: 1.137598, - scoreBreakdown: []float64{0.719076, 0.418521}, + score: 1.8859816084399936, + scoreBreakdown: []float64{0.7764299912779237, 1.1095516171620698}, + }, + "doc23": { + score: 1.8615644255330264, + scoreBreakdown: []float64{1.0980378181315602, 0.7635266074014662}, }, "doc27": { - score: 0.429730, - scoreBreakdown: []float64{0.859461, 0}, + score: 0.4640056648691007, + scoreBreakdown: []float64{0.9280113297382014, 0}, }, "doc28": { - score: 0.401976, - scoreBreakdown: []float64{0.803952, 0}, + score: 0.434037555556026, + scoreBreakdown: []float64{0.868075111112052, 0}, }, "doc30": { - score: 0.359538, - scoreBreakdown: []float64{0.719076, 0}, + score: 0.38821499563896184, + scoreBreakdown: []float64{0.7764299912779237, 0}, }, "doc24": { - score: 0.359538, - scoreBreakdown: []float64{0.719076, 0}, + score: 0.38821499563896184, + scoreBreakdown: []float64{0.7764299912779237, 0}, }, }, }, @@ -263,29 +279,29 @@ func TestSimilaritySearchQuery(t *testing.T) { queryIndex: 1, numIndexPartitions: 5, expectedResults: map[string]testResult{ - "doc29": { - score: 1.0019961733597083, - scoreBreakdown: []float64{0.28546778431249004, 0.7165283890472183}, - }, - "doc28": { - score: 0.758083452201489, - scoreBreakdown: []float64{0.3191626822375969, 0.4389207699638921}, - }, "doc23": { - score: 0.32598793043274804, - scoreBreakdown: []float64{0.6519758608654961, 0}, + score: 1.5207250366637521, + scoreBreakdown: []float64{0.7571984292622859, 0.7635266074014662}, + }, + "doc29": { + score: 1.4834345192674083, + scoreBreakdown: []float64{0.3738829021053385, 1.1095516171620698}, }, "doc24": { - score: 0.2305082774045959, - scoreBreakdown: []float64{0.4610165548091918, 0}, + score: 0.2677100734235977, + scoreBreakdown: []float64{0.5354201468471954, 0}, }, "doc27": { - score: 0.17059962977552257, - scoreBreakdown: []float64{0.34119925955104513, 0}, + score: 0.22343776840593196, + scoreBreakdown: []float64{0.4468755368118639, 0}, + }, + "doc28": { + score: 0.20900689401100958, + scoreBreakdown: []float64{0.41801378802201916, 0}, }, "doc30": { - score: 0.14273389215624502, - scoreBreakdown: []float64{0.28546778431249004, 0}, + score: 0.18694145105266924, + scoreBreakdown: []float64{0.3738829021053385, 0}, }, }, }, @@ -297,41 +313,41 @@ func TestSimilaritySearchQuery(t *testing.T) { score: 3333.333333333333, scoreBreakdown: []float64{0, 0, 10000}, }, - "doc23": { - score: 0.3234943403006508, - scoreBreakdown: []float64{0.30977843187823456, 0.1754630785727416, 0}, - }, "doc29": { - score: 0.31601878039486375, - scoreBreakdown: []float64{0.21904643099686824, 0.2549817395954274, 0}, + score: 0.6774608026082964, + scoreBreakdown: []float64{0.23161973134064517, 0.7845714725717996, 0}, + }, + "doc23": { + score: 0.5783030702431613, + scoreBreakdown: []float64{0.32755976365480655, 0.5398948417099355, 0}, }, "doc3": { - score: 0.24118912169729392, - scoreBreakdown: []float64{0.7235673650918818, 0, 0}, + score: 0.2550334160459894, + scoreBreakdown: []float64{0.7651002481379682, 0, 0}, }, "doc13": { - score: 0.20887591025526625, - scoreBreakdown: []float64{0.6266277307657988, 0, 0}, + score: 0.2208654210738964, + scoreBreakdown: []float64{0.6625962632216892, 0, 0}, + }, + "doc5": { + score: 0.21180931116413285, + scoreBreakdown: []float64{0, 0, 0.6354279334923986}, }, "doc27": { - score: 0.08727018618864227, - scoreBreakdown: []float64{0.26181055856592683, 0, 0}, + score: 0.09227950890170131, + scoreBreakdown: []float64{0.27683852670510395, 0, 0}, }, "doc28": { - score: 0.0816337841412418, - scoreBreakdown: []float64{0.2449013524237254, 0, 0}, - }, - "doc24": { - score: 0.07301547699895608, - scoreBreakdown: []float64{0.21904643099686824, 0, 0}, + score: 0.0863195764709126, + scoreBreakdown: []float64{0.2589587294127378, 0, 0}, }, "doc30": { - score: 0.07301547699895608, - scoreBreakdown: []float64{0.21904643099686824, 0, 0}, + score: 0.07720657711354839, + scoreBreakdown: []float64{0.23161973134064517, 0, 0}, }, - "doc5": { - score: 0.06883694922797147, - scoreBreakdown: []float64{0, 0, 0.20651084768391442}, + "doc24": { + score: 0.07720657711354839, + scoreBreakdown: []float64{0.23161973134064517, 0, 0}, }, }, }, @@ -343,33 +359,41 @@ func TestSimilaritySearchQuery(t *testing.T) { score: 3333.333333333333, scoreBreakdown: []float64{0, 0, 10000}, }, + "doc29": { + score: 0.567426591648309, + scoreBreakdown: []float64{0.06656841490066398, 0.7845714725717996, 0}, + }, "doc23": { - score: 0.2195946458144798, - scoreBreakdown: []float64{0.11312710434186327, 0.21626486437985648, 0}, + score: 0.5639255136185979, + scoreBreakdown: []float64{0.3059934287179615, 0.5398948417099355, 0}, }, - "doc28": { - score: 0.18796580117926376, - scoreBreakdown: []float64{0.08943482824521586, 0.1925138735236798, 0}, + "doc5": { + score: 0.21180931116413285, + scoreBreakdown: []float64{0, 0, 0.6354279334923986}, }, "doc3": { - score: 0.12303621410516037, - scoreBreakdown: []float64{0.36910864231548113, 0, 0}, + score: 0.14064944169372873, + scoreBreakdown: []float64{0.42194832508118624, 0, 0}, }, "doc13": { - score: 0.10655248891295889, - scoreBreakdown: []float64{0.3196574667388767, 0, 0}, - }, - "doc5": { - score: 0.07546992065969621, - scoreBreakdown: []float64{0, 0, 0.22640976197908863}, + score: 0.12180599172106943, + scoreBreakdown: []float64{0.3654179751632083, 0, 0}, }, "doc27": { - score: 0.03186995104545246, - scoreBreakdown: []float64{0.09560985313635739, 0, 0}, + score: 0.026521491065731144, + scoreBreakdown: []float64{0.07956447319719344, 0, 0}, + }, + "doc28": { + score: 0.024808583220893122, + scoreBreakdown: []float64{0.07442574966267937, 0, 0}, + }, + "doc30": { + score: 0.02218947163355466, + scoreBreakdown: []float64{0.06656841490066398, 0, 0}, }, "doc24": { - score: 0.02666431434541773, - scoreBreakdown: []float64{0.07999294303625319, 0, 0}, + score: 0.02218947163355466, + scoreBreakdown: []float64{0.06656841490066398, 0, 0}, }, }, }, @@ -419,4 +443,5 @@ func TestSimilaritySearchQuery(t *testing.T) { } cleanUp(t, indexPaths, index.indexes...) } + } diff --git a/search/searcher/search_conjunction.go b/search/searcher/search_conjunction.go index e0ae2c349..21fcf696a 100644 --- a/search/searcher/search_conjunction.go +++ b/search/searcher/search_conjunction.go @@ -16,7 +16,6 @@ package searcher import ( "context" - "math" "reflect" "sort" @@ -34,15 +33,16 @@ func init() { } type ConjunctionSearcher struct { - indexReader index.IndexReader - searchers []search.Searcher - queryNorm float64 - currs []*search.DocumentMatch - maxIDIdx int - scorer *scorer.ConjunctionQueryScorer - initialized bool - options search.SearcherOptions - bytesRead uint64 + indexReader index.IndexReader + searchers []search.Searcher + queryNorm float64 + queryNormForKNN float64 + currs []*search.DocumentMatch + maxIDIdx int + scorer *scorer.ConjunctionQueryScorer + initialized bool + options search.SearcherOptions + bytesRead uint64 } func NewConjunctionSearcher(ctx context.Context, indexReader index.IndexReader, @@ -110,20 +110,6 @@ func (s *ConjunctionSearcher) Size() int { return sizeInBytes } -func (s *ConjunctionSearcher) computeQueryNorm() { - // first calculate sum of squared weights - sumOfSquaredWeights := 0.0 - for _, searcher := range s.searchers { - sumOfSquaredWeights += searcher.Weight() - } - // now compute query norm from this - s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) - // finally tell all the downstream searchers the norm - for _, searcher := range s.searchers { - searcher.SetQueryNorm(s.queryNorm) - } -} - func (s *ConjunctionSearcher) initSearchers(ctx *search.SearchContext) error { var err error // get all searchers pointing at their first match diff --git a/search/searcher/search_disjunction_heap.go b/search/searcher/search_disjunction_heap.go index af5915934..51e210804 100644 --- a/search/searcher/search_disjunction_heap.go +++ b/search/searcher/search_disjunction_heap.go @@ -18,7 +18,6 @@ import ( "bytes" "container/heap" "context" - "math" "reflect" "github.com/blevesearch/bleve/v2/search" @@ -47,13 +46,14 @@ type SearcherCurr struct { type DisjunctionHeapSearcher struct { indexReader index.IndexReader - numSearchers int - scorer *scorer.DisjunctionQueryScorer - min int - queryNorm float64 - initialized bool - searchers []search.Searcher - heap []*SearcherCurr + numSearchers int + scorer *scorer.DisjunctionQueryScorer + min int + queryNorm float64 + queryNormForKNN float64 + initialized bool + searchers []search.Searcher + heap []*SearcherCurr matching []*search.DocumentMatch matchingIdxs []int @@ -109,20 +109,6 @@ func (s *DisjunctionHeapSearcher) Size() int { return sizeInBytes } -func (s *DisjunctionHeapSearcher) computeQueryNorm() { - // first calculate sum of squared weights - sumOfSquaredWeights := 0.0 - for _, searcher := range s.searchers { - sumOfSquaredWeights += searcher.Weight() - } - // now compute query norm from this - s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) - // finally tell all the downstream searchers the norm - for _, searcher := range s.searchers { - searcher.SetQueryNorm(s.queryNorm) - } -} - func (s *DisjunctionHeapSearcher) initSearchers(ctx *search.SearchContext) error { // alloc a single block of SearcherCurrs block := make([]SearcherCurr, len(s.searchers)) diff --git a/search/searcher/search_disjunction_slice.go b/search/searcher/search_disjunction_slice.go index fb9f7bab9..251463e6a 100644 --- a/search/searcher/search_disjunction_slice.go +++ b/search/searcher/search_disjunction_slice.go @@ -16,7 +16,6 @@ package searcher import ( "context" - "math" "reflect" "sort" @@ -34,18 +33,19 @@ func init() { } type DisjunctionSliceSearcher struct { - indexReader index.IndexReader - searchers []search.Searcher - originalPos []int - numSearchers int - queryNorm float64 - currs []*search.DocumentMatch - scorer *scorer.DisjunctionQueryScorer - min int - matching []*search.DocumentMatch - matchingIdxs []int - initialized bool - bytesRead uint64 + indexReader index.IndexReader + searchers []search.Searcher + originalPos []int + numSearchers int + queryNorm float64 + queryNormForKNN float64 + currs []*search.DocumentMatch + scorer *scorer.DisjunctionQueryScorer + min int + matching []*search.DocumentMatch + matchingIdxs []int + initialized bool + bytesRead uint64 } func newDisjunctionSliceSearcher(ctx context.Context, indexReader index.IndexReader, @@ -110,20 +110,6 @@ func (s *DisjunctionSliceSearcher) Size() int { return sizeInBytes } -func (s *DisjunctionSliceSearcher) computeQueryNorm() { - // first calculate sum of squared weights - sumOfSquaredWeights := 0.0 - for _, searcher := range s.searchers { - sumOfSquaredWeights += searcher.Weight() - } - // now compute query norm from this - s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) - // finally tell all the downstream searchers the norm - for _, searcher := range s.searchers { - searcher.SetQueryNorm(s.queryNorm) - } -} - func (s *DisjunctionSliceSearcher) initSearchers(ctx *search.SearchContext) error { var err error // get all searchers pointing at their first match diff --git a/search/searcher/search_knn_util.go b/search/searcher/search_knn_util.go new file mode 100644 index 000000000..dabc61c54 --- /dev/null +++ b/search/searcher/search_knn_util.go @@ -0,0 +1,72 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package searcher + +import ( + "math" + + "github.com/blevesearch/bleve/v2/search" +) + +// util func used by both disjunction and conjunction searchers +// to compute the query norm. +// This follows a separate code path from the non-knn version +// because we need to separate out the weights from the KNN searchers +// and the rest of the searchers to make the knn +// score completely independent of tf-idf. +// the sumOfSquaredWeights depends on the tf-idf weights +// and using the same value for knn searchers will make the +// knn score dependent on tf-idf. +func computeQueryNorm(searchers []search.Searcher) (float64, float64) { + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + + sumOfSquaredWeightsForKNN := 0.0 + + for _, searcher := range searchers { + if knnSearcher, ok := searcher.(*KNNSearcher); ok { + sumOfSquaredWeightsForKNN += knnSearcher.Weight() + } else { + sumOfSquaredWeights += searcher.Weight() + } + } + // now compute query norm from this + queryNorm := 1.0 / math.Sqrt(sumOfSquaredWeights) + queryNormForKNN := 1.0 / math.Sqrt(sumOfSquaredWeightsForKNN) + // finally tell all the downstream searchers the norm + for _, searcher := range searchers { + if knnSearcher, ok := searcher.(*KNNSearcher); ok { + knnSearcher.SetQueryNorm(queryNormForKNN) + } else { + searcher.SetQueryNorm(queryNorm) + } + } + return queryNorm, queryNormForKNN +} + +func (s *DisjunctionSliceSearcher) computeQueryNorm() { + s.queryNorm, s.queryNormForKNN = computeQueryNorm(s.searchers) +} + +func (s *DisjunctionHeapSearcher) computeQueryNorm() { + s.queryNorm, s.queryNormForKNN = computeQueryNorm(s.searchers) +} + +func (s *ConjunctionSearcher) computeQueryNorm() { + s.queryNorm, s.queryNormForKNN = computeQueryNorm(s.searchers) +} diff --git a/search/searcher/search_no_knn_util.go b/search/searcher/search_no_knn_util.go new file mode 100644 index 000000000..abf7a32d8 --- /dev/null +++ b/search/searcher/search_no_knn_util.go @@ -0,0 +1,62 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vectors +// +build !vectors + +package searcher + +import "math" + +func (s *DisjunctionSliceSearcher) computeQueryNorm() { + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + for _, searcher := range s.searchers { + sumOfSquaredWeights += searcher.Weight() + } + // now compute query norm from this + s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) + // finally tell all the downstream searchers the norm + for _, searcher := range s.searchers { + searcher.SetQueryNorm(s.queryNorm) + } +} + +func (s *DisjunctionHeapSearcher) computeQueryNorm() { + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + for _, searcher := range s.searchers { + sumOfSquaredWeights += searcher.Weight() + } + // now compute query norm from this + s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) + // finally tell all the downstream searchers the norm + for _, searcher := range s.searchers { + searcher.SetQueryNorm(s.queryNorm) + } +} + +func (s *ConjunctionSearcher) computeQueryNorm() { + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + for _, searcher := range s.searchers { + sumOfSquaredWeights += searcher.Weight() + } + // now compute query norm from this + s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) + // finally tell all the downstream searchers the norm + for _, searcher := range s.searchers { + searcher.SetQueryNorm(s.queryNorm) + } +} diff --git a/test/tests/knn/dataset-30-docs.json b/test/knn/knn_dataset.json similarity index 100% rename from test/tests/knn/dataset-30-docs.json rename to test/knn/knn_dataset.json diff --git a/test/tests/knn/small-query.json b/test/knn/knn_queries.json similarity index 99% rename from test/tests/knn/small-query.json rename to test/knn/knn_queries.json index 8bafde6e6..2304a96c1 100644 --- a/test/tests/knn/small-query.json +++ b/test/knn/knn_queries.json @@ -8,7 +8,7 @@ "content" ], "from":0, - "size":10, + "size":30, "knn":[{ "field":"vector", "vector":[ @@ -410,7 +410,7 @@ "content" ], "from":0, - "size":10, + "size":30, "knn":[{ "field":"vector", "vector":[ @@ -820,7 +820,7 @@ "content" ], "from":0, - "size":20, + "size":30, "knn":[{ "field":"vector", "vector":[