Skip to content

Commit

Permalink
making bm25 presearch (i.e. global scoring) optional
Browse files Browse the repository at this point in the history
  • Loading branch information
Thejas-bhat committed Dec 10, 2024
1 parent 6ca0e3e commit 3a3acfc
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 43 deletions.
18 changes: 12 additions & 6 deletions index_alias_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package bleve

import (
"context"
"fmt"
"sync"
"time"

Expand Down Expand Up @@ -215,15 +214,14 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest
// - the request requires preSearch
var preSearchDuration time.Duration
var sr *SearchResult
flags := preSearchRequired(req, i.mapping)
flags := preSearchRequired(ctx, req, i.mapping)
if req.PreSearchData == nil && flags != nil {
searchStart := time.Now()
preSearchResult, err := preSearch(ctx, req, flags, i.indexes...)
if err != nil {
return nil, err
}

fmt.Println("presearch result", preSearchResult.docCount)
// check if the preSearch result has any errors and if so
// return the search result as is without executing the query
// so that the errors are not lost
Expand Down Expand Up @@ -557,12 +555,20 @@ type preSearchFlags struct {

// preSearchRequired checks if preSearch is required and returns the presearch flags struct
// indicating which preSearch is required
func preSearchRequired(req *SearchRequest, m mapping.IndexMapping) *preSearchFlags {
func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexMapping) *preSearchFlags {
// Check for KNN query
knn := requestHasKNN(req)
var bm25 bool
if _, ok := m.(mapping.BM25Mapping); ok {
bm25 = true

if ctx != nil {
if searchType := ctx.Value(search.SearchTypeKey); searchType != nil {
if searchType.(string) == search.FetchStatsAndSearch {
// todo: check mapping to see if bm25 is needed
if _, ok := m.(mapping.BM25Mapping); ok {
bm25 = true
}
}
}
}

if knn || bm25 {
Expand Down
4 changes: 3 additions & 1 deletion index_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr
if v != nil {
bm25Data, ok = v.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("bm25 preSearchData must be of type uint64")
return nil, fmt.Errorf("bm25 preSearchData must be of type map[string]interface{}")
}
}
}
Expand All @@ -563,6 +563,8 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr

setKnnHitsInCollector(knnHits, req, coll)

// set the bm25 presearch data (stats important for consistent scoring) in
// the context object
if bm25Data != nil {
ctx = context.WithValue(ctx, search.BM25PreSearchDataKey, bm25Data)
}
Expand Down
15 changes: 5 additions & 10 deletions index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,20 +464,15 @@ func TestBM25(t *testing.T) {
t.Fatal(err)
}

res, err = multiPartIndex.Search(searchRequest)
ctx := context.Background()
// not setting this doesn't perform a presearch for bm25
ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch)

res, err = multiPartIndex.SearchInContext(ctx, searchRequest)
if err != nil {
t.Error(err)
}

// ctx := context.Background()
// ctx = context.WithValue(ctx, search.PreSearchKey,
// search.SearcherStartCallbackFn(bleveCtxSearcherStartCallback))

// res, err = multiPartIndex.SearchInContext(ctx, searchRequest)
// if err != nil {
// t.Error(err)
// }

fmt.Println("length of hits alias search", res.Hits[0].Score)

}
Expand Down
51 changes: 32 additions & 19 deletions search/scorer/scorer_term.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,36 @@ func (s *TermQueryScorer) Size() int {
return sizeInBytes
}

func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal,
docTerm uint64, avgDocLength float64, options search.SearcherOptions) *TermQueryScorer {

var idfVal float64
func (s *TermQueryScorer) computeIDF(avgDocLength float64, docTotal, docTerm uint64) float64 {
var rv float64
if avgDocLength > 0 {
// avgDocLength is set only for bm25 scoring
idfVal = math.Log(1 + (float64(docTotal)-float64(docTerm)+0.5)/(float64(docTerm)+0.5))
rv = math.Log(1 + (float64(docTotal)-float64(docTerm)+0.5)/
(float64(docTerm)+0.5))
} else {
idfVal = 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0))
rv = 1.0 + math.Log(float64(docTotal)/
float64(docTerm+1.0))
}

return rv
}

