Skip to content

Commit

Permalink
Add KNN Operator
Browse files Browse the repository at this point in the history
  • Loading branch information
metonymic-smokey committed Nov 21, 2023
1 parent 645d0e3 commit ee4c5c2
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 11 deletions.
9 changes: 6 additions & 3 deletions index_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,12 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr
ctx = context.WithValue(ctx, search.GeoBufferPoolCallbackKey,
search.GeoBufferPoolCallbackFunc(getBufferPool))

// Using a disjunction query to get union of results from KNN query
// and the original query
searchQuery := disjunctQueryWithKNN(req)
// Using a query to get results from KNN queries
// and the original query based on the KNN operator.
searchQuery, err := queryWithKNN(req)
if err != nil {
return nil, err
}

searcher, err := searchQuery.Searcher(ctx, indexReader, i.m, search.SearcherOptions{
Explain: req.Explain,
Expand Down
36 changes: 30 additions & 6 deletions search_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"github.com/blevesearch/bleve/v2/search/query"
)

type knnOperator string

type SearchRequest struct {
Query query.Query `json:"query"`
Size int `json:"size"`
Expand All @@ -40,7 +42,8 @@ type SearchRequest struct {
SearchAfter []string `json:"search_after"`
SearchBefore []string `json:"search_before"`

KNN []*KNNRequest `json:"knn"`
KNN []*KNNRequest `json:"knn"`
KNNOperator knnOperator `json:"knn_operator"`

sortFunc func(sort.Interface)
}
Expand All @@ -62,6 +65,10 @@ func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost fl
})
}

func (r *SearchRequest) AddKNNOperator(operator knnOperator) {
r.KNNOperator = operator
}

// UnmarshalJSON deserializes a JSON representation of
// a SearchRequest
func (r *SearchRequest) UnmarshalJSON(input []byte) error {
Expand All @@ -79,6 +86,7 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error {
SearchAfter []string `json:"search_after"`
SearchBefore []string `json:"search_before"`
KNN []*KNNRequest `json:"knn"`
KNNOperator knnOperator `json:"knn_operator"`
}

err := json.Unmarshal(input, &temp)
Expand Down Expand Up @@ -121,6 +129,10 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error {
}

r.KNN = temp.KNN
r.KNNOperator = temp.KNNOperator
if r.KNNOperator == "" {
r.KNNOperator = knnOperatorOr
}

return nil

Expand All @@ -143,26 +155,38 @@ func copySearchRequest(req *SearchRequest) *SearchRequest {
SearchAfter: req.SearchAfter,
SearchBefore: req.SearchBefore,
KNN: req.KNN,
KNNOperator: req.KNNOperator,
}
return &rv

}

