diff --git a/index_impl.go b/index_impl.go index 5c9538822..87ecf4c32 100644 --- a/index_impl.go +++ b/index_impl.go @@ -496,9 +496,12 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr ctx = context.WithValue(ctx, search.GeoBufferPoolCallbackKey, search.GeoBufferPoolCallbackFunc(getBufferPool)) - // Using a disjunction query to get union of results from KNN query - // and the original query - searchQuery := disjunctQueryWithKNN(req) + // Using a query to get results from KNN queries + // and the original query based on the KNN operator. + searchQuery, err := queryWithKNN(req) + if err != nil { + return nil, err + } searcher, err := searchQuery.Searcher(ctx, indexReader, i.m, search.SearcherOptions{ Explain: req.Explain, diff --git a/search_knn.go b/search_knn.go index 0e20d8d99..3e5b09e19 100644 --- a/search_knn.go +++ b/search_knn.go @@ -26,6 +26,8 @@ import ( "github.com/blevesearch/bleve/v2/search/query" ) +type knnOperator string + type SearchRequest struct { Query query.Query `json:"query"` Size int `json:"size"` @@ -40,7 +42,8 @@ type SearchRequest struct { SearchAfter []string `json:"search_after"` SearchBefore []string `json:"search_before"` - KNN []*KNNRequest `json:"knn"` + KNN []*KNNRequest `json:"knn"` + KNNOperator knnOperator `json:"knn_operator"` sortFunc func(sort.Interface) } @@ -62,6 +65,10 @@ func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost fl }) } +func (r *SearchRequest) AddKNNOperator(operator knnOperator) { + r.KNNOperator = operator +} + // UnmarshalJSON deserializes a JSON representation of // a SearchRequest func (r *SearchRequest) UnmarshalJSON(input []byte) error { @@ -79,6 +86,7 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { SearchAfter []string `json:"search_after"` SearchBefore []string `json:"search_before"` KNN []*KNNRequest `json:"knn"` + KNNOperator knnOperator `json:"knn_operator"` } err := json.Unmarshal(input, &temp) @@ -121,6 +129,10 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { } r.KNN = temp.KNN + r.KNNOperator = temp.KNNOperator + if r.KNNOperator == "" { + r.KNNOperator = knnOperatorOr + } return nil @@ -143,26 +155,38 @@ func copySearchRequest(req *SearchRequest) *SearchRequest { SearchAfter: req.SearchAfter, SearchBefore: req.SearchBefore, KNN: req.KNN, + KNNOperator: req.KNNOperator, } return &rv } -func disjunctQueryWithKNN(req *SearchRequest) query.Query { +var ( + knnOperatorAnd = knnOperator("and") + knnOperatorOr = knnOperator("or") +) + +func queryWithKNN(req *SearchRequest) (query.Query, error) { if len(req.KNN) > 0 { - disjuncts := []query.Query{req.Query} + subQueries := []query.Query{req.Query} for _, knn := range req.KNN { if knn != nil { knnQuery := query.NewKNNQuery(knn.Vector) knnQuery.SetFieldVal(knn.Field) knnQuery.SetK(knn.K) knnQuery.SetBoost(knn.Boost.Value()) - disjuncts = append(disjuncts, knnQuery) + subQueries = append(subQueries, knnQuery) } } - return query.NewDisjunctionQuery(disjuncts) + if req.KNNOperator == knnOperatorAnd { + return query.NewConjunctionQuery(subQueries), nil + } else if req.KNNOperator == knnOperatorOr || req.KNNOperator == "" { + return query.NewDisjunctionQuery(subQueries), nil + } else { + return nil, fmt.Errorf("unknown knn operator: %s", req.KNNOperator) + } } - return req.Query + return req.Query, nil } func validateKNN(req *SearchRequest) error { diff --git a/search_knn_test.go b/search_knn_test.go new file mode 100644 index 000000000..3cb737deb --- /dev/null +++ b/search_knn_test.go @@ -0,0 +1,133 @@ +//go:build vectors +// +build vectors + +package bleve + +import ( + "fmt" + "log" + "math/rand" + "strconv" + "testing" + + "github.com/blevesearch/bleve/v2/mapping" + "github.com/blevesearch/bleve/v2/search/query" +) + +// Test to see if KNN Operators get added right to the query. +func TestKNNOperator(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + dataset := make([]map[string]interface{}, 10) + + // Indexing just a few docs to populate index. + for i := 0; i < 10; i++ { + docVec := []float32{} + for i := 0; i < 5; i++ { + docVec = append(docVec, rand.Float32()) + } + dataset = append(dataset, map[string]interface{}{ + "type": "vectorStuff", + "content": strconv.Itoa(i), + "vector": docVec, + }) + } + + indexMapping := NewIndexMapping() + indexMapping.TypeField = "type" + indexMapping.DefaultAnalyzer = "en" + documentMapping := NewDocumentMapping() + indexMapping.AddDocumentMapping("vectorStuff", documentMapping) + + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Index = true + contentFieldMapping.Store = true + documentMapping.AddFieldMappingsAt("content", contentFieldMapping) + + vecFieldMapping := mapping.NewVectorFieldMapping() + vecFieldMapping.Index = true + vecFieldMapping.Dims = 5 + vecFieldMapping.Similarity = "dot_product" + documentMapping.AddFieldMappingsAt("vector", vecFieldMapping) + + index, err := New(tmpIndexPath, indexMapping) + if err != nil { + log.Fatal(err) + } + defer func() { + err := index.Close() + if err != nil { + log.Fatal(err) + } + }() + + batch := index.NewBatch() + for i := 0; i < len(dataset); i++ { + batch.Index(strconv.Itoa(i), dataset[i]) + } + + err = index.Batch(batch) + if err != nil { + log.Fatal(err) + } + + termQuery := query.NewTermQuery("world") + + searchRequest := NewSearchRequest(termQuery) + queryVec2 := getQueryVec("hilly region worldwide") + searchRequest.AddKNN("vector", queryVec2, 3, 2.0) + searchRequest.AddKNN("vector", queryVec2, 2, 1.5) + searchRequest.Fields = []string{"content", "vector"} + + // Conjunction + searchRequest.AddKNNOperator(knnOperatorAnd) + conjunction, err := queryWithKNN(searchRequest) + if err != nil { + log.Fatal(fmt.Errorf("unexpected error for AND knn operator")) + } + + conj, ok := conjunction.(*query.ConjunctionQuery) + if !ok { + log.Fatal(fmt.Errorf("expected conjunction query")) + } + + if len(conj.Conjuncts) == 3 { + _, ok := conj.Conjuncts[0].(*query.TermQuery) + if !ok { + log.Fatal(fmt.Errorf("expected first query to be a term query,"+ + " but it's %T", conj.Conjuncts[0])) + } + } else { + log.Fatal(fmt.Errorf("expected 3 conjuncts")) + } + + // Disjunction + searchRequest.AddKNNOperator(knnOperatorOr) + disjunction, err := queryWithKNN(searchRequest) + if err != nil { + log.Fatal(fmt.Errorf("unexpected error for OR knn operator")) + } + + disj, ok := disjunction.(*query.DisjunctionQuery) + if !ok { + log.Fatal(fmt.Errorf("expected disjunction query")) + } + + if len(disj.Disjuncts) == 3 { + _, ok := disj.Disjuncts[0].(*query.TermQuery) + if !ok { + log.Fatal(fmt.Errorf("expected first query to be a term query,"+ + " but it's %T", conj.Conjuncts[0])) + } + } else { + log.Fatal(fmt.Errorf("expected 3 disjuncts")) + } + + // Incorrect operator. + searchRequest.AddKNNOperator("bs_op") + searchRequest.Query, err = queryWithKNN(searchRequest) + if err == nil { + log.Fatal(fmt.Errorf("expected error for incorrect knn operator")) + } +} diff --git a/search_no_knn.go b/search_no_knn.go index 4600f8748..37354f0fd 100644 --- a/search_no_knn.go +++ b/search_no_knn.go @@ -144,8 +144,8 @@ func copySearchRequest(req *SearchRequest) *SearchRequest { return &rv } -func disjunctQueryWithKNN(req *SearchRequest) query.Query { - return req.Query +func queryWithKNN(req *SearchRequest) (query.Query, error) { + return req.Query, nil } func validateKNN(req *SearchRequest) error {