Skip to content

Commit

Permalink
Adding KNN scorer explanation (#1899)
Browse files Browse the repository at this point in the history
This PR adds a score Explanation to the KNN scorer, along with a unit
test.
Also, contains a minor refactor of scoring related code.

---------

Co-authored-by: Abhinav Dangeti <[email protected]>
  • Loading branch information
metonymic-smokey and abhinavdangeti authored Nov 22, 2023
1 parent 6291df2 commit 835f042
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 27 deletions.
89 changes: 68 additions & 21 deletions search/scorer/scorer_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package scorer

import (
"fmt"
"math"
"reflect"

Expand All @@ -33,28 +34,25 @@ func init() {
}

type KNNQueryScorer struct {
queryVector []float32
queryField string
queryWeight float64
queryBoost float64
queryNorm float64
docTerm uint64
docTotal uint64
options search.SearcherOptions
includeScore bool
similarityMetric string
queryVector []float32
queryField string
queryWeight float64
queryBoost float64
queryNorm float64
options search.SearcherOptions
includeScore bool
similarityMetric string
queryWeightExplanation *search.Explanation
}

func NewKNNQueryScorer(queryVector []float32, queryField string, queryBoost float64,
docTerm uint64, docTotal uint64, options search.SearcherOptions,
options search.SearcherOptions,
similarityMetric string) *KNNQueryScorer {
return &KNNQueryScorer{
queryVector: queryVector,
queryField: queryField,
queryBoost: queryBoost,
queryWeight: 1.0,
docTerm: docTerm,
docTotal: docTotal,
options: options,
includeScore: options.Score != "none",
similarityMetric: similarityMetric,
Expand All @@ -72,19 +70,50 @@ func (sqs *KNNQueryScorer) Score(ctx *search.SearchContext,
if sqs.includeScore || sqs.options.Explain {
var scoreExplanation *search.Explanation
score := knnMatch.Score
// in case of euclidean distance being the distance metric,
// an exact vector (perfect match), would return distance = 0
if score == 0 {
score = maxKNNScore
} else {
// euclidean distances need to be inverted to work with
// tf-idf scoring
score = 1.0 / score
if sqs.similarityMetric == index.EuclideanDistance {
// in case of euclidean distance being the distance metric,
// an exact vector (perfect match), would return distance = 0
if score == 0 {
score = maxKNNScore
} else {
// euclidean distances need to be inverted to work with
// tf-idf scoring
score = 1.0 / score
}
}

if sqs.options.Explain {
childExplanations := make([]*search.Explanation, 1)
childExplanations[0] = &search.Explanation{
Value: score,
Message: fmt.Sprintf("vector(field(%s:%s) with similarity_metric(%s)=%e",
sqs.queryField, knnMatch.ID, sqs.similarityMetric, score),
}
scoreExplanation = &search.Explanation{
Value: score,
Message: fmt.Sprintf("fieldWeight(%s in doc %s), score of:",
sqs.queryField, knnMatch.ID),
Children: childExplanations,
}
}

// if the query weight isn't 1, multiply
if sqs.queryWeight != 1.0 && score != maxKNNScore {
score = score * sqs.queryWeight
if sqs.options.Explain {
childExplanations := make([]*search.Explanation, 2)
childExplanations[0] = sqs.queryWeightExplanation
childExplanations[1] = scoreExplanation
scoreExplanation = &search.Explanation{
Value: score,
// Product of score * weight
// Avoid adding the query vector to the explanation since vectors
// can get quite large.
Message: fmt.Sprintf("weight(%s:query Vector^%f in %s), product of:",
sqs.queryField, sqs.queryBoost, knnMatch.ID),
Children: childExplanations,
}
}
}

if sqs.includeScore {
Expand All @@ -109,4 +138,22 @@ func (sqs *KNNQueryScorer) SetQueryNorm(qnorm float64) {

// update the query weight
sqs.queryWeight = sqs.queryBoost * sqs.queryNorm

if sqs.options.Explain {
childrenExplanations := make([]*search.Explanation, 2)
childrenExplanations[0] = &search.Explanation{
Value: sqs.queryBoost,
Message: "boost",
}
childrenExplanations[1] = &search.Explanation{
Value: sqs.queryNorm,
Message: "queryNorm",
}
sqs.queryWeightExplanation = &search.Explanation{
Value: sqs.queryWeight,
Message: fmt.Sprintf("queryWeight(%s:query Vector^%f), product of:",
sqs.queryField, sqs.queryBoost),
Children: childrenExplanations,
}
}
}
178 changes: 178 additions & 0 deletions search/scorer/scorer_knn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// Copyright (c) 2023 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build vectors
// +build vectors

package scorer

import (
"reflect"
"testing"

"github.com/blevesearch/bleve/v2/search"
index "github.com/blevesearch/bleve_index_api"
)

func TestKNNScorerExplanation(t *testing.T) {
var queryVector []float32
// arbitrary vector of dims: 64
for i := 0; i < 64; i++ {
queryVector = append(queryVector, float32(i))
}

var resVector []float32
// arbitrary res vector.
for i := 0; i < 64; i++ {
resVector = append(resVector, float32(i))
}

tests := []struct {
vectorMatch *index.VectorDoc
scorer *KNNQueryScorer
norm float64
result *search.DocumentMatch
}{
{
vectorMatch: &index.VectorDoc{
ID: index.IndexInternalID("one"),
Score: 0.5,
Vector: resVector,
},
norm: 1.0,
scorer: NewKNNQueryScorer(queryVector, "desc", 1.0,
search.SearcherOptions{Explain: true}, index.EuclideanDistance),
// Specifically testing EuclideanDistance since that involves score inversion.
result: &search.DocumentMatch{
IndexInternalID: index.IndexInternalID("one"),
Score: 0.5,
Expl: &search.Explanation{
Value: 1 / 0.5,
Message: "fieldWeight(desc in doc one), score of:",
Children: []*search.Explanation{
{Value: 1 / 0.5,
Message: "vector(field(desc:one) with similarity_metric(l2_norm)=2.000000e+00",
},
},
},
},
},
{
vectorMatch: &index.VectorDoc{
ID: index.IndexInternalID("one"),
Score: 0.0,
// Result vector is an exact match of an existing vector.
Vector: queryVector,
},
norm: 1.0,
scorer: NewKNNQueryScorer(queryVector, "desc", 1.0,
search.SearcherOptions{Explain: true}, index.EuclideanDistance),
// Specifically testing EuclideanDistance with 0 score.
result: &search.DocumentMatch{
IndexInternalID: index.IndexInternalID("one"),
Score: 0.0,
Expl: &search.Explanation{
Value: maxKNNScore,
Message: "fieldWeight(desc in doc one), score of:",
Children: []*search.Explanation{
{Value: maxKNNScore,
Message: "vector(field(desc:one) with similarity_metric(l2_norm)=1.797693e+308",
},
},
},
},
},
{
vectorMatch: &index.VectorDoc{
ID: index.IndexInternalID("one"),
Score: 0.5,
Vector: resVector,
},
norm: 1.0,
scorer: NewKNNQueryScorer(queryVector, "desc", 1.0,
search.SearcherOptions{Explain: true}, index.CosineSimilarity),
result: &search.DocumentMatch{
IndexInternalID: index.IndexInternalID("one"),
Score: 0.5,
Expl: &search.Explanation{
Value: 0.5,
Message: "fieldWeight(desc in doc one), score of:",
Children: []*search.Explanation{
{Value: 0.5,
Message: "vector(field(desc:one) with similarity_metric(dot_product)=5.000000e-01",
},
},
},
},
},
{
vectorMatch: &index.VectorDoc{
ID: index.IndexInternalID("one"),
Score: 0.25,
Vector: resVector,
},
norm: 0.5,
scorer: NewKNNQueryScorer(queryVector, "desc", 1.0,
search.SearcherOptions{Explain: true}, index.CosineSimilarity),
result: &search.DocumentMatch{
IndexInternalID: index.IndexInternalID("one"),
Score: 0.25,
Expl: &search.Explanation{
Value: 0.125,
Message: "weight(desc:query Vector^1.000000 in one), product of:",
Children: []*search.Explanation{
{
Value: 0.5,
Message: "queryWeight(desc:query Vector^1.000000), product of:",
Children: []*search.Explanation{
{
Value: 1,
Message: "boost",
},
{
Value: 0.5,
Message: "queryNorm",
},
},
},
{
Value: 0.25,
Message: "fieldWeight(desc in doc one), score of:",
Children: []*search.Explanation{
{
Value: 0.25,
Message: "vector(field(desc:one) with similarity_metric(dot_product)=2.500000e-01",
},
},
},
},
},
},
},
}

for _, test := range tests {
ctx := &search.SearchContext{
DocumentMatchPool: search.NewDocumentMatchPool(1, 0),
}
test.scorer.SetQueryNorm(test.norm)
actual := test.scorer.Score(ctx, test.vectorMatch)
actual.Complete(nil)

if !reflect.DeepEqual(actual.Expl, test.result.Expl) {
t.Errorf("expected %#v got %#v for %#v", test.result.Expl,
actual.Expl, test.vectorMatch)
}
}
}
7 changes: 1 addition & 6 deletions search/searcher/search_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,9 @@ func NewKNNSearcher(ctx context.Context, i index.IndexReader, m mapping.IndexMap
if err != nil {
return nil, err
}
count, err := i.DocCount()
if err != nil {
_ = vectorReader.Close()
return nil, err
}

knnScorer := scorer.NewKNNQueryScorer(vector, field, boost,
vectorReader.Count(), count, options, similarityMetric)
options, similarityMetric)
return &KNNSearcher{
indexReader: i,
vectorReader: vectorReader,
Expand Down

0 comments on commit 835f042

Please sign in to comment.