func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal,
docTerm uint64, avgDocLength float64, options search.SearcherOptions) *TermQueryScorer {

rv := TermQueryScorer{
queryTerm: string(queryTerm),
queryField: queryField,
queryBoost: queryBoost,
docTerm: docTerm,
docTotal: docTotal,
avgDocLength: avgDocLength,
idf: idfVal,
options: options,
queryWeight: 1.0,
includeScore: options.Score != "none",
}

rv.idf = rv.computeIDF(avgDocLength, docTotal, docTerm)
if options.Explain {
rv.idfExplanation = &search.Explanation{
Value: rv.idf,
Expand Down Expand Up @@ -126,6 +132,24 @@ func (s *TermQueryScorer) SetQueryNorm(qnorm float64) {
}
}

func (s *TermQueryScorer) docScore(tf, norm float64) float64 {
// tf-idf scoring by default
score := tf * norm * s.idf
if s.avgDocLength > 0 {
// bm25 scoring
// using the posting's norm value to recompute the field length for the doc num
fieldLength := 1 / (norm * norm)

// multiplies deciding how much does a doc length affect the score and also
// how much can the term frequency affect the score
var k1 float64 = 1
var b float64 = 1
score = s.idf * (tf * k1) /
(tf + k1*(1-b+(b*fieldLength/s.avgDocLength)))
}
return score
}

func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.TermFieldDoc) *search.DocumentMatch {
rv := ctx.DocumentMatchPool.Get()
// perform any score computations only when needed
Expand All @@ -138,18 +162,7 @@ func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.Term
tf = math.Sqrt(float64(termMatch.Freq))
}

// tf-idf scoring by default
score := tf * termMatch.Norm * s.idf
if s.avgDocLength > 0 {
// using the posting's norm value to recompute the field length for the doc num
fieldLength := 1 / (termMatch.Norm * termMatch.Norm)

// multipliers. todo: these are something to be set in the scorer by parent layer
var k float64 = 1
var b float64 = 1
score = s.idf * (tf * k) / (tf + k*(1-b+(b*fieldLength/s.avgDocLength)))
}

score := s.docScore(tf, termMatch.Norm)
// todo: explain stuff properly
if s.options.Explain {
childrenExplanations := make([]*search.Explanation, 3)
Expand Down
10 changes: 5 additions & 5 deletions search/searcher/search_term.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ type TermSearcher struct {
tfd index.TermFieldDoc
}

func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, term string, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) {
func NewTermSearcher(ctx context.Context, indexReader index.IndexReader,
term string, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) {
if isTermQuery(ctx) {
ctx = context.WithValue(ctx, search.QueryTypeKey, search.Term)
}
return NewTermSearcherBytes(ctx, indexReader, []byte(term), field, boost, options)
}

func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) {
func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader,
term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) {
needFreqNorm := options.Score != "none"
reader, err := indexReader.TermFieldReader(ctx, term, field, needFreqNorm, needFreqNorm, options.IncludeTermVectors)
if err != nil {
Expand Down Expand Up @@ -89,9 +91,7 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade
fmt.Println("average doc length for", field, "is", fieldCardinality/int(count))
}

// in case of bm25 need to fetch the multipliers as well (maybe something set in index mapping?)
// fieldMapping := m.FieldMappingForPath(q.VectorField)
// but tbd how to pass on the field mapping here, can we pass it (the multipliers) in the context?
// in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data)
}
scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), float64(fieldCardinality/int(count)), options)
return &TermSearcher{
Expand Down
5 changes: 3 additions & 2 deletions search/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,13 @@ const BM25PreSearchDataKey = "_bm25_pre_search_data_key"

const PreSearchKey = "_presearch_key"

const SearchTypeKey = "_search_type_key"
const FetchStatsAndSearch = "fetch_stats_and_search"

type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation)

type SearcherStartCallbackFn func(size uint64) error
type SearcherEndCallbackFn func(size uint64) error

const BM25MapKey = "_bm25_map_key"

const SearcherStartCallbackKey = "_searcher_start_callback_key"
const SearcherEndCallbackKey = "_searcher_end_callback_key"

0 comments on commit 3a3acfc

Please sign in to comment.