From 00b5d9fad4929600a99f2b4c68169fadf0781bec Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Fri, 17 Nov 2023 10:05:26 +0530 Subject: [PATCH] Add KNN Operator --- index_impl.go | 10 ++++++---- search_knn.go | 36 ++++++++++++++++++++++++++++++------ search_no_knn.go | 4 ++-- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/index_impl.go b/index_impl.go index fe3a62e9e..3865dc260 100644 --- a/index_impl.go +++ b/index_impl.go @@ -496,10 +496,12 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr ctx = context.WithValue(ctx, search.GeoBufferPoolCallbackKey, search.GeoBufferPoolCallbackFunc(getBufferPool)) - - // Using a disjunction query to get union of results from KNN query - // and the original query - searchQuery := disjunctQueryWithKNN(req) + // Using a query to get results from KNN queries + // and the original query based on the KNN operator. + searchQuery, err := queryWithKNN(req) + if err != nil { + return nil, err + } searcher, err := searchQuery.Searcher(ctx, indexReader, i.m, search.SearcherOptions{ Explain: req.Explain, diff --git a/search_knn.go b/search_knn.go index a2f8d343c..3528af438 100644 --- a/search_knn.go +++ b/search_knn.go @@ -25,6 +25,8 @@ import ( "github.com/blevesearch/bleve/v2/search/query" ) +type knnOperator string + type SearchRequest struct { Query query.Query `json:"query"` Size int `json:"size"` @@ -39,7 +41,8 @@ type SearchRequest struct { SearchAfter []string `json:"search_after"` SearchBefore []string `json:"search_before"` - KNN []*KNNRequest `json:"knn"` + KNN []*KNNRequest `json:"knn"` + KNNOperator knnOperator `json:"knn_operator,omitempty"` sortFunc func(sort.Interface) } @@ -61,6 +64,10 @@ func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost fl }) } +func (r *SearchRequest) AddKNNOperator(operator knnOperator) { + r.KNNOperator = operator +} + // UnmarshalJSON deserializes a JSON representation of // a SearchRequest func (r *SearchRequest) UnmarshalJSON(input []byte) error { @@ -78,6 +85,7 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { SearchAfter []string `json:"search_after"` SearchBefore []string `json:"search_before"` KNN []*KNNRequest `json:"knn"` + KNNOperator knnOperator `json:"knn_operator"` } err := json.Unmarshal(input, &temp) @@ -120,6 +128,10 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { } r.KNN = temp.KNN + r.KNNOperator = temp.KNNOperator + if r.KNNOperator == "" { + r.KNNOperator = knnOperatorOr + } return nil @@ -142,24 +154,36 @@ func copySearchRequest(req *SearchRequest) *SearchRequest { SearchAfter: req.SearchAfter, SearchBefore: req.SearchBefore, KNN: req.KNN, + KNNOperator: req.KNNOperator, } return &rv } -func disjunctQueryWithKNN(req *SearchRequest) query.Query { +var ( + knnOperatorAnd = knnOperator("and") + knnOperatorOr = knnOperator("or") +) + +func queryWithKNN(req *SearchRequest) (query.Query, error) { if len(req.KNN) > 0 { - disjuncts := []query.Query{req.Query} + subQueries := []query.Query{req.Query} for _, knn := range req.KNN { if knn != nil { knnQuery := query.NewKNNQuery(knn.Vector) knnQuery.SetFieldVal(knn.Field) knnQuery.SetK(knn.K) knnQuery.SetBoost(knn.Boost.Value()) - disjuncts = append(disjuncts, knnQuery) + subQueries = append(subQueries, knnQuery) } } - return query.NewDisjunctionQuery(disjuncts) + if req.KNNOperator == knnOperatorAnd { + return query.NewConjunctionQuery(subQueries), nil + } else if req.KNNOperator == knnOperatorOr { + return query.NewDisjunctionQuery(subQueries), nil + } else { + return nil, fmt.Errorf("unknown knn operator: %s", req.KNNOperator) + } } - return req.Query + return req.Query, nil } diff --git a/search_no_knn.go b/search_no_knn.go index fb3814911..ad7b35220 100644 --- a/search_no_knn.go +++ b/search_no_knn.go @@ -144,6 +144,6 @@ func copySearchRequest(req *SearchRequest) *SearchRequest { return &rv } -func disjunctQueryWithKNN(req *SearchRequest) query.Query { - return req.Query +func queryWithKNN(req *SearchRequest) (query.Query, error) { + return req.Query, nil }