diff --git a/search/scorer/scorer_knn.go b/search/scorer/scorer_knn.go index 70724fa65..5b36d1e52 100644 --- a/search/scorer/scorer_knn.go +++ b/search/scorer/scorer_knn.go @@ -18,6 +18,7 @@ package scorer import ( + "fmt" "reflect" "github.com/blevesearch/bleve/v2/search" @@ -32,28 +33,25 @@ func init() { } type KNNQueryScorer struct { - queryVector []float32 - queryField string - queryWeight float64 - queryBoost float64 - queryNorm float64 - docTerm uint64 - docTotal uint64 - options search.SearcherOptions - includeScore bool - similarityMetric string + queryVector []float32 + queryField string + queryWeight float64 + queryBoost float64 + queryNorm float64 + options search.SearcherOptions + includeScore bool + similarityMetric string + queryWeightExplanation *search.Explanation } func NewKNNQueryScorer(queryVector []float32, queryField string, queryBoost float64, - docTerm uint64, docTotal uint64, options search.SearcherOptions, + options search.SearcherOptions, similarityMetric string) *KNNQueryScorer { return &KNNQueryScorer{ queryVector: queryVector, queryField: queryField, queryBoost: queryBoost, queryWeight: 1.0, - docTerm: docTerm, - docTotal: docTotal, options: options, includeScore: options.Score != "none", similarityMetric: similarityMetric, @@ -73,9 +71,36 @@ func (sqs *KNNQueryScorer) Score(ctx *search.SearchContext, score = 1.0 / score } + if sqs.options.Explain { + childrenExplanations := make([]*search.Explanation, 1) + childrenExplanations[0] = &search.Explanation{ + Value: score, + Message: fmt.Sprintf("vector(field(%s:%s) with similarity_metric(%s)=%f", + sqs.queryField, knnMatch.ID, sqs.similarityMetric, score), + } + scoreExplanation = &search.Explanation{ + Value: score, + Message: fmt.Sprintf("fieldWeight(%s in doc %s), score of:", + sqs.queryField, knnMatch.ID), + Children: childrenExplanations, + } + } + // if the query weight isn't 1, multiply if sqs.queryWeight != 1.0 { score = score * sqs.queryWeight + if sqs.options.Explain { + childExplanations := make([]*search.Explanation, 2) + childExplanations[0] = sqs.queryWeightExplanation + childExplanations[1] = scoreExplanation + scoreExplanation = &search.Explanation{ + Value: score, + // Product of score * weight + Message: fmt.Sprintf("weight(%s:%f^%f in %s), product of:", + sqs.queryField, sqs.queryVector, sqs.queryBoost, knnMatch.ID), + Children: childExplanations, + } + } } if sqs.includeScore { @@ -100,4 +125,22 @@ func (sqs *KNNQueryScorer) SetQueryNorm(qnorm float64) { // update the query weight sqs.queryWeight = sqs.queryBoost * sqs.queryNorm + + if sqs.options.Explain { + childrenExplanations := make([]*search.Explanation, 2) + childrenExplanations[0] = &search.Explanation{ + Value: sqs.queryBoost, + Message: "boost", + } + childrenExplanations[1] = &search.Explanation{ + Value: sqs.queryNorm, + Message: "queryNorm", + } + sqs.queryWeightExplanation = &search.Explanation{ + Value: sqs.queryWeight, + Message: fmt.Sprintf("queryWeight(%s:%f^%f), product of:", + sqs.queryField, sqs.queryVector, sqs.queryBoost), + Children: childrenExplanations, + } + } } diff --git a/search/scorer/scorer_knn_test.go b/search/scorer/scorer_knn_test.go new file mode 100644 index 000000000..57900e28b --- /dev/null +++ b/search/scorer/scorer_knn_test.go @@ -0,0 +1,153 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package scorer + +import ( + "reflect" + "testing" + + "github.com/blevesearch/bleve/v2/search" + index "github.com/blevesearch/bleve_index_api" +) + +func TestKNNScorerExplanation(t *testing.T) { + var queryVector []float32 + // arbitrary vector of dims: 64 + for i := 0; i < 64; i++ { + queryVector = append(queryVector, float32(i)) + } + + var resVector []float32 + // arbitrary res vector. + for i := 0; i < 64; i++ { + resVector = append(resVector, float32(i)) + } + + tests := []struct { + termMatch *index.VectorDoc + scorer *KNNQueryScorer + norm float64 + result *search.DocumentMatch + }{ + { + termMatch: &index.VectorDoc{ + ID: index.IndexInternalID("one"), + Score: 0.5, + Vector: resVector, + }, + norm: 1.0, + scorer: NewKNNQueryScorer(queryVector, "desc", 1.0, + search.SearcherOptions{Explain: true}, index.EuclideanDistance), + // Specifically testing EuclideanDistance since that involves score inversion. + result: &search.DocumentMatch{ + IndexInternalID: index.IndexInternalID("one"), + Score: 0.5, + Expl: &search.Explanation{ + Value: 1 / 0.5, + Message: "fieldWeight(desc in doc one), score of:", + Children: []*search.Explanation{ + {Value: 1 / 0.5, + Message: "vector(field(desc:one) with similarity_metric(l2_norm)=2.000000", + }, + }, + }, + }, + }, + { + termMatch: &index.VectorDoc{ + ID: index.IndexInternalID("one"), + Score: 0.5, + Vector: resVector, + }, + norm: 1.0, + scorer: NewKNNQueryScorer(queryVector, "desc", 1.0, + search.SearcherOptions{Explain: true}, index.CosineSimilarity), + result: &search.DocumentMatch{ + IndexInternalID: index.IndexInternalID("one"), + Score: 0.5, + Expl: &search.Explanation{ + Value: 0.5, + Message: "fieldWeight(desc in doc one), score of:", + Children: []*search.Explanation{ + {Value: 0.5, + Message: "vector(field(desc:one) with similarity_metric(dot_product)=0.500000", + }, + }, + }, + }, + }, + { + termMatch: &index.VectorDoc{ + ID: index.IndexInternalID("one"), + Score: 0.25, + Vector: resVector, + }, + norm: 0.5, + scorer: NewKNNQueryScorer(queryVector, "desc", 1.0, + search.SearcherOptions{Explain: true}, index.CosineSimilarity), + result: &search.DocumentMatch{ + IndexInternalID: index.IndexInternalID("one"), + Score: 0.25, + Expl: &search.Explanation{ + Value: 0.125, + Message: "weight(desc:[0.000000 1.000000 2.000000 3.000000 4.000000 5.000000 6.000000 7.000000 8.000000 9.000000 10.000000 11.000000 12.000000 13.000000 14.000000 15.000000 16.000000 17.000000 18.000000 19.000000 20.000000 21.000000 22.000000 23.000000 24.000000 25.000000 26.000000 27.000000 28.000000 29.000000 30.000000 31.000000 32.000000 33.000000 34.000000 35.000000 36.000000 37.000000 38.000000 39.000000 40.000000 41.000000 42.000000 43.000000 44.000000 45.000000 46.000000 47.000000 48.000000 49.000000 50.000000 51.000000 52.000000 53.000000 54.000000 55.000000 56.000000 57.000000 58.000000 59.000000 60.000000 61.000000 62.000000 63.000000]^1.000000 in one), product of:", + Children: []*search.Explanation{ + { + Value: 0.5, + Message: "queryWeight(desc:[0.000000 1.000000 2.000000 3.000000 4.000000 5.000000 6.000000 7.000000 8.000000 9.000000 10.000000 11.000000 12.000000 13.000000 14.000000 15.000000 16.000000 17.000000 18.000000 19.000000 20.000000 21.000000 22.000000 23.000000 24.000000 25.000000 26.000000 27.000000 28.000000 29.000000 30.000000 31.000000 32.000000 33.000000 34.000000 35.000000 36.000000 37.000000 38.000000 39.000000 40.000000 41.000000 42.000000 43.000000 44.000000 45.000000 46.000000 47.000000 48.000000 49.000000 50.000000 51.000000 52.000000 53.000000 54.000000 55.000000 56.000000 57.000000 58.000000 59.000000 60.000000 61.000000 62.000000 63.000000]^1.000000), product of:", + Children: []*search.Explanation{ + { + Value: 1, + Message: "boost", + }, + { + Value: 0.5, + Message: "queryNorm", + }, + }, + }, + { + Value: 0.25, + Message: "fieldWeight(desc in doc one), score of:", + Children: []*search.Explanation{ + { + Value: 0.25, + Message: "vector(field(desc:one) with similarity_metric(dot_product)=0.250000", + }, + }, + }, + }, + }, + }, + }, + } + + for _, test := range tests { + ctx := &search.SearchContext{ + DocumentMatchPool: search.NewDocumentMatchPool(1, 0), + } + test.scorer.SetQueryNorm(test.norm) + actual := test.scorer.Score(ctx, test.termMatch) + actual.Complete(nil) + + if !reflect.DeepEqual(actual.Expl, test.result.Expl) { + t.Errorf("expected %#v got %#v for %#v", test.result.Expl, + actual.Expl, test.termMatch) + } + } +} diff --git a/search/searcher/search_knn.go b/search/searcher/search_knn.go index ed1c078aa..a88dcf849 100644 --- a/search/searcher/search_knn.go +++ b/search/searcher/search_knn.go @@ -45,14 +45,9 @@ func NewKNNSearcher(ctx context.Context, i index.IndexReader, m mapping.IndexMap if err != nil { return nil, err } - count, err := i.DocCount() - if err != nil { - _ = vectorReader.Close() - return nil, err - } knnScorer := scorer.NewKNNQueryScorer(vector, field, boost, - vectorReader.Count(), count, options, similarityMetric) + options, similarityMetric) return &KNNSearcher{ indexReader: i, vectorReader: vectorReader,