Skip to content

Commit

Permalink
field mapping to capture type of scoring; bm25 by default
Browse files Browse the repository at this point in the history
  • Loading branch information
Thejas-bhat committed Dec 11, 2024
1 parent 9d87631 commit dbe105f
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 36 deletions.
2 changes: 1 addition & 1 deletion index_alias_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest
// and NOT a real search
flags := &preSearchFlags{
knn: requestHasKNN(req), // set knn flag if the request has KNN
bm25: true, // TODO Just force setting it to true to test
bm25: true,
}
return preSearchDataSearch(ctx, req, flags, i.indexes...)
}
Expand Down
7 changes: 7 additions & 0 deletions index_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,13 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr

setKnnHitsInCollector(knnHits, req, coll)

fieldMappingCallback := func(field string) string {
rv := i.m.FieldMappingForPath(field)
return rv.Similarity
}
ctx = context.WithValue(ctx, search.GetSimilarityModelCallbackKey,
search.GetSimilarityModelCallbackFn(fieldMappingCallback))

// set the bm25 presearch data (stats important for consistent scoring) in
// the context object
if bm25Data != nil {
Expand Down
2 changes: 1 addition & 1 deletion index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ func TestBM25(t *testing.T) {

ctx := context.Background()
// not setting this doesn't perform a presearch for bm25
ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch)
// ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch)

res, err = multiPartIndex.SearchInContext(ctx, searchRequest)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions mapping/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ type FieldMapping struct {
Dims int `json:"dims,omitempty"`

// Similarity is the similarity algorithm used for scoring
// vector fields.
// See: index.DefaultSimilarityMetric & index.SupportedSimilarityMetrics
// field's content while performing search on it.
// See: index.SimilarityModels
Similarity string `json:"similarity,omitempty"`

// Applicable to vector fields only - optimization string
Expand Down
6 changes: 3 additions & 3 deletions mapping/mapping_vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func validateVectorFieldAlias(field *FieldMapping, parentName string,
}

if field.Similarity == "" {
field.Similarity = index.DefaultSimilarityMetric
field.Similarity = index.DefaultVectorSimilarityMetric
}

if field.VectorIndexOptimizedFor == "" {
Expand Down Expand Up @@ -249,10 +249,10 @@ func validateVectorFieldAlias(field *FieldMapping, parentName string,
MinVectorDims, MaxVectorDims)
}

if _, ok := index.SupportedSimilarityMetrics[field.Similarity]; !ok {
if _, ok := index.SupportedVectorSimilarityMetrics[field.Similarity]; !ok {
return fmt.Errorf("field: '%s', invalid similarity "+
"metric: '%s', valid metrics are: %+v", field.Name, field.Similarity,
reflect.ValueOf(index.SupportedSimilarityMetrics).MapKeys())
reflect.ValueOf(index.SupportedVectorSimilarityMetrics).MapKeys())
}

if fieldAliasCtx != nil { // writing to a nil map is unsafe
Expand Down
2 changes: 1 addition & 1 deletion search/query/knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader,
fieldMapping := m.FieldMappingForPath(q.VectorField)
similarityMetric := fieldMapping.Similarity
if similarityMetric == "" {
similarityMetric = index.DefaultSimilarityMetric
similarityMetric = index.DefaultVectorSimilarityMetric
}
if q.K <= 0 || len(q.Vector) == 0 {
return nil, fmt.Errorf("k must be greater than 0 and vector must be non-empty")
Expand Down
90 changes: 65 additions & 25 deletions search/searcher/search_term.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,41 +57,81 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader,
return newTermSearcherFromReader(ctx, indexReader, reader, term, field, boost, options)
}

func tfTDFScoreMetrics(indexReader index.IndexReader) (uint64, int, error) {
// default tf-idf stats
count, err := indexReader.DocCount()
if err != nil {
return 0, 0, err
}
fieldCardinality := 0
return count, fieldCardinality, nil
}

func bm25ScoreMetrics(ctx context.Context, field string,
indexReader index.IndexReader) (uint64, int, error) {
var count uint64
var fieldCardinality int
var err error

bm25Stats, ok := ctx.Value(search.BM25PreSearchDataKey).(map[string]interface{})
if !ok {
count, err = indexReader.DocCount()
if err != nil {
return 0, 0, err
}
dict, err := indexReader.FieldDict(field)
if err != nil {
return 0, 0, err
}
fieldCardinality = dict.Cardinality()

Check failure on line 86 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.20.x, ubuntu-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)

Check failure on line 86 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.20.x, macos-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)

Check failure on line 86 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.21.x, ubuntu-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)

Check failure on line 86 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.21.x, macos-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)

Check failure on line 86 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.22.x, ubuntu-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)

