diff --git a/index_alias_impl.go b/index_alias_impl.go index ae0097b7d..eccdbf9ec 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -174,7 +174,7 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest // and NOT a real search flags := &preSearchFlags{ knn: requestHasKNN(req), // set knn flag if the request has KNN - bm25: true, // TODO Just force setting it to true to test + bm25: true, } return preSearchDataSearch(ctx, req, flags, i.indexes...) } diff --git a/index_impl.go b/index_impl.go index d03da3916..903e6774f 100644 --- a/index_impl.go +++ b/index_impl.go @@ -563,6 +563,13 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr setKnnHitsInCollector(knnHits, req, coll) + fieldMappingCallback := func(field string) string { + rv := i.m.FieldMappingForPath(field) + return rv.Similarity + } + ctx = context.WithValue(ctx, search.GetSimilarityModelCallbackKey, + search.GetSimilarityModelCallbackFn(fieldMappingCallback)) + // set the bm25 presearch data (stats important for consistent scoring) in // the context object if bm25Data != nil { diff --git a/index_test.go b/index_test.go index 70e4b4588..0603d2e29 100644 --- a/index_test.go +++ b/index_test.go @@ -466,7 +466,7 @@ func TestBM25(t *testing.T) { ctx := context.Background() // not setting this doesn't perform a presearch for bm25 - ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch) + // ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch) res, err = multiPartIndex.SearchInContext(ctx, searchRequest) if err != nil { diff --git a/mapping/field.go b/mapping/field.go index 5c064fddd..8efb52556 100644 --- a/mapping/field.go +++ b/mapping/field.go @@ -74,8 +74,8 @@ type FieldMapping struct { Dims int `json:"dims,omitempty"` // Similarity is the similarity algorithm used for scoring - // vector fields. - // See: index.DefaultSimilarityMetric & index.SupportedSimilarityMetrics + // field's content while performing search on it. + // See: index.SimilarityModels Similarity string `json:"similarity,omitempty"` // Applicable to vector fields only - optimization string diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index dbfde1fb0..20cbac6a8 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -204,7 +204,7 @@ func validateVectorFieldAlias(field *FieldMapping, parentName string, } if field.Similarity == "" { - field.Similarity = index.DefaultSimilarityMetric + field.Similarity = index.DefaultVectorSimilarityMetric } if field.VectorIndexOptimizedFor == "" { @@ -249,10 +249,10 @@ func validateVectorFieldAlias(field *FieldMapping, parentName string, MinVectorDims, MaxVectorDims) } - if _, ok := index.SupportedSimilarityMetrics[field.Similarity]; !ok { + if _, ok := index.SupportedVectorSimilarityMetrics[field.Similarity]; !ok { return fmt.Errorf("field: '%s', invalid similarity "+ "metric: '%s', valid metrics are: %+v", field.Name, field.Similarity, - reflect.ValueOf(index.SupportedSimilarityMetrics).MapKeys()) + reflect.ValueOf(index.SupportedVectorSimilarityMetrics).MapKeys()) } if fieldAliasCtx != nil { // writing to a nil map is unsafe diff --git a/search/query/knn.go b/search/query/knn.go index 4d105d943..8221fbcea 100644 --- a/search/query/knn.go +++ b/search/query/knn.go @@ -82,7 +82,7 @@ func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader, fieldMapping := m.FieldMappingForPath(q.VectorField) similarityMetric := fieldMapping.Similarity if similarityMetric == "" { - similarityMetric = index.DefaultSimilarityMetric + similarityMetric = index.DefaultVectorSimilarityMetric } if q.K <= 0 || len(q.Vector) == 0 { return nil, fmt.Errorf("k must be greater than 0 and vector must be non-empty") diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index f7bffabbb..b84f3705a 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -57,41 +57,81 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, return newTermSearcherFromReader(ctx, indexReader, reader, term, field, boost, options) } +func tfTDFScoreMetrics(indexReader index.IndexReader) (uint64, int, error) { + // default tf-idf stats + count, err := indexReader.DocCount() + if err != nil { + return 0, 0, err + } + fieldCardinality := 0 + return count, fieldCardinality, nil +} + +func bm25ScoreMetrics(ctx context.Context, field string, + indexReader index.IndexReader) (uint64, int, error) { + var count uint64 + var fieldCardinality int + var err error + + bm25Stats, ok := ctx.Value(search.BM25PreSearchDataKey).(map[string]interface{}) + if !ok { + count, err = indexReader.DocCount() + if err != nil { + return 0, 0, err + } + dict, err := indexReader.FieldDict(field) + if err != nil { + return 0, 0, err + } + fieldCardinality = dict.Cardinality() + } else { + count = bm25Stats["docCount"].(uint64) + fieldCardinalityMap := bm25Stats["fieldCardinality"].(map[string]int) + fieldCardinality, ok = fieldCardinalityMap[field] + if !ok { + return 0, 0, fmt.Errorf("field stat for bm25 not present %s", field) + } + } + + fmt.Println("----------bm25 stats--------") + fmt.Println("docCount: ", count) + fmt.Println("fieldCardinality: ", fieldCardinality) + fmt.Println("avgDocLength: ", fieldCardinality/int(count)) + + return count, fieldCardinality, nil +} + func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, reader index.TermFieldReader, term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { var count uint64 var fieldCardinality int + var err error if ctx != nil { - bm25Stats, ok := ctx.Value(search.BM25PreSearchDataKey).(map[string]interface{}) - if !ok { - var err error - count, err = indexReader.DocCount() - if err != nil { - _ = reader.Close() - return nil, err + if similaritModelCallback, ok := ctx.Value(search. + GetSimilarityModelCallbackKey).(search.GetSimilarityModelCallbackFn); ok { + similarityModel := similaritModelCallback(field) + if similarityModel == "" || similarityModel == index.BM25Similarity { + // in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data) + count, fieldCardinality, err = bm25ScoreMetrics(ctx, field, indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } + } else { + count, fieldCardinality, err = tfTDFScoreMetrics(indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } } - dict, err := indexReader.FieldDict(field) + } else { + // default tf-idf stats + count, fieldCardinality, err = tfTDFScoreMetrics(indexReader) if err != nil { - _ = indexReader.Close() + _ = reader.Close() return nil, err } - fieldCardinality = dict.Cardinality() - fmt.Println("------------------") - fmt.Println("the num docs", count) - fmt.Println("the field cardinality", fieldCardinality) - } else { - fmt.Printf("fetched from ctx \n") - count = bm25Stats["docCount"].(uint64) - fieldCardinalityMap := bm25Stats["fieldCardinality"].(map[string]int) - fieldCardinality, ok = fieldCardinalityMap[field] - if !ok { - return nil, fmt.Errorf("field stat for bm25 not present %s", field) - } - - fmt.Println("average doc length for", field, "is", fieldCardinality/int(count)) } - - // in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data) } scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), float64(fieldCardinality/int(count)), options) return &TermSearcher{ diff --git a/search/util.go b/search/util.go index de72bacb5..307060fda 100644 --- a/search/util.go +++ b/search/util.go @@ -143,10 +143,14 @@ const PreSearchKey = "_presearch_key" const SearchTypeKey = "_search_type_key" const FetchStatsAndSearch = "fetch_stats_and_search" -type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) +const SearcherStartCallbackKey = "_searcher_start_callback_key" +const SearcherEndCallbackKey = "_searcher_end_callback_key" type SearcherStartCallbackFn func(size uint64) error type SearcherEndCallbackFn func(size uint64) error -const SearcherStartCallbackKey = "_searcher_start_callback_key" -const SearcherEndCallbackKey = "_searcher_end_callback_key" +const GetSimilarityModelCallbackKey = "_get_similarity_model" + +type GetSimilarityModelCallbackFn func(field string) string + +type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation)