Skip to content

Commit

Permalink
bug fixes and BM25 UT pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Thejas-bhat committed Dec 10, 2024
1 parent c020f10 commit 6ca0e3e
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 15 deletions.
14 changes: 7 additions & 7 deletions index_alias_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package bleve

import (
"context"
"fmt"
"sync"
"time"

Expand Down Expand Up @@ -221,6 +222,8 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest
if err != nil {
return nil, err
}

fmt.Println("presearch result", preSearchResult.docCount)
// check if the preSearch result has any errors and if so
// return the search result as is without executing the query
// so that the errors are not lost
Expand Down Expand Up @@ -549,7 +552,7 @@ type asyncSearchResult struct {
// preSearchFlags is a struct to hold flags indicating why preSearch is required
type preSearchFlags struct {
knn bool
bm25 bool // needs presearch for this too
bm25 bool
}

// preSearchRequired checks if preSearch is required and returns the presearch flags struct
Expand All @@ -558,13 +561,10 @@ func preSearchRequired(req *SearchRequest, m mapping.IndexMapping) *preSearchFla
// Check for KNN query
knn := requestHasKNN(req)
var bm25 bool
if !isMatchNoneQuery(req.Query) {
// todo fix this cuRRENTLY ALL INDEX mappings are BM25 mappings, need to fix
// this is just a placeholder.
if _, ok := m.(mapping.BM25Mapping); ok {
bm25 = true
}
if _, ok := m.(mapping.BM25Mapping); ok {
bm25 = true
}

if knn || bm25 {
return &preSearchFlags{
knn: knn,
Expand Down
8 changes: 4 additions & 4 deletions index_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr
}

var knnHits []*search.DocumentMatch
var bm25TotalDocs uint64
var bm25Data map[string]interface{}
var ok bool
var skipKnnCollector bool
if req.PreSearchData != nil {
Expand All @@ -546,7 +546,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr
skipKnnCollector = true
case search.BM25PreSearchDataKey:
if v != nil {
bm25TotalDocs, ok = v.(uint64)
bm25Data, ok = v.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("bm25 preSearchData must be of type uint64")
}
Expand All @@ -563,8 +563,8 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr

setKnnHitsInCollector(knnHits, req, coll)

if bm25TotalDocs > 0 {
ctx = context.WithValue(ctx, search.BM25MapKey, bm25TotalDocs)
if bm25Data != nil {
ctx = context.WithValue(ctx, search.BM25PreSearchDataKey, bm25Data)
}

// This callback and variable handles the tracking of bytes read
Expand Down
143 changes: 141 additions & 2 deletions index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,138 @@ func TestBytesWritten(t *testing.T) {
cleanupTmpIndexPath(t, tmpIndexPath4)
}

func TestBM25(t *testing.T) {
tmpIndexPath := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPath)

indexMapping := NewIndexMapping()
indexMapping.TypeField = "type"
indexMapping.DefaultAnalyzer = "en"
documentMapping := NewDocumentMapping()
indexMapping.AddDocumentMapping("hotel", documentMapping)
indexMapping.StoreDynamic = false
indexMapping.DocValuesDynamic = false
contentFieldMapping := NewTextFieldMapping()
contentFieldMapping.Store = false

reviewsMapping := NewDocumentMapping()
reviewsMapping.AddFieldMappingsAt("content", contentFieldMapping)
documentMapping.AddSubDocumentMapping("reviews", reviewsMapping)

typeFieldMapping := NewTextFieldMapping()
typeFieldMapping.Store = false
documentMapping.AddFieldMappingsAt("type", typeFieldMapping)

idxSinglePartition, err := NewUsing(tmpIndexPath, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil)
if err != nil {
t.Fatal(err)
}

defer func() {
err := idxSinglePartition.Close()
if err != nil {
t.Fatal(err)
}
}()

batch, err := getBatchFromData(idxSinglePartition, "sample-data.json")
if err != nil {
t.Fatalf("failed to form a batch")
}
err = idxSinglePartition.Batch(batch)
if err != nil {
t.Fatalf("failed to index batch %v\n", err)
}
query := NewMatchQuery("Apartments")
query.FieldVal = "name"
searchRequest := NewSearchRequestOptions(query, int(10), 0, true)

res, err := idxSinglePartition.Search(searchRequest)
if err != nil {
t.Error(err)
}

fmt.Println("length of hits", res.Hits[0].Score)
dataset, _ := readDataFromFile("sample-data.json")
fmt.Println("length of dataset", len(dataset))
tmpIndexPath1 := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPath1)

idxPart1, err := NewUsing(tmpIndexPath1, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil)
if err != nil {
t.Fatal(err)
}

defer func() {
err := idxPart1.Close()
if err != nil {
t.Fatal(err)
}
}()

batch1 := idxPart1.NewBatch()
for _, doc := range dataset[:len(dataset)/2] {
err = batch1.Index(fmt.Sprintf("%d", doc["id"]), doc)
if err != nil {
t.Fatal(err)
}
}
err = idxPart1.Batch(batch1)
if err != nil {
t.Fatal(err)
}

tmpIndexPath2 := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPath2)

idxPart2, err := NewUsing(tmpIndexPath2, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil)
if err != nil {
t.Fatal(err)
}

defer func() {
err := idxPart2.Close()
if err != nil {
t.Fatal(err)
}
}()

batch2 := idxPart2.NewBatch()
for _, doc := range dataset[len(dataset)/2:] {
err = batch2.Index(fmt.Sprintf("%d", doc["id"]), doc)
if err != nil {
t.Fatal(err)
}
}
err = idxPart2.Batch(batch2)
if err != nil {
t.Fatal(err)
}

multiPartIndex := NewIndexAlias(idxPart1, idxPart2)
err = multiPartIndex.SetIndexMapping(indexMapping)
if err != nil {
t.Fatal(err)
}

res, err = multiPartIndex.Search(searchRequest)
if err != nil {
t.Error(err)
}

// ctx := context.Background()
// ctx = context.WithValue(ctx, search.PreSearchKey,
// search.SearcherStartCallbackFn(bleveCtxSearcherStartCallback))

// res, err = multiPartIndex.SearchInContext(ctx, searchRequest)
// if err != nil {
// t.Error(err)
// }

fmt.Println("length of hits alias search", res.Hits[0].Score)

}

func TestBytesRead(t *testing.T) {
tmpIndexPath := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPath)
Expand Down Expand Up @@ -671,23 +803,30 @@ func TestBytesReadStored(t *testing.T) {
}
}

func getBatchFromData(idx Index, fileName string) (*Batch, error) {
func readDataFromFile(fileName string) ([]map[string]interface{}, error) {
pwd, err := os.Getwd()
if err != nil {
return nil, err
}
path := filepath.Join(pwd, "data", "test", fileName)
batch := idx.NewBatch()

var dataset []map[string]interface{}
fileContent, err := os.ReadFile(path)
if err != nil {
return nil, err
}

err = json.Unmarshal(fileContent, &dataset)
if err != nil {
return nil, err
}

return dataset, nil
}

func getBatchFromData(idx Index, fileName string) (*Batch, error) {
dataset, err := readDataFromFile(fileName)
batch := idx.NewBatch()
for _, doc := range dataset {
err = batch.Index(fmt.Sprintf("%d", doc["id"]), doc)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions mapping/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,7 @@ func (im *IndexMappingImpl) FieldMappingForPath(path string) FieldMapping {
func (im *IndexMappingImpl) DefaultSearchField() string {
return im.DefaultField
}

func (im *IndexMappingImpl) BM25Impl() {
fmt.Println("BM25Impl")
}
2 changes: 2 additions & 0 deletions mapping/mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,6 @@ type IndexMapping interface {
}
type BM25Mapping interface {
IndexMapping

BM25Impl()
}
4 changes: 3 additions & 1 deletion pre_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package bleve

import "fmt"

// A preSearchResultProcessor processes the data in
// the preSearch result from multiple
// indexes in an alias and merges them together to
Expand Down Expand Up @@ -60,7 +62,7 @@ func newBM25PreSearchResultProcessor() *bm25PreSearchResultProcessor {
// TODO How will this work for queries other than term queries?
func (b *bm25PreSearchResultProcessor) add(sr *SearchResult, indexName string) {
b.docCount += (sr.docCount)

fmt.Println("docCount: ", b.docCount)
for field, cardinality := range sr.fieldCardinality {
b.fieldCardinality[field] += cardinality
}
Expand Down
4 changes: 3 additions & 1 deletion search/searcher/search_term.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade
return nil, err
}
fieldCardinality = dict.Cardinality()

Check failure on line 76 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.20.x, ubuntu-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)

Check failure on line 76 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.20.x, macos-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)

Check failure on line 76 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.21.x, ubuntu-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)

Check failure on line 76 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.21.x, macos-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)

Check failure on line 76 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.22.x, ubuntu-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)

Check failure on line 76 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.22.x, macos-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)
fmt.Println("average doc length for", field, "is", fieldCardinality/int(count))
fmt.Println("------------------")
fmt.Println("the num docs", count)
fmt.Println("the field cardinality", fieldCardinality)
} else {
fmt.Printf("fetched from ctx \n")
count = bm25Stats["docCount"].(uint64)
Expand Down

0 comments on commit 6ca0e3e

Please sign in to comment.