Skip to content

Commit

Permalink
fix scoring
Browse files Browse the repository at this point in the history
  • Loading branch information
CascadingRadium committed Nov 17, 2023
1 parent 147574e commit 8f0d8c1
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 156 deletions.
185 changes: 105 additions & 80 deletions knn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@ import (
"github.com/blevesearch/bleve/v2/mapping"
)

const testDatasetFileName = "test/tests/knn/dataset-30-docs.json"
const testQueryFileName = "test/tests/knn/small-query.json"
const testDatasetFileName = "test/knn/knn_dataset.json"
const testQueryFileName = "test/knn/knn_queries.json"

const testDatasetDims = 384

const randomizeDocuments = false

type testDocument struct {
ID string `json:"id"`
Content string `json:"content"`
Expand Down Expand Up @@ -93,7 +91,17 @@ func createPartitionedIndex(documents []map[string]interface{}, index *indexAlia
mapping mapping.IndexMapping, t *testing.T) []string {

partitionSize := len(documents) / numPartitions
extraDocs := len(documents) % numPartitions
docsPerPartition := make([]int, numPartitions)
for i := 0; i < numPartitions; i++ {
docsPerPartition[i] = partitionSize
if extraDocs > 0 {
docsPerPartition[i]++
extraDocs--
}
}
var rv []string
prevCutoff := 0
for i := 0; i < numPartitions; i++ {
tmpIndexPath := createTmpIndexPath(t)
rv = append(rv, tmpIndexPath)
Expand All @@ -103,14 +111,15 @@ func createPartitionedIndex(documents []map[string]interface{}, index *indexAlia
t.Fatal(err)
}
batch := childIndex.NewBatch()
for j := i * partitionSize; (j < (i+1)*partitionSize) && j < len(documents); j++ {
for j := prevCutoff; j < prevCutoff+docsPerPartition[i]; j++ {
doc := documents[j]
err := batch.Index(doc["id"].(string), doc)
if err != nil {
cleanUp(t, rv)
t.Fatal(err)
}
}
prevCutoff += docsPerPartition[i]
err = childIndex.Batch(batch)
if err != nil {
cleanUp(t, rv)
Expand All @@ -134,8 +143,15 @@ func truncateScore(score float64) float64 {
return float64(int(score*1e6)) / 1e6
}

func TestSimilaritySearchQuery(t *testing.T) {
func TestSimilaritySearchRandomized(t *testing.T) {
runKNNTest(t, true)
}

func TestSimilaritySearchNotRandomized(t *testing.T) {
runKNNTest(t, false)
}

func runKNNTest(t *testing.T, randomizeDocuments bool) {
dataset, err := createVectorDataset(testDatasetFileName)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -197,6 +213,10 @@ func TestSimilaritySearchQuery(t *testing.T) {
queryIndex: 0,
numIndexPartitions: 4,
expectedResults: map[string]testResult{
"doc29": {
score: 0.5547758085810349,
scoreBreakdown: []float64{0, 1.1095516171620698},
},
"doc23": {
score: 0.3817633037007331,
scoreBreakdown: []float64{0, 0.7635266074014662},
Expand All @@ -205,10 +225,6 @@ func TestSimilaritySearchQuery(t *testing.T) {
score: 0.33983667469689355,
scoreBreakdown: []float64{0, 0.6796733493937871},
},
"doc13": {
score: 0.3206958457835452,
scoreBreakdown: []float64{0, 0.6413916915670904},
},
},
},
{
Expand All @@ -233,59 +249,59 @@ func TestSimilaritySearchQuery(t *testing.T) {
queryIndex: 1,
numIndexPartitions: 1,
expectedResults: map[string]testResult{
"doc23": {
score: 1.304929,
scoreBreakdown: []float64{1.016928, 0.288001},
},
"doc29": {
score: 1.137598,
scoreBreakdown: []float64{0.719076, 0.418521},
score: 1.8859816084399936,
scoreBreakdown: []float64{0.7764299912779237, 1.1095516171620698},
},
"doc23": {
score: 1.8615644255330264,
scoreBreakdown: []float64{1.0980378181315602, 0.7635266074014662},
},
"doc27": {
score: 0.429730,
scoreBreakdown: []float64{0.859461, 0},
score: 0.4640056648691007,
scoreBreakdown: []float64{0.9280113297382014, 0},
},
"doc28": {
score: 0.401976,
scoreBreakdown: []float64{0.803952, 0},
score: 0.434037555556026,
scoreBreakdown: []float64{0.868075111112052, 0},
},
"doc30": {
score: 0.359538,
scoreBreakdown: []float64{0.719076, 0},
score: 0.38821499563896184,
scoreBreakdown: []float64{0.7764299912779237, 0},
},
"doc24": {
score: 0.359538,
scoreBreakdown: []float64{0.719076, 0},
score: 0.38821499563896184,
scoreBreakdown: []float64{0.7764299912779237, 0},
},
},
},
{
queryIndex: 1,
numIndexPartitions: 5,
expectedResults: map[string]testResult{
"doc29": {
score: 1.0019961733597083,
scoreBreakdown: []float64{0.28546778431249004, 0.7165283890472183},
},
"doc28": {
score: 0.758083452201489,
scoreBreakdown: []float64{0.3191626822375969, 0.4389207699638921},
},
"doc23": {
score: 0.32598793043274804,
scoreBreakdown: []float64{0.6519758608654961, 0},
score: 1.5207250366637521,
scoreBreakdown: []float64{0.7571984292622859, 0.7635266074014662},
},
"doc29": {
score: 1.4834345192674083,
scoreBreakdown: []float64{0.3738829021053385, 1.1095516171620698},
},
"doc24": {
score: 0.2305082774045959,
scoreBreakdown: []float64{0.4610165548091918, 0},
score: 0.2677100734235977,
scoreBreakdown: []float64{0.5354201468471954, 0},
},
"doc27": {
score: 0.17059962977552257,
scoreBreakdown: []float64{0.34119925955104513, 0},
score: 0.22343776840593196,
scoreBreakdown: []float64{0.4468755368118639, 0},
},
"doc28": {
score: 0.20900689401100958,
scoreBreakdown: []float64{0.41801378802201916, 0},
},
"doc30": {
score: 0.14273389215624502,
scoreBreakdown: []float64{0.28546778431249004, 0},
score: 0.18694145105266924,
scoreBreakdown: []float64{0.3738829021053385, 0},
},
},
},
Expand All @@ -297,41 +313,41 @@ func TestSimilaritySearchQuery(t *testing.T) {
score: 3333.333333333333,
scoreBreakdown: []float64{0, 0, 10000},
},
"doc23": {
score: 0.3234943403006508,
scoreBreakdown: []float64{0.30977843187823456, 0.1754630785727416, 0},
},
"doc29": {
score: 0.31601878039486375,
scoreBreakdown: []float64{0.21904643099686824, 0.2549817395954274, 0},
score: 0.6774608026082964,
scoreBreakdown: []float64{0.23161973134064517, 0.7845714725717996, 0},
},
"doc23": {
score: 0.5783030702431613,
scoreBreakdown: []float64{0.32755976365480655, 0.5398948417099355, 0},
},
"doc3": {
score: 0.24118912169729392,
scoreBreakdown: []float64{0.7235673650918818, 0, 0},
score: 0.2550334160459894,
scoreBreakdown: []float64{0.7651002481379682, 0, 0},
},
"doc13": {
score: 0.20887591025526625,
scoreBreakdown: []float64{0.6266277307657988, 0, 0},
score: 0.2208654210738964,
scoreBreakdown: []float64{0.6625962632216892, 0, 0},
},
"doc5": {
score: 0.21180931116413285,
scoreBreakdown: []float64{0, 0, 0.6354279334923986},
},
"doc27": {
score: 0.08727018618864227,
scoreBreakdown: []float64{0.26181055856592683, 0, 0},
score: 0.09227950890170131,
scoreBreakdown: []float64{0.27683852670510395, 0, 0},
},
"doc28": {
score: 0.0816337841412418,
scoreBreakdown: []float64{0.2449013524237254, 0, 0},
},
"doc24": {
score: 0.07301547699895608,
scoreBreakdown: []float64{0.21904643099686824, 0, 0},
score: 0.0863195764709126,
scoreBreakdown: []float64{0.2589587294127378, 0, 0},
},
"doc30": {
score: 0.07301547699895608,
scoreBreakdown: []float64{0.21904643099686824, 0, 0},
score: 0.07720657711354839,
scoreBreakdown: []float64{0.23161973134064517, 0, 0},
},
"doc5": {
score: 0.06883694922797147,
scoreBreakdown: []float64{0, 0, 0.20651084768391442},
"doc24": {
score: 0.07720657711354839,
scoreBreakdown: []float64{0.23161973134064517, 0, 0},
},
},
},
Expand All @@ -343,33 +359,41 @@ func TestSimilaritySearchQuery(t *testing.T) {
score: 3333.333333333333,
scoreBreakdown: []float64{0, 0, 10000},
},
"doc29": {
score: 0.567426591648309,
scoreBreakdown: []float64{0.06656841490066398, 0.7845714725717996, 0},
},
"doc23": {
score: 0.2195946458144798,
scoreBreakdown: []float64{0.11312710434186327, 0.21626486437985648, 0},
score: 0.5639255136185979,
scoreBreakdown: []float64{0.3059934287179615, 0.5398948417099355, 0},
},
"doc28": {
score: 0.18796580117926376,
scoreBreakdown: []float64{0.08943482824521586, 0.1925138735236798, 0},
"doc5": {
score: 0.21180931116413285,
scoreBreakdown: []float64{0, 0, 0.6354279334923986},
},
"doc3": {
score: 0.12303621410516037,
scoreBreakdown: []float64{0.36910864231548113, 0, 0},
score: 0.14064944169372873,
scoreBreakdown: []float64{0.42194832508118624, 0, 0},
},
"doc13": {
score: 0.10655248891295889,
scoreBreakdown: []float64{0.3196574667388767, 0, 0},
},
"doc5": {
score: 0.07546992065969621,
scoreBreakdown: []float64{0, 0, 0.22640976197908863},
score: 0.12180599172106943,
scoreBreakdown: []float64{0.3654179751632083, 0, 0},
},
"doc27": {
score: 0.03186995104545246,
scoreBreakdown: []float64{0.09560985313635739, 0, 0},
score: 0.026521491065731144,
scoreBreakdown: []float64{0.07956447319719344, 0, 0},
},
"doc28": {
score: 0.024808583220893122,
scoreBreakdown: []float64{0.07442574966267937, 0, 0},
},
"doc30": {
score: 0.02218947163355466,
scoreBreakdown: []float64{0.06656841490066398, 0, 0},
},
"doc24": {
score: 0.02666431434541773,
scoreBreakdown: []float64{0.07999294303625319, 0, 0},
score: 0.02218947163355466,
scoreBreakdown: []float64{0.06656841490066398, 0, 0},
},
},
},
Expand Down Expand Up @@ -419,4 +443,5 @@ func TestSimilaritySearchQuery(t *testing.T) {
}
cleanUp(t, indexPaths, index.indexes...)
}

}
34 changes: 10 additions & 24 deletions search/searcher/search_conjunction.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package searcher

import (
"context"
"math"
"reflect"
"sort"

Expand All @@ -34,15 +33,16 @@ func init() {
}

type ConjunctionSearcher struct {
indexReader index.IndexReader
searchers []search.Searcher
queryNorm float64
currs []*search.DocumentMatch
maxIDIdx int
scorer *scorer.ConjunctionQueryScorer
initialized bool
options search.SearcherOptions
bytesRead uint64
indexReader index.IndexReader
searchers []search.Searcher
queryNorm float64
queryNormForKNN float64
currs []*search.DocumentMatch
maxIDIdx int
scorer *scorer.ConjunctionQueryScorer
initialized bool
options search.SearcherOptions
bytesRead uint64
}

func NewConjunctionSearcher(ctx context.Context, indexReader index.IndexReader,
Expand Down Expand Up @@ -110,20 +110,6 @@ func (s *ConjunctionSearcher) Size() int {
return sizeInBytes
}

func (s *ConjunctionSearcher) computeQueryNorm() {
// first calculate sum of squared weights
sumOfSquaredWeights := 0.0
for _, searcher := range s.searchers {
sumOfSquaredWeights += searcher.Weight()
}
// now compute query norm from this
s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights)
// finally tell all the downstream searchers the norm
for _, searcher := range s.searchers {
searcher.SetQueryNorm(s.queryNorm)
}
}

func (s *ConjunctionSearcher) initSearchers(ctx *search.SearchContext) error {
var err error
// get all searchers pointing at their first match
Expand Down
Loading

0 comments on commit 8f0d8c1

Please sign in to comment.