func disjunctQueryWithKNN(req *SearchRequest) query.Query {
var (
knnOperatorAnd = knnOperator("and")
knnOperatorOr = knnOperator("or")
)

func queryWithKNN(req *SearchRequest) (query.Query, error) {
if len(req.KNN) > 0 {
disjuncts := []query.Query{req.Query}
subQueries := []query.Query{req.Query}
for _, knn := range req.KNN {
if knn != nil {
knnQuery := query.NewKNNQuery(knn.Vector)
knnQuery.SetFieldVal(knn.Field)
knnQuery.SetK(knn.K)
knnQuery.SetBoost(knn.Boost.Value())
disjuncts = append(disjuncts, knnQuery)
subQueries = append(subQueries, knnQuery)
}
}
return query.NewDisjunctionQuery(disjuncts)
if req.KNNOperator == knnOperatorAnd {
return query.NewConjunctionQuery(subQueries), nil
} else if req.KNNOperator == knnOperatorOr || req.KNNOperator == "" {
return query.NewDisjunctionQuery(subQueries), nil
} else {
return nil, fmt.Errorf("unknown knn operator: %s", req.KNNOperator)
}
}
return req.Query
return req.Query, nil
}

func validateKNN(req *SearchRequest) error {
Expand Down
133 changes: 133 additions & 0 deletions search_knn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
//go:build vectors
// +build vectors

package bleve

import (
"fmt"
"log"
"math/rand"
"strconv"
"testing"

"github.com/blevesearch/bleve/v2/mapping"
"github.com/blevesearch/bleve/v2/search/query"
)

// Test to see if KNN Operators get added right to the query.
func TestKNNOperator(t *testing.T) {
tmpIndexPath := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPath)

dataset := make([]map[string]interface{}, 10)

// Indexing just a few docs to populate index.
for i := 0; i < 10; i++ {
docVec := []float32{}
for i := 0; i < 5; i++ {
docVec = append(docVec, rand.Float32())
}
dataset = append(dataset, map[string]interface{}{
"type": "vectorStuff",
"content": strconv.Itoa(i),
"vector": docVec,
})
}

indexMapping := NewIndexMapping()
indexMapping.TypeField = "type"
indexMapping.DefaultAnalyzer = "en"
documentMapping := NewDocumentMapping()
indexMapping.AddDocumentMapping("vectorStuff", documentMapping)

contentFieldMapping := NewTextFieldMapping()
contentFieldMapping.Index = true
contentFieldMapping.Store = true
documentMapping.AddFieldMappingsAt("content", contentFieldMapping)

vecFieldMapping := mapping.NewVectorFieldMapping()
vecFieldMapping.Index = true
vecFieldMapping.Dims = 5
vecFieldMapping.Similarity = "dot_product"
documentMapping.AddFieldMappingsAt("vector", vecFieldMapping)

index, err := New(tmpIndexPath, indexMapping)
if err != nil {
log.Fatal(err)
}
defer func() {
err := index.Close()
if err != nil {
log.Fatal(err)
}
}()

batch := index.NewBatch()
for i := 0; i < len(dataset); i++ {
batch.Index(strconv.Itoa(i), dataset[i])
}

err = index.Batch(batch)
if err != nil {
log.Fatal(err)
}

termQuery := query.NewTermQuery("world")

searchRequest := NewSearchRequest(termQuery)
queryVec2 := getQueryVec("hilly region worldwide")
searchRequest.AddKNN("vector", queryVec2, 3, 2.0)
searchRequest.AddKNN("vector", queryVec2, 2, 1.5)
searchRequest.Fields = []string{"content", "vector"}

// Conjunction
searchRequest.AddKNNOperator(knnOperatorAnd)
conjunction, err := queryWithKNN(searchRequest)
if err != nil {
log.Fatal(fmt.Errorf("unexpected error for AND knn operator"))
}

conj, ok := conjunction.(*query.ConjunctionQuery)
if !ok {
log.Fatal(fmt.Errorf("expected conjunction query"))
}

if len(conj.Conjuncts) == 3 {
_, ok := conj.Conjuncts[0].(*query.TermQuery)
if !ok {
log.Fatal(fmt.Errorf("expected first query to be a term query,"+
" but it's %T", conj.Conjuncts[0]))
}
} else {
log.Fatal(fmt.Errorf("expected 3 conjuncts"))
}

// Disjunction
searchRequest.AddKNNOperator(knnOperatorOr)
disjunction, err := queryWithKNN(searchRequest)
if err != nil {
log.Fatal(fmt.Errorf("unexpected error for OR knn operator"))
}

disj, ok := disjunction.(*query.DisjunctionQuery)
if !ok {
log.Fatal(fmt.Errorf("expected disjunction query"))
}

if len(disj.Disjuncts) == 3 {
_, ok := disj.Disjuncts[0].(*query.TermQuery)
if !ok {
log.Fatal(fmt.Errorf("expected first query to be a term query,"+
" but it's %T", conj.Conjuncts[0]))
}
} else {
log.Fatal(fmt.Errorf("expected 3 disjuncts"))
}

// Incorrect operator.
searchRequest.AddKNNOperator("bs_op")
searchRequest.Query, err = queryWithKNN(searchRequest)
if err == nil {
log.Fatal(fmt.Errorf("expected error for incorrect knn operator"))
}
}
4 changes: 2 additions & 2 deletions search_no_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ func copySearchRequest(req *SearchRequest) *SearchRequest {
return &rv
}

func disjunctQueryWithKNN(req *SearchRequest) query.Query {
return req.Query
func queryWithKNN(req *SearchRequest) (query.Query, error) {
return req.Query, nil
}

func validateKNN(req *SearchRequest) error {
Expand Down

0 comments on commit ee4c5c2

Please sign in to comment.