Skip to content

Commit

Permalink
Adding scorer explanation
Browse files Browse the repository at this point in the history
  • Loading branch information
metonymic-smokey committed Nov 20, 2023
1 parent 645d0e3 commit efaeacb
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 19 deletions.
69 changes: 56 additions & 13 deletions search/scorer/scorer_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package scorer

import (
"fmt"
"reflect"

"github.com/blevesearch/bleve/v2/search"
Expand All @@ -32,28 +33,25 @@ func init() {
}

type KNNQueryScorer struct {
queryVector []float32
queryField string
queryWeight float64
queryBoost float64
queryNorm float64
docTerm uint64
docTotal uint64
options search.SearcherOptions
includeScore bool
similarityMetric string
queryVector []float32
queryField string
queryWeight float64
queryBoost float64
queryNorm float64
options search.SearcherOptions
includeScore bool
similarityMetric string
queryWeightExplanation *search.Explanation
}

func NewKNNQueryScorer(queryVector []float32, queryField string, queryBoost float64,
docTerm uint64, docTotal uint64, options search.SearcherOptions,
options search.SearcherOptions,
similarityMetric string) *KNNQueryScorer {
return &KNNQueryScorer{
queryVector: queryVector,
queryField: queryField,
queryBoost: queryBoost,
queryWeight: 1.0,
docTerm: docTerm,
docTotal: docTotal,
options: options,
includeScore: options.Score != "none",
similarityMetric: similarityMetric,
Expand All @@ -73,9 +71,36 @@ func (sqs *KNNQueryScorer) Score(ctx *search.SearchContext,
score = 1.0 / score
}

if sqs.options.Explain {
childrenExplanations := make([]*search.Explanation, 1)
childrenExplanations[0] = &search.Explanation{
Value: score,
Message: fmt.Sprintf("vector(field(%s:%s) with similarity_metric(%s)=%f",
sqs.queryField, knnMatch.ID, sqs.similarityMetric, score),
}
scoreExplanation = &search.Explanation{
Value: score,
Message: fmt.Sprintf("fieldWeight(%s in doc %s), score of:",
sqs.queryField, knnMatch.ID),
Children: childrenExplanations,
}
}

// if the query weight isn't 1, multiply
if sqs.queryWeight != 1.0 {
score = score * sqs.queryWeight
if sqs.options.Explain {
childExplanations := make([]*search.Explanation, 2)
childExplanations[0] = sqs.queryWeightExplanation
childExplanations[1] = scoreExplanation
scoreExplanation = &search.Explanation{
Value: score,
// Product of score * weight
Message: fmt.Sprintf("weight(%s:%f^%f in %s), product of:",
sqs.queryField, sqs.queryVector, sqs.queryBoost, knnMatch.ID),
Children: childExplanations,
}
}
}

if sqs.includeScore {
Expand All @@ -100,4 +125,22 @@ func (sqs *KNNQueryScorer) SetQueryNorm(qnorm float64) {

// update the query weight
sqs.queryWeight = sqs.queryBoost * sqs.queryNorm

if sqs.options.Explain {
childrenExplanations := make([]*search.Explanation, 2)
childrenExplanations[0] = &search.Explanation{
Value: sqs.queryBoost,
Message: "boost",
}
childrenExplanations[1] = &search.Explanation{
Value: sqs.queryNorm,
Message: "queryNorm",
}
sqs.queryWeightExplanation = &search.Explanation{
Value: sqs.queryWeight,
Message: fmt.Sprintf("queryWeight(%s:%f^%f), product of:",
sqs.queryField, sqs.queryVector, sqs.queryBoost),
Children: childrenExplanations,
}
}
}
153 changes: 153 additions & 0 deletions search/scorer/scorer_knn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Copyright (c) 2023 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build vectors
// +build vectors

package scorer

import (
"reflect"
"testing"

"github.com/blevesearch/bleve/v2/search"
index "github.com/blevesearch/bleve_index_api"
)

func TestKNNScorerExplanation(t *testing.T) {
var queryVector []float32
// arbitrary vector of dims: 64
for i := 0; i < 64; i++ {
queryVector = append(queryVector, float32(i))
}

var resVector []float32
// arbitrary res vector.
for i := 0; i < 64; i++ {
resVector = append(resVector, float32(i))
}

tests := []struct {
termMatch *index.VectorDoc
scorer *KNNQueryScorer
norm float64
result *search.DocumentMatch
}{
{
termMatch: &index.VectorDoc{
ID: index.IndexInternalID("one"),
Score: 0.5,
Vector: resVector,
},
norm: 1.0,
scorer: NewKNNQueryScorer(queryVector, "desc", 1.0,
search.SearcherOptions{Explain: true}, index.EuclideanDistance),
// Specifically testing EuclideanDistance since that involves score inversion.
result: &search.DocumentMatch{
IndexInternalID: index.IndexInternalID("one"),
Score: 0.5,
Expl: &search.Explanation{
Value: 1 / 0.5,
Message: "fieldWeight(desc in doc one), score of:",
Children: []*search.Explanation{
{Value: 1 / 0.5,
Message: "vector(field(desc:one) with similarity_metric(l2_norm)=2.000000",
},
},
},
},
},
{
termMatch: &index.VectorDoc{
ID: index.IndexInternalID("one"),
Score: 0.5,
Vector: resVector,
},
norm: 1.0,
scorer: NewKNNQueryScorer(queryVector, "desc", 1.0,
search.SearcherOptions{Explain: true}, index.CosineSimilarity),
result: &search.DocumentMatch{
IndexInternalID: index.IndexInternalID("one"),
Score: 0.5,
Expl: &search.Explanation{
Value: 0.5,
Message: "fieldWeight(desc in doc one), score of:",
Children: []*search.Explanation{
{Value: 0.5,
Message: "vector(field(desc:one) with similarity_metric(dot_product)=0.500000",
},
},
},
},
},
{
termMatch: &index.VectorDoc{
ID: index.IndexInternalID("one"),
Score: 0.25,
Vector: resVector,
},
norm: 0.5,
scorer: NewKNNQueryScorer(queryVector, "desc", 1.0,
search.SearcherOptions{Explain: true}, index.CosineSimilarity),
result: &search.DocumentMatch{
IndexInternalID: index.IndexInternalID("one"),
Score: 0.25,
Expl: &search.Explanation{
Value: 0.125,
Message: "weight(desc:[0.000000 1.000000 2.000000 3.000000 4.000000 5.000000 6.000000 7.000000 8.000000 9.000000 10.000000 11.000000 12.000000 13.000000 14.000000 15.000000 16.000000 17.000000 18.000000 19.000000 20.000000 21.000000 22.000000 23.000000 24.000000 25.000000 26.000000 27.000000 28.000000 29.000000 30.000000 31.000000 32.000000 33.000000 34.000000 35.000000 36.000000 37.000000 38.000000 39.000000 40.000000 41.000000 42.000000 43.000000 44.000000 45.000000 46.000000 47.000000 48.000000 49.000000 50.000000 51.000000 52.000000 53.000000 54.000000 55.000000 56.000000 57.000000 58.000000 59.000000 60.000000 61.000000 62.000000 63.000000]^1.000000 in one), product of:",
Children: []*search.Explanation{
{
Value: 0.5,
Message: "queryWeight(desc:[0.000000 1.000000 2.000000 3.000000 4.000000 5.000000 6.000000 7.000000 8.000000 9.000000 10.000000 11.000000 12.000000 13.000000 14.000000 15.000000 16.000000 17.000000 18.000000 19.000000 20.000000 21.000000 22.000000 23.000000 24.000000 25.000000 26.000000 27.000000 28.000000 29.000000 30.000000 31.000000 32.000000 33.000000 34.000000 35.000000 36.000000 37.000000 38.000000 39.000000 40.000000 41.000000 42.000000 43.000000 44.000000 45.000000 46.000000 47.000000 48.000000 49.000000 50.000000 51.000000 52.000000 53.000000 54.000000 55.000000 56.000000 57.000000 58.000000 59.000000 60.000000 61.000000 62.000000 63.000000]^1.000000), product of:",
Children: []*search.Explanation{
{
Value: 1,
Message: "boost",
},
{
Value: 0.5,
Message: "queryNorm",
},
},
},
{
Value: 0.25,
Message: "fieldWeight(desc in doc one), score of:",
Children: []*search.Explanation{
{
Value: 0.25,
Message: "vector(field(desc:one) with similarity_metric(dot_product)=0.250000",
},
},
},
},
},
},
},
}

for _, test := range tests {
ctx := &search.SearchContext{
DocumentMatchPool: search.NewDocumentMatchPool(1, 0),
}
test.scorer.SetQueryNorm(test.norm)
actual := test.scorer.Score(ctx, test.termMatch)
actual.Complete(nil)

if !reflect.DeepEqual(actual.Expl, test.result.Expl) {
t.Errorf("expected %#v got %#v for %#v", test.result.Expl,
actual.Expl, test.termMatch)
}
}
}
7 changes: 1 addition & 6 deletions search/searcher/search_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,9 @@ func NewKNNSearcher(ctx context.Context, i index.IndexReader, m mapping.IndexMap
if err != nil {
return nil, err
}
count, err := i.DocCount()
if err != nil {
_ = vectorReader.Close()
return nil, err
}

knnScorer := scorer.NewKNNQueryScorer(vector, field, boost,
vectorReader.Count(), count, options, similarityMetric)
options, similarityMetric)
return &KNNSearcher{
indexReader: i,
vectorReader: vectorReader,
Expand Down

0 comments on commit efaeacb

Please sign in to comment.