From e7e309df17b9166216ac04f678c4505b8773d04d Mon Sep 17 00:00:00 2001
From: Aditi Ahuja <aditi.ahuja@couchbase.com>
Date: Fri, 17 Nov 2023 10:05:26 +0530
Subject: [PATCH] Add KNN Operator

---
 index_impl.go    | 10 ++++++----
 search_knn.go    | 37 +++++++++++++++++++++++++++++++------
 search_no_knn.go |  4 ++--
 3 files changed, 39 insertions(+), 12 deletions(-)

diff --git a/index_impl.go b/index_impl.go
index fe3a62e9e..3865dc260 100644
--- a/index_impl.go
+++ b/index_impl.go
@@ -496,10 +496,12 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr
 	ctx = context.WithValue(ctx, search.GeoBufferPoolCallbackKey,
 		search.GeoBufferPoolCallbackFunc(getBufferPool))
 
-
-	// Using a disjunction query to get union of results from KNN query
-	// and the original query
-	searchQuery := disjunctQueryWithKNN(req)
+	// Using a query to get results from KNN queries
+	// and the original query based on the KNN operator.
+	searchQuery, err := queryWithKNN(req)
+	if err != nil {
+		return nil, err
+	}
 
 	searcher, err := searchQuery.Searcher(ctx, indexReader, i.m, search.SearcherOptions{
 		Explain:            req.Explain,
diff --git a/search_knn.go b/search_knn.go
index a2f8d343c..f5fee6341 100644
--- a/search_knn.go
+++ b/search_knn.go
@@ -19,12 +19,15 @@ package bleve
 
 import (
 	"encoding/json"
+	"fmt"
 	"sort"
 
 	"github.com/blevesearch/bleve/v2/search"
 	"github.com/blevesearch/bleve/v2/search/query"
 )
 
+type knnOperator string
+
 type SearchRequest struct {
 	Query            query.Query       `json:"query"`
 	Size             int               `json:"size"`
@@ -39,7 +42,8 @@ type SearchRequest struct {
 	SearchAfter      []string          `json:"search_after"`
 	SearchBefore     []string          `json:"search_before"`
 
-	KNN []*KNNRequest `json:"knn"`
+	KNN         []*KNNRequest `json:"knn"`
+	KNNOperator knnOperator   `json:"knn_operator"`
 
 	sortFunc func(sort.Interface)
 }
@@ -61,6 +65,10 @@ func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost fl
 	})
 }
 
+func (r *SearchRequest) AddKNNOperator(operator knnOperator) {
+	r.KNNOperator = operator
+}
+
 // UnmarshalJSON deserializes a JSON representation of
 // a SearchRequest
 func (r *SearchRequest) UnmarshalJSON(input []byte) error {
@@ -78,6 +86,7 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error {
 		SearchAfter      []string          `json:"search_after"`
 		SearchBefore     []string          `json:"search_before"`
 		KNN              []*KNNRequest     `json:"knn"`
+		KNNOperator      knnOperator       `json:"knn_operator"`
 	}
 
 	err := json.Unmarshal(input, &temp)
@@ -120,6 +129,10 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error {
 	}
 
 	r.KNN = temp.KNN
+	r.KNNOperator = temp.KNNOperator
+	if r.KNNOperator == "" {
+		r.KNNOperator = knnOperatorOr
+	}
 
 	return nil
 
@@ -142,24 +155,36 @@ func copySearchRequest(req *SearchRequest) *SearchRequest {
 		SearchAfter:      req.SearchAfter,
 		SearchBefore:     req.SearchBefore,
 		KNN:              req.KNN,
+		KNNOperator:      req.KNNOperator,
 	}
 	return &rv
 
 }
 
-func disjunctQueryWithKNN(req *SearchRequest) query.Query {
+var (
+	knnOperatorAnd = knnOperator("and")
+	knnOperatorOr  = knnOperator("or")
+)
+
+func queryWithKNN(req *SearchRequest) (query.Query, error) {
 	if len(req.KNN) > 0 {
-		disjuncts := []query.Query{req.Query}
+		subQueries := []query.Query{req.Query}
 		for _, knn := range req.KNN {
 			if knn != nil {
 				knnQuery := query.NewKNNQuery(knn.Vector)
 				knnQuery.SetFieldVal(knn.Field)
 				knnQuery.SetK(knn.K)
 				knnQuery.SetBoost(knn.Boost.Value())
-				disjuncts = append(disjuncts, knnQuery)
+				subQueries = append(subQueries, knnQuery)
 			}
 		}
-		return query.NewDisjunctionQuery(disjuncts)
+		if req.KNNOperator == knnOperatorAnd {
+			return query.NewConjunctionQuery(subQueries), nil
+		} else if req.KNNOperator == knnOperatorOr || req.KNNOperator == "" {
+			return query.NewDisjunctionQuery(subQueries), nil
+		} else {
+			return nil, fmt.Errorf("unknown knn operator: %s", req.KNNOperator)
+		}
 	}
-	return req.Query
+	return req.Query, nil
 }
diff --git a/search_no_knn.go b/search_no_knn.go
index fb3814911..ad7b35220 100644
--- a/search_no_knn.go
+++ b/search_no_knn.go
@@ -144,6 +144,6 @@ func copySearchRequest(req *SearchRequest) *SearchRequest {
 	return &rv
 }
 
-func disjunctQueryWithKNN(req *SearchRequest) query.Query {
-	return req.Query
+func queryWithKNN(req *SearchRequest) (query.Query, error) {
+	return req.Query, nil
 }