Check failure on line 86 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.22.x, macos-latest)

dict.Cardinality undefined (type index.FieldDict has no field or method Cardinality)
} else {
count = bm25Stats["docCount"].(uint64)
fieldCardinalityMap := bm25Stats["fieldCardinality"].(map[string]int)
fieldCardinality, ok = fieldCardinalityMap[field]
if !ok {
return 0, 0, fmt.Errorf("field stat for bm25 not present %s", field)
}
}

fmt.Println("----------bm25 stats--------")
fmt.Println("docCount: ", count)
fmt.Println("fieldCardinality: ", fieldCardinality)
fmt.Println("avgDocLength: ", fieldCardinality/int(count))

return count, fieldCardinality, nil
}

func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, reader index.TermFieldReader,
term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) {
var count uint64
var fieldCardinality int
var err error
if ctx != nil {
bm25Stats, ok := ctx.Value(search.BM25PreSearchDataKey).(map[string]interface{})
if !ok {
var err error
count, err = indexReader.DocCount()
if err != nil {
_ = reader.Close()
return nil, err
if similaritModelCallback, ok := ctx.Value(search.
GetSimilarityModelCallbackKey).(search.GetSimilarityModelCallbackFn); ok {
similarityModel := similaritModelCallback(field)
if similarityModel == "" || similarityModel == index.BM25Similarity {

Check failure on line 113 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.20.x, ubuntu-latest)

undefined: index.BM25Similarity

Check failure on line 113 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.20.x, macos-latest)

undefined: index.BM25Similarity

Check failure on line 113 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.21.x, ubuntu-latest)

undefined: index.BM25Similarity

Check failure on line 113 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.21.x, macos-latest)

undefined: index.BM25Similarity

Check failure on line 113 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.22.x, ubuntu-latest)

undefined: index.BM25Similarity

Check failure on line 113 in search/searcher/search_term.go

View workflow job for this annotation

GitHub Actions / test (1.22.x, macos-latest)

undefined: index.BM25Similarity
// in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data)
count, fieldCardinality, err = bm25ScoreMetrics(ctx, field, indexReader)
if err != nil {
_ = reader.Close()
return nil, err
}
} else {
count, fieldCardinality, err = tfTDFScoreMetrics(indexReader)
if err != nil {
_ = reader.Close()
return nil, err
}
}
dict, err := indexReader.FieldDict(field)
} else {
// default tf-idf stats
count, fieldCardinality, err = tfTDFScoreMetrics(indexReader)
if err != nil {
_ = indexReader.Close()
_ = reader.Close()
return nil, err
}
fieldCardinality = dict.Cardinality()
fmt.Println("------------------")
fmt.Println("the num docs", count)
fmt.Println("the field cardinality", fieldCardinality)
} else {
fmt.Printf("fetched from ctx \n")
count = bm25Stats["docCount"].(uint64)
fieldCardinalityMap := bm25Stats["fieldCardinality"].(map[string]int)
fieldCardinality, ok = fieldCardinalityMap[field]
if !ok {
return nil, fmt.Errorf("field stat for bm25 not present %s", field)
}

fmt.Println("average doc length for", field, "is", fieldCardinality/int(count))
}

// in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data)
}
scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), float64(fieldCardinality/int(count)), options)
return &TermSearcher{
Expand Down
10 changes: 7 additions & 3 deletions search/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,14 @@ const PreSearchKey = "_presearch_key"
const SearchTypeKey = "_search_type_key"
const FetchStatsAndSearch = "fetch_stats_and_search"

type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation)
const SearcherStartCallbackKey = "_searcher_start_callback_key"
const SearcherEndCallbackKey = "_searcher_end_callback_key"

type SearcherStartCallbackFn func(size uint64) error
type SearcherEndCallbackFn func(size uint64) error

const SearcherStartCallbackKey = "_searcher_start_callback_key"
const SearcherEndCallbackKey = "_searcher_end_callback_key"
const GetSimilarityModelCallbackKey = "_get_similarity_model"

type GetSimilarityModelCallbackFn func(field string) string

type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation)

0 comments on commit dbe105f

Please sign in to comment.