diff --git a/index_alias_impl.go b/index_alias_impl.go index 0761eae81..ae0097b7d 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -16,7 +16,6 @@ package bleve import ( "context" - "fmt" "sync" "time" @@ -215,7 +214,7 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest // - the request requires preSearch var preSearchDuration time.Duration var sr *SearchResult - flags := preSearchRequired(req, i.mapping) + flags := preSearchRequired(ctx, req, i.mapping) if req.PreSearchData == nil && flags != nil { searchStart := time.Now() preSearchResult, err := preSearch(ctx, req, flags, i.indexes...) @@ -223,7 +222,6 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest return nil, err } - fmt.Println("presearch result", preSearchResult.docCount) // check if the preSearch result has any errors and if so // return the search result as is without executing the query // so that the errors are not lost @@ -557,14 +555,22 @@ type preSearchFlags struct { // preSearchRequired checks if preSearch is required and returns the presearch flags struct // indicating which preSearch is required -func preSearchRequired(req *SearchRequest, m mapping.IndexMapping) *preSearchFlags { +func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexMapping) *preSearchFlags { // Check for KNN query knn := requestHasKNN(req) var bm25 bool - if _, ok := m.(mapping.BM25Mapping); ok { - bm25 = true + if !isMatchNoneQuery(req.Query) { + if ctx != nil { + if searchType := ctx.Value(search.SearchTypeKey); searchType != nil { + if searchType.(string) == search.FetchStatsAndSearch { + // todo: check mapping to see if bm25 is needed + if _, ok := m.(mapping.BM25Mapping); ok { + bm25 = true + } + } + } + } } - if knn || bm25 { return &preSearchFlags{ knn: knn, diff --git a/index_impl.go b/index_impl.go index b4da3baa0..d03da3916 100644 --- a/index_impl.go +++ b/index_impl.go @@ -548,7 +548,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr if v != nil { bm25Data, ok = v.(map[string]interface{}) if !ok { - return nil, fmt.Errorf("bm25 preSearchData must be of type uint64") + return nil, fmt.Errorf("bm25 preSearchData must be of type map[string]interface{}") } } } @@ -563,6 +563,8 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr setKnnHitsInCollector(knnHits, req, coll) + // set the bm25 presearch data (stats important for consistent scoring) in + // the context object if bm25Data != nil { ctx = context.WithValue(ctx, search.BM25PreSearchDataKey, bm25Data) } diff --git a/index_test.go b/index_test.go index 61dbd0bb4..70e4b4588 100644 --- a/index_test.go +++ b/index_test.go @@ -464,20 +464,15 @@ func TestBM25(t *testing.T) { t.Fatal(err) } - res, err = multiPartIndex.Search(searchRequest) + ctx := context.Background() + // not setting this doesn't perform a presearch for bm25 + ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch) + + res, err = multiPartIndex.SearchInContext(ctx, searchRequest) if err != nil { t.Error(err) } - // ctx := context.Background() - // ctx = context.WithValue(ctx, search.PreSearchKey, - // search.SearcherStartCallbackFn(bleveCtxSearcherStartCallback)) - - // res, err = multiPartIndex.SearchInContext(ctx, searchRequest) - // if err != nil { - // t.Error(err) - // } - fmt.Println("length of hits alias search", res.Hits[0].Score) } diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index 7967ac393..f98d4288d 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -62,17 +62,23 @@ func (s *TermQueryScorer) Size() int { return sizeInBytes } -func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, - docTerm uint64, avgDocLength float64, options search.SearcherOptions) *TermQueryScorer { - - var idfVal float64 +func (s *TermQueryScorer) computeIDF(avgDocLength float64, docTotal, docTerm uint64) float64 { + var rv float64 if avgDocLength > 0 { // avgDocLength is set only for bm25 scoring - idfVal = math.Log(1 + (float64(docTotal)-float64(docTerm)+0.5)/(float64(docTerm)+0.5)) + rv = math.Log(1 + (float64(docTotal)-float64(docTerm)+0.5)/ + (float64(docTerm)+0.5)) } else { - idfVal = 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)) + rv = 1.0 + math.Log(float64(docTotal)/ + float64(docTerm+1.0)) } + return rv +} + +func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, + docTerm uint64, avgDocLength float64, options search.SearcherOptions) *TermQueryScorer { + rv := TermQueryScorer{ queryTerm: string(queryTerm), queryField: queryField, @@ -80,12 +86,12 @@ func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTerm: docTerm, docTotal: docTotal, avgDocLength: avgDocLength, - idf: idfVal, options: options, queryWeight: 1.0, includeScore: options.Score != "none", } + rv.idf = rv.computeIDF(avgDocLength, docTotal, docTerm) if options.Explain { rv.idfExplanation = &search.Explanation{ Value: rv.idf, @@ -126,6 +132,24 @@ func (s *TermQueryScorer) SetQueryNorm(qnorm float64) { } } +func (s *TermQueryScorer) docScore(tf, norm float64) float64 { + // tf-idf scoring by default + score := tf * norm * s.idf + if s.avgDocLength > 0 { + // bm25 scoring + // using the posting's norm value to recompute the field length for the doc num + fieldLength := 1 / (norm * norm) + + // multiplies deciding how much does a doc length affect the score and also + // how much can the term frequency affect the score + var k1 float64 = 1 + var b float64 = 1 + score = s.idf * (tf * k1) / + (tf + k1*(1-b+(b*fieldLength/s.avgDocLength))) + } + return score +} + func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.TermFieldDoc) *search.DocumentMatch { rv := ctx.DocumentMatchPool.Get() // perform any score computations only when needed @@ -138,18 +162,7 @@ func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.Term tf = math.Sqrt(float64(termMatch.Freq)) } - // tf-idf scoring by default - score := tf * termMatch.Norm * s.idf - if s.avgDocLength > 0 { - // using the posting's norm value to recompute the field length for the doc num - fieldLength := 1 / (termMatch.Norm * termMatch.Norm) - - // multipliers. todo: these are something to be set in the scorer by parent layer - var k float64 = 1 - var b float64 = 1 - score = s.idf * (tf * k) / (tf + k*(1-b+(b*fieldLength/s.avgDocLength))) - } - + score := s.docScore(tf, termMatch.Norm) // todo: explain stuff properly if s.options.Explain { childrenExplanations := make([]*search.Explanation, 3) diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 7e89f17fc..f7bffabbb 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -39,14 +39,16 @@ type TermSearcher struct { tfd index.TermFieldDoc } -func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, term string, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { +func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, + term string, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { if isTermQuery(ctx) { ctx = context.WithValue(ctx, search.QueryTypeKey, search.Term) } return NewTermSearcherBytes(ctx, indexReader, []byte(term), field, boost, options) } -func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { +func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, + term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { needFreqNorm := options.Score != "none" reader, err := indexReader.TermFieldReader(ctx, term, field, needFreqNorm, needFreqNorm, options.IncludeTermVectors) if err != nil { @@ -89,9 +91,7 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade fmt.Println("average doc length for", field, "is", fieldCardinality/int(count)) } - // in case of bm25 need to fetch the multipliers as well (maybe something set in index mapping?) - // fieldMapping := m.FieldMappingForPath(q.VectorField) - // but tbd how to pass on the field mapping here, can we pass it (the multipliers) in the context? + // in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data) } scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), float64(fieldCardinality/int(count)), options) return &TermSearcher{ diff --git a/search/util.go b/search/util.go index 14c21b5b6..de72bacb5 100644 --- a/search/util.go +++ b/search/util.go @@ -140,12 +140,13 @@ const BM25PreSearchDataKey = "_bm25_pre_search_data_key" const PreSearchKey = "_presearch_key" +const SearchTypeKey = "_search_type_key" +const FetchStatsAndSearch = "fetch_stats_and_search" + type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) type SearcherStartCallbackFn func(size uint64) error type SearcherEndCallbackFn func(size uint64) error -const BM25MapKey = "_bm25_map_key" - const SearcherStartCallbackKey = "_searcher_start_callback_key" const SearcherEndCallbackKey = "_searcher_end_callback_key"